/
HLists.scala
76 lines (58 loc) · 2.54 KB
/
HLists.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
package com.thoughtworks.deeplearning.plugins
import com.thoughtworks.continuation._
import com.thoughtworks.future._
import com.thoughtworks.deeplearning.DeepLearning
import com.thoughtworks.deeplearning.DeepLearning.Tape
import com.thoughtworks.raii.asynchronous._
import scalaz.Applicative
import scalaz.syntax.all._
import scalaz.Tags.Parallel
import shapeless.{::, HList, HNil}
import java.io.{PrintStream, PrintWriter}
import scalaz.Semigroup
private object HLists {
implicit val doParallelApplicative = asynchronousDoParallelApplicative
private val noop: Do[HNil] => UnitContinuation[Unit] = {
Function.const(UnitContinuation.now(()))
}
}
/**
* @author 杨博 (Yang Bo)
*/
trait HLists {
import com.thoughtworks.deeplearning.plugins.HLists._
trait ImplicitsApi {
implicit def hnilDeepLearning[L <: HNil]: DeepLearning.Aux[L, HNil, HNil] = new DeepLearning[L] {
type Data = HNil
type Delta = HNil
def forward(differentiable: L): Do[Tape[Data, Delta]] = {
Do.now(Tape(HNil, noop))
}
}
implicit def hconsDeepLearning[Head, Tail <: HList, HeadData, TailData <: HList, HeadDelta, TailDelta <: HList](
implicit headDeepLearning: DeepLearning.Aux[Head, HeadData, HeadDelta],
tailDeepLearning: DeepLearning.Aux[Tail, TailData, TailDelta])
: DeepLearning.Aux[Head :: Tail, HeadData :: TailData, HeadDelta :: TailDelta] = new DeepLearning[Head :: Tail] {
type Data = HeadData :: TailData
type Delta = HeadDelta :: TailDelta
def forward(differentiable: Head :: Tail): Do[Tape[Data, Delta]] = {
val head :: tail = differentiable
val doHead: ParallelDo[Tape[HeadData, HeadDelta]] = Parallel(headDeepLearning.forward(head))
val doTail: ParallelDo[Tape[TailData, TailDelta]] = Parallel(tailDeepLearning.forward(tail))
Parallel.unwrap(doParallelApplicative.tuple2(doHead, doTail)).map {
case (Tape(headData, headBackward), Tape(tailData, tailBackward)) =>
def backward(doDelta: Do[HeadDelta :: TailDelta]) = {
val continuationHead: ParallelContinuation[Unit] = Parallel(headBackward(doDelta.map(_.head)))
val continuationTail: ParallelContinuation[Unit] = Parallel(tailBackward(doDelta.map(_.tail)))
Parallel.unwrap(continuationParallelApplicative.apply2(continuationHead, continuationTail) {
(_: Unit, _: Unit) =>
()
})
}
Tape(headData :: tailData, backward)
}
}
}
}
type Implicits <: ImplicitsApi
}