diff --git a/.github/workflows/scala.yml b/.github/workflows/scala.yml index 6594dc8ae..a691b7444 100644 --- a/.github/workflows/scala.yml +++ b/.github/workflows/scala.yml @@ -71,3 +71,4 @@ jobs: sbt 'testOnly gensym.TestImpCPSGS_Z3' sbt 'testOnly gensym.TestLibrary' sbt 'testOnly gensym.wasm.TestEval' + sbt 'testOnly gensym.wasm.TestScriptRun' diff --git a/benchmarks/wasm/script/script_basic.wast b/benchmarks/wasm/script/script_basic.wast new file mode 100644 index 000000000..4d1d1cb55 --- /dev/null +++ b/benchmarks/wasm/script/script_basic.wast @@ -0,0 +1,8 @@ +(module + (func $one (result i32) + i32.const 1) + (export "one" (func 0)) +) + +(assert_return (invoke "one") (i32.const 1)) + diff --git a/src/main/scala/wasm/AST.scala b/src/main/scala/wasm/AST.scala index bb4d606c2..fb4c3fdbf 100644 --- a/src/main/scala/wasm/AST.scala +++ b/src/main/scala/wasm/AST.scala @@ -292,3 +292,17 @@ case class ExportFunc(i: Int) extends ExportDesc case class ExportTable(i: Int) extends ExportDesc case class ExportMemory(i: Int) extends ExportDesc case class ExportGlobal(i: Int) extends ExportDesc + +case class Script(cmds: List[Cmd]) extends WIR +abstract class Cmd extends WIR +// TODO: can we turn abstract class sealed? +case class CmdModule(module: Module) extends Cmd + +abstract class Action extends WIR +case class Invoke(instName: Option[String], name: String, args: List[Value]) extends Action + +abstract class Assertion extends Cmd +case class AssertReturn(action: Action, expect: List[Num] /* TODO: support multiple expect result type*/) + extends Assertion +case class AssertTrap(action: Action, message: String) extends Assertion + diff --git a/src/main/scala/wasm/MiniWasm.scala b/src/main/scala/wasm/MiniWasm.scala index 642936242..ba2b7a85b 100644 --- a/src/main/scala/wasm/MiniWasm.scala +++ b/src/main/scala/wasm/MiniWasm.scala @@ -18,6 +18,44 @@ case class ModuleInstance( exports: List[Export] = List() ) +object ModuleInstance { + def apply(module: Module): ModuleInstance = { + val types = List() + val funcs = module.definitions + .collect({ + case FuncDef(_, fndef @ FuncBodyDef(_, _, _, _)) => fndef + }) + .toList + + val globals = module.definitions + .collect({ + case Global(_, GlobalValue(ty, e)) => + (e.head) match { + case Const(c) => RTGlobal(ty, c) + // Q: What is the default behavior if case in non-exhaustive + case _ => ??? + } + }) + .toList + + // TODO: correct the behavior for memory + val memory = module.definitions + .collect({ + case Memory(id, MemoryType(min, max_opt)) => + RTMemory(min, max_opt) + }) + .toList + + val exports = module.definitions + .collect({ + case e @ Export(_, ExportFunc(_)) => e + }) + .toList + + ModuleInstance(types, module.funcEnv, memory, globals, exports) + } +} + object Primtives { def evalBinOp(op: BinOp, lhs: Value, rhs: Value): Value = op match { case Add(_) => @@ -412,33 +450,7 @@ object Evaluator { if (instrs.isEmpty) println("Warning: nothing is executed") - val types = List() - val funcs = module.definitions - .collect({ - case FuncDef(_, fndef @ FuncBodyDef(_, _, _, _)) => fndef - }) - .toList - - val globals = module.definitions - .collect({ - case Global(_, GlobalValue(ty, e)) => - (e.head) match { - case Const(c) => RTGlobal(ty, c) - // Q: What is the default behavior if case in non-exhaustive - case _ => ??? - } - }) - .toList - - // TODO: correct the behavior for memory - val memory = module.definitions - .collect({ - case Memory(id, MemoryType(min, max_opt)) => - RTMemory(min, max_opt) - }) - .toList - - val moduleInst = ModuleInstance(types, module.funcEnv, memory, globals) + val moduleInst = ModuleInstance(module) Evaluator.eval(instrs, List(), Frame(moduleInst, ArrayBuffer(I32V(0))), halt, List(halt)) } diff --git a/src/main/scala/wasm/MiniWasmScript.scala b/src/main/scala/wasm/MiniWasmScript.scala new file mode 100644 index 000000000..26103f080 --- /dev/null +++ b/src/main/scala/wasm/MiniWasmScript.scala @@ -0,0 +1,49 @@ +package gensym.wasm.miniwasmscript + +import gensym.wasm.miniwasm._ +import gensym.wasm.ast._ +import scala.collection.mutable.{ListBuffer, Map, ArrayBuffer} + +sealed class ScriptRunner { + val instances: ListBuffer[ModuleInstance] = ListBuffer() + val instanceMap: Map[String, ModuleInstance] = Map() + + def getInstance(instName: Option[String]): ModuleInstance = { + instName match { + case Some(name) => instanceMap(name) + case None => instances.head + } + } + + def assertReturn(action: Action, expect: List[Value]): Unit = { + action match { + case Invoke(instName, name, args) => + val module = getInstance(instName) + val func = module.exports.collectFirst({ + case Export(`name`, ExportFunc(index)) => + module.funcs(index) + case _ => throw new RuntimeException("Not Supported") + }).get + val instrs = func match { + case FuncDef(_, FuncBodyDef(ty, _, locals, body)) => body + } + val k = (retStack: List[Value]) => retStack + val actual = Evaluator.eval(instrs, List(), Frame(module, ArrayBuffer(args: _*)), k, List(k)) + assert(actual == expect) + } + } + + def runCmd(cmd: Cmd): Unit = { + cmd match { + case CmdModule(module) => instances += ModuleInstance(module) + case AssertReturn(action, expect) => assertReturn(action, expect) + case AssertTrap(action, message) => ??? + } + } + + def run(script: Script): Unit = { + for (cmd <- script.cmds) { + runCmd(cmd) + } + } +} diff --git a/src/main/scala/wasm/Parser.scala b/src/main/scala/wasm/Parser.scala index a2019de26..079370e72 100644 --- a/src/main/scala/wasm/Parser.scala +++ b/src/main/scala/wasm/Parser.scala @@ -633,7 +633,59 @@ class GSWasmVisitor extends WatParserBaseVisitor[WIR] { else if (ctx.MEMORY != null) ExportMemory(id) else if (ctx.GLOBAL != null) ExportGlobal(id) else error + } + + override def visitScriptModule(ctx: ScriptModuleContext): Module = { + if (ctx.module_ != null) { + visitModule_(ctx.module_).asInstanceOf[Module] + } else { + throw new RuntimeException("Unsupported") + } + } + + override def visitAction_(ctx: Action_Context): Action = { + if (ctx.INVOKE != null) { + val instName = if (ctx.VAR != null) Some(ctx.VAR().getText) else None + var name = ctx.name.getText.substring(1).dropRight(1) + var args = for (constCtx <- ctx.constList.wconst.asScala) yield { + val Array(ty, _) = constCtx.CONST.getText.split("\\.") + visitLiteralWithType(constCtx.literal, toNumType(ty)) + } + Invoke(instName, name, args.toList) + } else { + throw new RuntimeException("Unsupported") + } + } + + override def visitAssertion(ctx: AssertionContext): Assertion = { + if (ctx.ASSERT_RETURN != null) { + val action = visitAction_(ctx.action_) + val expect = for (constCtx <- ctx.constList.wconst.asScala) yield { + val Array(ty, _) = constCtx.CONST.getText.split("\\.") + visitLiteralWithType(constCtx.literal, toNumType(ty)) + } + println(s"expect = $expect") + AssertReturn(action, expect.toList) + } else { + throw new RuntimeException("Unsupported") + } + } + override def visitCmd(ctx: CmdContext): Cmd = { + if (ctx.assertion != null) { + visitAssertion(ctx.assertion) + } else if (ctx.scriptModule != null) { + CmdModule(visitScriptModule(ctx.scriptModule)) + } else { + throw new RuntimeException("Unsupported") + } + } + + override def visitScript(ctx: ScriptContext): WIR = { + val cmds = for (cmd <- ctx.cmd.asScala) yield { + visitCmd(cmd) + } + Script(cmds.toList) } override def visitTag(ctx: TagContext): WIR = { @@ -645,15 +697,35 @@ class GSWasmVisitor extends WatParserBaseVisitor[WIR] { } object Parser { - def parse(input: String): Module = { + private def makeWatVisitor(input: String) = { val charStream = new ANTLRInputStream(input) val lexer = new WatLexer(charStream) val tokens = new CommonTokenStream(lexer) - val parser = new WatParser(tokens) + new WatParser(tokens) + } + + def parse(input: String): Module = { + val parser = makeWatVisitor(input) val visitor = new GSWasmVisitor() val res: Module = visitor.visit(parser.module).asInstanceOf[Module] res } def parseFile(filepath: String): Module = parse(scala.io.Source.fromFile(filepath).mkString) + + // parse extended webassembly script language + def parseScript(input: String): Option[Script] = { + val parser = makeWatVisitor(input) + val visitor = new GSWasmVisitor() + val tree = parser.script() + val errorNumer = parser.getNumberOfSyntaxErrors() + if (errorNumer != 0) None + else { + val res: Script = visitor.visitScript(tree).asInstanceOf[Script] + Some(res) + } + } + + def parseScriptFile(filepath: String): Option[Script] = + parseScript(scala.io.Source.fromFile(filepath).mkString) } diff --git a/src/test/scala/genwasym/TestScriptRun.scala b/src/test/scala/genwasym/TestScriptRun.scala new file mode 100644 index 000000000..2c1e3f9d2 --- /dev/null +++ b/src/test/scala/genwasym/TestScriptRun.scala @@ -0,0 +1,19 @@ +package gensym.wasm + +import gensym.wasm.parser.Parser +import gensym.wasm.miniwasmscript.ScriptRunner + +import org.scalatest.FunSuite + + +class TestScriptRun extends FunSuite { + def testFile(filename: String): Unit = { + val script = Parser.parseScriptFile(filename).get + val runner = new ScriptRunner() + runner.run(script) + } + + test("simple script") { + testFile("./benchmarks/wasm/script/script_basic.wast") + } +}