Skip to content

Commit

Permalink
+ Fix #2387. Handle Matrix Multiplication of record types according t…
Browse files Browse the repository at this point in the history
…o spec.

git-svn-id: https://openmodelica.org/svn/OpenModelica/trunk@17828 f25d12d1-65f4-0310-ae8a-bbce733d8d8e
  • Loading branch information
mahge committed Oct 22, 2013
1 parent 4797d69 commit f9a7bc6
Show file tree
Hide file tree
Showing 3 changed files with 235 additions and 2 deletions.
7 changes: 6 additions & 1 deletion Compiler/FrontEnd/Expression.mo
Original file line number Diff line number Diff line change
Expand Up @@ -6967,6 +6967,11 @@ algorithm
// any other call is a function call
case (DAE.CALL(path = _)) then true;

// mahge: Commented out because it doesn't seem neccessary to
// traverse all expressions. It is not needed (I think) and it
// is expensive because success here means exps will be elaborated
// again by InstSecion.condenseArrayEquation for no apparent use .
/*
// partial evaluation functions
case (DAE.PARTEVALFUNCTION(path = _, expList = elst)) // stefan
equation
Expand Down Expand Up @@ -7110,7 +7115,7 @@ algorithm
true = containFunctioncall(e2);
then
true;

*/
// anything else
case (_) then false;

Expand Down
100 changes: 100 additions & 0 deletions Compiler/FrontEnd/ExpressionSimplify.mo
Original file line number Diff line number Diff line change
Expand Up @@ -5012,5 +5012,105 @@ algorithm
end match;
end condSimplifyAddSymbolicOperation;


public function simplifyMatrixProductOfRecords
"mahge: Simplifies matrix multiplication of record types by using overloaded
multiplication and addition functions."
input DAE.Exp inMatrix1;
input DAE.Exp inMatrix2;
input Absyn.Path mulFunc;
input Absyn.Path sumFunc;
output DAE.Exp outProduct;
protected
DAE.Exp mat1, mat2;
algorithm
mat1 := Expression.matrixToArray(inMatrix1);
mat2 := Expression.matrixToArray(inMatrix2);
// Transpose the second matrix. This makes it easier to do the multiplication,
// since we can do row-row multiplications instead of row-column.
mat2 := Expression.transposeArray(mat2);
outProduct := simplifyMatrixProductOfRecords2(mat1, mat2, mulFunc, sumFunc);
end simplifyMatrixProductOfRecords;


protected function simplifyMatrixProductOfRecords2
" Simplifies the scalar product of two vectors of record types using overloaded
scalar addition and multiplication functions."
input DAE.Exp inMatrix1;
input DAE.Exp inMatrix2;
input Absyn.Path mulFunc;
input Absyn.Path sumFunc;
output DAE.Exp outProduct;
algorithm
outProduct := matchcontinue(inMatrix1, inMatrix2, mulFunc, sumFunc)
local
DAE.Dimension n, m, p;
list<DAE.Exp> expl1, expl2;
DAE.Type ty, row_ty;
DAE.TypeSource tp;
list<list<DAE.Exp>> matrix;

// Matrix-matrix multiplication, c[n, p] = a[n, m] * b[m, p].
case (DAE.ARRAY(ty = DAE.T_ARRAY(ty, {n, m}, tp), array = expl1),
DAE.ARRAY(ty = DAE.T_ARRAY(dims = {p, _}), array = expl2), _, _)
equation
// c[i, j] = a[i, :] * b[:, j] for i in 1:n, j in 1:p
matrix = List.map3(expl1, simplifyMatrixProductOfRecords3, expl2, mulFunc, sumFunc);
row_ty = DAE.T_ARRAY(ty, {p}, tp);
expl1 = List.map2(matrix, Expression.makeArray, row_ty, true);
then
DAE.ARRAY(DAE.T_ARRAY(ty, {n, p}, tp), false, expl1);

end matchcontinue;
end simplifyMatrixProductOfRecords2;

protected function simplifyMatrixProductOfRecords3
"mahge: Simplifies the scalar product of two vectors of record types using overloaded
scalar addition and multiplication functions."
input DAE.Exp inRow;
input list<DAE.Exp> inMatrix;
input Absyn.Path mulFunc;
input Absyn.Path sumFunc;
output list<DAE.Exp> outRow;
algorithm
outRow := List.map3(inMatrix, simplifyScalarProductOfRecords, inRow, mulFunc, sumFunc);
end simplifyMatrixProductOfRecords3;


public function simplifyScalarProductOfRecords
"mahge: Simplifies the scalar product of two vectors of record types using overloaded
scalar addition and multiplication functions."
input DAE.Exp inVector1;
input DAE.Exp inVector2;
input Absyn.Path mulFunc;
input Absyn.Path sumFunc;
output DAE.Exp outProduct;
algorithm
outProduct := match(inVector1, inVector2, mulFunc, sumFunc)
local
list<DAE.Exp> expl, expl1, expl2;
DAE.Exp exp;
Type tp;

case (DAE.ARRAY(array = expl1), DAE.ARRAY(array = expl2), _, _)
equation
expl = List.threadMap1(expl2, expl1, makeDaeCall, mulFunc);
exp = List.reduce1(expl, makeDaeCall, sumFunc);
then
exp;

end match;
end simplifyScalarProductOfRecords;

protected function makeDaeCall
input DAE.Exp inArg1;
input DAE.Exp inArg2;
input Absyn.Path funcPath;
output DAE.Exp outExp;
algorithm
/* mahge: TODO: Fix the type of the call attributes. Type should propagate and reach here from handleMatMultOfRecords*/
outExp := DAE.CALL(funcPath,{inArg1,inArg2},DAE.CALL_ATTR(DAE.T_UNKNOWN_DEFAULT,false,false,false,DAE.NO_INLINE(),DAE.NO_TAIL()));
end makeDaeCall;

end ExpressionSimplify;

130 changes: 129 additions & 1 deletion Compiler/FrontEnd/Static.mo
Original file line number Diff line number Diff line change
Expand Up @@ -13160,7 +13160,7 @@ algorithm
equation
str1 = "\n" +&
"- Failed to deoverload operator '" +& Dump.opSymbol(inOper) +& "' " +&
" for record of type: '" +& Absyn.pathString(Absyn.pathPrefix(inPath));
" for record of type: '" +& Absyn.pathString(Absyn.pathPrefix(inPath)) +& "'";
Error.addSourceMessage(Error.OPERATOR_OVERLOADING_ERROR,
{str1}, inInfo);
then fail();
Expand Down Expand Up @@ -13367,6 +13367,7 @@ algorithm
DAE.Properties prop, props1, props2;
Absyn.Exp absexp1, absexp2;
Boolean lastRound;
DAE.Dimension n,m1,m2,p;

// handle tuple op non_tuple
case (_, _, aboper, props1 as DAE.PROP_TUPLE(type_ = _), exp1, props2 as DAE.PROP(type_ = _), exp2, _, _, _, _, _, _, _)
Expand Down Expand Up @@ -13401,6 +13402,33 @@ algorithm
warnUnsafeRelations(inEnv,AbExp,const, type1,type2,exp1,exp2,oper,inPre);
then
(inCache,exp, prop);

/* We have a matrix multiplication of records. According to Spec. 3.2 Section 14.4 and 10.6.4, this should be handled
the same way as matrix multilication of numeric matrics.
- Not sure what will happen when users want to overload multiplication '*' for matrices of their records with their own algorithm.
Which one should be chosen?
- Also if the user hasn't overloaded either of '+' or '*'(for scalar records) then what should happen? The matrix multiplication needs both to be overloaded.
*/
case (cache, env, Absyn.MUL(), DAE.PROP(type1,const1), exp1, DAE.PROP(type2,const2), exp2, _, absexp1, absexp2, _, _, _, _)
equation
true = typeIsRecord(Types.arrayElementType(type1));
true = typeIsRecord(Types.arrayElementType(type2));
2 = Types.numberOfDimensions(type1);
2 = Types.numberOfDimensions(type2);
n = Types.getDimensionNth(type1, 1);
m1 = Types.getDimensionNth(type1, 2);
m2 = Types.getDimensionNth(type2, 1);
p = Types.getDimensionNth(type2, 2);

true = isValidMatrixProductDims(m1, m2);
otype = Types.arrayElementType(type1);
otype = Types.liftArrayListDims(otype, {n, p});

exp = handleMatMultOfRecords(cache,env,type1,type2,exp1,exp2,inInfo);
const = Types.constAnd(const1, const2);
prop = DAE.PROP(otype,const);
then
(inCache,exp, prop);

// The order of this two cases determines the priority given to operators
// Now left has priority for all.
Expand Down Expand Up @@ -13433,6 +13461,106 @@ algorithm
end matchcontinue;
end operatorDeoverloadBinary;


protected function handleMatMultOfRecords
"handles matrix multiplication of record types. It looks up the scalar versions of overloaded
addition and multiplication operations and uses them to expand and simplify the matrix multiplication."
input Env.Cache inCache;
input Env.Env inEnv;
input DAE.Type inType1;
input DAE.Type inType2;
input DAE.Exp inDAEExp1;
input DAE.Exp inDAEExp2;
input Absyn.Info inInfo;
output DAE.Exp outDAEExp;
algorithm
(outDAEExp) :=
matchcontinue (inCache,inEnv,inType1,inType2,inDAEExp1,inDAEExp2,inInfo)
local
Absyn.Path path, multPath, sumPath;
list<Absyn.Path> operNames;
Env.Env recordEnv,env;
SCode.Element operatorCl;
Env.Cache cache;
list<DAE.Type> types,scalartypes,arraytypes;
DAE.Type type1, type2, funcType;

case (cache, env, type1, type2, _, _, _)
equation
path = getRecordPath(type1);
path = Absyn.makeFullyQualified(path);
(cache,_,recordEnv) = Lookup.lookupClass(cache,env,path, false);

// Get the overloaded scalar multiplication function
multPath = getOverloadedScalarOperator(cache, recordEnv, path, "*", inInfo);
// Get the overloaded scalar addition function
sumPath = getOverloadedScalarOperator(cache, recordEnv, path, "+", inInfo);

outDAEExp = ExpressionSimplify.simplifyMatrixProductOfRecords(inDAEExp1,inDAEExp2,multPath,sumPath);
then
(outDAEExp);
case (_, _, _, _, _, _, _)
equation
// Error.addSourceMessage(Error.INTERNAL_ERROR, {"- Static.handleMatMultOfRecords failed."}, inInfo);
then fail();

end matchcontinue;
end handleMatMultOfRecords;

protected function getOverloadedScalarOperator
"Given the symobl of an operator this function finds its overloaded operator function (scalar version)
and returns the full path to it. Currently used in handleMatMultOfRecords to find '*' and '+' operators
to handle matrix multiplication of records."
input Env.Cache inCache;
input Env.Env inRecordEnv;
input Absyn.Path inRecordPath;
input String opSymbol;
input Absyn.Info inInfo;
output Absyn.Path outMultPath;
algorithm
(outMultPath) :=
matchcontinue (inCache,inRecordEnv,inRecordPath,opSymbol,inInfo)
local
Env.Cache cache;
Absyn.Path path;
list<Absyn.Path> operNames;
Env.Env operatorEnv;
SCode.Element operatorCl;
list<DAE.Type> types,scalartypes;
DAE.Type funcType;
String str1;

case (_, _, _, _, _)
equation
str1 = "'" +& opSymbol +& "'";
path = Absyn.joinPaths(inRecordPath, Absyn.IDENT(str1));

// check if the operator is defined. i.e overloaded
(cache,operatorCl,operatorEnv) = Lookup.lookupClass(inCache,inRecordEnv,path, false);
true = SCode.isOperator(operatorCl);

// get the list of functions in the operator. there can be multiple options
operNames = SCodeUtil.getListofQualOperatorFuncsfromOperator(operatorCl);
(cache,types) = Lookup.lookupFunctionsListInEnv(inCache, operatorEnv, operNames, inInfo, {});

(_, scalartypes) = List.splitOnTrue(types,isFuncWithArrayInput);
funcType::_ = scalartypes;
path = Types.getClassname(funcType);
then
path;

case (_, _, _, _, _)
equation
str1 = "- Failed to find scalar version of overloaded operator '" +& opSymbol +& "'" +&
" for expanding matrix multiplication of record type: '" +& Absyn.pathStringNoQual(inRecordPath) +&
"'. OMC will try to vectorize the multiplication";
Error.addSourceMessage(Error.OPERATOR_OVERLOADING_ERROR,
{str1}, inInfo);
then fail();
end matchcontinue;
end getOverloadedScalarOperator;


protected function operatorDeoverloadUnary
"used to resolve unary operations.

Expand Down

0 comments on commit f9a7bc6

Please sign in to comment.