Skip to content

Commit

Permalink
Allow the user to define ShouldResetNestedFunctions = true to reset…
Browse files Browse the repository at this point in the history
… nested functions
  • Loading branch information
Atry committed Dec 24, 2021
1 parent 6a51ad8 commit 53e395e
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 25 deletions.
85 changes: 60 additions & 25 deletions reset/src/main/scala/com/thoughtworks/dsl/reset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,27 @@ import scala.util.control.Exception.Catcher
* r % 10 should not be r / 10
* }}}
*/
object reset {
trait reset:
type ShouldResetNestedFunctions <: Boolean & Singleton

transparent inline def reify[Value](inline value: Value): Any = ${
reset.Macros.reify[ShouldResetNestedFunctions, Value]('value)
}

class *[Functor[_]]() {
inline def apply[Value](inline value: Value): Functor[Value] = ${
reset.Macros.reset[ShouldResetNestedFunctions, Value, Functor[Value]]('value)
}
}
inline def *[Domain[_]]: *[Domain] = new *[Domain]

inline def apply[Value](inline value: Value): Value = ${
reset.Macros.reset[ShouldResetNestedFunctions, Value, Value]('value)
}

object reset extends reset.DefaultOptions {
trait DefaultOptions extends reset:
type ShouldResetNestedFunctions = false

private class Macros[Q <: Quotes](resetDescendant: Boolean)(using val qctx: Q) {
import qctx.reflect.{_, given}
Expand Down Expand Up @@ -785,32 +805,47 @@ object reset {
}

object Macros {
def reify[V](body: quoted.Expr[_])(using qctx: Quotes, tv: quoted.Type[V]): quoted.Expr[_] = {
Macros[qctx.type](resetDescendant = false).reify[V](body/*.underlyingArgument*/)
}

def reset[From, To](body: quoted.Expr[From])(using qctx: Quotes, fromType: quoted.Type[From], toType: quoted.Type[To]): quoted.Expr[To] = {
import qctx.reflect.{_, given}
val result: quoted.Expr[To] = Macros[qctx.type](resetDescendant = false).reset(body/*.underlyingArgument*/)
// report.warning(result.asTerm.show(using qctx.reflect.Printer.TreeStructure))
// report.warning(result.asTerm.show)
result
def reify[ShouldResetNestedFunctions <: Boolean & Singleton, V](
body: quoted.Expr[_]
)(using
qctx: Quotes,
translateNestedFunctions: quoted.Type[ShouldResetNestedFunctions],
tv: quoted.Type[V]
): quoted.Expr[_] = {
import quoted.quotes.reflect.*
quoted.Type.valueOfConstant[ShouldResetNestedFunctions] match {
case None =>
report.error("ShouldResetNestedFunctions is not defined", body)
'{ ??? }
case Some(translateNestedFunction) =>
Macros[qctx.type](resetDescendant =
quoted.Type.valueOfConstant[ShouldResetNestedFunctions].get
).reify[V](body /*.underlyingArgument*/ )
}
}
}

transparent inline def reify[Value](inline value: Value): Any = ${
Macros.reify[Value]('value)
}

class *[Functor[_]]() {
inline def apply[Value](inline value: Value): Functor[Value] = ${
Macros.reset[Value, Functor[Value]]('value)
def reset[ShouldResetNestedFunctions <: Boolean & Singleton, From, To](
body: quoted.Expr[From]
)(using
qctx: Quotes,
translateNestedFunctions: quoted.Type[ShouldResetNestedFunctions],
fromType: quoted.Type[From],
toType: quoted.Type[To]
): quoted.Expr[To] = {
import quoted.quotes.reflect.{_, given}
quoted.Type.valueOfConstant[ShouldResetNestedFunctions] match {
case None =>
report.error("ShouldResetNestedFunctions is not defined", body)
'{ ??? }
case Some(translateNestedFunction) =>
val result = Macros[qctx.type](resetDescendant =
quoted.Type.valueOfConstant[ShouldResetNestedFunctions].get
).reset[From, To](body /*.underlyingArgument*/ )
// report.warning(result.asTerm.show(using qctx.reflect.Printer.TreeStructure))
// report.warning(result.asTerm.show)
result
}
}
}
inline def *[Domain[_]]: *[Domain] = new *[Domain]

inline def apply[Value](inline value: Value): Value = ${
Macros.reset[Value, Value]('value)
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ object ReturnSpec extends TestSuite {
assert(result == 42)
}

"reset nested function" - {
new reset {
type ShouldResetNestedFunctions = true
}.apply {
def continuation = { (!Return(42)): Int!!String }
val result = continuation { s =>
throw new java.lang.AssertionError(s)
}
assert(result == 42)
}
}

"return the right domain" - {
def continuation: Int !! String = reset[Int !! String]{!Return("right value") }

Expand Down

0 comments on commit 53e395e

Please sign in to comment.