diff --git a/scripts/builtin/raJoin.dml b/scripts/builtin/raJoin.dml index 7fa7572a364..5d3335277d0 100644 --- a/scripts/builtin/raJoin.dml +++ b/scripts/builtin/raJoin.dml @@ -27,7 +27,7 @@ # A Matrix of left input data [shape: N x M] # colA Integer indicating the column index of matrix A to execute inner join command # B Matrix of right left data [shape: N x M] -# colA Integer indicating the column index of matrix B to execute inner join command +# colB Integer indicating the column index of matrix B to execute inner join command # method Join implementation method (nested-loop, sort-merge, hash, hash2) # ------------------------------------------------------------------------------ # diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java index 9ba3ea3ed77..bae27b65837 100644 --- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java +++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java @@ -302,6 +302,7 @@ public enum MemoryManager { */ public static boolean ALLOW_TRANSITIVE_SPARK_EXEC_TYPE = true; + public static boolean ALLOW_JOIN_REORDERING_REWRITE = true; /** * Enable prefetch and broadcast. Prefetch asynchronously calls acquireReadAndRelease() to trigger remote * operations, which would otherwise make the next instruction wait till completion. Broadcast allows diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java index c2602dba510..c526efdccd1 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java @@ -78,6 +78,7 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites) _dagRuleSet.add( new RewriteRemoveReadAfterWrite() ); //dependency: before blocksize _dagRuleSet.add( new RewriteBlockSizeAndReblock() ); _dagRuleSet.add( new RewriteInjectOOCTee() ); + if( OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION ) _dagRuleSet.add( new RewriteRemoveUnnecessaryCasts() ); if( OptimizerUtils.ALLOW_COMMON_SUBEXPRESSION_ELIMINATION ) @@ -93,7 +94,6 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites) _dagRuleSet.add( new RewriteInjectSparkPReadCheckpointing() ); //dependency: reblock if( OptimizerUtils.ALLOW_QUANTIZE_COMPRESS_REWRITE ) _dagRuleSet.add( new RewriteQuantizationFusedCompression() ); - //add statement block rewrite rules if( OptimizerUtils.ALLOW_BRANCH_REMOVAL ) _sbRuleSet.add( new RewriteRemoveUnnecessaryBranches() ); //dependency: constant folding @@ -119,6 +119,9 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites) _sbRuleSet.add( new MarkForLineageReuse() ); _sbRuleSet.add( new RewriteRemoveTransformEncodeMeta() ); _dagRuleSet.add( new RewriteNonScalarPrint() ); + if( OptimizerUtils.ALLOW_JOIN_REORDERING_REWRITE ) + _sbRuleSet.add( new RewriteJoinReordering() ); + } // DYNAMIC REWRITES (which do require size information) diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteJoinReordering.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteJoinReordering.java new file mode 100644 index 00000000000..e35cacc201b --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteJoinReordering.java @@ -0,0 +1,604 @@ +package org.apache.sysds.hops.rewrite; + +import java.util.ArrayList; +import java.util.BitSet; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; + +import org.apache.sysds.common.Types.DataType; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.hops.DataOp; +import org.apache.sysds.hops.FunctionOp; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.LiteralOp; +import org.apache.sysds.parser.StatementBlock; +import org.apache.sysds.parser.VariableSet; +import org.apache.sysds.parser.WhileStatement; +import org.apache.sysds.parser.WhileStatementBlock; +import org.apache.sysds.parser.DataIdentifier; +import org.apache.sysds.parser.ForStatement; +import org.apache.sysds.parser.ForStatementBlock; +import org.apache.sysds.parser.FunctionStatement; +import org.apache.sysds.parser.FunctionStatementBlock; +import org.apache.sysds.parser.IfStatement; +import org.apache.sysds.parser.IfStatementBlock; + +public class RewriteJoinReordering extends StatementBlockRewriteRule { + // This exception is thrown when we cannot determine the base dependencies of a + // given join. + private class UnknownCanonicalJoinException extends RuntimeException { + private UnknownCanonicalJoinException() { + super(); + } + } + + // This exception is thrown when we cannot determine the dimension information + // for a given non-raJoin HOP. + private class UnknownDimensionInfoException extends RuntimeException { + private UnknownDimensionInfoException() { + super(); + } + } + + private boolean isRaJoin(Hop node) { + if (node instanceof FunctionOp fnode) { + return fnode.getFunctionNamespace().equals(".builtinNS") + && fnode.getFunctionName().equals("m_raJoin"); + } + return false; + } + + private boolean isLiteralInt(Hop node) { + if (node instanceof LiteralOp) { + return node.getValueType() == ValueType.INT64; + } + return false; + } + + private boolean isKnownMatrix(Hop hop) { + return hop.getDim1() > 0 && hop.getDim2() > 0; + } + + @Override + public boolean createsSplitDag() { + return false; + } + + /** + * Collect all raJoin calls + * + * @param sb current statement block to search from + * @param joinMap a mapping from the bound output variable name to the index of + * the join in the `joins` list. + * @param joins a list to accumulate all found raJoins + */ + private void collectRaJoin(HashMap hopToSb, StatementBlock sb, HashMap joinMap, + ArrayList joins) { + if (sb instanceof FunctionStatementBlock) { + FunctionStatementBlock fsb = (FunctionStatementBlock) sb; + FunctionStatement fstmt = (FunctionStatement) fsb.getStatement(0); + for (StatementBlock sbi : fstmt.getBody()) + collectRaJoin(hopToSb, sbi, joinMap, joins); + } else if (sb instanceof WhileStatementBlock) { + WhileStatementBlock wsb = (WhileStatementBlock) sb; + WhileStatement wstmt = (WhileStatement) wsb.getStatement(0); + for (StatementBlock sbi : wstmt.getBody()) + collectRaJoin(hopToSb, sbi, joinMap, joins); + } else if (sb instanceof IfStatementBlock) { + IfStatementBlock isb = (IfStatementBlock) sb; + IfStatement istmt = (IfStatement) isb.getStatement(0); + for (StatementBlock sbi : istmt.getIfBody()) + collectRaJoin(hopToSb, sbi, joinMap, joins); + for (StatementBlock sbi : istmt.getElseBody()) + collectRaJoin(hopToSb, sbi, joinMap, joins); + } else if (sb instanceof ForStatementBlock) // incl parfor + { + ForStatementBlock fsb = (ForStatementBlock) sb; + ForStatement fstmt = (ForStatement) fsb.getStatement(0); + for (StatementBlock sbi : fstmt.getBody()) + collectRaJoin(hopToSb, sbi, joinMap, joins); + } else // generic (last-level) + { + /* + * Check for raJoins at this branch + */ + for (Hop hop : sb.getHops()) { + if (isRaJoin(hop)) { + FunctionOp fhop = (FunctionOp) hop; + processRaJoin(sb, hopToSb, fhop, joinMap, joins); + } + } + } + } + + /** + * Add an raJoin HOP to custom intermediate objects. + * + * @param fhop the raJoin Hop + * @param joinMap a mapping from the bound output variable name to the index of + * the join in the `joins` list. + * @param joins a list to accumulate all found raJoins + */ + private void processRaJoin(StatementBlock sb, HashMap hopToSb, FunctionOp fhop, + HashMap joinMap, ArrayList joins) { + Hop acol = fhop.getInput(1); + Hop bcol = fhop.getInput(3); + // only support literal values. + if (!isLiteralInt(acol) || !isLiteralInt(bcol)) { + return; + } + + for (String varName : fhop.getOutputVariableNames()) { + joinMap.put(varName, joins.size()); + } + joins.add(fhop); + hopToSb.put(fhop, sb); + } + + /** + * Find the topological order of all joins. + * + * @param joinMap + * @param joins + * @return the topological order of joins as indices of `joins` + */ + private ArrayList topoOrder(HashMap joinMap, ArrayList joins) { + ArrayList topoOrder = new ArrayList<>(); + boolean[] visited = new boolean[joins.size()]; + for (int i = 0; i < joins.size(); i++) + dfsOrder(joinMap, joins, topoOrder, visited, i); + Collections.reverse(topoOrder); + return topoOrder; + } + + /** + * DFS call to find the topological order. + * + * @param joinMap + * @param joins + * @param order + * @param visited + * @param i the current join index we are at + */ + private void dfsOrder(HashMap joinMap, ArrayList joins, ArrayList order, + boolean[] visited, int i) { + visited[i] = true; + FunctionOp join = joins.get(i); + Hop a = join.getInput(0); + Hop b = join.getInput(2); + // recurse if the matrix is not a base matrix. + if (!isKnownMatrix(a)) { + Integer next = joinMap.get(a.getName()); + if (next == null) + throw new UnknownCanonicalJoinException(); + if (!visited[next]) { + dfsOrder(joinMap, joins, order, visited, next); + } + } + if (!isKnownMatrix(b)) { + Integer next = joinMap.get(b.getName()); + if (next == null) + throw new UnknownCanonicalJoinException(); + if (!visited[next]) { + dfsOrder(joinMap, joins, order, visited, next); + } + } + order.add(i); + } + + /** + * rewrite all roots + * + * @param joinMap + * @param joins all raJoins + * @param order topological order of joins + */ + private void rewriteRoots(ArrayList sbs, HashMap hopToSb, + HashMap joinMap, ArrayList joins, ArrayList order) { + boolean[] visited = new boolean[joins.size()]; + for (int i : order) { + if (!visited[i]) { + try { + rewriteRoot(sbs, hopToSb, joinMap, joins, visited, i); + } catch (Exception e) { + // if it is a local exception, try rewriting the next root. + if ((e instanceof UnknownCanonicalJoinException) || (e instanceof UnknownDimensionInfoException)) { + continue; + } + throw e; + } + } + } + + HashSet consumedHops = new HashSet<>(); + for (int i = 0; i < joins.size(); i++) { + if (!visited[i]) + continue; + consumedHops.add(joins.get(i)); + HopRewriteUtils.cleanupUnreferenced(joins.get(i)); + } + for (Hop hop : hopToSb.keySet()) { + if (!consumedHops.contains(hop)) + continue; + hopToSb.get(hop).getHops().remove(hop); + } + } + + // Custom representation of nested join calls. + sealed interface JoinNode permits BaseNode, BinaryNode { + } + + private record BaseNode(int i) implements JoinNode { + }; + + private record BinaryNode(JoinNode left, long leftCol, JoinNode right, long rightCol, String method) + implements JoinNode { + }; + + private record Cost(long dim1, long dim2, long cost, JoinNode node) { + }; + + // Rewrite a single root + private void rewriteRoot(ArrayList sbs, HashMap hopToSb, + HashMap joinMap, ArrayList joins, boolean[] visited, int rootIndex) { + // get bases traversal = base relations(matrices) + FunctionOp root = joins.get(rootIndex); + ArrayList bases = new ArrayList<>(); + ArrayList basesLengthPrefixSum = new ArrayList<>(); + ArrayList canonicalJoins = new ArrayList<>(); + dfsInorder(joinMap, joins, canonicalJoins, visited, bases, basesLengthPrefixSum, rootIndex); + // convert all joins to joins between base relations. + HashMap dp = new HashMap<>(); + for (int i = 0; i < bases.size() - 1; i++) { + BitSet leftBS = new BitSet(); + BitSet rightBS = new BitSet(); + leftBS.set(i); + rightBS.set(i + 1); + CanonicalJoin validJoin = getValidJoin(canonicalJoins, leftBS, rightBS); + if (validJoin == null) { + continue; + } + BitSet bs = new BitSet(bases.size()); + bs.set(i); + bs.set(i + 1); + Hop left = bases.get(i); + Hop right = bases.get(i + 1); + + long dim1 = left.getDim1() * right.getDim1(); + long dim2 = left.getDim2() + right.getDim2(); + long cost = dim1 * dim2; + + long leftCol = validJoin.acol; + long rightCol = validJoin.bcol; + + JoinNode joinNode = new BinaryNode(new BaseNode(i), leftCol, new BaseNode(i + 1), rightCol, validJoin.method); + dp.put(bs, new Cost(dim1, dim2, cost, joinNode)); + } + for (int intervalLength = 2; intervalLength < bases.size(); intervalLength++) { + // join base relation from the left + for (int start = 1; start + intervalLength <= bases.size(); start++) { + BitSet leftBS = new BitSet(bases.size()); + leftBS.set(start - 1); + BitSet rightBS = new BitSet(bases.size()); + rightBS.set(start, start + intervalLength); + if (dp.get(rightBS) == null) { + continue; + } + CanonicalJoin validJoin = getValidJoin(canonicalJoins, leftBS, rightBS); + if (validJoin == null) { + continue; + } + + BitSet bs = new BitSet(bases.size()); + bs.set(start - 1, start + intervalLength); + + Hop left = bases.get(start - 1); + + Cost right = dp.get(rightBS); + + long dim1 = left.getDim1() * right.dim1; + long dim2 = left.getDim2() + right.dim2; + long cost = dim1 * dim2 + right.cost; + + long leftCol = validJoin.acol; + long rightCol = getRelativeCol(basesLengthPrefixSum, start, validJoin.bBaseIndex, validJoin.bcol); + JoinNode joinNode = new BinaryNode(new BaseNode(start - 1), leftCol, right.node, rightCol, validJoin.method); + dp.put(bs, new Cost(dim1, dim2, cost, joinNode)); + } + // join base relation from the right + for (int start = 0; start + intervalLength + 1 <= bases.size(); start++) { + BitSet leftBS = new BitSet(bases.size()); + leftBS.set(start, start + intervalLength); + BitSet rightBS = new BitSet(bases.size()); + rightBS.set(start + intervalLength); + BitSet bs = new BitSet(bases.size()); + bs.set(start, start + intervalLength + 1); + + if (dp.get(leftBS) == null) + continue; + CanonicalJoin validJoin = getValidJoin(canonicalJoins, leftBS, rightBS); + if (validJoin == null) + continue; + BitSet leftBs = new BitSet(bases.size()); + leftBs.set(start, start + intervalLength); + Cost left = dp.get(leftBs); + + Hop right = bases.get(start + intervalLength); + + long dim1 = left.dim1 * right.getDim1(); + long dim2 = left.dim2 + right.getDim2(); + long cost = dim1 * dim2 + left.cost; + + if (dp.get(bs) == null || cost < dp.get(bs).cost) { + long leftCol = getRelativeCol(basesLengthPrefixSum, start, validJoin.aBaseIndex, validJoin.acol); + long rightCol = validJoin.bcol; + JoinNode joinNode = new BinaryNode(left.node(), leftCol, new BaseNode(start + intervalLength), rightCol, + validJoin.method); + dp.put(bs, new Cost(dim1, dim2, cost, joinNode)); + } + } + } + BitSet fullBs = new BitSet(bases.size()); + fullBs.set(0, bases.size()); + JoinNode optimalJoin = dp.get(fullBs).node; + // System.out.println("optimalJoin: " + optimalJoin); + + // rewire the nodes. + StatementBlock rootSb = hopToSb.get(root); + ArrayList rootSbHops = hopToSb.get(root).getHops(); + + ArrayList intermediateWrites = new ArrayList<>(); + Hop newHop = generateHop(root, intermediateWrites, bases, optimalJoin); + + // remove and replace root + for (int i = 0; i < rootSbHops.size(); i++) { + if (rootSbHops.get(i) == root) { + rootSbHops.set(i, newHop); + } + } + HopRewriteUtils.rewireAllParentChildReferences(root, newHop); + + // remove all consumed joins that now aren't used + HashSet consumed = new HashSet<>(); + for (int j = 0; j < joins.size(); j++) + if (visited[j]) + consumed.add(joins.get(j)); + + rootSbHops.removeIf(consumed::contains); + + // rootSbHops.addAll(0,intermediateWrites); + // add new Sb containing TWrites to right before it is consumed + StatementBlock newSb = createIntermediateStatementBlock(rootSb, intermediateWrites); + sbs.add(sbs.indexOf(rootSb), newSb); + } + + // get the column number relative to the current relation starting at + // `intervalStart + long getRelativeCol(ArrayList prefixSum, int intervalStart, int baseIndex, long col) { + long offset = col; + if (intervalStart - 1 >= 0) + offset -= prefixSum.get(intervalStart - 1); + if (baseIndex - 1 >= 0) + offset += prefixSum.get(baseIndex - 1); + return offset; + } + + // modified from RewriteHoistLoopInvariantOperations.java + private StatementBlock createIntermediateStatementBlock(StatementBlock originalSb, List intermediateWrites) { + //create empty last-level statement block + StatementBlock ret = new StatementBlock(); + ret.setDMLProg(originalSb.getDMLProg()); + ret.setParseInfo(originalSb); + ret.setLiveIn(new VariableSet(originalSb.liveIn())); + ret.setLiveOut(new VariableSet(originalSb.liveIn())); + + //put custom hops + ret.setHops(new ArrayList<>(intermediateWrites)); + + // live variable analysis + for (DataOp tWrite : intermediateWrites) { + String varName = tWrite.getName(); + Hop hop = tWrite.getInput().get(0); + DataIdentifier diVar = new DataIdentifier(varName); + diVar.setDimensions(hop.getDim1(), hop.getDim2()); + diVar.setBlocksize(hop.getBlocksize()); + diVar.setDataType(hop.getDataType()); + diVar.setValueType(hop.getValueType()); + ret.liveOut().addVariable(varName, diVar); + originalSb.liveIn().addVariable(varName, diVar); + } + + return ret; + } + + // process a Hop to TRead and TWrite to be consumed. + private Hop materialize(Hop hop, ArrayList intermediateWrites) { + if (!(hop instanceof FunctionOp fop)) + return hop; + + String varName = fop.getOutputVariableNames()[0]; + + DataOp tWrite = HopRewriteUtils.createTransientWrite(varName, fop); + intermediateWrites.add(tWrite); + + return HopRewriteUtils.createTransientRead(varName, fop); + } + + /** + * Generate the Hop to replace the existing root. + * + * @param root root of the current rewrite if `optimalJoin` corresponds + * to the root, otherwise null + * @param bases + * @param optimalJoin the current JoinNode we are constructing + */ + private Hop generateHop(FunctionOp root, ArrayList intermediateWrites, ArrayList bases, + JoinNode optimalJoin) { + if (optimalJoin instanceof BaseNode baseNode) { + return bases.get(baseNode.i); + } + BinaryNode binaryNode = (BinaryNode) optimalJoin; + + String[] inputNames = new String[] { "A", "colA", "B", "colB", "method" }; + String[] outputNames; + ArrayList outputHops; + + Hop a = generateHop(null, intermediateWrites, bases, binaryNode.left); + a = materialize(a, intermediateWrites); + Hop colA = new LiteralOp(binaryNode.leftCol); + Hop b = generateHop(null, intermediateWrites, bases, binaryNode.right); + b = materialize(b, intermediateWrites); + Hop colB = new LiteralOp(binaryNode.rightCol); + Hop method = new LiteralOp(binaryNode.method); + + ArrayList inputs = new ArrayList<>(List.of(a, colA, b, colB, method)); + String varName = "_rajoin_reorder_tmp_" + a.getHopID() + "_" + b.getHopID(); + if (root != null) { + outputNames = root.getOutputVariableNames(); + outputHops = root.getOutputs(); + } else { + outputNames = new String[] { varName }; + outputHops = new ArrayList<>(); + } + + FunctionOp fop = new FunctionOp(FunctionOp.FunctionType.DML, ".builtinNS", "m_raJoin", inputNames, inputs, + outputNames, outputHops); + fop.setDim2(a.getDim2() + b.getDim2()); + fop.setDataType(DataType.MATRIX); + fop.setValueType(ValueType.FP64); + if (root == null) { + // Return a TRead if it is not the root. + return materialize(fop, intermediateWrites); + } + return fop; + } + + /** + * get a join that is applicable to left and right + * + * @param canonicalJoins + * @param left the bitset representing the left side of the raJoin + * @param right the bitset representing the right side of the raJoin + */ + private CanonicalJoin getValidJoin(ArrayList canonicalJoins, BitSet left, BitSet right) { + for (CanonicalJoin join : canonicalJoins) { + if (left.get(join.aBaseIndex) && right.get(join.bBaseIndex)) { + return join; + } + } + return null; + } + + private record IntPair(int left, int right) { + }; + + // representation of the dependencies on the bases and its indices for a given + // raJoin + private record CanonicalJoin(int aBaseIndex, long acol, int bBaseIndex, long bcol, String method) { + }; + + /** + * Inorder traversal of an raJoin + * + * @param joinMap + * @param joins + * @param canonicaljoins + * @return inclusive [left, right] range of the indices of `joins` that the + * current join corresponds to + */ + private IntPair dfsInorder(HashMap joinMap, ArrayList joins, + ArrayList cannonicalJoins, boolean[] visited, ArrayList bases, + ArrayList basesLengthPrefixSum, int i) { + visited[i] = true; + FunctionOp join = joins.get(i); + Hop a = join.getInput(0); + long acol = ((LiteralOp) join.getInput(1)).getLongValue(); + + Hop b = join.getInput(2); + long bcol = ((LiteralOp) join.getInput(3)).getLongValue(); + + String method = ((LiteralOp) join.getInput(4)).getStringValue(); + IntPair aPair; + if (isKnownMatrix(a)) { + bases.add(a); + basesLengthPrefixSum + .add((basesLengthPrefixSum.size() > 0 ? basesLengthPrefixSum.get(basesLengthPrefixSum.size() - 1) : 0) + + a.getDim2()); + aPair = new IntPair(bases.size() - 1, bases.size() - 1); + } else { + Integer aIndex = joinMap.get(a.getName()); + if (aIndex == null) + throw new UnknownDimensionInfoException(); + aPair = dfsInorder(joinMap, joins, cannonicalJoins, visited, bases, basesLengthPrefixSum, aIndex); + } + IntPair bPair; + if (isKnownMatrix(b)) { + bases.add(b); + basesLengthPrefixSum + .add((basesLengthPrefixSum.size() > 0 ? basesLengthPrefixSum.get(basesLengthPrefixSum.size() - 1) : 0) + + b.getDim2()); + bPair = new IntPair(bases.size() - 1, bases.size() - 1); + } else { + Integer bIndex = joinMap.get(b.getName()); + if (bIndex == null) + throw new UnknownDimensionInfoException(); + bPair = dfsInorder(joinMap, joins, cannonicalJoins, visited, bases, basesLengthPrefixSum, bIndex); + } + int aBaseIndex = -1; + for (int j = aPair.left; j <= aPair.right; j++) { + if (acol <= basesLengthPrefixSum.get(j)) { + // if (j - 1 >= 0) acol -= basesLengthPrefixSum.get(j-1); + aBaseIndex = j; + break; + } + } + + int bBaseIndex = -1; + for (int j = bPair.left; j <= bPair.right; j++) { + if (bcol <= basesLengthPrefixSum.get(j)) { + // if (j - 1 >= 0) bcol -= basesLengthPrefixSum.get(j-1); + bBaseIndex = j; + break; + } + } + acol = getRelativeCol(basesLengthPrefixSum, aPair.left, aBaseIndex, acol); + bcol = getRelativeCol(basesLengthPrefixSum, bPair.left, bBaseIndex, bcol); + + // throw an error and do not rewrite if we cannot figure out the dependencies. + if (aBaseIndex < 0 || bBaseIndex < 0) { + throw new UnknownCanonicalJoinException(); + } + cannonicalJoins.add(new CanonicalJoin(aBaseIndex, acol, bBaseIndex, bcol, method)); + return new IntPair(aPair.left, bPair.right); + } + + @Override + public List rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) { + List ret = new ArrayList<>(); + ret.add(sb); + return ret; + } + + @Override + public List rewriteStatementBlocks(List sbs, ProgramRewriteStatus state) { + HashMap hopToSb = new HashMap<>(); + HashMap joinMap = new HashMap<>(); + ArrayList joins = new ArrayList<>(); + ArrayList sbsA = new ArrayList<>(sbs); + for (StatementBlock sb : sbsA) { + collectRaJoin(hopToSb, sb, joinMap, joins); + } + try { + ArrayList order = topoOrder(joinMap, joins); + rewriteRoots(sbsA, hopToSb, joinMap, joins, order); + } catch (Exception e) { + // if it is a local exception, try rewriting the next root. + if (!((e instanceof UnknownCanonicalJoinException) || (e instanceof UnknownDimensionInfoException))) { + throw e; + } + } + return sbs; + } +} \ No newline at end of file diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRaJoinTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRaJoinTest.java new file mode 100644 index 00000000000..f05970e3127 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteRaJoinTest.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.rewrite; + +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +import java.util.HashMap; + +public class RewriteRaJoinTest extends AutomatedTestBase { + private final static String TEST_NAME = "raJoin"; + private final static String TEST_DIR = "functions/rewrite/"; + private final static String TEST_CLASS_DIR = TEST_DIR + RewriteRaJoinTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] { "OUT" })); + } + + @Test + public void testRaJoin() { + // Load test configuration (sets up temp input/output folders) + getAndLoadTestConfiguration(TEST_NAME); + + // Create inputs + double[][] A = {{1,1},{1,1}}; + double[][] B = {{3,2,1},{1,2,3},{3,1,2}}; + double[][] C = { + {1,1,1,1}, + {2,2,2,2}, + {3,3,3,3}, + {4,4,4,4} + }; + + MatrixCharacteristics mcA = new MatrixCharacteristics(2,2,-1,-1); + writeInputMatrixWithMTD("A", A, true, mcA); + + MatrixCharacteristics mcB = new MatrixCharacteristics(3,3,-1,-1); + writeInputMatrixWithMTD("B", B, true, mcB); + + MatrixCharacteristics mcC = new MatrixCharacteristics(4,4,-1,-1); + writeInputMatrixWithMTD("C", C, true, mcC); + + programArgs = new String[] { + "-explain", "hops", + "-stats", + "-args", + input("A"), + input("B"), + input("C"), + output("OUT") + }; + + fullDMLScriptName = SCRIPT_DIR + TEST_DIR + TEST_NAME + ".dml"; + + // Execute single threaded + ExecMode oldPlatform = setExecMode(ExecMode.SINGLE_NODE); + try { + runTest(true, EXCEPTION_NOT_EXPECTED, null, -1); + + HashMap out = readDMLMatrixFromOutputDir("OUT"); + + System.out.println("Result matrix:"); + for (CellIndex idx : out.keySet()) { + System.out.println(idx + " -> " + out.get(idx)); + } + + double[][] expected = { + {1,1,3,2,1,2,2,2,2}, + {1,1,3,2,1,2,2,2,2} + }; + HashMap expectedMap = TestUtils.convert2DDoubleArrayToHashMap(expected); + TestUtils.compareMatrices(expectedMap, out, 1e-10, "expected", "actual"); + + } finally { + rtplatform = oldPlatform; + } + } +} diff --git a/src/test/scripts/functions/rewrite/raJoin.dml b/src/test/scripts/functions/rewrite/raJoin.dml new file mode 100644 index 00000000000..f0dbc68d9c6 --- /dev/null +++ b/src/test/scripts/functions/rewrite/raJoin.dml @@ -0,0 +1,33 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# A = matrix(2, rows=2,cols=2) +# B = matrix(3, rows=3,cols=3) +# C = matrix(4, rows=4,cols=4) +A = read($1) +B = read($2) +C = read($3) +# A = matrix("1 1 1 1", rows=2, cols=2) +# B = matrix("3 2 1 1 2 3 3 1 2", rows=3, cols=3) +# C = matrix("1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4", rows=4, cols=4) +ans = raJoin(A, 1, raJoin(B, 2, C, 3, "nested-loop"), 3, "nested-loop") +# ans = raJoin(raJoin(A,1,B,3, "nested-loop"),4,C,3, "nested-loop") +write(ans, $4) \ No newline at end of file