-
Notifications
You must be signed in to change notification settings - Fork 87
/
INDArrayWeights.scala
105 lines (78 loc) · 3.57 KB
/
INDArrayWeights.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
package com.thoughtworks.deeplearning
package plugins
import com.thoughtworks.feature.Factory.inject
import com.thoughtworks.feature.{Factory, ImplicitApply, PartialApply}
import com.thoughtworks.raii.asynchronous.Do
import com.thoughtworks.raii.asynchronous.Do._
import org.nd4j.linalg.api.ndarray.INDArray
import shapeless.Witness
import org.nd4s.Implicits._
import annotation.meta.getter
import scala.concurrent.ExecutionContext
import scalaz.syntax.all._
/** A plugin to create [[org.nd4j.linalg.api.ndarray.INDArray]] weights.
*
* @note Custom optimization algorithm for updating [[INDArrayWeight]] can be implemented by creating a plugin
* that provides a overridden [[INDArrayOptimizer]] that provides an overridden [[INDArrayOptimizer.delta]].
*
* @author 杨博 (Yang Bo)
*/
trait INDArrayWeights extends Weights with ImplicitsSingleton {
@inject
implicit protected def deepLearningExecutionContext: ExecutionContext
override type Implicits <: ImplicitsApi
import implicits._
trait INDArrayWeightApi extends WeightApi { this: INDArrayWeight =>
override type Delta = INDArray
override type Data = INDArray
override protected type PartiallyAppliedOptimizer = indArrayPartialApplyOriginalDelta.Rest
override protected def backward[SubtypeOfOptimizer](originalDelta: INDArray)(
implicit implicitApplyRest: ImplicitApply.Aux[PartiallyAppliedOptimizer, SubtypeOfOptimizer],
asOptimizer: SubtypeOfOptimizer <:< OptimizerApi { type Delta <: INDArray }): Do[Unit] = {
Do.jump().map { _: Unit =>
val delta =
implicitApplyRest(
indArrayPartialApplyOriginalDelta(indArrayPartialApplyWeight(indArrayOptimizerFactory.newInstance,
indArrayWeightParameter(this)),
indArrayOriginalDeltaParameter(originalDelta))).delta
synchronized {
data -= delta
()
}
}
}
}
/** @template */
type INDArrayWeight <: INDArrayWeightApi with Weight
@inject
protected val indArrayWeightFactory: Factory[INDArrayWeight]
@inject
protected val indArrayPartialApplyData: PartialApply[indArrayWeightFactory.Constructor, Witness.`"data"`.T]
@inject
protected def indArrayDataParameter: INDArray <:< indArrayPartialApplyData.Parameter
object INDArrayWeight {
/** @usecase def apply(data: Float): INDArrayWeight = ???
*/
def apply[SubtypeOfWeight, OptimizerFunction, Optimizer](data: INDArray)(
implicit implicitApplyRest: ImplicitApply[indArrayPartialApplyData.Rest]) = {
implicitApplyRest(indArrayPartialApplyData(indArrayWeightFactory.newInstance, indArrayDataParameter(data)))
}
}
trait INDArrayOptimizerApi extends OptimizerApi { this: INDArrayOptimizer =>
override type Delta = INDArray
val weight: INDArrayWeight
}
/** @template */
type INDArrayOptimizer <: Optimizer with INDArrayOptimizerApi
@inject
protected val indArrayOptimizerFactory: Factory[INDArrayOptimizer]
@inject
protected val indArrayPartialApplyWeight: PartialApply[indArrayOptimizerFactory.Constructor, Witness.`"weight"`.T]
@inject
protected def indArrayWeightParameter: INDArrayWeight <:< indArrayPartialApplyWeight.Parameter
@inject
protected val indArrayPartialApplyOriginalDelta: PartialApply[indArrayPartialApplyWeight.Rest,
Witness.`"originalDelta"`.T]
@inject
protected def indArrayOriginalDeltaParameter: INDArray <:< indArrayPartialApplyOriginalDelta.Parameter
}