diff --git a/README.md b/README.md index d3498d8..0239e64 100644 --- a/README.md +++ b/README.md @@ -2,5 +2,45 @@ This project is part of the https://compilerprogramming.github.io/ project. -The EZ (pronounced EeeZee) programming language is designed to allow us to learn various compiler techniques. +The EZ (pronounced EeZee) programming language is designed to allow us to learn various compiler techniques. + +## The EZ Language + +The EZ programming language is a tiny statically typed language with syntax inspired by Swift. +The language has the following features: + +* Integer, Struct and 1-Dimensional Array types +* If and While statements +* Functions + +The language syntax is described [ANTLR Grammar](antlr-parser/src/main/antlr4/com/compilerprogramming/ezlang/antlr/EZLanguage.g4). +The language is intentionally very simple and is meant to have just enough functionality to experiment with compiler implementation techniques. + +## Modules + +The project is under development and subject to change. At this point in time, we have following initial implementations: + +* lexer - a simple tokenizer +* parser - a recursive descent parser and AST +* types - the type definitions +* semantic - semantic analyzer +* stackvm - a bytecode compiler that generates stack IR (bytecode interpreter not yet available) +* registervm - a bytecode compiler that generates a linear register IR and a bytecode interpreter that can execute the IR + +## How can you contribute? + +Obviously firstly any contributes that improve and fix bugs are welcome. I am not keen on language extensions at this stage, but eventually +we will be extending the language to explore more advanced features. + +I am also interested in creating implementations of this project in C++, Go, Rust, swift, D, C, etc. If you are interested in working on such a +port please contact me via [Discussions](https://github.com/orgs/CompilerProgramming/discussions). + +## Community Discussions + +There is a [community discussion forum](https://github.com/orgs/CompilerProgramming/discussions). + +## What's next + +The project has only just got started, there is lots to do!. See the plan in the [website](https://compilerprogramming.github.io/). More documentation to follow, but for now please refer to the source code and the site above. + diff --git a/common/src/main/java/com/compilerprogramming/ezlang/exceptions/InterpreterException.java b/common/src/main/java/com/compilerprogramming/ezlang/exceptions/InterpreterException.java new file mode 100644 index 0000000..a08f059 --- /dev/null +++ b/common/src/main/java/com/compilerprogramming/ezlang/exceptions/InterpreterException.java @@ -0,0 +1,8 @@ +package com.compilerprogramming.ezlang.exceptions; + +public class InterpreterException extends RuntimeException { + public InterpreterException(String message) {super(message);} + public InterpreterException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/registervm/src/main/java/com/compilerprogramming/ezlang/bytecode/RegisterVMCompiler.java b/registervm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeCompiler.java similarity index 61% rename from registervm/src/main/java/com/compilerprogramming/ezlang/bytecode/RegisterVMCompiler.java rename to registervm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeCompiler.java index 7ca606a..3bef310 100644 --- a/registervm/src/main/java/com/compilerprogramming/ezlang/bytecode/RegisterVMCompiler.java +++ b/registervm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeCompiler.java @@ -1,14 +1,16 @@ package com.compilerprogramming.ezlang.bytecode; import com.compilerprogramming.ezlang.types.Symbol; +import com.compilerprogramming.ezlang.types.Type; import com.compilerprogramming.ezlang.types.TypeDictionary; -public class RegisterVMCompiler { +public class BytecodeCompiler { public void compile(TypeDictionary typeDictionary) { for (Symbol symbol: typeDictionary.getLocalSymbols()) { if (symbol instanceof Symbol.FunctionTypeSymbol functionSymbol) { - functionSymbol.code = new FunctionBuilder(functionSymbol); + Type.TypeFunction functionType = (Type.TypeFunction) functionSymbol.type; + functionType.code = new BytecodeFunction(functionSymbol); } } } diff --git a/registervm/src/main/java/com/compilerprogramming/ezlang/bytecode/FunctionBuilder.java b/registervm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeFunction.java similarity index 80% rename from registervm/src/main/java/com/compilerprogramming/ezlang/bytecode/FunctionBuilder.java rename to registervm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeFunction.java index 6c0da79..0513127 100644 --- a/registervm/src/main/java/com/compilerprogramming/ezlang/bytecode/FunctionBuilder.java +++ b/registervm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeFunction.java @@ -9,10 +9,12 @@ import java.util.ArrayList; import java.util.List; -public class FunctionBuilder { +public class BytecodeFunction { public BasicBlock entry; public BasicBlock exit; + public int maxLocalReg; + public int maxStackSize; private int bid = 0; private BasicBlock currentBlock; private BasicBlock currentBreakTarget; @@ -28,7 +30,7 @@ public class FunctionBuilder { */ private List virtualStack = new ArrayList<>(); - public FunctionBuilder(Symbol.FunctionTypeSymbol functionSymbol) { + public BytecodeFunction(Symbol.FunctionTypeSymbol functionSymbol) { AST.FuncDecl funcDecl = (AST.FuncDecl) functionSymbol.functionDecl; setVirtualRegisters(funcDecl.scope); this.bid = 0; @@ -40,6 +42,10 @@ public FunctionBuilder(Symbol.FunctionTypeSymbol functionSymbol) { exitBlockIfNeeded(); } + public int frameSize() { + return maxLocalReg+maxStackSize; + } + private void exitBlockIfNeeded() { if (currentBlock != null && currentBlock != exit) { @@ -57,6 +63,8 @@ private void setVirtualRegisters(Scope scope) { } } scope.maxReg = reg; + if (maxLocalReg < scope.maxReg) + maxLocalReg = scope.maxReg; for (Scope childScope: scope.children) { setVirtualRegisters(childScope); } @@ -78,9 +86,11 @@ private void compileBlock(AST.BlockStmt block) { private void compileReturn(AST.ReturnStmt returnStmt) { if (returnStmt.expr != null) { - compileExpr(returnStmt.expr); + boolean isIndexed = compileExpr(returnStmt.expr); + if (isIndexed) + codeIndexedLoad(); if (virtualStack.size() == 1) - code(new Instruction.Move(pop(), new Operand.ReturnRegisterOperand())); + code(new Instruction.Return(pop())); else if (virtualStack.size() > 1) throw new CompilerException("Virtual stack has more than one item at return"); } @@ -269,13 +279,13 @@ private boolean compileExpr(AST.Expr expr) { private boolean compileCallExpr(AST.CallExpr callExpr) { compileExpr(callExpr.callee); - var callee = top(); - if (!(callee instanceof Operand.TempRegisterOperand) ) { - var origCallee = pop(); - callee = createTemp(); - code(new Instruction.Move(origCallee, callee)); - } - List args = new ArrayList<>(); + var callee = pop(); + Type.TypeFunction calleeType = null; + if (callee instanceof Operand.LocalFunctionOperand functionOperand) + calleeType = functionOperand.functionType; + else throw new CompilerException("Cannot call a non function type"); + var returnStackPos = virtualStack.size(); + List args = new ArrayList<>(); for (AST.Expr expr: callExpr.args) { boolean indexed = compileExpr(expr); if (indexed) @@ -283,18 +293,21 @@ private boolean compileCallExpr(AST.CallExpr callExpr) { var arg = top(); if (!(arg instanceof Operand.TempRegisterOperand) ) { var origArg = pop(); - arg = createTemp(); + arg = createTemp(origArg.type); code(new Instruction.Move(origArg, arg)); } - args.add(arg); + args.add((Operand.RegisterOperand) arg); } - code(new Instruction.Call(callee, args.toArray(new Operand[args.size()]))); - // Similute the actions on the stack - for (int i = 0; i < args.size()+1; i++) + // Simulate the actions on the stack + for (int i = 0; i < args.size(); i++) pop(); + Operand.TempRegisterOperand ret = null; if (callExpr.callee.type instanceof Type.TypeFunction tf && - tf.returnType != null) - createTemp(); + !(tf.returnType instanceof Type.TypeVoid)) { + ret = createTemp(tf.returnType); + assert ret.regnum-maxLocalReg == returnStackPos; + } + code(new Instruction.Call(returnStackPos, ret, calleeType, args.toArray(new Operand.RegisterOperand[args.size()]))); return false; } @@ -347,13 +360,20 @@ private boolean compileSetFieldExpr(AST.SetFieldExpr setFieldExpr) { } private void codeNew(Type type) { - var temp = createTemp(); - code(new Instruction.Move(new Operand.NewTypeOperand(type), temp)); + var temp = createTemp(type); + if (type instanceof Type.TypeArray typeArray) { + code(new Instruction.NewArray(typeArray, temp)); + } + else if (type instanceof Type.TypeStruct typeStruct) { + code(new Instruction.NewStruct(typeStruct, temp)); + } + else + throw new CompilerException("Unexpected type: " + type); } private void codeStoreAppend() { var operand = pop(); - code(new Instruction.AStoreAppend(top(), operand)); + code(new Instruction.AStoreAppend((Operand.RegisterOperand) top(), operand)); } private boolean compileNewExpr(AST.NewExpr newExpr) { @@ -399,7 +419,7 @@ private boolean compileBinaryExpr(AST.BinaryExpr binaryExpr) { Operand right = pop(); Operand left = pop(); if (left instanceof Operand.ConstantOperand leftconstant && - right instanceof Operand.ConstantOperand rightconstant) { + right instanceof Operand.ConstantOperand rightconstant) { long value = 0; switch (opCode) { case "+": value = leftconstant.value + rightconstant.value; break; @@ -415,11 +435,11 @@ private boolean compileBinaryExpr(AST.BinaryExpr binaryExpr) { case ">=": value = leftconstant.value <= rightconstant.value ? 1 : 0; break; default: throw new CompilerException("Invalid binary op"); } - pushConstant(value); + pushConstant(value, leftconstant.type); } else { - var temp = createTemp(); - code(new Instruction.BinaryInstruction(opCode, temp, left, right)); + var temp = createTemp(binaryExpr.type); + code(new Instruction.Binary(opCode, temp, left, right)); } return false; } @@ -433,35 +453,38 @@ private boolean compileUnaryExpr(AST.UnaryExpr unaryExpr) { Operand top = pop(); if (top instanceof Operand.ConstantOperand constant) { switch (opCode) { - case "-": pushConstant(-constant.value); break; - case "!": pushConstant(constant.value == 0?1:0); break; + case "-": pushConstant(-constant.value, constant.type); break; + // Maybe below we should explicitly set Int + case "!": pushConstant(constant.value == 0?1:0, constant.type); break; default: throw new CompilerException("Invalid unary op"); } } else { - var temp = createTemp(); - code(new Instruction.UnaryInstruction(opCode, temp, top)); + var temp = createTemp(unaryExpr.type); + code(new Instruction.Unary(opCode, temp, top)); } return false; } private boolean compileConstantExpr(AST.LiteralExpr constantExpr) { - pushConstant(constantExpr.value.num.intValue()); + pushConstant(constantExpr.value.num.intValue(), constantExpr.type); return false; } - private void pushConstant(long value) { - virtualStack.add(new Operand.ConstantOperand(value)); + private void pushConstant(long value, Type type) { + pushOperand(new Operand.ConstantOperand(value, type)); } - private Operand.TempRegisterOperand createTemp() { - var tempRegister = new Operand.TempRegisterOperand(virtualStack.size()); - virtualStack.add(tempRegister); + private Operand.TempRegisterOperand createTemp(Type type) { + var tempRegister = new Operand.TempRegisterOperand(virtualStack.size()+maxLocalReg, type); + pushOperand(tempRegister); + if (maxStackSize < virtualStack.size()) + maxStackSize = virtualStack.size(); return tempRegister; } private void pushLocal(int regnum, String varName) { - virtualStack.add(new Operand.LocalRegisterOperand(regnum, varName)); + pushOperand(new Operand.LocalRegisterOperand(regnum, varName)); } private void pushOperand(Operand operand) { @@ -478,14 +501,28 @@ private Operand top() { private void codeIndexedLoad() { Operand indexed = pop(); - var temp = createTemp(); - code(new Instruction.Move(indexed, temp)); + var temp = createTemp(indexed.type); + if (indexed instanceof Operand.LoadIndexedOperand loadIndexedOperand) { + code(new Instruction.ArrayLoad(loadIndexedOperand, temp)); + } + else if (indexed instanceof Operand.LoadFieldOperand loadFieldOperand) { + code(new Instruction.GetField(loadFieldOperand, temp)); + } + else + code(new Instruction.Move(indexed, temp)); } private void codeIndexedStore() { Operand value = pop(); Operand indexed = pop(); - code(new Instruction.Move(value, indexed)); + if (indexed instanceof Operand.LoadIndexedOperand loadIndexedOperand) { + code(new Instruction.ArrayStore(value, loadIndexedOperand)); + } + else if (indexed instanceof Operand.LoadFieldOperand loadFieldOperand) { + code(new Instruction.SetField(value, loadFieldOperand)); + } + else + code(new Instruction.Move(value, indexed)); } private boolean vstackEmpty() { diff --git a/registervm/src/main/java/com/compilerprogramming/ezlang/bytecode/Instruction.java b/registervm/src/main/java/com/compilerprogramming/ezlang/bytecode/Instruction.java index 8f5ba7b..b1e220a 100644 --- a/registervm/src/main/java/com/compilerprogramming/ezlang/bytecode/Instruction.java +++ b/registervm/src/main/java/com/compilerprogramming/ezlang/bytecode/Instruction.java @@ -1,5 +1,7 @@ package com.compilerprogramming.ezlang.bytecode; +import com.compilerprogramming.ezlang.types.Type; + public abstract class Instruction { public boolean isTerminal() { @@ -23,11 +25,134 @@ public StringBuilder toStr(StringBuilder sb) { } } - public static class UnaryInstruction extends Instruction { + public static class NewArray extends Instruction { + public final Type.TypeArray type; + public final Operand.RegisterOperand destOperand; + public NewArray(Type.TypeArray type, Operand.RegisterOperand destOperand) { + this.type = type; + this.destOperand = destOperand; + } + @Override + public StringBuilder toStr(StringBuilder sb) { + return sb.append(destOperand) + .append(" = ") + .append("New(") + .append(type) + .append(")"); + } + } + + public static class NewStruct extends Instruction { + public final Type.TypeStruct type; + public final Operand.RegisterOperand destOperand; + public NewStruct(Type.TypeStruct type, Operand.RegisterOperand destOperand) { + this.type = type; + this.destOperand = destOperand; + } + @Override + public StringBuilder toStr(StringBuilder sb) { + return sb.append(destOperand) + .append(" = ") + .append("New(") + .append(type) + .append(")"); + } + } + + public static class ArrayLoad extends Instruction { + public final Operand arrayOperand; + public final Operand indexOperand; + public final Operand.RegisterOperand destOperand; + public ArrayLoad(Operand.LoadIndexedOperand from, Operand.RegisterOperand to) { + arrayOperand = from.arrayOperand; + indexOperand = from.indexOperand; + destOperand = to; + } + @Override + public StringBuilder toStr(StringBuilder sb) { + return sb.append(destOperand) + .append(" = ") + .append(arrayOperand) + .append("[") + .append(indexOperand) + .append("]"); + } + } + + public static class ArrayStore extends Instruction { + public final Operand arrayOperand; + public final Operand indexOperand; + public final Operand sourceOperand; + public ArrayStore(Operand from, Operand.LoadIndexedOperand to) { + arrayOperand = to.arrayOperand; + indexOperand = to.indexOperand; + sourceOperand = from; + } + @Override + public StringBuilder toStr(StringBuilder sb) { + return sb + .append(arrayOperand) + .append("[") + .append(indexOperand) + .append("] = ") + .append(sourceOperand); + } + } + + public static class GetField extends Instruction { + public final Operand structOperand; + public final String fieldName; + public final int fieldIndex; + public final Operand.RegisterOperand destOperand; + public GetField(Operand.LoadFieldOperand from, Operand.RegisterOperand to) + { + this.structOperand = from.structOperand; + this.fieldName = from.fieldName; + this.fieldIndex = from.fieldIndex; + this.destOperand = to; + } + @Override + public StringBuilder toStr(StringBuilder sb) { + return sb.append(destOperand) + .append(" = ") + .append(structOperand) + .append(".") + .append(fieldName); + } + } + + public static class SetField extends Instruction { + public final Operand structOperand; + public final String fieldName; + public final int fieldIndex; + public final Operand sourceOperand; + public SetField(Operand from,Operand.LoadFieldOperand to) + { + this.structOperand = to.structOperand; + this.fieldName = to.fieldName; + this.fieldIndex = to.fieldIndex; + this.sourceOperand = from; + } + @Override + public StringBuilder toStr(StringBuilder sb) { + return sb + .append(structOperand) + .append(".") + .append(fieldName) + .append(" = ") + .append(sourceOperand); + } + } + public static class Return extends Move { + public Return(Operand from) { + super(from, new Operand.ReturnRegisterOperand()); + } + } + public static class Unary extends Instruction { public final String unop; - public final Operand result; + public final Operand.RegisterOperand result; public final Operand operand; - public UnaryInstruction(String unop, Operand result, Operand operand) { + public Unary(String unop, Operand.RegisterOperand result, Operand operand) { this.unop = unop; this.result = result; this.operand = operand; @@ -38,12 +163,12 @@ public StringBuilder toStr(StringBuilder sb) { } } - public static class BinaryInstruction extends Instruction { + public static class Binary extends Instruction { public final String binOp; - public final Operand result; + public final Operand.RegisterOperand result; public final Operand left; public final Operand right; - public BinaryInstruction(String binop, Operand result, Operand left, Operand right) { + public Binary(String binop, Operand.RegisterOperand result, Operand left, Operand right) { this.binOp = binop; this.result = result; this.left = left; @@ -56,9 +181,9 @@ public StringBuilder toStr(StringBuilder sb) { } public static class AStoreAppend extends Instruction { - public final Operand array; + public final Operand.RegisterOperand array; public final Operand value; - public AStoreAppend(Operand array, Operand value) { + public AStoreAppend(Operand.RegisterOperand array, Operand value) { this.array = array; this.value = value; } @@ -90,14 +215,21 @@ public StringBuilder toStr(StringBuilder sb) { } public static class Call extends Instruction { - public final Operand callee; - public final Operand[] args; - public Call(Operand callee, Operand... args) { + public final Type.TypeFunction callee; + public final Operand.RegisterOperand[] args; + public final Operand.RegisterOperand returnOperand; + public final int newbase; + public Call(int newbase, Operand.RegisterOperand returnOperand, Type.TypeFunction callee, Operand.RegisterOperand... args) { + this.returnOperand = returnOperand; this.callee = callee; this.args = args; + this.newbase = newbase; } @Override public StringBuilder toStr(StringBuilder sb) { + if (returnOperand != null) { + sb.append(returnOperand).append(" = "); + } sb.append("call ").append(callee); if (args.length > 0) sb.append(" params "); diff --git a/registervm/src/main/java/com/compilerprogramming/ezlang/bytecode/Operand.java b/registervm/src/main/java/com/compilerprogramming/ezlang/bytecode/Operand.java index cee6a9c..6ff8248 100644 --- a/registervm/src/main/java/com/compilerprogramming/ezlang/bytecode/Operand.java +++ b/registervm/src/main/java/com/compilerprogramming/ezlang/bytecode/Operand.java @@ -4,10 +4,13 @@ public class Operand { + Type type; + public static class ConstantOperand extends Operand { public final long value; - public ConstantOperand(long value) { + public ConstantOperand(long value, Type type) { this.value = value; + this.type = type; } @Override public String toString() { @@ -15,11 +18,17 @@ public String toString() { } } - public static class LocalRegisterOperand extends Operand { + public static abstract class RegisterOperand extends Operand { public final int regnum; + public RegisterOperand(int regnum) { + this.regnum = regnum; + } + } + + public static class LocalRegisterOperand extends RegisterOperand { public final String varName; public LocalRegisterOperand(int regnum, String varName) { - this.regnum = regnum; + super(regnum); this.varName = varName; } @Override @@ -44,8 +53,8 @@ public String toString() { * the caller will expect to see any return value. The VM must map * this to appropriate location. */ - public static class ReturnRegisterOperand extends Operand { - public ReturnRegisterOperand() {} + public static class ReturnRegisterOperand extends RegisterOperand { + public ReturnRegisterOperand() { super(0); } @Override public String toString() { return "%ret"; } } @@ -55,10 +64,10 @@ public ReturnRegisterOperand() {} * virtual stack. Temps start at offset 0, but this is a relative * register number from start of temp area. */ - public static class TempRegisterOperand extends Operand { - public final int regnum; - public TempRegisterOperand(int regnum) { - this.regnum = regnum; + public static class TempRegisterOperand extends RegisterOperand { + public TempRegisterOperand(int regnum, Type type) { + super(regnum); + this.type = type; } @Override public String toString() { @@ -66,12 +75,16 @@ public String toString() { } } - public static class LoadIndexedOperand extends Operand { + public static class IndexedOperand extends Operand {} + + public static class LoadIndexedOperand extends IndexedOperand { public final Operand arrayOperand; public final Operand indexOperand; public LoadIndexedOperand(Operand arrayOperand, Operand indexOperand) { this.arrayOperand = arrayOperand; this.indexOperand = indexOperand; + assert !(indexOperand instanceof IndexedOperand) && + !(arrayOperand instanceof IndexedOperand); } @Override public String toString() { @@ -79,7 +92,7 @@ public String toString() { } } - public static class LoadFieldOperand extends Operand { + public static class LoadFieldOperand extends IndexedOperand { public final Operand structOperand; public final int fieldIndex; public final String fieldName; @@ -87,6 +100,7 @@ public LoadFieldOperand(Operand structOperand, String fieldName, int field) { this.structOperand = structOperand; this.fieldName = fieldName; this.fieldIndex = field; + assert !(structOperand instanceof IndexedOperand); } @Override diff --git a/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/ExecutionStack.java b/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/ExecutionStack.java new file mode 100644 index 0000000..dbfcac9 --- /dev/null +++ b/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/ExecutionStack.java @@ -0,0 +1,12 @@ +package com.compilerprogramming.ezlang.interpreter; + +public class ExecutionStack { + + public Value[] stack; + public int sp; + + public ExecutionStack(int maxStackSize) { + this.stack = new Value[maxStackSize]; + this.sp = -1; + } +} diff --git a/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java b/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java new file mode 100644 index 0000000..46c0b65 --- /dev/null +++ b/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java @@ -0,0 +1,255 @@ +package com.compilerprogramming.ezlang.interpreter; + +import com.compilerprogramming.ezlang.bytecode.BasicBlock; +import com.compilerprogramming.ezlang.bytecode.BytecodeFunction; +import com.compilerprogramming.ezlang.bytecode.Instruction; +import com.compilerprogramming.ezlang.bytecode.Operand; +import com.compilerprogramming.ezlang.exceptions.CompilerException; +import com.compilerprogramming.ezlang.exceptions.InterpreterException; +import com.compilerprogramming.ezlang.types.Symbol; +import com.compilerprogramming.ezlang.types.Type; +import com.compilerprogramming.ezlang.types.TypeDictionary; + +public class Interpreter { + + TypeDictionary typeDictionary; + + public Interpreter(TypeDictionary typeDictionary) { + this.typeDictionary = typeDictionary; + } + + public Value run(String functionName) { + Symbol symbol = typeDictionary.lookup(functionName); + if (symbol instanceof Symbol.FunctionTypeSymbol functionSymbol) { + Frame frame = new Frame(functionSymbol); + ExecutionStack execStack = new ExecutionStack(1024); + return interpret(execStack, frame); + } + else { + throw new InterpreterException("Unknown function: " + functionName); + } + } + + public Value interpret(ExecutionStack execStack, Frame frame) { + BytecodeFunction currentFunction = frame.bytecodeFunction; + BasicBlock currentBlock = currentFunction.entry; + int ip = -1; + int base = frame.base; + boolean done = false; + Value returnValue = null; + + while (!done) { + Instruction instruction; + + ip++; + instruction = currentBlock.instructions.get(ip); + switch (instruction) { + case Instruction.Return returnInst -> { + if (returnInst.from instanceof Operand.ConstantOperand constantOperand) { + execStack.stack[base] = new Value.IntegerValue(constantOperand.value); + } + else if (returnInst.from instanceof Operand.RegisterOperand registerOperand) { + execStack.stack[base] = execStack.stack[base+registerOperand.regnum]; + } + else throw new IllegalStateException(); + returnValue = execStack.stack[base]; + } + case Instruction.Move moveInst -> { + if (moveInst.to instanceof Operand.RegisterOperand toReg) { + if (moveInst.from instanceof Operand.RegisterOperand fromReg) { + execStack.stack[base + toReg.regnum] = execStack.stack[base + fromReg.regnum]; + } + else if (moveInst.from instanceof Operand.ConstantOperand constantOperand) { + execStack.stack[base + toReg.regnum] = new Value.IntegerValue(constantOperand.value); + } + else throw new IllegalStateException(); + } + else throw new IllegalStateException(); + } + case Instruction.Jump jumpInst -> { + currentBlock = jumpInst.jumpTo; + ip = -1; + if (currentBlock == currentFunction.exit) + done = true; + } + case Instruction.ConditionalBranch cbrInst -> { + boolean condition; + if (cbrInst.condition instanceof Operand.RegisterOperand registerOperand) { + Value value = execStack.stack[base + registerOperand.regnum]; + if (value instanceof Value.IntegerValue integerValue) { + condition = integerValue.value != 0; + } + else { + condition = value != null; + } + } + else if (cbrInst.condition instanceof Operand.ConstantOperand constantOperand) { + condition = constantOperand.value != 0; + } + else throw new IllegalStateException(); + if (condition) + currentBlock = cbrInst.trueBlock; + else + currentBlock = cbrInst.falseBlock; + ip = -1; + if (currentBlock == currentFunction.exit) + done = true; + } + case Instruction.Call callInst -> { + // Copy args to new frame + int baseReg = base+currentFunction.frameSize(); + int reg = baseReg; + for (Operand.RegisterOperand arg: callInst.args) { + execStack.stack[base + reg] = execStack.stack[base + arg.regnum]; + reg += 1; + } + // Call function + Frame newFrame = new Frame(frame, baseReg, callInst.callee); + interpret(execStack, newFrame); + // Copy return value in expected location + if (!(callInst.callee.returnType instanceof Type.TypeVoid)) { + execStack.stack[base + callInst.returnOperand.regnum] = execStack.stack[baseReg]; + } + } + case Instruction.Unary unaryInst -> { + // We don't expect constant here because we fold constants in unary expressions + Operand.RegisterOperand unaryOperand = (Operand.RegisterOperand) unaryInst.operand; + Value unaryValue = execStack.stack[base + unaryOperand.regnum]; + if (unaryValue instanceof Value.IntegerValue integerValue) { + switch (unaryInst.unop) { + case "-": execStack.stack[base + unaryInst.result.regnum] = new Value.IntegerValue(-integerValue.value); break; + // Maybe below we should explicitly set Int + case "!": execStack.stack[base + unaryInst.result.regnum] = new Value.IntegerValue(integerValue.value==0?1:0); break; + default: throw new CompilerException("Invalid unary op"); + } + } + else + throw new IllegalStateException("Unexpected unary operand: " + unaryOperand); + } + case Instruction.Binary binaryInst -> { + long x, y; + long value = 0; + if (binaryInst.left instanceof Operand.ConstantOperand constant) + x = constant.value; + else if (binaryInst.left instanceof Operand.RegisterOperand registerOperand) + x = ((Value.IntegerValue) execStack.stack[base + registerOperand.regnum]).value; + else throw new IllegalStateException(); + if (binaryInst.right instanceof Operand.ConstantOperand constant) + y = constant.value; + else if (binaryInst.right instanceof Operand.RegisterOperand registerOperand) + y = ((Value.IntegerValue) execStack.stack[base + registerOperand.regnum]).value; + else throw new IllegalStateException(); + switch (binaryInst.binOp) { + case "+": value = x + y; break; + case "-": value = x - y; break; + case "*": value = x * y; break; + case "/": value = x / y; break; + case "%": value = x % y; break; + case "==": value = x == y ? 1 : 0; break; + case "!=": value = x != y ? 1 : 0; break; + case "<": value = x < y ? 1: 0; break; + case ">": value = x > y ? 1 : 0; break; + case "<=": value = x <= y ? 1 : 0; break; + case ">=": value = x <= y ? 1 : 0; break; + default: throw new IllegalStateException(); + } + execStack.stack[base + binaryInst.result.regnum] = new Value.IntegerValue(value); + } + case Instruction.NewArray newArrayInst -> { + execStack.stack[base + newArrayInst.destOperand.regnum] = new Value.ArrayValue(newArrayInst.type); + } + case Instruction.NewStruct newStructInst -> { + execStack.stack[base + newStructInst.destOperand.regnum] = new Value.StructValue(newStructInst.type); + } + case Instruction.AStoreAppend arrayAppendInst -> { + Value.ArrayValue arrayValue = (Value.ArrayValue) execStack.stack[base + arrayAppendInst.array.regnum]; + if (arrayAppendInst.value instanceof Operand.ConstantOperand constant) { + arrayValue.values.add(new Value.IntegerValue(constant.value)); + } + else if (arrayAppendInst.value instanceof Operand.RegisterOperand registerOperand) { + arrayValue.values.add(execStack.stack[base + registerOperand.regnum]); + } + else throw new IllegalStateException(); + } + case Instruction.ArrayStore arrayStoreInst -> { + if (arrayStoreInst.arrayOperand instanceof Operand.RegisterOperand arrayOperand) { + Value.ArrayValue arrayValue = (Value.ArrayValue) execStack.stack[base + arrayOperand.regnum]; + int index = 0; + if (arrayStoreInst.indexOperand instanceof Operand.ConstantOperand constant) { + index = (int) constant.value; + } + else if (arrayStoreInst.indexOperand instanceof Operand.RegisterOperand registerOperand) { + Value.IntegerValue indexValue = (Value.IntegerValue) execStack.stack[base + registerOperand.regnum]; + index = (int) indexValue.value; + } + else throw new IllegalStateException(); + Value value; + if (arrayStoreInst.sourceOperand instanceof Operand.ConstantOperand constantOperand) { + value = new Value.IntegerValue(constantOperand.value); + } + else if (arrayStoreInst.sourceOperand instanceof Operand.RegisterOperand registerOperand) { + value = execStack.stack[base + registerOperand.regnum]; + } + else throw new IllegalStateException(); + arrayValue.values.set(index, value); + } else throw new IllegalStateException(); + } + case Instruction.ArrayLoad arrayLoadInst -> { + if (arrayLoadInst.arrayOperand instanceof Operand.RegisterOperand arrayOperand) { + Value.ArrayValue arrayValue = (Value.ArrayValue) execStack.stack[base + arrayOperand.regnum]; + if (arrayLoadInst.indexOperand instanceof Operand.ConstantOperand constant) { + execStack.stack[base + arrayLoadInst.destOperand.regnum] = arrayValue.values.get((int) constant.value); + } + else if (arrayLoadInst.indexOperand instanceof Operand.RegisterOperand registerOperand) { + Value.IntegerValue index = (Value.IntegerValue) execStack.stack[base + registerOperand.regnum]; + execStack.stack[base + arrayLoadInst.destOperand.regnum] = arrayValue.values.get((int) index.value); + } + else throw new IllegalStateException(); + } else throw new IllegalStateException(); + } + case Instruction.SetField setFieldInst -> { + if (setFieldInst.structOperand instanceof Operand.RegisterOperand structOperand) { + Value.StructValue structValue = (Value.StructValue) execStack.stack[base + structOperand.regnum]; + int index = setFieldInst.fieldIndex; + Value value; + if (setFieldInst.sourceOperand instanceof Operand.ConstantOperand constant) { + value = new Value.IntegerValue(constant.value); + } + else if (setFieldInst.sourceOperand instanceof Operand.RegisterOperand registerOperand) { + value = execStack.stack[base + registerOperand.regnum]; + } + else throw new IllegalStateException(); + structValue.fields[index] = value; + } else throw new IllegalStateException(); + } + case Instruction.GetField getFieldInst -> { + if (getFieldInst.structOperand instanceof Operand.RegisterOperand structOperand) { + Value.StructValue structValue = (Value.StructValue) execStack.stack[base + structOperand.regnum]; + int index = getFieldInst.fieldIndex; + execStack.stack[base + getFieldInst.destOperand.regnum] = structValue.fields[index]; + } else throw new IllegalStateException(); + } + default -> throw new IllegalStateException("Unexpected value: " + instruction); + } + } + return returnValue; + } + + static class Frame { + Frame caller; + int base; + BytecodeFunction bytecodeFunction; + + public Frame(Symbol.FunctionTypeSymbol functionSymbol) { + this.caller = null; + this.base = 0; + this.bytecodeFunction = (BytecodeFunction) functionSymbol.code(); + } + + Frame(Frame caller, int base, Type.TypeFunction functionType) { + this.caller = caller; + this.base = base; + this.bytecodeFunction = (BytecodeFunction) functionType.code; + } + } +} diff --git a/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/Value.java b/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/Value.java new file mode 100644 index 0000000..e03fb80 --- /dev/null +++ b/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/Value.java @@ -0,0 +1,30 @@ +package com.compilerprogramming.ezlang.interpreter; + +import com.compilerprogramming.ezlang.types.Type; + +import java.util.ArrayList; + +public class Value { + static public class IntegerValue extends Value { + public IntegerValue(long value) { + this.value = value; + } + public final long value; + } + static public class ArrayValue extends Value { + public final Type.TypeArray arrayType; + public final ArrayList values; + public ArrayValue(Type.TypeArray arrayType) { + this.arrayType = arrayType; + values = new ArrayList<>(); + } + } + static public class StructValue extends Value { + public final Type.TypeStruct structType; + public final Value[] fields; + public StructValue(Type.TypeStruct structType) { + this.structType = structType; + this.fields = new Value[structType.numFields()]; + } + } +} diff --git a/registervm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestCompiler.java b/registervm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestCompiler.java index ca17ccc..f9fbc74 100644 --- a/registervm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestCompiler.java +++ b/registervm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestCompiler.java @@ -21,12 +21,12 @@ String compileSrc(String src) { sema.analyze(program); var sema2 = new SemaAssignTypes(typeDict); sema2.analyze(program); - RegisterVMCompiler byteCodeCompiler = new RegisterVMCompiler(); + BytecodeCompiler byteCodeCompiler = new BytecodeCompiler(); byteCodeCompiler.compile(typeDict); StringBuilder sb = new StringBuilder(); for (Symbol s: typeDict.bindings.values()) { if (s instanceof Symbol.FunctionTypeSymbol f) { - var functionBuilder = (FunctionBuilder) f.code; + var functionBuilder = (BytecodeFunction) f.code(); BasicBlock.toStr(sb, functionBuilder.entry, new BitSet()); } } @@ -91,8 +91,8 @@ func foo(n: Int)->Int { String result = compileSrc(src); Assert.assertEquals(""" L0: - %t0 = -n - %ret = %t0 + %t1 = -n + %ret = %t1 goto L1 L1: """, result); @@ -108,12 +108,13 @@ func foo(n: Int)->Int { String result = compileSrc(src); Assert.assertEquals(""" L0: - %t0 = n+1 - %ret = %t0 + %t1 = n+1 + %ret = %t1 goto L1 L1: """, result); } + @Test public void testFunction6() { String src = """ @@ -129,6 +130,7 @@ func foo(n: Int)->Int { L1: """, result); } + @Test public void testFunction7() { String src = """ @@ -144,6 +146,7 @@ func foo(n: Int)->Int { L1: """, result); } + @Test public void testFunction8() { String src = """ @@ -159,6 +162,7 @@ func foo(n: Int)->Int { L1: """, result); } + @Test public void testFunction9() { String src = """ @@ -185,7 +189,8 @@ func foo(n: [Int])->Int { String result = compileSrc(src); Assert.assertEquals(""" L0: - %ret = n[0] + %t1 = n[0] + %ret = %t1 goto L1 L1: """, result); @@ -201,10 +206,10 @@ func foo(n: [Int])->Int { String result = compileSrc(src); Assert.assertEquals(""" L0: - %t0 = n[0] - %t1 = n[1] - %t0 = %t0+%t1 - %ret = %t0 + %t1 = n[0] + %t2 = n[1] + %t1 = %t1+%t2 + %ret = %t1 goto L1 L1: """, result); @@ -240,9 +245,9 @@ func foo(n: Int) -> [Int] { String result = compileSrc(src); Assert.assertEquals(""" L0: - %t0 = New([Int,Int]) - %t0.append(n) - %ret = %t0 + %t1 = New([Int,Int]) + %t1.append(n) + %ret = %t1 goto L1 L1: """, result); @@ -258,8 +263,8 @@ func add(x: Int, y: Int) -> Int { String result = compileSrc(src); Assert.assertEquals(""" L0: - %t0 = x+y - %ret = %t0 + %t2 = x+y + %ret = %t2 goto L1 L1: """, result); @@ -340,8 +345,8 @@ func min(x: Int, y: Int) -> Int { String result = compileSrc(src); Assert.assertEquals(""" L0: - %t0 = x0 - if %t0 goto L3 else goto L4 + %t1 = n>0 + if %t1 goto L3 else goto L4 L3: - %t0 = n-1 - n = %t0 + %t1 = n-1 + n = %t1 goto L2 L4: goto L1 @@ -437,8 +442,7 @@ func foo() {} goto L1 L1: L0: - %t0 = foo - call %t0 + call foo goto L1 L1: """, result); @@ -456,10 +460,9 @@ func foo(x: Int, y: Int) {} goto L1 L1: L0: - %t0 = foo - %t1 = 1 - %t2 = 2 - call %t0 params %t1, %t2 + %t0 = 1 + %t1 = 2 + call foo params %t0, %t1 goto L1 L1: """, result); @@ -474,18 +477,17 @@ public void testFunction24() { String result = compileSrc(src); Assert.assertEquals(""" L0: - %t0 = x+y - %ret = %t0 + %t2 = x+y + %ret = %t2 goto L1 L1: L0: - %t0 = foo %t1 = 1 %t2 = 2 - call %t0 params %t1, %t2 - t = %t0 - %t0 = t+1 - %ret = %t0 + %t1 = call foo params %t1, %t2 + t = %t1 + %t1 = t+1 + %ret = %t1 goto L1 L1: """, result); @@ -506,7 +508,8 @@ func foo(p: Person) -> Int { String result = compileSrc(src); Assert.assertEquals(""" L0: - %ret = p.age + %t1 = p.age + %ret = %t1 goto L1 L1: """, result); @@ -527,8 +530,9 @@ func foo(p: Person) -> Int { String result = compileSrc(src); Assert.assertEquals(""" L0: - %t0 = p.parent - %ret = %t0.age + %t1 = p.parent + %t1 = %t1.age + %ret = %t1 goto L1 L1: """, result); @@ -549,13 +553,15 @@ func foo(p: [Person], i: Int) -> Int { String result = compileSrc(src); Assert.assertEquals(""" L0: - %t0 = p[i] - %t0 = %t0.parent - %ret = %t0.age + %t2 = p[i] + %t2 = %t2.parent + %t2 = %t2.age + %ret = %t2 goto L1 L1: """, result); } + @Test public void testFunction28() { String src = """ @@ -565,18 +571,17 @@ public void testFunction28() { String result = compileSrc(src); Assert.assertEquals(""" L0: - %t0 = x+y - %ret = %t0 + %t2 = x+y + %ret = %t2 goto L1 L1: L0: - %t0 = foo - %t1 = a - %t2 = 2 - call %t0 params %t1, %t2 - t = %t0 - %t0 = t+1 - %ret = %t0 + %t2 = a + %t3 = 2 + %t2 = call foo params %t2, %t3 + t = %t2 + %t2 = t+1 + %ret = %t2 goto L1 L1: """, result); diff --git a/registervm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestInterpreter.java b/registervm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestInterpreter.java new file mode 100644 index 0000000..15c1257 --- /dev/null +++ b/registervm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestInterpreter.java @@ -0,0 +1,127 @@ +package com.compilerprogramming.ezlang.bytecode; + +import com.compilerprogramming.ezlang.interpreter.Interpreter; +import com.compilerprogramming.ezlang.interpreter.Value; +import com.compilerprogramming.ezlang.lexer.Lexer; +import com.compilerprogramming.ezlang.parser.Parser; +import com.compilerprogramming.ezlang.semantic.SemaAssignTypes; +import com.compilerprogramming.ezlang.semantic.SemaDefineTypes; +import com.compilerprogramming.ezlang.types.Symbol; +import com.compilerprogramming.ezlang.types.TypeDictionary; +import org.junit.Assert; +import org.junit.Test; + +import java.util.BitSet; + +public class TestInterpreter { + + Value compileAndRun(String src, String mainFunction) { + Parser parser = new Parser(); + var program = parser.parse(new Lexer(src)); + var typeDict = new TypeDictionary(); + var sema = new SemaDefineTypes(typeDict); + sema.analyze(program); + var sema2 = new SemaAssignTypes(typeDict); + sema2.analyze(program); + BytecodeCompiler byteCodeCompiler = new BytecodeCompiler(); + byteCodeCompiler.compile(typeDict); + StringBuilder sb = new StringBuilder(); + for (Symbol s : typeDict.bindings.values()) { + if (s instanceof Symbol.FunctionTypeSymbol f) { + var functionBuilder = (BytecodeFunction) f.code(); + BasicBlock.toStr(sb, functionBuilder.entry, new BitSet()); + } + } + System.out.println(sb.toString()); + var interpreter = new Interpreter(typeDict); + return interpreter.run(mainFunction); + } + + @Test + public void testFunction1() { + String src = """ + func foo()->Int { + return 42; + } + """; + var value = compileAndRun(src, "foo"); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.IntegerValue integerValue + && integerValue.value == 42); + } + + @Test + public void testFunction2() { + String src = """ + func bar()->Int { + return 42; + } + func foo()->Int { + return bar(); + } + """; + var value = compileAndRun(src, "foo"); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.IntegerValue integerValue + && integerValue.value == 42); + } + + @Test + public void testFunction3() { + String src = """ + func negate(n: Int)->Int { + return -n; + } + func foo()->Int { + return negate(42); + } + """; + var value = compileAndRun(src, "foo"); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.IntegerValue integerValue + && integerValue.value == -42); + } + + @Test + public void testFunction4() { + String src = """ + func foo(x: Int, y: Int)->Int { return x+y; } + func bar()->Int { var t = foo(1,2); return t+1; } + """; + var value = compileAndRun(src, "bar"); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.IntegerValue integerValue + && integerValue.value == 4); + } + + @Test + public void testFunction5() { + String src = """ + func bar()->Int { var t = new [Int] {1,21,3}; return t[1]; } + """; + var value = compileAndRun(src, "bar"); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.IntegerValue integerValue + && integerValue.value == 21); + } + + @Test + public void testFunction6() { + String src = """ + struct Test + { + var field: Int + } + func foo()->Test + { + var test = new Test{ field = 42 } + return test + } + func bar()->Int { return foo().field } + """; + var value = compileAndRun(src, "bar"); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.IntegerValue integerValue + && integerValue.value == 42); + } +} diff --git a/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaDefineTypes.java b/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaDefineTypes.java index 0030d1b..c51dcdf 100644 --- a/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaDefineTypes.java +++ b/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaDefineTypes.java @@ -229,12 +229,15 @@ public ASTVisitor visit(AST.ReturnTypeExpr returnTypeExpr, boolean enter) { // We override the visitor and visit the return type here because // we need to associate the return type to the function's return type // The visitor mechanism doesn't allow us to associate values between two steps + Type.TypeFunction type = (Type.TypeFunction) currentFuncDecl.symbol.type; if (returnTypeExpr.returnType != null) { returnTypeExpr.returnType.accept(this); - Type.TypeFunction type = (Type.TypeFunction) currentFuncDecl.symbol.type; returnTypeExpr.type = returnTypeExpr.returnType.type; type.setReturnType(returnTypeExpr.type); } + else { + type.setReturnType(typeDictionary.VOID); + } visitor = null; } return visitor; diff --git a/stackvm/src/main/java/com/compilerprogramming/ezlang/bytecode/ByteCodeCompiler.java b/stackvm/src/main/java/com/compilerprogramming/ezlang/bytecode/ByteCodeCompiler.java index 8d1bfc4..85a846a 100644 --- a/stackvm/src/main/java/com/compilerprogramming/ezlang/bytecode/ByteCodeCompiler.java +++ b/stackvm/src/main/java/com/compilerprogramming/ezlang/bytecode/ByteCodeCompiler.java @@ -1,6 +1,7 @@ package com.compilerprogramming.ezlang.bytecode; import com.compilerprogramming.ezlang.types.Symbol; +import com.compilerprogramming.ezlang.types.Type; import com.compilerprogramming.ezlang.types.TypeDictionary; public class ByteCodeCompiler { @@ -8,7 +9,8 @@ public class ByteCodeCompiler { public void compile(TypeDictionary typeDictionary) { for (Symbol symbol: typeDictionary.getLocalSymbols()) { if (symbol instanceof Symbol.FunctionTypeSymbol functionSymbol) { - functionSymbol.code = new FunctionBuilder(functionSymbol); + Type.TypeFunction functionType = (Type.TypeFunction) functionSymbol.type; + functionType.code = new BytecodeFunction(functionSymbol); } } } diff --git a/stackvm/src/main/java/com/compilerprogramming/ezlang/bytecode/FunctionBuilder.java b/stackvm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeFunction.java similarity index 99% rename from stackvm/src/main/java/com/compilerprogramming/ezlang/bytecode/FunctionBuilder.java rename to stackvm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeFunction.java index 2b1e97c..ea26f5f 100644 --- a/stackvm/src/main/java/com/compilerprogramming/ezlang/bytecode/FunctionBuilder.java +++ b/stackvm/src/main/java/com/compilerprogramming/ezlang/bytecode/BytecodeFunction.java @@ -6,7 +6,7 @@ import com.compilerprogramming.ezlang.types.Symbol; import com.compilerprogramming.ezlang.types.Type; -public class FunctionBuilder { +public class BytecodeFunction { BasicBlock entry; BasicBlock exit; @@ -15,7 +15,7 @@ public class FunctionBuilder { BasicBlock currentBreakTarget; BasicBlock currentContinueTarget; - public FunctionBuilder(Symbol.FunctionTypeSymbol functionSymbol) { + public BytecodeFunction(Symbol.FunctionTypeSymbol functionSymbol) { AST.FuncDecl funcDecl = (AST.FuncDecl) functionSymbol.functionDecl; setVirtualRegisters(funcDecl.scope); this.bid = 0; diff --git a/stackvm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestCompiler.java b/stackvm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestCompiler.java index 847df92..5a5fd9c 100644 --- a/stackvm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestCompiler.java +++ b/stackvm/src/test/java/com/compilerprogramming/ezlang/bytecode/TestCompiler.java @@ -24,7 +24,7 @@ void compileSrc(String src, String functionName) { byteCodeCompiler.compile(typeDict); for (Symbol s: typeDict.bindings.values()) { if (s instanceof Symbol.FunctionTypeSymbol f) { - var functionBuilder = (FunctionBuilder) f.code; + var functionBuilder = (BytecodeFunction) f.code(); System.out.println(BasicBlock.toStr(new StringBuilder(), functionBuilder.entry, new BitSet())); } } diff --git a/types/src/main/java/com/compilerprogramming/ezlang/types/Symbol.java b/types/src/main/java/com/compilerprogramming/ezlang/types/Symbol.java index b982452..0252f7b 100644 --- a/types/src/main/java/com/compilerprogramming/ezlang/types/Symbol.java +++ b/types/src/main/java/com/compilerprogramming/ezlang/types/Symbol.java @@ -1,5 +1,7 @@ package com.compilerprogramming.ezlang.types; +import java.util.Objects; + /** * A symbol is something that has a name and a type. */ @@ -21,11 +23,14 @@ public TypeSymbol(String name, Type type) { public static class FunctionTypeSymbol extends Symbol { public final Object functionDecl; - public Object code; public FunctionTypeSymbol(String name, Type.TypeFunction type, Object functionDecl) { super(name, type); this.functionDecl = functionDecl; } + public Object code() { + Type.TypeFunction function = (Type.TypeFunction) type; + return function.code; + } } public static class VarSymbol extends Symbol { diff --git a/types/src/main/java/com/compilerprogramming/ezlang/types/Type.java b/types/src/main/java/com/compilerprogramming/ezlang/types/Type.java index 406735a..2ba5aab 100644 --- a/types/src/main/java/com/compilerprogramming/ezlang/types/Type.java +++ b/types/src/main/java/com/compilerprogramming/ezlang/types/Type.java @@ -13,13 +13,14 @@ public abstract class Type { // Type classes - static final byte TANY = 0; - static final byte TNULL = 1; - static final byte TINT = 2; // Int, Bool - static final byte TNULLABLE = 3; // Null, or not null ptr - static final byte TFUNC = 4; // Function types - static final byte TSTRUCT = 5; - static final byte TARRAY = 6; + static final byte TVOID = 0; + static final byte TANY = 1; + static final byte TNULL = 2; + static final byte TINT = 3; // Int, Bool + static final byte TNULLABLE = 4; // Null, or not null ptr + static final byte TFUNC = 5; // Function types + static final byte TSTRUCT = 6; + static final byte TARRAY = 7; public final byte tclass; // type class public final String name; // type name, always unique @@ -51,6 +52,12 @@ public String toString() { } public String name() { return name; } + public static class TypeVoid extends Type { + public TypeVoid() { + super(TVOID, "$Void"); + } + } + /** * We give it the name $Any so that it cannot be referenced in * the language @@ -114,6 +121,7 @@ public Type getField(String name) { public int getFieldIndex(String name) { return fieldNames.indexOf(name); } + public int numFields() { return fieldNames.size(); } public void complete() { pending = false; } } @@ -146,6 +154,7 @@ public TypeNullable(Type baseType) { public static class TypeFunction extends Type { public final List args = new ArrayList<>(); public Type returnType; + public Object code; public TypeFunction(String name) { super(TFUNC, name); } @@ -166,7 +175,7 @@ public String describe() { sb.append(arg.name).append(": ").append(arg.type.name()); } sb.append(")"); - if (returnType != null) { + if (!(returnType instanceof Type.TypeVoid)) { sb.append("->").append(returnType.name()); } return sb.toString(); diff --git a/types/src/main/java/com/compilerprogramming/ezlang/types/TypeDictionary.java b/types/src/main/java/com/compilerprogramming/ezlang/types/TypeDictionary.java index f4b5bc2..9f296c0 100644 --- a/types/src/main/java/com/compilerprogramming/ezlang/types/TypeDictionary.java +++ b/types/src/main/java/com/compilerprogramming/ezlang/types/TypeDictionary.java @@ -6,12 +6,14 @@ public class TypeDictionary extends Scope { public final Type.TypeAny ANY; public final Type.TypeInteger INT; public final Type.TypeNull NULL; + public final Type.TypeVoid VOID; public TypeDictionary() { super(null); INT = (Type.TypeInteger) intern(new Type.TypeInteger()); ANY = (Type.TypeAny) intern(new Type.TypeAny()); NULL = (Type.TypeNull) intern(new Type.TypeNull()); + VOID = (Type.TypeVoid) intern(new Type.TypeVoid()); } public Type makeArrayType(Type elementType, boolean isNullable) { switch (elementType) {