Skip to content

Commit 117b803

Browse files
authored
optimize data structure for gauss algorithm (#13745)
use array<array<Real>> for jacobian matrix instead of intermediate list<list<Real>> which is converted to flat matrix array<Real> for gauss algorithm
1 parent 696d683 commit 117b803

File tree

1 file changed

+42
-35
lines changed

1 file changed

+42
-35
lines changed

OMCompiler/Compiler/BackEnd/SymbolicJacobian.mo

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -569,15 +569,14 @@ protected
569569
BackendDAE.Variables orderedVars;
570570
BackendDAE.EquationArray orderedEqs;
571571
BackendDAE.Matching matching;
572-
array<Integer> ass1, ass2, ass1add, ass2add;
572+
array<Integer> ass1, ass2, assAdd;
573573
BackendDAE.StrongComponents comps;
574574
algorithm
575575
BackendDAE.EQSYSTEM(matching=BackendDAE.MATCHING(ass1=ass1, ass2=ass2, comps=comps)) := systIn;
576576
if not listEmpty(compsAdd) then
577-
ass1add := arrayCreate(listLength(compsAdd), 0);
578-
ass2add := arrayCreate(listLength(compsAdd), 0);
579-
ass1 := arrayAppend(ass1, ass1add);
580-
ass2 := arrayAppend(ass2, ass1add);
577+
assAdd := arrayCreate(listLength(compsAdd), 0);
578+
ass1 := arrayAppend(ass1, assAdd);
579+
ass2 := arrayAppend(ass2, assAdd);
581580
List.map2_0(compsAdd, updateAssignment, ass1, ass2);
582581
end if;
583582
List.map2_0(compsNew, updateAssignment, ass1, ass2);
@@ -624,14 +623,14 @@ protected
624623
BackendDAE.EquationArray eqns,eqns1, eqns2;
625624
list<DAE.Exp> beqs;
626625
list<DAE.ElementSource> sources;
627-
list<list<Real>> jacVals;
628626
BackendDAE.Matching matching;
629627
DAE.FunctionTree funcs;
630628
BackendDAE.Shared shared;
631629
BackendDAE.StateSets stateSets;
632630
BackendDAE.BaseClockPartitionKind partitionKind;
633631

634-
array<Real> A,b;
632+
array<array<Real>> A;
633+
array<Real> b;
635634
Real entry;
636635
Integer row,col,n, systIdx;
637636
array<Integer> order;
@@ -644,17 +643,12 @@ algorithm
644643
(beqs,sources) := BackendDAEUtil.getEqnSysRhs(eqns1,v,SOME(funcs));
645644
beqs := listReverse(beqs);
646645
//print("bside: \n"+ExpressionDump.printExpListStr(beqs)+"\n");
647-
jacVals := evaluateConstantJacobian(listLength(var_lst),jac);
646+
A := evaluateConstantJacobianArray(listLength(var_lst),jac);
648647
//print("JacVals\n"+stringDelimitList(List.map(jacVals,rListStr),"\n")+"\n\n");
649648

650-
A := arrayCreate(n*n,0.0);
651-
b := arrayCreate(n*n,0.0); // i.e. a matrix for the b-vars to get their coefficients independently [(b1,0,0);(0,b2,0),(0,0,b3)]
649+
b := arrayCreate(n*n,0.0); // i.e. a matrix for the b-vars to get their coefficients independently [(b1,0,0);(0,b2,0),(0,0,b3)]
652650
order := arrayCreate(n,0);
653651
for row in 1:n loop
654-
for col in 1:n loop
655-
entry := listGet(listGet(jacVals,row),col);
656-
arrayUpdate(A,((row-1)*n+col),entry);
657-
end for;
658652
arrayUpdate(b,(row-1)*n+row,1.0);
659653
end for;
660654
//print("b\n"+stringDelimitList(List.map(arrayList(b),realString),", ")+"\n\n");
@@ -664,6 +658,9 @@ algorithm
664658

665659
(bVarsOut,bEqsOut) := createBVecVars(sysIdxIn,compIdxIn,n,DAE.T_REAL_DEFAULT,beqs);
666660
sysEqsOut := createSysEquations(A,b,n,order,var_lst,bVarsOut);
661+
for a in A loop
662+
GCExt.free(a);
663+
end for;
667664
GCExt.free(A);
668665
GCExt.free(b);
669666
sysIdxOut := sysIdxIn+1;
@@ -672,7 +669,7 @@ end solveConstJacLinearSystem;
672669

673670
protected function createSysEquations "creates new equations for a linear system with constant Jacobian matrix.
674671
author: Waurich TUD 2015-03"
675-
input array<Real> A;
672+
input array<array<Real>> A;
676673
input array<Real> b;
677674
input Integer n;
678675
input array<Integer> order;
@@ -691,7 +688,7 @@ algorithm
691688
bExps := List.map(bVars, BackendVariable.varExp2);
692689
for i in 1:n loop
693690
row := arrayGet(order,i);
694-
coeffs := Array.getRange((row-1)*n+1,(row*n),A);
691+
coeffs := arrayList(A[row]);
695692
coeffExps := List.map(coeffs,Expression.makeRealExp);
696693
xProds := List.threadMap1(coeffExps,xExps,makeBinaryExp,DAE.MUL(DAE.T_REAL_DEFAULT));
697694
lhs := List.fold1(xProds,Expression.makeBinaryExp,DAE.ADD(DAE.T_REAL_DEFAULT),DAE.RCONST(0.0));
@@ -742,7 +739,7 @@ algorithm
742739
end createBVecVars;
743740

744741
protected function gauss
745-
input array<Real> A;
742+
input array<array<Real>> A;
746743
input array<Real> b;
747744
input Integer indxIn;
748745
input Integer n;
@@ -764,11 +761,10 @@ algorithm
764761

765762
// the pivot row in the A-matrix divided by the pivot element
766763
for ic in indxIn:n loop
767-
pos := (pivotIdx-1)*n+ic;
768-
entry := arrayGet(A,pos);
764+
entry := arrayGet(A[pivotIdx],ic);
769765
entry := realDiv(entry,pivot); //divide column entry with pivot element
770766
//print(" pos "+intString(pos)+" entry "+realString(arrayGet(A,pos))+"\n");
771-
arrayUpdate(A,pos,entry);
767+
arrayUpdate(A[pivotIdx],ic,entry);
772768
end for;
773769
// the complete pivot row of the b-vector divided by the pivot element
774770
for ic in 1:n loop
@@ -780,20 +776,20 @@ algorithm
780776

781777
// the remaining rows
782778
for ir in range loop
783-
first := arrayGet(A,(ir-1)*n+indxIn); //the first row element, that is going to be zero
779+
first := arrayGet(A[ir],indxIn); //the first row element, that is going to be zero
784780
//print("first "+realString(first)+"\n");
785781
for ic in indxIn:n loop
786-
pos := (ir-1)*n+ic;
787-
entry := arrayGet(A,pos); // the current entry
788-
pivot := arrayGet(A,(pivotIdx-1)*n+ic); // the element from the column in the pivot row
782+
pos := (ir-1)*n+ic;
783+
entry := arrayGet(A[ir],ic); // the current entry
784+
pivot := arrayGet(A[pivotIdx],ic); // the element from the column in the pivot row
789785
//print("pivot "+realString(pivot)+"\n");
790786
//print("ir "+intString(ir)+" pos "+intString(pos)+" entry0 "+realString(entry)+" entry1 "+realString(realSub(entry,realDiv(first,pivot)))+"\n");
791-
entry := realSub(entry,realMul(first,pivot));
792-
arrayUpdate(A,pos,entry);
793-
b_entry := arrayGet(b,pos);
794-
pivot := arrayGet(b,(pivotIdx-1)*n+ic);
795-
b_entry := b_entry - realMul(first,pivot);
796-
arrayUpdate(b,pos,b_entry);
787+
entry := realSub(entry,realMul(first,pivot));
788+
arrayUpdate(A[ir],ic,entry);
789+
b_entry := arrayGet(b,pos);
790+
pivot := arrayGet(b,(pivotIdx-1)*n+ic);
791+
b_entry := b_entry - realMul(first,pivot);
792+
arrayUpdate(b,pos,b_entry);
797793
end for;
798794
end for;
799795
//print("A\n"+stringDelimitList(List.map(arrayList(A),realString),", ")+"\n\n");
@@ -808,7 +804,7 @@ algorithm
808804
end gauss;
809805

810806
protected function getPivotElement "gets the highest element in the startIdx'th to n'th rows and the startidx'th column"
811-
input array<Real> A;
807+
input array<array<Real>> A;
812808
input list<Integer> rangeIn;
813809
input Integer startIdx;
814810
input Integer n;
@@ -819,7 +815,7 @@ protected
819815
Real entry;
820816
algorithm
821817
for i in rangeIn loop
822-
entry := arrayGet(A,(i-1)*n+startIdx);
818+
entry := arrayGet(A[i],startIdx);
823819
//print("i "+intString(i)+" pi "+intString(p_i)+" entry "+realString(entry)+"\n");
824820
if realAbs(entry) > value then
825821
value := entry;
@@ -1116,16 +1112,27 @@ public function evaluateConstantJacobian
11161112
output list<list<Real>> vals;
11171113
protected
11181114
array<array<Real>> valarr;
1115+
list<array<Real>> tmp2;
1116+
algorithm
1117+
valarr := evaluateConstantJacobianArray(size, jac);
1118+
tmp2 := arrayList(valarr);
1119+
vals := List.map(tmp2,arrayList);
1120+
end evaluateConstantJacobian;
1121+
1122+
protected function evaluateConstantJacobianArray
1123+
"Evaluate a constant Jacobian so we can solve a linear system during runtime"
1124+
input Integer size;
1125+
input list<tuple<Integer,Integer,BackendDAE.Equation>> jac;
1126+
output array<array<Real>> valarr;
1127+
protected
11191128
array<Real> tmp;
11201129
list<array<Real>> tmp2;
11211130
algorithm
11221131
tmp := arrayCreate(size,0.0);
11231132
tmp2 := List.map(List.fill(tmp,size),arrayCopy);
11241133
valarr := listArray(tmp2);
11251134
List.map1_0(jac,evaluateConstantJacobian2,valarr);
1126-
tmp2 := arrayList(valarr);
1127-
vals := List.map(tmp2,arrayList);
1128-
end evaluateConstantJacobian;
1135+
end evaluateConstantJacobianArray;
11291136

11301137
protected function evaluateConstantJacobian2
11311138
input tuple<Integer,Integer,BackendDAE.Equation> jac;

0 commit comments

Comments
 (0)