diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/BasicBlock.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/BasicBlock.java index 845b2f6..bb82949 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/BasicBlock.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/BasicBlock.java @@ -98,10 +98,19 @@ public void add(Instruction instruction) { instructions.add(instruction); instruction.block = this; } + public void deleteInstruction(Instruction instruction) { + instructions.remove(instruction); + } public void addSuccessor(BasicBlock successor) { + assert successors.contains(successor) == false; successors.add(successor); + assert successor.predecessors.contains(this) == false; successor.predecessors.add(this); } + public void removeSuccessor(BasicBlock successor) { + successors.remove(successor); + successor.predecessors.remove(this); + } /** * Initially the phi has the form @@ -132,15 +141,25 @@ public List phis() { } return list; } - public int whichPred(BasicBlock s) { + public int whichPred(BasicBlock pred) { int i = 0; - for (BasicBlock p: s.predecessors) { - if (p == this) + for (BasicBlock p: predecessors) { + if (p == pred) return i; i++; } throw new IllegalStateException(); } + public int whichSucc(BasicBlock succ) { + int i = 0; + for (BasicBlock s: successors) { + if (s == succ) + return i; + i++; + } + throw new IllegalStateException(); + } + public static StringBuilder toStr(StringBuilder sb, BasicBlock bb, BitSet visited, boolean dumpLiveness) { if (visited.get(bb.bid)) diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/DominatorTree.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/DominatorTree.java index ed97438..ea3dc2f 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/DominatorTree.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/DominatorTree.java @@ -193,7 +193,9 @@ private void calculateDominanceFrontiers() { if (b.predecessors.size() >= 2) { for (BasicBlock p : b.predecessors) { BasicBlock runner = p; - while (runner != b.idom) { + // re runner != null: Dominance frontier calc fails in infinite loop + // scenario - need to check what the correct solution is + while (runner != b.idom && runner != null) { runner.dominationFrontier.add(b); runner = runner.idom; } diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/EnterSSA.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/EnterSSA.java index 54afbc3..9365c82 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/EnterSSA.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/EnterSSA.java @@ -153,9 +153,9 @@ void search(BasicBlock block) { } // Update phis in successor blocks for (BasicBlock s: block.successors) { - int j = block.whichPred(s); + int j = s.whichPred(block); for (Instruction.Phi phi: s.phis()) { - Register oldReg = phi.input(j); + Register oldReg = phi.inputAsRegister(j); phi.replaceInput(j, stacks[oldReg.nonSSAId()].top()); } } diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java index 684d11e..5bac8f8 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/ExitSSA.java @@ -52,6 +52,11 @@ private void insertCopies(BasicBlock block) { * replace all uses u with stacks[i] */ private void replaceUses(Instruction i) { + if (i instanceof Instruction.Phi) + // FIXME check this can never be valid + // tests 8/9 in TestInterpreter invoke on Phi but + // replacements are same as existing inputs + return; var oldUses = i.uses(); Register[] newUses = new Register[oldUses.size()]; for (int u = 0; u < oldUses.size(); u++) { @@ -65,13 +70,15 @@ private void replaceUses(Instruction i) { } static class CopyItem { - final Register src; + final Operand src; final Register dest; + final BasicBlock destBlock; boolean removed; - public CopyItem(Register src, Register dest) { + public CopyItem(Operand src, Register dest, BasicBlock destBlock) { this.src = src; this.dest = dest; + this.destBlock = destBlock; this.removed = false; } } @@ -83,14 +90,17 @@ private void scheduleCopies(BasicBlock block, List pushed) { Map map = new HashMap<>(); BitSet usedByAnother = new BitSet(function.registerPool.numRegisters()*2); for (BasicBlock s: block.successors) { - int j = block.whichPred(s); + int j = s.whichPred(block); for (Instruction.Phi phi: s.phis()) { Register dst = phi.value(); - Register src = phi.input(j); // jth operand of phi node - copySet.add(new CopyItem(src, dst)); - map.put(src.id, src); + Operand srcOperand = phi.input(j); // jth operand of phi node + if (srcOperand instanceof Operand.RegisterOperand srcRegisterOperand) { + Register src = srcRegisterOperand.reg; + map.put(src.id, src); + usedByAnother.set(src.id); + } + copySet.add(new CopyItem(srcOperand, dst, s)); map.put(dst.id, dst); - usedByAnother.set(src.id); } } @@ -111,8 +121,9 @@ private void scheduleCopies(BasicBlock block, List pushed) { while (!workList.isEmpty() || !copySet.isEmpty()) { while (!workList.isEmpty()) { final CopyItem copyItem = workList.remove(0); - final Register src = copyItem.src; + final Operand src = copyItem.src; final Register dest = copyItem.dest; + final BasicBlock destBlock = copyItem.destBlock; /* Engineering a Compiler: We can avoid the lost copy problem by checking the liveness of the target name for each copy that we try to insert. When we discover @@ -122,18 +133,23 @@ private void scheduleCopies(BasicBlock block, List pushed) { */ if (block.liveOut.get(dest.id)) { /* Insert a copy from dest to a new temp t at phi node defining dest */ - final Register t = addMoveToTempAfterPhi(block, dest); + final Register t = addMoveToTempAfterPhi(destBlock, dest); stacks[dest.id].push(t); pushed.add(dest.id); } /* Insert a copy operation from map[src] to dest at end of BB */ - addMoveAtBBEnd(block, map.get(src.id), dest); - map.put(src.id, dest); - /* If src is the name of a dest in copySet add item to worklist */ - /* see comment on phi cycles below. */ - CopyItem item = isCycle(copySet, src); - if (item != null) { - workList.add(item); + if (src instanceof Operand.RegisterOperand srcRegisterOperand) { + addMoveAtBBEnd(block, map.get(srcRegisterOperand.reg.id), dest); + map.put(srcRegisterOperand.reg.id, dest); + /* If src is the name of a dest in copySet add item to worklist */ + /* see comment on phi cycles below. */ + CopyItem item = isCycle(copySet, srcRegisterOperand.reg); + if (item != null) { + workList.add(item); + } + } + else if (src instanceof Operand.ConstantOperand srcConstantOperand) { + addMoveAtBBEnd(block, srcConstantOperand, dest); } } /* Engineering a Compiler: To solve the swap problem @@ -204,7 +220,10 @@ private void addMoveAtBBEnd(BasicBlock block, Register src, Register dest) { var inst = new Instruction.Move(new Operand.RegisterOperand(src), new Operand.RegisterOperand(dest)); insertAtEnd(block, inst); } - + private void addMoveAtBBEnd(BasicBlock block, Operand.ConstantOperand src, Register dest) { + var inst = new Instruction.Move(src, new Operand.RegisterOperand(dest)); + insertAtEnd(block, inst); + } /* Insert a copy dest to a new temp at phi node defining dest, return temp */ private Register addMoveToTempAfterPhi(BasicBlock block, Register dest) { var temp = function.registerPool.newTempReg(dest.name(), dest.type); diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Instruction.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Instruction.java index 9a98f37..20f70a2 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Instruction.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Instruction.java @@ -27,19 +27,21 @@ public abstract class Instruction { static final int I_FIELD_SET = 16; public final int opcode; - public Operand.RegisterOperand def; - public Operand[] uses; + protected Operand.RegisterOperand def; + protected Operand[] uses; public BasicBlock block; public Instruction(int opcode, Operand... uses) { this.opcode = opcode; this.def = null; - this.uses = uses; + this.uses = new Operand[uses.length]; + System.arraycopy(uses, 0, this.uses, 0, uses.length); } public Instruction(int opcode, Operand.RegisterOperand def, Operand... uses) { this.opcode = opcode; this.def = def; - this.uses = uses; + this.uses = new Operand[uses.length]; + System.arraycopy(uses, 0, this.uses, 0, uses.length); } public boolean isTerminal() { return false; } @@ -87,7 +89,14 @@ public boolean replaceUse(Register source, Register target) { } return replaced; } - + public void replaceWithConstant(Register register, Operand.ConstantOperand constantOperand) { + for (int i = 0; i < uses.length; i++) { + Operand operand = uses[i]; + if (operand != null && operand instanceof Operand.RegisterOperand registerOperand && registerOperand.reg.id == register.id) { + uses[i] = constantOperand; + } + } + } public static class NoOp extends Instruction { public NoOp() { super(I_NOOP); @@ -340,21 +349,35 @@ public StringBuilder toStr(StringBuilder sb) { /** * Phi does not generate uses or defs directly, instead * they are treated as a special case. - * To avoid bugs we do not use the def or uses. + * To avoid bugs we disable the standard interfaces */ public static class Phi extends Instruction { - public Register value; - public final Register[] inputs; + private Register value; public Phi(Register value, List inputs) { super(I_PHI); this.value = value; - this.inputs = inputs.toArray(new Register[inputs.size()]); + this.uses = new Operand[inputs.size()]; + for (int i = 0; i < inputs.size(); i++) { + this.uses[i] = new Operand.RegisterOperand(inputs.get(i)); + } } public void replaceInput(int i, Register newReg) { - inputs[i] = newReg; + uses[i].replaceRegister(newReg); + } + /** + * This will fail in input was replaced by a constant + */ + public Register inputAsRegister(int i) { + return ((Operand.RegisterOperand) uses[i]).reg; + } + public Operand input(int i) { + return uses[i]; + } + public boolean isRegisterInput(int i) { + return uses[i] instanceof Operand.RegisterOperand; } - public Register input(int i) { - return inputs[i]; + public Register[] inputRegisters() { + return super.uses().toArray(new Register[super.uses().size()]); } @Override public Register def() { @@ -368,18 +391,36 @@ public void replaceDef(Register newReg) { public boolean definesVar() { return false; } + @Override + public List uses() { + return Collections.emptyList(); + } + @Override + public void replaceUses(Register[] newUses) { + throw new UnsupportedOperationException(); + } + @Override + public boolean replaceUse(Register source, Register target) { + throw new UnsupportedOperationException(); + } public Register value() { return value; } public void replaceValue(Register newReg) { this.value = newReg; } + public void removeInput(int i) { + var newUses = new Operand[uses.length - 1]; + System.arraycopy(uses, 0, newUses, 0, i); + System.arraycopy(uses, i + 1, newUses, i, uses.length - i - 1); + this.uses = newUses; + } @Override public StringBuilder toStr(StringBuilder sb) { sb.append(value().name()).append(" = phi("); - for (int i = 0; i < inputs.length; i++) { + for (int i = 0; i < uses.length; i++) { if (i > 0) sb.append(", "); - sb.append(inputs[i].name()); + sb.append(uses[i].toString()); } sb.append(")"); return sb; diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/InterferenceGraph.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/InterferenceGraph.java index 6691d06..1714f6b 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/InterferenceGraph.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/InterferenceGraph.java @@ -63,6 +63,10 @@ public void rename(Integer source, Integer target) { // Move all interferences var fromSet = edges.remove(source); var toSet = edges.get(target); + if (toSet == null) { + //throw new RuntimeException("Cannot find edge " + target + " from " + source); + return; // FIXME this is workaround to handle sceanrio where target is arg register but we need a better way + } toSet.addAll(fromSet); // If any node interfered with from // it should now interfere with to diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Liveness.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Liveness.java index 8d0b322..f5b2802 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Liveness.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Liveness.java @@ -69,7 +69,9 @@ private void init(List blocks) { if (instruction instanceof Instruction.Phi phi) { for (int i = 0; i < block.predecessors.size(); i++) { BasicBlock pred = block.predecessors.get(i); - Register use = phi.input(i); + if (!phi.isRegisterInput(i)) + continue; + Register use = phi.inputAsRegister(i); // We can have a block referring it its own phis // if there is loop back and there are cycles // such as e.g. the swap copy problem diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Optimizer.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Optimizer.java index f886a4a..8eeebbc 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Optimizer.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Optimizer.java @@ -4,7 +4,7 @@ public class Optimizer { public void optimize(CompiledFunction function) { new EnterSSA(function); - new SparseConditionalConstantPropagation().constantPropagation(function); + new SparseConditionalConstantPropagation().constantPropagation(function).apply(); new ExitSSA(function); new ChaitinGraphColoringRegisterAllocator().assignRegisters(function, 64); } diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/SSAEdges.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/SSAEdges.java index 262a3ba..ed8b8ea 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/SSAEdges.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/SSAEdges.java @@ -58,7 +58,7 @@ private static void recordUses(CompiledFunction function, Map for (BasicBlock block : function.getBlocks()) { for (Instruction instruction : block.instructions) { if (instruction instanceof Instruction.Phi phi) { - recordUses(defUseChains, phi.inputs, block, instruction); + recordUses(defUseChains, phi.inputRegisters(), block, instruction); } else { List uses = instruction.uses(); diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/SparseConditionalConstantPropagation.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/SparseConditionalConstantPropagation.java index 08c558a..a8e77e3 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/SparseConditionalConstantPropagation.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/SparseConditionalConstantPropagation.java @@ -14,8 +14,6 @@ *
  • Modern Compiler Implementation in C, Andrew Appel, section 19.3
  • *
  • Building an Optimizing Compiler, Bob Morgan, section 8.3
  • * - * - * */ public class SparseConditionalConstantPropagation { @@ -50,6 +48,9 @@ public class SparseConditionalConstantPropagation { Map ssaEdges; CompiledFunction function; + /** Used to track reachable blocks when the SCCP changes are applied */ + BitSet executableBlocks = new BitSet(); + public SparseConditionalConstantPropagation constantPropagation(CompiledFunction function) { init(function); while (!flowWorklist.isEmpty() || !instructionWorkList.isEmpty()) { @@ -119,6 +120,99 @@ private void init(CompiledFunction function) { visited = new BitSet(); } + public SparseConditionalConstantPropagation apply() { + /* + The constant propagation algorithm does not change the flow graph - it computes + information about the flow graph. The compiler now uses this information to improve + the graph in the following ways: + + * The instructions corresponding to temporaries that evaluate as constants are modified + to be load constant instructions. + + • An edge that has not become executable is eliminated, and the conditional branching + instruction representing that edge is modified to be a simpler instruction. + The phi-nodes at the head of the edge are modified to have one less operand. + + • Blocks that become unreachable are eliminated. + + Bob Morgan. Building an Optimizing Compiler + */ + markExecutableBlocks(); + removeBranchesThatAreNotExecutable(); + replaceVarsWithConstants(); + // Unreachable blocks are eliminated as there are no paths to them + return this; + } + + private void markExecutableBlocks() { + var blocks = function.getBlocks(); + executableBlocks = new BitSet(blocks.size()); + executableBlocks.set(function.entry.bid); + for (FlowEdge edge: flowEdges.keySet()) { + if (flowEdges.get(edge)) { + executableBlocks.set(edge.source.bid); + executableBlocks.set(edge.target.bid); + } + } + } + + /** + * Where we know which branch will be executed on a CBR, + * we replace such a branch with a jump to the known + * basic block + */ + private void removeBranchesThatAreNotExecutable() { + for (var flowEdge : flowEdges.keySet()) { + if (!flowEdges.get(flowEdge)) { + if (executableBlocks.get(flowEdge.source.bid) || + executableBlocks.get(flowEdge.target.bid)) + removeEdge(flowEdge.source, flowEdge.target); + } + } + } + + private void removeEdge(BasicBlock source, BasicBlock target) { + int j = target.whichPred(source); + // Replace cbr with jump + int idx = source.instructions.size()-1; + Instruction instruction = source.instructions.get(idx); + if (instruction instanceof Instruction.ConditionalBranch cbr) { + BasicBlock remainingExecutableBlock = (cbr.falseBlock == target) ? cbr.trueBlock : cbr.falseBlock; + source.instructions.set(idx, new Instruction.Jump(remainingExecutableBlock)); + } + // Remove phis in target corresponding to the input + for (var phi: target.phis()) { + phi.removeInput(j); + } + // update cfg + source.removeSuccessor(target); + } + + /** + * Where a definition is known to be a constant, + * replace all uses with the constant and then delete + * the defining instruction. + */ + private void replaceVarsWithConstants() { + for (var register: valueLattice.getRegisters()) { + var latticeElement = valueLattice.get(register); + if (latticeElement.kind == V_CONSTANT) { + var constant = new Operand.ConstantOperand(latticeElement.value, register.type); + var defUseChain = this.ssaEdges.get(register); + // replace uses with constant + for (var usingInstruction: defUseChain.useList) { + if (executableBlocks.get(usingInstruction.block.bid)) + usingInstruction.replaceWithConstant(register, constant); + } + defUseChain.useList.clear(); + var block = defUseChain.instruction.block; + // delete defining instruction + block.deleteInstruction(defUseChain.instruction); + ssaEdges.remove(register); + } + } + } + static final byte V_UNDEFINED = 1; // TOP static final byte V_CONSTANT = 2; static final byte V_VARYING = 3; // BOTTOM @@ -349,7 +443,7 @@ private boolean visitPhi(BasicBlock block, Instruction.Phi phiInst) { BasicBlock pred = block.predecessors.get(j); // We ignore non-executable edges if (isEdgeExecutable(pred, block)) { - LatticeElement varValue = valueLattice.get(phiInst.input(j)); + LatticeElement varValue = valueLattice.get(phiInst.inputAsRegister(j)); newValue.meet(varValue); } } diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java b/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java index a62375f..6aec1b8 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/interpreter/Interpreter.java @@ -101,8 +101,12 @@ else if (cbrInst.condition() instanceof Operand.ConstantOperand constantOperand) int baseReg = base+currentFunction.frameSize(); int reg = baseReg; for (Operand arg: callInst.args()) { - Operand.RegisterOperand param = (Operand.RegisterOperand) arg; - execStack.stack[base + reg] = execStack.stack[base + param.frameSlot()]; + if (arg instanceof Operand.RegisterOperand param) { + execStack.stack[base + reg] = execStack.stack[base + param.frameSlot()]; + } + else if (arg instanceof Operand.ConstantOperand constantOperand) { + execStack.stack[base + reg] = new Value.IntegerValue(constantOperand.value); + } reg += 1; } // Call function diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSCCP.java b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSCCP.java index c0d4cae..530fbd9 100644 --- a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSCCP.java +++ b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSCCP.java @@ -18,7 +18,11 @@ String compileSrc(String src) { new EnterSSA(functionBuilder); BasicBlock.toStr(sb, functionBuilder.entry, new BitSet(), false); //functionBuilder.toDot(sb, false); - sb.append(new SparseConditionalConstantPropagation().constantPropagation(functionBuilder).toString()); + var sccp = new SparseConditionalConstantPropagation().constantPropagation(functionBuilder); + sb.append(sccp.toString()); + sccp.apply(); + sb.append("After SCCP changes:\n"); + functionBuilder.toStr(sb, false); } } return sb.toString(); @@ -64,6 +68,15 @@ func foo()->Int { %t1_0=0 i_1=3 i_3=3 +After SCCP changes: +L0: + goto L3 +L3: + goto L4 +L4: + ret 3 + goto L1 +L1: """; Assert.assertEquals(expected, actual); } @@ -146,6 +159,26 @@ func foo()->Int { j_1=1 %t3_0=varying %t4_0=1 +After SCCP changes: +L0: + goto L2 +L2: + k_1 = phi(0, k_4) + %t3_0 = k_1<100 + if %t3_0 goto L3 else goto L4 +L3: + goto L5 +L5: + %t5_0 = k_1+1 + k_3 = %t5_0 + goto L7 +L7: + k_4 = phi(k_3) + goto L2 +L4: + ret 1 + goto L1 +L1: """; Assert.assertEquals(expected, actual); } diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java index 8b46834..d2b4f04 100644 --- a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java +++ b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestSSATransform.java @@ -840,4 +840,210 @@ func foo(x: Int, y: Int)->Int { } + @Test + public void testContinue() { + String src = """ +func foo(x: Int)->Int { + var sum = 0 + var i = 0 + while (i < x) { + if (i % 2 == 0) + continue + if (i / 3 == 1) + continue + sum = sum + 1 + i = i + 1 + } + return sum +} + """; + String result = compileSrc(src); + Assert.assertEquals(""" +func foo +Before SSA +========== +L0: + arg x + sum = 0 + i = 0 + goto L2 +L2: + %t3 = i