Skip to content

Commit

Permalink
[NB] update record array handling (OpenModelica#12583)
Browse files Browse the repository at this point in the history
- use full map to correctly determine if something is an unknown when skipping to a subset of a record (even if its not part of the current adjacency update)
 - record array entry to adjacency updated
 - type checking updated
 - udapte debugging and error handling
  • Loading branch information
kabdelhak committed Jun 14, 2024
1 parent 7af7386 commit 025d351
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 48 deletions.
2 changes: 1 addition & 1 deletion OMCompiler/Compiler/NBackEnd/Classes/NBackendDAE.mo
Original file line number Diff line number Diff line change
Expand Up @@ -1117,7 +1117,7 @@ protected
list<ComponentRef> inputs, outputs;
EquationAttributes attr;
algorithm
size := sum(ComponentRef.size(out) for out in alg.outputs);
size := sum(ComponentRef.size(out, true) for out in alg.outputs);
if listEmpty(alg.outputs) then
attr := EquationAttributes.default(EquationKind.EMPTY, init);
Expand Down
2 changes: 0 additions & 2 deletions OMCompiler/Compiler/NBackEnd/Modules/1_Main/NBCausalize.mo
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,6 @@ protected
// #################################################
vn := UnorderedMap.subMap(system.unknowns.map, list(BVariable.getVarName(var) for var in unfixable));
en := UnorderedMap.subMap(system.equations.map, list(Equation.getEqnName(eqn) for eqn in initials));

adj_matching := Adjacency.Matrix.fromFull(full, vn, en, system.equations, NBAdjacency.MatrixStrictness.MATCHING);
matching := Matching.regular(NBMatching.EMPTY_MATCHING, adj_matching, true, true);

Expand All @@ -267,7 +266,6 @@ protected
eo := en;
vn := UnorderedMap.new<Integer>(ComponentRef.hash, ComponentRef.isEqual);
en := UnorderedMap.subMap(system.equations.map, list(Equation.getEqnName(eqn) for eqn in simulation));

(adj_matching, full) := Adjacency.Matrix.expand(adj_matching, full, vo, vn, eo, en, system.unknowns, system.equations);
matching := Matching.regular(matching, adj_matching, true, true);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ protected
new_exp := match ty
case Type.TUPLE() algorithm
names := list(Call_Aux.createName(sub_ty, new_iter, index, init) for sub_ty in ty.types);
tpl_lst := list(if ComponentRef.size(cref) == 0 then Expression.fromCref(ComponentRef.WILD()) else Expression.fromCref(cref) for cref in names);
tpl_lst := list(if ComponentRef.size(cref, true) == 0 then Expression.fromCref(ComponentRef.WILD()) else Expression.fromCref(cref) for cref in names);
then Expression.TUPLE(ty, tpl_lst);
else algorithm
name := Call_Aux.createName(ty, new_iter, index, init);
Expand Down
28 changes: 19 additions & 9 deletions OMCompiler/Compiler/NBackEnd/Util/NBAdjacency.mo
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ public
index := UnorderedMap.getSafe(name, eqns_map, sourceInfo());
filtered := Solvability.filter(UnorderedSet.toList(occ[index]), sol[index], vars_map, min, max);
// upgrade the row and all meta data
upgradeRow(EquationPointers.getEqnAt(eqns, index), index, filtered, dep[index], rep[index], vars_map, adj.m, adj.mapping, adj.modes);
upgradeRow(EquationPointers.getEqnAt(eqns, index), index, filtered, dep[index], rep[index], vars_map, vars_map, adj.m, adj.mapping, adj.modes);
end for;
adj.mT := transposeScalar(adj.m, arrayLength(adj.mapping.var_StA));
result := adj;
Expand Down Expand Up @@ -554,10 +554,10 @@ public
Integer size_vo, size_vn, size_eo, size_en; //only for debugging
algorithm
if Flags.isSet(Flags.BLT_MATRIX_DUMP) then
size_vo := sum(ComponentRef.size(var) for var in UnorderedMap.keyList(vo));
size_vn := sum(ComponentRef.size(var) for var in UnorderedMap.keyList(vn)) + size_vo;
size_eo := sum(ComponentRef.size(eqn) for eqn in UnorderedMap.keyList(eo));
size_en := sum(ComponentRef.size(eqn) for eqn in UnorderedMap.keyList(en)) + size_eo;
size_vo := sum(ComponentRef.size(var, true) for var in UnorderedMap.keyList(vo));
size_vn := sum(ComponentRef.size(var, true) for var in UnorderedMap.keyList(vn)) + size_vo;
size_eo := sum(ComponentRef.size(eqn, true) for eqn in UnorderedMap.keyList(eo));
size_en := sum(ComponentRef.size(eqn, true) for eqn in UnorderedMap.keyList(en)) + size_eo;
print(StringUtil.headline_1("Expanding from size [vars: " + intString(size_vo) + "| eqns: " + intString(size_eo) + "] to [vars: " + intString(size_vn) + "| eqns: " + intString(size_en) + "]") + "\n");
end if;

Expand Down Expand Up @@ -600,15 +600,15 @@ public
if not UnorderedMap.isEmpty(vn) then
for e in UnorderedMap.valueList(eo) loop
filtered := Solvability.filter(UnorderedSet.toList(full.occurences[e]), full.solvabilities[e], vn, 0, rank);
upgradeRow(EquationPointers.getEqnAt(eqns, e), e, filtered, full.dependencies[e], full.repetitions[e], vn, adj.m, adj.mapping, adj.modes);
upgradeRow(EquationPointers.getEqnAt(eqns, e), e, filtered, full.dependencies[e], full.repetitions[e], vn, vars.map, adj.m, adj.mapping, adj.modes);
end for;
end if;

// II. update new equations with all variables
if not UnorderedMap.isEmpty(en) then
for e in UnorderedMap.valueList(en) loop
filtered := Solvability.filter(UnorderedSet.toList(full.occurences[e]), full.solvabilities[e], v, 0, rank);
upgradeRow(EquationPointers.getEqnAt(eqns, e), e, filtered, full.dependencies[e], full.repetitions[e], v, adj.m, adj.mapping, adj.modes);
upgradeRow(EquationPointers.getEqnAt(eqns, e), e, filtered, full.dependencies[e], full.repetitions[e], v, vars.map, adj.m, adj.mapping, adj.modes);
end for;
end if;

Expand Down Expand Up @@ -1175,6 +1175,7 @@ public
input UnorderedMap<ComponentRef, Dependency> dep "dependency map";
input UnorderedSet<ComponentRef> rep "repetition set";
input UnorderedMap<ComponentRef, Integer> map "unordered map to check for relevance";
input UnorderedMap<ComponentRef, Integer> fullmap "unordered map to check for general relevance";
input array<list<Integer>> m;
input Mapping mapping;
input UnorderedMap<Mode.Key, Mode> modes;
Expand All @@ -1196,7 +1197,7 @@ public
end for;
else
// todo: if, when single equation (needs to be updated for if)
Slice.upgradeRow(Equation.getEqnName(eqn_ptr), eqn_arr_idx, iter, ty, dependencies, dep, rep, map, m, mapping, modes);
Slice.upgradeRow(Equation.getEqnName(eqn_ptr), eqn_arr_idx, iter, ty, dependencies, dep, rep, map, fullmap, m, mapping, modes);
end if;
else
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed for:\n" + Equation.pointerToString(eqn_ptr)});
Expand Down Expand Up @@ -1843,6 +1844,8 @@ public
protected
Pointer<Variable> var;
Integer skips = 1;
list<Subscript> subs;
list<Option<Integer>> int_subs;
algorithm
if UnorderedMap.contains(cref, map) then
if not UnorderedMap.contains(cref, dep_map) then
Expand All @@ -1853,7 +1856,14 @@ public
else
var := BVariable.getVarPointer(cref);
if BVariable.isRecord(var) then
crefs := List.flatten(list(collectDependenciesCref(BVariable.getVarName(child), map, dep_map, sol_map) for child in BVariable.getRecordChildren(var)));
subs := ComponentRef.subscriptsAllFlat(cref);
int_subs := list(Subscript.toIntegerOpt(sub) for sub in subs);
// get all Record children
crefs := list(BVariable.getVarName(child) for child in BVariable.getRecordChildren(var));
// add original subscripts
crefs := list(ComponentRef.mergeSubscripts(subs, child) for child in crefs);
// collect dependencies
crefs := List.flatten(list(collectDependenciesCref(child, map, dep_map, sol_map) for child in crefs));
for cref in crefs loop
Dependency.skip(cref, skips, dep_map);
skips := skips + 1;
Expand Down
54 changes: 36 additions & 18 deletions OMCompiler/Compiler/NBackEnd/Util/NBSlice.mo
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ public
stripped := ComponentRef.stripSubscriptsAll(cref);
var_arr_idx := UnorderedMap.getSafe(stripped, map, sourceInfo());
(var_start, _) := mapping.var_AtS[var_arr_idx];
sizes := ComponentRef.sizes(stripped);
sizes := ComponentRef.sizes(stripped, false);
int_subs := ComponentRef.subscriptsToInteger(cref);
var_scal_idx := locationToIndex(List.zip(sizes, int_subs), var_start);
indices := var_scal_idx :: indices;
Expand Down Expand Up @@ -1034,7 +1034,7 @@ public
stripped := ComponentRef.stripSubscriptsAll(cref);
var_arr_idx := UnorderedMap.getSafe(stripped, map, sourceInfo());
(var_start, _) := mapping.var_AtS[var_arr_idx];
sizes := ComponentRef.sizes(stripped);
sizes := ComponentRef.sizes(stripped, false);
int_subs := ComponentRef.subscriptsToInteger(cref);
var_scal_idx := locationToIndex(List.zip(sizes, int_subs), var_start);
indices := var_scal_idx :: indices;
Expand All @@ -1052,12 +1052,13 @@ public
input UnorderedMap<ComponentRef, Dependency> dep "dependency map";
input UnorderedSet<ComponentRef> rep "repetition set";
input UnorderedMap<ComponentRef, Integer> map "unordered map to check for relevance";
input UnorderedMap<ComponentRef, Integer> fullmap "unordered map to check for general relevance";
input array<list<Integer>> m;
input Mapping mapping "array <-> scalar index mapping";
input UnorderedMap<Mode.Key, Mode> modes;
algorithm
for cref in dependencies loop
resolveDependency(cref, eqn_name, eqn_arr_idx, iter, ty, dep, rep, map, m, mapping, modes);
resolveDependency(cref, eqn_name, eqn_arr_idx, iter, ty, dep, rep, map, fullmap, m, mapping, modes);
end for;
end upgradeRow;

Expand All @@ -1071,7 +1072,7 @@ protected
input output Type ty;
input list<Integer> skips;
input ComponentRef cref;
input UnorderedMap<ComponentRef, Integer> map "unordered map to check for relevance";
input UnorderedMap<ComponentRef, Integer> fullmap "unordered map to check for general relevance";
algorithm
(index, ty) := match (ty, skips)
local
Expand All @@ -1082,6 +1083,7 @@ protected
Pointer<Variable> parent;
list<ComponentRef> crefs;
ComponentRef field;
list<Subscript> subs;

// 0 skips are full dependencies
case (Type.TUPLE(types = rest_ty), 0::rest) then (index, ty);
Expand All @@ -1095,34 +1097,43 @@ protected
end for;
sub_ty :: rest_ty := rest_ty;
// see if there is nested skips
then resolveSkips(index, sub_ty, rest, cref, map);
then resolveSkips(index, sub_ty, rest, cref, fullmap);

// skip to a record element
case (Type.COMPLEX(complexTy = ComplexType.RECORD()), skip::rest) algorithm
// get the children and skip to correct one
field := match BVariable.getParent(BVariable.getVarPointer(cref))
case SOME(parent) algorithm
subs := ComponentRef.subscriptsAllFlat(cref);
crefs := list(BVariable.getVarName(child) for child in BVariable.getRecordChildren(parent));
crefs := list(c for c guard(UnorderedMap.contains(c, map)) in crefs);
for i in 1:skip-1 loop
crefs := list(c for c guard(UnorderedMap.contains(c, fullmap)) in crefs);
if skip <= listLength(crefs) then
for i in 1:skip-1 loop
field :: crefs := crefs;
field := ComponentRef.mergeSubscripts(subs, field);
index := index + Type.sizeOf(ComponentRef.getSubscriptedType(field));
end for;
field :: crefs := crefs;
index := index + Type.sizeOf(ComponentRef.getSubscriptedType(field));
end for;
field :: crefs := crefs;
field := ComponentRef.mergeSubscripts(subs, field);
else
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed because skip of " + intString(skip)
+ " is too large for record elements " + List.toString(crefs, ComponentRef.toString) + "."});
fail();
end if;
then field;
else algorithm
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed because skip of " + intString(skip)
+ " for type " + Type.toString(ty) + " is requested, but the cref is not part of a record:" + ComponentRef.toString(cref) + "."});
then fail();
end match;
// see if there is nested skips
then resolveSkips(index, ComponentRef.getSubscriptedType(field), rest, cref, map);
then resolveSkips(index, ComponentRef.getSubscriptedType(field), rest, cref, fullmap);

// skip to an array element
case (Type.ARRAY(), rest) guard(listLength(rest) >= listLength(ty.dimensions)) algorithm
(rest, tail) := List.split(rest, listLength(ty.dimensions));
index := locationToIndex(List.zip(list(Dimension.size(dim) for dim in ty.dimensions), rest), index);
then resolveSkips(index, ty.elementType, tail, cref, map);
then resolveSkips(index, ty.elementType, tail, cref, fullmap);

// skip for tuple or array, but the skip is too large
case (_, skip::_) guard(Type.isTuple(ty) or Type.isArray(ty)) algorithm
Expand Down Expand Up @@ -1184,6 +1195,7 @@ protected
input UnorderedMap<ComponentRef, Dependency> dep "dependency map";
input UnorderedSet<ComponentRef> rep "repetition set";
input UnorderedMap<ComponentRef, Integer> map "unordered map to check for relevance";
input UnorderedMap<ComponentRef, Integer> fullmap "unordered map to check for general relevance";
input array<list<Integer>> m;
input Mapping mapping "array <-> scalar index mapping";
input UnorderedMap<Mode.Key, Mode> modes;
Expand Down Expand Up @@ -1212,7 +1224,7 @@ protected
d := UnorderedMap.getSafe(cref, dep, sourceInfo());
(start, _) := mapping.eqn_AtS[eqn_arr_idx];
if not UnorderedSet.contains(cref, rep) then
(skip_idx, skip_ty) := resolveSkips(start, ty, d.skips, cref, map);
(skip_idx, skip_ty) := resolveSkips(start, ty, d.skips, cref, fullmap);
else
(skip_idx, skip_ty) := (start, ty);
end if;
Expand Down Expand Up @@ -1247,8 +1259,8 @@ protected
end for;
else
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " (single dependency) failed because list of scalar variables("
+ intString(scal_size) + ") " + List.toString(scalarized, ComponentRef.toString)
+ ", does not fit the equation size " + intString(size) + ".\n"});
+ intString(scal_size) + ") " + List.toString(scalarized, ComponentRef.toString)
+ ", does not fit the equation size " + intString(size) + ".\n"});
fail();
end if;

Expand Down Expand Up @@ -1376,8 +1388,14 @@ protected
input Mode mode;
algorithm
//print("adding eqn: " + intString(eqn_idx) + " var: " + intString(var_idx) + " with mode " + Mode.toString(mode) + "\n");
arrayUpdate(m, eqn_idx, var_idx :: m[eqn_idx]);
UnorderedMap.addUpdate((eqn_idx, var_idx), function Mode.mergeCreate(mode = mode), modes);
try
arrayUpdate(m, eqn_idx, var_idx :: m[eqn_idx]);
UnorderedMap.addUpdate((eqn_idx, var_idx), function Mode.mergeCreate(mode = mode), modes);
else
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName() + " failed because index " + intString(eqn_idx)
+ " could not be added. Matrix size: " + intString(arrayLength(m)) + "."});
fail();
end try;
end addMatrixEntry;

function resolveReductions
Expand Down Expand Up @@ -1502,7 +1520,7 @@ protected
stripped := if listEmpty(frames) then cref else ComponentRef.stripSubscriptsAll(cref);
var_arr_idx := UnorderedMap.getSafe(stripped, map, sourceInfo());
(var_start, _) := mapping.var_AtS[var_arr_idx];
sizes := ComponentRef.sizes(stripped);
sizes := ComponentRef.sizes(stripped, false);
subs := ComponentRef.subscriptsToExpression(cref, true);
scal_lst := listReverse(combineFrames2Indices(var_start, sizes, subs, frames, UnorderedMap.new<Expression>(ComponentRef.hash, ComponentRef.isEqual)));
end getCrefInFrameIndices;
Expand Down
13 changes: 8 additions & 5 deletions OMCompiler/Compiler/NFFrontEnd/NFComponentRef.mo
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ public
list<Subscript> subs;

case CREF(subscripts = {}) guard(not backendCref(cref)) algorithm
sizes_ := sizes_local(cref);
sizes_ := sizes_local(cref, false);
subs := {};
for size in listReverse(sizes_) loop
if size <> 1 then
Expand Down Expand Up @@ -1619,26 +1619,29 @@ public

function size
input ComponentRef cref;
output Integer s = product(i for i in sizes(cref));
input Boolean withComplex;
output Integer s = product(i for i in sizes(cref, withComplex));
end size;

function sizes
input ComponentRef cref;
input Boolean withComplex;
input output list<Integer> s_lst = {};
algorithm
s_lst := match cref
local
list<Integer> local_lst = {};
case EMPTY() then listReverse(s_lst);
case CREF() algorithm
local_lst := sizes_local(cref);
local_lst := sizes_local(cref, withComplex);
s_lst := listAppend(local_lst, s_lst);
then sizes(cref.restCref, s_lst);
then sizes(cref.restCref, withComplex, s_lst);
end match;
end sizes;

function sizes_local
input ComponentRef cref;
input Boolean withComplex;
output list<Integer> s_lst = {};
protected
Option<Integer> complex_size;
Expand All @@ -1648,7 +1651,7 @@ public
case CREF() algorithm
complex_size := Type.complexSize(cref.ty);
s_lst := list(Dimension.size(dim) for dim in Type.arrayDims(cref.ty));
if Util.isSome(complex_size) then
if withComplex and Util.isSome(complex_size) then
s_lst := Util.getOption(complex_size) :: s_lst;
end if;
s_lst := if listEmpty(s_lst) then {1} else s_lst;
Expand Down
10 changes: 10 additions & 0 deletions OMCompiler/Compiler/NFFrontEnd/NFSubscript.mo
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,16 @@ public
end match;
end toInteger;

function toIntegerOpt
input Subscript subscript;
output Option<Integer> int;
algorithm
int := match subscript
case INDEX() then SOME(Expression.toInteger(subscript.index));
else NONE();
end match;
end toIntegerOpt;

function toIndexList
input Subscript subscript;
input Integer length;
Expand Down
12 changes: 0 additions & 12 deletions OMCompiler/Compiler/NFFrontEnd/NFVariable.mo
Original file line number Diff line number Diff line change
Expand Up @@ -210,21 +210,9 @@ public
"Expands a variable into itself and its children if its complex."
input Variable var;
output list<Variable> children;
protected
function expandChildType
"helper function to inherit the array type dimensions"
input output Variable child;
input list<Dimension> dimensions;
algorithm
child.ty := Type.liftArrayLeftList(child.ty, dimensions);
end expandChildType;
algorithm
// for non-complex variables the children are empty therefore it will be returned itself
var.children := List.flatten(list(expandChildren(v) for v in var.children));
if isComplexArray(var) then
// if the variable is an array, inherit the array dimensions
var.children := list(expandChildType(v, Type.arrayDims(var.ty)) for v in var.children);
end if;
// return all children and the variable itself
children := var :: var.children;
end expandChildren;
Expand Down

0 comments on commit 025d351

Please sign in to comment.