From 1256402247f231f1f62d221b6af134510cd1b439 Mon Sep 17 00:00:00 2001 From: dibyendumajumdar Date: Sat, 15 Feb 2025 20:48:40 +0000 Subject: [PATCH 1/4] Cleanup types a bit - arrays do not have the second type anymore. Any is now Unknown. --- .../ezlang/compiler/TestCompiler.java | 4 +-- .../ezlang/compiler/TestCompiler.java | 4 +-- .../ezlang/semantic/SemaAssignTypes.java | 5 +-- .../ezlang/semantic/SemaDefineTypes.java | 2 +- .../ezlang/semantic/TestSemaDefineTypes.java | 8 ++--- .../ezlang/compiler/TestCompiler.java | 4 +-- .../ezlang/types/Type.java | 35 ++++++++++++------- .../ezlang/types/TypeDictionary.java | 14 ++++---- 8 files changed, 43 insertions(+), 33 deletions(-) diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestCompiler.java b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestCompiler.java index 15c31d4..fa148b6 100644 --- a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestCompiler.java +++ b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestCompiler.java @@ -214,7 +214,7 @@ func foo()->[Int] { String result = compileSrc(src); Assert.assertEquals(""" L0: - %t0 = New([Int,Int]) + %t0 = New([Int]) %t0.append(1) %t0.append(2) %t0.append(3) @@ -235,7 +235,7 @@ func foo(n: Int) -> [Int] { Assert.assertEquals(""" L0: arg n - %t1 = New([Int,Int]) + %t1 = New([Int]) %t1.append(n) ret %t1 goto L1 diff --git a/registervm/src/test/java/com/compilerprogramming/ezlang/compiler/TestCompiler.java b/registervm/src/test/java/com/compilerprogramming/ezlang/compiler/TestCompiler.java index 294d6ac..c2f9767 100644 --- a/registervm/src/test/java/com/compilerprogramming/ezlang/compiler/TestCompiler.java +++ b/registervm/src/test/java/com/compilerprogramming/ezlang/compiler/TestCompiler.java @@ -203,7 +203,7 @@ func foo()->[Int] { String result = compileSrc(src); Assert.assertEquals(""" L0: - %t0 = New([Int,Int]) + %t0 = New([Int]) %t0.append(1) %t0.append(2) %t0.append(3) @@ -223,7 +223,7 @@ func foo(n: Int) -> [Int] { String result = compileSrc(src); Assert.assertEquals(""" L0: - %t1 = New([Int,Int]) + %t1 = New([Int]) %t1.append(n) %ret = %t1 goto L1 diff --git a/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaAssignTypes.java b/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaAssignTypes.java index 0ddf525..07908a1 100644 --- a/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaAssignTypes.java +++ b/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaAssignTypes.java @@ -294,7 +294,8 @@ public ASTVisitor visit(AST.AssignStmt assignStmt, boolean enter) { if (!enter) { validType(assignStmt.lhs.type); validType(assignStmt.rhs.type); - // TODO check assignment + if (!assignStmt.lhs.type.isAssignable(assignStmt.rhs.type)) + throw new CompilerException("Value of type " + assignStmt.rhs.type + " cannot be assigned to type " + assignStmt.lhs.type); } return this; } @@ -306,7 +307,7 @@ public void analyze(AST.Program program) { private void validType(Type t) { if (t == null) throw new CompilerException("Undefined type"); - if (t == typeDictionary.ANY) + if (t == typeDictionary.UNKNOWN) throw new CompilerException("Undefined type"); } } diff --git a/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaDefineTypes.java b/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaDefineTypes.java index 1c126b1..95d1ab8 100644 --- a/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaDefineTypes.java +++ b/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaDefineTypes.java @@ -293,7 +293,7 @@ public ASTVisitor visit(AST.VarStmt varStmt, boolean enter) { if (enter) { if (currentScope.localLookup(varStmt.varName) != null) throw new CompilerException("Variable " + varStmt.varName + " already declared in current scope"); - varStmt.symbol = (Symbol.VarSymbol) currentScope.install(varStmt.varName, new Symbol.VarSymbol(varStmt.varName, typeDictionary.ANY)); + varStmt.symbol = (Symbol.VarSymbol) currentScope.install(varStmt.varName, new Symbol.VarSymbol(varStmt.varName, typeDictionary.UNKNOWN)); } return this; } diff --git a/semantic/src/test/java/com/compilerprogramming/ezlang/semantic/TestSemaDefineTypes.java b/semantic/src/test/java/com/compilerprogramming/ezlang/semantic/TestSemaDefineTypes.java index 28ef408..8de6c06 100644 --- a/semantic/src/test/java/com/compilerprogramming/ezlang/semantic/TestSemaDefineTypes.java +++ b/semantic/src/test/java/com/compilerprogramming/ezlang/semantic/TestSemaDefineTypes.java @@ -56,7 +56,7 @@ public void test3() { sema.analyze(program); var symbol = typeDict.lookup("TreeArray"); Assert.assertNotNull(symbol); - Assert.assertEquals("struct TreeArray{data: [Tree,Int];}", symbol.type.describe()); + Assert.assertEquals("struct TreeArray{data: [Tree];}", symbol.type.describe()); } @Test @@ -72,7 +72,7 @@ public void test4() { sema.analyze(program); var symbol = typeDict.lookup("TreeArray"); Assert.assertNotNull(symbol); - Assert.assertEquals("struct TreeArray{data: [Tree?,Int];}", symbol.type.describe()); + Assert.assertEquals("struct TreeArray{data: [Tree?];}", symbol.type.describe()); } @Test @@ -88,7 +88,7 @@ public void test5() { sema.analyze(program); var symbol = typeDict.lookup("TreeArray"); Assert.assertNotNull(symbol); - Assert.assertEquals("struct TreeArray{data: [Tree?,Int]?;}", symbol.type.describe()); + Assert.assertEquals("struct TreeArray{data: [Tree?]?;}", symbol.type.describe()); } @Test @@ -147,7 +147,7 @@ public void test9() { sema.analyze(program); var symbol = typeDict.lookup("TreeArray"); Assert.assertNotNull(symbol); - Assert.assertEquals("struct TreeArray{data: [Int,Int];}", symbol.type.describe()); + Assert.assertEquals("struct TreeArray{data: [Int];}", symbol.type.describe()); } } diff --git a/stackvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestCompiler.java b/stackvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestCompiler.java index eeb0a02..2a9a21e 100644 --- a/stackvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestCompiler.java +++ b/stackvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestCompiler.java @@ -227,7 +227,7 @@ func foo()->[Int] { String result = compileSrc(src); Assert.assertEquals(""" L0: - new [Int,Int] + new [Int] pushi 1 storeappend pushi 2 @@ -249,7 +249,7 @@ func foo(n: Int) -> [Int] { String result = compileSrc(src); Assert.assertEquals(""" L0: - new [Int,Int] + new [Int] load 0 storeappend jump L1 diff --git a/types/src/main/java/com/compilerprogramming/ezlang/types/Type.java b/types/src/main/java/com/compilerprogramming/ezlang/types/Type.java index 2ba5aab..7e1d44c 100644 --- a/types/src/main/java/com/compilerprogramming/ezlang/types/Type.java +++ b/types/src/main/java/com/compilerprogramming/ezlang/types/Type.java @@ -14,7 +14,7 @@ public abstract class Type { // Type classes static final byte TVOID = 0; - static final byte TANY = 1; + static final byte TUNKNOWN = 1; static final byte TNULL = 2; static final byte TINT = 3; // Int, Bool static final byte TNULLABLE = 4; // Null, or not null ptr @@ -52,19 +52,31 @@ public String toString() { } public String name() { return name; } + public boolean isAssignable(Type other) { + if (other instanceof TypeVoid || other instanceof TypeUnknown) + return false; + if (this == other || equals(other)) return true; + if (this instanceof TypeNullable nullable) { + if (other instanceof TypeNull) + return true; + return nullable.baseType.isAssignable(other); + } + return false; + } + + /** + * Represents no type - useful for defining functions + * that do not return a value + */ public static class TypeVoid extends Type { public TypeVoid() { super(TVOID, "$Void"); } } - /** - * We give it the name $Any so that it cannot be referenced in - * the language - */ - public static class TypeAny extends Type { - public TypeAny() { - super(TANY, "$Any"); + public static class TypeUnknown extends Type { + public TypeUnknown() { + super(TUNKNOWN, "$Unknown"); } } @@ -127,11 +139,9 @@ public int getFieldIndex(String name) { public static class TypeArray extends Type { Type elementType; - Type.TypeInteger size; - public TypeArray(Type baseType, Type.TypeInteger size) { - super(TARRAY, "[" + baseType.name() + "," + size.name() + "]"); - this.size = size; + public TypeArray(Type baseType) { + super(TARRAY, "[" + baseType.name() + "]"); this.elementType = baseType; if (baseType instanceof TypeArray) throw new CompilerException("Array of array type not supported"); @@ -139,7 +149,6 @@ public TypeArray(Type baseType, Type.TypeInteger size) { public Type getElementType() { return elementType; } - public Type getSizeType() { return size; } } // This is really a dedicated Union type for T|Null. diff --git a/types/src/main/java/com/compilerprogramming/ezlang/types/TypeDictionary.java b/types/src/main/java/com/compilerprogramming/ezlang/types/TypeDictionary.java index 9f296c0..9fc233b 100644 --- a/types/src/main/java/com/compilerprogramming/ezlang/types/TypeDictionary.java +++ b/types/src/main/java/com/compilerprogramming/ezlang/types/TypeDictionary.java @@ -3,7 +3,7 @@ import com.compilerprogramming.ezlang.exceptions.CompilerException; public class TypeDictionary extends Scope { - public final Type.TypeAny ANY; + public final Type.TypeUnknown UNKNOWN; public final Type.TypeInteger INT; public final Type.TypeNull NULL; public final Type.TypeVoid VOID; @@ -11,22 +11,22 @@ public class TypeDictionary extends Scope { public TypeDictionary() { super(null); INT = (Type.TypeInteger) intern(new Type.TypeInteger()); - ANY = (Type.TypeAny) intern(new Type.TypeAny()); + UNKNOWN = (Type.TypeUnknown) intern(new Type.TypeUnknown()); NULL = (Type.TypeNull) intern(new Type.TypeNull()); VOID = (Type.TypeVoid) intern(new Type.TypeVoid()); } public Type makeArrayType(Type elementType, boolean isNullable) { switch (elementType) { case Type.TypeInteger ti -> { - var arrayType = intern(new Type.TypeArray(ti, INT)); + var arrayType = intern(new Type.TypeArray(ti)); return isNullable ? intern(new Type.TypeNullable(arrayType)) : arrayType; } case Type.TypeStruct ts -> { - var arrayType = intern(new Type.TypeArray(ts, INT)); + var arrayType = intern(new Type.TypeArray(ts)); return isNullable ? intern(new Type.TypeNullable(arrayType)) : arrayType; } case Type.TypeNullable nullable when nullable.baseType instanceof Type.TypeStruct -> { - var arrayType = intern(new Type.TypeArray(elementType, INT)); + var arrayType = intern(new Type.TypeArray(elementType)); return isNullable ? intern(new Type.TypeNullable(arrayType)) : arrayType; } case null, default -> throw new CompilerException("Unsupported array element type: " + elementType); @@ -50,9 +50,9 @@ else if (t1 instanceof Type.TypeArray && t2 instanceof Type.TypeNull) { else if (t2 instanceof Type.TypeArray && t1 instanceof Type.TypeNull) { return intern(new Type.TypeNullable(t2)); } - else if (t1 instanceof Type.TypeAny) + else if (t1 instanceof Type.TypeUnknown) return t2; - else if (t2 instanceof Type.TypeAny) + else if (t2 instanceof Type.TypeUnknown) return t1; else if (!t1.equals(t2)) throw new CompilerException("Unsupported merge type: " + t1 + " and " + t2); From 35b6814fa45acbb00d002f34bb31efc0b708fe4d Mon Sep 17 00:00:00 2001 From: dibyendumajumdar Date: Sat, 15 Feb 2025 23:10:54 +0000 Subject: [PATCH 2/4] Start work on supporting null --- .../ezlang/parser/Parser.java | 7 +- .../ezlang/semantic/SemaAssignTypes.java | 61 +++- .../ezlang/semantic/TestSemaAssignTypes.java | 331 +++++------------- .../ezlang/semantic/TestSemaDefineTypes.java | 83 +---- .../ezlang/types/Type.java | 6 + 5 files changed, 154 insertions(+), 334 deletions(-) diff --git a/parser/src/main/java/com/compilerprogramming/ezlang/parser/Parser.java b/parser/src/main/java/com/compilerprogramming/ezlang/parser/Parser.java index 2249d4d..ca9deb9 100644 --- a/parser/src/main/java/com/compilerprogramming/ezlang/parser/Parser.java +++ b/parser/src/main/java/com/compilerprogramming/ezlang/parser/Parser.java @@ -358,7 +358,12 @@ private AST.Expr parsePrimary(Lexer lexer) { return x; } case IDENT -> { - if (isToken(currentToken, "new")) { + if (isToken(currentToken, "null")) { + var x = new AST.LiteralExpr(currentToken); + nextToken(lexer); + return x; + } + else if (isToken(currentToken, "new")) { return parseNew(lexer); } else { diff --git a/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaAssignTypes.java b/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaAssignTypes.java index 07908a1..f3a87d2 100644 --- a/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaAssignTypes.java +++ b/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaAssignTypes.java @@ -67,11 +67,19 @@ public ASTVisitor visit(AST.VarDecl varDecl, boolean enter) { @Override public ASTVisitor visit(AST.BinaryExpr binaryExpr, boolean enter) { if (!enter) { - validType(binaryExpr.expr1.type); - validType(binaryExpr.expr2.type); - if (binaryExpr.expr1.type instanceof Type.TypeInteger t1 && - binaryExpr.expr2.type instanceof Type.TypeInteger t2) { - binaryExpr.type = typeDictionary.merge(t1, t2); + validType(binaryExpr.expr1.type, true); + validType(binaryExpr.expr2.type, true); + if (binaryExpr.expr1.type instanceof Type.TypeInteger && + binaryExpr.expr2.type instanceof Type.TypeInteger) { + // booleans are int too + binaryExpr.type = typeDictionary.INT; + } + else if (((binaryExpr.expr1.type instanceof Type.TypeNull && + binaryExpr.expr2.type instanceof Type.TypeNullable) || + (binaryExpr.expr1.type instanceof Type.TypeNullable && + binaryExpr.expr2.type instanceof Type.TypeNull)) && + (binaryExpr.op.str.equals("==") || binaryExpr.op.str.equals("!="))) { + binaryExpr.type = typeDictionary.INT; } else { throw new CompilerException("Binary operator " + binaryExpr.op + " not supported for operands"); @@ -85,7 +93,7 @@ public ASTVisitor visit(AST.UnaryExpr unaryExpr, boolean enter) { if (enter) { return this; } - validType(unaryExpr.expr.type); + validType(unaryExpr.expr.type, false); if (unaryExpr.expr.type instanceof Type.TypeInteger ti) { unaryExpr.type = unaryExpr.expr.type; } @@ -99,7 +107,7 @@ public ASTVisitor visit(AST.UnaryExpr unaryExpr, boolean enter) { public ASTVisitor visit(AST.FieldExpr fieldExpr, boolean enter) { if (enter) return this; - validType(fieldExpr.object.type); + validType(fieldExpr.object.type, false); Type.TypeStruct structType = null; if (fieldExpr.object.type instanceof Type.TypeStruct ts) { structType = ts; @@ -120,7 +128,7 @@ else if (fieldExpr.object.type instanceof Type.TypeNullable ptr && @Override public ASTVisitor visit(AST.CallExpr callExpr, boolean enter) { if (!enter) { - validType(callExpr.callee.type); + validType(callExpr.callee.type, false); if (callExpr.callee.type instanceof Type.TypeFunction f) { callExpr.type = f.returnType; } @@ -133,7 +141,7 @@ public ASTVisitor visit(AST.CallExpr callExpr, boolean enter) { @Override public ASTVisitor visit(AST.SetFieldExpr setFieldExpr, boolean enter) { if (!enter) { - validType(setFieldExpr.value.type); + validType(setFieldExpr.value.type, true); } return this; } @@ -168,6 +176,10 @@ public ASTVisitor visit(AST.LiteralExpr literalExpr, boolean enter) { if (literalExpr.value.kind == Token.Kind.NUM) { literalExpr.type = typeDictionary.INT; } + else if (literalExpr.value.kind == Token.Kind.IDENT + && literalExpr.value.str.equals("null")) { + literalExpr.type = typeDictionary.NULL; + } else { throw new CompilerException("Unsupported literal " + literalExpr.value); } @@ -178,7 +190,7 @@ public ASTVisitor visit(AST.LiteralExpr literalExpr, boolean enter) { @Override public ASTVisitor visit(AST.ArrayIndexExpr arrayIndexExpr, boolean enter) { if (!enter) { - validType(arrayIndexExpr.array.type); + validType(arrayIndexExpr.array.type, false); Type.TypeArray arrayType = null; if (arrayIndexExpr.array.type instanceof Type.TypeArray ta) { arrayType = ta; @@ -189,8 +201,10 @@ else if (arrayIndexExpr.array.type instanceof Type.TypeNullable ptr && } else throw new CompilerException("Unexpected array type " + arrayIndexExpr.array.type); + if (!(arrayIndexExpr.expr.type instanceof Type.TypeInteger)) + throw new CompilerException("Array index must be integer type"); arrayIndexExpr.type = arrayType.getElementType(); - validType(arrayIndexExpr.type); + validType(arrayIndexExpr.type, false); } return this; } @@ -201,6 +215,7 @@ public ASTVisitor visit(AST.NewExpr newExpr, boolean enter) { return this; if (newExpr.typeExpr.type == null) throw new CompilerException("Unresolved type in new expression"); + validType(newExpr.typeExpr.type, false); if (newExpr.typeExpr.type instanceof Type.TypeStruct || newExpr.typeExpr.type instanceof Type.TypeArray) { newExpr.type = newExpr.typeExpr.type; @@ -223,7 +238,7 @@ public ASTVisitor visit(AST.NameExpr nameExpr, boolean enter) { if (symbol == null) { throw new CompilerException("Unknown symbol " + nameExpr.name); } - validType(symbol.type); + validType(symbol.type, false); nameExpr.symbol = symbol; nameExpr.type = symbol.type; return this; @@ -244,7 +259,7 @@ public ASTVisitor visit(AST.ReturnStmt returnStmt, boolean enter) { if (enter) return this; if (returnStmt.expr != null) - validType(returnStmt.expr.type); + validType(returnStmt.expr.type, false); return this; } @@ -261,9 +276,11 @@ public ASTVisitor visit(AST.WhileStmt whileStmt, boolean enter) { @Override public ASTVisitor visit(AST.VarStmt varStmt, boolean enter) { if (!enter) { - validType(varStmt.expr.type); + validType(varStmt.expr.type, true); var symbol = currentScope.lookup(varStmt.varName); symbol.type = typeDictionary.merge(varStmt.expr.type, symbol.type); + if (symbol.type == typeDictionary.NULL) + throw new CompilerException("Variable " + varStmt.varName + " cannot be Null type"); } return this; } @@ -292,10 +309,9 @@ public ASTVisitor visit(AST.ExprStmt exprStmt, boolean enter) { @Override public ASTVisitor visit(AST.AssignStmt assignStmt, boolean enter) { if (!enter) { - validType(assignStmt.lhs.type); - validType(assignStmt.rhs.type); - if (!assignStmt.lhs.type.isAssignable(assignStmt.rhs.type)) - throw new CompilerException("Value of type " + assignStmt.rhs.type + " cannot be assigned to type " + assignStmt.lhs.type); + validType(assignStmt.lhs.type, false); + validType(assignStmt.rhs.type, true); + checkAssignmentCompatible(assignStmt.lhs.type, assignStmt.rhs.type); } return this; } @@ -304,10 +320,17 @@ public void analyze(AST.Program program) { program.accept(this); } - private void validType(Type t) { + private void validType(Type t, boolean allowNull) { if (t == null) throw new CompilerException("Undefined type"); if (t == typeDictionary.UNKNOWN) throw new CompilerException("Undefined type"); + if (!allowNull && t == typeDictionary.NULL) + throw new CompilerException("Null type not allowed"); + } + + private void checkAssignmentCompatible(Type var, Type value) { + if (!var.isAssignable(value)) + throw new CompilerException("Value of type " + value + " cannot be assigned to type " + var); } } diff --git a/semantic/src/test/java/com/compilerprogramming/ezlang/semantic/TestSemaAssignTypes.java b/semantic/src/test/java/com/compilerprogramming/ezlang/semantic/TestSemaAssignTypes.java index 3f68c20..39b8302 100644 --- a/semantic/src/test/java/com/compilerprogramming/ezlang/semantic/TestSemaAssignTypes.java +++ b/semantic/src/test/java/com/compilerprogramming/ezlang/semantic/TestSemaAssignTypes.java @@ -1,5 +1,6 @@ package com.compilerprogramming.ezlang.semantic; +import com.compilerprogramming.ezlang.exceptions.CompilerException; import com.compilerprogramming.ezlang.lexer.Lexer; import com.compilerprogramming.ezlang.parser.Parser; import com.compilerprogramming.ezlang.types.TypeDictionary; @@ -8,494 +9,290 @@ public class TestSemaAssignTypes { - @Test - public void test1() { + static void analyze(String src, String symbolName, String typeSig) { Parser parser = new Parser(); - String src = """ - func foo(a: Int, b: Int)->Int - { - return a+b; - } -"""; var program = parser.parse(new Lexer(src)); var typeDict = new TypeDictionary(); var sema = new SemaDefineTypes(typeDict); sema.analyze(program); - var symbol = typeDict.lookup("foo"); + var symbol = typeDict.lookup(symbolName); Assert.assertNotNull(symbol); - Assert.assertEquals("func foo(a: Int,b: Int)->Int", symbol.type.describe()); + Assert.assertEquals(typeSig, symbol.type.describe()); var sema2 = new SemaAssignTypes(typeDict); sema2.analyze(program); } + @Test + public void test1() { + String src = """ + func foo(a: Int, b: Int)->Int + { + return a+b; + } +"""; + analyze(src, "foo", "func foo(a: Int,b: Int)->Int"); + } + @Test public void test51() { - Parser parser = new Parser(); String src = """ func foo()->Int { return 1+1; } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("foo"); - Assert.assertNotNull(symbol); - Assert.assertEquals("func foo()->Int", symbol.type.describe()); - var sema2 = new SemaAssignTypes(typeDict); - sema2.analyze(program); + analyze(src, "foo", "func foo()->Int"); } @Test public void test2() { - Parser parser = new Parser(); String src = """ func foo(a: Int, b: Int)->Int { return a-b; } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("foo"); - Assert.assertNotNull(symbol); - Assert.assertEquals("func foo(a: Int,b: Int)->Int", symbol.type.describe()); - var sema2 = new SemaAssignTypes(typeDict); - sema2.analyze(program); + analyze(src, "foo", "func foo(a: Int,b: Int)->Int"); } @Test public void test52() { - Parser parser = new Parser(); String src = """ func foo()->Int { return 1-1; } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("foo"); - Assert.assertNotNull(symbol); - Assert.assertEquals("func foo()->Int", symbol.type.describe()); - var sema2 = new SemaAssignTypes(typeDict); - sema2.analyze(program); + analyze(src, "foo", "func foo()->Int"); } @Test public void test3() { - Parser parser = new Parser(); String src = """ func foo(a: Int, b: Int)->Int { return a*b; } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("foo"); - Assert.assertNotNull(symbol); - Assert.assertEquals("func foo(a: Int,b: Int)->Int", symbol.type.describe()); - var sema2 = new SemaAssignTypes(typeDict); - sema2.analyze(program); + analyze(src, "foo", "func foo(a: Int,b: Int)->Int"); } @Test public void test53() { - Parser parser = new Parser(); String src = """ func foo()->Int { return 4*2; } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("foo"); - Assert.assertNotNull(symbol); - Assert.assertEquals("func foo()->Int", symbol.type.describe()); - var sema2 = new SemaAssignTypes(typeDict); - sema2.analyze(program); + analyze(src, "foo", "func foo()->Int"); } @Test public void test4() { - Parser parser = new Parser(); String src = """ func foo(a: Int, b: Int)->Int { return a/b; } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("foo"); - Assert.assertNotNull(symbol); - Assert.assertEquals("func foo(a: Int,b: Int)->Int", symbol.type.describe()); - var sema2 = new SemaAssignTypes(typeDict); - sema2.analyze(program); + analyze(src, "foo", "func foo(a: Int,b: Int)->Int"); } @Test public void test54() { - Parser parser = new Parser(); String src = """ func foo()->Int { return 4/2; } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("foo"); - Assert.assertNotNull(symbol); - Assert.assertEquals("func foo()->Int", symbol.type.describe()); - var sema2 = new SemaAssignTypes(typeDict); - sema2.analyze(program); + analyze(src, "foo", "func foo()->Int"); } @Test public void test5() { - Parser parser = new Parser(); String src = """ func foo(a: Int, b: Int)->Int { return a==b; } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("foo"); - Assert.assertNotNull(symbol); - Assert.assertEquals("func foo(a: Int,b: Int)->Int", symbol.type.describe()); - var sema2 = new SemaAssignTypes(typeDict); - sema2.analyze(program); + analyze(src, "foo", "func foo(a: Int,b: Int)->Int"); } @Test public void test55() { - Parser parser = new Parser(); String src = """ func foo()->Int { return 4==2; } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("foo"); - Assert.assertNotNull(symbol); - Assert.assertEquals("func foo()->Int", symbol.type.describe()); - var sema2 = new SemaAssignTypes(typeDict); - sema2.analyze(program); + analyze(src, "foo", "func foo()->Int"); } @Test public void test6() { - Parser parser = new Parser(); String src = """ func foo(a: Int, b: Int)->Int { return a!=b; } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("foo"); - Assert.assertNotNull(symbol); - Assert.assertEquals("func foo(a: Int,b: Int)->Int", symbol.type.describe()); - var sema2 = new SemaAssignTypes(typeDict); - sema2.analyze(program); + analyze(src, "foo", "func foo(a: Int,b: Int)->Int"); } @Test public void test56() { - Parser parser = new Parser(); String src = """ func foo()->Int { return 4!=2; } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("foo"); - Assert.assertNotNull(symbol); - Assert.assertEquals("func foo()->Int", symbol.type.describe()); - var sema2 = new SemaAssignTypes(typeDict); - sema2.analyze(program); + analyze(src, "foo", "func foo()->Int"); } @Test public void test7() { - Parser parser = new Parser(); String src = """ func foo(a: Int, b: Int)->Int { return aInt", symbol.type.describe()); - var sema2 = new SemaAssignTypes(typeDict); - sema2.analyze(program); + analyze(src, "foo", "func foo(a: Int,b: Int)->Int"); } @Test public void test57() { - Parser parser = new Parser(); String src = """ func foo()->Int { return 4<2; } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("foo"); - Assert.assertNotNull(symbol); - Assert.assertEquals("func foo()->Int", symbol.type.describe()); - var sema2 = new SemaAssignTypes(typeDict); - sema2.analyze(program); + analyze(src, "foo", "func foo()->Int"); } @Test public void test8() { - Parser parser = new Parser(); String src = """ func foo(a: Int, b: Int)->Int { return a<=b; } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("foo"); - Assert.assertNotNull(symbol); - Assert.assertEquals("func foo(a: Int,b: Int)->Int", symbol.type.describe()); - var sema2 = new SemaAssignTypes(typeDict); - sema2.analyze(program); + analyze(src, "foo", "func foo(a: Int,b: Int)->Int"); } @Test public void test58() { - Parser parser = new Parser(); String src = """ func foo()->Int { return 4<=2; } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("foo"); - Assert.assertNotNull(symbol); - Assert.assertEquals("func foo()->Int", symbol.type.describe()); - var sema2 = new SemaAssignTypes(typeDict); - sema2.analyze(program); + analyze(src, "foo", "func foo()->Int"); } @Test public void test9() { - Parser parser = new Parser(); String src = """ func foo(a: Int, b: Int)->Int { return a>b; } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("foo"); - Assert.assertNotNull(symbol); - Assert.assertEquals("func foo(a: Int,b: Int)->Int", symbol.type.describe()); - var sema2 = new SemaAssignTypes(typeDict); - sema2.analyze(program); + analyze(src, "foo", "func foo(a: Int,b: Int)->Int"); } @Test public void test59() { - Parser parser = new Parser(); String src = """ func foo()->Int { return 4>2; } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("foo"); - Assert.assertNotNull(symbol); - Assert.assertEquals("func foo()->Int", symbol.type.describe()); - var sema2 = new SemaAssignTypes(typeDict); - sema2.analyze(program); + analyze(src, "foo", "func foo()->Int"); } @Test public void test10() { - Parser parser = new Parser(); String src = """ func foo(a: Int, b: Int)->Int { return a>=b; } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("foo"); - Assert.assertNotNull(symbol); - Assert.assertEquals("func foo(a: Int,b: Int)->Int", symbol.type.describe()); - var sema2 = new SemaAssignTypes(typeDict); - sema2.analyze(program); + analyze(src, "foo", "func foo(a: Int,b: Int)->Int"); } @Test public void test60() { - Parser parser = new Parser(); String src = """ func foo()->Int { return 4>=2; } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("foo"); - Assert.assertNotNull(symbol); - Assert.assertEquals("func foo()->Int", symbol.type.describe()); - var sema2 = new SemaAssignTypes(typeDict); - sema2.analyze(program); + analyze(src, "foo", "func foo()->Int"); } @Test public void test11() { - Parser parser = new Parser(); String src = """ func foo(a: Int, b: Int)->Int { return a&&b; } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("foo"); - Assert.assertNotNull(symbol); - Assert.assertEquals("func foo(a: Int,b: Int)->Int", symbol.type.describe()); - var sema2 = new SemaAssignTypes(typeDict); - sema2.analyze(program); + analyze(src, "foo", "func foo(a: Int,b: Int)->Int"); } @Test public void test61() { - Parser parser = new Parser(); String src = """ func foo()->Int { return 4&&2; } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("foo"); - Assert.assertNotNull(symbol); - Assert.assertEquals("func foo()->Int", symbol.type.describe()); - var sema2 = new SemaAssignTypes(typeDict); - sema2.analyze(program); + analyze(src, "foo", "func foo()->Int"); } @Test public void test12() { - Parser parser = new Parser(); String src = """ func foo(a: Int, b: Int)->Int { return a||b; } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("foo"); - Assert.assertNotNull(symbol); - Assert.assertEquals("func foo(a: Int,b: Int)->Int", symbol.type.describe()); - var sema2 = new SemaAssignTypes(typeDict); - sema2.analyze(program); + analyze(src, "foo", "func foo(a: Int,b: Int)->Int"); } @Test public void test62() { - Parser parser = new Parser(); String src = """ func foo()->Int { return 4||2; } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("foo"); - Assert.assertNotNull(symbol); - Assert.assertEquals("func foo()->Int", symbol.type.describe()); - var sema2 = new SemaAssignTypes(typeDict); - sema2.analyze(program); + analyze(src, "foo", "func foo()->Int"); } @Test public void test13() { - Parser parser = new Parser(); String src = """ struct Foo { @@ -508,15 +305,49 @@ func foo()->Foo return f.bar[0] } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("foo"); - Assert.assertNotNull(symbol); - Assert.assertEquals("func foo()->Foo", symbol.type.describe()); - var sema2 = new SemaAssignTypes(typeDict); - sema2.analyze(program); + analyze(src, "foo", "func foo()->Foo"); } + @Test(expected = CompilerException.class) + public void test14() { + String src = """ + struct Foo + { + var bar: [Int] + } + func foo() + { + var f: Foo + f = null + } +"""; + analyze(src, "foo", "func foo()"); + } + + @Test + public void test15() { + String src = """ + struct Foo + { + var bar: [Int] + } + func foo() + { + var f: Foo? + f = null + } +"""; + analyze(src, "foo", "func foo()"); + } + + @Test(expected = CompilerException.class) + public void test16() { + String src = """ + func foo() + { + var f = null + } +"""; + analyze(src, "foo", "func foo()"); + } } diff --git a/semantic/src/test/java/com/compilerprogramming/ezlang/semantic/TestSemaDefineTypes.java b/semantic/src/test/java/com/compilerprogramming/ezlang/semantic/TestSemaDefineTypes.java index 8de6c06..0b87090 100644 --- a/semantic/src/test/java/com/compilerprogramming/ezlang/semantic/TestSemaDefineTypes.java +++ b/semantic/src/test/java/com/compilerprogramming/ezlang/semantic/TestSemaDefineTypes.java @@ -9,116 +9,78 @@ public class TestSemaDefineTypes { + private void analyze(String src, String symName, String typeSig) { + Parser parser = new Parser(); + var program = parser.parse(new Lexer(src)); + var typeDict = new TypeDictionary(); + var sema = new SemaDefineTypes(typeDict); + sema.analyze(program); + var symbol = typeDict.lookup(symName); + Assert.assertNotNull(symbol); + Assert.assertEquals(typeSig, symbol.type.describe()); + } + @Test public void test1() { - Parser parser = new Parser(); String src = """ struct Tree { var left: Tree? var right: Tree? } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("Tree"); - Assert.assertNotNull(symbol); - Assert.assertEquals("struct Tree{left: Tree?;right: Tree?;}", symbol.type.describe()); + analyze(src, "Tree", "struct Tree{left: Tree?;right: Tree?;}"); } @Test public void test2() { - Parser parser = new Parser(); String src = """ struct Tree { var left: Tree var right: Tree } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("Tree"); - Assert.assertNotNull(symbol); - Assert.assertEquals("struct Tree{left: Tree;right: Tree;}", symbol.type.describe()); + analyze(src, "Tree", "struct Tree{left: Tree;right: Tree;}"); } @Test public void test3() { - Parser parser = new Parser(); String src = """ struct TreeArray { var data: [Tree] } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("TreeArray"); - Assert.assertNotNull(symbol); - Assert.assertEquals("struct TreeArray{data: [Tree];}", symbol.type.describe()); + analyze(src, "TreeArray", "struct TreeArray{data: [Tree];}"); } @Test public void test4() { - Parser parser = new Parser(); String src = """ struct TreeArray { var data: [Tree?] } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("TreeArray"); - Assert.assertNotNull(symbol); - Assert.assertEquals("struct TreeArray{data: [Tree?];}", symbol.type.describe()); + analyze(src, "TreeArray", "struct TreeArray{data: [Tree?];}"); } @Test public void test5() { - Parser parser = new Parser(); String src = """ struct TreeArray { var data: [Tree?]? } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("TreeArray"); - Assert.assertNotNull(symbol); - Assert.assertEquals("struct TreeArray{data: [Tree?]?;}", symbol.type.describe()); + analyze(src, "TreeArray", "struct TreeArray{data: [Tree?]?;}"); } @Test public void test6() { - Parser parser = new Parser(); String src = """ func print(t: Tree) { } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("print"); - Assert.assertNotNull(symbol); - Assert.assertEquals("func print(t: Tree)", symbol.type.describe()); + analyze(src, "print", "func print(t: Tree)"); } @Test public void test7() { - Parser parser = new Parser(); String src = """ func makeTree()->Tree { } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("makeTree"); - Assert.assertNotNull(symbol); - Assert.assertEquals("func makeTree()->Tree", symbol.type.describe()); + analyze(src, "makeTree", "func makeTree()->Tree"); } @Test(expected = CompilerException.class) @@ -136,18 +98,11 @@ public void test8() { @Test public void test9() { - Parser parser = new Parser(); String src = """ struct TreeArray { var data: [Int] } """; - var program = parser.parse(new Lexer(src)); - var typeDict = new TypeDictionary(); - var sema = new SemaDefineTypes(typeDict); - sema.analyze(program); - var symbol = typeDict.lookup("TreeArray"); - Assert.assertNotNull(symbol); - Assert.assertEquals("struct TreeArray{data: [Int];}", symbol.type.describe()); + analyze(src, "TreeArray", "struct TreeArray{data: [Int];}"); } } diff --git a/types/src/main/java/com/compilerprogramming/ezlang/types/Type.java b/types/src/main/java/com/compilerprogramming/ezlang/types/Type.java index 7e1d44c..1233a64 100644 --- a/types/src/main/java/com/compilerprogramming/ezlang/types/Type.java +++ b/types/src/main/java/com/compilerprogramming/ezlang/types/Type.java @@ -61,6 +61,12 @@ public boolean isAssignable(Type other) { return true; return nullable.baseType.isAssignable(other); } + else if (other instanceof TypeNullable nullable) { + // At compile time we allow nullable value to be + // assigned to base type, but null check must be inserted + // Optimizer may remove null check + return isAssignable(nullable.baseType); + } return false; } From 5d4824c2c64450729857dc194ddb2addc277a9bc Mon Sep 17 00:00:00 2001 From: dibyendumajumdar Date: Sun, 16 Feb 2025 10:28:14 +0000 Subject: [PATCH 3/4] Sync register vm compiler and interpreter with optvm so that both are close. This will make it easier to see what the differences are and also to make the changes to support null --- .../ezlang/compiler/CompiledFunction.java | 13 +- .../ezlang/compiler/Instruction.java | 187 ++++++++++-------- .../ezlang/compiler/Operand.java | 25 +-- .../ezlang/interpreter/Interpreter.java | 116 +++++------ .../ezlang/compiler/TestCompiler.java | 48 ++--- 5 files changed, 201 insertions(+), 188 deletions(-) diff --git a/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java b/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java index dedc42e..56d5f55 100644 --- a/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java +++ b/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java @@ -13,8 +13,8 @@ public class CompiledFunction { public BasicBlock entry; public BasicBlock exit; - private int bid = 0; - private BasicBlock currentBlock; + private int BID = 0; + public BasicBlock currentBlock; private BasicBlock currentBreakTarget; private BasicBlock currentContinueTarget; public int maxLocalReg; @@ -33,7 +33,7 @@ public class CompiledFunction { public CompiledFunction(Symbol.FunctionTypeSymbol functionSymbol) { AST.FuncDecl funcDecl = (AST.FuncDecl) functionSymbol.functionDecl; setVirtualRegisters(funcDecl.scope); - this.bid = 0; + this.BID = 0; this.entry = this.currentBlock = createBlock(); this.exit = createBlock(); this.currentBreakTarget = null; @@ -71,11 +71,11 @@ private void setVirtualRegisters(Scope scope) { } private BasicBlock createBlock() { - return new BasicBlock(bid++); + return new BasicBlock(BID++); } private BasicBlock createLoopHead() { - return new BasicBlock(bid++, true); + return new BasicBlock(BID++, true); } private void compileBlock(AST.BlockStmt block) { @@ -90,7 +90,7 @@ private void compileReturn(AST.ReturnStmt returnStmt) { if (isIndexed) codeIndexedLoad(); if (virtualStack.size() == 1) - code(new Instruction.Return(pop())); + code(new Instruction.Ret(pop())); else if (virtualStack.size() > 1) throw new CompilerException("Virtual stack has more than one item at return"); } @@ -109,6 +109,7 @@ private void compileStatement(AST.Stmt statement) { case AST.VarStmt letStmt -> { compileLet(letStmt); } + case AST.VarDeclStmt varDeclStmt -> {} case AST.IfElseStmt ifElseStmt -> { compileIf(ifElseStmt); } diff --git a/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/Instruction.java b/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/Instruction.java index 448c0da..35c1fda 100644 --- a/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/Instruction.java +++ b/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/Instruction.java @@ -4,37 +4,69 @@ public abstract class Instruction { - public boolean isTerminal() { - return false; + static final int I_NOOP = 0; + static final int I_MOVE = 1; + static final int I_RET = 2; + static final int I_UNARY = 3; + static final int I_BINARY = 4; + static final int I_BR = 5; + static final int I_CBR = 6; + static final int I_ARG = 7; + static final int I_CALL = 8; + static final int I_PHI = 9; + static final int I_NEW_ARRAY = 10; + static final int I_NEW_STRUCT = 11; + static final int I_ARRAY_STORE = 12; + static final int I_ARRAY_LOAD = 13; + static final int I_ARRAY_APPEND = 14; + static final int I_FIELD_GET = 15; + static final int I_FIELD_SET = 16; + + public final int opcode; + protected Operand.RegisterOperand def; + protected Operand[] uses; + + protected Instruction(int opcode, Operand... uses) { + this.opcode = opcode; + this.def = null; + this.uses = new Operand[uses.length]; + System.arraycopy(uses, 0, this.uses, 0, uses.length); + } + protected Instruction(int opcode, Operand.RegisterOperand def, Operand... uses) { + this.opcode = opcode; + this.def = def; + this.uses = new Operand[uses.length]; + System.arraycopy(uses, 0, this.uses, 0, uses.length); } + + public boolean isTerminal() { return false; } @Override public String toString() { return toStr(new StringBuilder()).toString(); } public static class Move extends Instruction { - public final Operand from; - public final Operand to; public Move(Operand from, Operand to) { - this.from = from; - this.to = to; + super(I_MOVE, (Operand.RegisterOperand) to, from); } + public Operand from() { return uses[0]; } + public Operand.RegisterOperand to() { return def; } @Override public StringBuilder toStr(StringBuilder sb) { - return sb.append(to).append(" = ").append(from); + return sb.append(to()).append(" = ").append(from()); } } public static class NewArray extends Instruction { public final Type.TypeArray type; - public final Operand.RegisterOperand destOperand; public NewArray(Type.TypeArray type, Operand.RegisterOperand destOperand) { + super(I_NEW_ARRAY, destOperand); this.type = type; - this.destOperand = destOperand; } + public Operand.RegisterOperand destOperand() { return def; } @Override public StringBuilder toStr(StringBuilder sb) { - return sb.append(destOperand) + return sb.append(def) .append(" = ") .append("New(") .append(type) @@ -44,14 +76,14 @@ public StringBuilder toStr(StringBuilder sb) { public static class NewStruct extends Instruction { public final Type.TypeStruct type; - public final Operand.RegisterOperand destOperand; public NewStruct(Type.TypeStruct type, Operand.RegisterOperand destOperand) { + super(I_NEW_STRUCT, destOperand); this.type = type; - this.destOperand = destOperand; } + public Operand.RegisterOperand destOperand() { return def; } @Override public StringBuilder toStr(StringBuilder sb) { - return sb.append(destOperand) + return sb.append(def) .append(" = ") .append("New(") .append(type) @@ -60,182 +92,177 @@ public StringBuilder toStr(StringBuilder sb) { } public static class ArrayLoad extends Instruction { - public final Operand arrayOperand; - public final Operand indexOperand; - public final Operand.RegisterOperand destOperand; public ArrayLoad(Operand.LoadIndexedOperand from, Operand.RegisterOperand to) { - arrayOperand = from.arrayOperand; - indexOperand = from.indexOperand; - destOperand = to; + super(I_ARRAY_LOAD, to, from.arrayOperand, from.indexOperand); } + public Operand arrayOperand() { return uses[0]; } + public Operand indexOperand() { return uses[1]; } + public Operand.RegisterOperand destOperand() { return def; } @Override public StringBuilder toStr(StringBuilder sb) { - return sb.append(destOperand) + return sb.append(destOperand()) .append(" = ") - .append(arrayOperand) + .append(arrayOperand()) .append("[") - .append(indexOperand) + .append(indexOperand()) .append("]"); } } public static class ArrayStore extends Instruction { - public final Operand arrayOperand; - public final Operand indexOperand; - public final Operand sourceOperand; public ArrayStore(Operand from, Operand.LoadIndexedOperand to) { - arrayOperand = to.arrayOperand; - indexOperand = to.indexOperand; - sourceOperand = from; + super(I_ARRAY_STORE, (Operand.RegisterOperand) null, to.arrayOperand, to.indexOperand, from); } + public Operand arrayOperand() { return uses[0]; } + public Operand indexOperand() { return uses[1]; } + public Operand sourceOperand() { return uses[2]; } @Override public StringBuilder toStr(StringBuilder sb) { return sb - .append(arrayOperand) + .append(arrayOperand()) .append("[") - .append(indexOperand) + .append(indexOperand()) .append("] = ") - .append(sourceOperand); + .append(sourceOperand()); } } public static class GetField extends Instruction { - public final Operand structOperand; public final String fieldName; public final int fieldIndex; - public final Operand.RegisterOperand destOperand; - public GetField(Operand.LoadFieldOperand from, Operand.RegisterOperand to) - { - this.structOperand = from.structOperand; + public GetField(Operand.LoadFieldOperand from, Operand.RegisterOperand to) { + super(I_FIELD_GET, to, from.structOperand); this.fieldName = from.fieldName; this.fieldIndex = from.fieldIndex; - this.destOperand = to; } + public Operand structOperand() { return uses[0]; } + public Operand.RegisterOperand destOperand() { return def; } @Override public StringBuilder toStr(StringBuilder sb) { - return sb.append(destOperand) + return sb.append(def) .append(" = ") - .append(structOperand) + .append(uses[0]) .append(".") .append(fieldName); } } public static class SetField extends Instruction { - public final Operand structOperand; public final String fieldName; public final int fieldIndex; - public final Operand sourceOperand; - public SetField(Operand from,Operand.LoadFieldOperand to) - { - this.structOperand = to.structOperand; + public SetField(Operand from,Operand.LoadFieldOperand to) { + super(I_FIELD_SET, (Operand.RegisterOperand) null, to.structOperand, from); this.fieldName = to.fieldName; this.fieldIndex = to.fieldIndex; - this.sourceOperand = from; } + public Operand structOperand() { return uses[0]; } + public Operand sourceOperand() { return uses[1]; } @Override public StringBuilder toStr(StringBuilder sb) { return sb - .append(structOperand) + .append(structOperand()) .append(".") .append(fieldName) .append(" = ") - .append(sourceOperand); + .append(sourceOperand()); } } - public static class Return extends Move { - public Return(Operand from) { - super(from, new Operand.ReturnRegisterOperand()); + public static class Ret extends Instruction { + public Ret(Operand value) { + super(I_RET, (Operand.RegisterOperand) null, value); + } + public Operand value() { return uses[0]; } + @Override + public StringBuilder toStr(StringBuilder sb) { + sb.append("ret"); + if (uses[0] != null) + sb.append(" ").append(value()); + return sb; } } public static class Unary extends Instruction { public final String unop; - public final Operand.RegisterOperand result; - public final Operand operand; public Unary(String unop, Operand.RegisterOperand result, Operand operand) { + super(I_UNARY, result, operand); this.unop = unop; - this.result = result; - this.operand = operand; } + public Operand.RegisterOperand result() { return def; } + public Operand operand() { return uses[0]; } @Override public StringBuilder toStr(StringBuilder sb) { - return sb.append(result).append(" = ").append(unop).append(operand); + return sb.append(result()).append(" = ").append(unop).append(operand()); } } public static class Binary extends Instruction { public final String binOp; - public final Operand.RegisterOperand result; - public final Operand left; - public final Operand right; public Binary(String binop, Operand.RegisterOperand result, Operand left, Operand right) { + super(I_BINARY, result, left, right); this.binOp = binop; - this.result = result; - this.left = left; - this.right = right; } + public Operand.RegisterOperand result() { return def; } + public Operand left() { return uses[0]; } + public Operand right() { return uses[1]; } @Override public StringBuilder toStr(StringBuilder sb) { - return sb.append(result).append(" = ").append(left).append(binOp).append(right); + return sb.append(def).append(" = ").append(uses[0]).append(binOp).append(uses[1]); } } public static class AStoreAppend extends Instruction { - public final Operand.RegisterOperand array; - public final Operand value; public AStoreAppend(Operand.RegisterOperand array, Operand value) { - this.array = array; - this.value = value; + super(I_ARRAY_APPEND, (Operand.RegisterOperand) null, array, value); } + public Operand.RegisterOperand array() { return (Operand.RegisterOperand) uses[0]; } + public Operand value() { return uses[1]; } @Override public StringBuilder toStr(StringBuilder sb) { - return sb.append(array).append(".append(").append(value).append(")"); + return sb.append(uses[0]).append(".append(").append(uses[1]).append(")"); } } public static class ConditionalBranch extends Instruction { - public final Operand condition; public final BasicBlock trueBlock; public final BasicBlock falseBlock; public ConditionalBranch(BasicBlock currentBlock, Operand condition, BasicBlock trueBlock, BasicBlock falseBlock) { - this.condition = condition; + super(I_CBR, (Operand.RegisterOperand) null, condition); this.trueBlock = trueBlock; this.falseBlock = falseBlock; currentBlock.addSuccessor(trueBlock); currentBlock.addSuccessor(falseBlock); } + public Operand condition() { return uses[0]; } @Override public boolean isTerminal() { return true; } @Override public StringBuilder toStr(StringBuilder sb) { - return sb.append("if ").append(condition).append(" goto L").append(trueBlock.bid).append(" else goto L").append(falseBlock.bid); + return sb.append("if ").append(condition()).append(" goto L").append(trueBlock.bid).append(" else goto L").append(falseBlock.bid); } } public static class Call extends Instruction { public final Type.TypeFunction callee; - public final Operand.RegisterOperand[] args; - public final Operand.RegisterOperand returnOperand; public final int newbase; public Call(int newbase, Operand.RegisterOperand returnOperand, Type.TypeFunction callee, Operand.RegisterOperand... args) { - this.returnOperand = returnOperand; + super(I_CALL, returnOperand, args); this.callee = callee; - this.args = args; this.newbase = newbase; } + public Operand.RegisterOperand returnOperand() { return def; } + public Operand[] args() { return uses; } @Override public StringBuilder toStr(StringBuilder sb) { - if (returnOperand != null) { - sb.append(returnOperand).append(" = "); + if (def != null) { + sb.append(def).append(" = "); } sb.append("call ").append(callee); - if (args.length > 0) + if (uses.length > 0) sb.append(" params "); - for (int i = 0; i < args.length; i++) { + for (int i = 0; i < uses.length; i++) { if (i > 0) sb.append(", "); - sb.append(args[i]); + sb.append(uses[i]); } return sb; } @@ -244,6 +271,7 @@ public StringBuilder toStr(StringBuilder sb) { public static class Jump extends Instruction { public final BasicBlock jumpTo; public Jump(BasicBlock jumpTo) { + super(I_BR); this.jumpTo = jumpTo; } @Override @@ -256,5 +284,6 @@ public StringBuilder toStr(StringBuilder sb) { } } + public abstract StringBuilder toStr(StringBuilder sb); } diff --git a/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/Operand.java b/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/Operand.java index bff7d66..9cb9a70 100644 --- a/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/Operand.java +++ b/registervm/src/main/java/com/compilerprogramming/ezlang/compiler/Operand.java @@ -23,6 +23,7 @@ public static abstract class RegisterOperand extends Operand { public RegisterOperand(int regnum) { this.regnum = regnum; } + public int frameSlot() { return regnum; } } public static class LocalRegisterOperand extends RegisterOperand { @@ -48,17 +49,6 @@ public String toString() { } } - /** - * Represents the return register, which is the location where - * the caller will expect to see any return value. The VM must map - * this to appropriate location. - */ - public static class ReturnRegisterOperand extends RegisterOperand { - public ReturnRegisterOperand() { super(0); } - @Override - public String toString() { return "%ret"; } - } - /** * Represents a temp register, maps to a location on the * virtual stack. Temps start at offset 0, but this is a relative @@ -108,17 +98,4 @@ public String toString() { return structOperand + "." + fieldName; } } - - public static class NewTypeOperand extends Operand { - public final Type type; - public NewTypeOperand(Type type) { - this.type = type; - } - - @Override - public String toString() { - return "New(" + type + ")"; - } - } - } diff --git a/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java b/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java index 3c06472..7be8ef9 100644 --- a/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java +++ b/registervm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java @@ -44,23 +44,24 @@ public Value interpret(ExecutionStack execStack, Frame frame) { ip++; instruction = currentBlock.instructions.get(ip); switch (instruction) { - case Instruction.Return returnInst -> { - if (returnInst.from instanceof Operand.ConstantOperand constantOperand) { + case Instruction.Ret retInst -> { + if (retInst.value() instanceof Operand.ConstantOperand constantOperand) { execStack.stack[base] = new Value.IntegerValue(constantOperand.value); } - else if (returnInst.from instanceof Operand.RegisterOperand registerOperand) { - execStack.stack[base] = execStack.stack[base+registerOperand.regnum]; + else if (retInst.value() instanceof Operand.RegisterOperand registerOperand) { + execStack.stack[base] = execStack.stack[base+registerOperand.frameSlot()]; } else throw new IllegalStateException(); returnValue = execStack.stack[base]; } + case Instruction.Move moveInst -> { - if (moveInst.to instanceof Operand.RegisterOperand toReg) { - if (moveInst.from instanceof Operand.RegisterOperand fromReg) { - execStack.stack[base + toReg.regnum] = execStack.stack[base + fromReg.regnum]; + if (moveInst.to() instanceof Operand.RegisterOperand toReg) { + if (moveInst.from() instanceof Operand.RegisterOperand fromReg) { + execStack.stack[base + toReg.frameSlot()] = execStack.stack[base + fromReg.frameSlot()]; } - else if (moveInst.from instanceof Operand.ConstantOperand constantOperand) { - execStack.stack[base + toReg.regnum] = new Value.IntegerValue(constantOperand.value); + else if (moveInst.from() instanceof Operand.ConstantOperand constantOperand) { + execStack.stack[base + toReg.frameSlot()] = new Value.IntegerValue(constantOperand.value); } else throw new IllegalStateException(); } @@ -74,8 +75,8 @@ else if (moveInst.from instanceof Operand.ConstantOperand constantOperand) { } case Instruction.ConditionalBranch cbrInst -> { boolean condition; - if (cbrInst.condition instanceof Operand.RegisterOperand registerOperand) { - Value value = execStack.stack[base + registerOperand.regnum]; + if (cbrInst.condition() instanceof Operand.RegisterOperand registerOperand) { + Value value = execStack.stack[base + registerOperand.frameSlot()]; if (value instanceof Value.IntegerValue integerValue) { condition = integerValue.value != 0; } @@ -83,7 +84,7 @@ else if (moveInst.from instanceof Operand.ConstantOperand constantOperand) { condition = value != null; } } - else if (cbrInst.condition instanceof Operand.ConstantOperand constantOperand) { + else if (cbrInst.condition() instanceof Operand.ConstantOperand constantOperand) { condition = constantOperand.value != 0; } else throw new IllegalStateException(); @@ -99,8 +100,13 @@ else if (cbrInst.condition instanceof Operand.ConstantOperand constantOperand) { // Copy args to new frame int baseReg = base+currentFunction.frameSize(); int reg = baseReg; - for (Operand.RegisterOperand arg: callInst.args) { - execStack.stack[base + reg] = execStack.stack[base + arg.regnum]; + for (Operand arg: callInst.args()) { + if (arg instanceof Operand.RegisterOperand param) { + execStack.stack[base + reg] = execStack.stack[base + param.frameSlot()]; + } + else if (arg instanceof Operand.ConstantOperand constantOperand) { + execStack.stack[base + reg] = new Value.IntegerValue(constantOperand.value); + } reg += 1; } // Call function @@ -108,18 +114,18 @@ else if (cbrInst.condition instanceof Operand.ConstantOperand constantOperand) { interpret(execStack, newFrame); // Copy return value in expected location if (!(callInst.callee.returnType instanceof Type.TypeVoid)) { - execStack.stack[base + callInst.returnOperand.regnum] = execStack.stack[baseReg]; + execStack.stack[base + callInst.returnOperand().frameSlot()] = execStack.stack[baseReg]; } } case Instruction.Unary unaryInst -> { // We don't expect constant here because we fold constants in unary expressions - Operand.RegisterOperand unaryOperand = (Operand.RegisterOperand) unaryInst.operand; - Value unaryValue = execStack.stack[base + unaryOperand.regnum]; + Operand.RegisterOperand unaryOperand = (Operand.RegisterOperand) unaryInst.operand(); + Value unaryValue = execStack.stack[base + unaryOperand.frameSlot()]; if (unaryValue instanceof Value.IntegerValue integerValue) { switch (unaryInst.unop) { - case "-": execStack.stack[base + unaryInst.result.regnum] = new Value.IntegerValue(-integerValue.value); break; + case "-": execStack.stack[base + unaryInst.result().frameSlot()] = new Value.IntegerValue(-integerValue.value); break; // Maybe below we should explicitly set Int - case "!": execStack.stack[base + unaryInst.result.regnum] = new Value.IntegerValue(integerValue.value==0?1:0); break; + case "!": execStack.stack[base + unaryInst.result().frameSlot()] = new Value.IntegerValue(integerValue.value==0?1:0); break; default: throw new CompilerException("Invalid unary op"); } } @@ -129,15 +135,15 @@ else if (cbrInst.condition instanceof Operand.ConstantOperand constantOperand) { case Instruction.Binary binaryInst -> { long x, y; long value = 0; - if (binaryInst.left instanceof Operand.ConstantOperand constant) + if (binaryInst.left() instanceof Operand.ConstantOperand constant) x = constant.value; - else if (binaryInst.left instanceof Operand.RegisterOperand registerOperand) - x = ((Value.IntegerValue) execStack.stack[base + registerOperand.regnum]).value; + else if (binaryInst.left() instanceof Operand.RegisterOperand registerOperand) + x = ((Value.IntegerValue) execStack.stack[base + registerOperand.frameSlot()]).value; else throw new IllegalStateException(); - if (binaryInst.right instanceof Operand.ConstantOperand constant) + if (binaryInst.right() instanceof Operand.ConstantOperand constant) y = constant.value; - else if (binaryInst.right instanceof Operand.RegisterOperand registerOperand) - y = ((Value.IntegerValue) execStack.stack[base + registerOperand.regnum]).value; + else if (binaryInst.right() instanceof Operand.RegisterOperand registerOperand) + y = ((Value.IntegerValue) execStack.stack[base + registerOperand.frameSlot()]).value; else throw new IllegalStateException(); switch (binaryInst.binOp) { case "+": value = x + y; break; @@ -153,80 +159,80 @@ else if (binaryInst.right instanceof Operand.RegisterOperand registerOperand) case ">=": value = x <= y ? 1 : 0; break; default: throw new IllegalStateException(); } - execStack.stack[base + binaryInst.result.regnum] = new Value.IntegerValue(value); + execStack.stack[base + binaryInst.result().frameSlot()] = new Value.IntegerValue(value); } case Instruction.NewArray newArrayInst -> { - execStack.stack[base + newArrayInst.destOperand.regnum] = new Value.ArrayValue(newArrayInst.type); + execStack.stack[base + newArrayInst.destOperand().frameSlot()] = new Value.ArrayValue(newArrayInst.type); } case Instruction.NewStruct newStructInst -> { - execStack.stack[base + newStructInst.destOperand.regnum] = new Value.StructValue(newStructInst.type); + execStack.stack[base + newStructInst.destOperand().frameSlot()] = new Value.StructValue(newStructInst.type); } case Instruction.AStoreAppend arrayAppendInst -> { - Value.ArrayValue arrayValue = (Value.ArrayValue) execStack.stack[base + arrayAppendInst.array.regnum]; - if (arrayAppendInst.value instanceof Operand.ConstantOperand constant) { + Value.ArrayValue arrayValue = (Value.ArrayValue) execStack.stack[base + arrayAppendInst.array().frameSlot()]; + if (arrayAppendInst.value() instanceof Operand.ConstantOperand constant) { arrayValue.values.add(new Value.IntegerValue(constant.value)); } - else if (arrayAppendInst.value instanceof Operand.RegisterOperand registerOperand) { - arrayValue.values.add(execStack.stack[base + registerOperand.regnum]); + else if (arrayAppendInst.value() instanceof Operand.RegisterOperand registerOperand) { + arrayValue.values.add(execStack.stack[base + registerOperand.frameSlot()]); } else throw new IllegalStateException(); } case Instruction.ArrayStore arrayStoreInst -> { - if (arrayStoreInst.arrayOperand instanceof Operand.RegisterOperand arrayOperand) { - Value.ArrayValue arrayValue = (Value.ArrayValue) execStack.stack[base + arrayOperand.regnum]; + if (arrayStoreInst.arrayOperand() instanceof Operand.RegisterOperand arrayOperand) { + Value.ArrayValue arrayValue = (Value.ArrayValue) execStack.stack[base + arrayOperand.frameSlot()]; int index = 0; - if (arrayStoreInst.indexOperand instanceof Operand.ConstantOperand constant) { + if (arrayStoreInst.indexOperand() instanceof Operand.ConstantOperand constant) { index = (int) constant.value; } - else if (arrayStoreInst.indexOperand instanceof Operand.RegisterOperand registerOperand) { - Value.IntegerValue indexValue = (Value.IntegerValue) execStack.stack[base + registerOperand.regnum]; + else if (arrayStoreInst.indexOperand() instanceof Operand.RegisterOperand registerOperand) { + Value.IntegerValue indexValue = (Value.IntegerValue) execStack.stack[base + registerOperand.frameSlot()]; index = (int) indexValue.value; } else throw new IllegalStateException(); Value value; - if (arrayStoreInst.sourceOperand instanceof Operand.ConstantOperand constantOperand) { + if (arrayStoreInst.sourceOperand() instanceof Operand.ConstantOperand constantOperand) { value = new Value.IntegerValue(constantOperand.value); } - else if (arrayStoreInst.sourceOperand instanceof Operand.RegisterOperand registerOperand) { - value = execStack.stack[base + registerOperand.regnum]; + else if (arrayStoreInst.sourceOperand() instanceof Operand.RegisterOperand registerOperand) { + value = execStack.stack[base + registerOperand.frameSlot()]; } else throw new IllegalStateException(); arrayValue.values.set(index, value); } else throw new IllegalStateException(); } case Instruction.ArrayLoad arrayLoadInst -> { - if (arrayLoadInst.arrayOperand instanceof Operand.RegisterOperand arrayOperand) { - Value.ArrayValue arrayValue = (Value.ArrayValue) execStack.stack[base + arrayOperand.regnum]; - if (arrayLoadInst.indexOperand instanceof Operand.ConstantOperand constant) { - execStack.stack[base + arrayLoadInst.destOperand.regnum] = arrayValue.values.get((int) constant.value); + if (arrayLoadInst.arrayOperand() instanceof Operand.RegisterOperand arrayOperand) { + Value.ArrayValue arrayValue = (Value.ArrayValue) execStack.stack[base + arrayOperand.frameSlot()]; + if (arrayLoadInst.indexOperand() instanceof Operand.ConstantOperand constant) { + execStack.stack[base + arrayLoadInst.destOperand().frameSlot()] = arrayValue.values.get((int) constant.value); } - else if (arrayLoadInst.indexOperand instanceof Operand.RegisterOperand registerOperand) { - Value.IntegerValue index = (Value.IntegerValue) execStack.stack[base + registerOperand.regnum]; - execStack.stack[base + arrayLoadInst.destOperand.regnum] = arrayValue.values.get((int) index.value); + else if (arrayLoadInst.indexOperand() instanceof Operand.RegisterOperand registerOperand) { + Value.IntegerValue index = (Value.IntegerValue) execStack.stack[base + registerOperand.frameSlot()]; + execStack.stack[base + arrayLoadInst.destOperand().frameSlot()] = arrayValue.values.get((int) index.value); } else throw new IllegalStateException(); } else throw new IllegalStateException(); } case Instruction.SetField setFieldInst -> { - if (setFieldInst.structOperand instanceof Operand.RegisterOperand structOperand) { - Value.StructValue structValue = (Value.StructValue) execStack.stack[base + structOperand.regnum]; + if (setFieldInst.structOperand() instanceof Operand.RegisterOperand structOperand) { + Value.StructValue structValue = (Value.StructValue) execStack.stack[base + structOperand.frameSlot()]; int index = setFieldInst.fieldIndex; Value value; - if (setFieldInst.sourceOperand instanceof Operand.ConstantOperand constant) { + if (setFieldInst.sourceOperand() instanceof Operand.ConstantOperand constant) { value = new Value.IntegerValue(constant.value); } - else if (setFieldInst.sourceOperand instanceof Operand.RegisterOperand registerOperand) { - value = execStack.stack[base + registerOperand.regnum]; + else if (setFieldInst.sourceOperand() instanceof Operand.RegisterOperand registerOperand) { + value = execStack.stack[base + registerOperand.frameSlot()]; } else throw new IllegalStateException(); structValue.fields[index] = value; } else throw new IllegalStateException(); } case Instruction.GetField getFieldInst -> { - if (getFieldInst.structOperand instanceof Operand.RegisterOperand structOperand) { - Value.StructValue structValue = (Value.StructValue) execStack.stack[base + structOperand.regnum]; + if (getFieldInst.structOperand() instanceof Operand.RegisterOperand structOperand) { + Value.StructValue structValue = (Value.StructValue) execStack.stack[base + structOperand.frameSlot()]; int index = getFieldInst.fieldIndex; - execStack.stack[base + getFieldInst.destOperand.regnum] = structValue.fields[index]; + execStack.stack[base + getFieldInst.destOperand().frameSlot()] = structValue.fields[index]; } else throw new IllegalStateException(); } default -> throw new IllegalStateException("Unexpected value: " + instruction); diff --git a/registervm/src/test/java/com/compilerprogramming/ezlang/compiler/TestCompiler.java b/registervm/src/test/java/com/compilerprogramming/ezlang/compiler/TestCompiler.java index c2f9767..0b0d91f 100644 --- a/registervm/src/test/java/com/compilerprogramming/ezlang/compiler/TestCompiler.java +++ b/registervm/src/test/java/com/compilerprogramming/ezlang/compiler/TestCompiler.java @@ -21,7 +21,7 @@ func foo(n: Int)->Int { String result = compileSrc(src); Assert.assertEquals(""" L0: - %ret = 1 + ret 1 goto L1 L1: """, result); @@ -37,7 +37,7 @@ func foo(n: Int)->Int { String result = compileSrc(src); Assert.assertEquals(""" L0: - %ret = -1 + ret -1 goto L1 L1: """, result); @@ -53,7 +53,7 @@ func foo(n: Int)->Int { String result = compileSrc(src); Assert.assertEquals(""" L0: - %ret = n + ret n goto L1 L1: """, result); @@ -70,7 +70,7 @@ func foo(n: Int)->Int { Assert.assertEquals(""" L0: %t1 = -n - %ret = %t1 + ret %t1 goto L1 L1: """, result); @@ -87,7 +87,7 @@ func foo(n: Int)->Int { Assert.assertEquals(""" L0: %t1 = n+1 - %ret = %t1 + ret %t1 goto L1 L1: """, result); @@ -103,7 +103,7 @@ func foo(n: Int)->Int { String result = compileSrc(src); Assert.assertEquals(""" L0: - %ret = 2 + ret 2 goto L1 L1: """, result); @@ -119,7 +119,7 @@ func foo(n: Int)->Int { String result = compileSrc(src); Assert.assertEquals(""" L0: - %ret = 1 + ret 1 goto L1 L1: """, result); @@ -135,7 +135,7 @@ func foo(n: Int)->Int { String result = compileSrc(src); Assert.assertEquals(""" L0: - %ret = 1 + ret 1 goto L1 L1: """, result); @@ -151,7 +151,7 @@ func foo(n: Int)->Int { String result = compileSrc(src); Assert.assertEquals(""" L0: - %ret = 0 + ret 0 goto L1 L1: """, result); @@ -168,7 +168,7 @@ func foo(n: [Int])->Int { Assert.assertEquals(""" L0: %t1 = n[0] - %ret = %t1 + ret %t1 goto L1 L1: """, result); @@ -187,7 +187,7 @@ func foo(n: [Int])->Int { %t1 = n[0] %t2 = n[1] %t1 = %t1+%t2 - %ret = %t1 + ret %t1 goto L1 L1: """, result); @@ -207,7 +207,7 @@ func foo()->[Int] { %t0.append(1) %t0.append(2) %t0.append(3) - %ret = %t0 + ret %t0 goto L1 L1: """, result); @@ -225,7 +225,7 @@ func foo(n: Int) -> [Int] { L0: %t1 = New([Int]) %t1.append(n) - %ret = %t1 + ret %t1 goto L1 L1: """, result); @@ -242,7 +242,7 @@ func add(x: Int, y: Int) -> Int { Assert.assertEquals(""" L0: %t2 = x+y - %ret = %t2 + ret %t2 goto L1 L1: """, result); @@ -287,7 +287,7 @@ func foo() -> Person { %t0 = New(Person) %t0.age = 10 %t0.children = 0 - %ret = %t0 + ret %t0 goto L1 L1: """, result); @@ -326,11 +326,11 @@ func min(x: Int, y: Int) -> Int { %t2 = x Int { Assert.assertEquals(""" L0: %t1 = p.age - %ret = %t1 + ret %t1 goto L1 L1: """, result); @@ -510,7 +510,7 @@ func foo(p: Person) -> Int { L0: %t1 = p.parent %t1 = %t1.age - %ret = %t1 + ret %t1 goto L1 L1: """, result); @@ -534,7 +534,7 @@ func foo(p: [Person], i: Int) -> Int { %t2 = p[i] %t2 = %t2.parent %t2 = %t2.age - %ret = %t2 + ret %t2 goto L1 L1: """, result); @@ -550,7 +550,7 @@ public void testFunction28() { Assert.assertEquals(""" L0: %t2 = x+y - %ret = %t2 + ret %t2 goto L1 L1: L0: @@ -559,7 +559,7 @@ public void testFunction28() { %t2 = call foo params %t2, %t3 t = %t2 %t2 = t+1 - %ret = %t2 + ret %t2 goto L1 L1: """, result); From f01de7b70636ea8a57214cd1a832263e33d28213 Mon Sep 17 00:00:00 2001 From: dibyendumajumdar Date: Sun, 16 Feb 2025 11:42:48 +0000 Subject: [PATCH 4/4] Improved type checking - also some fixes to test cases that broke after type checking was improved --- .../ezlang/compiler/TestSSATransform.java | 2 +- .../ezlang/interpreter/TestInterpreter.java | 20 ++--- .../ezlang/semantic/SemaAssignTypes.java | 21 +++++- .../ezlang/semantic/SemaDefineTypes.java | 2 +- .../ezlang/semantic/TestSemaAssignTypes.java | 74 ++++++++++++++++++- 5 files changed, 102 insertions(+), 17 deletions(-) diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java index 27387b2..d4d1026 100644 --- a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java +++ b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java @@ -788,7 +788,7 @@ func bar(x: Int)->Int { return x; } - func foo() { + func foo()->Int { return bar(10); } """; diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/interpreter/TestInterpreter.java b/optvm/src/test/java/com/compilerprogramming/ezlang/interpreter/TestInterpreter.java index 2ccaca5..01eb89c 100644 --- a/optvm/src/test/java/com/compilerprogramming/ezlang/interpreter/TestInterpreter.java +++ b/optvm/src/test/java/com/compilerprogramming/ezlang/interpreter/TestInterpreter.java @@ -167,7 +167,7 @@ func fib(n: Int)->Int { return f2; } - func foo() { + func foo()->Int { return fib(10); } """; @@ -229,7 +229,7 @@ func bar(data: [Int]) { else data[2] = 123 + j + 10 } - func foo() { + func foo()->Int { var data = new [Int] {0,0,0} bar(data) return data[0]+data[1]+data[2]; @@ -253,7 +253,7 @@ func bar(data: [Int]) { else data[2] = 123 - j + 10 } - func foo() { + func foo()->Int { var data = new [Int] {0,0,0} bar(data) return data[0]+data[1]+data[2]; @@ -277,7 +277,7 @@ func bar(data: [Int]) { data[2] = 15 data[3] = j + 21 } - func foo() { + func foo()->Int { var data = new [Int] {0,0,0,0} bar(data) return data[0]+data[1]+data[2]+data[3]; @@ -301,7 +301,7 @@ func bar(data: [Int]) { data[2] = j * 15 data[3] = j * 21 } - func foo() { + func foo()->Int { var data = new [Int] {0,0,0,0} bar(data) return data[0]+data[1]+data[2]+data[3]; @@ -329,7 +329,7 @@ func bar(data: [Int]) { } data[3] = j + 21 } - func foo() { + func foo()->Int { var data = new [Int] {1,0,0,0} bar(data) return data[0]+data[1]+data[2]+data[3] @@ -361,7 +361,7 @@ func bar(data: [Int]) { } data[3] = (j+k) * 21 } - func foo() { + func foo()->Int { var data = new [Int] {1,0,0,0} bar(data) return data[0]+data[1]+data[2]+data[3] @@ -389,7 +389,7 @@ func bar(data: [Int]) { data[1] = j data[2] = i } - func foo() { + func foo()->Int { var data = new [Int] {2,0,0} bar(data) return data[0]+data[1]+data[2]; @@ -411,7 +411,7 @@ func bar(data: [Int]) { else j = data[0] data[0] = j * 21 + data[1] } - func foo() { + func foo()->Int { var data = new [Int] {2,3} bar(data) return data[0]+data[1]; @@ -433,7 +433,7 @@ func bar(data: [Int]) { j = j * 21 + 25 / j data[1] = j } - func foo() { + func foo()->Int { var data = new [Int] {5,3} bar(data) return data[0]+data[1]; diff --git a/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaAssignTypes.java b/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaAssignTypes.java index f3a87d2..d996df5 100644 --- a/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaAssignTypes.java +++ b/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaAssignTypes.java @@ -216,15 +216,24 @@ public ASTVisitor visit(AST.NewExpr newExpr, boolean enter) { if (newExpr.typeExpr.type == null) throw new CompilerException("Unresolved type in new expression"); validType(newExpr.typeExpr.type, false); - if (newExpr.typeExpr.type instanceof Type.TypeStruct || - newExpr.typeExpr.type instanceof Type.TypeArray) { + if (newExpr.typeExpr.type instanceof Type.TypeNullable) + throw new CompilerException("new cannot be used to create a Nullable type"); + if (newExpr.typeExpr.type instanceof Type.TypeStruct typeStruct) { newExpr.type = newExpr.typeExpr.type; for (AST.Expr expr: newExpr.initExprList) { if (expr instanceof AST.SetFieldExpr setFieldExpr) { setFieldExpr.objectType = newExpr.typeExpr.type; + var fieldType = typeStruct.getField(setFieldExpr.fieldName); + checkAssignmentCompatible(fieldType, setFieldExpr.value.type); } } } + else if (newExpr.typeExpr.type instanceof Type.TypeArray arrayType) { + newExpr.type = newExpr.typeExpr.type; + for (AST.Expr expr: newExpr.initExprList) { + checkAssignmentCompatible(arrayType.getElementType(), expr.type); + } + } else throw new CompilerException("Unsupported type in new expression"); return this; @@ -258,8 +267,14 @@ public ASTVisitor visit(AST.ContinueStmt continueStmt, boolean enter) { public ASTVisitor visit(AST.ReturnStmt returnStmt, boolean enter) { if (enter) return this; - if (returnStmt.expr != null) + Type.TypeFunction functionType = (Type.TypeFunction) currentFuncDecl.symbol.type; + if (returnStmt.expr != null) { validType(returnStmt.expr.type, false); + checkAssignmentCompatible(functionType.returnType, returnStmt.expr.type); + } + else if (!(functionType.returnType instanceof Type.TypeVoid)) { + throw new CompilerException("A return value of type " + functionType.returnType + " is expected"); + } return this; } diff --git a/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaDefineTypes.java b/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaDefineTypes.java index 95d1ab8..f0bf0c6 100644 --- a/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaDefineTypes.java +++ b/semantic/src/main/java/com/compilerprogramming/ezlang/semantic/SemaDefineTypes.java @@ -162,7 +162,7 @@ Type getNullableSimpleType(AST.NullableSimpleTypeExpr simpleTypeExpr) { else baseType = typeSymbol.type; if (baseType.isPrimitive()) - throw new CompilerException("Cannot make nullable instance of primitive type"); + throw new CompilerException("Cannot make Nullable instance of primitive type"); return typeDictionary.intern(new Type.TypeNullable(baseType)); } diff --git a/semantic/src/test/java/com/compilerprogramming/ezlang/semantic/TestSemaAssignTypes.java b/semantic/src/test/java/com/compilerprogramming/ezlang/semantic/TestSemaAssignTypes.java index 39b8302..91f6700 100644 --- a/semantic/src/test/java/com/compilerprogramming/ezlang/semantic/TestSemaAssignTypes.java +++ b/semantic/src/test/java/com/compilerprogramming/ezlang/semantic/TestSemaAssignTypes.java @@ -298,14 +298,14 @@ public void test13() { { var bar: [Int] } - func foo()->Foo + func foo()->Int { var f: Foo f = new Foo{} return f.bar[0] } """; - analyze(src, "foo", "func foo()->Foo"); + analyze(src, "foo", "func foo()->Int"); } @Test(expected = CompilerException.class) @@ -347,6 +347,76 @@ func foo() { var f = null } +"""; + analyze(src, "foo", "func foo()"); + } + + @Test(expected = CompilerException.class) + public void test17() { + String src = """ + func foo() + { + var f = new [Int] {null} + } +"""; + analyze(src, "foo", "func foo()"); + } + + @Test(expected = CompilerException.class) + public void test18() { + String src = """ + func foo() + { + var f = new [Int?] {null} + } +"""; + analyze(src, "foo", "func foo()"); + } + + @Test + public void test19() { + String src = """ + struct Foo { var bar: Int } + func foo() + { + var f = new [Foo?] {null} + } +"""; + analyze(src, "foo", "func foo()"); + } + + @Test + public void test20() { + String src = """ + struct Foo { var bar: Int } + func foo() + { + var f = new [Foo?] {new Foo{ bar = 1}} + } +"""; + analyze(src, "foo", "func foo()"); + } + + @Test(expected = CompilerException.class) + public void test21() { + String src = """ + struct Foo { var bar: Int } + func foo() + { + var f = new [Foo?] {new Foo{ bar = null}} + } +"""; + analyze(src, "foo", "func foo()"); + } + + @Test(expected = CompilerException.class) + public void test22() { + String src = """ + struct Foo { var bar: Int } + func foo() + { + var f = new [Foo?]? {} + } """; analyze(src, "foo", "func foo()"); }