/
CumulativeFloatLayers.scala
77 lines (65 loc) · 2.46 KB
/
CumulativeFloatLayers.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
package com.thoughtworks.deeplearning
package plugins
import com.thoughtworks.deeplearning.DeepLearning.Tape
import com.thoughtworks.raii.asynchronous.Do
import com.thoughtworks.raii.asynchronous.Do._
import com.thoughtworks.raii.shared._
import com.thoughtworks.raii.covariant.{Releasable, ResourceT}
import com.thoughtworks.tryt.covariant.TryT
import scalaz.concurrent.Future
import scala.util.{Success, Try}
import scalaz.{-\/, \/-}
import scalaz.syntax.all._
/** A plugin that provides differentiable operators
* on neural networks whose [[DeepLearning.Data Data]] and [[DeepLearning.Delta Delta]] is [[scala.Float]].
*
* @note Unlike [[FloatLayers]], [[FloatLayer]] in this `CumulativeFloatLayers` will share [[DeepLearning.Tape Tape]]s
* created in [[FloatLayer.forward forward]] pass pass for all dependencies, avoiding re-evaluation
* in the case of diamond dependencies in a neural network.
*
* @author 杨博 (Yang Bo)
*/
trait CumulativeFloatLayers extends FloatLayers {
trait FloatLayerApi extends super[FloatLayers].FloatLayerApi {
private def doCumulativeTape: Do[Tape[Float, Float]] = {
super.forward.flatMap {
case Tape(data, flushBackward) =>
Do(Future.delay(new Releasable[Future, Try[Tape[Float, Float]]] {
@volatile
private var currentDelta: Float = 0
override def value: Try[Tape[Float, Float]] = {
def cumulativeBackward(doDelta: Do[Float]): Future[Unit] = {
Do.run(doDelta)
.map { delta =>
synchronized {
currentDelta += delta
}
}
.get
.map {
case \/-(()) => // Success. Do nothing
case -\/(e) => handleException(e)
}
}
Success(Tape(data, cumulativeBackward))
}
override def release(): Future[Unit] = {
flushBackward(Do.delay {
synchronized {
val delta = currentDelta
currentDelta = 0
delta
}
})
}
}))
}
}
@transient
private lazy val sharedForward: Do[Tape[Float, Float]] = {
Do.shared(doCumulativeTape)
}
abstract override def forward: Do[DeepLearning.Tape[Float, Float]] = sharedForward
}
override type FloatLayer <: FloatLayerApi with Layer
}