/
FloatWeights.scala
64 lines (48 loc) · 1.69 KB
/
FloatWeights.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
package com.thoughtworks.deeplearning
package plugins
import com.thoughtworks.deeplearning.DeepLearning.Tape
import com.thoughtworks.feature.Factory.inject
import com.thoughtworks.feature.{Factory, ImplicitApply, PartialApply}
import com.thoughtworks.raii.asynchronous.Do
import com.thoughtworks.raii.asynchronous.Do._
import shapeless.Witness
import annotation.meta.getter
import scalaz.{-\/, \/-}
import scalaz.concurrent.Future
/**
* @author 杨博 (Yang Bo)
*/
trait FloatWeights extends Weights {
trait FloatWeightApi extends WeightApi { this: FloatWeight =>
override type Delta = Float
override type Data = Float
override protected type Optimizer = FloatOptimizer
}
/** @template */
type FloatWeight <: FloatWeightApi with Weight
@inject
protected val floatWeightFactory: Factory[FloatWeight]
@inject
protected val floatPartialApplyData: PartialApply[floatWeightFactory.Constructor, Witness.`"data"`.T]
@inject
protected def floatDataParameter: Float <:< floatPartialApplyData.Parameter
object FloatWeight extends {
def apply[SubtypeOfWeight, OptimizerFunction, Optimizer](data: Float)(
implicit implicitApplyRest: ImplicitApply[floatPartialApplyData.Rest]) = {
implicitApplyRest(floatPartialApplyData(floatWeightFactory.newInstance, floatDataParameter(data)))
}
}
trait FloatOptimizerApi extends OptimizerApi { this: FloatOptimizer =>
override type Delta = Float
override protected type Weight = FloatWeight
override protected def update() = {
Do.delay {
weight.synchronized {
weight.data -= delta
}
}
}
}
/** @template */
type FloatOptimizer <: FloatOptimizerApi with Optimizer
}