diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/BasicBlock.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/BasicBlock.java index 9a42c6a..91fe9ac 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/BasicBlock.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/BasicBlock.java @@ -49,18 +49,18 @@ public class BasicBlock { * VarKill contains all the variables that are defined * in the block. */ - BitSet varKill; + LiveSet varKill; /** * UEVar contains upward-exposed variables in the block, * i.e. those variables that are used in the block prior to * any redefinition in the block. */ - BitSet UEVar; + LiveSet UEVar; /** * LiveOut is the union of variables that are live at the * head of some block that is a successor of this block. */ - BitSet liveOut; + LiveSet liveOut; // ----------------------- public BasicBlock(int bid, boolean loopHead) { diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java index ed7502c..6c1170d 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java @@ -560,4 +560,9 @@ public StringBuilder toStr(StringBuilder sb, boolean verbose) { BasicBlock.toStr(sb, entry, new BitSet(), verbose); return sb; } + + public void livenessAnalysis() { + new Liveness(this); + this.hasLiveness = true; + } } diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java index 163bbbd..fb9bb03 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java @@ -17,9 +17,7 @@ public class ExitSSA { public ExitSSA(CompiledFunction function) { this.function = function; if (!function.isSSA) throw new IllegalStateException(); - if (!function.hasLiveness) { - new Liveness().computeLiveness(function); - } + function.livenessAnalysis(); tree = new DominatorTree(function.entry); initStack(); insertCopies(function.entry); @@ -115,6 +113,13 @@ private void scheduleCopies(BasicBlock block, List pushed) { final CopyItem copyItem = workList.remove(0); final Register src = copyItem.src; final Register dest = copyItem.dest; + /* Engineering a Compiler: We can avoid the lost copy + problem by checking the liveness of the target name + for each copy that we try to insert. When we discover + a copy target that is live, we must preserve the live + value in a temporary name and rewrite subsequent uses to + refer to the temporary name. + */ if (block.liveOut.get(dest.id)) { /* Insert a copy from dest to a new temp t at phi node defining dest */ final Register t = addMoveToTempAfterPhi(block, dest); @@ -125,11 +130,19 @@ private void scheduleCopies(BasicBlock block, List pushed) { addMoveAtBBEnd(block, map.get(src.id), dest); map.put(src.id, dest); /* If src is the name of a dest in copySet add item to worklist */ + /* see comment on phi cycles below. */ CopyItem item = isCycle(copySet, src); if (item != null) { workList.add(item); } } + /* Engineering a Compiler: To solve the swap problem + we can detect cases where phi functions reference the + targets of other phi functions in the same block. For each + cycle of references, it must insert a copy to a temporary + that breaks the cycle. Then we can schedule the copies to + respect the dependencies implied by the phi functions. + */ if (!copySet.isEmpty()) { CopyItem copyItem = copySet.remove(0); /* Insert a copy from dst to new temp at the end of Block */ diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/InterferenceGraph.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/InterferenceGraph.java new file mode 100644 index 0000000..ede1f7a --- /dev/null +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/InterferenceGraph.java @@ -0,0 +1,80 @@ +package com.compilerprogramming.ezlang.compiler; + +import java.util.*; + +public class InterferenceGraph { + Map> edges = new HashMap<>(); + + private Set addNode(Integer node) { + var set = edges.get(node); + if (set == null) { + set = new HashSet<>(); + edges.put(node, set); + } + return set; + } + + public void addEdge(Integer from, Integer to) { + if (from == to) { + return; + } + var set1 = addNode(from); + var set2 = addNode(to); + set1.add(to); + set2.add(from); + } + + public boolean containsEdge(Integer from, Integer to) { + var set = edges.get(from); + return set != null && set.contains(to); + } + + public Set adjacents(Integer node) { + return edges.get(node); + } + + public static final class Edge { + public final int from; + public final int to; + public Edge(int from, int to) { + this.from = from; + this.to = to; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Edge edge = (Edge) o; + return (from == edge.from && to == edge.to) + || (from == edge.to && to == edge.from); + } + + @Override + public int hashCode() { + return from+to; + } + } + + public Set getEdges() { + Set all = new HashSet<>(); + for (Integer from: edges.keySet()) { + var set = edges.get(from); + for (Integer to: set) { + all.add(new Edge(from, to)); + } + } + return all; + } + + public String generateDotOutput() { + StringBuilder sb = new StringBuilder(); + sb.append("digraph IGraph {\n"); + for (var edge: getEdges()) { + sb.append(edge.from).append("->").append(edge.to).append(";\n"); + } + sb.append("}\n"); + return sb.toString(); + } + +} diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/InterferenceGraphBuilder.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/InterferenceGraphBuilder.java new file mode 100644 index 0000000..bac880e --- /dev/null +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/InterferenceGraphBuilder.java @@ -0,0 +1,51 @@ +package com.compilerprogramming.ezlang.compiler; + +public class InterferenceGraphBuilder { + + public InterferenceGraph build(CompiledFunction function) { + InterferenceGraph graph = new InterferenceGraph(); + // Calculate liveOut for all basic blocks + function.livenessAnalysis(); + System.out.println(function.toStr(new StringBuilder(), true)); + var blocks = BBHelper.findAllBlocks(function.entry); + for (var b : blocks) { + // Start with the set of live vars at the end of the block + // This liveness will be updated as we look through the + // instructions in the block + var liveNow = b.liveOut.dup(); + // liveNow is initially the set of values that are live (and avail?) at the + // end of the block. + // Process each instruction in the block in reverse order + for (var i: b.instructions.reversed()) { + if (i instanceof Instruction.Move || + i instanceof Instruction.Phi) { + // Move(copy) instructions are handled specially to avoid + // adding an undesirable interference between the source and + // destination (section 2.2.2 in Briggs thesis) + // Engineering a Compiler: The copy operation does not + // create an interference cause both values can occupy the + // same register + // Same argument applies to phi. + liveNow.remove(i.uses()); + } + if (i.definesVar()) { + var def = i.def(); + // Defined vars interfere with all members of the live set + addInterference(graph, def, liveNow); + // Defined vars are removed from the live set + liveNow.dead(def); + } + // All used vars are added to the live set + liveNow.live(i.uses()); + } + } + return graph; + } + + private static void addInterference(InterferenceGraph graph, Register def, LiveSet liveSet) { + for (int regNum = liveSet.nextSetBit(0); regNum >= 0; regNum = liveSet.nextSetBit(regNum+1)) { + if (regNum != def.id) + graph.addEdge(regNum, def.id); + } + } +} diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/LiveSet.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/LiveSet.java new file mode 100644 index 0000000..b5d4da7 --- /dev/null +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/LiveSet.java @@ -0,0 +1,50 @@ +package com.compilerprogramming.ezlang.compiler; + +import java.util.BitSet; +import java.util.List; + +public class LiveSet extends BitSet { + public LiveSet(int numRegs) { + super(numRegs); + } + public LiveSet dup() { + return (LiveSet) clone(); + } + public void live(Register r) { + set(r.id, true); + } + public void dead(Register r) { + set(r.id, false); + } + public void live(List regs) { + for (Register r : regs) { + live(r); + } + } + public void add(Register r) { + set(r.id, true); + } + public void remove(Register r) { + set(r.id, false); + } + public void remove(List regs) { + for (Register r : regs) { + remove(r); + } + } + public boolean isMember(Register r) { + return get(r.id); + } + public LiveSet intersect(LiveSet other) { + and(other); + return this; + } + public LiveSet union(LiveSet other) { + or(other); + return this; + } + public LiveSet intersectNot(LiveSet other) { + andNot(other); + return this; + } +} diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Liveness.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Liveness.java index b816231..d9c9382 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Liveness.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Liveness.java @@ -10,7 +10,7 @@ */ public class Liveness { - public void computeLiveness(CompiledFunction function) { + public Liveness(CompiledFunction function) { List blocks = BBHelper.findAllBlocks(function.entry); RegisterPool regPool = function.registerPool; init(regPool, blocks); @@ -21,19 +21,17 @@ public void computeLiveness(CompiledFunction function) { private void init(RegisterPool regPool, List blocks) { int numRegisters = regPool.numRegisters(); for (BasicBlock block : blocks) { - block.UEVar = new BitSet(numRegisters); - block.varKill = new BitSet(numRegisters); - block.liveOut = new BitSet(numRegisters); + block.UEVar = new LiveSet(numRegisters); + block.varKill = new LiveSet(numRegisters); + block.liveOut = new LiveSet(numRegisters); for (Instruction instruction : block.instructions) { - if (instruction.usesVars()) { - for (Register use : instruction.uses()) { - if (!block.varKill.get(use.id)) - block.UEVar.set(use.id); - } + for (Register use : instruction.uses()) { + if (!block.varKill.isMember(use)) + block.UEVar.add(use); } if (instruction.definesVar()) { Register def = instruction.def(); - block.varKill.set(def.id); + block.varKill.add(def); } } } @@ -51,15 +49,15 @@ private void computeLiveness(List blocks) { } private boolean recomputeLiveOut(BasicBlock block) { - BitSet oldLiveOut = (BitSet) block.liveOut.clone(); + LiveSet oldLiveOut = block.liveOut.dup(); for (BasicBlock m: block.successors) { - BitSet mLiveIn = (BitSet) m.liveOut.clone(); + LiveSet mLiveIn = m.liveOut.dup(); // LiveOut(m) intersect not VarKill(m) - mLiveIn.andNot(m.varKill); + mLiveIn.intersectNot(m.varKill); // UEVar(m) union (LiveOut(m) intersect not VarKill(m)) - mLiveIn.or(m.UEVar); + mLiveIn.union(m.UEVar); // LiveOut(block) =union (UEVar(m) union (LiveOut(m) intersect not VarKill(m))) - block.liveOut.or(mLiveIn); + block.liveOut.union(mLiveIn); } return !oldLiveOut.equals(block.liveOut); } diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestInterferenceGraph.java b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestInterferenceGraph.java new file mode 100644 index 0000000..674df1f --- /dev/null +++ b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestInterferenceGraph.java @@ -0,0 +1,201 @@ +package com.compilerprogramming.ezlang.compiler; + +import com.compilerprogramming.ezlang.types.Symbol; +import com.compilerprogramming.ezlang.types.Type; +import com.compilerprogramming.ezlang.types.TypeDictionary; +import org.junit.Assert; +import org.junit.Test; + +public class TestInterferenceGraph { + + private CompiledFunction buildTest1() { + TypeDictionary typeDictionary = new TypeDictionary(); + Type.TypeFunction functionType = new Type.TypeFunction("foo"); + functionType.addArg(new Symbol.ParameterSymbol("a", typeDictionary.INT)); + functionType.setReturnType(typeDictionary.INT); + CompiledFunction function = new CompiledFunction(functionType); + RegisterPool regPool = function.registerPool; + Register a = regPool.newReg("a", typeDictionary.INT); + Register b = regPool.newReg("b", typeDictionary.INT); + Register c = regPool.newReg("c", typeDictionary.INT); + Register d = regPool.newReg("d", typeDictionary.INT); + function.code(new Instruction.ArgInstruction(new Operand.LocalRegisterOperand(a))); + function.code(new Instruction.Binary( + "+", + new Operand.RegisterOperand(a), + new Operand.RegisterOperand(b), + new Operand.ConstantOperand(1, typeDictionary.INT))); + function.code(new Instruction.Binary( + "*", + new Operand.RegisterOperand(c), + new Operand.RegisterOperand(b), + new Operand.RegisterOperand(b))); + function.code(new Instruction.Binary( + "+", + new Operand.RegisterOperand(b), + new Operand.RegisterOperand(c), + new Operand.ConstantOperand(1, typeDictionary.INT))); + function.code(new Instruction.Binary( + "*", + new Operand.RegisterOperand(d), + new Operand.RegisterOperand(b), + new Operand.RegisterOperand(a))); + function.code(new Instruction.Ret(new Operand.RegisterOperand(d))); + function.startBlock(function.exit); + function.isSSA = false; + + System.out.println(function.toStr(new StringBuilder(), true)); + + return function; + } + + @Test + public void test1() { + CompiledFunction function = buildTest1(); + var graph = new InterferenceGraphBuilder().build(function); + System.out.println(graph.generateDotOutput()); + var edges = graph.getEdges(); + Assert.assertEquals(2, edges.size()); + Assert.assertTrue(edges.contains(new InterferenceGraph.Edge(0, 1))); + Assert.assertTrue(edges.contains(new InterferenceGraph.Edge(0, 2))); + } + + /* + Engineering a Compiler, 2nd ed, page 700 + + B0 + a = 1 + if a B1 else B2 + B1 + b = 2 + d = b + goto B3 + B2 + c = 1 + d = c + goto B3 + B3 + t = a+d + + */ + private CompiledFunction buildTest2() { + TypeDictionary typeDictionary = new TypeDictionary(); + Type.TypeFunction functionType = new Type.TypeFunction("foo"); + functionType.setReturnType(typeDictionary.VOID); + CompiledFunction function = new CompiledFunction(functionType); + RegisterPool regPool = function.registerPool; + Register a = regPool.newReg("a", typeDictionary.INT); + Register b = regPool.newReg("b", typeDictionary.INT); + Register c = regPool.newReg("c", typeDictionary.INT); + Register d = regPool.newReg("d", typeDictionary.INT); + Register t = regPool.newReg("t", typeDictionary.INT); + BasicBlock b1 = function.createBlock(); + BasicBlock b2 = function.createBlock(); + BasicBlock b3 = function.createBlock(); + + function.code(new Instruction.Move( + new Operand.ConstantOperand(1, typeDictionary.INT), + new Operand.RegisterOperand(a))); + function.code(new Instruction.ConditionalBranch( + function.currentBlock, + new Operand.RegisterOperand(a), + b1, b2)); + function.startBlock(b1); + function.code(new Instruction.Move( + new Operand.ConstantOperand(2, typeDictionary.INT), + new Operand.RegisterOperand(b))); + function.code(new Instruction.Move( + new Operand.RegisterOperand(b), + new Operand.RegisterOperand(d))); + function.jumpTo(b3); + function.startBlock(b2); + function.code(new Instruction.Move( + new Operand.ConstantOperand(1, typeDictionary.INT), + new Operand.RegisterOperand(c))); + function.code(new Instruction.Move( + new Operand.RegisterOperand(c), + new Operand.RegisterOperand(d))); + function.jumpTo(b3); + function.startBlock(b3); + function.code(new Instruction.Binary( + "+", + new Operand.RegisterOperand(t), + new Operand.RegisterOperand(a), + new Operand.RegisterOperand(d))); + function.startBlock(function.exit); + function.isSSA = false; + + System.out.println(function.toStr(new StringBuilder(), true)); + + return function; + } + + @Test + public void test2() { + CompiledFunction function = buildTest2(); + var graph = new InterferenceGraphBuilder().build(function); + System.out.println(graph.generateDotOutput()); + var edges = graph.getEdges(); + Assert.assertEquals(3, edges.size()); + Assert.assertTrue(edges.contains(new InterferenceGraph.Edge(0, 1))); + Assert.assertTrue(edges.contains(new InterferenceGraph.Edge(0, 2))); + Assert.assertTrue(edges.contains(new InterferenceGraph.Edge(0, 1))); + Assert.assertTrue(edges.contains(new InterferenceGraph.Edge(0, 3))); + } + + @Test + public void test3() { + CompiledFunction function = TestLiveness.buildTest3(); + var graph = new InterferenceGraphBuilder().build(function); + System.out.println(graph.generateDotOutput()); + var edges = graph.getEdges(); + Assert.assertEquals(1, edges.size()); + Assert.assertTrue(edges.contains(new InterferenceGraph.Edge(0, 1))); + } + + /* Test move does not interfere with uses */ + private CompiledFunction buildTest4() { + TypeDictionary typeDictionary = new TypeDictionary(); + Type.TypeFunction functionType = new Type.TypeFunction("foo"); + functionType.setReturnType(typeDictionary.VOID); + CompiledFunction function = new CompiledFunction(functionType); + RegisterPool regPool = function.registerPool; + Register a = regPool.newReg("a", typeDictionary.INT); + Register b = regPool.newReg("b", typeDictionary.INT); + Register c = regPool.newReg("c", typeDictionary.INT); + Register t = regPool.newReg("t", typeDictionary.INT); + + function.code(new Instruction.Move( + new Operand.ConstantOperand(1, typeDictionary.INT), + new Operand.RegisterOperand(a))); + function.code(new Instruction.Move( + new Operand.ConstantOperand(2, typeDictionary.INT), + new Operand.RegisterOperand(b))); + function.code(new Instruction.Move( + new Operand.RegisterOperand(b), + new Operand.RegisterOperand(c))); + function.code(new Instruction.Binary( + "+", + new Operand.RegisterOperand(t), + new Operand.RegisterOperand(c), + new Operand.RegisterOperand(a))); + function.startBlock(function.exit); + function.isSSA = false; + + System.out.println(function.toStr(new StringBuilder(), true)); + + return function; + } + + /* Test move does not interfere with uses */ + @Test + public void test4() { + CompiledFunction function = buildTest4(); + var graph = new InterferenceGraphBuilder().build(function); + System.out.println(graph.generateDotOutput()); + var edges = graph.getEdges(); + Assert.assertEquals(2, edges.size()); + Assert.assertTrue(edges.contains(new InterferenceGraph.Edge(0, 1))); + Assert.assertTrue(edges.contains(new InterferenceGraph.Edge(0, 2))); + } +} diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestLiveness.java b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestLiveness.java index 6511a95..8096dae 100644 --- a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestLiveness.java +++ b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestLiveness.java @@ -1,6 +1,7 @@ package com.compilerprogramming.ezlang.compiler; import com.compilerprogramming.ezlang.types.Symbol; +import com.compilerprogramming.ezlang.types.Type; import com.compilerprogramming.ezlang.types.TypeDictionary; import org.junit.Assert; import org.junit.Test; @@ -34,8 +35,7 @@ func foo() { var typeDict = compileSrc(src); var funcSymbol = typeDict.lookup("foo"); CompiledFunction func = (CompiledFunction) ((Symbol.FunctionTypeSymbol)funcSymbol).code(); - var liveness = new Liveness(); - liveness.computeLiveness(func); + func.livenessAnalysis(); String output = Compiler.dumpIR(typeDict, true); Assert.assertEquals(""" func print(n: Int) @@ -128,8 +128,7 @@ func foo(a: Int, b: Int) { var typeDict = compileSrc(src); var funcSymbol = typeDict.lookup("foo"); CompiledFunction func = (CompiledFunction) ((Symbol.FunctionTypeSymbol)funcSymbol).code(); - var liveness = new Liveness(); - liveness.computeLiveness(func); + func.livenessAnalysis(); String output = Compiler.dumpIR(typeDict, true); Assert.assertEquals(""" func foo(a: Int,b: Int) @@ -192,4 +191,103 @@ func foo(a: Int,b: Int) """, output); } + /* page 448 Engineering a Compiler */ + static CompiledFunction buildTest3() { + TypeDictionary typeDictionary = new TypeDictionary(); + Type.TypeFunction functionType = new Type.TypeFunction("foo"); + functionType.setReturnType(typeDictionary.INT); + CompiledFunction function = new CompiledFunction(functionType); + RegisterPool regPool = function.registerPool; + Register i = regPool.newReg("i", typeDictionary.INT); + Register s = regPool.newReg("s", typeDictionary.INT); + function.code(new Instruction.Move( + new Operand.ConstantOperand(1, typeDictionary.INT), + new Operand.RegisterOperand(i))); + BasicBlock b1 = function.createBlock(); + BasicBlock b2 = function.createBlock(); + BasicBlock b3 = function.createBlock(); + BasicBlock b4 = function.createBlock(); + function.jumpTo(b1); + function.startBlock(b1); + function.code(new Instruction.ConditionalBranch( + function.currentBlock, + new Operand.RegisterOperand(i), + b2, b3)); + function.startBlock(b2); + function.code(new Instruction.Move( + new Operand.ConstantOperand(0, typeDictionary.INT), + new Operand.RegisterOperand(s))); + function.jumpTo(b3); + function.startBlock(b3); + function.code(new Instruction.Binary( + "+", + new Operand.RegisterOperand(s), + new Operand.RegisterOperand(s), + new Operand.RegisterOperand(i))); + function.code(new Instruction.Binary( + "+", + new Operand.RegisterOperand(i), + new Operand.RegisterOperand(i), + new Operand.ConstantOperand(1, typeDictionary.INT))); + function.code(new Instruction.ConditionalBranch( + function.currentBlock, + new Operand.RegisterOperand(i), + b1, b4)); + function.startBlock(b4); + function.code(new Instruction.Ret(new Operand.RegisterOperand(s))); + function.startBlock(function.exit); + function.isSSA = false; + + System.out.println(function.toStr(new StringBuilder(), true)); + + return function; + } + + @Test + public void test3() { + CompiledFunction function = buildTest3(); + function.livenessAnalysis(); + String actual = function.toStr(new StringBuilder(), true).toString(); + Assert.assertEquals(""" +func foo()->Int +Reg #0 i +Reg #1 s +L0: + i = 1 + goto L2 + #UEVAR = {} + #VARKILL = {0} + #LIVEOUT = {0, 1} +L2: + if i goto L3 else goto L4 + #UEVAR = {0} + #VARKILL = {} + #LIVEOUT = {0, 1} +L3: + s = 0 + goto L4 + #UEVAR = {} + #VARKILL = {1} + #LIVEOUT = {0, 1} +L4: + s = s+i + i = i+1 + if i goto L2 else goto L5 + #UEVAR = {0, 1} + #VARKILL = {0, 1} + #LIVEOUT = {0, 1} +L5: + ret s + goto L1 + #UEVAR = {1} + #VARKILL = {} + #LIVEOUT = {} +L1: + #UEVAR = {} + #VARKILL = {} + #LIVEOUT = {} +""", actual); + } + + }