Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,9 @@ public void updateAllDCEncoders() {
ColumnEncoderDummycode dc = getEncoder(ColumnEncoderDummycode.class);
if(dc != null)
dc.updateDomainSizes(_columnEncoders);
ColumnEncoderUDF udf = getEncoder(ColumnEncoderUDF.class);
if (udf != null && dc != null)
udf.updateDomainSizes(_columnEncoders);
}

public void addEncoder(ColumnEncoder other) {
Expand Down Expand Up @@ -385,7 +388,10 @@ public void computeRCDMapSizeEstimate(CacheBlock in, int[] sampleIndices) {
public void setNumPartitions(int nBuild, int nApply) {
_columnEncoders.forEach(e -> {
e.setBuildRowBlocksPerColumn(nBuild);
e.setApplyRowBlocksPerColumn(nApply);
if (e.getClass().equals(ColumnEncoderUDF.class))
e.setApplyRowBlocksPerColumn(1);
else
e.setApplyRowBlocksPerColumn(nApply);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,11 @@ protected double[] getCodeCol(CacheBlock in, int startInd, int blkSize) {
codes[i-startInd] = Double.NaN;
else {
// Calculate non-negative modulo
double mod = key.hashCode() % _K > 0 ? key.hashCode() % _K : _K + key.hashCode() % _K;
codes[i - startInd] = mod + 1;
//double mod = key.hashCode() % _K > 0 ? key.hashCode() % _K : _K + key.hashCode() % _K;
double mod = (key.hashCode() % _K) + 1;
if (mod < 0)
mod += _K;
codes[i - startInd] = mod;
}
}
return codes;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@

import java.util.List;

import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.Program;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
Expand All @@ -39,15 +39,16 @@
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.DependencyTask;
import org.apache.sysds.utils.stats.TransformStatistics;

public class ColumnEncoderUDF extends ColumnEncoder {

//TODO pass execution context through encoder factory for arbitrary functions not just builtin
//TODO handling udf after dummy coding
//TODO integration into IPA to ensure existence of unoptimized functions

private final String _fName;

public int _domainSize = 1;

protected ColumnEncoderUDF(int ptCols, String name) {
super(ptCols); // 1-based
_fName = name;
Expand All @@ -73,10 +74,12 @@ public List<DependencyTask<?>> getBuildTasks(CacheBlock in) {
}

@Override
public MatrixBlock apply(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk) {
public void applyDense(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk) {
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
//create execution context and input
ExecutionContext ec = ExecutionContextFactory.createContext(new Program(new DMLProgram()));
MatrixBlock col = out.slice(0, in.getNumRows()-1, _colID-1, _colID-1, new MatrixBlock());
//MatrixBlock col = out.slice(0, in.getNumRows()-1, _colID-1, _colID-1, new MatrixBlock());
MatrixBlock col = out.slice(0, in.getNumRows()-1, outputCol, outputCol+_domainSize-1, new MatrixBlock());
ec.setVariable("I", new ListObject(new Data[] {ParamservUtils.newMatrixObject(col, true)}));
ec.setVariable("O", ParamservUtils.newMatrixObject(col, true));

Expand All @@ -87,11 +90,39 @@ public MatrixBlock apply(CacheBlock in, MatrixBlock out, int outputCol, int rowS
new CPOperand(_fName, ValueType.STRING, DataType.SCALAR, true),
new CPOperand("I", ValueType.UNKNOWN, DataType.LIST)});
fun.processInstruction(ec);

//obtain result and in-place write back
MatrixBlock ret = ((MatrixObject)ec.getCacheableData("O")).acquireReadAndRelease();
out.leftIndexingOperations(ret, 0, in.getNumRows()-1, _colID-1, _colID-1, ret, UpdateType.INPLACE);
return out;
//out.leftIndexingOperations(ret, 0, in.getNumRows()-1, _colID-1, _colID-1, ret, UpdateType.INPLACE);
//out.leftIndexingOperations(ret, 0, in.getNumRows()-1, outputCol, outputCol+_domainSize-1, ret, UpdateType.INPLACE);
//out.copy(0, in.getNumRows()-1, _colID-1, _colID-1, ret, true);
out.copy(0, in.getNumRows()-1, outputCol, outputCol+_domainSize-1, ret, true);

if (DMLScript.STATISTICS)
TransformStatistics.incUDFApplyTime(System.nanoTime() - t0);
}

public void updateDomainSizes(List<ColumnEncoder> columnEncoders) {
if(_colID == -1)
return;
for(ColumnEncoder columnEncoder : columnEncoders) {
int distinct = -1;
if(columnEncoder instanceof ColumnEncoderRecode) {
ColumnEncoderRecode columnEncoderRecode = (ColumnEncoderRecode) columnEncoder;
distinct = columnEncoderRecode.getNumDistinctValues();
}
else if(columnEncoder instanceof ColumnEncoderBin) {
distinct = ((ColumnEncoderBin) columnEncoder)._numBin;
}
else if(columnEncoder instanceof ColumnEncoderFeatureHash){
distinct = (int) ((ColumnEncoderFeatureHash) columnEncoder).getK();
}

if(distinct != -1) {
_domainSize = distinct;
LOG.debug("DummyCoder for column: " + _colID + " has domain size: " + _domainSize);
}
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,11 +314,12 @@ public MatrixBlock apply(CacheBlock in) {

public MatrixBlock apply(CacheBlock in, int k) {
// domain sizes are not updated if called from transformapply
boolean hasUDF = _columnEncoders.stream().anyMatch(e -> e.hasEncoder(ColumnEncoderUDF.class));
for(ColumnEncoderComposite columnEncoder : _columnEncoders)
columnEncoder.updateAllDCEncoders();
int numCols = in.getNumColumns() + getNumExtraCols();
long estNNz = (long) in.getNumColumns() * (long) in.getNumRows();
boolean sparse = MatrixBlock.evalSparseFormatInMemory(in.getNumRows(), numCols, estNNz);
long estNNz = (long) in.getNumRows() * (hasUDF ? numCols : (long) in.getNumColumns());
boolean sparse = MatrixBlock.evalSparseFormatInMemory(in.getNumRows(), numCols, estNNz) && !hasUDF;
MatrixBlock out = new MatrixBlock(in.getNumRows(), numCols, sparse, estNNz);
return apply(in, out, 0, k);
}
Expand Down Expand Up @@ -379,16 +380,15 @@ private List<DependencyTask<?>> getApplyTasks(CacheBlock in, MatrixBlock out, in
private void applyMT(CacheBlock in, MatrixBlock out, int outputCol, int k) {
DependencyThreadPool pool = new DependencyThreadPool(k);
try {
if(APPLY_ENCODER_SEPARATE_STAGES){
if(APPLY_ENCODER_SEPARATE_STAGES) {
int offset = outputCol;
for (ColumnEncoderComposite e : _columnEncoders) {
pool.submitAllAndWait(e.getApplyTasks(in, out, e._colID - 1 + offset));
if (e.hasEncoder(ColumnEncoderDummycode.class))
offset += e.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1;
}
}else{
} else
pool.submitAllAndWait(getApplyTasks(in, out, outputCol));
}
}
catch(ExecutionException | InterruptedException e) {
LOG.error("MT Column apply failed");
Expand Down Expand Up @@ -455,7 +455,7 @@ private void deriveNumRowPartitions(CacheBlock in, int k) {
long memBudget = (long) (OptimizerUtils.getLocalMemBudget() - in.getInMemorySize());
// Worst case scenario: all partial maps contain all distinct values (if < #rows)
long totMemOverhead = getTotalMemOverhead(in, rcdNumBuildBlks, recodeEncoders);
// Reduce recode build blocks count till they fit int the memory budget
// Reduce recode build blocks count till they fit in the memory budget
while (rcdNumBuildBlks > 1 && totMemOverhead > memBudget) {
rcdNumBuildBlks--;
totMemOverhead = getTotalMemOverhead(in, rcdNumBuildBlks, recodeEncoders);
Expand Down Expand Up @@ -1078,10 +1078,11 @@ private InitOutputMatrixTask(MultiColumnEncoder encoder, CacheBlock input, Matri

@Override
public Object call() throws Exception {
boolean hasUDF = _encoder.getColumnEncoders().stream().anyMatch(e -> e.hasEncoder(ColumnEncoderUDF.class));
int numCols = _input.getNumColumns() + _encoder.getNumExtraCols();
boolean hasDC = _encoder.getColumnEncoders(ColumnEncoderDummycode.class).size() > 0;
long estNNz = (long) _input.getNumColumns() * (long) _input.getNumRows();
boolean sparse = MatrixBlock.evalSparseFormatInMemory(_input.getNumRows(), numCols, estNNz);
long estNNz = (long) _input.getNumRows() * (hasUDF ? numCols : (long) _input.getNumColumns());
boolean sparse = MatrixBlock.evalSparseFormatInMemory(_input.getNumRows(), numCols, estNNz) && !hasUDF;
_output.reset(_input.getNumRows(), numCols, sparse, estNNz);
outputMatrixPreProcessing(_output, _input, hasDC);
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public class TransformStatistics {
private static final LongAdder passThroughApplyTime = new LongAdder();
private static final LongAdder featureHashingApplyTime = new LongAdder();
private static final LongAdder binningApplyTime = new LongAdder();
private static final LongAdder UDFApplyTime = new LongAdder();
private static final LongAdder omitApplyTime = new LongAdder();
private static final LongAdder imputeApplyTime = new LongAdder();

Expand All @@ -58,6 +59,10 @@ public static void incBinningApplyTime(long t) {
binningApplyTime.add(t);
}

public static void incUDFApplyTime(long t) {
UDFApplyTime.add(t);
}

public static void incPassThroughApplyTime(long t) {
passThroughApplyTime.add(t);
}
Expand Down Expand Up @@ -106,8 +111,8 @@ public static long getEncodeBuildTime() {
public static long getEncodeApplyTime() {
return dummyCodeApplyTime.longValue() + binningApplyTime.longValue() +
featureHashingApplyTime.longValue() + passThroughApplyTime.longValue() +
recodeApplyTime.longValue() + omitApplyTime.longValue() +
imputeApplyTime.longValue();
recodeApplyTime.longValue() + UDFApplyTime.longValue() +
omitApplyTime.longValue() + imputeApplyTime.longValue();
}

public static void reset() {
Expand All @@ -122,6 +127,7 @@ public static void reset() {
passThroughApplyTime.reset();
featureHashingApplyTime.reset();
binningApplyTime.reset();
UDFApplyTime.reset();
omitApplyTime.reset();
imputeApplyTime.reset();
outMatrixPreProcessingTime.reset();
Expand Down Expand Up @@ -163,6 +169,9 @@ public static String displayStatistics() {
if(passThroughApplyTime.longValue() > 0)
sb.append("\tPassThrough apply time:\t").append(String.format("%.3f",
passThroughApplyTime.longValue()*1e-9)).append(" sec.\n");
if(UDFApplyTime.longValue() > 0)
sb.append("\tUDF apply time:\t").append(String.format("%.3f",
UDFApplyTime.longValue()*1e-9)).append(" sec.\n");
if(omitApplyTime.longValue() > 0)
sb.append("\tOmit apply time:\t").append(String.format("%.3f",
omitApplyTime.longValue()*1e-9)).append(" sec.\n");
Expand Down