Skip to content

Commit

Permalink
Improve connection handling (#9733)
Browse files Browse the repository at this point in the history
- Only use the array connection handling on models that contain large
  array connections.
- Split the splitting and scalarization of connectors into two separate
  phases, to allow better control over the scalarization.
  • Loading branch information
perost committed Nov 23, 2022
1 parent 719e6c2 commit 2880f4b
Show file tree
Hide file tree
Showing 19 changed files with 309 additions and 175 deletions.
40 changes: 1 addition & 39 deletions OMCompiler/Compiler/NFFrontEnd/NFCheckModel.mo
Expand Up @@ -57,7 +57,7 @@ algorithm
(variables, equations) := countVariableSize(v, variables, equations);
end for;

equations := equations + countEquationListSize(flatModel.equations);
equations := equations + Equation.sizeOfList(flatModel.equations);

for a in flatModel.algorithms loop
equations := equations + countAlgorithmSize(a);
Expand Down Expand Up @@ -90,44 +90,6 @@ algorithm
equations := equations + Type.sizeOf(Binding.getType(binding));
end countVariableSize;

function countEquationListSize
input list<Equation> eqs;
output Integer equations = 0;
algorithm
for e in eqs loop
equations := equations + countEquationSize(e);
end for;
end countEquationListSize;

function countEquationSize
input Equation eq;
output Integer equations;
algorithm
equations := match eq
case Equation.EQUALITY() then Type.sizeOf(eq.ty);
case Equation.ARRAY_EQUALITY() then Type.sizeOf(eq.ty);
case Equation.FOR() then countEquationListSize(eq.body);
case Equation.IF() then countEquationBranchSize(listHead(eq.branches));
case Equation.WHEN() then countEquationBranchSize(listHead(eq.branches));
else 0;
end match;
end countEquationSize;

function countEquationBranchSize
input Equation.Branch branch;
output Integer equations;
algorithm
equations := match branch
case Equation.Branch.BRANCH() then countEquationListSize(branch.body);

else
algorithm
Error.assertion(false, getInstanceName() + " got invalid branch", sourceInfo());
then
fail();
end match;
end countEquationBranchSize;

function countAlgorithmSize
input Algorithm alg;
output Integer equations = 0;
Expand Down
14 changes: 14 additions & 0 deletions OMCompiler/Compiler/NFFrontEnd/NFComponentRef.mo
Expand Up @@ -298,6 +298,20 @@ public
end match;
end firstName;

function first
input output ComponentRef cref;
algorithm
() := match cref
case CREF()
algorithm
cref.restCref := EMPTY();
then
();

else ();
end match;
end first;

function rest
input ComponentRef cref;
output ComponentRef restCref;
Expand Down
70 changes: 63 additions & 7 deletions OMCompiler/Compiler/NFFrontEnd/NFConnection.mo
Expand Up @@ -53,13 +53,7 @@ public
algorithm
cls := Connector.split(conn.lhs);
crs := Connector.split(conn.rhs);

if listLength(cls) <> listLength(crs) then
Error.assertion(false, getInstanceName() + " got unbalanced connection " + toString(conn) + ":" +
List.toString(cls, Connector.toString, "\n lhs: ", "{", ", ", "}", true) +
List.toString(crs, Connector.toString, "\n rhs: ", "{", ", ", "}", true), sourceInfo());
fail();
end if;
checkBalance(cls, crs, conn);

for cl in cls loop
cr :: crs := crs;
Expand All @@ -75,12 +69,74 @@ public
conns := listReverseInPlace(conns);
end split;

function scalarize
input Connection conn;
output list<Connection> conns = {};
protected
list<Connector> cls, crs;
Connector cr;
algorithm
if not Connector.isArray(conn.lhs) then
conns := {conn};
return;
end if;

cls := Connector.scalarize(conn.lhs);
crs := Connector.scalarize(conn.rhs);
checkBalance(cls, crs, conn);

for cl in cls loop
cr :: crs := crs;
conns := CONNECTION(cl, cr) :: conns;
end for;

conns := listReverseInPlace(conns);
end scalarize;

function scalarizePrefix
input Connection conn;
output list<Connection> conns = {};
protected
list<Connector> cls, crs;
Connector cr;
algorithm
if not Connector.isArray(conn.lhs) then
conns := {conn};
return;
end if;

cls := Connector.scalarizePrefix(conn.lhs);
crs := Connector.scalarizePrefix(conn.rhs);
checkBalance(cls, crs, conn);

for cl in cls loop
cr :: crs := crs;
conns := CONNECTION(cl, cr) :: conns;
end for;

conns := listReverseInPlace(conns);
end scalarizePrefix;

function toString
input Connection conn;
output String str;
algorithm
str := "connect(" + Connector.toString(conn.lhs) + ", " + Connector.toString(conn.rhs) + ")";
end toString;

protected
function checkBalance
input list<Connector> leftConnectors;
input list<Connector> rightConnectors;
input Connection conn;
algorithm
if listLength(leftConnectors) <> listLength(rightConnectors) then
Error.assertion(false, getInstanceName() + " got unbalanced connection " + toString(conn) + ":" +
List.toString(leftConnectors, Connector.toString, "\n lhs: ", "{", ", ", "}", true) +
List.toString(rightConnectors, Connector.toString, "\n rhs: ", "{", ", ", "}", true), sourceInfo());
fail();
end if;
end checkBalance;

annotation(__OpenModelica_Interface="frontend");
end NFConnection;
27 changes: 7 additions & 20 deletions OMCompiler/Compiler/NFFrontEnd/NFConnectionSets.mo
Expand Up @@ -70,19 +70,12 @@ package ConnectionSets
listLength(connections.connections) + listLength(connections.flows));

// Add flow variable to the sets, unless disabled by flag.
// Do this here if NF_SCALARIZE to use fast addList for scalarized flows.
if not Flags.isSet(Flags.DISABLE_SINGLE_FLOW_EQ) and Flags.isSet(Flags.NF_SCALARIZE) then
sets := List.fold(connections.flows, addConnector, sets);
if not Flags.isSet(Flags.DISABLE_SINGLE_FLOW_EQ) then
sets := List.fold(connections.flows, addSingleConnector, sets);
end if;

// Add the connections.
sets := List.fold1(connections.connections, addConnection, connections.broken, sets);

// Add remaining flow variables to the sets, unless disabled by flag.
// Do this after addConnection if not NF_SCALARIZE to get array dims right.
if not Flags.isSet(Flags.DISABLE_SINGLE_FLOW_EQ) and not Flags.isSet(Flags.NF_SCALARIZE) then
sets := List.fold(connections.flows, addSingleConnector, sets);
end if;
end fromConnections;

function addScalarConnector
Expand All @@ -97,17 +90,15 @@ package ConnectionSets
input Connector conn;
input output ConnectionSets.Sets sets;
algorithm
sets := addList(Connector.split(conn), sets);
sets := addList(Connector.scalarize(conn), sets);
end addConnector;

function addSingleConnector
"Adds a connector to the sets if it does not already exist"
input Connector conn;
input output ConnectionSets.Sets sets;
algorithm
for c in Connector.split(conn) loop
sets := find(c, sets);
end for;
sets := find(conn, sets);
end addSingleConnector;

function addConnection
Expand All @@ -119,17 +110,13 @@ package ConnectionSets
protected
list<Connection> conns;
algorithm
conns := Connection.split(connection);

if not listEmpty(broken) then
conns := list(c for c guard not isBroken(c.lhs, c.rhs, broken) in conns);
if not listEmpty(broken) and isBroken(connection.lhs, connection.rhs, broken) then
return;
end if;

// TODO: Check variability of connectors. It's an error if either
// connector is constant/parameter while the other isn't.
for conn in conns loop
sets := merge(conn.lhs, conn.rhs, sets);
end for;
sets := merge(connection.lhs, connection.rhs, sets);
end addConnection;

function isBroken
Expand Down
14 changes: 14 additions & 0 deletions OMCompiler/Compiler/NFFrontEnd/NFConnections.mo
Expand Up @@ -202,6 +202,20 @@ public
end if;
end makeConnectors;

function split
input output Connections conns;
algorithm
conns.flows := List.mapFlat(conns.flows, Connector.split);
conns.connections := List.mapFlat(conns.connections, Connection.split);
end split;

function scalarize
input output Connections conns;
algorithm
conns.flows := List.mapFlat(conns.flows, Connector.scalarize);
conns.connections := List.mapFlat(conns.connections, Connection.scalarize);
end scalarize;

function toString
input Connections conns;
output String str;
Expand Down

0 comments on commit 2880f4b

Please sign in to comment.