Skip to content

Commit

Permalink
[SYSTEMDS-3572] CLA Morphing Interfaces
Browse files Browse the repository at this point in the history
This commit adds multiple Interfaces for the different intended behavior
of the Column Groups to enable

- Morphing
- Re-compressing
- getting compression statistics from a column groups
- Combining column groups

Also contained is a change of the CLALib classes to be final and
their constructors private.

Closes #1830
  • Loading branch information
Baunsgaard committed May 22, 2023
1 parent 4c90e89 commit 34ca59f
Show file tree
Hide file tree
Showing 39 changed files with 508 additions and 140 deletions.
Expand Up @@ -56,24 +56,21 @@
import org.apache.sysds.runtime.compress.lib.CLALibSlice;
import org.apache.sysds.runtime.compress.lib.CLALibSquash;
import org.apache.sysds.runtime.compress.lib.CLALibTSMM;
import org.apache.sysds.runtime.compress.lib.CLALibTernaryOp;
import org.apache.sysds.runtime.compress.lib.CLALibUnary;
import org.apache.sysds.runtime.compress.lib.CLALibUtils;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject.UpdateType;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseRow;
import org.apache.sysds.runtime.functionobjects.MinusMultiply;
import org.apache.sysds.runtime.functionobjects.PlusMultiply;
import org.apache.sysds.runtime.functionobjects.TernaryValueFunction.ValueFunctionWithConstant;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.matrix.data.CTableMap;
import org.apache.sysds.runtime.matrix.data.IJV;
import org.apache.sysds.runtime.matrix.data.LibMatrixDatagen;
import org.apache.sysds.runtime.matrix.data.LibMatrixTercell;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
Expand Down Expand Up @@ -575,21 +572,6 @@ public MatrixBlock reorgOperations(ReorgOperator op, MatrixValue ret, int startR
return tmp.reorgOperations(op, ret, startRow, startColumn, length);
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("CompressedMatrixBlock:");
sb.append("\nCols:" + getNumColumns() + " Rows:" + getNumRows() + " Overlapping: " + isOverlapping() + " nnz: "
+ nonZeros);
if(_colGroups != null)
for(AColGroup cg : _colGroups) {
sb.append("\n" + cg);
}
else
sb.append("\nEmptyColGroups");
return sb.toString();
}

public boolean isOverlapping() {
return _colGroups.size() != 1 && overlappingColGroups;
}
Expand Down Expand Up @@ -881,54 +863,7 @@ public void ctableOperations(Operator op, MatrixValue that, MatrixValue that2, C

@Override
public MatrixBlock ternaryOperations(TernaryOperator op, MatrixBlock m2, MatrixBlock m3, MatrixBlock ret) {

// prepare inputs
final int r1 = getNumRows();
final int r2 = m2.getNumRows();
final int r3 = m3.getNumRows();
final int c1 = getNumColumns();
final int c2 = m2.getNumColumns();
final int c3 = m3.getNumColumns();
final boolean s1 = (r1 == 1 && c1 == 1);
final boolean s2 = (r2 == 1 && c2 == 1);
final boolean s3 = (r3 == 1 && c3 == 1);
final double d1 = s1 ? quickGetValue(0, 0) : Double.NaN;
final double d2 = s2 ? m2.quickGetValue(0, 0) : Double.NaN;
final double d3 = s3 ? m3.quickGetValue(0, 0) : Double.NaN;
final int m = Math.max(Math.max(r1, r2), r3);
final int n = Math.max(Math.max(c1, c2), c3);

ternaryOperationCheck(s1, s2, s3, m, r1, r2, r3, n, c1, c2, c3);

final boolean PM_Or_MM = (op.fn instanceof PlusMultiply || op.fn instanceof MinusMultiply);
if(PM_Or_MM && ((s2 && d2 == 0) || (s3 && d3 == 0))) {
ret = new CompressedMatrixBlock();
ret.copy(this);
return ret;
}

if(m2 instanceof CompressedMatrixBlock)
m2 = ((CompressedMatrixBlock) m2).getUncompressed("Ternary Operator arg2 " + op.fn.getClass().getSimpleName(),
op.getNumThreads());
if(m3 instanceof CompressedMatrixBlock)
m3 = ((CompressedMatrixBlock) m3).getUncompressed("Ternary Operator arg3 " + op.fn.getClass().getSimpleName(),
op.getNumThreads());

if(s2 != s3 && (op.fn instanceof PlusMultiply || op.fn instanceof MinusMultiply)) {
// SPECIAL CASE for sparse-dense combinations of common +* and -*
BinaryOperator bop = ((ValueFunctionWithConstant) op.fn).setOp2Constant(s2 ? d2 : d3);
bop.setNumThreads(op.getNumThreads());
ret = CLALibBinaryCellOp.binaryOperationsRight(bop, this, s2 ? m3 : m2, ret);
}
else {
final boolean sparseOutput = evalSparseFormatInMemory(m, n, (s1 ? m * n * (d1 != 0 ? 1 : 0) : getNonZeros()) +
Math.min(s2 ? m * n : m2.getNonZeros(), s3 ? m * n : m3.getNonZeros()));
ret.reset(m, n, sparseOutput);
final MatrixBlock thisUncompressed = getUncompressed("Ternary Operation not supported");
LibMatrixTercell.tercellOp(thisUncompressed, m2, m3, ret, op);
ret.examSparsity();
}
return ret;
return CLALibTernaryOp.ternaryOperations(this, op, m2, m3, ret);
}

@Override
Expand Down Expand Up @@ -1293,4 +1228,19 @@ public boolean allocateSparseRowsBlock(boolean clearNNZ) {
public void allocateAndResetSparseBlock(boolean clearNNZ, SparseBlock.Type stype) {
throw new DMLCompressionException("Invalid to allocate block on a compressed MatrixBlock");
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("CompressedMatrixBlock:");
sb.append("\nCols:" + getNumColumns() + " Rows:" + getNumRows() + " Overlapping: " + isOverlapping() + " nnz: "
+ nonZeros);
if(_colGroups != null)
for(AColGroup cg : _colGroups) {
sb.append("\n" + cg);
}
else
sb.append("\nEmptyColGroups");
return sb.toString();
}
}
Expand Up @@ -267,7 +267,7 @@ public static CompressedMatrixBlock createConstant(int numRows, int numCols, dou
private Pair<MatrixBlock, CompressionStatistics> compressMatrix() {

if(mb instanceof CompressedMatrixBlock) // Redundant compression
return returnSelf();
return recompress((CompressedMatrixBlock) mb);

_stats.denseSize = MatrixBlock.estimateSizeInMemory(mb.getNumRows(), mb.getNumColumns(), 1.0);
_stats.originalSize = mb.getInMemorySize();
Expand Down Expand Up @@ -305,8 +305,8 @@ private void classifyPhase() {
if(LOG.isTraceEnabled()) {
LOG.trace("Logging all individual columns estimated cost:");
for(CompressedSizeInfoColGroup g : compressionGroups.getInfo())
LOG.trace(String.format("Cost: %8.0f Size: %16d %15s", costEstimator.getCost(g), g.getMinSize(),
g.getColumns()));
LOG.trace(
String.format("Cost: %8.0f Size: %16d %15s", costEstimator.getCost(g), g.getMinSize(), g.getColumns()));
}

_stats.estimatedSizeCols = compressionGroups.memoryEstimate();
Expand Down Expand Up @@ -452,6 +452,17 @@ private Pair<MatrixBlock, CompressionStatistics> abortCompression() {
return new ImmutablePair<>(mb, _stats);
}

private Pair<MatrixBlock, CompressionStatistics> recompress(CompressedMatrixBlock cmb) {
LOG.debug("Recompressing an already compressed MatrixBlock");
LOG.error("Not Implemented Recompress yet");
return new ImmutablePair<>(cmb, null);
// _stats.originalSize = cmb.getInMemorySize();
// CompressedMatrixBlock combined = CLALibCombineGroups.combine(cmb, k);
// CompressedMatrixBlock squashed = CLALibSquash.squash(combined, k);
// _stats.compressedSize = squashed.getInMemorySize();
// return new ImmutablePair<>(squashed, _stats);
}

private void logPhase() {
setNextTimePhase(time.stop());
DMLCompressionStatistics.addCompressionTime(getLastTimePhase(), phase);
Expand Down Expand Up @@ -559,11 +570,6 @@ private Pair<MatrixBlock, CompressionStatistics> createEmpty() {
return new ImmutablePair<>(res, _stats);
}

private Pair<MatrixBlock, CompressionStatistics> returnSelf() {
LOG.info("MatrixBlock already compressed or is Empty");
return new ImmutablePair<>(mb, null);
}

private static String constructNrColumnString(List<AColGroup> cg) {
StringBuilder sb = new StringBuilder();
sb.append("[");
Expand Down
Expand Up @@ -24,12 +24,14 @@
import java.io.Serializable;
import java.util.Collection;

import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex.SliceResult;
import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme;
import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
Expand Down Expand Up @@ -610,12 +612,37 @@ public static AColGroup appendN(AColGroup[] groups) {
public abstract ICLAScheme getCompressionScheme();

/**
* Clear variables that can be recomputed from the allocation of this columngroup.
* Clear variables that can be recomputed from the allocation of this column group.
*/
public void clear(){
public void clear() {
// do nothing
}

/**
* Recompress this column group into a new column group.
*
* @return A new or the same column group depending on optimization goal.
*/
public abstract AColGroup recompress();

/**
* Recompress this column group into a new column group of the given type.
*
* @param ct The compressionType that the column group should morph into
* @return A new column group
*/
public AColGroup morph(CompressionType ct){
throw new NotImplementedException();
}

/**
* Get the compression info for this column group.
*
* @param nRow The number of rows in this column group.
* @return The compression info for this group.
*/
public abstract CompressedSizeInfoColGroup getCompressionInfo(int nRow);

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
Expand Down
Expand Up @@ -33,6 +33,7 @@
import org.apache.sysds.runtime.compress.colgroup.scheme.ConstScheme;
import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme;
import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.compress.lib.CLALibLeftMultBy;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
Expand Down Expand Up @@ -566,6 +567,16 @@ public ICLAScheme getCompressionScheme() {
return ConstScheme.create(this);
}

@Override
public AColGroup recompress(){
return this;
}

@Override
public CompressedSizeInfoColGroup getCompressionInfo(int nRow){
return new CompressedSizeInfoColGroup(_colIndexes, 1, nRow, CompressionType.CONST);
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
Expand Down
Expand Up @@ -37,6 +37,7 @@
import org.apache.sysds.runtime.compress.colgroup.scheme.DDCScheme;
import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme;
import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
Expand All @@ -55,7 +56,7 @@ public class ColGroupDDC extends APreAgg implements AMapToDataGroup {

protected final AMapToData _data;

private ColGroupDDC(IColIndex colIndexes, ADictionary dict, AMapToData data, int[] cachedCounts) {
private ColGroupDDC(IColIndex colIndexes, ADictionary dict, AMapToData data, int[] cachedCounts) {
super(colIndexes, dict, cachedCounts);
_data = data;
}
Expand Down Expand Up @@ -117,8 +118,8 @@ else if(offC == 0)
decompressToDenseBlockDenseDictGeneric(db, rl, ru, offR, offC, values);
}

private final void decompressToDenseBlockDenseDictSingleColContiguous(DenseBlock db, int rl, int ru, int offR, int offC,
double[] values) {
private final void decompressToDenseBlockDenseDictSingleColContiguous(DenseBlock db, int rl, int ru, int offR,
int offC, double[] values) {
final double[] c = db.values(0);
final int nCols = db.getDim(1);
final int colOff = _colIndexes.get(0) + offC;
Expand All @@ -128,12 +129,12 @@ private final void decompressToDenseBlockDenseDictSingleColContiguous(DenseBlock
}

@Override
public AMapToData getMapToData(){
public AMapToData getMapToData() {
return _data;
}

private final void decompressToDenseBlockDenseDictSingleColOutContiguous(DenseBlock db, int rl, int ru, int offR, int offC,
double[] values) {
private final void decompressToDenseBlockDenseDictSingleColOutContiguous(DenseBlock db, int rl, int ru, int offR,
int offC, double[] values) {
final double[] c = db.values(0);
for(int i = rl, offT = rl + offR + _colIndexes.get(0) + offC; i < ru; i++, offT++)
c[offT] += values[_data.getIndex(i)];
Expand All @@ -152,7 +153,8 @@ private final void decompressToDenseBlockDenseDictAllColumnsContiguous(DenseBloc
}
}

private final void decompressToDenseBlockDenseDictNoColOffset(DenseBlock db, int rl, int ru, int offR, double[] values) {
private final void decompressToDenseBlockDenseDictNoColOffset(DenseBlock db, int rl, int ru, int offR,
double[] values) {
final int nCol = _colIndexes.size();
final int colOut = db.getDim(1);
int off = (rl + offR) * colOut;
Expand Down Expand Up @@ -508,8 +510,7 @@ public AColGroup append(AColGroup g) {
LOG.warn("Not same Dictionaries therefore not appending DDC\n" + _dict + "\n\n" + gDDC._dict);
}
else
LOG.warn("Not same columns therefore not appending DDC\n" + _colIndexes + "\n\n"
+ g.getColIndices());
LOG.warn("Not same columns therefore not appending DDC\n" + _colIndexes + "\n\n" + g.getColIndices());
}
else
LOG.warn("Not DDC but " + g.getClass().getSimpleName() + ", therefore not appending DDC");
Expand All @@ -519,9 +520,8 @@ public AColGroup append(AColGroup g) {
@Override
public AColGroup appendNInternal(AColGroup[] g) {
for(int i = 1; i < g.length; i++) {
if(!_colIndexes.equals( g[i]._colIndexes)) {
LOG.warn("Not same columns therefore not appending DDC\n" + _colIndexes + "\n\n"
+ g[i]._colIndexes);
if(!_colIndexes.equals(g[i]._colIndexes)) {
LOG.warn("Not same columns therefore not appending DDC\n" + _colIndexes + "\n\n" + g[i]._colIndexes);
return null;
}

Expand All @@ -540,12 +540,21 @@ public AColGroup appendNInternal(AColGroup[] g) {
return create(_colIndexes, _dict, nd, null);
}


@Override
public ICLAScheme getCompressionScheme() {
return DDCScheme.create(this);
}

@Override
public AColGroup recompress() {
return this;
}

@Override
public CompressedSizeInfoColGroup getCompressionInfo(int nRow) {
throw new NotImplementedException();
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
Expand Down
Expand Up @@ -35,6 +35,7 @@
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme;
import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Divide;
import org.apache.sysds.runtime.functionobjects.Minus;
Expand All @@ -59,7 +60,8 @@ public class ColGroupDDCFOR extends AMorphingMMColGroup {
/** Reference values in this column group */
protected final double[] _reference;

private ColGroupDDCFOR(IColIndex colIndexes, ADictionary dict, double[] reference, AMapToData data, int[] cachedCounts) {
private ColGroupDDCFOR(IColIndex colIndexes, ADictionary dict, double[] reference, AMapToData data,
int[] cachedCounts) {
super(colIndexes, dict, cachedCounts);
_data = data;
_reference = reference;
Expand Down Expand Up @@ -439,7 +441,7 @@ protected AColGroup copyAndSet(IColIndex colIndexes, ADictionary newDictionary)
public AColGroup append(AColGroup g) {
if(g instanceof ColGroupDDCFOR && g.getColIndices().equals(_colIndexes)) {
ColGroupDDCFOR gDDC = (ColGroupDDCFOR) g;
if(Arrays.equals(_reference , gDDC._reference) && gDDC._dict.equals(_dict)){
if(Arrays.equals(_reference, gDDC._reference) && gDDC._dict.equals(_dict)) {
AMapToData nd = _data.append(gDDC._data);
return create(_colIndexes, _dict, nd, null, _reference);
}
Expand All @@ -457,6 +459,16 @@ public ICLAScheme getCompressionScheme() {
return null;
}

@Override
public AColGroup recompress() {
return this;
}

@Override
public CompressedSizeInfoColGroup getCompressionInfo(int nRow) {
throw new NotImplementedException();
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
Expand Down

0 comments on commit 34ca59f

Please sign in to comment.