Skip to content

Commit

Permalink
Remove MultipleException defined in DeepLearning.scala, use MultipleE…
Browse files Browse the repository at this point in the history
…xception in future.scala instead
  • Loading branch information
Atry committed May 25, 2018
1 parent 3a73afa commit d6a0300
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 63 deletions.
Expand Up @@ -14,66 +14,6 @@ import scalaz.Semigroup

object DeepLearning {

implicit object multipleExceptionThrowableSemigroup extends Semigroup[Throwable] {
override def append(f1: Throwable, f2: => Throwable): Throwable =
f1 match {
case me1: AbstractMultipleException =>
f2 match {
case me2: AbstractMultipleException => MultipleException(me1.throwableSet ++ me2.throwableSet)
case e: Throwable => MultipleException(me1.throwableSet + e)
}
case _: Throwable =>
f2 match {
case me2: AbstractMultipleException => MultipleException(me2.throwableSet + f1)
case `f1` => f1
case e: Throwable => MultipleException(Set(f1, e))
}
}
}

private final case class MultipleException(throwableSet: Set[Throwable])
extends DeepLearning.AbstractMultipleException

abstract class AbstractMultipleException extends RuntimeException("Multiple exceptions found") {

def throwableSet: Set[Throwable]

override def toString: String = throwableSet.mkString("\n")

override def printStackTrace(): Unit = {
for (throwable <- throwableSet) {
throwable.printStackTrace()
}
}

override def printStackTrace(s: PrintStream): Unit = {
for (throwable <- throwableSet) {
throwable.printStackTrace(s)
}
}

override def printStackTrace(s: PrintWriter): Unit = {
for (throwable <- throwableSet) {
throwable.printStackTrace(s)
}
}

override def getStackTrace: Array[StackTraceElement] = synchronized {
super.getStackTrace match {
case null =>
setStackTrace(throwableSet.flatMap(_.getStackTrace)(collection.breakOut))
super.getStackTrace
case stackTrace =>
stackTrace
}
}

override def fillInStackTrace(): this.type = {
this
}

}

/** The node of wengert list created during [[DeepLearning.forward forward]] pass */
final case class Tape[+Data, -Delta](data: Data, backward: Do[Delta] => UnitContinuation[Unit])

Expand Down
@@ -1,6 +1,7 @@
package com.thoughtworks.deeplearning.plugins

import com.thoughtworks.continuation._
import com.thoughtworks.future._
import com.thoughtworks.deeplearning.DeepLearning
import com.thoughtworks.deeplearning.DeepLearning.Tape
import com.thoughtworks.raii.asynchronous._
Expand All @@ -15,8 +16,7 @@ import scalaz.Semigroup

private object HLists {

implicit val doParallelApplicative =
asynchronousDoParallelApplicative(DeepLearning.multipleExceptionThrowableSemigroup)
implicit val doParallelApplicative = asynchronousDoParallelApplicative

private val noop: Do[HNil] => UnitContinuation[Unit] = {
Function.const(UnitContinuation.now(()))
Expand Down Expand Up @@ -53,7 +53,7 @@ trait HLists {

val doTail: ParallelDo[Tape[TailData, TailDelta]] = Parallel(tailDeepLearning.forward(tail))

Parallel.unwrap(Applicative[ParallelDo].tuple2(doHead, doTail)).map {
Parallel.unwrap(doParallelApplicative.tuple2(doHead, doTail)).map {
case (Tape(headData, headBackward), Tape(tailData, tailBackward)) =>
def backward(doDelta: Do[HeadDelta :: TailDelta]) = {
val continuationHead: ParallelContinuation[Unit] = Parallel(headBackward(doDelta.map(_.head)))
Expand Down

0 comments on commit d6a0300

Please sign in to comment.