Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 74 additions & 70 deletions src/main/scala/wasm/MiniWasmFX.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,49 +15,50 @@ case class EvaluatorFX(module: ModuleInstance) {
type Stack = List[Value]

trait Cont[A] {
def apply(stack: Stack, trail: Trail[A], mcont: MCont[A], handler: Handlers[A]): A
def apply(stack: Stack, trail: Trail[A], handler: Handlers[A]): A
}
type Trail[A] = List[(Cont[A], List[Int])] // trail items are pairs of continuation and tags
type MCont[A] = Stack => A

type Handler[A] = Stack => A
type Handlers[A] = List[(Int, Handler[A])]

case class ContV[A](k: (Stack, Cont[A], Trail[A], MCont[A], Handlers[A]) => A) extends Value {
case class ContV[A](k: (Stack, Cont[A], Trail[A], Handlers[A]) => A) extends Value {
def tipe(implicit m: ModuleInstance): ValueType = ???

// override def toString: String = "ContV"
}

// initK is a continuation that simply returns the inputed stack
def initK[Ans](s: Stack, trail: Trail[Ans], mkont: MCont[Ans], hs: Handlers[Ans]): Ans =
def initK[Ans](s: Stack, trail: Trail[Ans], hs: Handlers[Ans]): Ans =
trail match {
case (k1, _) :: trail => k1(s, trail, mkont, hs)
case Nil => mkont(s)
// Currently, the last element of the Trail is the halt continuation
// the exception will never be thrown
case (k1, _) :: trail => k1(s, trail, hs)
case Nil => throw new Exception("No halting continuation in trail")
}

def eval1[Ans](inst: Instr, stack: Stack, frame: Frame,
kont: Cont[Ans], trail: Trail[Ans], mkont: MCont[Ans],
brTable: List[Cont[Ans]], hs: Handlers[Ans]): Ans = {
def eval1[Ans](inst: Instr, stack: Stack, frame: Frame, kont: Cont[Ans],
trail: Trail[Ans], brTable: List[Cont[Ans]], hs: Handlers[Ans]): Ans = {
// System.err.println(f"[DEBUG] ${inst} | ${frame} | ${stack.reverse} | handlers: ${hs}");
inst match {
case Drop => kont(stack.tail, trail, mkont, hs)
case Drop => kont(stack.tail, trail, hs)
case Select(_) =>
val I32V(cond) :: v2 :: v1 :: newStack = stack
val value = if (cond == 0) v1 else v2
kont(value :: newStack, trail, mkont, hs)
kont(value :: newStack, trail, hs)
case LocalGet(i) =>
kont(frame.locals(i) :: stack, trail, mkont, hs)
kont(frame.locals(i) :: stack, trail, hs)
case LocalSet(i) =>
val value :: newStack = stack
frame.locals(i) = value
kont(newStack, trail, mkont, hs)
kont(newStack, trail, hs)
case LocalTee(i) =>
val value :: newStack = stack
frame.locals(i) = value
kont(stack, trail, mkont, hs)
kont(stack, trail, hs)
case GlobalGet(i) =>
kont(module.globals(i).value :: stack, trail, mkont, hs)
kont(module.globals(i).value :: stack, trail, hs)
case GlobalSet(i) =>
val value :: newStack = stack
module.globals(i).ty match {
Expand All @@ -66,107 +67,107 @@ case class EvaluatorFX(module: ModuleInstance) {
case GlobalType(_, true) => throw new Exception("Invalid type")
case _ => throw new Exception("Cannot set immutable global")
}
kont(newStack, trail, mkont, hs)
kont(newStack, trail, hs)
case MemorySize =>
kont(I32V(module.memory.head.size) :: stack, trail, mkont, hs)
kont(I32V(module.memory.head.size) :: stack, trail, hs)
case MemoryGrow =>
val I32V(delta) :: newStack = stack
val mem = module.memory.head
val oldSize = mem.size
mem.grow(delta) match {
case Some(e) => kont(I32V(-1) :: newStack, trail, mkont, hs)
case _ => kont(I32V(oldSize) :: newStack, trail, mkont, hs)
case Some(e) => kont(I32V(-1) :: newStack, trail, hs)
case _ => kont(I32V(oldSize) :: newStack, trail, hs)
}
case MemoryFill =>
val I32V(value) :: I32V(offset) :: I32V(size) :: newStack = stack
if (memOutOfBound(module, 0, offset, size))
throw new Exception("Out of bounds memory access") // GW: turn this into a `trap`?
else {
module.memory.head.fill(offset, size, value.toByte)
kont(newStack, trail, mkont, hs)
kont(newStack, trail, hs)
}
case MemoryCopy =>
val I32V(n) :: I32V(src) :: I32V(dest) :: newStack = stack
if (memOutOfBound(module, 0, src, n) || memOutOfBound(module, 0, dest, n))
throw new Exception("Out of bounds memory access")
else {
module.memory.head.copy(dest, src, n)
kont(newStack, trail, mkont, hs)
kont(newStack, trail, hs)
}
case Const(n) => kont(n :: stack, trail, mkont, hs)
case Const(n) => kont(n :: stack, trail, hs)
case Binary(op) =>
val v2 :: v1 :: newStack = stack
kont(evalBinOp(op, v1, v2) :: newStack, trail, mkont, hs)
kont(evalBinOp(op, v1, v2) :: newStack, trail, hs)
case Unary(op) =>
val v :: newStack = stack
kont(evalUnaryOp(op, v) :: newStack, trail, mkont, hs)
kont(evalUnaryOp(op, v) :: newStack, trail, hs)
case Compare(op) =>
val v2 :: v1 :: newStack = stack
kont(evalRelOp(op, v1, v2) :: newStack, trail, mkont, hs)
kont(evalRelOp(op, v1, v2) :: newStack, trail, hs)
case Test(op) =>
val v :: newStack = stack
kont(evalTestOp(op, v) :: newStack, trail, mkont, hs)
kont(evalTestOp(op, v) :: newStack, trail, hs)
case Store(StoreOp(align, offset, ty, None)) =>
val I32V(v) :: I32V(addr) :: newStack = stack
module.memory(0).storeInt(addr + offset, v)
kont(newStack, trail, mkont, hs)
kont(newStack, trail, hs)
case Load(LoadOp(align, offset, ty, None, None)) =>
val I32V(addr) :: newStack = stack
val value = module.memory(0).loadInt(addr + offset)
kont(I32V(value) :: newStack, trail, mkont, hs)
case Nop => kont(stack, trail, mkont, hs)
kont(I32V(value) :: newStack, trail, hs)
case Nop => kont(stack, trail, hs)
case Unreachable => throw Trap()
case Block(ty, inner) =>
val funcTy = getFuncType(ty)
val (inputs, restStack) = stack.splitAt(funcTy.inps.size)
val escape: Cont[Ans] = (s1, t1, m1, h1) => kont(s1.take(funcTy.out.size) ++ restStack, t1, m1, h1)
evalList(inner, inputs, frame, escape, trail, mkont, escape::brTable, hs)
val escape: Cont[Ans] = (s1, t1, h1) => kont(s1.take(funcTy.out.size) ++ restStack, t1, h1)
evalList(inner, inputs, frame, escape, trail, escape::brTable, hs)
case Loop(ty, inner) =>
val funcTy = getFuncType(ty)
val (inputs, restStack) = stack.splitAt(funcTy.inps.size)
val escape: Cont[Ans] = (s1, t1, m1, h1) => kont(s1.take(funcTy.out.size) ++ restStack, t1, m1, h1)
def loop(retStack: List[Value], trail1: Trail[Ans], mkont: MCont[Ans], h1: Handlers[Ans]): Ans =
evalList(inner, retStack.take(funcTy.inps.size), frame, escape, trail, mkont, (loop _ : Cont[Ans])::brTable, h1)
loop(inputs, trail, mkont, hs)
val escape: Cont[Ans] = (s1, t1, h1) => kont(s1.take(funcTy.out.size) ++ restStack, t1, h1)
def loop(retStack: List[Value], trail1: Trail[Ans], h1: Handlers[Ans]): Ans =
evalList(inner, retStack.take(funcTy.inps.size), frame, escape, trail, (loop _ : Cont[Ans])::brTable, h1)
loop(inputs, trail, hs)
case If(ty, thn, els) =>
val funcTy = getFuncType(ty)
val I32V(cond) :: newStack = stack
val inner = if (cond != 0) thn else els
val (inputs, restStack) = newStack.splitAt(funcTy.inps.size)
val escape: Cont[Ans] = (s1, t1, m1, h1) => kont(s1.take(funcTy.out.size) ++ restStack, t1, m1, h1)
evalList(inner, inputs, frame, escape, trail, mkont, escape::brTable, hs)
val escape: Cont[Ans] = (s1, t1, h1) => kont(s1.take(funcTy.out.size) ++ restStack, t1, h1)
evalList(inner, inputs, frame, escape, trail, escape::brTable, hs)
case Br(label) =>
brTable(label)(stack, trail, mkont, hs)
brTable(label)(stack, trail, hs)
case BrIf(label) =>
val I32V(cond) :: newStack = stack
if (cond != 0) brTable(label)(newStack, trail, mkont, hs)
else kont(newStack, trail, mkont, hs)
if (cond != 0) brTable(label)(newStack, trail, hs)
else kont(newStack, trail, hs)
case BrTable(labels, default) =>
val I32V(cond) :: newStack = stack
val goto = if (cond < labels.length) labels(cond) else default
brTable(goto)(newStack, trail, mkont, hs)
brTable(goto)(newStack, trail, hs)
case Return =>
brTable.last(stack, trail, mkont, hs)
case Call(f) => evalCall1(f, stack, frame, kont, trail, mkont, brTable, hs, false)
brTable.last(stack, trail, hs)
case Call(f) => evalCall1(f, stack, frame, kont, trail, brTable, hs, false)
case ReturnCall(f) =>
// System.err.println(s"[DEBUG] return call: $f")
evalCall1(f, stack, frame, kont, trail, mkont, brTable, hs, true)
evalCall1(f, stack, frame, kont, trail, brTable, hs, true)
case RefFunc(f) =>
// TODO: RefFuncV stores an applicable function, instead of a syntactic structure
kont(RefFuncV(f) :: stack, trail, mkont, hs)
kont(RefFuncV(f) :: stack, trail, hs)
// WasmFX effect handlers:
case ContNew(ty) =>
val RefFuncV(f) :: newStack = stack
def kr(s: Stack, k1: Cont[Ans], t1: Trail[Ans], m1: MCont[Ans], hs: Handlers[Ans]): Ans = {
evalCall1(f, s, frame/*?*/, k1, t1, m1, List(), hs, false)
def kr(s: Stack, k1: Cont[Ans], t1: Trail[Ans], hs: Handlers[Ans]): Ans = {
evalCall1(f, s, frame/*?*/, k1, t1, List(), hs, false)
}
kont(ContV(kr) :: newStack, trail, mkont, hs)
kont(ContV(kr) :: newStack, trail, hs)
case Suspend(tagId) =>
val FuncType(_, inps, out) = module.tags(tagId)
val (inputs, restStack) = stack.splitAt(inps.size)
// System.err.println(s"[DEBUG] handlers: $hs")
// System.err.println(s"[DEBUG] trail: $trail")
val kr = (s: Stack, _: Cont[Ans], t1: Trail[Ans], m1: MCont[Ans], hs1: Handlers[Ans]) => {
val kr = (s: Stack, _: Cont[Ans], t1: Trail[Ans], hs1: Handlers[Ans]) => {
// construct a new trail by ignoring the default handler
val index = trail.indexWhere { case (_, tags) => tags.contains(tagId) }
val newTrail = if (index >= 0) trail.take(index) else trail
Expand All @@ -175,7 +176,7 @@ case class EvaluatorFX(module: ModuleInstance) {
// Q: Should we clear tags in the `newTrail`? Is that possible suspend target tag in hs1 but also in newTrail?
// A: Yes, we should maintain the consistency between `hs1` and `newTrail + t1`.
// mkont lost here, and it's safe if we never modify it
kont(s ++ restStack, newTrail.map({ case (c, _) => (c, List()) }) ++ t1, m1, hs1)
kont(s ++ restStack, newTrail.map({ case (c, _) => (c, List()) }) ++ t1, hs1)
}
val newStack = ContV(kr) :: inputs
hs.find(_._1 == tagId) match {
Expand All @@ -191,12 +192,12 @@ case class EvaluatorFX(module: ModuleInstance) {
val (inputs, restStack) = newStack.splitAt(inps.size)
val newHs: List[(Int, Handler[Ans])] = handler.map {
case Handler(tagId, labelId) =>
val hh: Handler[Ans] = s1 => brTable(labelId)(s1, trail, mkont/*???*/, hs)
val hh: Handler[Ans] = s1 => brTable(labelId)(s1, trail, hs)
(tagId, hh)
}
val tags = handler.map(_.tag)
// rather than push `kont` to meta-continuation, maybe we can push it to `trail`?
f.k(inputs, initK, List((kont,tags)) ++ trail, mkont, newHs ++ hs)
f.k(inputs, initK, List((kont,tags)) ++ trail, newHs ++ hs)

case ContBind(oldContTyId, newConTyId) =>
val (f: ContV[Ans]) :: newStack = stack
Expand All @@ -209,29 +210,28 @@ case class EvaluatorFX(module: ModuleInstance) {
val inputSize = oldParamTy.size - newParamTy.size
val (inputs, restStack) = newStack.splitAt(inputSize)
// partially apply the old continuation
def kr(s: Stack, k1: Cont[Ans], t1: Trail[Ans], mk: MCont[Ans], handlers: Handlers[Ans]): Ans = {
f.k(s ++ inputs, k1, t1, mk, handlers)
def kr(s: Stack, k1: Cont[Ans], t1: Trail[Ans], handlers: Handlers[Ans]): Ans = {
f.k(s ++ inputs, k1, t1, handlers)
}
kont(ContV(kr) :: restStack, trail, mkont, hs)
kont(ContV(kr) :: restStack, trail, hs)

case CallRef(ty) =>
val RefFuncV(f) :: newStack = stack
evalCall1(f, newStack, frame, kont, trail, mkont, brTable, hs, false)
evalCall1(f, newStack, frame, kont, trail, brTable, hs, false)

case _ =>
println(inst)
throw new Exception(s"instruction $inst not implemented")
}
}

def evalList[Ans](insts: List[Instr], stack: Stack, frame: Frame,
kont: Cont[Ans], trail1: Trail[Ans], mkont: MCont[Ans],
brTable: List[Cont[Ans]], hs: Handlers[Ans]): Ans = {
def evalList[Ans](insts: List[Instr], stack: Stack, frame: Frame, kont: Cont[Ans],
trail1: Trail[Ans], brTable: List[Cont[Ans]], hs: Handlers[Ans]): Ans = {
insts match {
case Nil => kont(stack, trail1, mkont, hs)
case Nil => kont(stack, trail1, hs)
case inst :: rest =>
val newKont: Cont[Ans] = (s1, t1, m1, h1) => evalList(rest, s1, frame, kont, t1, m1, brTable, h1)
eval1(inst, stack, frame, newKont, trail1, mkont, brTable, hs)
val newKont: Cont[Ans] = (s1, t1, h1) => evalList(rest, s1, frame, kont, t1, brTable, h1)
eval1(inst, stack, frame, newKont, trail1, brTable, hs)
}
}

Expand All @@ -240,7 +240,6 @@ case class EvaluatorFX(module: ModuleInstance) {
frame: Frame,
kont: Cont[Ans],
trail: Trail[Ans],
mkont: MCont[Ans],
brTable: List[Cont[Ans]], // can be removed
h: Handlers[Ans],
isTail: Boolean): Ans =
Expand All @@ -252,31 +251,31 @@ case class EvaluatorFX(module: ModuleInstance) {
val newFrame = Frame(ArrayBuffer(frameLocals: _*))
if (isTail) {
// when tail call, share the continuation for returning with the callee
evalList(body, List(), newFrame, brTable.last, trail, mkont, List(brTable.last), h)
evalList(body, List(), newFrame, brTable.last, trail, List(brTable.last), h)
}
else {
val restK: Cont[Ans] = (s1, t1, m1, h1) => kont(s1.take(ty.out.size) ++ newStack, t1, m1, h1)
val restK: Cont[Ans] = (s1, t1, h1) => kont(s1.take(ty.out.size) ++ newStack, t1, h1)
// We make a new brTable by `restK`, since function creates a new block to escape
// (more or less like `return`)
evalList(body, List(), newFrame, restK, trail, mkont, List(restK), h)
evalList(body, List(), newFrame, restK, trail, List(restK), h)
}
case Import("console", "log", _) =>
// println(s"[DEBUG] current stack: $stack")
val I32V(v) :: newStack = stack
println(v)
kont(newStack, trail, mkont, h)
kont(newStack, trail, h)
case Import("spectest", "print_i32", _) =>
// println(s"[DEBUG] current stack: $stack")
val I32V(v) :: newStack = stack
println(v)
kont(newStack, trail, mkont, h)
kont(newStack, trail, h)
case Import(_, _, _) => throw new Exception(s"Unknown import at $funcIndex")
case _ => throw new Exception(s"Definition at $funcIndex is not callable")
}

// If `main` is given, then we use that function as the entry point of the program;
// otherwise, we look up the top-level `start` instruction to locate the entry point.
def evalTop[Ans](halt: Cont[Ans], mhalt: MCont[Ans], main: Option[String] = None): Ans = {
def evalTop[Ans](halt: Cont[Ans], main: Option[String] = None): Ans = {
val instrs = main match {
case Some(func_name) =>
module.defs.flatMap({
Expand Down Expand Up @@ -326,9 +325,14 @@ case class EvaluatorFX(module: ModuleInstance) {
if (instrs.isEmpty) println("Warning: nothing is executed")
// initialized locals
val frame = Frame(ArrayBuffer(locals.map(zero(_)): _*))
evalList(instrs, List(), frame, halt, List(), mhalt, List(halt), List())
evalList(instrs, List(), frame, initK[Ans], List((halt, List())), List(initK: Cont[Ans]), List())
}

def evalTop(m: ModuleInstance): Unit = evalTop(initK[Unit], stack => ())
def evalTop(m: ModuleInstance): Unit =
evalTop(((stack, trail, _hs) => {
if (!trail.isEmpty) {
throw new Exception("Composing something after halt continuation")
}
}): Cont[Unit])
}

4 changes: 2 additions & 2 deletions src/main/scala/wasm/MiniWasmScript.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ sealed class ScriptRunner {
type MCont = evaluator.MCont[evaluator.Stack]
type Handler = evaluator.Handler[evaluator.Stack]
val k: Cont = evaluator.initK
val mk: MCont = (retStack) => retStack
val halt: Cont = (retStack, _, _) => retStack
// Note: change this back to Evaluator if we are just testing original stuff
evaluator.evalList(instrs, List(), Frame(ArrayBuffer(args: _*)), k, List(), mk, List(k), List())
evaluator.evalList(instrs, List(), Frame(ArrayBuffer(args: _*)), k, List((halt, List())), List(k), List())
}

def runCmd(cmd: Cmd): Unit = {
Expand Down
10 changes: 7 additions & 3 deletions src/test/scala/genwasym/TestFx.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,20 @@ class TestFx extends FunSuite {
val evaluator = EvaluatorFX(ModuleInstance(module))
type Cont = evaluator.Cont[Unit]
type MCont = evaluator.MCont[Unit]
val haltK: Cont = evaluator.initK
val haltMK: MCont = (stack) => {
val haltK: Cont = (stack, trail, _hs) => {
if (!trail.isEmpty) {
// this throw will never reach, trail will never been appended
System.err.println(s"[Debug]: $trail")
throw new Exception("Trail is not empty")
}
// println(s"halt cont: $stack")
expected match {
case ExpInt(e) => assert(stack(0) == I32V(e))
case ExpStack(e) => assert(stack == e)
case Ignore => ()
}
}
evaluator.evalTop(haltK, haltMK, main)
evaluator.evalTop(haltK, main)
}

// So far it assumes that the output is multi-line integers
Expand Down