diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java index 243bfe7caa6..608b1de3f1f 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java @@ -269,6 +269,9 @@ public void updateAllDCEncoders() { ColumnEncoderDummycode dc = getEncoder(ColumnEncoderDummycode.class); if(dc != null) dc.updateDomainSizes(_columnEncoders); + ColumnEncoderUDF udf = getEncoder(ColumnEncoderUDF.class); + if (udf != null && dc != null) + udf.updateDomainSizes(_columnEncoders); } public void addEncoder(ColumnEncoder other) { @@ -385,7 +388,10 @@ public void computeRCDMapSizeEstimate(CacheBlock in, int[] sampleIndices) { public void setNumPartitions(int nBuild, int nApply) { _columnEncoders.forEach(e -> { e.setBuildRowBlocksPerColumn(nBuild); - e.setApplyRowBlocksPerColumn(nApply); + if (e.getClass().equals(ColumnEncoderUDF.class)) + e.setApplyRowBlocksPerColumn(1); + else + e.setApplyRowBlocksPerColumn(nApply); }); } diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java index 161959441b7..fec65dd93d0 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java @@ -83,8 +83,11 @@ protected double[] getCodeCol(CacheBlock in, int startInd, int blkSize) { codes[i-startInd] = Double.NaN; else { // Calculate non-negative modulo - double mod = key.hashCode() % _K > 0 ? key.hashCode() % _K : _K + key.hashCode() % _K; - codes[i - startInd] = mod + 1; + //double mod = key.hashCode() % _K > 0 ? key.hashCode() % _K : _K + key.hashCode() % _K; + double mod = (key.hashCode() % _K) + 1; + if (mod < 0) + mod += _K; + codes[i - startInd] = mod; } } return codes; diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java index a3f76623f26..00e588ee17e 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java @@ -21,6 +21,7 @@ import java.util.List; +import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.parser.DMLProgram; @@ -28,7 +29,6 @@ import org.apache.sysds.runtime.controlprogram.Program; import org.apache.sysds.runtime.controlprogram.caching.CacheBlock; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; -import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory; import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils; @@ -39,15 +39,16 @@ import org.apache.sysds.runtime.matrix.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.DependencyTask; +import org.apache.sysds.utils.stats.TransformStatistics; public class ColumnEncoderUDF extends ColumnEncoder { //TODO pass execution context through encoder factory for arbitrary functions not just builtin - //TODO handling udf after dummy coding //TODO integration into IPA to ensure existence of unoptimized functions private final String _fName; - + public int _domainSize = 1; + protected ColumnEncoderUDF(int ptCols, String name) { super(ptCols); // 1-based _fName = name; @@ -73,10 +74,12 @@ public List> getBuildTasks(CacheBlock in) { } @Override - public MatrixBlock apply(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk) { + public void applyDense(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk) { + long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0; //create execution context and input ExecutionContext ec = ExecutionContextFactory.createContext(new Program(new DMLProgram())); - MatrixBlock col = out.slice(0, in.getNumRows()-1, _colID-1, _colID-1, new MatrixBlock()); + //MatrixBlock col = out.slice(0, in.getNumRows()-1, _colID-1, _colID-1, new MatrixBlock()); + MatrixBlock col = out.slice(0, in.getNumRows()-1, outputCol, outputCol+_domainSize-1, new MatrixBlock()); ec.setVariable("I", new ListObject(new Data[] {ParamservUtils.newMatrixObject(col, true)})); ec.setVariable("O", ParamservUtils.newMatrixObject(col, true)); @@ -87,11 +90,39 @@ public MatrixBlock apply(CacheBlock in, MatrixBlock out, int outputCol, int rowS new CPOperand(_fName, ValueType.STRING, DataType.SCALAR, true), new CPOperand("I", ValueType.UNKNOWN, DataType.LIST)}); fun.processInstruction(ec); - + //obtain result and in-place write back MatrixBlock ret = ((MatrixObject)ec.getCacheableData("O")).acquireReadAndRelease(); - out.leftIndexingOperations(ret, 0, in.getNumRows()-1, _colID-1, _colID-1, ret, UpdateType.INPLACE); - return out; + //out.leftIndexingOperations(ret, 0, in.getNumRows()-1, _colID-1, _colID-1, ret, UpdateType.INPLACE); + //out.leftIndexingOperations(ret, 0, in.getNumRows()-1, outputCol, outputCol+_domainSize-1, ret, UpdateType.INPLACE); + //out.copy(0, in.getNumRows()-1, _colID-1, _colID-1, ret, true); + out.copy(0, in.getNumRows()-1, outputCol, outputCol+_domainSize-1, ret, true); + + if (DMLScript.STATISTICS) + TransformStatistics.incUDFApplyTime(System.nanoTime() - t0); + } + + public void updateDomainSizes(List columnEncoders) { + if(_colID == -1) + return; + for(ColumnEncoder columnEncoder : columnEncoders) { + int distinct = -1; + if(columnEncoder instanceof ColumnEncoderRecode) { + ColumnEncoderRecode columnEncoderRecode = (ColumnEncoderRecode) columnEncoder; + distinct = columnEncoderRecode.getNumDistinctValues(); + } + else if(columnEncoder instanceof ColumnEncoderBin) { + distinct = ((ColumnEncoderBin) columnEncoder)._numBin; + } + else if(columnEncoder instanceof ColumnEncoderFeatureHash){ + distinct = (int) ((ColumnEncoderFeatureHash) columnEncoder).getK(); + } + + if(distinct != -1) { + _domainSize = distinct; + LOG.debug("DummyCoder for column: " + _colID + " has domain size: " + _domainSize); + } + } } @Override diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java index d84d00e531c..07d1aa7f2a6 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java @@ -314,11 +314,12 @@ public MatrixBlock apply(CacheBlock in) { public MatrixBlock apply(CacheBlock in, int k) { // domain sizes are not updated if called from transformapply + boolean hasUDF = _columnEncoders.stream().anyMatch(e -> e.hasEncoder(ColumnEncoderUDF.class)); for(ColumnEncoderComposite columnEncoder : _columnEncoders) columnEncoder.updateAllDCEncoders(); int numCols = in.getNumColumns() + getNumExtraCols(); - long estNNz = (long) in.getNumColumns() * (long) in.getNumRows(); - boolean sparse = MatrixBlock.evalSparseFormatInMemory(in.getNumRows(), numCols, estNNz); + long estNNz = (long) in.getNumRows() * (hasUDF ? numCols : (long) in.getNumColumns()); + boolean sparse = MatrixBlock.evalSparseFormatInMemory(in.getNumRows(), numCols, estNNz) && !hasUDF; MatrixBlock out = new MatrixBlock(in.getNumRows(), numCols, sparse, estNNz); return apply(in, out, 0, k); } @@ -379,16 +380,15 @@ private List> getApplyTasks(CacheBlock in, MatrixBlock out, in private void applyMT(CacheBlock in, MatrixBlock out, int outputCol, int k) { DependencyThreadPool pool = new DependencyThreadPool(k); try { - if(APPLY_ENCODER_SEPARATE_STAGES){ + if(APPLY_ENCODER_SEPARATE_STAGES) { int offset = outputCol; for (ColumnEncoderComposite e : _columnEncoders) { pool.submitAllAndWait(e.getApplyTasks(in, out, e._colID - 1 + offset)); if (e.hasEncoder(ColumnEncoderDummycode.class)) offset += e.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1; } - }else{ + } else pool.submitAllAndWait(getApplyTasks(in, out, outputCol)); - } } catch(ExecutionException | InterruptedException e) { LOG.error("MT Column apply failed"); @@ -455,7 +455,7 @@ private void deriveNumRowPartitions(CacheBlock in, int k) { long memBudget = (long) (OptimizerUtils.getLocalMemBudget() - in.getInMemorySize()); // Worst case scenario: all partial maps contain all distinct values (if < #rows) long totMemOverhead = getTotalMemOverhead(in, rcdNumBuildBlks, recodeEncoders); - // Reduce recode build blocks count till they fit int the memory budget + // Reduce recode build blocks count till they fit in the memory budget while (rcdNumBuildBlks > 1 && totMemOverhead > memBudget) { rcdNumBuildBlks--; totMemOverhead = getTotalMemOverhead(in, rcdNumBuildBlks, recodeEncoders); @@ -1078,10 +1078,11 @@ private InitOutputMatrixTask(MultiColumnEncoder encoder, CacheBlock input, Matri @Override public Object call() throws Exception { + boolean hasUDF = _encoder.getColumnEncoders().stream().anyMatch(e -> e.hasEncoder(ColumnEncoderUDF.class)); int numCols = _input.getNumColumns() + _encoder.getNumExtraCols(); boolean hasDC = _encoder.getColumnEncoders(ColumnEncoderDummycode.class).size() > 0; - long estNNz = (long) _input.getNumColumns() * (long) _input.getNumRows(); - boolean sparse = MatrixBlock.evalSparseFormatInMemory(_input.getNumRows(), numCols, estNNz); + long estNNz = (long) _input.getNumRows() * (hasUDF ? numCols : (long) _input.getNumColumns()); + boolean sparse = MatrixBlock.evalSparseFormatInMemory(_input.getNumRows(), numCols, estNNz) && !hasUDF; _output.reset(_input.getNumRows(), numCols, sparse, estNNz); outputMatrixPreProcessing(_output, _input, hasDC); return null; diff --git a/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java b/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java index 05f06b065c8..b7779e4ee19 100644 --- a/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java +++ b/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java @@ -35,6 +35,7 @@ public class TransformStatistics { private static final LongAdder passThroughApplyTime = new LongAdder(); private static final LongAdder featureHashingApplyTime = new LongAdder(); private static final LongAdder binningApplyTime = new LongAdder(); + private static final LongAdder UDFApplyTime = new LongAdder(); private static final LongAdder omitApplyTime = new LongAdder(); private static final LongAdder imputeApplyTime = new LongAdder(); @@ -58,6 +59,10 @@ public static void incBinningApplyTime(long t) { binningApplyTime.add(t); } + public static void incUDFApplyTime(long t) { + UDFApplyTime.add(t); + } + public static void incPassThroughApplyTime(long t) { passThroughApplyTime.add(t); } @@ -106,8 +111,8 @@ public static long getEncodeBuildTime() { public static long getEncodeApplyTime() { return dummyCodeApplyTime.longValue() + binningApplyTime.longValue() + featureHashingApplyTime.longValue() + passThroughApplyTime.longValue() + - recodeApplyTime.longValue() + omitApplyTime.longValue() + - imputeApplyTime.longValue(); + recodeApplyTime.longValue() + UDFApplyTime.longValue() + + omitApplyTime.longValue() + imputeApplyTime.longValue(); } public static void reset() { @@ -122,6 +127,7 @@ public static void reset() { passThroughApplyTime.reset(); featureHashingApplyTime.reset(); binningApplyTime.reset(); + UDFApplyTime.reset(); omitApplyTime.reset(); imputeApplyTime.reset(); outMatrixPreProcessingTime.reset(); @@ -163,6 +169,9 @@ public static String displayStatistics() { if(passThroughApplyTime.longValue() > 0) sb.append("\tPassThrough apply time:\t").append(String.format("%.3f", passThroughApplyTime.longValue()*1e-9)).append(" sec.\n"); + if(UDFApplyTime.longValue() > 0) + sb.append("\tUDF apply time:\t").append(String.format("%.3f", + UDFApplyTime.longValue()*1e-9)).append(" sec.\n"); if(omitApplyTime.longValue() > 0) sb.append("\tOmit apply time:\t").append(String.format("%.3f", omitApplyTime.longValue()*1e-9)).append(" sec.\n");