Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/scala.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,4 @@ jobs:
sbt 'testOnly gensym.TestImpCPSGS_Z3'
sbt 'testOnly gensym.TestLibrary'
sbt 'testOnly gensym.wasm.TestEval'
sbt 'testOnly gensym.wasm.TestScriptRun'
8 changes: 8 additions & 0 deletions benchmarks/wasm/script/script_basic.wast
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
(module
(func $one (result i32)
i32.const 1)
(export "one" (func 0))
)

(assert_return (invoke "one") (i32.const 1))

14 changes: 14 additions & 0 deletions src/main/scala/wasm/AST.scala
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,17 @@ case class ExportFunc(i: Int) extends ExportDesc
case class ExportTable(i: Int) extends ExportDesc
case class ExportMemory(i: Int) extends ExportDesc
case class ExportGlobal(i: Int) extends ExportDesc

case class Script(cmds: List[Cmd]) extends WIR
abstract class Cmd extends WIR
// TODO: can we turn abstract class sealed?
case class CmdModule(module: Module) extends Cmd

abstract class Action extends WIR
case class Invoke(instName: Option[String], name: String, args: List[Value]) extends Action

abstract class Assertion extends Cmd
case class AssertReturn(action: Action, expect: List[Num] /* TODO: support multiple expect result type*/)
extends Assertion
case class AssertTrap(action: Action, message: String) extends Assertion

66 changes: 39 additions & 27 deletions src/main/scala/wasm/MiniWasm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,44 @@ case class ModuleInstance(
exports: List[Export] = List()
)

object ModuleInstance {
def apply(module: Module): ModuleInstance = {
val types = List()
val funcs = module.definitions
.collect({
case FuncDef(_, fndef @ FuncBodyDef(_, _, _, _)) => fndef
})
.toList

val globals = module.definitions
.collect({
case Global(_, GlobalValue(ty, e)) =>
(e.head) match {
case Const(c) => RTGlobal(ty, c)
// Q: What is the default behavior if case in non-exhaustive
case _ => ???
}
})
.toList

// TODO: correct the behavior for memory
val memory = module.definitions
.collect({
case Memory(id, MemoryType(min, max_opt)) =>
RTMemory(min, max_opt)
})
.toList

val exports = module.definitions
.collect({
case e @ Export(_, ExportFunc(_)) => e
})
.toList

ModuleInstance(types, module.funcEnv, memory, globals, exports)
}
}

object Primtives {
def evalBinOp(op: BinOp, lhs: Value, rhs: Value): Value = op match {
case Add(_) =>
Expand Down Expand Up @@ -412,33 +450,7 @@ object Evaluator {

if (instrs.isEmpty) println("Warning: nothing is executed")

val types = List()
val funcs = module.definitions
.collect({
case FuncDef(_, fndef @ FuncBodyDef(_, _, _, _)) => fndef
})
.toList

val globals = module.definitions
.collect({
case Global(_, GlobalValue(ty, e)) =>
(e.head) match {
case Const(c) => RTGlobal(ty, c)
// Q: What is the default behavior if case in non-exhaustive
case _ => ???
}
})
.toList

// TODO: correct the behavior for memory
val memory = module.definitions
.collect({
case Memory(id, MemoryType(min, max_opt)) =>
RTMemory(min, max_opt)
})
.toList

val moduleInst = ModuleInstance(types, module.funcEnv, memory, globals)
val moduleInst = ModuleInstance(module)

Evaluator.eval(instrs, List(), Frame(moduleInst, ArrayBuffer(I32V(0))), halt, List(halt))
}
Expand Down
49 changes: 49 additions & 0 deletions src/main/scala/wasm/MiniWasmScript.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package gensym.wasm.miniwasmscript

import gensym.wasm.miniwasm._
import gensym.wasm.ast._
import scala.collection.mutable.{ListBuffer, Map, ArrayBuffer}

sealed class ScriptRunner {
val instances: ListBuffer[ModuleInstance] = ListBuffer()
val instanceMap: Map[String, ModuleInstance] = Map()

def getInstance(instName: Option[String]): ModuleInstance = {
instName match {
case Some(name) => instanceMap(name)
case None => instances.head
}
}

def assertReturn(action: Action, expect: List[Value]): Unit = {
action match {
case Invoke(instName, name, args) =>
val module = getInstance(instName)
val func = module.exports.collectFirst({
case Export(`name`, ExportFunc(index)) =>
module.funcs(index)
case _ => throw new RuntimeException("Not Supported")
}).get
val instrs = func match {
case FuncDef(_, FuncBodyDef(ty, _, locals, body)) => body
}
val k = (retStack: List[Value]) => retStack
val actual = Evaluator.eval(instrs, List(), Frame(module, ArrayBuffer(args: _*)), k, List(k))
assert(actual == expect)
}
}

def runCmd(cmd: Cmd): Unit = {
cmd match {
case CmdModule(module) => instances += ModuleInstance(module)
case AssertReturn(action, expect) => assertReturn(action, expect)
case AssertTrap(action, message) => ???
}
}

def run(script: Script): Unit = {
for (cmd <- script.cmds) {
runCmd(cmd)
}
}
}
76 changes: 74 additions & 2 deletions src/main/scala/wasm/Parser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,59 @@ class GSWasmVisitor extends WatParserBaseVisitor[WIR] {
else if (ctx.MEMORY != null) ExportMemory(id)
else if (ctx.GLOBAL != null) ExportGlobal(id)
else error
}

override def visitScriptModule(ctx: ScriptModuleContext): Module = {
if (ctx.module_ != null) {
visitModule_(ctx.module_).asInstanceOf[Module]
} else {
throw new RuntimeException("Unsupported")
}
}

override def visitAction_(ctx: Action_Context): Action = {
if (ctx.INVOKE != null) {
val instName = if (ctx.VAR != null) Some(ctx.VAR().getText) else None
var name = ctx.name.getText.substring(1).dropRight(1)
var args = for (constCtx <- ctx.constList.wconst.asScala) yield {
val Array(ty, _) = constCtx.CONST.getText.split("\\.")
visitLiteralWithType(constCtx.literal, toNumType(ty))
}
Invoke(instName, name, args.toList)
} else {
throw new RuntimeException("Unsupported")
}
}

override def visitAssertion(ctx: AssertionContext): Assertion = {
if (ctx.ASSERT_RETURN != null) {
val action = visitAction_(ctx.action_)
val expect = for (constCtx <- ctx.constList.wconst.asScala) yield {
val Array(ty, _) = constCtx.CONST.getText.split("\\.")
visitLiteralWithType(constCtx.literal, toNumType(ty))
}
println(s"expect = $expect")
AssertReturn(action, expect.toList)
} else {
throw new RuntimeException("Unsupported")
}
}

override def visitCmd(ctx: CmdContext): Cmd = {
if (ctx.assertion != null) {
visitAssertion(ctx.assertion)
} else if (ctx.scriptModule != null) {
CmdModule(visitScriptModule(ctx.scriptModule))
} else {
throw new RuntimeException("Unsupported")
}
}

override def visitScript(ctx: ScriptContext): WIR = {
val cmds = for (cmd <- ctx.cmd.asScala) yield {
visitCmd(cmd)
}
Script(cmds.toList)
}

override def visitTag(ctx: TagContext): WIR = {
Expand All @@ -645,15 +697,35 @@ class GSWasmVisitor extends WatParserBaseVisitor[WIR] {
}

object Parser {
def parse(input: String): Module = {
private def makeWatVisitor(input: String) = {
val charStream = new ANTLRInputStream(input)
val lexer = new WatLexer(charStream)
val tokens = new CommonTokenStream(lexer)
val parser = new WatParser(tokens)
new WatParser(tokens)
}

def parse(input: String): Module = {
val parser = makeWatVisitor(input)
val visitor = new GSWasmVisitor()
val res: Module = visitor.visit(parser.module).asInstanceOf[Module]
res
}

def parseFile(filepath: String): Module = parse(scala.io.Source.fromFile(filepath).mkString)

// parse extended webassembly script language
def parseScript(input: String): Option[Script] = {
val parser = makeWatVisitor(input)
val visitor = new GSWasmVisitor()
val tree = parser.script()
val errorNumer = parser.getNumberOfSyntaxErrors()
if (errorNumer != 0) None
else {
val res: Script = visitor.visitScript(tree).asInstanceOf[Script]
Some(res)
}
}

def parseScriptFile(filepath: String): Option[Script] =
parseScript(scala.io.Source.fromFile(filepath).mkString)
}
19 changes: 19 additions & 0 deletions src/test/scala/genwasym/TestScriptRun.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package gensym.wasm

import gensym.wasm.parser.Parser
import gensym.wasm.miniwasmscript.ScriptRunner

import org.scalatest.FunSuite


class TestScriptRun extends FunSuite {
def testFile(filename: String): Unit = {
val script = Parser.parseScriptFile(filename).get
val runner = new ScriptRunner()
runner.run(script)
}

test("simple script") {
testFile("./benchmarks/wasm/script/script_basic.wast")
}
}
Loading