-
Notifications
You must be signed in to change notification settings - Fork 87
/
FloatWeights.scala
90 lines (67 loc) · 2.96 KB
/
FloatWeights.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
package com.thoughtworks.deeplearning
package plugins
import com.thoughtworks.feature.Factory.inject
import com.thoughtworks.feature.{Factory, ImplicitApply, PartialApply}
import com.thoughtworks.raii.asynchronous.Do
import com.thoughtworks.raii.asynchronous.Do._
import shapeless.Witness
/** A plugin to create [[scala.Float]] weights.
*
* @note Custom optimization algorithm for updating [[FloatWeight]] can be implemented by creating a plugin
* that provides an overridden [[FloatOptimizer]] that provides an overridden [[FloatOptimizer.delta]].
*
* @author 杨博 (Yang Bo)
*/
trait FloatWeights extends Weights {
trait FloatWeightApi extends WeightApi { this: FloatWeight =>
override type Delta = Float
override type Data = Float
override protected type PartiallyAppliedOptimizer = floatPartialApplyOriginalDelta.Rest
override protected def backward[SubtypeOfOptimizer](originalDelta: Float)(
implicit implicitApplyRest: ImplicitApply.Aux[PartiallyAppliedOptimizer, SubtypeOfOptimizer],
asOptimizer: SubtypeOfOptimizer <:< OptimizerApi { type Delta <: Float }): Do[Unit] = {
Do.delay {
val delta =
implicitApplyRest(
floatPartialApplyOriginalDelta(floatPartialApplyWeight(floatOptimizerFactory.newInstance,
floatWeightParameter(this)),
floatOriginalDeltaParameter(originalDelta))).delta
synchronized {
data -= delta
}
}
}
}
/** @template */
type FloatWeight <: FloatWeightApi with Weight
@inject
protected val floatWeightFactory: Factory[FloatWeight]
@inject
protected val floatPartialApplyData: PartialApply[floatWeightFactory.Constructor, Witness.`"data"`.T]
@inject
protected def floatDataParameter: Float <:< floatPartialApplyData.Parameter
object FloatWeight {
/** @usecase def apply(data: Float): FloatWeight = ???
*/
def apply[SubtypeOfWeight, OptimizerFunction, Optimizer](data: Float)(
implicit implicitApplyRest: ImplicitApply[floatPartialApplyData.Rest]) = {
implicitApplyRest(floatPartialApplyData(floatWeightFactory.newInstance, floatDataParameter(data)))
}
}
trait FloatOptimizerApi extends OptimizerApi { this: FloatOptimizer =>
override type Delta = Float
val weight: FloatWeight
}
/** @template */
type FloatOptimizer <: FloatOptimizerApi with Optimizer
@inject
protected val floatOptimizerFactory: Factory[FloatOptimizer]
@inject
protected val floatPartialApplyWeight: PartialApply[floatOptimizerFactory.Constructor, Witness.`"weight"`.T]
@inject
protected def floatWeightParameter: FloatWeight <:< floatPartialApplyWeight.Parameter
@inject
protected val floatPartialApplyOriginalDelta: PartialApply[floatPartialApplyWeight.Rest, Witness.`"originalDelta"`.T]
@inject
protected def floatOriginalDeltaParameter: Float <:< floatPartialApplyOriginalDelta.Parameter
}