Skip to content

Commit be5c0d6

Browse files
RevySRAtry
authored andcommitted
Add the plugin Differentiables
1 parent f1bb230 commit be5c0d6

File tree

7 files changed

+43
-69
lines changed
  • plugins-Builtins/src/main/scala/com/thoughtworks/deeplearning/plugins
  • plugins-Differentiables/src/main/scala/com/thoughtworks/deeplearning/plugins
  • plugins-Layers/src/main/scala/com/thoughtworks/deeplearning/plugins
  • plugins-Logging/src/main/scala/com/thoughtworks/deeplearning/plugins
  • plugins-Names/src/main/scala/com/thoughtworks/deeplearning/plugins
  • plugins-Weights/src/main/scala/com/thoughtworks/deeplearning/plugins

7 files changed

+43
-69
lines changed

build.sbt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ includeFilter in unmanagedSources := (includeFilter in unmanagedSources).value &
44

55
lazy val DeepLearning = project
66

7-
lazy val `plugins-Layers` = project.dependsOn(DeepLearning)
7+
lazy val `plugins-Layers` = project.dependsOn(DeepLearning, `plugins-Differentiables`)
88

9-
lazy val `plugins-Weights` = project.dependsOn(DeepLearning)
9+
lazy val `plugins-Weights` = project.dependsOn(DeepLearning, `plugins-Differentiables`)
1010

11-
lazy val `plugins-Names` = project.dependsOn(`plugins-Layers`, `plugins-Weights`)
11+
lazy val `plugins-Names` = project.dependsOn(`plugins-Differentiables`)
1212

13-
lazy val `plugins-Logging` = project.dependsOn(`plugins-Layers`, `plugins-Weights`)
13+
lazy val `plugins-Logging` = project.dependsOn(`plugins-Differentiables`)
1414

1515
lazy val `plugins-Operators` = project
1616

@@ -104,6 +104,7 @@ lazy val `plugins-Builtins` =
104104
`plugins-CumulativeDoubleLayers`,
105105
DeepLearning % "test->test"
106106
)
107+
lazy val `plugins-Differentiables` = project
107108

108109
lazy val `plugins-OpenCLBuffers` =
109110
project.dependsOn(DeepLearning,

plugins-Builtins/src/main/scala/com/thoughtworks/deeplearning/plugins/Builtins.scala

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,6 @@ trait Builtins
3737

3838
type Implicits <: ImplicitsApi
3939

40-
trait LayerApi extends super[Logging].LayerApi with super[Names].LayerApi { this: Layer =>
41-
}
42-
43-
type Layer <: LayerApi
44-
45-
trait WeightApi extends super[Logging].WeightApi with super[Names].WeightApi { this: Weight =>
46-
}
47-
48-
type Weight <: WeightApi
40+
trait DifferentiableApi extends super[Logging].DifferentiableApi with super[Names].DifferentiableApi
41+
type Differentiable <: DifferentiableApi
4942
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package com.thoughtworks.deeplearning.plugins
2+
3+
/**
4+
* @author 杨博 (Yang Bo)
5+
*/
6+
trait Differentiables {
7+
8+
trait DifferentiableApi {
9+
10+
protected def handleException(throwable: Throwable): Unit = {
11+
throwable.printStackTrace()
12+
}
13+
}
14+
15+
type Differentiable <: DifferentiableApi
16+
17+
}

plugins-Layers/src/main/scala/com/thoughtworks/deeplearning/plugins/Layers.scala

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,23 +31,19 @@ object Layers {
3131
}
3232

3333
/** A plugin that enables [[Layer]] in neural networks. */
34-
trait Layers {
35-
trait LayerApi {
34+
trait Layers extends Differentiables {
35+
trait LayerApi extends DifferentiableApi {
3636
type Data
3737
type Delta
3838

3939
def forward: Do[Tape[Data, Delta]]
4040

41-
protected def handleException(throwable: Throwable): Unit = {
42-
throwable.printStackTrace()
43-
}
44-
4541
}
4642

4743
/** A differentiable operation.
4844
* @template
4945
*/
50-
type Layer <: LayerApi
46+
type Layer <: LayerApi with Differentiable
5147

5248
trait ImplicitsApi {
5349
implicit def layerDeepLearning[From, Data0, Delta0](implicit asLayer: From <:< LayerApi {

plugins-Logging/src/main/scala/com/thoughtworks/deeplearning/plugins/Logging.scala

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -33,53 +33,35 @@ object Logging {
3333
case message => message
3434
}
3535
}
36-
37-
final class ThrownInLayer(val layer: Layers#Layer, getThrown: Throwable)(implicit fullName: sourcecode.FullName,
38-
name: sourcecode.Name,
39-
caller: Caller[_])
40-
extends ContextualLogRecord(Level.SEVERE, thrown = getThrown)
41-
with LazyMessage {
42-
override protected def makeDefaultMessage: Fastring = fast"An exception is thrown in layer $layer"
43-
}
44-
45-
final class ThrownInWeight(val weight: Weights#Weight, getThrown: Throwable)(implicit fullName: sourcecode.FullName,
46-
name: sourcecode.Name,
47-
caller: Caller[_])
36+
final class UncaughtException(val differentiable: Logging#DifferentiableApi, getThrown: Throwable)(
37+
implicit fullName: sourcecode.FullName,
38+
name: sourcecode.Name,
39+
caller: Caller[_])
4840
extends ContextualLogRecord(Level.SEVERE, thrown = getThrown)
4941
with LazyMessage {
50-
override protected def makeDefaultMessage: Fastring = fast"An exception is thrown in weight $weight"
42+
override protected def makeDefaultMessage: Fastring = fast"An exception is thrown in $differentiable"
5143
}
5244

5345
}
5446

55-
/** A plugin that logs uncaught exceptions raised from [[Layer]] and [[Weight]].
47+
/** A plugin that logs uncaught exceptions.
5648
*
5749
* @author 杨博 (Yang Bo)
5850
*/
59-
trait Logging extends Layers with Weights {
51+
trait Logging extends Differentiables {
6052
import Logging._
6153

6254
@transient lazy val logger: Logger = Logger.getLogger(getClass.getName)
6355

64-
trait LayerApi extends super.LayerApi { this: Layer =>
56+
trait DifferentiableApi extends super.DifferentiableApi {
6557
implicit protected def fullName: sourcecode.FullName
6658
implicit protected def name: sourcecode.Name
6759
implicit protected def caller: Caller[_]
6860
override protected def handleException(thrown: Throwable): Unit = {
69-
logger.log(new ThrownInLayer(this, thrown))
61+
logger.log(new UncaughtException(this, thrown))
7062
}
7163
}
72-
override type Layer <: LayerApi
7364

74-
trait WeightApi extends super.WeightApi { this: Weight =>
75-
implicit protected def fullName: sourcecode.FullName
76-
implicit protected def name: sourcecode.Name
77-
implicit protected def caller: Caller[_]
78-
override protected def handleException(thrown: Throwable): Unit = {
79-
logger.log(new ThrownInWeight(this, thrown))
80-
}
81-
}
82-
override type Weight <: WeightApi
83-
override type Implicits <: super[Layers].ImplicitsApi with super[Weights].ImplicitsApi
65+
type Differentiable <: DifferentiableApi
8466

8567
}

plugins-Names/src/main/scala/com/thoughtworks/deeplearning/plugins/Names.scala

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,9 @@ package com.thoughtworks.deeplearning.plugins
44
*
55
* @author 杨博 (Yang Bo)
66
*/
7-
trait Names extends Layers with Weights {
7+
trait Names {
88

9-
trait LayerApi extends super.LayerApi { this: Layer =>
10-
def fullName: sourcecode.FullName
11-
def name: sourcecode.Name
12-
13-
override def toString: String = {
14-
raw"""Layer[fullName=${fullName.value}]"""
15-
}
16-
}
17-
override type Layer <: LayerApi
18-
19-
trait WeightApi extends super.WeightApi { this: Weight =>
9+
trait DifferentiableApi {
2010
def fullName: sourcecode.FullName
2111
def name: sourcecode.Name
2212

@@ -25,7 +15,6 @@ trait Names extends Layers with Weights {
2515
}
2616

2717
}
28-
override type Weight <: WeightApi
29-
override type Implicits <: super[Layers].ImplicitsApi with super[Weights].ImplicitsApi
18+
type Differentiable <: DifferentiableApi
3019

3120
}

plugins-Weights/src/main/scala/com/thoughtworks/deeplearning/plugins/Weights.scala

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ import com.thoughtworks.continuation._
1515
*
1616
* @author 杨博 (Yang Bo)
1717
*/
18-
trait Weights {
18+
trait Weights extends Differentiables {
1919

20-
trait WeightApi {
20+
trait WeightApi extends DifferentiableApi {
2121

2222
protected type PartiallyAppliedOptimizer
2323

@@ -49,16 +49,12 @@ trait Weights {
4949
type Data
5050
type Delta
5151

52-
protected def handleException(throwable: Throwable): Unit = {
53-
throwable.printStackTrace()
54-
}
55-
5652
var data: Data
5753

5854
}
5955

6056
/** @template */
61-
type Weight <: WeightApi
57+
type Weight <: WeightApi with Differentiable
6258

6359
trait OptimizerApi {
6460
type Delta

0 commit comments

Comments
 (0)