diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ChaitinGraphColoringRegisterAllocator.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ChaitinGraphColoringRegisterAllocator.java
index aa22a75..4838eb0 100644
--- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ChaitinGraphColoringRegisterAllocator.java
+++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ChaitinGraphColoringRegisterAllocator.java
@@ -4,26 +4,78 @@
import java.util.stream.IntStream;
/**
- * Implement the original graph coloring algorithm described by Chaitin.
+ * Implements the original graph coloring algorithm described by Chaitin.
+ * Since we are targeting an abstract machine where there are no limits on
+ * number of registers except how we set them, our goal here is to get to
+ * the minimum number of registers required to execute the function.
+ *
+ * We do want to implement spilling even though we do not need it for the
+ * abstract machine, but it is not yet implemented. We would spill to a
+ * stack attached to the abstract machine.
*
* TODO spilling
*/
public class ChaitinGraphColoringRegisterAllocator {
- public ChaitinGraphColoringRegisterAllocator() {
- }
-
public Map assignRegisters(CompiledFunction function, int numRegisters) {
if (function.isSSA) throw new IllegalStateException("Register allocation should be done after exiting SSA");
- var g = coalesce(function);
- var registers = registersInIR(function);
- var colors = IntStream.range(0, numRegisters).boxed().toList();
- // TODO pre-assign regs to args
+ // Remove useless copy operations
+ InterferenceGraph g = coalesce(function);
+ // Get used registers
+ Set registers = registersInIR(function);
+ // Create color set
+ List colors = new ArrayList<>(IntStream.range(0, numRegisters).boxed().toList());
+ // Function args are pre-assigned colors
+ // and we remove them from the register set
+ Map assignments = preAssignArgsToColors(function, registers, colors);
// TODO spilling
- var assignments = colorGraph(g, registers, new HashSet<>(colors));
+ // execute graph coloring on remaining registers
+ assignments = colorGraph(g, registers, new HashSet<>(colors), assignments);
+ // update all instructions
+ // We simply set the slot on each register - rather than actually trying to replace them
+ updateInstructions(function, assignments);
+ // Compute and set the new framesize
+ function.setFrameSize(computeFrameSize(assignments));
return assignments;
}
+ /**
+ * Frame size = max number of registers needed to execute the function
+ */
+ private int computeFrameSize(Map assignments) {
+ return assignments.values().stream().mapToInt(k->k).max().orElse(0);
+ }
+
+ /**
+ * Due to the way function args are received by the abstract machine, we need
+ * to assign them register slots starting from 0. After assigning colors/slots
+ * we remove these from the set so that the graph coloring algo does
+ */
+ private Map preAssignArgsToColors(CompiledFunction function, Set registers, List colors) {
+ int count = 0;
+ Map assignments = new HashMap<>();
+ for (Instruction instruction : function.entry.instructions) {
+ if (instruction instanceof Instruction.ArgInstruction argInstruction) {
+ Integer color = colors.get(count);
+ Register reg = argInstruction.arg().reg;
+ registers.remove(reg.nonSSAId()); // Remove register from set before changing slot
+ assignments.put(reg.nonSSAId(), color);
+ count++;
+ }
+ else break;
+ }
+ return assignments;
+ }
+
+ private void updateInstructions(CompiledFunction function, Map assignments) {
+ var regPool = function.registerPool;
+ for (var entry : assignments.entrySet()) {
+ int reg = entry.getKey();
+ int slot = entry.getValue();
+ regPool.getReg(reg).updateSlot(slot);
+ }
+ }
+
/**
* Chaitin: coalesce_nodes - coalesce away copy operations
*/
@@ -85,9 +137,7 @@ private void rewriteInstructions(CompiledFunction function, Instruction deadInst
private Set registersInIR(CompiledFunction function) {
Set registers = new HashSet<>();
for (var block: function.getBlocks()) {
- Iterator iter = block.instructions.iterator();
- while (iter.hasNext()) {
- Instruction instruction = iter.next();
+ for (Instruction instruction: block.instructions) {
if (instruction.definesVar())
registers.add(instruction.def().id);
for (Register use: instruction.uses())
@@ -112,7 +162,7 @@ private Integer findNodeWithNeighborCountLessThan(InterferenceGraph g, Set getNeighborColors(InterferenceGraph g, Integer node, Map assignedColors) {
Set colors = new HashSet<>();
for (var neighbour: g.neighbors(node)) {
- var c = assignedColors.get(neighbour);
+ Integer c = assignedColors.get(neighbour);
if (c != null) {
colors.add(c);
}
@@ -137,18 +187,18 @@ private static HashSet subtract(Set originalSet, Integer node)
/**
* Chaitin: color_graph
*/
- private Map colorGraph(InterferenceGraph g, Set nodes, Set colors) {
+ private Map colorGraph(InterferenceGraph g, Set nodes, Set colors, Map preAssignedColors) {
if (nodes.size() == 0)
- return new HashMap<>();
- var numColors = colors.size();
- var node = findNodeWithNeighborCountLessThan(g, nodes, numColors);
+ return preAssignedColors;
+ int numColors = colors.size();
+ Integer node = findNodeWithNeighborCountLessThan(g, nodes, numColors);
if (node == null)
return null;
- var coloring = colorGraph(g.dup().subtract(node), subtract(nodes, node), colors);
+ Map coloring = colorGraph(g.dup().subtract(node), subtract(nodes, node), colors, preAssignedColors);
if (coloring == null)
return null;
- var neighbourColors = getNeighborColors(g, node, coloring);
- var color = chooseSomeColorNotAssignedToNeighbors(colors, neighbourColors);
+ Set neighbourColors = getNeighborColors(g, node, coloring);
+ Integer color = chooseSomeColorNotAssignedToNeighbors(colors, neighbourColors);
coloring.put(node, color);
return coloring;
}
diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java
index 85dd365..3e02457 100644
--- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java
+++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java
@@ -21,7 +21,7 @@ public class CompiledFunction {
private Type.TypeFunction functionType;
public final RegisterPool registerPool;
- private final int frameSlots;
+ private int frameSlots;
public boolean isSSA;
public boolean hasLiveness;
@@ -76,6 +76,9 @@ private void generateArgInstructions(Scope scope) {
public int frameSize() {
return frameSlots;
}
+ public void setFrameSize(int size) {
+ frameSlots = size;
+ }
private void exitBlockIfNeeded() {
if (currentBlock != null &&
@@ -134,6 +137,7 @@ private void compileStatement(AST.Stmt statement) {
case AST.VarStmt letStmt -> {
compileLet(letStmt);
}
+ case AST.VarDeclStmt varDeclStmt -> {}
case AST.IfElseStmt ifElseStmt -> {
compileIf(ifElseStmt);
}
diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Compiler.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Compiler.java
index 105102b..e390fd1 100644
--- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Compiler.java
+++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Compiler.java
@@ -8,19 +8,24 @@
import com.compilerprogramming.ezlang.types.Type;
import com.compilerprogramming.ezlang.types.TypeDictionary;
-import java.util.BitSet;
-
public class Compiler {
- private void compile(TypeDictionary typeDictionary) {
+ private void compile(TypeDictionary typeDictionary, boolean opt) {
for (Symbol symbol: typeDictionary.getLocalSymbols()) {
if (symbol instanceof Symbol.FunctionTypeSymbol functionSymbol) {
Type.TypeFunction functionType = (Type.TypeFunction) functionSymbol.type;
- functionType.code = new CompiledFunction(functionSymbol);
+ var function = new CompiledFunction(functionSymbol);
+ functionType.code = function;
+ if (opt) {
+ new Optimizer().optimize(function);
+ }
}
}
}
public TypeDictionary compileSrc(String src) {
+ return compileSrc(src, false);
+ }
+ public TypeDictionary compileSrc(String src, boolean opt) {
Parser parser = new Parser();
var program = parser.parse(new Lexer(src));
var typeDict = new TypeDictionary();
@@ -28,7 +33,7 @@ public TypeDictionary compileSrc(String src) {
sema.analyze(program);
var sema2 = new SemaAssignTypes(typeDict);
sema2.analyze(program);
- compile(typeDict);
+ compile(typeDict, opt);
return typeDict;
}
public static String dumpIR(TypeDictionary typeDictionary) {
diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/EnterSSA.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/EnterSSA.java
index 16cf7ee..57e30e5 100644
--- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/EnterSSA.java
+++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/EnterSSA.java
@@ -1,5 +1,7 @@
package com.compilerprogramming.ezlang.compiler;
+import com.compilerprogramming.ezlang.exceptions.CompilerException;
+
import java.util.*;
/**
@@ -178,7 +180,11 @@ static class BBSet {
static class VersionStack {
List stack = new ArrayList<>();
void push(Register.SSARegister r) { stack.add(r); }
- Register.SSARegister top() { return stack.getLast(); }
+ Register.SSARegister top() {
+ if (stack.isEmpty())
+ throw new CompilerException("Variable may not be initialized");
+ return stack.getLast();
+ }
void pop() { stack.removeLast(); }
}
diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java
index 205f23a..684d11e 100644
--- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java
+++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java
@@ -23,6 +23,7 @@ public ExitSSA(CompiledFunction function) {
initStack();
insertCopies(function.entry);
removePhis();
+ function.isSSA = false;
}
private void removePhis() {
diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Operand.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Operand.java
index 2d002a9..3f8806c 100644
--- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Operand.java
+++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Operand.java
@@ -21,7 +21,7 @@ public String toString() {
}
public static class RegisterOperand extends Operand {
- public Register reg;
+ Register reg;
public RegisterOperand(Register reg) {
this.reg = reg;
if (reg == null)
diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Optimizer.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Optimizer.java
new file mode 100644
index 0000000..73acce5
--- /dev/null
+++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Optimizer.java
@@ -0,0 +1,10 @@
+package com.compilerprogramming.ezlang.compiler;
+
+public class Optimizer {
+
+ public void optimize(CompiledFunction function) {
+ new EnterSSA(function);
+ new ExitSSA(function);
+ new ChaitinGraphColoringRegisterAllocator().assignRegisters(function, 64);
+ }
+}
diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Register.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Register.java
index 259466e..f9f8296 100644
--- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Register.java
+++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Register.java
@@ -21,11 +21,13 @@ public class Register {
* The type of a register
*/
public final Type type;
+ private int slot;
public Register(int id, String name, Type type) {
this.id = id;
this.name = name;
this.type = type;
+ this.slot = id;
}
@Override
public boolean equals(Object o) {
@@ -44,7 +46,10 @@ public String name() {
return name;
}
public int nonSSAId() {
- return id;
+ return slot;
+ }
+ public void updateSlot(int slot) {
+ this.slot = slot;
}
/**
diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java
index ee16af3..52dffb6 100644
--- a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java
+++ b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java
@@ -4,6 +4,7 @@
import com.compilerprogramming.ezlang.types.Type;
import com.compilerprogramming.ezlang.types.TypeDictionary;
import org.junit.Assert;
+import org.junit.Ignore;
import org.junit.Test;
import java.util.Arrays;
@@ -766,4 +767,51 @@ public void testSwapProblem() {
Assert.assertEquals(expected, function.toStr(new StringBuilder(), false).toString());
}
+ @Test
+ public void testLiveness() {
+ String src = """
+ func bar(x: Int)->Int {
+ var y = 0
+ var z = 0
+ while( x>1 ){
+ y = x/2;
+ if (y > 3) {
+ x = x-y;
+ }
+ z = x-4;
+ if (z > 0) {
+ x = x/2;
+ }
+ z = z-1;
+ }
+ return x;
+ }
+
+ func foo() {
+ return bar(10);
+ }
+ """;
+ String result = compileSrc(src);
+ System.out.println(result);
+ }
+
+ @Test
+ @Ignore
+ public void testInit() {
+ // see issue #16
+ String src = """
+ func foo(x: Int) {
+ var z: Int
+ while (x > 0) {
+ z = 5
+ if (x == 1)
+ z = z+1
+ x = x - 1
+ }
+ }
+ """;
+ String result = compileSrc(src);
+ System.out.println(result);
+ }
+
}
\ No newline at end of file
diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/interpreter/TestInterpreter.java b/optvm/src/test/java/com/compilerprogramming/ezlang/interpreter/TestInterpreter.java
index 3be1663..fd87020 100644
--- a/optvm/src/test/java/com/compilerprogramming/ezlang/interpreter/TestInterpreter.java
+++ b/optvm/src/test/java/com/compilerprogramming/ezlang/interpreter/TestInterpreter.java
@@ -7,8 +7,11 @@
public class TestInterpreter {
Value compileAndRun(String src, String mainFunction) {
+ return compileAndRun(src, mainFunction, false);
+ }
+ Value compileAndRun(String src, String mainFunction, boolean opt) {
var compiler = new Compiler();
- var typeDict = compiler.compileSrc(src);
+ var typeDict = compiler.compileSrc(src, opt);
var compiled = compiler.dumpIR(typeDict);
System.out.println(compiled);
var interpreter = new Interpreter(typeDict);
@@ -121,4 +124,53 @@ func foo()->Int {
&& integerValue.value == 42);
}
+ @Test
+ public void testFunction8() {
+ String src = """
+ func factorial(num: Int)->Int {
+ var result = 1
+ while (num > 1)
+ {
+ result = result * num
+ num = num - 1
+ }
+ return result
+ }
+ func foo()->Int {
+ return factorial(5);
+ }
+ """;
+ var value = compileAndRun(src, "foo", true);
+ Assert.assertNotNull(value);
+ Assert.assertTrue(value instanceof Value.IntegerValue integerValue
+ && integerValue.value == 120);
+ }
+
+ @Test
+ public void testFunction9() {
+ String src = """
+ func fib(n: Int)->Int {
+ var i: Int;
+ var temp: Int;
+ var f1=1;
+ var f2=1;
+ i=n;
+ while( i>1 ){
+ temp = f1+f2;
+ f1=f2;
+ f2=temp;
+ i=i-1;
+ }
+ return f2;
+ }
+
+ func foo() {
+ return fib(10);
+ }
+ """;
+ var value = compileAndRun(src, "foo", true);
+ Assert.assertNotNull(value);
+ Assert.assertTrue(value instanceof Value.IntegerValue integerValue
+ && integerValue.value == 89);
+ }
}