/
FloatWeights.scala
87 lines (65 loc) · 2.76 KB
/
FloatWeights.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
package com.thoughtworks.deeplearning
package plugins
import com.thoughtworks.deeplearning.DeepLearning.Tape
import com.thoughtworks.feature.Factory.inject
import com.thoughtworks.feature.{Factory, ImplicitApply, PartialApply}
import com.thoughtworks.raii.asynchronous.Do
import com.thoughtworks.raii.asynchronous.Do._
import shapeless.Witness
import annotation.meta.getter
import scalaz.{-\/, \/-}
import scalaz.concurrent.Future
/**
* @author 杨博 (Yang Bo)
*/
trait FloatWeights extends Weights {
trait FloatWeightApi extends WeightApi { this: FloatWeight =>
override type Delta = Float
override type Data = Float
override protected type PartiallyAppliedOptimizer = floatPartialApplyOriginalDelta.Rest
override protected def backward[SubtypeOfOptimizer](originalDelta: Float)(
implicit implicitApplyRest: ImplicitApply.Aux[PartiallyAppliedOptimizer, SubtypeOfOptimizer],
asOptimizer: SubtypeOfOptimizer <:< Optimizer.Aux[Delta]): Do[Unit] = {
Do.delay {
val delta =
implicitApplyRest(
floatPartialApplyOriginalDelta(floatPartialApplyWeight(floatOptimizerFactory.newInstance,
floatWeightParameter(this)),
floatOriginalDeltaParameter(originalDelta))).delta
synchronized {
data -= delta
}
}
}
}
/** @template */
type FloatWeight <: FloatWeightApi with Weight
@inject
protected val floatWeightFactory: Factory[FloatWeight]
@inject
protected val floatPartialApplyData: PartialApply[floatWeightFactory.Constructor, Witness.`"data"`.T]
@inject
protected def floatDataParameter: Float <:< floatPartialApplyData.Parameter
object FloatWeight extends {
def apply[SubtypeOfWeight, OptimizerFunction, Optimizer](data: Float)(
implicit implicitApplyRest: ImplicitApply[floatPartialApplyData.Rest]) = {
implicitApplyRest(floatPartialApplyData(floatWeightFactory.newInstance, floatDataParameter(data)))
}
}
trait FloatOptimizerApi extends OptimizerApi { this: FloatOptimizer =>
override type Delta = Float
val weight: FloatWeight
}
/** @template */
type FloatOptimizer <: FloatOptimizerApi with Optimizer
@inject
protected val floatOptimizerFactory: Factory[FloatOptimizer]
@inject
protected val floatPartialApplyWeight: PartialApply[floatOptimizerFactory.Constructor, Witness.`"weight"`.T]
@inject
protected def floatWeightParameter: FloatWeight <:< floatPartialApplyWeight.Parameter
@inject
protected val floatPartialApplyOriginalDelta: PartialApply[floatPartialApplyWeight.Rest, Witness.`"originalDelta"`.T]
@inject
protected def floatOriginalDeltaParameter: Float <:< floatPartialApplyOriginalDelta.Parameter
}