Skip to content

Commit

Permalink
[NB] correctly handle tuple assignments in when (OpenModelica#12545)
Browse files Browse the repository at this point in the history
  • Loading branch information
kabdelhak committed Jun 10, 2024
1 parent 5ffd12a commit f2815de
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 44 deletions.
128 changes: 85 additions & 43 deletions OMCompiler/Compiler/NBackEnd/Classes/NBEquation.mo
Original file line number Diff line number Diff line change
Expand Up @@ -2622,49 +2622,55 @@ public
input WhenEquationBody body;
output list<WhenEquationBody> bodies = {};
protected
UnorderedSet<ComponentRef> discr_map = UnorderedSet.new(ComponentRef.hash, ComponentRef.isEqual);
UnorderedSet<ComponentRef> state_map = UnorderedSet.new(ComponentRef.hash, ComponentRef.isEqual);
UnorderedMap<ComponentRef, CrefSet> discr_map = UnorderedMap.new<CrefSet>(ComponentRef.hash, ComponentRef.isEqual);
UnorderedSet<ComponentRef> state_set = UnorderedSet.new(ComponentRef.hash, ComponentRef.isEqual);
UnorderedSet<ComponentRef> discr_marks = UnorderedSet.new(ComponentRef.hash, ComponentRef.isEqual);
list<tuple<Expression, list<WhenStatement>>> flat_when;
list<tuple<Expression, list<WhenStatement>>> flat_new;
list<ComponentRef> discretes, states;
CrefSet set;
Expression condition, acc_condition = Expression.EMPTY(Type.INTEGER());
list<WhenStatement> stmts;
list<WhenStatement> stmts, assigns;
Option<WhenStatement> stmt;
Option<WhenEquationBody> new_body;
algorithm
// collect all discretes and states contained in the when equation body
// and also flatten the when equation to a list
flat_when := collectForSplit(SOME(body), discr_map, state_map);
discretes := UnorderedSet.toList(discr_map);
states := UnorderedSet.toList(state_map);
flat_when := collectForSplit(SOME(body), discr_map, state_set);
discretes := UnorderedMap.keyList(discr_map);
states := UnorderedSet.toList(state_set);

// create a when equation for each discrete state
for disc in discretes loop
flat_new := {};
for tpl in flat_when loop
(condition, stmts) := tpl;
// get first assignment - each branch should only have one
// assignment per discrete state
stmt := getFirstAssignment(disc, stmts);
// if there is a statement: create the when body and combine with previous
// conditions. if there is no statement in this branch, save the condition
// negated for the next branch
if Util.isSome(stmt) then
condition := combineConditions(acc_condition, condition, false);
acc_condition := Expression.EMPTY(Type.INTEGER());
flat_new := (condition, {Util.getOption(stmt)}) :: flat_new;
if not UnorderedSet.contains(disc, discr_marks) then
set := UnorderedMap.getSafe(disc, discr_map, sourceInfo());
for marked in UnorderedSet.toList(set) loop
UnorderedSet.add(marked, discr_marks);
end for;
flat_new := {};
for tpl in flat_when loop
(condition, stmts) := tpl;
assigns := getAssignments(set, stmts);
// if there is a statement: create the when body and combine with previous
// conditions. if there is no statement in this branch, save the condition
// negated for the next branch
if not listEmpty(assigns) then
condition := combineConditions(acc_condition, condition, false);
acc_condition := Expression.EMPTY(Type.INTEGER());
flat_new := (condition, assigns) :: flat_new;
else
acc_condition := combineConditions(acc_condition, condition, true);
end if;
end for;
// create body from flat list and add to new bodies
new_body := fromFlatList(flat_new);
if Util.isSome(new_body) then
bodies := Util.getOption(new_body) :: bodies;
else
acc_condition := combineConditions(acc_condition, condition, true);
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName()
+ " failed because when partition for: " + ComponentRef.toString(disc)
+ " could not be recovered."});
end if;
end for;
// create body from flat list and add to new bodies
new_body := fromFlatList(flat_new);
if Util.isSome(new_body) then
bodies := Util.getOption(new_body) :: bodies;
else
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName()
+ " failed because when partition for: " + ComponentRef.toString(disc)
+ " could not be recovered."});
end if;
end for;

Expand Down Expand Up @@ -2758,12 +2764,13 @@ public
end getAllAssigned;

protected
type CrefSet = UnorderedSet<ComponentRef>;
function collectForSplit
"collects all discrete states and regular states for splitting up
of a when equation. also flattens it to a list"
input Option<WhenEquationBody> body_opt;
input UnorderedSet<ComponentRef> discr_map;
input UnorderedSet<ComponentRef> state_map;
input UnorderedMap<ComponentRef, CrefSet> discr_map;
input UnorderedSet<ComponentRef> state_set;
output list<tuple<Expression, list<WhenStatement>>> flat_when;
protected
WhenEquationBody body;
Expand All @@ -2774,11 +2781,16 @@ public
() := match stmt
local
ComponentRef cref;
Expression tpl;

case WhenStatement.ASSIGN(lhs = Expression.CREF(cref = cref)) algorithm
UnorderedSet.add(cref, discr_map);
addCrefsMap(discr_map, {cref});
then ();
case WhenStatement.ASSIGN(lhs = tpl as Expression.TUPLE()) algorithm
addCrefsMap(discr_map, UnorderedSet.toList(Expression.extractCrefs(tpl)));
then ();
case WhenStatement.REINIT(stateVar = cref) algorithm
UnorderedSet.add(cref, state_map);
UnorderedSet.add(cref, state_set);
then ();
case WhenStatement.ASSIGN() algorithm
Error.addMessage(Error.INTERNAL_ERROR,{getInstanceName()
Expand All @@ -2787,30 +2799,60 @@ public
else ();
end match;
end for;
flat_when := (body.condition, body.when_stmts) :: collectForSplit(body.else_when, discr_map, state_map);
flat_when := (body.condition, body.when_stmts) :: collectForSplit(body.else_when, discr_map, state_set);
else
flat_when := {};
end if;
end collectForSplit;

function getFirstAssignment
"returns the first assignment in the list that is solved for cref"
input ComponentRef cref;
function addCrefsMap
input UnorderedMap<ComponentRef, CrefSet> discr_map;
input list<ComponentRef> crefs;
protected
CrefSet set_new, set = UnorderedSet.new(ComponentRef.hash, ComponentRef.isEqual);
algorithm
for c in crefs loop
if UnorderedMap.contains(c, discr_map) then
set_new := UnorderedMap.getSafe(c, discr_map, sourceInfo());
if not referenceEq(set, set_new) then
set := UnorderedSet.union(set, set_new);
end if;
else
UnorderedSet.add(c, set);
end if;
end for;

for c in crefs loop
UnorderedMap.add(c, set, discr_map);
end for;
end addCrefsMap;

function getAssignments
"returns all assignments for the crefs in crefSet and merges if necessary"
input UnorderedSet<ComponentRef> crefSet;
input list<WhenStatement> stmts;
output Option<WhenStatement> assign = NONE();
output list<WhenStatement> assigns = {};
algorithm
for stmt in stmts loop
() := match stmt
local
ComponentRef lhs;
case WhenStatement.ASSIGN(lhs = Expression.CREF(cref = lhs))
guard(ComponentRef.isEqual(cref, lhs)) algorithm
assign := SOME(stmt); break;
ComponentRef cref;
Expression tpl;

case WhenStatement.ASSIGN(lhs = Expression.CREF(cref = cref))
guard(UnorderedSet.contains(cref, crefSet)) algorithm
assigns := stmt :: assigns;
then ();

case WhenStatement.ASSIGN(lhs = tpl as Expression.TUPLE())
guard(List.any(list(UnorderedSet.contains(c, crefSet) for c in UnorderedSet.toList(Expression.extractCrefs(tpl))), Util.id)) algorithm
assigns := stmt :: assigns;
then ();

else ();
end match;
end for;
end getFirstAssignment;
end getAssignments;

function getFirstReinit
"returns the first reinit in the list that reinitializes cref"
Expand Down
6 changes: 5 additions & 1 deletion OMCompiler/Compiler/NBackEnd/Util/NBSlice.mo
Original file line number Diff line number Diff line change
Expand Up @@ -1201,7 +1201,11 @@ protected
// I. resolve the skips
d := UnorderedMap.getSafe(cref, dep, sourceInfo());
(start, _) := mapping.eqn_AtS[eqn_arr_idx];
(skip_idx, skip_ty) := resolveSkips(start, ty, d.skips);
if not UnorderedSet.contains(cref, rep) then
(skip_idx, skip_ty) := resolveSkips(start, ty, d.skips);
else
(skip_idx, skip_ty) := (start, ty);
end if;

// get equation and iterator sizes and frames
body_size := Type.sizeOf(skip_ty);
Expand Down

0 comments on commit f2815de

Please sign in to comment.