/
Weights.scala
112 lines (95 loc) · 3.83 KB
/
Weights.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
package com.thoughtworks.deeplearning.plugins
import com.thoughtworks.deeplearning.DeepLearning
import com.thoughtworks.deeplearning.DeepLearning.Tape
import com.thoughtworks.feature.{Factory, ImplicitApply, PartialApply}
import com.thoughtworks.raii.asynchronous.Do
import shapeless.Witness
import scalaz.{-\/, \/-}
/**
* @author 杨博 (Yang Bo)
*/
trait Weights {
trait WeightApi { self =>
protected type Optimizer <: Weights.this.Optimizer {
type Weight >: self.type
type Delta = self.Delta
}
type Data
type Delta
private[Weights] def handleExceptionFriend(throwable: Throwable): Unit = handleException(throwable)
protected def handleException(throwable: Throwable): Unit = ()
var data: Data
}
/** @template */
type Weight <: WeightApi
object Weight {
type Aux[Optimizer0, Data0, Delta0] = Weight {
type Data = Data0
type Delta = Delta0
type Optimizer = Optimizer0
}
}
trait OptimizerApi { self =>
type Delta
protected type Weight <: Weights.this.Weight {
type Optimizer >: self.type
}
protected val weight: Weight
protected val originalDelta: Delta
protected def delta: Delta = originalDelta
protected def update(): Do[Unit]
private[Weights] def updateFriend(): Do[Unit] = update()
}
/** @template */
type Optimizer <: OptimizerApi
trait ImplicitsApi {
implicit def weightDeepLearning[SubtypeOfWeight,
Optimizer0 <: Optimizer,
Data0,
Delta0,
OptimizerConstructor,
WeightParameter,
WeightRest,
OriginalDeltaParameter,
OriginalDeltaRest,
SubtypeOfOptimizer](
implicit asWeight: SubtypeOfWeight <:< Weight.Aux[Optimizer0, Data0, Delta0],
factory: Factory.Aux[Optimizer0, OptimizerConstructor],
partialApplyWeight: PartialApply.Aux[OptimizerConstructor, Witness.`"weight"`.T, WeightParameter, WeightRest],
asWeightParameter: SubtypeOfWeight <:< WeightParameter,
partialApplyOriginalDelta: PartialApply.Aux[WeightRest,
Witness.`"originalDelta"`.T,
OriginalDeltaParameter,
OriginalDeltaRest],
asOriginalDeltaParameter: Delta0 <:< OriginalDeltaParameter,
implicitApplyRest: ImplicitApply.Aux[OriginalDeltaRest, SubtypeOfOptimizer],
asOptimizer: SubtypeOfOptimizer <:< Optimizer0
): DeepLearning.Aux[SubtypeOfWeight, Data0, Delta0] = {
new DeepLearning[SubtypeOfWeight] {
override type Data = Data0
override type Delta = Delta0
override def forward(subtypeOfWeight: SubtypeOfWeight): Do[Tape[Data0, Delta0]] = {
val weight = asWeight(subtypeOfWeight)
Do.now(
Tape[Data0, Delta0](
weight.data, { doDelta: Do[Delta0] =>
val doUpdate: Do[Unit] = Do.releaseFlatMap(doDelta) { delta =>
asOptimizer(
implicitApplyRest(
partialApplyOriginalDelta(partialApplyWeight(factory.newInstance,
asWeightParameter(subtypeOfWeight)),
asOriginalDeltaParameter(delta)))).updateFriend()
}
Do.run(doUpdate).get.map {
case \/-(()) => ()
case -\/(e) => weight.handleExceptionFriend(e)
}
}
))
}
}
}
}
/** @template */
type Implicits <: ImplicitsApi
}