Skip to content

Commit

Permalink
Change sliced crefs to array constructors (#8339)
Browse files Browse the repository at this point in the history
- Change crefs like `a.x` or `a[:].x` into
  `{a[$i1].x for $i1 in 1:size(a, 1)}`
  • Loading branch information
perost committed Dec 17, 2021
1 parent 6f1e7d1 commit 5a6b507
Show file tree
Hide file tree
Showing 15 changed files with 620 additions and 8 deletions.
6 changes: 1 addition & 5 deletions OMCompiler/Compiler/NFFrontEnd/NFArrayConnections.mo
Expand Up @@ -719,11 +719,7 @@ protected

iterators := arrayCreate(Vector.size(vCount), InstNode.EMPTY_NODE());
for i in 1:arrayLength(iterators) loop
iterators[i] := InstNode.fromComponent(
"$i" + String(i),
Component.newIterator(Type.INTEGER(), AbsynUtil.dummyInfo),
InstNode.EMPTY_NODE()
);
iterators[i] := InstNode.newIndexedIterator(i);
end for;

iter_expl := list(Expression.fromCref(ComponentRef.makeIterator(i, Type.INTEGER())) for i in iterators);
Expand Down
110 changes: 110 additions & 0 deletions OMCompiler/Compiler/NFFrontEnd/NFComponentRef.mo
Expand Up @@ -1518,5 +1518,115 @@ public
end match;
end mapTypes;

function isSliced
input ComponentRef cref;
output Boolean sliced;
protected
function is_sliced_impl
input ComponentRef cref;
output Boolean sliced;
algorithm
sliced := match cref
case CREF(origin = Origin.CREF)
algorithm
sliced := Type.dimensionCount(cref.ty) > listLength(cref.subscripts) or
List.any(cref.subscripts, Subscript.isSliced);
then
sliced or is_sliced_impl(cref.restCref);

else false;
end match;
end is_sliced_impl;
algorithm
sliced := match cref
case CREF() then is_sliced_impl(cref.restCref);
else false;
end match;
end isSliced;

function iterate
input output ComponentRef cref;
output list<tuple<InstNode, Expression>> iterators;
protected
ComponentRef rest_cref;

function iterate_impl
"Replaces any slice subscripts (including implicit :) with an iterator,
and returns a list of all iterators with the corresponding ranges."
input output ComponentRef cref;
input output list<tuple<InstNode, Expression>> iterators = {};
input Integer index = 1;
protected
ComponentRef rest_cref;
Dimension dim;
list<Dimension> dims;
Integer dim_count, sub_count;
list<Subscript> subs, isubs;
Integer dim_index, iter_index;
InstNode iterator;
Expression range;
algorithm
() := match cref
case CREF(origin = Origin.CREF)
algorithm
dims := listReverse(Type.arrayDims(cref.ty));
dim_count := listLength(dims);
sub_count := listLength(cref.subscripts);
subs := List.consN(dim_count - sub_count, Subscript.WHOLE(), cref.subscripts);
isubs := {};
iter_index := index;
dim_index := dim_count;

for s in listReverse(subs) loop
dim :: dims := dims;

if not Subscript.isIndex(s) then
range := match s
// Slices like 1:3 are used directly.
case Subscript.SLICE() then s.slice;
// : are turned into 1:size(x, dim).
case Subscript.WHOLE()
then Expression.makeRange(Dimension.lowerBoundExp(dim),
NONE(),
Dimension.endExp(dim, cref, dim_index));
end match;

iterator := InstNode.newIndexedIterator(iter_index);
iterators := (iterator, range) :: iterators;
dim_index := dim_index - 1;
iter_index := iter_index + 1;

s := Subscript.INDEX(Expression.fromCref(makeIterator(iterator, Type.INTEGER())));
end if;

isubs := s :: isubs;
end for;

cref.subscripts := isubs;
(rest_cref, iterators) := iterate_impl(cref.restCref, iterators, iter_index);
cref.restCref := rest_cref;
then
();

else ();
end match;
end iterate_impl;
algorithm
iterators := match cref
case CREF()
algorithm
(rest_cref, iterators) := iterate_impl(cref.restCref);

if not listEmpty(iterators) then
cref.restCref := rest_cref;
iterators := listReverseInPlace(iterators);
end if;
then
iterators;

else {};
end match;
end iterate;

annotation(__OpenModelica_Interface="frontend");
end NFComponentRef;
2 changes: 1 addition & 1 deletion OMCompiler/Compiler/NFFrontEnd/NFDimension.mo
Expand Up @@ -328,7 +328,7 @@ public
then Expression.makeEnumLiteral(ty, listLength(ty.literals));
case EXP() then dim.exp;
case UNKNOWN()
then Expression.SIZE(Expression.CREF(Type.UNKNOWN(), ComponentRef.stripSubscripts(cref)),
then Expression.SIZE(Expression.fromCref(ComponentRef.stripSubscripts(cref)),
SOME(Expression.INTEGER(index)));
end match;
end endExp;
Expand Down
83 changes: 83 additions & 0 deletions OMCompiler/Compiler/NFFrontEnd/NFEquation.mo
Expand Up @@ -595,6 +595,89 @@ public
end match;
end mapExp;

function mapExpShallow
input output Equation eq;
input MapExpFn func;
algorithm
eq := match eq
local
Expression e1, e2, e3;

case EQUALITY()
algorithm
e1 := func(eq.lhs);
e2 := func(eq.rhs);
then
if referenceEq(e1, eq.lhs) and referenceEq(e2, eq.rhs)
then eq else EQUALITY(e1, e2, eq.ty, eq.source);

case ARRAY_EQUALITY()
algorithm
e1 := func(eq.lhs);
e2 := func(eq.rhs);
then
if referenceEq(e1, eq.lhs) and referenceEq(e2, eq.rhs)
then eq else ARRAY_EQUALITY(e1, e2, eq.ty, eq.source);

case CONNECT()
algorithm
e1 := func(eq.lhs);
e2 := func(eq.rhs);
then
if referenceEq(e1, eq.lhs) and referenceEq(e2, eq.rhs)
then eq else CONNECT(e1, e2, eq.source);

case FOR()
algorithm
eq.range := Util.applyOption(eq.range, func);
then
eq;

case IF()
algorithm
eq.branches := list(Branch.mapExp(b, func, mapBody = false) for b in eq.branches);
then
eq;

case WHEN()
algorithm
eq.branches := list(Branch.mapExp(b, func, mapBody = false) for b in eq.branches);
then
eq;

case ASSERT()
algorithm
e1 := func(eq.condition);
e2 := func(eq.message);
e3 := func(eq.level);
then
if referenceEq(e1, eq.condition) and referenceEq(e2, eq.message) and
referenceEq(e3, eq.level) then eq else ASSERT(e1, e2, e3, eq.source);

case TERMINATE()
algorithm
e1 := func(eq.message);
then
if referenceEq(e1, eq.message) then eq else TERMINATE(e1, eq.source);

case REINIT()
algorithm
e1 := func(eq.cref);
e2 := func(eq.reinitExp);
then
if referenceEq(e1, eq.cref) and referenceEq(e2, eq.reinitExp) then
eq else REINIT(e1, e2, eq.source);

case NORETCALL()
algorithm
e1 := func(eq.exp);
then
if referenceEq(e1, eq.exp) then eq else NORETCALL(e1, eq.source);

else eq;
end match;
end mapExpShallow;

function foldExpList<ArgT>
input list<Equation> eq;
input FoldFunc func;
Expand Down
24 changes: 24 additions & 0 deletions OMCompiler/Compiler/NFFrontEnd/NFFlatModel.mo
Expand Up @@ -113,6 +113,30 @@ public
flatModel.initialAlgorithms := Algorithm.mapExpList(flatModel.initialAlgorithms, fn);
end mapExp;

function mapEquations
input output FlatModel flatModel;
input MapFn fn;

partial function MapFn
input output Equation eq;
end MapFn;
algorithm
flatModel.equations := list(Equation.map(eq, fn) for eq in flatModel.equations);
flatModel.initialEquations := list(Equation.map(eq, fn) for eq in flatModel.initialEquations);
end mapEquations;

function mapAlgorithms
input output FlatModel flatModel;
input MapFn fn;

partial function MapFn
input output Algorithm alg;
end MapFn;
algorithm
flatModel.algorithms := list(fn(alg) for alg in flatModel.algorithms);
flatModel.initialAlgorithms := list(fn(alg) for alg in flatModel.initialAlgorithms);
end mapAlgorithms;

function toString
input FlatModel flatModel;
input Boolean printBindingTypes = false;
Expand Down
3 changes: 1 addition & 2 deletions OMCompiler/Compiler/NFFrontEnd/NFFlatten.mo
Expand Up @@ -907,8 +907,7 @@ algorithm
prefix_node := ComponentRef.node(prefix);

for dim in dimensions loop
iter_comp := Component.newIterator(Type.INTEGER(), InstNode.info(prefix_node));
iter := InstNode.fromComponent("$i" + String(index), iter_comp, InstNode.parent(prefix_node));
iter := InstNode.newIndexedIterator(index, Type.INTEGER(), InstNode.info(prefix_node));
iterators := iter :: iterators;
index := index + 1;

Expand Down
1 change: 1 addition & 0 deletions OMCompiler/Compiler/NFFrontEnd/NFInst.mo
Expand Up @@ -206,6 +206,7 @@ algorithm
InstUtil.dumpFlatModelDebug("scalarize", flatModel, functions);

VerifyModel.verify(flatModel);
(flatModel, functions) := InstUtil.expandSlicedCrefs(flatModel, functions);
flatModel := InstUtil.combineSubscripts(flatModel);

//(var_count, eq_count) := CheckModel.checkModel(flatModel);
Expand Down
18 changes: 18 additions & 0 deletions OMCompiler/Compiler/NFFrontEnd/NFInstNode.mo
Expand Up @@ -290,6 +290,24 @@ uniontype InstNode
InstNodeType.BASE_CLASS(parent, definition));
end newExtends;

function newIterator
input String name;
input Type ty;
input SourceInfo info;
output InstNode iterator;
algorithm
iterator := fromComponent(name, Component.newIterator(ty, info), EMPTY_NODE());
end newIterator;

function newIndexedIterator
input Integer index;
input Type ty = Type.INTEGER();
input SourceInfo info = AbsynUtil.dummyInfo;
output InstNode iterator;
algorithm
iterator := newIterator("$i" + String(index), ty, info);
end newIndexedIterator;

function fromComponent
input String name;
input Component component;
Expand Down

0 comments on commit 5a6b507

Please sign in to comment.