Skip to content

Commit

Permalink
- shorten Paths CSE-option
Browse files Browse the repository at this point in the history
  • Loading branch information
vwaurich committed May 4, 2016
1 parent 7e798ee commit 3a80514
Showing 1 changed file with 230 additions and 15 deletions.
245 changes: 230 additions & 15 deletions Compiler/BackEnd/CommonSubExpression.mo
Expand Up @@ -53,6 +53,7 @@ protected import DAEUtil;
protected import Expression;
protected import ExpressionDump;
protected import ExpressionSolve;
protected import ExpressionSimplify;
protected import Global;
protected import HashTableExpToExp;
protected import HashTableExpToIndex;
Expand Down Expand Up @@ -1013,10 +1014,21 @@ end prepareExpForReplace;
protected
uniontype CommonSubExp
record ASSIGNMENT_CSE
//a = exp1;
//b = exp1;
//--> a = b;
list<Integer> eqIdcs;
list<Integer> sharedVars;
list<Integer> aliasVars;
end ASSIGNMENT_CSE;

record SHORTCUT_CSE
//a = exp1;
//a = exp2;
//--> exp1 = exp2;
list<Integer> eqIdcs;
Integer sharedVar;
end SHORTCUT_CSE;
end CommonSubExp;

public function commonSubExpressionReplacement"detects common sub expressions and introduces alias variables for them.
Expand Down Expand Up @@ -1048,16 +1060,22 @@ algorithm
BackendDAE.IncidenceMatrix m, mT;
list<Integer> eqIdcs;
list<CommonSubExp> cseLst;

BackendDAE.Equation eqTest;
BackendDAE.Var var1,var2;
list<BackendDAE.Equation> eqLst;

case(BackendDAE.EQSYSTEM(orderedVars=vars, orderedEqs=eqs), BackendDAE.SHARED(functionTree=functionTree))
equation
(_, m, mT) = BackendDAEUtil.getIncidenceMatrix(sysIn, BackendDAE.ABSOLUTE(), SOME(functionTree));
algorithm
(_, m, mT) := BackendDAEUtil.getIncidenceMatrix(sysIn, BackendDAE.ABSOLUTE(), SOME(functionTree));
//print("start this eqSystem\n");
//BackendDump.dumpEqSystem(sysIn, "eqSystem input");
//BackendDump.dumpIncidenceMatrix(m);
//BackendDump.dumpIncidenceMatrixT(mT);
cseLst = commonSubExpressionFind(m, mT, vars, eqs);
cseLst := commonSubExpressionFind(m, mT, vars, eqs);
//if not listEmpty(cseLst) then print("update "+stringDelimitList(List.map(cseLst, printCSE), "\n")+"\n");end if;
syst = commonSubExpressionUpdate(cseLst, m, mT, sysIn);
syst := commonSubExpressionUpdate(cseLst, m, mT, sysIn);
syst.orderedEqs := eqs;
//print("done this eqSystem\n");
//BackendDump.dumpEqSystem(syst, "eqSystem");
then (syst, sharedIn);
Expand All @@ -1072,25 +1090,26 @@ protected function commonSubExpressionFind
input BackendDAE.EquationArray eqsIn;
output list<CommonSubExp> cseOut;
protected
list<Integer> eqIdcs, varIdcs, lengthLst, range;
Integer numVars;
list<Integer> eqIdcs, varIdcs,lengthLst, range;
list<list<Integer>> arrLst;
list<list<Integer>> partitions;
BackendDAE.Variables vars;
BackendDAE.Variables vars, linPathVars;
BackendDAE.EquationArray eqs;
BackendDAE.EqSystem eqSys;
BackendDAE.IncidenceMatrix m, mT;
list<BackendDAE.Equation> eqLst;
list<BackendDAE.Var> varLst;
list<CommonSubExp> cseLst2, cseLst3;
list<CommonSubExp> cseLst2, cseLst3, shortenPathsCSE;
list<tuple<Boolean, String>> varAtts, eqAtts;
algorithm
try
range := List.intRange(arrayLength(mIn));
arrLst := arrayList(mIn);
lengthLst := List.map(arrLst, listLength);

// check for CSE of length 1
//print("CHECK FOR CSE 2\n");
// check for CSE of length 1 (all eqs with 2 variables)
//print("CHECK FOR CSE 2\n");
(_, eqIdcs) := List.filter1OnTrueSync(lengthLst, intEq, 2, range);
(eqLst, eqIdcs) := List.filterOnTrueSync(BackendEquation.getEqns(eqIdcs, eqsIn),BackendEquation.isNotAlgorithm,eqIdcs); // no algorithms
eqs := BackendEquation.listEquation(eqLst);
Expand All @@ -1105,13 +1124,16 @@ algorithm
//BackendDump.dumpIncidenceMatrix(mT);
//varAtts := List.threadMap(List.fill(false, listLength(varIdcs)), List.fill("", listLength(varIdcs)), Util.makeTuple);
//eqAtts := List.threadMap(List.fill(false, listLength(eqIdcs)), List.fill("", listLength(eqIdcs)), Util.makeTuple);
//BackendDump.dumpBipartiteGraphStrongComponent2(vars, eqs, m, varAtts, eqAtts, "CSE2");
//BackendDump.dumpBipartiteGraphStrongComponent2(vars, eqs, m, varAtts, eqAtts, "CSE2_"+intString(arrayLength(mIn)));
partitions := arrayList(ResolveLoops.partitionBipartiteGraph(m, mT));
partitions := List.filterOnFalse(partitions,listEmpty);
//print("the partitions for system : \n"+stringDelimitList(List.map(partitions, HpcOmTaskGraph.intLstString), "\n")+"\n");
cseLst2 := List.fold(partitions, function getCSE2(m=m, mT=mT, vars=vars, eqs=eqs, eqMap=eqIdcs, varMap=varIdcs), {});

shortenPathsCSE := shortenPaths(partitions, m, mT, vars, eqs, listArray(eqIdcs), listArray(varIdcs), {});

// check for CSE of length 2
//print("CHECK FOR CSE 3\n");
//print("CHECK FOR CSE 3\n");
(_, eqIdcs) := List.filter1OnTrueSync(lengthLst, intEq, 3, range);
(eqLst, eqIdcs) := List.filterOnTrueSync(BackendEquation.getEqns(eqIdcs, eqsIn),BackendEquation.isNotAlgorithm,eqIdcs); // no algorithms
eqs := BackendEquation.listEquation(eqLst);
Expand All @@ -1126,17 +1148,88 @@ algorithm
//BackendDump.dumpIncidenceMatrix(mT);
//varAtts := List.threadMap(List.fill(false, listLength(varIdcs)), List.fill("", listLength(varIdcs)), Util.makeTuple);
//eqAtts := List.threadMap(List.fill(false, listLength(eqIdcs)), List.fill("", listLength(eqIdcs)), Util.makeTuple);
//BackendDump.dumpBipartiteGraphStrongComponent2(vars, eqs, m, varAtts, eqAtts, "CSE3");
//BackendDump.dumpBipartiteGraphStrongComponent2(vars, eqs, m, varAtts, eqAtts, "CSE3_"+intString(arrayLength(mIn)));
partitions := arrayList(ResolveLoops.partitionBipartiteGraph(m, mT));
//print("the partitions for system : \n"+stringDelimitList(List.map(partitions, HpcOmTaskGraph.intLstString), "\n")+"\n");
cseLst3 := List.fold(partitions, function getCSE3(m=m, mT=mT, vars=vars, eqs=eqs, eqMap=eqIdcs, varMap=varIdcs), {});
cseOut := listAppend(cseLst2, cseLst3);
cseOut := listAppend(cseLst2, listAppend(cseLst3,shortenPathsCSE));
//print("the cses : \n"+stringDelimitList(List.map(cseOut, printCSE), "\n")+"\n");
else
cseOut := {};
end try;
end commonSubExpressionFind;


protected function shortenPaths"looks for a path in the bipartite graph where each variable and equation has only 2 adjacent node.
Then check if variables which are shared by 2 equations can be combined somehow to rearrange edges and create a shortcut of this path.
author:Waurich TUD 2016-05"
input list<list<Integer>> allPartitions;
input BackendDAE.IncidenceMatrix mIn;
input BackendDAE.IncidenceMatrix mTIn;
input BackendDAE.Variables allVars;
input BackendDAE.EquationArray allEqs;
input array<Integer> eqMap;
input array<Integer> varMap;
input list<CommonSubExp> cseIn;
output list<CommonSubExp> cseOut;
protected
BackendDAE.IncidenceMatrix m, mT;
BackendDAE.AdjacencyMatrixEnhanced me,meT;
BackendDAE.EqSystem eqSys;
BackendDAE.Variables vars, pathVars;
list<BackendDAE.Var> varLst;
list<BackendDAE.Equation> eqLst;
BackendDAE.EquationArray eqs;
list<tuple<Boolean, String>> varAtts, eqAtts;
Integer numVars, varIdx;
array<Integer> pathVarIdxMap;
list<Integer> partition, varIdcs, adjEqs, pathVarIdcs;
list<CommonSubExp> cses;
algorithm
// getall vars with only 2 adjacent equations
numVars := BackendVariable.varsSize(allVars);
(_, pathVarIdcs) := List.filter1OnTrueSync(List.map(arrayList(mTIn), listLength), intEq, 2, List.intRange(numVars));
pathVars := BackendVariable.listVar1(List.map1(pathVarIdcs, BackendVariable.getVarAtIndexFirst, allVars));
pathVarIdxMap := listArray(List.map1(pathVarIdcs,Array.getIndexFirst,varMap));
cses := cseIn;

if BackendVariable.varsSize(pathVars) > 0 then
for partition in allPartitions loop
//print("partition "+stringDelimitList(List.map(partition, intString), ", ")+"\n");
//print("pathVarIdxMap "+stringDelimitList(List.map(List.map1(pathVarIdcs,Array.getIndexFirst,varMap), intString), ", ")+"\n");

//get only the partition equations
eqLst := BackendEquation.equationList(allEqs);
eqLst := List.map1(partition,List.getIndexFirst,eqLst);
eqs := BackendEquation.listEquation(eqLst);

eqSys := BackendDAEUtil.createEqSystem(pathVars, eqs);
(_, m, mT) := BackendDAEUtil.getIncidenceMatrix(eqSys, BackendDAE.SOLVABLE(), NONE());

//BackendDump.dumpIncidenceMatrix(m);
//BackendDump.dumpIncidenceMatrixT(mT);
//varAtts := List.threadMap(List.fill(false, arrayLength(mT)), List.fill("", arrayLength(mT)), Util.makeTuple);
//eqAtts := List.threadMap(List.fill(false, arrayLength(m)), List.fill("", arrayLength(m)), Util.makeTuple);
//BackendDump.dumpBipartiteGraphStrongComponent2(pathVars, eqs, m, varAtts, eqAtts, "shortenPaths"+stringDelimitList(List.map(partition,intString),"_"));

for varIdx in List.intRange(arrayLength(mT)) loop
adjEqs := arrayGet(mT,varIdx);

if listLength(adjEqs)==2 then
//print("varIdx1 "+intString(varIdx)+"\n");
//print("adjEqs "+stringDelimitList(List.map(adjEqs,intString),",")+"\n");
adjEqs := List.map1(adjEqs,List.getIndexFirst,partition);
adjEqs := List.map1(adjEqs, Array.getIndexFirst, eqMap);
varIdx := arrayGet(pathVarIdxMap,varIdx);
cses := SHORTCUT_CSE(adjEqs,varIdx)::cses;
end if;
end for; //end the variables
end for; //end all partitions
//print("the SHORTPATH cses : \n"+stringDelimitList(List.map(cses, printCSE), "\n")+"\n");
end if;
cseOut := cses;
end shortenPaths;

protected function getCSE2"traverses the partitions and checks for CSE2 i.e a=b+const. ; c = b+const. --> a=c
author:Waurich TUD 2014-11"
input list<Integer> partition;
Expand Down Expand Up @@ -1268,15 +1361,18 @@ author:Waurich TUD 2014-11"
algorithm
sysOut := matchcontinue (tplsIn, m, mT, sysIn)
local
Integer sharedVar, eqIdx1, eqIdx2, varIdx1, varIdx2, varIdx_remain, varIdxAlias, eqIdxDel, eqIdxLeft;
Integer sharedVar, eqIdx1, eqIdx2, varIdx1, varIdx2, varIdx_remain, varIdxAlias, eqIdxDel, eqIdxLeft, n;
list<Integer> eqIdcs, eqs1, eqs2, vars1, vars2, aliasVars;
list<CommonSubExp> rest;
BackendDAE.Var var1, var2, var_remain, var_alias;
BackendVarTransform.VariableReplacements repl;
BackendDAE.Variables vars;
BackendDAE.Var var;
BackendDAE.Equation eq1,eq2, eqNew;
BackendDAE.EquationArray eqs;
BackendDAE.EqSystem syst;
DAE.Exp varExp_remain, varExp_alias;
DAE.Exp varExp_remain, varExp_alias, lhs1,rhs1,lhs2,rhs2,varExp,exp;
DAE.Type ty;
DAE.ComponentRef cref;
list<BackendDAE.Equation> eqLst;
case({}, _, _, syst as BackendDAE.EQSYSTEM(orderedVars=vars, orderedEqs=eqs))
Expand Down Expand Up @@ -1322,23 +1418,142 @@ algorithm
//replace original equation
BackendEquation.setAtIndex(eqs,eqIdxDel,BackendDAE.EQUATION(varExp_remain,varExp_alias,DAE.emptyElementSource,BackendDAE.EQ_ATTR_DEFAULT_DYNAMIC));
then commonSubExpressionUpdate(rest, m, mT, syst);

case (SHORTCUT_CSE(eqIdcs={eqIdx1, eqIdx2}, sharedVar=sharedVar)::rest, _, _, syst as BackendDAE.EQSYSTEM(orderedVars=vars, orderedEqs=eqs))
equation
{eq1, eq2} = BackendEquation.getEqns({eqIdx1, eqIdx2}, eqs);
var = BackendVariable.getVarAt(vars, sharedVar);
varExp = BackendVariable.varExp(var);
ty = Expression.typeof(varExp);
BackendDAE.EQUATION(exp=lhs1, scalar=rhs1) = eq1;
BackendDAE.EQUATION(exp=lhs2, scalar=rhs2) = eq2;

// since ExpressionSolve is able to solve for vars in if-expressions, stop here
true = hasAlgebraicOperationsOnly(lhs1);
true = hasAlgebraicOperationsOnly(rhs1);
true = hasAlgebraicOperationsOnly(lhs2);
true = hasAlgebraicOperationsOnly(rhs2);

(rhs1, _) = ExpressionSolve.solve(lhs1, rhs1, varExp);
(lhs1, _) = ExpressionSolve.solve(lhs2, rhs2, varExp);

(_,lhs1,rhs1) = cancelExpressions(lhs1,rhs1);
n = listLength(Expression.getAllCrefs(Expression.makeDiff(lhs1,rhs1)));
//print("n1 "+intString(n1)+"\n");
//print("n2 "+intString(n2)+"\n");

if n <= 2 then
//print("FROM "+BackendDump.equationString(eq1)+"\n");
//print("AND "+BackendDump.equationString(eq2)+"\n");
eqNew = BackendDAE.EQUATION(lhs1,rhs1,DAE.emptyElementSource,BackendDAE.EQ_ATTR_DEFAULT_DYNAMIC);
//print("MADE A NEW EQUATION "+BackendDump.equationString(eqNew)+"\n\n");
//replace original equation
BackendEquation.setAtIndex(eqs,eqIdx1,eqNew);
end if;

then commonSubExpressionUpdate(rest, m, mT, syst);
case (_::rest, _, _, _)
then commonSubExpressionUpdate(rest, m, mT, sysIn);
end matchcontinue;
end commonSubExpressionUpdate;


protected function hasAlgebraicOperationsOnly"checks if the expression contains algebraic operations only. (no realtions, ifs, etc.)
author:Waurich TUD 05-2016"
input DAE.Exp exp;
output Boolean isAlgOut;
algorithm
isAlgOut := match(exp)
local
Boolean b;
DAE.Exp e1,e2;
case(DAE.RCONST())
then true;
case(DAE.CREF())
then true;
case(DAE.BINARY(e1,_,e2))
equation
b = hasAlgebraicOperationsOnly(e1);
b = b and hasAlgebraicOperationsOnly(e2);
then b;
case(DAE.UNARY(_,e1))
equation
b = hasAlgebraicOperationsOnly(e1);
then b;
else
then false;
end match;
end hasAlgebraicOperationsOnly;


protected function cancelExpressions"checks if factors on each side of an equation can be cancelled
author: Waurich TUD 2016-05"
input DAE.Exp e1In;//lhs
input DAE.Exp e2In;//rhs
output Boolean canceled = false;
output DAE.Exp e1Out = e1In;
output DAE.Exp e2Out = e2In;
protected
list<DAE.Exp> topLevelFactors1, topLevelFactors2;
algorithm
topLevelFactors1 := getTopLevelFactors(e1In,{});
//print("topLevelFactors1 "+ExpressionDump.printExpListStr(topLevelFactors1)+"\n");
topLevelFactors2 := getTopLevelFactors(e2In,{});
//print("topLevelFactors2 "+ExpressionDump.printExpListStr(topLevelFactors2)+"\n");
if not listEmpty(topLevelFactors1) and not listEmpty(topLevelFactors1) then
topLevelFactors1 := List.intersectionOnTrue(topLevelFactors1,topLevelFactors2,Expression.expEqual);
if listLength(topLevelFactors1) == 1 then
e1Out := Expression.expDiv(e1In,listHead(topLevelFactors1));
e1Out := ExpressionSimplify.simplify(e1Out);
e2Out := Expression.expDiv(e2In,listHead(topLevelFactors2));
e2Out := ExpressionSimplify.simplify(e2Out);
//print("e1Out "+ExpressionDump.printExpListStr({e1Out})+"\n");
//print("e2Out "+ExpressionDump.printExpListStr({e2Out})+"\n");
canceled := true;
end if;
end if;
end cancelExpressions;

protected function getTopLevelFactors"Gets factors(crefs only) of the exp"
input DAE.Exp exp;
input list<DAE.Exp> lstIn;
output list<DAE.Exp> lstOut;
algorithm
lstOut := matchcontinue(exp,lstIn)
local
DAE.Exp e1,e2;
list<DAE.Exp> eLst;
case(DAE.BINARY(e1,DAE.MUL(_),e2),_)
equation
eLst = getTopLevelFactors(e1,lstIn);
eLst = getTopLevelFactors(e2,eLst);
then eLst;
case(DAE.UNARY(_ ,e1 as DAE.CREF()),_)
equation
then e1::lstIn;
case(e1 as DAE.CREF(),_)
equation
then e1::lstIn;
else
then lstIn;
end matchcontinue;
end getTopLevelFactors;

protected function printCSE"prints a CSE tuple string.
author:Waurich TUD 2014-11"
input CommonSubExp cse;
output String s;
algorithm
s := match(cse)
local
Integer sharedVar;
list<Integer> eqIdcs;
list<Integer> sharedVars;
list<Integer> aliasVars;
case(ASSIGNMENT_CSE(eqIdcs=eqIdcs, sharedVars=sharedVars, aliasVars=aliasVars))
then "ASSIGN_CSE: eqs{"+stringDelimitList(List.map(eqIdcs, intString), ", ")+"}"+" sharedVars{"+stringDelimitList(List.map(sharedVars, intString), ", ")+"}"+" aliasVars{"+stringDelimitList(List.map(aliasVars, intString), ", ")+"}";
case(SHORTCUT_CSE(eqIdcs, sharedVar))
then "SHORTCUT_CSE: eqs{"+stringDelimitList(List.map(eqIdcs, intString), ", ")+"}"+" sharedVar{"+intString(sharedVar)+"}";
end match;
end printCSE;
annotation(__OpenModelica_Interface="backend");
Expand Down

0 comments on commit 3a80514

Please sign in to comment.