Skip to content

Commit

Permalink
[MINOR] Cleanup code quality inter-procedural analysis / recompiler
Browse files Browse the repository at this point in the history
  • Loading branch information
mboehm7 committed Mar 22, 2024
1 parent af2c896 commit 6fd96bb
Show file tree
Hide file tree
Showing 16 changed files with 122 additions and 112 deletions.
16 changes: 9 additions & 7 deletions src/main/java/org/apache/sysds/hops/Hop.java
Expand Up @@ -23,6 +23,8 @@
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
Expand Down Expand Up @@ -954,7 +956,7 @@ public void addInput( Hop h ) {
h._parent.add(this);
}

public void addAllInputs( ArrayList<Hop> list ) {
public void addAllInputs( List<Hop> list ) {
for( Hop h : list )
addInput(h);
}
Expand Down Expand Up @@ -1130,13 +1132,13 @@ public boolean colsKnown() {
return _dataType.isScalar() || _dc.colsKnown();
}

public static void resetVisitStatus( ArrayList<Hop> hops ) {
public static void resetVisitStatus( List<Hop> hops ) {
if( hops != null )
for( Hop hopRoot : hops )
hopRoot.resetVisitStatus();
}

public static void resetVisitStatus( ArrayList<Hop> hops, boolean force ) {
public static void resetVisitStatus( List<Hop> hops, boolean force ) {
if( !force )
resetVisitStatus(hops);
else {
Expand Down Expand Up @@ -1413,23 +1415,23 @@ public void refreshRowsParameterInformation( Hop input, LocalVariableMap vars )
setDim1(computeSizeInformation(input, vars));
}

public void refreshRowsParameterInformation( Hop input, LocalVariableMap vars, HashMap<Long,Long> memo ) {
public void refreshRowsParameterInformation( Hop input, LocalVariableMap vars, Map<Long,Long> memo ) {
setDim1(computeSizeInformation(input, vars, memo));
}

public void refreshColsParameterInformation( Hop input, LocalVariableMap vars ) {
setDim2(computeSizeInformation(input, vars));
}

public void refreshColsParameterInformation( Hop input, LocalVariableMap vars, HashMap<Long,Long> memo ) {
public void refreshColsParameterInformation( Hop input, LocalVariableMap vars, Map<Long,Long> memo ) {
setDim2(computeSizeInformation(input, vars, memo));
}

public long computeSizeInformation( Hop input, LocalVariableMap vars ) {
return computeSizeInformation(input, vars, new HashMap<Long,Long>());
}

public long computeSizeInformation( Hop input, LocalVariableMap vars, HashMap<Long,Long> memo )
public long computeSizeInformation( Hop input, LocalVariableMap vars, Map<Long,Long> memo )
{
long ret = -1;
try {
Expand Down Expand Up @@ -1460,7 +1462,7 @@ public static double computeBoundsInformation( Hop input, LocalVariableMap vars
return computeBoundsInformation(input, vars, new HashMap<Long, Double>());
}

public static double computeBoundsInformation( Hop input, LocalVariableMap vars, HashMap<Long, Double> memo ) {
public static double computeBoundsInformation( Hop input, LocalVariableMap vars, Map<Long, Double> memo ) {
double ret = Double.MAX_VALUE;
try {
ret = OptimizerUtils.rEvalSimpleDoubleExpression(input, memo, vars);
Expand Down
18 changes: 9 additions & 9 deletions src/main/java/org/apache/sysds/hops/OptimizerUtils.java
Expand Up @@ -1457,7 +1457,7 @@ public static long getNumIterations(ForProgramBlock fpb, LocalVariableMap vars,
* @param valMemo ?
* @return size expression
*/
public static long rEvalSimpleLongExpression( Hop root, HashMap<Long, Long> valMemo )
public static long rEvalSimpleLongExpression( Hop root, Map<Long, Long> valMemo )
{
long ret = Long.MAX_VALUE;

Expand All @@ -1470,7 +1470,7 @@ public static long rEvalSimpleLongExpression( Hop root, HashMap<Long, Long> valM
return ret;
}

public static long rEvalSimpleLongExpression( Hop root, HashMap<Long, Long> valMemo, LocalVariableMap vars )
public static long rEvalSimpleLongExpression( Hop root, Map<Long, Long> valMemo, LocalVariableMap vars )
{
long ret = Long.MAX_VALUE;

Expand All @@ -1483,7 +1483,7 @@ public static long rEvalSimpleLongExpression( Hop root, HashMap<Long, Long> valM
return ret;
}

public static double rEvalSimpleDoubleExpression( Hop root, HashMap<Long, Double> valMemo )
public static double rEvalSimpleDoubleExpression( Hop root, Map<Long, Double> valMemo )
{
//memoization (prevent redundant computation of common subexpr)
if( valMemo.containsKey(root.getHopID()) )
Expand All @@ -1510,7 +1510,7 @@ else if( root instanceof TernaryOp )
return ret;
}

public static double rEvalSimpleDoubleExpression( Hop root, HashMap<Long, Double> valMemo, LocalVariableMap vars )
public static double rEvalSimpleDoubleExpression( Hop root, Map<Long, Double> valMemo, LocalVariableMap vars )
{
//memoization (prevent redundant computation of common subexpr)
if( valMemo.containsKey(root.getHopID()) )
Expand Down Expand Up @@ -1538,7 +1538,7 @@ else if( root instanceof DataOp ) {
return ret;
}

protected static double rEvalSimpleUnaryDoubleExpression( Hop root, HashMap<Long, Double> valMemo )
protected static double rEvalSimpleUnaryDoubleExpression( Hop root, Map<Long, Double> valMemo )
{
//memoization (prevent redundant computation of common subexpr)
if( valMemo.containsKey(root.getHopID()) )
Expand Down Expand Up @@ -1576,7 +1576,7 @@ else if( uroot.getOp() == OpOp1.NCOL )
return ret;
}

protected static double rEvalSimpleUnaryDoubleExpression( Hop root, HashMap<Long, Double> valMemo, LocalVariableMap vars )
protected static double rEvalSimpleUnaryDoubleExpression( Hop root, Map<Long, Double> valMemo, LocalVariableMap vars )
{
//memoization (prevent redundant computation of common subexpr)
if( valMemo.containsKey(root.getHopID()) )
Expand Down Expand Up @@ -1614,7 +1614,7 @@ else if( uroot.getOp() == OpOp1.NCOL )
return ret;
}

protected static double rEvalSimpleBinaryDoubleExpression( Hop root, HashMap<Long, Double> valMemo )
protected static double rEvalSimpleBinaryDoubleExpression( Hop root, Map<Long, Double> valMemo )
{
//memoization (prevent redundant computation of common subexpr)
if( valMemo.containsKey(root.getHopID()) )
Expand Down Expand Up @@ -1649,7 +1649,7 @@ protected static double rEvalSimpleBinaryDoubleExpression( Hop root, HashMap<Lon
return ret;
}

protected static double rEvalSimpleTernaryDoubleExpression( Hop root, HashMap<Long, Double> valMemo ) {
protected static double rEvalSimpleTernaryDoubleExpression( Hop root, Map<Long, Double> valMemo ) {
//memoization (prevent redundant computation of common subexpr)
if( valMemo.containsKey(root.getHopID()) )
return valMemo.get(root.getHopID());
Expand All @@ -1666,7 +1666,7 @@ else if( HopRewriteUtils.isLiteralOfValue(troot.getInput(0), false) )
return ret;
}

protected static double rEvalSimpleBinaryDoubleExpression( Hop root, HashMap<Long, Double> valMemo, LocalVariableMap vars )
protected static double rEvalSimpleBinaryDoubleExpression( Hop root, Map<Long, Double> valMemo, LocalVariableMap vars )
{
//memoization (prevent redundant computation of common subexpr)
if( valMemo.containsKey(root.getHopID()) )
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/org/apache/sysds/hops/ReorgOp.java
Expand Up @@ -30,7 +30,7 @@
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;

import java.util.ArrayList;
import java.util.List;

/**
* Reorg (cell) operation: aij
Expand Down Expand Up @@ -66,7 +66,7 @@ public ReorgOp(String l, DataType dt, ValueType vt, ReOrgOp o, Hop inp)
refreshSizeInformation();
}

public ReorgOp(String l, DataType dt, ValueType vt, ReOrgOp o, ArrayList<Hop> inp)
public ReorgOp(String l, DataType dt, ValueType vt, ReOrgOp o, List<Hop> inp)
{
super(l, dt, vt);
_op = o;
Expand Down
33 changes: 17 additions & 16 deletions src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java
Expand Up @@ -24,6 +24,7 @@
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.Stack;
Expand Down Expand Up @@ -58,18 +59,18 @@ public class FunctionCallGraph

//unrolled function call graph, in call direction
//(mapping from function keys to called function keys)
private final HashMap<String, HashSet<String>> _fGraph;
private final Map<String, Set<String>> _fGraph;

//program-wide function call operators per target function
//(mapping from function keys to set of its function calls)
private final HashMap<String, ArrayList<FunctionOp>> _fCalls;
private final HashMap<String, ArrayList<StatementBlock>> _fCallsSB;
private final Map<String, List<FunctionOp>> _fCalls;
private final Map<String, List<StatementBlock>> _fCallsSB;

//subset of direct or indirect recursive functions
private final HashSet<String> _fRecursive;
private final Set<String> _fRecursive;

//subset of side-effect-free functions
private final HashSet<String> _fSideEffectFree;
private final Set<String> _fSideEffectFree;

// a boolean value to indicate if exists the second order function (e.g. eval, paramserv)
// and the UDFs that are marked secondorder="true"
Expand Down Expand Up @@ -168,7 +169,7 @@ public void removeFunctionCalls(String fkey) {
_fCallsSB.remove(fkey);
_fRecursive.remove(fkey);
_fGraph.remove(fkey);
for( Entry<String, HashSet<String>> e : _fGraph.entrySet() )
for( Entry<String, Set<String>> e : _fGraph.entrySet() )
e.getValue().removeIf(s -> s.equals(fkey));
}

Expand All @@ -195,8 +196,8 @@ public void removeFunctionCall(String fkey, FunctionOp fop, StatementBlock sb) {
* @param fkey new function key of called function
*/
public void replaceFunctionCalls(String fkeyOld, String fkey) {
ArrayList<FunctionOp> fopTmp = _fCalls.get(fkeyOld);
ArrayList<StatementBlock> sbTmp =_fCallsSB.get(fkeyOld);
List<FunctionOp> fopTmp = _fCalls.get(fkeyOld);
List<StatementBlock> sbTmp =_fCallsSB.get(fkeyOld);
_fCalls.remove(fkeyOld);
_fCallsSB.remove(fkeyOld);
_fCalls.put(fkey, fopTmp);
Expand All @@ -205,7 +206,7 @@ public void replaceFunctionCalls(String fkeyOld, String fkey) {
_fRecursive.remove(fkeyOld);
_fSideEffectFree.remove(fkeyOld);
_fGraph.remove(fkeyOld);
for( HashSet<String> hs : _fGraph.values() )
for( Set<String> hs : _fGraph.values() )
hs.remove(fkeyOld);
}

Expand Down Expand Up @@ -350,7 +351,7 @@ private boolean constructFunctionCallGraph(DMLProgram prog) {
try {
//construct the main function call graph
Stack<String> fstack = new Stack<>();
HashSet<String> lfset = new HashSet<>();
Set<String> lfset = new HashSet<>();
_fGraph.put(MAIN_FUNCTION_KEY, new HashSet<String>());
for( StatementBlock sblk : prog.getStatementBlocks() )
ret |= rConstructFunctionCallGraph(MAIN_FUNCTION_KEY, sblk, fstack, lfset);
Expand All @@ -373,7 +374,7 @@ private boolean constructFunctionCallGraph(StatementBlock sb) {

try {
Stack<String> fstack = new Stack<>();
HashSet<String> lfset = new HashSet<>();
Set<String> lfset = new HashSet<>();
_fGraph.put(MAIN_FUNCTION_KEY, new HashSet<String>());
return rConstructFunctionCallGraph(MAIN_FUNCTION_KEY, sb, fstack, lfset);
}
Expand All @@ -382,7 +383,7 @@ private boolean constructFunctionCallGraph(StatementBlock sb) {
}
}

private boolean rConstructFunctionCallGraph(String fkey, StatementBlock sb, Stack<String> fstack, HashSet<String> lfset) {
private boolean rConstructFunctionCallGraph(String fkey, StatementBlock sb, Stack<String> fstack, Set<String> lfset) {
boolean ret = false;
if (sb instanceof WhileStatementBlock) {
WhileStatement ws = (WhileStatement)sb.getStatement(0);
Expand All @@ -408,7 +409,7 @@ else if (sb instanceof FunctionStatementBlock) {
}
else {
// For generic StatementBlock
ArrayList<Hop> hopsDAG = sb.getHops();
List<Hop> hopsDAG = sb.getHops();
if( hopsDAG == null || hopsDAG.isEmpty() )
return false; //nothing to do

Expand All @@ -428,7 +429,7 @@ else if (sb instanceof FunctionStatementBlock) {
return ret;
}

private boolean rConstructFunctionCallGraph(Hop hop, String fkey, StatementBlock sb, Stack<String> fstack, HashSet<String> lfset) {
private boolean rConstructFunctionCallGraph(Hop hop, String fkey, StatementBlock sb, Stack<String> fstack, Set<String> lfset) {
boolean ret = false;
if( hop.isVisited() )
return ret;
Expand All @@ -452,7 +453,7 @@ private boolean rConstructFunctionCallGraph(Hop hop, String fkey, StatementBlock
return ret;
}

private boolean addFunctionOpToGraph(FunctionOp fop, String fkey, StatementBlock sb, Stack<String> fstack, HashSet<String> lfset) {
private boolean addFunctionOpToGraph(FunctionOp fop, String fkey, StatementBlock sb, Stack<String> fstack, Set<String> lfset) {
try{
boolean ret = false;
String lfkey = fop.getFunctionKey();
Expand Down Expand Up @@ -523,7 +524,7 @@ else if (sb instanceof ForStatementBlock) {
}
else {
// For generic StatementBlock
ArrayList<Hop> hopsDAG = sb.getHops();
List<Hop> hopsDAG = sb.getHops();
if( hopsDAG == null || hopsDAG.isEmpty() )
return false; //nothing to do
//function ops can only occur as root nodes of the dag
Expand Down
Expand Up @@ -252,7 +252,7 @@ else if( InterProceduralAnalysis.ALLOW_MULTIPLE_FUNCTION_CALLS ) {
if( flist == null || flist.isEmpty() ) //robustness removed functions
continue;
FunctionOp first = flist.get(0);
HashSet<Integer> tmp = new HashSet<>();
Set<Integer> tmp = new HashSet<>();
for( int j=0; j<first.getInput().size(); j++ ) {
//if nnz known it is safe to propagate those nnz because for multiple calls
//we checked of equivalence and hence all calls have the same nnz
Expand All @@ -271,7 +271,7 @@ else if( InterProceduralAnalysis.ALLOW_MULTIPLE_FUNCTION_CALLS ) {
continue;
FunctionOp first = flist.get(0);
//initialize w/ all literals of first call
HashSet<Integer> tmp = new HashSet<>();
Set<Integer> tmp = new HashSet<>();
for( int j=0; j<first.getInput().size(); j++ )
if( first.getInput().get(j) instanceof LiteralOp )
tmp.add(j);
Expand Down
Expand Up @@ -19,9 +19,10 @@

package org.apache.sysds.hops.ipa;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import org.apache.sysds.api.DMLScript;
import org.apache.sysds.hops.FunctionOp;
Expand Down Expand Up @@ -56,7 +57,7 @@ public boolean rewriteProgram (DMLProgram prog, FunctionCallGraph fgraph, Functi

try {
// Find the individual functions and statementblocks with non-determinism.
HashSet<String> ndfncs = new HashSet<>();
Set<String> ndfncs = new HashSet<>();
for (String fkey : fgraph.getReachableFunctions()) {
FunctionStatementBlock fsblock = prog.getFunctionStatementBlock(fkey);
FunctionStatement fnstmt = (FunctionStatement)fsblock.getStatement(0);
Expand Down Expand Up @@ -88,7 +89,7 @@ public boolean rewriteProgram (DMLProgram prog, FunctionCallGraph fgraph, Functi
return false;
}

private boolean rIsNonDeterministicFnc (String fname, ArrayList<StatementBlock> sbs)
private boolean rIsNonDeterministicFnc (String fname, List<StatementBlock> sbs)
{
boolean isND = false;
for (StatementBlock sb : sbs)
Expand Down Expand Up @@ -124,7 +125,7 @@ else if (sb instanceof IfStatementBlock) {
return isND;
}

private void rMarkNondeterministicSBs (ArrayList<StatementBlock> sbs, HashSet<String> ndfncs)
private void rMarkNondeterministicSBs (List<StatementBlock> sbs, Set<String> ndfncs)
{
for (StatementBlock sb : sbs)
{
Expand Down Expand Up @@ -156,7 +157,7 @@ else if (sb instanceof IfStatementBlock) {
}
}

private boolean rMarkNondeterministicHop(Hop hop, HashSet<String> ndfncs) {
private boolean rMarkNondeterministicHop(Hop hop, Set<String> ndfncs) {
if (hop.isVisited())
return false;

Expand All @@ -182,7 +183,7 @@ private boolean rIsNonDeterministicHop(Hop hop) {
return isND;
}

private void propagate2Callers (FunctionCallGraph fgraph, HashSet<String> ndfncs, HashSet<String> fstack, String fkey) {
private void propagate2Callers (FunctionCallGraph fgraph, Set<String> ndfncs, Set<String> fstack, String fkey) {
Collection<String> cfkeys = fgraph.getCalledFunctions(fkey);
if (cfkeys != null) {
for (String cfkey : cfkeys) {
Expand Down

0 comments on commit 6fd96bb

Please sign in to comment.