Skip to content

Commit

Permalink
Isolate disjoint-set forest and use union by rank (#10611)
Browse files Browse the repository at this point in the history
Hopefully fixes time complexity of NB partitioning
  • Loading branch information
phannebohm committed Apr 25, 2023
1 parent 65d028c commit 8c558c5
Showing 1 changed file with 87 additions and 42 deletions.
129 changes: 87 additions & 42 deletions OMCompiler/Compiler/NBackEnd/Modules/1_Main/NBPartitioning.mo
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,79 @@ protected
end toSystem;
end Cluster;

// Perhaps this deserves its own place in Util/*.mo
uniontype DisjointSetForest
"Custom implementation of disjoint-set data structure with constant number of elements."
record FOREST
Pointer<array<Integer>> parent;
Pointer<array<Integer>> rank;
end FOREST;

function new
"Creates n disjoit subsets of size 1."
input Integer n;
output DisjointSetForest dsf;
algorithm
dsf := FOREST(
parent = Pointer.create(listArray(list(i for i in 1:n))),
rank = Pointer.create(arrayCreate(n, 0))
);
end new;

function find
input DisjointSetForest dsf;
input output Integer index;
protected
array<Integer> parent = Pointer.access(dsf.parent);
algorithm
while index <> parent[index] loop
parent[index] := parent[parent[index]] "path halving";
index := parent[index];
end while;
Pointer.update(dsf.parent, parent);
end find;

function unite
input DisjointSetForest dsf;
input list<Integer> indices;
output Integer root;
protected
list<Integer> roots = list(find(dsf, i) for i in indices);
array<Integer> parent = Pointer.access(dsf.parent);
array<Integer> rank = Pointer.access(dsf.rank);
Integer maxRank;
Boolean tied = false;
algorithm
// find root with highest rank
root := listHead(roots);
maxRank := rank[root];
for r in listRest(roots) loop
if r <> root then
if rank[r] > maxRank then
root := r;
maxRank := rank[root];
tied := false;
elseif rank[r] == maxRank then
tied := true;
end if;
end if;
end for;

// update parents
for r in roots loop
parent[find(dsf, r)] := root;
end for;

// if necessary increment rank
if tied then
rank[root] := rank[root] + 1;
end if;

Pointer.update(dsf.parent, parent);
Pointer.update(dsf.rank, rank);
end unite;
end DisjointSetForest;

function partitioningNone extends Module.partitioningInterface;
protected
Boolean isInit = systemType == System.SystemType.INI;
Expand All @@ -249,69 +322,41 @@ protected

function partitioningClocked extends Module.partitioningInterface;
protected
array<Integer> eqn_map = arrayCreate(equations.eqArr.lastUsedIndex[1], -1);
DisjointSetForest eqn_dsf = DisjointSetForest.new(equations.eqArr.lastUsedIndex[1]);
array<Integer> var_map = arrayCreate(variables.varArr.lastUsedIndex[1], -1);
Pointer<Equation> eqn;
UnorderedSet<ComponentRef> var_crefs;
list<ComponentRef> var_cref_list;
list<Integer> local_indices;
Integer part_idx, root_idx, idx;
list<Integer> var_indices;
Integer part_idx;
UnorderedMap<Integer, Cluster> cluster_map = UnorderedMap.new<Cluster>(Util.id, intEq);
ComponentRef name_cref;
Cluster cluster;
Pointer<Integer> index = Pointer.create(1);
array<Boolean> marked_vars;
list<Pointer<Variable>> single_vars;
algorithm
for eq_idx in UnorderedMap.valueList(equations.map) loop
if eq_idx > 0 then
eqn_map[eq_idx] := eq_idx;
eqn := EquationPointers.getEqnAt(equations, eq_idx);
var_crefs := UnorderedSet.new(ComponentRef.hash, ComponentRef.isEqual);

// collect all crefs in equation
_ := Equation.map(Pointer.access(eqn), function collectPartitioningCrefs(var_crefs = var_crefs), NONE(), Expression.mapReverse);
var_cref_list := UnorderedSet.toList(var_crefs);

// find minimal partition index for current equation and all connected variables
local_indices := list(VariablePointers.getVarIndex(variables, cref) for cref in var_cref_list);
// find all indices of connected variables
var_indices := list(VariablePointers.getVarIndex(variables, cref) for cref in UnorderedSet.toList(var_crefs));
// filter indices of non existant variables (e.g. time)
local_indices := list(i for i guard(i > 0) in local_indices);
part_idx := intMin(i for i in eq_idx :: list(var_map[j] for j guard(var_map[j] > 0) in local_indices));
// find root index
while part_idx <> eqn_map[part_idx] loop
part_idx := eqn_map[part_idx];
end while;
eqn_map[eq_idx] := part_idx;

// update connected variable partition indices and further connected equation partition indices
for i in local_indices loop
if var_map[i] > 0 then
// find root index and connect the whole path to part_idx
root_idx := var_map[i];
while root_idx <> eqn_map[root_idx] loop
idx := root_idx;
root_idx := eqn_map[root_idx];
eqn_map[idx] := part_idx;
end while;
eqn_map[root_idx] := part_idx;
end if;
var_indices := list(i for i guard(i > 0) in var_indices);

// unite current equation and all variables that already belong to a partition
part_idx := DisjointSetForest.unite(eqn_dsf, eq_idx :: list(var_map[j] for j guard(var_map[j] > 0) in var_indices));

// update connected variable partition indices
for i in var_indices loop
var_map[i] := part_idx;
end for;
end if;
end for;

// canonicalize eqn_map
for eq_idx in UnorderedMap.valueList(equations.map) loop
if eq_idx > 0 then
root_idx := eq_idx;
while root_idx <> eqn_map[root_idx] loop
root_idx := eqn_map[root_idx];
end while;
eqn_map[eq_idx] := root_idx;
end if;
end for;

// find and report variables that could not be assigned to a partition
marked_vars := listArray(list(var_map[var_idx] < 0 for var_idx in UnorderedMap.valueList(variables.map)));
single_vars := VariablePointers.getMarkedVars(variables, marked_vars);
Expand All @@ -327,14 +372,14 @@ protected
for eq_idx in UnorderedMap.valueList(equations.map) loop
if eq_idx > 0 then
name_cref := Equation.getEqnName(EquationPointers.getEqnAt(equations, eq_idx));
UnorderedMap.addUpdate(eqn_map[eq_idx], function Cluster.addElement(cref = name_cref, ty = ClusterElementType.EQUATION), cluster_map);
UnorderedMap.addUpdate(DisjointSetForest.find(eqn_dsf, eq_idx), function Cluster.addElement(cref = name_cref, ty = ClusterElementType.EQUATION), cluster_map);
end if;
end for;

for var_idx in UnorderedMap.valueList(variables.map) loop
if var_idx > 0 then
name_cref := BVariable.getVarName(VariablePointers.getVarAt(variables, var_idx));
UnorderedMap.addUpdate(eqn_map[var_map[var_idx]], function Cluster.addElement(cref = name_cref, ty = ClusterElementType.VARIABLE), cluster_map);
UnorderedMap.addUpdate(DisjointSetForest.find(eqn_dsf, var_map[var_idx]), function Cluster.addElement(cref = name_cref, ty = ClusterElementType.VARIABLE), cluster_map);
end if;
end for;

Expand Down

0 comments on commit 8c558c5

Please sign in to comment.