diff --git a/optvm/README.md b/optvm/README.md index df4d28e..ed5b9d0 100644 --- a/optvm/README.md +++ b/optvm/README.md @@ -4,4 +4,36 @@ This module implements various compiler optimization techniques such as: * Static Single Assignment * Liveness Analysis -* WIP Graph Coloring Register Allocator (Chaitin) \ No newline at end of file +* Graph Coloring Register Allocator (Chaitin) + +Our goal here is to perform optimizations on the Intermediate Representation targeting an abstract machine, rather than +a physical machine. Therefore, all our optimization passes will work on the instruction set of this abstract machine. + + +## Guide + +* [Register](src/main/java/com/compilerprogramming/ezlang/compiler/Register.java) - implements a virtual register. Virtual registers + have a name, type, and id - the id is unique, but name is not. Initially the compiler generates unique registers for every local + and temporary, the number of registers grows as we convert to SSA and back out of SSA. Finally, as we run the Chaitin register allocator + we shrink down the number of virtual registers to the minimum. +* [RegisterPool](src/main/java/com/compilerprogramming/ezlang/compiler/RegisterPool.java) - simple pool to allow us to find a register + by its id, and to allocate new virtual registers. +* [CompiledFunction](src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java) - encapsulates the IR for a single function. +* [BasicBlock](src/main/java/com/compilerprogramming/ezlang/compiler/BasicBlock.java) - Defines our basic block - which contains instructions + that execute sequentially. A basic block ends with a branch. There are two distinguished basic blocks in every function: entry and exit. +* [BBHelper](src/main/java/com/compilerprogramming/ezlang/compiler/BBHelper.java) - Some utilities that manipulate basic blocks. +* [Operand](src/main/java/com/compilerprogramming/ezlang/compiler/Operand.java) - Operands in instructions. +* [Instruction](src/main/java/com/compilerprogramming/ezlang/compiler/Instruction.java) - Instructions in basic blocks. +* [DominatorTree](src/main/java/com/compilerprogramming/ezlang/compiler/DominatorTree.java) - Calculates dominator tree and dominance frontiers. +* [LiveSet](src/main/java/com/compilerprogramming/ezlang/compiler/LiveSet.java) - Bitset used to track liveness of registers. +* [Liveness](src/main/java/com/compilerprogramming/ezlang/compiler/Liveness.java) - Liveness calculator, works for both SSA and non-SSA forms. +* [EnterSSA](src/main/java/com/compilerprogramming/ezlang/compiler/EnterSSA.java) - Transforms into SSA, using algorithm by Preston Briggs. +* [ExitSSA](src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java) - Exits SSA form, using algorithm by Preston Briggs. +* [LoopFinder](src/main/java/com/compilerprogramming/ezlang/compiler/LoopFinder.java) - Discovers loops. +* [LoopNest](src/main/java/com/compilerprogramming/ezlang/compiler/LoopNest.java) - Representation of loop nesting. +* [InterferenceGraph](src/main/java/com/compilerprogramming/ezlang/compiler/InterferenceGraph.java) - Representation of an Interference Graph + required by the register allocator. +* [InterferenceGraphBuilder](src/main/java/com/compilerprogramming/ezlang/compiler/InterferenceGraph.java) - Constructs InteferenceGraph for a set + of basic bocks, using liveness information. +* [ChaitinGraphColoringRegisterAllocator](src/main/java/com/compilerprogramming/ezlang/compiler/ChaitinGraphColoringRegisterAllocator.java) - basic + Chaitin Graph Coloring Register Allocator - WIP. diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/BasicBlock.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/BasicBlock.java index c34659f..09745e3 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/BasicBlock.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/BasicBlock.java @@ -126,6 +126,15 @@ public List phis() { } return list; } + public int whichPred(BasicBlock s) { + int i = 0; + for (BasicBlock p: s.predecessors) { + if (p == this) + return i; + i++; + } + throw new IllegalStateException(); + } public static StringBuilder toStr(StringBuilder sb, BasicBlock bb, BitSet visited, boolean dumpLiveness) { if (visited.get(bb.bid)) diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ChaitinGraphColoringRegisterAllocator.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ChaitinGraphColoringRegisterAllocator.java index 8fc2bc7..aa22a75 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ChaitinGraphColoringRegisterAllocator.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ChaitinGraphColoringRegisterAllocator.java @@ -1,52 +1,75 @@ package com.compilerprogramming.ezlang.compiler; -import java.util.ArrayList; -import java.util.List; +import java.util.*; +import java.util.stream.IntStream; +/** + * Implement the original graph coloring algorithm described by Chaitin. + * + * TODO spilling + */ public class ChaitinGraphColoringRegisterAllocator { - public ChaitinGraphColoringRegisterAllocator(CompiledFunction function) { - coalesce(function); + public ChaitinGraphColoringRegisterAllocator() { } - private void coalesce(CompiledFunction function) { + public Map assignRegisters(CompiledFunction function, int numRegisters) { + if (function.isSSA) throw new IllegalStateException("Register allocation should be done after exiting SSA"); + var g = coalesce(function); + var registers = registersInIR(function); + var colors = IntStream.range(0, numRegisters).boxed().toList(); + // TODO pre-assign regs to args + // TODO spilling + var assignments = colorGraph(g, registers, new HashSet<>(colors)); + return assignments; + } + + /** + * Chaitin: coalesce_nodes - coalesce away copy operations + */ + public InterferenceGraph coalesce(CompiledFunction function) { boolean changed = true; + InterferenceGraph igraph = null; while (changed) { - var igraph = new InterferenceGraphBuilder().build(function); - changed = coalesceRegisters(function, igraph); + igraph = new InterferenceGraphBuilder().build(function); + changed = coalesceCopyOperations(function, igraph); } + return igraph; } - private boolean coalesceRegisters(CompiledFunction function, InterferenceGraph igraph) { + /** + * Chaitin: coalesce_nodes - coalesce away copy operations + */ + private boolean coalesceCopyOperations(CompiledFunction function, InterferenceGraph igraph) { boolean changed = false; for (var block: function.getBlocks()) { - List instructionsToRemove = new ArrayList<>(); - for (int j = 0; j < block.instructions.size(); j++) { - Instruction i = block.instructions.get(j); - if (i instanceof Instruction.Move move - && move.from() instanceof Operand.RegisterOperand targetOperand) { + Iterator iter = block.instructions.iterator(); + while (iter.hasNext()) { + Instruction instruction = iter.next(); + if (instruction instanceof Instruction.Move move + && move.from() instanceof Operand.RegisterOperand registerTarget) { Register source = move.def(); - Register target = targetOperand.reg; + Register target = registerTarget.reg; if (source.id != target.id && !igraph.interfere(target.id, source.id)) { igraph.rename(source.id, target.id); - rewriteInstructions(function, i, source, target); - instructionsToRemove.add(j); + rewriteInstructions(function, instruction, source, target); + iter.remove(); changed = true; } } } - for (var j: instructionsToRemove) { - block.instructions.set(j, new Instruction.NoOp()); - } } return changed; } - private void rewriteInstructions(CompiledFunction function, Instruction notNeeded, Register source, Register target) { + /** + * Chaitin: rewrite_il + */ + private void rewriteInstructions(CompiledFunction function, Instruction deadInstruction, Register source, Register target) { for (var block: function.getBlocks()) { for (Instruction i: block.instructions) { - if (i == notNeeded) + if (i == deadInstruction) continue; if (i.definesVar() && source.id == i.def().id) i.replaceDef(target); @@ -54,4 +77,80 @@ private void rewriteInstructions(CompiledFunction function, Instruction notNeede } } } + + /** + * Get the list of registers in use in the Intermediate Code + * Chaitin: registers_in_il() + */ + private Set registersInIR(CompiledFunction function) { + Set registers = new HashSet<>(); + for (var block: function.getBlocks()) { + Iterator iter = block.instructions.iterator(); + while (iter.hasNext()) { + Instruction instruction = iter.next(); + if (instruction.definesVar()) + registers.add(instruction.def().id); + for (Register use: instruction.uses()) + registers.add(use.id); + } + } + return registers; + } + + /** + * Chaitin: color_graph line 2-3 + */ + private Integer findNodeWithNeighborCountLessThan(InterferenceGraph g, Set nodes, int numColors) { + for (var node: nodes) { + if (g.neighbors(node).size() < numColors) { + return node; + } + } + return null; + } + + private Set getNeighborColors(InterferenceGraph g, Integer node, Map assignedColors) { + Set colors = new HashSet<>(); + for (var neighbour: g.neighbors(node)) { + var c = assignedColors.get(neighbour); + if (c != null) { + colors.add(c); + } + } + return colors; + } + + private Integer chooseSomeColorNotAssignedToNeighbors(Set colors, Set neighborColors) { + // Create new color set that removes the colors assigned to neighbors + var set = new HashSet<>(colors); + set.removeAll(neighborColors); + // pick a random color (we pick the first) + return set.stream().findAny().orElseThrow(); + } + + private static HashSet subtract(Set originalSet, Integer node) { + var reducedSet = new HashSet<>(originalSet); + reducedSet.remove(node); + return reducedSet; + } + + /** + * Chaitin: color_graph + */ + private Map colorGraph(InterferenceGraph g, Set nodes, Set colors) { + if (nodes.size() == 0) + return new HashMap<>(); + var numColors = colors.size(); + var node = findNodeWithNeighborCountLessThan(g, nodes, numColors); + if (node == null) + return null; + var coloring = colorGraph(g.dup().subtract(node), subtract(nodes, node), colors); + if (coloring == null) + return null; + var neighbourColors = getNeighborColors(g, node, coloring); + var color = chooseSomeColorNotAssignedToNeighbors(colors, neighbourColors); + coloring.put(node, color); + return coloring; + } + } diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/SSATransform.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/EnterSSA.java similarity index 93% rename from optvm/src/main/java/com/compilerprogramming/ezlang/compiler/SSATransform.java rename to optvm/src/main/java/com/compilerprogramming/ezlang/compiler/EnterSSA.java index da802ea..16cf7ee 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/SSATransform.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/EnterSSA.java @@ -8,7 +8,7 @@ * 'Practical Improvements to the Construction and Destruction * of Single Static Assigment Form' by Preston Briggs. */ -public class SSATransform { +public class EnterSSA { CompiledFunction function; DominatorTree domTree; @@ -22,12 +22,12 @@ public class SSATransform { int[] counters; VersionStack[] stacks; - public SSATransform(CompiledFunction bytecodeFunction) { + public EnterSSA(CompiledFunction bytecodeFunction) { this.function = bytecodeFunction; setupGlobals(); computeDomTreeAndDominanceFrontiers(); this.blocks = domTree.blocks; - findGlobalVars(); + findNonLocalNames(); insertPhis(); renameVars(); bytecodeFunction.isSSA = true; @@ -47,7 +47,7 @@ private void setupGlobals() { * Compute set of registers that are live across multiple blocks * i.e. are not exclusively used in a single block. */ - private void findGlobalVars() { + private void findNonLocalNames() { for (BasicBlock block : blocks) { var varKill = new HashSet(); for (Instruction instruction: block.instructions) { @@ -142,7 +142,7 @@ void search(BasicBlock block) { } // Update phis in successor blocks for (BasicBlock s: block.successors) { - int j = whichPred(s,block); + int j = block.whichPred(s); for (Instruction.Phi phi: s.phis()) { Register oldReg = phi.input(j); phi.replaceInput(j, stacks[oldReg.nonSSAId()].top()); @@ -161,16 +161,6 @@ void search(BasicBlock block) { } } - public static int whichPred(BasicBlock s, BasicBlock block) { - int i = 0; - for (BasicBlock p: s.predecessors) { - if (p == block) - return i; - i++; - } - throw new IllegalStateException(); - } - private void initVersionCounters() { counters = new int[nonLocalNames.length]; stacks = new VersionStack[nonLocalNames.length]; diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java index ae64666..205f23a 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java @@ -82,7 +82,7 @@ private void scheduleCopies(BasicBlock block, List pushed) { Map map = new HashMap<>(); BitSet usedByAnother = new BitSet(function.registerPool.numRegisters()*2); for (BasicBlock s: block.successors) { - int j = SSATransform.whichPred(s, block); + int j = block.whichPred(s); for (Instruction.Phi phi: s.phis()) { Register dst = phi.value(); Register src = phi.input(j); // jth operand of phi node diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/InterferenceGraph.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/InterferenceGraph.java index 441a264..6691d06 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/InterferenceGraph.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/InterferenceGraph.java @@ -3,7 +3,7 @@ import java.util.*; public class InterferenceGraph { - Map> edges = new HashMap<>(); + private Map> edges = new HashMap<>(); private Set addNode(Integer node) { var set = edges.get(node); @@ -24,6 +24,32 @@ public void addEdge(Integer from, Integer to) { set2.add(from); } + /** + * Remove a node from the interference graph + * deleting it from all adjacency lists + */ + public InterferenceGraph subtract(Integer node) { + edges.remove(node); + for (var key : edges.keySet()) { + var neighbours = edges.get(key); + neighbours.remove(key); + } + return this; + } + + /** + * Duplicate an interference graph + */ + public InterferenceGraph dup() { + var igraph = new InterferenceGraph(); + igraph.edges = new HashMap<>(); + for (var key : edges.keySet()) { + var neighbours = edges.get(key); + igraph.edges.put(key, new HashSet<>(neighbours)); + } + return igraph; + } + public boolean interfere(Integer from, Integer to) { var set = edges.get(from); return set != null && set.contains(to); @@ -50,8 +76,15 @@ public void rename(Integer source, Integer target) { } } - public Set adjacents(Integer node) { - return edges.get(node); + /** + * Get neighbours of the node + * Chaitin: neighbors() + */ + public Set neighbors(Integer node) { + var adjacents = edges.get(node); + if (adjacents == null) + adjacents = Collections.emptySet(); + return adjacents; } public static final class Edge { diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/LiveSet.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/LiveSet.java index b5d4da7..4e08232 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/LiveSet.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/LiveSet.java @@ -32,7 +32,7 @@ public void remove(List regs) { remove(r); } } - public boolean isMember(Register r) { + public boolean contains(Register r) { return get(r.id); } public LiveSet intersect(LiveSet other) { @@ -43,7 +43,10 @@ public LiveSet union(LiveSet other) { or(other); return this; } - public LiveSet intersectNot(LiveSet other) { + /** + * Computes this - other. + */ + public LiveSet subtract(LiveSet other) { andNot(other); return this; } diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Liveness.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Liveness.java index a125547..8d0b322 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Liveness.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Liveness.java @@ -59,7 +59,7 @@ private void init(List blocks) { } for (Instruction instruction : block.instructions) { for (Register use : instruction.uses()) { - if (!block.varKill.isMember(use)) + if (!block.varKill.contains(use)) block.UEVar.add(use); } if (instruction.definesVar() && !(instruction instanceof Instruction.Phi)) { @@ -74,7 +74,7 @@ private void init(List blocks) { // if there is loop back and there are cycles // such as e.g. the swap copy problem if (pred == block && - block.phiDefs.isMember(use)) + block.phiDefs.contains(use)) continue; pred.phiUses.add(use); } @@ -99,11 +99,11 @@ private void computeLiveness(List blocks) { // LiveOut(B) = U all S (LiveIn(S) \ PhiDefs(S)) U PhiUses(B) private boolean recomputeLiveOut(BasicBlock block) { LiveSet oldLiveOut = block.liveOut.dup(); - LiveSet t = block.liveOut.dup().intersectNot(block.varKill); + LiveSet t = block.liveOut.dup().subtract(block.varKill); block.liveIn.union(block.phiDefs).union(block.UEVar).union(t); block.liveOut.clear(); for (BasicBlock s: block.successors) { - t = s.liveIn.dup().intersectNot(s.phiDefs); + t = s.liveIn.dup().subtract(s.phiDefs); block.liveOut.union(t); } block.liveOut.union(block.phiUses); diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestChaitinRegAllocator.java b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestChaitinRegAllocator.java index 3c09ac0..4197e1f 100644 --- a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestChaitinRegAllocator.java +++ b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestChaitinRegAllocator.java @@ -15,7 +15,17 @@ public void test4() { Assert.assertEquals(2, edges.size()); Assert.assertTrue(edges.contains(new InterferenceGraph.Edge(0, 1))); Assert.assertTrue(edges.contains(new InterferenceGraph.Edge(0, 2))); - new ChaitinGraphColoringRegisterAllocator(function); - System.out.println(function.toStr(new StringBuilder(), true)); + var regAssignments = new ChaitinGraphColoringRegisterAllocator().assignRegisters(function, 64); + String result = function.toStr(new StringBuilder(), false).toString(); + Assert.assertEquals(""" +L0: + a = 1 + b = 2 + t = b+a + goto L1 +L1: +""", result); + Assert.assertEquals(regAssignments.size(), 3); + Assert.assertEquals(regAssignments.values().stream().sorted().distinct().count(), 2); } } \ No newline at end of file diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestInterferenceGraph.java b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestInterferenceGraph.java index bf05f1f..2cfe554 100644 --- a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestInterferenceGraph.java +++ b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestInterferenceGraph.java @@ -205,8 +205,8 @@ public void test5() { graph.addEdge(1, 2); Assert.assertTrue(graph.interfere(1, 2)); Assert.assertTrue(graph.interfere(2, 1)); - Assert.assertTrue(graph.adjacents(1).contains(2)); - Assert.assertTrue(graph.adjacents(2).contains(1)); + Assert.assertTrue(graph.neighbors(1).contains(2)); + Assert.assertTrue(graph.neighbors(2).contains(1)); } @Test @@ -220,10 +220,10 @@ public void test6() { Assert.assertTrue(graph.interfere(3, 1)); Assert.assertFalse(graph.interfere(2, 3)); Assert.assertFalse(graph.interfere(3, 2)); - Assert.assertTrue(graph.adjacents(1).contains(2)); - Assert.assertTrue(graph.adjacents(1).contains(3)); - Assert.assertTrue(graph.adjacents(2).contains(1)); - Assert.assertTrue(graph.adjacents(3).contains(1)); + Assert.assertTrue(graph.neighbors(1).contains(2)); + Assert.assertTrue(graph.neighbors(1).contains(3)); + Assert.assertTrue(graph.neighbors(2).contains(1)); + Assert.assertTrue(graph.neighbors(3).contains(1)); System.out.println(graph.generateDotOutput()); graph.rename(2, 3); System.out.println(graph.generateDotOutput()); @@ -233,9 +233,9 @@ public void test6() { Assert.assertTrue(graph.interfere(3, 1)); Assert.assertFalse(graph.interfere(2, 3)); Assert.assertFalse(graph.interfere(3, 2)); - Assert.assertFalse(graph.adjacents(1).contains(2)); - Assert.assertTrue(graph.adjacents(1).contains(3)); - Assert.assertTrue(graph.adjacents(3).contains(1)); + Assert.assertFalse(graph.neighbors(1).contains(2)); + Assert.assertTrue(graph.neighbors(1).contains(3)); + Assert.assertTrue(graph.neighbors(3).contains(1)); } @Test @@ -249,10 +249,10 @@ public void test7() { Assert.assertTrue(graph.interfere(3, 1)); Assert.assertFalse(graph.interfere(2, 3)); Assert.assertFalse(graph.interfere(3, 2)); - Assert.assertTrue(graph.adjacents(1).contains(2)); - Assert.assertTrue(graph.adjacents(1).contains(3)); - Assert.assertTrue(graph.adjacents(2).contains(1)); - Assert.assertTrue(graph.adjacents(3).contains(1)); + Assert.assertTrue(graph.neighbors(1).contains(2)); + Assert.assertTrue(graph.neighbors(1).contains(3)); + Assert.assertTrue(graph.neighbors(2).contains(1)); + Assert.assertTrue(graph.neighbors(3).contains(1)); System.out.println(graph.generateDotOutput()); graph.rename(1, 2); System.out.println(graph.generateDotOutput()); @@ -262,8 +262,8 @@ public void test7() { Assert.assertFalse(graph.interfere(3, 1)); Assert.assertTrue(graph.interfere(2, 3)); Assert.assertTrue(graph.interfere(3, 2)); - Assert.assertTrue(graph.adjacents(2).contains(3)); - Assert.assertTrue(graph.adjacents(3).contains(2)); + Assert.assertTrue(graph.neighbors(2).contains(3)); + Assert.assertTrue(graph.neighbors(3).contains(2)); } } diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java index cbe8cab..ee16af3 100644 --- a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java +++ b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java @@ -22,7 +22,7 @@ String compileSrc(String src) { sb.append("Before SSA\n"); sb.append("==========\n"); BasicBlock.toStr(sb, functionBuilder.entry, new BitSet(), false); - new SSATransform(functionBuilder); + new EnterSSA(functionBuilder); sb.append("After SSA\n"); sb.append("=========\n"); BasicBlock.toStr(sb, functionBuilder.entry, new BitSet(), false);