Skip to content

Commit

Permalink
Handle fill in ExpandExp (#8391)
Browse files Browse the repository at this point in the history
- Expand fill calls in ExpandExp.
- Move the implementation in Ceval.evalBuiltinFills to a more generic
  Expression.fillArgs function, since the functionality is used in
  several places in the compiler.
  • Loading branch information
perost committed Jan 10, 2022
1 parent a499198 commit d469b93
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 31 deletions.
36 changes: 12 additions & 24 deletions OMCompiler/Compiler/NFFrontEnd/NFCeval.mo
Expand Up @@ -2199,30 +2199,18 @@ public
function evalBuiltinFill
input list<Expression> args;
output Expression result;
algorithm
result := evalBuiltinFill2(listHead(args), listRest(args));
end evalBuiltinFill;

function evalBuiltinFill2
input Expression fillValue;
input list<Expression> dims;
output Expression result = fillValue;
protected
Integer dim_size;
list<Expression> arr;
Type arr_ty = Expression.typeOf(result);
Expression fill_exp;
list<Expression> dims;
algorithm
for d in listReverse(dims) loop
() := match d
case Expression.INTEGER(value = dim_size) then ();
else algorithm printWrongArgsError(getInstanceName(), {d}, sourceInfo()); then fail();
end match;

arr := list(result for e in 1:dim_size);
arr_ty := Type.liftArrayLeft(arr_ty, Dimension.fromInteger(dim_size));
result := Expression.makeArray(arr_ty, arr, Expression.isLiteral(fillValue));
end for;
end evalBuiltinFill2;
try
fill_exp :: dims := args;
result := Expression.fillArgs(fill_exp, dims);
else
printWrongArgsError(getInstanceName(), args, sourceInfo());
fail();
end try;
end evalBuiltinFill;

protected
function evalBuiltinFloor
Expand Down Expand Up @@ -2514,7 +2502,7 @@ function evalBuiltinOnes
input list<Expression> args;
output Expression result;
algorithm
result := evalBuiltinFill2(Expression.INTEGER(1), args);
result := evalBuiltinFill(Expression.INTEGER(1) :: args);
end evalBuiltinOnes;

function evalBuiltinProduct
Expand Down Expand Up @@ -2874,7 +2862,7 @@ function evalBuiltinZeros
input list<Expression> args;
output Expression result;
algorithm
result := evalBuiltinFill2(Expression.INTEGER(0), args);
result := evalBuiltinFill(Expression.INTEGER(0) :: args);
end evalBuiltinZeros;

function evalUriToFilename
Expand Down
21 changes: 15 additions & 6 deletions OMCompiler/Compiler/NFFrontEnd/NFExpandExp.mo
Expand Up @@ -280,12 +280,13 @@ public
Absyn.Path fn_path = Function.nameConsiderBuiltin(fn);
algorithm
(outExp, expanded) := match AbsynUtil.pathFirstIdent(fn_path)
case "cat" then expandBuiltinCat(args, call);
case "der" then expandBuiltinGeneric(call);
case "diagonal" then expandBuiltinDiagonal(listHead(args));
case "pre" then expandBuiltinGeneric(call);
case "previous" then expandBuiltinGeneric(call);
case "promote" then expandBuiltinPromote(args);
case "cat" then expandBuiltinCat(args, call);
case "der" then expandBuiltinGeneric(call);
case "diagonal" then expandBuiltinDiagonal(listHead(args));
case "fill" then expandBuiltinFill(args);
case "pre" then expandBuiltinGeneric(call);
case "previous" then expandBuiltinGeneric(call);
case "promote" then expandBuiltinPromote(args);
case "transpose" then expandBuiltinTranspose(listHead(args));
end match;
end expandBuiltinCall;
Expand Down Expand Up @@ -336,6 +337,14 @@ public
end if;
end expandBuiltinDiagonal;

function expandBuiltinFill
input list<Expression> args;
output Expression outExp;
output Boolean expanded = true;
algorithm
outExp := Expression.fillArgs(listHead(args), listRest(args));
end expandBuiltinFill;

function expandBuiltinTranspose
input Expression arg;
output Expression outExp;
Expand Down
21 changes: 21 additions & 0 deletions OMCompiler/Compiler/NFFrontEnd/NFExpression.mo
Expand Up @@ -3826,6 +3826,27 @@ public
end for;
end fillType;

function fillArgs
"Creates an array from the given fill expression and list of dimensions,
similar to fill(fillExp, dims...). Fails if not all dimensions can be
converted to Integer values."
input Expression fillExp;
input list<Expression> dims;
output Expression result = fillExp;
protected
Integer dim_size;
list<Expression> arr;
Type arr_ty = typeOf(result);
Boolean is_literal = isLiteral(fillExp);
algorithm
for d in listReverse(dims) loop
dim_size := toInteger(d);
arr := list(result for e in 1:dim_size);
arr_ty := Type.liftArrayLeft(arr_ty, Dimension.fromInteger(dim_size));
result := Expression.makeArray(arr_ty, arr, is_literal);
end for;
end fillArgs;

function liftArray
"Creates an array with the given dimension, where each element is the given
expression. Example: liftArray([3], 1) => {1, 1, 1}"
Expand Down
2 changes: 1 addition & 1 deletion OMCompiler/Compiler/NFFrontEnd/NFSimplifyExp.mo
Expand Up @@ -355,7 +355,7 @@ function simplifyFill
output Expression exp;
algorithm
if List.all(dimArgs, Expression.isLiteral) then
exp := Ceval.evalBuiltinFill2(fillArg, dimArgs);
exp := Expression.fillArgs(fillArg, dimArgs);
else
exp := Expression.CALL(call);
end if;
Expand Down

0 comments on commit d469b93

Please sign in to comment.