From f0f8f0c19083114be5383d91ebbe80d362f6abbe Mon Sep 17 00:00:00 2001 From: baunsgaard Date: Mon, 3 Jul 2023 19:37:37 +0200 Subject: [PATCH] [SYSTEMDS-3592] Frame Compress This commit adds a compression pipeline for frames to first analyze a sample, that then is used to determine compression of individual columns. The distinct estimation tools of the matrix compression frame work is used. Next step is parallelization of the compression. Closes #1856 --- .../apache/sysds/api/mlcontext/MLContext.java | 6 +- .../estim/sample/SampleEstimatorFactory.java | 16 ++- .../sysds/runtime/frame/data/FrameBlock.java | 51 +++++++- .../frame/data/columns/ACompressedArray.java | 7 ++ .../runtime/frame/data/columns/Array.java | 60 ++++++--- .../frame/data/columns/BitSetArray.java | 9 +- .../frame/data/columns/BooleanArray.java | 9 +- .../runtime/frame/data/columns/DDCArray.java | 33 +++-- .../compress/ArrayCompressionStatistics.java | 50 ++++++++ .../compress/CompressedFrameBlockFactory.java | 119 ++++++++++++++++++ ...ics.java => FrameCompressionSettings.java} | 15 ++- .../FrameCompressionSettingsBuilder.java | 54 ++++++++ .../frame/data/lib/FrameLibCompress.java | 10 +- .../cp/CompressionCPInstruction.java | 8 +- .../runtime/io/FrameWriterCompressed.java | 2 +- .../frame/array/FrameArrayTests.java | 1 + .../frame/compress/FrameCompressTest.java | 57 ++++++++- .../compress/FrameCompressTestLogging.java | 99 +++++++++++++++ .../functions/codegen/APICodegenTest.java | 12 +- 19 files changed, 558 insertions(+), 60 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/frame/data/compress/ArrayCompressionStatistics.java create mode 100644 src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java rename src/main/java/org/apache/sysds/runtime/frame/data/compress/{FrameCompressionStatistics.java => FrameCompressionSettings.java} (72%) create mode 100644 src/main/java/org/apache/sysds/runtime/frame/data/compress/FrameCompressionSettingsBuilder.java create mode 100644 src/test/java/org/apache/sysds/test/component/frame/compress/FrameCompressTestLogging.java diff --git a/src/main/java/org/apache/sysds/api/mlcontext/MLContext.java b/src/main/java/org/apache/sysds/api/mlcontext/MLContext.java index 838f8c76ef6..64dafd3f5ca 100644 --- a/src/main/java/org/apache/sysds/api/mlcontext/MLContext.java +++ b/src/main/java/org/apache/sysds/api/mlcontext/MLContext.java @@ -73,6 +73,9 @@ public class MLContext implements ConfigurableAPI */ private static MLContext activeMLContext = null; + /** Welcome message */ + public static boolean welcomePrint = false; + /** * Contains cleanup methods used by MLContextProxy. */ @@ -262,8 +265,9 @@ private void initMLContext(SparkSession spark) { } } - if (activeMLContext == null) { + if (!welcomePrint) { System.out.println(MLContextUtil.welcomeMessage()); + welcomePrint = true; } this.spark = spark; diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/sample/SampleEstimatorFactory.java b/src/main/java/org/apache/sysds/runtime/compress/estim/sample/SampleEstimatorFactory.java index 39cb706e347..01d3c8449ee 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/sample/SampleEstimatorFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/sample/SampleEstimatorFactory.java @@ -30,8 +30,20 @@ public interface SampleEstimatorFactory { public enum EstimationType { HassAndStokes, ShlosserEstimator, // - ShlosserJackknifeEstimator, SmoothedJackknifeEstimator, - HassAndStokesNoSolveCache, + ShlosserJackknifeEstimator, SmoothedJackknifeEstimator, HassAndStokesNoSolveCache, + } + + /** + * Estimate a distinct number of values based on frequencies. + * + * @param frequencies A list of frequencies of unique values, Note all values contained should be larger than zero + * @param nRows The total number of rows to consider, Note should always be larger or equal to sum(frequencies) + * @param sampleSize The size of the sample, Note this should ideally be scaled to match the sum(frequencies) and + * should always be lower or equal to nRows + * @return A estimated number of unique values + */ + public static int distinctCount(int[] frequencies, int nRows, int sampleSize) { + return distinctCount(frequencies, nRows, sampleSize, EstimationType.HassAndStokes, null); } /** diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java index 5faacaf8d1d..486bac29feb 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java @@ -86,7 +86,7 @@ public class FrameBlock implements CacheBlock, Externalizable { /** Buffer size variable: 1M elements, size of default matrix block */ public static final int BUFFER_SIZE = 1 * 1000 * 1000; - /** If debugging is enabled for the FrameBlocks in stable state*/ + /** If debugging is enabled for the FrameBlocks in stable state */ public static boolean debug = false; /** The schema of the data frame as an ordered list of value types */ @@ -197,6 +197,55 @@ public FrameBlock(ValueType[] schema, String[] colNames, ColumnMetadata[] meta, _nRow = data[0].size(); } + /** + * Create a FrameBlock containing columns of the specified arrays + * + * @param data The column data contained + */ + public FrameBlock(Array[] data) { + _schema = new ValueType[data.length]; + for(int i = 0; i < data.length; i++) + _schema[i] = data[i].getValueType(); + + _colnames = null; + ensureAllocateMeta(); + _coldata = data; + _nRow = data[0].size(); + + if(debug) { + for(int i = 0; i < data.length; i++) { + if(data[i].size() != getNumRows()) + throw new DMLRuntimeException( + "Invalid Frame allocation with different size arrays " + data[i].size() + " vs " + getNumRows()); + } + } + } + + /** + * Create a FrameBlock containing columns of the specified arrays and names + * + * @param data The column data contained + * @param colnames The column names of the contained columns + */ + public FrameBlock(Array[] data, String[] colnames) { + _schema = new ValueType[data.length]; + for(int i = 0; i < data.length; i++) + _schema[i] = data[i].getValueType(); + + _colnames = colnames; + ensureAllocateMeta(); + _coldata = data; + _nRow = data[0].size(); + + if(debug) { + for(int i = 0; i < data.length; i++) { + if(data[i].size() != getNumRows()) + throw new DMLRuntimeException( + "Invalid Frame allocation with different size arrays " + data[i].size() + " vs " + getNumRows()); + } + } + } + /** * Get the number of rows of the frame block. * diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ACompressedArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ACompressedArray.java index a36a0c3cc53..90ceb5f6a20 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ACompressedArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ACompressedArray.java @@ -20,6 +20,7 @@ package org.apache.sysds.runtime.frame.data.columns; import org.apache.sysds.runtime.compress.DMLCompressionException; +import org.apache.sysds.runtime.frame.data.compress.ArrayCompressionStatistics; /** * A Compressed Array, in general does not allow us to set or modify the array. @@ -102,4 +103,10 @@ public void reset(int size) { throw new DMLCompressionException("Invalid to reset compressed array"); } + @Override + public ArrayCompressionStatistics statistics(int nSamples) { + // already compressed + return null; + } + } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java index 3fbf3ed2d0c..b544104df06 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java @@ -23,6 +23,7 @@ import java.util.HashMap; import java.util.Iterator; import java.util.Map; +import java.util.Map.Entry; import org.apache.commons.lang.NotImplementedException; import org.apache.commons.logging.Log; @@ -30,7 +31,9 @@ import org.apache.hadoop.io.Writable; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.compress.estim.sample.SampleEstimatorFactory; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType; +import org.apache.sysds.runtime.frame.data.compress.ArrayCompressionStatistics; import org.apache.sysds.runtime.matrix.data.Pair; /** @@ -97,14 +100,15 @@ public Map getRecodeMap() { /** * Recreate the recode map from what is already there. + * * @return */ - protected Map createRecodeMap(){ + protected Map createRecodeMap() { Map map = new HashMap<>(); long id = 0; for(int i = 0; i < size(); i++) { T val = get(i); - if(val != null){ + if(val != null) { Long v = map.putIfAbsent(val, id); if(v == null) id++; @@ -113,19 +117,18 @@ protected Map createRecodeMap(){ return map; } - /** * Get the dictionary of the contained values, including null. * * @return a dictionary containing all unique values. */ - protected Map getDictionary(){ + protected Map getDictionary() { Map dict = new HashMap<>(); int id = 0; - for(int i = 0 ; i < size(); i ++){ + for(int i = 0; i < size(); i++) { T val = get(i); Integer v = dict.putIfAbsent(val, id); - if(v== null) + if(v == null) id++; } @@ -371,7 +374,7 @@ public ABooleanArray getNulls() { * * @return If the array contains null. */ - public boolean containsNull(){ + public boolean containsNull() { return false; } @@ -424,7 +427,7 @@ public final Array changeType(ValueType t) { return changeTypeFloat(); case FP64: return changeTypeDouble(); - case UINT4: + case UINT4: case UINT8: throw new NotImplementedException(); case INT32: @@ -556,7 +559,7 @@ public Pair getMinMaxLength() { * * @param select Modify this to true in indexes that are not empty. */ - public final void findEmpty(boolean[] select){ + public final void findEmpty(boolean[] select) { for(int i = 0; i < select.length; i++) if(isNotEmpty(i)) select[i] = true; @@ -592,28 +595,57 @@ public String toString() { } /** - * Hash the given index of the array. - * It is allowed to return NaN on null elements. + * Hash the given index of the array. It is allowed to return NaN on null elements. * * @param idx The index to hash * @return The hash value of that index. */ public abstract double hashDouble(int idx); - public ArrayIterator getIterator(){ + public ArrayIterator getIterator() { return new ArrayIterator(); } + public ArrayCompressionStatistics statistics(int nSamples) { + + Map d = new HashMap<>(); + for(int i = 0; i < nSamples; i++) { + // super inefficient, but startup + T key = get(i); + if(d.containsKey(key)) + d.put(key, d.get(key) + 1); + else + d.put(key, 1); + } + + final int[] freq = new int[d.size()]; + int id = 0; + for(Entry e : d.entrySet()) + freq[id++] = e.getValue(); + + int estDistinct = SampleEstimatorFactory.distinctCount(freq, size(), nSamples); + long memSize = getInMemorySize(); // uncompressed size + int memSizePerElement = (int) ((memSize * 8L) / size()); + + long ddcSize = DDCArray.estimateInMemorySize(memSizePerElement, estDistinct, size()); + + if(ddcSize < memSize) + return new ArrayCompressionStatistics(memSizePerElement, // + estDistinct, true, FrameArrayType.DDC, memSize, ddcSize); + + return null; + } + public class ArrayIterator implements Iterator { int index = -1; - public int getIndex(){ + public int getIndex() { return index; } @Override public boolean hasNext() { - return index < size()-1; + return index < size() - 1; } @Override diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java index d6c2489ec2c..27a58e6e6cd 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java @@ -30,6 +30,7 @@ import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType; +import org.apache.sysds.runtime.frame.data.compress.ArrayCompressionStatistics; import org.apache.sysds.runtime.matrix.data.Pair; import org.apache.sysds.runtime.util.UtilFunctions; import org.apache.sysds.utils.MemoryEstimates; @@ -539,10 +540,16 @@ public static String longToBits(long l) { } @Override - public double hashDouble(int idx){ + public double hashDouble(int idx) { return get(idx) ? 1.0 : 0.0; } + @Override + public ArrayCompressionStatistics statistics(int nSamples) { + // Unlikely to compress so lets just say... no + return null; + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_size + 10); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java index e74f8bcd653..0d40ebe938b 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java @@ -28,6 +28,7 @@ import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType; +import org.apache.sysds.runtime.frame.data.compress.ArrayCompressionStatistics; import org.apache.sysds.runtime.matrix.data.Pair; import org.apache.sysds.runtime.util.UtilFunctions; import org.apache.sysds.utils.MemoryEstimates; @@ -339,10 +340,16 @@ public static boolean parseBoolean(String value) { } @Override - public double hashDouble(int idx){ + public double hashDouble(int idx) { return get(idx) ? 1.0 : 0.0; } + @Override + public ArrayCompressionStatistics statistics(int nSamples) { + // Unlikely to compress so lets just say... no + return null; + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_data.length * 2 + 10); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java index f7a810d0fdb..7c995769f17 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java @@ -22,7 +22,6 @@ import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; -import java.util.HashMap; import java.util.Map; import java.util.Map.Entry; @@ -56,13 +55,6 @@ public DDCArray(Array dict, AMapToData map) { } } - private static Map invert(Map map) { - Map invMap = new HashMap(); - for(Entry e : map.entrySet()) - invMap.put(e.getValue(), e.getKey()); - return invMap; - } - /** * Try to compress array into DDC format. * @@ -72,30 +64,33 @@ private static Map invert(Map map) { */ @SuppressWarnings("unchecked") public static Array compressToDDC(Array arr) { - // two pass algorithm - if(arr.size() <= 10) + // Early aborts + // if the size is small do not consider + // or if the instance if RaggedArray where all values typically are unique. + if(arr.size() <= 10 || arr instanceof RaggedArray) return arr; - // 1. Get unique + // Two pass algorithm + // 1.full iteration: Get unique Map rcd = arr.getDictionary(); + // Abort if there are to many unique values. if(rcd.size() > arr.size() / 2) return arr; + // Allocate the correct dictionary output Array ar; - if(rcd.keySet().contains(null)) ar = (Array) ArrayFactory.allocateOptional(arr.getValueType(), rcd.size()); else ar = (Array) ArrayFactory.allocate(arr.getValueType(), rcd.size()); - Map rcdInv = invert(rcd); - for(int i = 0; i < rcd.size(); i++) - ar.set(i, rcdInv.get(Integer.valueOf(i))); + // Set elements in the Dictionary array --- much smaller. + for(Entry e : rcd.entrySet()) + ar.set(e.getValue(), e.getKey()); - // 2. Make map + // 2. full iteration: Make map AMapToData m = MapToFactory.create(arr.size(), rcd.size()); - for(int i = 0; i < arr.size(); i++) m.set(i, rcd.get(arr.get(i))); @@ -285,6 +280,10 @@ protected Map getDictionary() { return dict.getDictionary(); } + public static long estimateInMemorySize(int memSizeBitPerElement, int estDistinct, int nRow) { + return (estDistinct * memSizeBitPerElement) / 8 + MapToFactory.estimateInMemorySize(nRow, estDistinct); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/compress/ArrayCompressionStatistics.java b/src/main/java/org/apache/sysds/runtime/frame/data/compress/ArrayCompressionStatistics.java new file mode 100644 index 00000000000..ae416554822 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/frame/data/compress/ArrayCompressionStatistics.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.frame.data.compress; + +import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType; + +public class ArrayCompressionStatistics { + + public final long originalSize; + public final long compressedSizeEstimate; + public final boolean shouldCompress; + public final FrameArrayType bestType; + public final int bitPerValue; + public final int nUnique; + + public ArrayCompressionStatistics(int bitPerValue, int nUnique, boolean shouldCompress, FrameArrayType bestType, + long originalSize, long compressedSizeEstimate) { + this.bitPerValue = bitPerValue; + this.nUnique = nUnique; + this.shouldCompress = shouldCompress; + this.bestType = bestType; + this.originalSize = originalSize; + this.compressedSizeEstimate = compressedSizeEstimate; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(String.format("Compressed Stats: size:%6d->%6d, Use:%10s, Unique:%5d", originalSize, + compressedSizeEstimate, bestType.toString(), nUnique)); + return sb.toString(); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java b/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java new file mode 100644 index 00000000000..b3246863c92 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.frame.data.compress; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.compress.workload.WTreeRoot; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.frame.data.columns.DDCArray; + +public class CompressedFrameBlockFactory { + + private static final Log LOG = LogFactory.getLog(CompressedFrameBlockFactory.class.getName()); + + private final FrameBlock in; + private final FrameCompressionSettings cs; + private final ArrayCompressionStatistics[] stats; + private final Array[] compressedColumns; + + private CompressedFrameBlockFactory(FrameBlock fb, FrameCompressionSettings cs) { + this.in = fb; + this.cs = cs; + this.stats = new ArrayCompressionStatistics[in.getNumColumns()]; + this.compressedColumns = new Array[in.getNumColumns()]; + } + + public static FrameBlock compress(FrameBlock fb) { + FrameCompressionSettings cs = new FrameCompressionSettingsBuilder().create(); + return compress(fb, cs); + } + + public static FrameBlock compress(FrameBlock fb, int k, WTreeRoot root) { + FrameCompressionSettings cs = new FrameCompressionSettingsBuilder()// + .threads(k).wTreeRoot(root).create(); + return compress(fb, cs); + } + + public static FrameBlock compress(FrameBlock fb, FrameCompressionSettingsBuilder csb) { + return compress(fb, csb.create()); + } + + public static FrameBlock compress(FrameBlock fb, FrameCompressionSettings cs) { + return new CompressedFrameBlockFactory(fb, cs).compressFrame(); + } + + private FrameBlock compressFrame() { + extractStatistics(); + logStatistics(); + encodeColumns(); + final FrameBlock ret = new FrameBlock(compressedColumns, in.getColumnNames(false)); + logRet(ret); + return ret; + } + + private void extractStatistics() { + final int nSamples = Math.min(in.getNumRows(), (int) Math.ceil(in.getNumRows() * cs.sampleRatio)); + for(int i = 0; i < stats.length; i++) { + stats[i] = in.getColumn(i).statistics(nSamples); + } + } + + private void encodeColumns() { + for(int i = 0; i < compressedColumns.length; i++) { + if(stats[i] != null) { + // commented out because no other encodings are supported yet + // switch(stats[i].bestType) { + // case DDC: + compressedColumns[i] = DDCArray.compressToDDC(in.getColumn(i)); + // break; + // default: + // compressedColumns[i] = in.getColumn(i); + // break; + // } + } + else + compressedColumns[i] = in.getColumn(i); + } + } + + private void logStatistics() { + if(LOG.isDebugEnabled()) { + for(int i = 0; i < compressedColumns.length; i++) { + if(stats[i] != null) + LOG.debug(stats[i]); + else + LOG.debug("no Comp col: " + i); + } + } + } + + private void logRet(FrameBlock ret) { + if(LOG.isDebugEnabled()) { + final long before = in.getInMemorySize(); + final long after = ret.getInMemorySize(); + LOG.debug(String.format("Uncompressed Size: %15d", before)); + LOG.debug(String.format("compressed Size: %15d", after)); + LOG.debug(String.format("ratio: %15.3f", (double) before / (double) after)); + } + } + +} diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/compress/FrameCompressionStatistics.java b/src/main/java/org/apache/sysds/runtime/frame/data/compress/FrameCompressionSettings.java similarity index 72% rename from src/main/java/org/apache/sysds/runtime/frame/data/compress/FrameCompressionStatistics.java rename to src/main/java/org/apache/sysds/runtime/frame/data/compress/FrameCompressionSettings.java index c235995e1e9..84a23bf6480 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/compress/FrameCompressionStatistics.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/compress/FrameCompressionSettings.java @@ -16,8 +16,21 @@ * specific language governing permissions and limitations * under the License. */ + package org.apache.sysds.runtime.frame.data.compress; -public class FrameCompressionStatistics { +import org.apache.sysds.runtime.compress.workload.WTreeRoot; + +public class FrameCompressionSettings { + + public final float sampleRatio; + public final int k; + public final WTreeRoot wt; + + protected FrameCompressionSettings(float sampleRatio, int k, WTreeRoot wt) { + this.sampleRatio = sampleRatio; + this.k = k; + this.wt = wt; + } } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/compress/FrameCompressionSettingsBuilder.java b/src/main/java/org/apache/sysds/runtime/frame/data/compress/FrameCompressionSettingsBuilder.java new file mode 100644 index 00000000000..936cd42898d --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/frame/data/compress/FrameCompressionSettingsBuilder.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.frame.data.compress; + +import org.apache.sysds.runtime.compress.workload.WTreeRoot; + +public class FrameCompressionSettingsBuilder { + + private float sampleRatio; + private int k; + private WTreeRoot wt; + + public FrameCompressionSettingsBuilder() { + this.sampleRatio = 0.1f; + this.k = 1; + this.wt = null; + } + + public FrameCompressionSettingsBuilder wTreeRoot(WTreeRoot wt) { + this.wt = wt; + return this; + } + + public FrameCompressionSettingsBuilder threads(int k) { + this.k = k; + return this; + } + + public FrameCompressionSettingsBuilder sampleRatio(float sampleRatio) { + this.sampleRatio = sampleRatio; + return this; + } + + public FrameCompressionSettings create() { + return new FrameCompressionSettings(sampleRatio, k, wt); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibCompress.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibCompress.java index 584462200d7..207ece8d264 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibCompress.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibCompress.java @@ -18,19 +18,17 @@ */ package org.apache.sysds.runtime.frame.data.lib; -import org.apache.commons.lang3.tuple.ImmutablePair; -import org.apache.commons.lang3.tuple.Pair; import org.apache.sysds.runtime.compress.workload.WTreeRoot; import org.apache.sysds.runtime.frame.data.FrameBlock; -import org.apache.sysds.runtime.frame.data.compress.FrameCompressionStatistics; +import org.apache.sysds.runtime.frame.data.compress.CompressedFrameBlockFactory; public class FrameLibCompress { - public static Pair compress(FrameBlock in, int k) { + public static FrameBlock compress(FrameBlock in, int k) { return compress(in, k, null); } - public static Pair compress(FrameBlock in, int k, WTreeRoot root) { - return new ImmutablePair<>(in, new FrameCompressionStatistics()); + public static FrameBlock compress(FrameBlock in, int k, WTreeRoot root) { + return CompressedFrameBlockFactory.compress(in, k, root); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java index 38b53e090c4..22766079e6e 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java @@ -29,7 +29,6 @@ import org.apache.sysds.runtime.compress.workload.WTreeRoot; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.frame.data.FrameBlock; -import org.apache.sysds.runtime.frame.data.compress.FrameCompressionStatistics; import org.apache.sysds.runtime.frame.data.lib.FrameLibCompress; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -89,12 +88,9 @@ private void processMatrixBlockCompression(ExecutionContext ec, MatrixBlock in, } private void processFrameBlockCompression(ExecutionContext ec, FrameBlock in, int k, WTreeRoot root) { - Pair compResult = FrameLibCompress.compress(in, k, root); - if(LOG.isTraceEnabled()) - LOG.trace(compResult.getRight()); - FrameBlock out = compResult.getLeft(); + FrameBlock compResult = FrameLibCompress.compress(in, k, root); // Set output and release input ec.releaseFrameInput(input1.getName()); - ec.setFrameOutput(output.getName(), out); + ec.setFrameOutput(output.getName(), compResult); } } diff --git a/src/main/java/org/apache/sysds/runtime/io/FrameWriterCompressed.java b/src/main/java/org/apache/sysds/runtime/io/FrameWriterCompressed.java index 70b6e89a9a1..82c5a08e2c0 100644 --- a/src/main/java/org/apache/sysds/runtime/io/FrameWriterCompressed.java +++ b/src/main/java/org/apache/sysds/runtime/io/FrameWriterCompressed.java @@ -40,7 +40,7 @@ public FrameWriterCompressed(boolean parallel) { protected void writeBinaryBlockFrameToHDFS(Path path, JobConf job, FrameBlock src, long rlen, long clen) throws IOException, DMLRuntimeException { int k = parallel ? OptimizerUtils.getParallelBinaryWriteParallelism() : 1; - FrameBlock compressed = FrameLibCompress.compress(src, k).getLeft(); + FrameBlock compressed = FrameLibCompress.compress(src, k); super.writeBinaryBlockFrameToHDFS(path, job, compressed, rlen, clen); } diff --git a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java index 81908f8d39c..a5acb133816 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java +++ b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java @@ -1995,6 +1995,7 @@ public static String[] generateRandomStringNUniqueLength(int size, int seed, int } public static String[] generateRandomStringNUniqueLengthOpt(int size, int seed, int nUnique, int stringLength) { + nUnique = Math.max(1, nUnique); String[] rands = generateRandomStringLength(nUnique, seed, stringLength); rands[rands.length - 1] = null; Random r = new Random(seed + 1); diff --git a/src/test/java/org/apache/sysds/test/component/frame/compress/FrameCompressTest.java b/src/test/java/org/apache/sysds/test/component/frame/compress/FrameCompressTest.java index bdf6038550d..fc4e69d7525 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/compress/FrameCompressTest.java +++ b/src/test/java/org/apache/sysds/test/component/frame/compress/FrameCompressTest.java @@ -19,12 +19,63 @@ package org.apache.sysds.test.component.frame.compress; -import org.apache.sysds.runtime.frame.data.compress.FrameCompressionStatistics; +import static org.junit.Assert.fail; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.frame.data.columns.ArrayFactory; +import org.apache.sysds.runtime.frame.data.compress.CompressedFrameBlockFactory; +import org.apache.sysds.runtime.frame.data.compress.FrameCompressionSettings; +import org.apache.sysds.runtime.frame.data.lib.FrameLibCompress; +import org.apache.sysds.test.TestUtils; +import org.apache.sysds.test.component.frame.array.FrameArrayTests; import org.junit.Test; public class FrameCompressTest { + protected static final Log LOG = LogFactory.getLog(FrameCompressTest.class.getName()); + + @Test + public void testSingleThread() { + FrameBlock a = generateCompressableBlock(200, 5, 1232); + runTest(a, 1); + } + @Test - public void testCompressionStatisticsConstruction() { - new FrameCompressionStatistics(); + public void testParallel() { + FrameBlock a = generateCompressableBlock(200, 5, 1232); + runTest(a, 4); + } + + public void runTest(FrameBlock a, int k) { + try { + FrameBlock b = FrameLibCompress.compress(a, k); + TestUtils.compareFrames(a, b, true); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + public void runTestConfig(FrameBlock a, FrameCompressionSettings cs) { + try { + FrameBlock b = CompressedFrameBlockFactory.compress(a, cs); + TestUtils.compareFrames(a, b, true); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + private FrameBlock generateCompressableBlock(int rows, int cols, int seed) { + Array[] data = new Array[cols]; + for(int i = 0; i < cols; i++) { + data[i] = ArrayFactory.create(// + FrameArrayTests.generateRandomStringNUniqueLengthOpt(rows, seed + i, i + 1, 55 + i)); + } + return new FrameBlock(data); } } diff --git a/src/test/java/org/apache/sysds/test/component/frame/compress/FrameCompressTestLogging.java b/src/test/java/org/apache/sysds/test/component/frame/compress/FrameCompressTestLogging.java new file mode 100644 index 00000000000..45de293d249 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/frame/compress/FrameCompressTestLogging.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.frame.compress; + +import static org.junit.Assert.fail; + +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.log4j.Level; +import org.apache.log4j.Logger; +import org.apache.log4j.spi.LoggingEvent; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.frame.data.columns.ArrayFactory; +import org.apache.sysds.runtime.frame.data.compress.CompressedFrameBlockFactory; +import org.apache.sysds.runtime.frame.data.lib.FrameLibCompress; +import org.apache.sysds.test.LoggingUtils; +import org.apache.sysds.test.LoggingUtils.TestAppender; +import org.apache.sysds.test.TestUtils; +import org.apache.sysds.test.component.frame.array.FrameArrayTests; +import org.junit.Test; + +public class FrameCompressTestLogging { + protected static final Log LOG = LogFactory.getLog(FrameCompressTestLogging.class.getName()); + + @Test + public void testCompressable() { + testLogging(generateCompressableBlock(200, 3, 3214)); + } + + @Test + public void testUnCompressable() { + testLogging(generateIncompressableBlock(200, 3, 2321)); + } + + public void testLogging(FrameBlock a) { + final TestAppender appender = LoggingUtils.overwrite(); + try { + Logger.getLogger(CompressedFrameBlockFactory.class).setLevel(Level.TRACE); + + FrameBlock b = FrameLibCompress.compress(a, 1); + + TestUtils.compareFrames(a, b, true); + + final List log = LoggingUtils.reinsert(appender); + for(LoggingEvent l : log) { + if(l.getMessage().toString().contains("ratio: ")) + return; + } + fail("Log did not contain Dictionary sizes"); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + finally { + Logger.getLogger(CompressedFrameBlockFactory.class).setLevel(Level.WARN); + LoggingUtils.reinsert(appender); + } + + } + + private FrameBlock generateCompressableBlock(int rows, int cols, int seed) { + Array[] data = new Array[cols]; + for(int i = 0; i < cols; i++) { + data[i] = ArrayFactory.create(// + FrameArrayTests.generateRandomStringNUniqueLengthOpt(rows, seed + i, i + 1, 55 + i)); + } + return new FrameBlock(data); + } + + private FrameBlock generateIncompressableBlock(int rows, int cols, int seed) { + Array[] data = new Array[cols]; + for(int i = 0; i < cols; i++) { + data[i] = ArrayFactory.create(// + FrameArrayTests.generateRandomStringNUniqueLengthOpt(rows, seed + i, rows, 55 + i)); + } + return new FrameBlock(data); + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/codegen/APICodegenTest.java b/src/test/java/org/apache/sysds/test/functions/codegen/APICodegenTest.java index f20ba75c26d..b316e2c5874 100644 --- a/src/test/java/org/apache/sysds/test/functions/codegen/APICodegenTest.java +++ b/src/test/java/org/apache/sysds/test/functions/codegen/APICodegenTest.java @@ -23,21 +23,20 @@ import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; -import org.junit.After; -import org.junit.Assert; -import org.junit.Test; import org.apache.sysds.api.DMLScript; import org.apache.sysds.api.jmlc.Connection; import org.apache.sysds.api.jmlc.PreparedScript; import org.apache.sysds.api.mlcontext.MLContext; import org.apache.sysds.api.mlcontext.Script; -import org.apache.sysds.conf.DMLConfig; import org.apache.sysds.conf.CompilerConfig.ConfigType; +import org.apache.sysds.conf.DMLConfig; import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.DataConverter; import org.apache.sysds.test.AutomatedTestBase; -import org.apache.sysds.utils.Statistics; +import org.junit.After; +import org.junit.Assert; +import org.junit.Test; public class APICodegenTest extends AutomatedTestBase @@ -85,12 +84,13 @@ private void runMLContextParforDatasetTest(boolean jmlc) pscript.setMatrix("X", mX, false); pscript.executeScript(); conn.close(); - System.out.println(Statistics.display()); + // System.out.println(Statistics.display()); } else { SparkConf conf = SparkExecutionContext.createSystemDSSparkConf() .setAppName("MLContextTest").setMaster("local"); JavaSparkContext sc = new JavaSparkContext(conf); + MLContext.welcomePrint = true; MLContext ml = new MLContext(sc); ml.setConfigProperty(DMLConfig.CODEGEN, "true"); ml.setStatistics(true);