diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ChaitinGraphColoringRegisterAllocator.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ChaitinGraphColoringRegisterAllocator.java index aa22a75..4838eb0 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ChaitinGraphColoringRegisterAllocator.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ChaitinGraphColoringRegisterAllocator.java @@ -4,26 +4,78 @@ import java.util.stream.IntStream; /** - * Implement the original graph coloring algorithm described by Chaitin. + * Implements the original graph coloring algorithm described by Chaitin. + * Since we are targeting an abstract machine where there are no limits on + * number of registers except how we set them, our goal here is to get to + * the minimum number of registers required to execute the function. + *

+ * We do want to implement spilling even though we do not need it for the + * abstract machine, but it is not yet implemented. We would spill to a + * stack attached to the abstract machine. * * TODO spilling */ public class ChaitinGraphColoringRegisterAllocator { - public ChaitinGraphColoringRegisterAllocator() { - } - public Map assignRegisters(CompiledFunction function, int numRegisters) { if (function.isSSA) throw new IllegalStateException("Register allocation should be done after exiting SSA"); - var g = coalesce(function); - var registers = registersInIR(function); - var colors = IntStream.range(0, numRegisters).boxed().toList(); - // TODO pre-assign regs to args + // Remove useless copy operations + InterferenceGraph g = coalesce(function); + // Get used registers + Set registers = registersInIR(function); + // Create color set + List colors = new ArrayList<>(IntStream.range(0, numRegisters).boxed().toList()); + // Function args are pre-assigned colors + // and we remove them from the register set + Map assignments = preAssignArgsToColors(function, registers, colors); // TODO spilling - var assignments = colorGraph(g, registers, new HashSet<>(colors)); + // execute graph coloring on remaining registers + assignments = colorGraph(g, registers, new HashSet<>(colors), assignments); + // update all instructions + // We simply set the slot on each register - rather than actually trying to replace them + updateInstructions(function, assignments); + // Compute and set the new framesize + function.setFrameSize(computeFrameSize(assignments)); return assignments; } + /** + * Frame size = max number of registers needed to execute the function + */ + private int computeFrameSize(Map assignments) { + return assignments.values().stream().mapToInt(k->k).max().orElse(0); + } + + /** + * Due to the way function args are received by the abstract machine, we need + * to assign them register slots starting from 0. After assigning colors/slots + * we remove these from the set so that the graph coloring algo does + */ + private Map preAssignArgsToColors(CompiledFunction function, Set registers, List colors) { + int count = 0; + Map assignments = new HashMap<>(); + for (Instruction instruction : function.entry.instructions) { + if (instruction instanceof Instruction.ArgInstruction argInstruction) { + Integer color = colors.get(count); + Register reg = argInstruction.arg().reg; + registers.remove(reg.nonSSAId()); // Remove register from set before changing slot + assignments.put(reg.nonSSAId(), color); + count++; + } + else break; + } + return assignments; + } + + private void updateInstructions(CompiledFunction function, Map assignments) { + var regPool = function.registerPool; + for (var entry : assignments.entrySet()) { + int reg = entry.getKey(); + int slot = entry.getValue(); + regPool.getReg(reg).updateSlot(slot); + } + } + /** * Chaitin: coalesce_nodes - coalesce away copy operations */ @@ -85,9 +137,7 @@ private void rewriteInstructions(CompiledFunction function, Instruction deadInst private Set registersInIR(CompiledFunction function) { Set registers = new HashSet<>(); for (var block: function.getBlocks()) { - Iterator iter = block.instructions.iterator(); - while (iter.hasNext()) { - Instruction instruction = iter.next(); + for (Instruction instruction: block.instructions) { if (instruction.definesVar()) registers.add(instruction.def().id); for (Register use: instruction.uses()) @@ -112,7 +162,7 @@ private Integer findNodeWithNeighborCountLessThan(InterferenceGraph g, Set getNeighborColors(InterferenceGraph g, Integer node, Map assignedColors) { Set colors = new HashSet<>(); for (var neighbour: g.neighbors(node)) { - var c = assignedColors.get(neighbour); + Integer c = assignedColors.get(neighbour); if (c != null) { colors.add(c); } @@ -137,18 +187,18 @@ private static HashSet subtract(Set originalSet, Integer node) /** * Chaitin: color_graph */ - private Map colorGraph(InterferenceGraph g, Set nodes, Set colors) { + private Map colorGraph(InterferenceGraph g, Set nodes, Set colors, Map preAssignedColors) { if (nodes.size() == 0) - return new HashMap<>(); - var numColors = colors.size(); - var node = findNodeWithNeighborCountLessThan(g, nodes, numColors); + return preAssignedColors; + int numColors = colors.size(); + Integer node = findNodeWithNeighborCountLessThan(g, nodes, numColors); if (node == null) return null; - var coloring = colorGraph(g.dup().subtract(node), subtract(nodes, node), colors); + Map coloring = colorGraph(g.dup().subtract(node), subtract(nodes, node), colors, preAssignedColors); if (coloring == null) return null; - var neighbourColors = getNeighborColors(g, node, coloring); - var color = chooseSomeColorNotAssignedToNeighbors(colors, neighbourColors); + Set neighbourColors = getNeighborColors(g, node, coloring); + Integer color = chooseSomeColorNotAssignedToNeighbors(colors, neighbourColors); coloring.put(node, color); return coloring; } diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java index 85dd365..3e02457 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java @@ -21,7 +21,7 @@ public class CompiledFunction { private Type.TypeFunction functionType; public final RegisterPool registerPool; - private final int frameSlots; + private int frameSlots; public boolean isSSA; public boolean hasLiveness; @@ -76,6 +76,9 @@ private void generateArgInstructions(Scope scope) { public int frameSize() { return frameSlots; } + public void setFrameSize(int size) { + frameSlots = size; + } private void exitBlockIfNeeded() { if (currentBlock != null && @@ -134,6 +137,7 @@ private void compileStatement(AST.Stmt statement) { case AST.VarStmt letStmt -> { compileLet(letStmt); } + case AST.VarDeclStmt varDeclStmt -> {} case AST.IfElseStmt ifElseStmt -> { compileIf(ifElseStmt); } diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Compiler.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Compiler.java index 105102b..e390fd1 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Compiler.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Compiler.java @@ -8,19 +8,24 @@ import com.compilerprogramming.ezlang.types.Type; import com.compilerprogramming.ezlang.types.TypeDictionary; -import java.util.BitSet; - public class Compiler { - private void compile(TypeDictionary typeDictionary) { + private void compile(TypeDictionary typeDictionary, boolean opt) { for (Symbol symbol: typeDictionary.getLocalSymbols()) { if (symbol instanceof Symbol.FunctionTypeSymbol functionSymbol) { Type.TypeFunction functionType = (Type.TypeFunction) functionSymbol.type; - functionType.code = new CompiledFunction(functionSymbol); + var function = new CompiledFunction(functionSymbol); + functionType.code = function; + if (opt) { + new Optimizer().optimize(function); + } } } } public TypeDictionary compileSrc(String src) { + return compileSrc(src, false); + } + public TypeDictionary compileSrc(String src, boolean opt) { Parser parser = new Parser(); var program = parser.parse(new Lexer(src)); var typeDict = new TypeDictionary(); @@ -28,7 +33,7 @@ public TypeDictionary compileSrc(String src) { sema.analyze(program); var sema2 = new SemaAssignTypes(typeDict); sema2.analyze(program); - compile(typeDict); + compile(typeDict, opt); return typeDict; } public static String dumpIR(TypeDictionary typeDictionary) { diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/EnterSSA.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/EnterSSA.java index 16cf7ee..57e30e5 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/EnterSSA.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/EnterSSA.java @@ -1,5 +1,7 @@ package com.compilerprogramming.ezlang.compiler; +import com.compilerprogramming.ezlang.exceptions.CompilerException; + import java.util.*; /** @@ -178,7 +180,11 @@ static class BBSet { static class VersionStack { List stack = new ArrayList<>(); void push(Register.SSARegister r) { stack.add(r); } - Register.SSARegister top() { return stack.getLast(); } + Register.SSARegister top() { + if (stack.isEmpty()) + throw new CompilerException("Variable may not be initialized"); + return stack.getLast(); + } void pop() { stack.removeLast(); } } diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java index 205f23a..684d11e 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java @@ -23,6 +23,7 @@ public ExitSSA(CompiledFunction function) { initStack(); insertCopies(function.entry); removePhis(); + function.isSSA = false; } private void removePhis() { diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Operand.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Operand.java index 2d002a9..3f8806c 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Operand.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Operand.java @@ -21,7 +21,7 @@ public String toString() { } public static class RegisterOperand extends Operand { - public Register reg; + Register reg; public RegisterOperand(Register reg) { this.reg = reg; if (reg == null) diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Optimizer.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Optimizer.java new file mode 100644 index 0000000..73acce5 --- /dev/null +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Optimizer.java @@ -0,0 +1,10 @@ +package com.compilerprogramming.ezlang.compiler; + +public class Optimizer { + + public void optimize(CompiledFunction function) { + new EnterSSA(function); + new ExitSSA(function); + new ChaitinGraphColoringRegisterAllocator().assignRegisters(function, 64); + } +} diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Register.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Register.java index 259466e..f9f8296 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Register.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Register.java @@ -21,11 +21,13 @@ public class Register { * The type of a register */ public final Type type; + private int slot; public Register(int id, String name, Type type) { this.id = id; this.name = name; this.type = type; + this.slot = id; } @Override public boolean equals(Object o) { @@ -44,7 +46,10 @@ public String name() { return name; } public int nonSSAId() { - return id; + return slot; + } + public void updateSlot(int slot) { + this.slot = slot; } /** diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java index ee16af3..52dffb6 100644 --- a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java +++ b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java @@ -4,6 +4,7 @@ import com.compilerprogramming.ezlang.types.Type; import com.compilerprogramming.ezlang.types.TypeDictionary; import org.junit.Assert; +import org.junit.Ignore; import org.junit.Test; import java.util.Arrays; @@ -766,4 +767,51 @@ public void testSwapProblem() { Assert.assertEquals(expected, function.toStr(new StringBuilder(), false).toString()); } + @Test + public void testLiveness() { + String src = """ + func bar(x: Int)->Int { + var y = 0 + var z = 0 + while( x>1 ){ + y = x/2; + if (y > 3) { + x = x-y; + } + z = x-4; + if (z > 0) { + x = x/2; + } + z = z-1; + } + return x; + } + + func foo() { + return bar(10); + } + """; + String result = compileSrc(src); + System.out.println(result); + } + + @Test + @Ignore + public void testInit() { + // see issue #16 + String src = """ + func foo(x: Int) { + var z: Int + while (x > 0) { + z = 5 + if (x == 1) + z = z+1 + x = x - 1 + } + } + """; + String result = compileSrc(src); + System.out.println(result); + } + } \ No newline at end of file diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/interpreter/TestInterpreter.java b/optvm/src/test/java/com/compilerprogramming/ezlang/interpreter/TestInterpreter.java index 3be1663..fd87020 100644 --- a/optvm/src/test/java/com/compilerprogramming/ezlang/interpreter/TestInterpreter.java +++ b/optvm/src/test/java/com/compilerprogramming/ezlang/interpreter/TestInterpreter.java @@ -7,8 +7,11 @@ public class TestInterpreter { Value compileAndRun(String src, String mainFunction) { + return compileAndRun(src, mainFunction, false); + } + Value compileAndRun(String src, String mainFunction, boolean opt) { var compiler = new Compiler(); - var typeDict = compiler.compileSrc(src); + var typeDict = compiler.compileSrc(src, opt); var compiled = compiler.dumpIR(typeDict); System.out.println(compiled); var interpreter = new Interpreter(typeDict); @@ -121,4 +124,53 @@ func foo()->Int { && integerValue.value == 42); } + @Test + public void testFunction8() { + String src = """ + func factorial(num: Int)->Int { + var result = 1 + while (num > 1) + { + result = result * num + num = num - 1 + } + return result + } + func foo()->Int { + return factorial(5); + } + """; + var value = compileAndRun(src, "foo", true); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.IntegerValue integerValue + && integerValue.value == 120); + } + + @Test + public void testFunction9() { + String src = """ + func fib(n: Int)->Int { + var i: Int; + var temp: Int; + var f1=1; + var f2=1; + i=n; + while( i>1 ){ + temp = f1+f2; + f1=f2; + f2=temp; + i=i-1; + } + return f2; + } + + func foo() { + return fib(10); + } + """; + var value = compileAndRun(src, "foo", true); + Assert.assertNotNull(value); + Assert.assertTrue(value instanceof Value.IntegerValue integerValue + && integerValue.value == 89); + } }