-
Notifications
You must be signed in to change notification settings - Fork 87
/
INDArrayWeights.scala
69 lines (52 loc) · 1.96 KB
/
INDArrayWeights.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
package com.thoughtworks.deeplearning
package plugins
import com.thoughtworks.feature.Factory.inject
import com.thoughtworks.feature.{Factory, ImplicitApply, PartialApply}
import com.thoughtworks.raii.asynchronous.Do
import com.thoughtworks.raii.asynchronous.Do._
import org.nd4j.linalg.api.ndarray.INDArray
import shapeless.Witness
import org.nd4s.Implicits._
import annotation.meta.getter
import scala.concurrent.ExecutionContext
import scalaz.syntax.all._
/**
* @author 杨博 (Yang Bo) <pop.atry@gmail.com>
*/
trait INDArrayWeights extends Weights with ImplicitsSingleton {
@inject
implicit protected def deepLearningExecutionContext: ExecutionContext
type Implicits <: ImplicitsApi
import implicits._
trait INDArrayWeightApi extends WeightApi { this: INDArrayWeight =>
type Delta = INDArray
override type Data = INDArray
override protected type Optimizer = INDArrayOptimizer
}
type INDArrayWeight <: INDArrayWeightApi with Weight
@inject
protected val indArrayWeightFactory: Factory[INDArrayWeight]
@inject
protected val indArrayPartialApplyData: PartialApply[indArrayWeightFactory.Constructor, Witness.`"data"`.T]
@inject
protected def indArrayDataParameter: INDArray <:< indArrayPartialApplyData.Parameter
object INDArrayWeight extends {
def apply[SubtypeOfWeight, OptimizerFunction, Optimizer](data: INDArray)(
implicit implicitApplyRest: ImplicitApply[indArrayPartialApplyData.Rest]) = {
implicitApplyRest(indArrayPartialApplyData(indArrayWeightFactory.newInstance, indArrayDataParameter(data)))
}
}
trait INDArrayOptimizerApi extends OptimizerApi { this: INDArrayOptimizer =>
type Delta = INDArray
override protected type Weight = INDArrayWeight
override protected def update(): Do[Unit] = {
Do.jump().map { _: Unit =>
weight.synchronized {
weight.data -= delta
()
}
}
}
}
type INDArrayOptimizer <: INDArrayOptimizerApi with Optimizer
}