From 68509350d86a999a4854d01cbb0afbc303a737b8 Mon Sep 17 00:00:00 2001 From: dibyendumajumdar Date: Mon, 16 Dec 2024 23:20:21 +0000 Subject: [PATCH] Liveness Analysis --- .../ezlang/compiler/BBHelper.java | 27 +++ .../ezlang/compiler/BasicBlock.java | 28 ++- .../ezlang/compiler/CompiledFunction.java | 9 + .../ezlang/compiler/Compiler.java | 10 +- .../ezlang/compiler/DominatorTree.java | 21 +- .../ezlang/compiler/Liveness.java | 65 ++++++ .../ezlang/compiler/Register.java | 2 +- .../ezlang/compiler/RegisterPool.java | 8 +- .../ezlang/compiler/TestLiveness.java | 198 ++++++++++++++++++ .../ezlang/compiler/TestSSATransform.java | 4 +- 10 files changed, 342 insertions(+), 30 deletions(-) create mode 100644 optvm/src/main/java/com/compilerprogramming/ezlang/compiler/BBHelper.java create mode 100644 optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Liveness.java create mode 100644 optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestLiveness.java diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/BBHelper.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/BBHelper.java new file mode 100644 index 0000000..c20521d --- /dev/null +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/BBHelper.java @@ -0,0 +1,27 @@ +package com.compilerprogramming.ezlang.compiler; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.function.Consumer; + +public class BBHelper { + /** + * Utility to locate all the basic blocks, order does not matter. + */ + public static List findAllBlocks(BasicBlock root) { + List nodes = new ArrayList<>(); + postOrderWalk(root, (n) -> nodes.add(n), new HashSet<>()); + return nodes; + } + + static void postOrderWalk(BasicBlock n, Consumer consumer, HashSet visited) { + visited.add(n); + /* For each successor node */ + for (BasicBlock s : n.successors) { + if (!visited.contains(s)) + postOrderWalk(s, consumer, visited); + } + consumer.accept(n); + } +} diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/BasicBlock.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/BasicBlock.java index e327e90..9a42c6a 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/BasicBlock.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/BasicBlock.java @@ -44,6 +44,25 @@ public class BasicBlock { */ public LoopNest loop; + // Liveness computation + /** + * VarKill contains all the variables that are defined + * in the block. + */ + BitSet varKill; + /** + * UEVar contains upward-exposed variables in the block, + * i.e. those variables that are used in the block prior to + * any redefinition in the block. + */ + BitSet UEVar; + /** + * LiveOut is the union of variables that are live at the + * head of some block that is a successor of this block. + */ + BitSet liveOut; + // ----------------------- + public BasicBlock(int bid, boolean loopHead) { this.bid = bid; this.loopHead = loopHead; @@ -94,7 +113,7 @@ public List phis() { } return list; } - public static StringBuilder toStr(StringBuilder sb, BasicBlock bb, BitSet visited) + public static StringBuilder toStr(StringBuilder sb, BasicBlock bb, BitSet visited, boolean dumpLiveness) { if (visited.get(bb.bid)) return sb; @@ -104,8 +123,13 @@ public static StringBuilder toStr(StringBuilder sb, BasicBlock bb, BitSet visite sb.append(" "); n.toStr(sb).append("\n"); } + if (dumpLiveness) { + if (bb.UEVar != null) sb.append(" #UEVAR = ").append(bb.UEVar.toString()).append("\n"); + if (bb.varKill != null) sb.append(" #VARKILL = ").append(bb.varKill.toString()).append("\n"); + if (bb.liveOut != null) sb.append(" #LIVEOUT = ").append(bb.liveOut.toString()).append("\n"); + } for (BasicBlock succ: bb.successors) { - toStr(sb, succ, visited); + toStr(sb, succ, visited, dumpLiveness); } return sb; } diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java index 9656ed1..7d53293 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/CompiledFunction.java @@ -7,6 +7,7 @@ import com.compilerprogramming.ezlang.types.Type; import java.util.ArrayList; +import java.util.BitSet; import java.util.List; public class CompiledFunction { @@ -536,4 +537,12 @@ else if (indexed instanceof Operand.LoadFieldOperand loadFieldOperand) { private boolean vstackEmpty() { return virtualStack.isEmpty(); } + + public void toStr(StringBuilder sb, boolean verbose) { + if (verbose) { + sb.append(this.functionType.describe()).append("\n"); + registerPool.toStr(sb); + } + BasicBlock.toStr(sb, entry, new BitSet(), verbose); + } } diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Compiler.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Compiler.java index aa173b7..105102b 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Compiler.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Compiler.java @@ -31,15 +31,17 @@ public TypeDictionary compileSrc(String src) { compile(typeDict); return typeDict; } - public String dumpIR(TypeDictionary typeDictionary) { + public static String dumpIR(TypeDictionary typeDictionary) { + return dumpIR(typeDictionary, false); + } + public static String dumpIR(TypeDictionary typeDictionary, boolean verbose) { StringBuilder sb = new StringBuilder(); for (Symbol s: typeDictionary.bindings.values()) { if (s instanceof Symbol.FunctionTypeSymbol f) { - var functionBuilder = (CompiledFunction) f.code(); - BasicBlock.toStr(sb, functionBuilder.entry, new BitSet()); + var function = (CompiledFunction) f.code(); + function.toStr(sb, verbose); } } return sb.toString(); } - } diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/DominatorTree.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/DominatorTree.java index 21f68a9..ed97438 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/DominatorTree.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/DominatorTree.java @@ -29,32 +29,13 @@ public class DominatorTree { */ public DominatorTree(BasicBlock entry) { this.entry = entry; - blocks = findAllBlocks(entry); + blocks = BBHelper.findAllBlocks(entry); calculateDominatorTree(); populateTree(); setDepth(); calculateDominanceFrontiers(); } - /** - * Utility to locate all the basic blocks, order does not matter. - */ - private static List findAllBlocks(BasicBlock root) { - List nodes = new ArrayList<>(); - postOrderWalk(root, (n) -> nodes.add(n), new HashSet<>()); - return nodes; - } - - static void postOrderWalk(BasicBlock n, Consumer consumer, HashSet visited) { - visited.add(n); - /* For each successor node */ - for (BasicBlock s : n.successors) { - if (!visited.contains(s)) - postOrderWalk(s, consumer, visited); - } - consumer.accept(n); - } - private void calculateDominatorTree() { resetDomInfo(); annotateBlocksWithRPO(); diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Liveness.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Liveness.java new file mode 100644 index 0000000..72e70db --- /dev/null +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Liveness.java @@ -0,0 +1,65 @@ +package com.compilerprogramming.ezlang.compiler; + +import java.util.BitSet; +import java.util.List; + +/** + * Compute LiveOut for each Basic Block + * Implementation is based on description in 'Engineering a Compiler' 2nd ed. + * pages 446-447. + */ +public class Liveness { + + public void computeLiveness(CompiledFunction function) { + List blocks = BBHelper.findAllBlocks(function.entry); + RegisterPool regPool = function.registerPool; + init(regPool, blocks); + computeLiveness(blocks); + } + + private void init(RegisterPool regPool, List blocks) { + int numRegisters = regPool.numRegisters(); + for (BasicBlock block : blocks) { + block.UEVar = new BitSet(numRegisters); + block.varKill = new BitSet(numRegisters); + block.liveOut = new BitSet(numRegisters); + for (Instruction instruction : block.instructions) { + if (instruction.usesVars()) { + for (Register use : instruction.uses()) { + if (!block.varKill.get(use.id)) + block.UEVar.set(use.id); + } + } + if (instruction.definesVar()) { + Register def = instruction.def(); + block.varKill.set(def.id); + } + } + } + } + + private void computeLiveness(List blocks) { + boolean changed = true; + while (changed) { + changed = false; + for (BasicBlock block : blocks) { + if (recomputeLiveOut(block)) + changed = true; + } + } + } + + private boolean recomputeLiveOut(BasicBlock block) { + BitSet oldLiveOut = (BitSet) block.liveOut.clone(); + for (BasicBlock m: block.successors) { + BitSet mLiveIn = (BitSet) m.liveOut.clone(); + // LiveOut(m) intersect not VarKill(m) + mLiveIn.andNot(m.varKill); + // UEVar(m) union (LiveOut(m) intersect not VarKill(m)) + mLiveIn.or(m.UEVar); + // LiveOut(block) =union (UEVar(m) union (LiveOut(m) intersect not VarKill(m))) + block.liveOut.or(mLiveIn); + } + return !oldLiveOut.equals(block.liveOut); + } +} diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Register.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Register.java index 5b9f3b7..259466e 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Register.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/Register.java @@ -10,7 +10,7 @@ public class Register { /** * Unique virtual ID */ - private final int id; + public final int id; /** * The base name - for local variables and function params this should be the name * in the source program. For temps this is a made up name. diff --git a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/RegisterPool.java b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/RegisterPool.java index 5013366..7a011e0 100644 --- a/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/RegisterPool.java +++ b/optvm/src/main/java/com/compilerprogramming/ezlang/compiler/RegisterPool.java @@ -5,7 +5,7 @@ import java.util.ArrayList; /** - * The RegisterPool is ued when compiling functions + * The RegisterPool is used when compiling functions * to assign IDs to registers. Initially the registers get * sequential IDs. For SSA registers we assign new IDs but also * retain the old ID and attach a version number - the old ID is @@ -48,4 +48,10 @@ public Register.SSARegister ssaReg(Register original, int version) { public int numRegisters() { return registers.size(); } + + public void toStr(StringBuilder sb) { + for (Register reg : registers) { + sb.append("Reg #").append(reg.id).append(" ").append(reg.name()).append("\n"); + } + } } diff --git a/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestLiveness.java b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestLiveness.java new file mode 100644 index 0000000..da2890e --- /dev/null +++ b/optvm/src/test/java/com/compilerprogramming/ezlang/compiler/TestLiveness.java @@ -0,0 +1,198 @@ +package com.compilerprogramming.ezlang.compiler; + +import com.compilerprogramming.ezlang.types.Symbol; +import com.compilerprogramming.ezlang.types.TypeDictionary; +import org.junit.Assert; +import org.junit.Test; + +public class TestLiveness { + + TypeDictionary compileSrc(String src) { + var compiler = new Compiler(); + return compiler.compileSrc(src); + } + + @Test + public void test1() { + String src = """ + func print(n: Int) {} + func foo() { + var i = 1 + var s = 1; + while (1) { + if (i == 5) + s = 0; + s = s + 1 + i = i + 1 + if (i < 10) + continue; + break; + } + print(s); + } + """; + var typeDict = compileSrc(src); + var funcSymbol = typeDict.lookup("foo"); + CompiledFunction func = (CompiledFunction) ((Symbol.FunctionTypeSymbol)funcSymbol).code(); + var liveness = new Liveness(); + liveness.computeLiveness(func); + String output = Compiler.dumpIR(typeDict, true); + Assert.assertEquals(output, """ +func print(n: Int) +Reg #0 %ret +Reg #1 n +L0: + arg n + goto L1 +L1: +func foo() +Reg #0 %ret +Reg #1 i +Reg #2 s +Reg #3 %t3 +Reg #4 %t4 +Reg #5 %t5 +Reg #6 %t6 +Reg #7 %t7 +L0: + i = 1 + s = 1 + goto L2 + #UEVAR = {} + #VARKILL = {1, 2} + #LIVEOUT = {1, 2} +L2: + if 1 goto L3 else goto L4 + #UEVAR = {} + #VARKILL = {} + #LIVEOUT = {1, 2} +L3: + %t3 = i==5 + if %t3 goto L5 else goto L6 + #UEVAR = {1} + #VARKILL = {3} + #LIVEOUT = {1, 2} +L5: + s = 0 + goto L6 + #UEVAR = {} + #VARKILL = {2} + #LIVEOUT = {1, 2} +L6: + %t4 = s+1 + s = %t4 + %t5 = i+1 + i = %t5 + %t6 = i<10 + if %t6 goto L7 else goto L8 + #UEVAR = {1, 2} + #VARKILL = {1, 2, 4, 5, 6} + #LIVEOUT = {1, 2} +L7: + goto L2 + #UEVAR = {} + #VARKILL = {} + #LIVEOUT = {1, 2} +L8: + goto L4 + #UEVAR = {} + #VARKILL = {} + #LIVEOUT = {2} +L4: + %t7 = s + call print params %t7 + goto L1 + #UEVAR = {2} + #VARKILL = {7} + #LIVEOUT = {} +L1: + #UEVAR = {} + #VARKILL = {} + #LIVEOUT = {} +"""); + } + + @Test + public void test2() { + String src = """ + func foo(a: Int, b: Int) { + while (b < 10) { + if (b < a) { + a = a * 7 + b = a + 1 + } + else { + a = b - 1 + } + } + } + """; + var typeDict = compileSrc(src); + var funcSymbol = typeDict.lookup("foo"); + CompiledFunction func = (CompiledFunction) ((Symbol.FunctionTypeSymbol)funcSymbol).code(); + var liveness = new Liveness(); + liveness.computeLiveness(func); + String output = Compiler.dumpIR(typeDict, true); + Assert.assertEquals(output, """ +func foo(a: Int,b: Int) +Reg #0 %ret +Reg #1 a +Reg #2 b +Reg #3 %t3 +Reg #4 %t4 +Reg #5 %t5 +Reg #6 %t6 +Reg #7 %t7 +L0: + arg a + arg b + goto L2 + #UEVAR = {} + #VARKILL = {1, 2} + #LIVEOUT = {1, 2} +L2: + %t3 = b<10 + if %t3 goto L3 else goto L4 + #UEVAR = {2} + #VARKILL = {3} + #LIVEOUT = {1, 2} +L3: + %t4 = b