Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,78 @@
import java.util.stream.IntStream;

/**
* Implement the original graph coloring algorithm described by Chaitin.
* Implements the original graph coloring algorithm described by Chaitin.
* Since we are targeting an abstract machine where there are no limits on
* number of registers except how we set them, our goal here is to get to
* the minimum number of registers required to execute the function.
* <p>
* We do want to implement spilling even though we do not need it for the
* abstract machine, but it is not yet implemented. We would spill to a
* stack attached to the abstract machine.
*
* TODO spilling
*/
public class ChaitinGraphColoringRegisterAllocator {

public ChaitinGraphColoringRegisterAllocator() {
}

public Map<Integer, Integer> assignRegisters(CompiledFunction function, int numRegisters) {
if (function.isSSA) throw new IllegalStateException("Register allocation should be done after exiting SSA");
var g = coalesce(function);
var registers = registersInIR(function);
var colors = IntStream.range(0, numRegisters).boxed().toList();
// TODO pre-assign regs to args
// Remove useless copy operations
InterferenceGraph g = coalesce(function);
// Get used registers
Set<Integer> registers = registersInIR(function);
// Create color set
List<Integer> colors = new ArrayList<>(IntStream.range(0, numRegisters).boxed().toList());
// Function args are pre-assigned colors
// and we remove them from the register set
Map<Integer, Integer> assignments = preAssignArgsToColors(function, registers, colors);
// TODO spilling
var assignments = colorGraph(g, registers, new HashSet<>(colors));
// execute graph coloring on remaining registers
assignments = colorGraph(g, registers, new HashSet<>(colors), assignments);
// update all instructions
// We simply set the slot on each register - rather than actually trying to replace them
updateInstructions(function, assignments);
// Compute and set the new framesize
function.setFrameSize(computeFrameSize(assignments));
return assignments;
}

/**
* Frame size = max number of registers needed to execute the function
*/
private int computeFrameSize(Map<Integer, Integer> assignments) {
return assignments.values().stream().mapToInt(k->k).max().orElse(0);
}

/**
* Due to the way function args are received by the abstract machine, we need
* to assign them register slots starting from 0. After assigning colors/slots
* we remove these from the set so that the graph coloring algo does
*/
private Map<Integer, Integer> preAssignArgsToColors(CompiledFunction function, Set<Integer> registers, List<Integer> colors) {
int count = 0;
Map<Integer, Integer> assignments = new HashMap<>();
for (Instruction instruction : function.entry.instructions) {
if (instruction instanceof Instruction.ArgInstruction argInstruction) {
Integer color = colors.get(count);
Register reg = argInstruction.arg().reg;
registers.remove(reg.nonSSAId()); // Remove register from set before changing slot
assignments.put(reg.nonSSAId(), color);
count++;
}
else break;
}
return assignments;
}

private void updateInstructions(CompiledFunction function, Map<Integer, Integer> assignments) {
var regPool = function.registerPool;
for (var entry : assignments.entrySet()) {
int reg = entry.getKey();
int slot = entry.getValue();
regPool.getReg(reg).updateSlot(slot);
}
}

/**
* Chaitin: coalesce_nodes - coalesce away copy operations
*/
Expand Down Expand Up @@ -85,9 +137,7 @@ private void rewriteInstructions(CompiledFunction function, Instruction deadInst
private Set<Integer> registersInIR(CompiledFunction function) {
Set<Integer> registers = new HashSet<>();
for (var block: function.getBlocks()) {
Iterator<Instruction> iter = block.instructions.iterator();
while (iter.hasNext()) {
Instruction instruction = iter.next();
for (Instruction instruction: block.instructions) {
if (instruction.definesVar())
registers.add(instruction.def().id);
for (Register use: instruction.uses())
Expand All @@ -112,7 +162,7 @@ private Integer findNodeWithNeighborCountLessThan(InterferenceGraph g, Set<Integ
private Set<Integer> getNeighborColors(InterferenceGraph g, Integer node, Map<Integer,Integer> assignedColors) {
Set<Integer> colors = new HashSet<>();
for (var neighbour: g.neighbors(node)) {
var c = assignedColors.get(neighbour);
Integer c = assignedColors.get(neighbour);
if (c != null) {
colors.add(c);
}
Expand All @@ -137,18 +187,18 @@ private static HashSet<Integer> subtract(Set<Integer> originalSet, Integer node)
/**
* Chaitin: color_graph
*/
private Map<Integer, Integer> colorGraph(InterferenceGraph g, Set<Integer> nodes, Set<Integer> colors) {
private Map<Integer, Integer> colorGraph(InterferenceGraph g, Set<Integer> nodes, Set<Integer> colors, Map<Integer, Integer> preAssignedColors) {
if (nodes.size() == 0)
return new HashMap<>();
var numColors = colors.size();
var node = findNodeWithNeighborCountLessThan(g, nodes, numColors);
return preAssignedColors;
int numColors = colors.size();
Integer node = findNodeWithNeighborCountLessThan(g, nodes, numColors);
if (node == null)
return null;
var coloring = colorGraph(g.dup().subtract(node), subtract(nodes, node), colors);
Map<Integer, Integer> coloring = colorGraph(g.dup().subtract(node), subtract(nodes, node), colors, preAssignedColors);
if (coloring == null)
return null;
var neighbourColors = getNeighborColors(g, node, coloring);
var color = chooseSomeColorNotAssignedToNeighbors(colors, neighbourColors);
Set<Integer> neighbourColors = getNeighborColors(g, node, coloring);
Integer color = chooseSomeColorNotAssignedToNeighbors(colors, neighbourColors);
coloring.put(node, color);
return coloring;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public class CompiledFunction {
private Type.TypeFunction functionType;
public final RegisterPool registerPool;

private final int frameSlots;
private int frameSlots;

public boolean isSSA;
public boolean hasLiveness;
Expand Down Expand Up @@ -76,6 +76,9 @@ private void generateArgInstructions(Scope scope) {
public int frameSize() {
return frameSlots;
}
public void setFrameSize(int size) {
frameSlots = size;
}

private void exitBlockIfNeeded() {
if (currentBlock != null &&
Expand Down Expand Up @@ -134,6 +137,7 @@ private void compileStatement(AST.Stmt statement) {
case AST.VarStmt letStmt -> {
compileLet(letStmt);
}
case AST.VarDeclStmt varDeclStmt -> {}
case AST.IfElseStmt ifElseStmt -> {
compileIf(ifElseStmt);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,32 @@
import com.compilerprogramming.ezlang.types.Type;
import com.compilerprogramming.ezlang.types.TypeDictionary;

import java.util.BitSet;

public class Compiler {

private void compile(TypeDictionary typeDictionary) {
private void compile(TypeDictionary typeDictionary, boolean opt) {
for (Symbol symbol: typeDictionary.getLocalSymbols()) {
if (symbol instanceof Symbol.FunctionTypeSymbol functionSymbol) {
Type.TypeFunction functionType = (Type.TypeFunction) functionSymbol.type;
functionType.code = new CompiledFunction(functionSymbol);
var function = new CompiledFunction(functionSymbol);
functionType.code = function;
if (opt) {
new Optimizer().optimize(function);
}
}
}
}
public TypeDictionary compileSrc(String src) {
return compileSrc(src, false);
}
public TypeDictionary compileSrc(String src, boolean opt) {
Parser parser = new Parser();
var program = parser.parse(new Lexer(src));
var typeDict = new TypeDictionary();
var sema = new SemaDefineTypes(typeDict);
sema.analyze(program);
var sema2 = new SemaAssignTypes(typeDict);
sema2.analyze(program);
compile(typeDict);
compile(typeDict, opt);
return typeDict;
}
public static String dumpIR(TypeDictionary typeDictionary) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.compilerprogramming.ezlang.compiler;

import com.compilerprogramming.ezlang.exceptions.CompilerException;

import java.util.*;

/**
Expand Down Expand Up @@ -178,7 +180,11 @@ static class BBSet {
static class VersionStack {
List<Register.SSARegister> stack = new ArrayList<>();
void push(Register.SSARegister r) { stack.add(r); }
Register.SSARegister top() { return stack.getLast(); }
Register.SSARegister top() {
if (stack.isEmpty())
throw new CompilerException("Variable may not be initialized");
return stack.getLast();
}
void pop() { stack.removeLast(); }
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ public ExitSSA(CompiledFunction function) {
initStack();
insertCopies(function.entry);
removePhis();
function.isSSA = false;
}

private void removePhis() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public String toString() {
}

public static class RegisterOperand extends Operand {
public Register reg;
Register reg;
public RegisterOperand(Register reg) {
this.reg = reg;
if (reg == null)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package com.compilerprogramming.ezlang.compiler;

public class Optimizer {

public void optimize(CompiledFunction function) {
new EnterSSA(function);
new ExitSSA(function);
new ChaitinGraphColoringRegisterAllocator().assignRegisters(function, 64);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ public class Register {
* The type of a register
*/
public final Type type;
private int slot;

public Register(int id, String name, Type type) {
this.id = id;
this.name = name;
this.type = type;
this.slot = id;
}
@Override
public boolean equals(Object o) {
Expand All @@ -44,7 +46,10 @@ public String name() {
return name;
}
public int nonSSAId() {
return id;
return slot;
}
public void updateSlot(int slot) {
this.slot = slot;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.compilerprogramming.ezlang.types.Type;
import com.compilerprogramming.ezlang.types.TypeDictionary;
import org.junit.Assert;
import org.junit.Ignore;
import org.junit.Test;

import java.util.Arrays;
Expand Down Expand Up @@ -766,4 +767,51 @@ public void testSwapProblem() {
Assert.assertEquals(expected, function.toStr(new StringBuilder(), false).toString());
}

@Test
public void testLiveness() {
String src = """
func bar(x: Int)->Int {
var y = 0
var z = 0
while( x>1 ){
y = x/2;
if (y > 3) {
x = x-y;
}
z = x-4;
if (z > 0) {
x = x/2;
}
z = z-1;
}
return x;
}

func foo() {
return bar(10);
}
""";
String result = compileSrc(src);
System.out.println(result);
}

@Test
@Ignore
public void testInit() {
// see issue #16
String src = """
func foo(x: Int) {
var z: Int
while (x > 0) {
z = 5
if (x == 1)
z = z+1
x = x - 1
}
}
""";
String result = compileSrc(src);
System.out.println(result);
}

}
Loading
Loading