Skip to content

Commit

Permalink
[NB] update record inlining (#11957)
Browse files Browse the repository at this point in the history
- if the record equation is inside for loop, carry the iterator over
  • Loading branch information
kabdelhak committed Feb 6, 2024
1 parent 6abca0b commit 2ab36e8
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions OMCompiler/Compiler/NBackEnd/Modules/2_Pre/NBInline.mo
Expand Up @@ -161,7 +161,7 @@ public
Pointer<Integer> index = EqData.getUniqueIndex(eqData);
Pointer<list<Pointer<Equation>>> new_eqns = Pointer.create({});
algorithm
eqData := EqData.map(eqData, function inlineRecordEquation(variables = variables, record_eqns = new_eqns, index = index, inlineSimple = true));
eqData := EqData.map(eqData, function inlineRecordEquation(iter = Iterator.EMPTY(), variables = variables, record_eqns = new_eqns, index = index, inlineSimple = true));
eqData := EqData.addUntypedList(eqData, Pointer.access(new_eqns), false);
eqData := EqData.compress(eqData);
end inlineRecords;
Expand All @@ -175,7 +175,7 @@ public
protected
Pointer<list<Pointer<Equation>>> record_eqns = Pointer.create({});
algorithm
inlineRecordEquation(Pointer.access(Slice.getT(slice)), variables, record_eqns, index, inlineSimple);
inlineRecordEquation(Pointer.access(Slice.getT(slice)), Iterator.EMPTY(), variables, record_eqns, index, inlineSimple);
// somehow split slice.indices
slices := list(Slice.SLICE(eqn, {}) for eqn in Pointer.access(record_eqns));
end inlineRecordSliceEquation;
Expand Down Expand Up @@ -226,7 +226,7 @@ protected
Pointer<Integer> index = EqData.getUniqueIndex(eqData);
Pointer<list<Pointer<Equation>>> new_eqns = Pointer.create({});
algorithm
eqData := EqData.map(eqData, function inlineRecordEquation(variables = variables, record_eqns = new_eqns, index = index, inlineSimple = false));
eqData := EqData.map(eqData, function inlineRecordEquation(iter = Iterator.EMPTY(), variables = variables, record_eqns = new_eqns, index = index, inlineSimple = false));
eqData := EqData.map(eqData, function inlineTupleEquation(index = index, iter = Iterator.EMPTY(), tuple_eqns = new_eqns));
eqData := EqData.addUntypedList(eqData, Pointer.access(new_eqns), false);
eqData := EqData.compress(eqData);
Expand All @@ -237,6 +237,7 @@ protected
and appends new equations to the mutable list.
EquationPointers.compress() should be used afterwards to remove the dummy equations."
input output Equation eqn;
input Iterator iter;
input VariablePointers variables;
input Pointer<list<Pointer<Equation>>> record_eqns;
input Pointer<Integer> index;
Expand All @@ -246,6 +247,7 @@ protected
local
Equation new_eqn;
Integer size;
String str;

// don't inline simple cref equalities
case Equation.RECORD_EQUATION(lhs = Expression.CREF(), rhs = Expression.CREF()) guard(not inlineSimple) then eqn;
Expand All @@ -254,8 +256,16 @@ protected
// try to inline other record equations. try catch to be sure to not discard
case Equation.RECORD_EQUATION(ty = Type.COMPLEX()) algorithm
try
if Flags.isSet(Flags.DUMPBACKENDINLINE) then print("[" + getInstanceName() + "] Inlining: " + Equation.toString(eqn) + "\n"); end if;
new_eqn := inlineRecordEquationWork(eqn.lhs, eqn.rhs, eqn.attr, eqn.source, eqn.recordSize, variables, record_eqns, index, inlineSimple);
if Flags.isSet(Flags.DUMPBACKENDINLINE) then
str := "[" + getInstanceName() + "] Inlining: ";
if Iterator.isEmpty(iter) then
str := str + Equation.toString(eqn);
else
str := str + "\n" + Equation.forEquationToString(iter, {eqn}, "", "[----] ", "[FOR-] " + "(" + intString(Equation.size(Pointer.create(eqn)) * Iterator.size(iter)) + ")" + EquationAttributes.toString(eqn.attr, " "));
end if;
print(str + "\n");
end if;
new_eqn := inlineRecordEquationWork(eqn.lhs, eqn.rhs, iter, eqn.attr, eqn.recordSize, variables, record_eqns, index, inlineSimple);
if Flags.isSet(Flags.DUMPBACKENDINLINE) then print("\n"); end if;
else
// inlining failed, keep old equation
Expand All @@ -267,7 +277,7 @@ protected
case Equation.ARRAY_EQUATION(recordSize = SOME(size)) algorithm
try
if Flags.isSet(Flags.DUMPBACKENDINLINE) then print("[" + getInstanceName() + "] Inlining: " + Equation.toString(eqn) + "\n"); end if;
new_eqn := inlineRecordEquationWork(eqn.lhs, eqn.rhs, eqn.attr, eqn.source, size, variables, record_eqns, index, inlineSimple);
new_eqn := inlineRecordEquationWork(eqn.lhs, eqn.rhs, iter, eqn.attr, size, variables, record_eqns, index, inlineSimple);
else
// inlining failed, keep old equation
new_eqn := eqn;
Expand All @@ -276,8 +286,9 @@ protected

// iterate over body equations of for-loop
case Equation.FOR_EQUATION() algorithm
eqn.body := list(inlineRecordEquation(body_eqn, variables, record_eqns, index, inlineSimple) for body_eqn in eqn.body);
then eqn;
new_eqn := inlineRecordEquation(List.first(eqn.body), eqn.iter, variables, record_eqns, index, inlineSimple);
new_eqn := if Equation.isDummy(new_eqn) then new_eqn else eqn;
then new_eqn;

else eqn;
end match;
Expand All @@ -286,8 +297,8 @@ protected
function inlineRecordEquationWork
input Expression lhs;
input Expression rhs;
input Iterator iter;
input EquationAttributes attr;
input DAE.ElementSource src;
input Integer recordSize;
input VariablePointers variables;
input Pointer<list<Pointer<Equation>>> record_eqns;
Expand All @@ -314,12 +325,12 @@ protected
new_rhs := Expression.map(new_rhs, function BackendDAE.lowerComponentReferenceExp(variables = variables));

// create new equation
tmp_eqn := Equation.fromLHSandRHS(new_lhs, new_rhs, index, NBEquation.SIMULATION_STR, attr, src);
tmp_eqn := Equation.makeAssignment(new_lhs, new_rhs, index, NBEquation.SIMULATION_STR, iter, attr);

// if the equation still has a record type, inline it further
if Equation.isRecordEquation(tmp_eqn) then
tmp_eqns_ptr := Pointer.create(tmp_eqns);
_ := inlineRecordEquation(Pointer.access(tmp_eqn), variables, tmp_eqns_ptr, index, inlineSimple);
_ := inlineRecordEquation(Pointer.access(tmp_eqn), iter, variables, tmp_eqns_ptr, index, inlineSimple);
tmp_eqns := Pointer.access(tmp_eqns_ptr);
else
tmp_eqns := tmp_eqn :: tmp_eqns;
Expand Down

0 comments on commit 2ab36e8

Please sign in to comment.