-
Notifications
You must be signed in to change notification settings - Fork 87
/
DeepLearning.scala
73 lines (56 loc) · 2.83 KB
/
DeepLearning.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
package com.thoughtworks.deeplearning
import com.thoughtworks.deeplearning.DeepLearning.Tape
import scalaz.concurrent.{Future, Task}
import scalaz.syntax.all._
import com.thoughtworks.raii.asynchronous.Do
import com.thoughtworks.raii.asynchronous.Do._
import simulacrum.typeclass
import scala.language.implicitConversions
import spire.algebra.MultiplicativeMonoid
object DeepLearning {
/** The node of wengert list created during [[DeepLearning.forward forward]] pass */
final case class Tape[+Data, -Delta](data: Data, backward: Do[Delta] => Future[Unit])
type Aux[Differentiable, Data0, Delta0] = DeepLearning[Differentiable] {
type Data = Data0
type Delta = Delta0
}
// The Scaladoc of members of trait DeepLearning must defined in `SimulacrumIssue82WorkAround`,
// in case of https://github.com/mpilquist/simulacrum/issues/82
private[DeepLearning] sealed trait SimulacrumIssue82WorkAround[Differentiable] {
/** The result value of forward pass */
type Data
/** The partial derivative for [[Data]] */
type Delta
/** Returns an asynchronous operation of forward pass, which creates a wengert list. */
def forward(differentiable: Differentiable): Do[Tape[Data, Delta]]
/** Returns a [[scalaz.concurrent.Task Task]] that updates [[plugins.Weights.Weight Weight]] internally used by `differentiable`. */
def train(differentiable: Differentiable)(implicit monoid: MultiplicativeMonoid[Delta]): Task[Data]
/** Returns a [[scalaz.concurrent.Task Task]] of the value of the `differentiable` expression. */
def predict(differentiable: Differentiable): Task[Data]
}
}
import DeepLearning._
/** A type class that witnesses `Differentiable` is a differentiable expression.
*
* Common differentiable types that supports [[DeepLearning]] are:
*
* - [[scala.Float Float]], [[plugins.FloatWeights.FloatWeight FloatWeight]] or [[plugins.FloatLayers.FloatLayer FloatLayer]]
* - [[scala.Double Double]], [[plugins.DoubleWeights.DoubleWeight DoubleWeight]] or [[plugins.DoubleLayers.DoubleLayer DoubleLayer]]
* - [[org.nd4j.linalg.api.ndarray.INDArray INDArray]], [[plugins.INDArrayWeights.INDArrayWeight INDArrayWeight]] or [[plugins.INDArrayLayers.INDArrayLayer INDArrayLayer]]
*/
@typeclass(excludeParents = List("SimulacrumIssue82WorkAround"))
trait DeepLearning[Differentiable] extends SimulacrumIssue82WorkAround[Differentiable] {
type Data
type Delta
def forward(differentiable: Differentiable): Do[Tape[Data, Delta]]
final def train(differentiable: Differentiable)(implicit monoid: MultiplicativeMonoid[Delta]): Task[Data] = {
Do.run(forward(differentiable).flatMap[Data] { tape =>
Do.delay(tape.backward(Do.now(monoid.one))).map { _ =>
tape.data
}
})
}
final def predict(differentiable: Differentiable): Task[Data] = {
Do.run(forward(differentiable).map(_.data))
}
}