diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java index 7d53293..caf691e 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java @@ -23,6 +23,9 @@ public class CompiledFunction { private final int frameSlots; + public boolean isSSA; + public boolean hasLiveness; + /** * We essentially do a form of abstract interpretation as we generate * the bytecode instructions. For this purpose we use a virtual operand stack. diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java new file mode 100644 index 0000000..f4d96c6 --- /dev/null +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java @@ -0,0 +1,211 @@ +package com.compilerprogramming.ezlang.compiler; + +import java.util.*; + +/** + * Converts from SSA form to non-SSA form. + * Implementation is based on description in + * 'Practical Improvements to the Construction and Destruction + * of Static Single Assignment Form' by Preston Briggs. + */ +public class ExitSSA { + + CompiledFunction function; + NameStack[] stacks; + DominatorTree tree; + + public ExitSSA(CompiledFunction function) { + this.function = function; + if (!function.isSSA) throw new IllegalStateException(); + if (!function.hasLiveness) { + new Liveness().computeLiveness(function); + } + tree = new DominatorTree(function.entry); + initStack(); + insertCopies(function.entry); + removePhis(); + } + + private void removePhis() { + for (BasicBlock block : tree.blocks) { + block.instructions.removeIf(instruction -> instruction instanceof Instruction.Phi); + } + } + + /* Algorithm for iterating through blocks to perform phi replacement */ + private void insertCopies(BasicBlock block) { + List pushed = new ArrayList<>(); + for (Instruction i: block.instructions) { + // replace all uses u with stacks[i] + if (i.usesVars()) { + replaceUses(i); + } + } + scheduleCopies(block, pushed); + for (BasicBlock c: block.dominatedChildren) { + insertCopies(c); + } + for (Register name: pushed) { + stacks[name.id].pop(); + } + } + + /** + * replace all uses u with stacks[i] + */ + private void replaceUses(Instruction i) { + var oldUses = i.uses(); + Register[] newUses = new Register[oldUses.size()]; + for (int u = 0; u < oldUses.size(); u++) { + Register use = oldUses.get(u); + if (!stacks[use.id].isEmpty()) + newUses[u] = stacks[use.id].top(); + else + newUses[u] = use; + } + i.replaceUses(newUses); + } + + static class CopyItem { + Register src; + Register dest; + boolean removed; + + public CopyItem(Register src, Register dest) { + this.src = src; + this.dest = dest; + this.removed = false; + } + } + + private void scheduleCopies(BasicBlock block, List pushed) { + /* Pass 1 - Initialize data structures */ + /* In this pass we count the number of times a name is used by other phi-nodes */ + List copySet = new ArrayList<>(); + Map map = new HashMap<>(); + BitSet usedByAnother = new BitSet(function.registerPool.numRegisters()*2); + for (BasicBlock s: block.successors) { + int j = SSATransform.whichPred(s, block); + for (Instruction.Phi phi: s.phis()) { + Register dst = phi.def(); + Register src = phi.inputs.get(j).reg; // jth operand of phi node + copySet.add(new CopyItem(src, dst)); + map.put(src.id, src); + map.put(dst.id, dst); + usedByAnother.set(src.id); + } + } + + /* Pass 2: setup up the worklist of initial copies */ + /* In this pass we build a worklist of names that are not used in other phi nodes */ + List workList = new ArrayList<>(); + for (CopyItem copyItem: copySet) { + if (!usedByAnother.get(copyItem.dest.id)) { + copyItem.removed = true; + workList.add(copyItem); + } + } + copySet.removeIf(copyItem -> copyItem.removed); + + /* Pass 3: iterate over the worklist, inserting copies */ + /* Copy operations whose destinations are not used by other copy operations can be scheduled immediately */ + /* Each time we insert a copy operation we add the source of that op to the worklist */ + while (!workList.isEmpty() || !copySet.isEmpty()) { + while (!workList.isEmpty()) { + CopyItem copyItem = workList.remove(0); + Register src = copyItem.src; + Register dest = copyItem.dest; + if (block.liveOut.get(dest.id)) { + /* Insert a copy from dest to a new temp t at phi node defining dest */ + Register t = insertCopy(block, dest); + stacks[dest.id].push(t); + pushed.add(t); + } + /* Insert a copy operation from map[src] to dst at end of BB */ + appendCopy(block, map.get(src.id), dest); + map.put(src.id, dest); + /* If src is the name of a dest in copySet add item to worklist */ + CopyItem item = isDest(copySet, src); + if (item != null) { + workList.add(item); + } + } + if (!copySet.isEmpty()) { + CopyItem copyItem = copySet.remove(0); + /* Insert a copy from dst to new temp at the end of Block */ + Register t = appendCopy(block, copyItem.dest); + map.put(copyItem.dest.id, t); + workList.add(copyItem); + } + } + } + + private void insertAtEnd(BasicBlock bb, Instruction i) { + assert bb.instructions.size() > 0; + int pos = bb.instructions.size()-1; + bb.instructions.add(pos, i); + } + + private void insertAfterPhi(BasicBlock bb, Register phiDef, Instruction newInst) { + assert bb.instructions.size() > 0; + int pos = 0; + while (pos < bb.instructions.size()) { + Instruction i = bb.instructions.get(pos); + if (i instanceof Instruction.Phi phi && + phi.def().id == phiDef.id) { + pos += 1; + break; + } + } + if (pos == bb.instructions.size()) { + throw new IllegalStateException(); + } + bb.instructions.add(pos, newInst); + } + + + /* Insert a copy from dest to new temp at end of BB, and return temp */ + private Register appendCopy(BasicBlock block, Register dest) { + var temp = function.registerPool.newReg(dest.name(), dest.type); + var inst = new Instruction.Move(new Operand.RegisterOperand(dest), new Operand.RegisterOperand(temp)); + insertAtEnd(block, inst); + return temp; + } + + /* If src is the name of a dest in copySet return the item */ + private CopyItem isDest(List copySet, Register src) { + for (CopyItem copyItem: copySet) { + if (copyItem.dest.id == src.id) + return copyItem; + } + return null; + } + + /* Insert a copy from src to dst at end of BB */ + private void appendCopy(BasicBlock block, Register src, Register dest) { + var inst = new Instruction.Move(new Operand.RegisterOperand(src), new Operand.RegisterOperand(dest)); + insertAtEnd(block, inst); + } + + /* Insert a copy dest to a new temp at phi node defining dest, return temp */ + private Register insertCopy(BasicBlock block, Register dst) { + var temp = function.registerPool.newReg(dst.name(), dst.type); + var inst = new Instruction.Move(new Operand.RegisterOperand(dst), new Operand.RegisterOperand(temp)); + insertAfterPhi(block, dst, inst); + return temp; + } + + private void initStack() { + stacks = new NameStack[function.registerPool.numRegisters()]; + for (int i = 0; i < stacks.length; i++) + stacks[i] = new NameStack(); + } + + static class NameStack { + List stack = new ArrayList<>(); + void push(Register r) { stack.add(r); } + Register top() { return stack.getLast(); } + void pop() { stack.removeLast(); } + boolean isEmpty() { return stack.isEmpty(); } + } +} diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Liveness.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Liveness.java index 72e70db..b816231 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Liveness.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Liveness.java @@ -15,6 +15,7 @@ public void computeLiveness(CompiledFunction function) { RegisterPool regPool = function.registerPool; init(regPool, blocks); computeLiveness(blocks); + function.hasLiveness = true; } private void init(RegisterPool regPool, List blocks) { diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/SSATransform.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/SSATransform.java index d222ca8..30a6e8f 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/SSATransform.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/SSATransform.java @@ -30,6 +30,8 @@ public SSATransform(CompiledFunction bytecodeFunction) { findGlobalVars(); insertPhis(); renameVars(); + bytecodeFunction.isSSA = true; + bytecodeFunction.hasLiveness = false; } private void computeDomTreeAndDominanceFrontiers() { @@ -161,7 +163,7 @@ void search(BasicBlock block) { } } - private int whichPred(BasicBlock s, BasicBlock block) { + public static int whichPred(BasicBlock s, BasicBlock block) { int i = 0; for (BasicBlock p: s.predecessors) { if (p == block) diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java index b56b77f..f690a36 100644 --- a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java +++ b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java @@ -23,6 +23,10 @@ String compileSrc(String src) { sb.append("After SSA\n"); sb.append("=========\n"); BasicBlock.toStr(sb, functionBuilder.entry, new BitSet(), false); + new ExitSSA(functionBuilder); + sb.append("After exiting SSA\n"); + sb.append("=================\n"); + BasicBlock.toStr(sb, functionBuilder.entry, new BitSet(), false); } } return sb.toString(); @@ -70,6 +74,20 @@ func foo(d: Int) { c_1 = %t7_0 goto L1 L1: +After exiting SSA +================= +L0: + arg d_0 + a_0 = 42 + b_0 = a_0 + %t5_0 = a_0+b_0 + c_0 = %t5_0 + %t6_0 = c_0+23 + a_1 = %t6_0 + %t7_0 = a_1+d_0 + c_1 = %t7_0 + goto L1 +L1: """, result); } @@ -130,6 +148,26 @@ func foo(d: Int)->Int { %t4_0 = a_0-1 a_1 = %t4_0 goto L4 +After exiting SSA +================= +L0: + arg d_0 + a_0 = 42 + if d_0 goto L2 else goto L3 +L2: + %t3_0 = a_0+1 + a_2 = %t3_0 + a_3 = a_2 + goto L4 +L4: + %ret_0 = a_3 + goto L1 +L1: +L3: + %t4_0 = a_0-1 + a_1 = %t4_0 + a_3 = a_1 + goto L4 """, result); } @@ -190,6 +228,29 @@ func factorial(num: Int)->Int { %ret_0 = result_1 goto L1 L1: +After exiting SSA +================= +L0: + arg num_0 + result_0 = 1 + result_1 = result_0 + num_1 = num_0 + goto L2 +L2: + %t3_0 = num_1>1 + if %t3_0 goto L3 else goto L4 +L3: + %t4_0 = result_1*num_1 + result_2 = %t4_0 + %t5_0 = num_1-1 + num_2 = %t5_0 + result_1 = result_2 + num_1 = num_2 + goto L2 +L4: + %ret_0 = result_1 + goto L1 +L1: """, result); } @@ -252,6 +313,15 @@ func example14_66(p: Int, q: Int, r: Int, s: Int, t: Int) { arg d_0 goto L1 L1: +After exiting SSA +================= +L0: + arg a_0 + arg b_0 + arg c_0 + arg d_0 + goto L1 +L1: func example14_66 Before SSA ========== @@ -404,6 +474,104 @@ func example14_66(p: Int, q: Int, r: Int, s: Int, t: Int) { %t11_0 = k_1+2 k_2 = %t11_0 goto L7 +After exiting SSA +================= +L0: + arg p_0 + arg q_0 + arg r_0 + arg s_0 + arg t_0 + i_0 = 1 + j_0 = 1 + k_0 = 1 + l_0 = 1 + l_1 = l_0 + k_1 = k_0 + j_1 = j_0 + i_1 = i_0 + goto L2 +L2: + l_10 = l_1 + k_5 = k_1 + j_4 = j_1 + i_3 = i_1 + if 1 goto L3 else goto L4 +L3: + if p_0 goto L5 else goto L6 +L5: + j_2 = i_1 + if q_0 goto L8 else goto L9 +L8: + l_3 = 2 + l_4 = l_3 + goto L10 +L10: + %t10_0 = k_1+1 + k_3 = %t10_0 + l_5 = l_4 + k_4 = k_3 + j_3 = j_2 + goto L7 +L7: + %t12_0 = i_1 + %t13_0 = j_3 + %t14_0 = k_4 + %t15_0 = l_5 + call print params %t12_0, %t13_0, %t14_0, %t15_0 + l_6 = l_5 + goto L11 +L11: + l_9 = l_6 + if 1 goto L12 else goto L13 +L12: + l_8 = l_6 + if r_0 goto L14 else goto L15 +L14: + %t16_0 = l_6+4 + l_7 = %t16_0 + l_8 = l_7 + goto L15 +L15: + %t17_0 = !s_0 + if %t17_0 goto L16 else goto L17 +L16: + l_9 = l_8 + goto L13 +L13: + %t18_0 = i_1+6 + i_2 = %t18_0 + %t19_0 = !t_0 + if %t19_0 goto L18 else goto L19 +L18: + l_10 = l_9 + k_5 = k_4 + j_4 = j_3 + i_3 = i_2 + goto L4 +L4: + goto L1 +L1: +L19: + l_1 = l_9 + k_1 = k_4 + j_1 = j_3 + i_1 = i_2 + goto L2 +L17: + l_6 = l_8 + goto L11 +L9: + l_2 = 3 + l_4 = l_2 + goto L10 +L6: + %t11_0 = k_1+2 + k_2 = %t11_0 + l_5 = l_1 + k_4 = k_2 + j_3 = j_1 + goto L7 """, result); } @@ -413,7 +581,7 @@ public void test5() { func bar(arg: Int)->Int { if (arg) return 42; - return 0; + return 0; } """; String result = compileSrc(src); @@ -440,6 +608,18 @@ func bar(arg: Int)->Int { %ret_1 = 42 goto L1 L1: +L3: + %ret_0 = 0 + goto L1 +After exiting SSA +================= +L0: + arg arg_0 + if arg_0 goto L2 else goto L3 +L2: + %ret_1 = 42 + goto L1 +L1: L3: %ret_0 = 0 goto L1