From 21c0c07c3e95fc1d8e087bd658c70dc21f238aea Mon Sep 17 00:00:00 2001 From: baunsgaard Date: Wed, 21 Dec 2022 19:40:47 +0100 Subject: [PATCH 1/5] [SYSTEMDS-3481] FrameFromMatrix Improvements This Commit introduce various updates and refinements to the FrameBlock infrastructure. In specific the modification and changing of MatrixBlock to FrameBlock is optimized. In this process the parallelization of the instructions is critical, and therefore contained in this commit is a larger change to Unary instructions to now all contain a thread count in the instruction string. This change is also effecting instructions that does not nessesarily need a thread count such as broadcast, but it did give the opportunity to refine the applySchema, toMatrix, toFrame, and other instructions to be parallel. Example: changing a Matrix 64kx2k to boolean frame: before 2.2 sec single thread, after single thread 0.9 sec, and parallel 0.13 sec. Also improved is the reading time of frames, where before the time varied drastically depending on block size saved, it is now improved from 0.56 sec to 0.13 sec on 500 block size. A final update to also imrove overall execution is compiletime, I observed that the compiletime if we include IO operations increase to 0.6 sec. While if we do not have IO operations it is 0.3 sec. This is due to the hadoop IO we are using taking up to 70% of the compile time in cases where we have simple scripts with only read and a single operation. This is a constant overhead on the fist IO operation that does not effect subsequent IO operations, to improve this i have moved this to a parallel operation when we construct the JobConfiguration. This improve the compile time of systemds in general from ~0.6 sec when using IO to ~0.2 sec. --- .../org/apache/sysds/hops/AggUnaryOp.java | 6 +- .../java/org/apache/sysds/hops/BinaryOp.java | 2 +- src/main/java/org/apache/sysds/hops/Hop.java | 2 +- .../org/apache/sysds/hops/LeftIndexingOp.java | 3 +- .../java/org/apache/sysds/hops/UnaryOp.java | 25 +- src/main/java/org/apache/sysds/lops/Lop.java | 2 +- .../java/org/apache/sysds/lops/Unary.java | 14 +- .../java/org/apache/sysds/lops/UnaryCP.java | 19 +- .../dictionary/DictLibMatrixMult.java | 1 - .../context/ExecutionContext.java | 2 +- .../sysds/runtime/frame/data/FrameBlock.java | 809 +++++++++--------- .../runtime/frame/data/columns/Array.java | 259 +++++- .../frame/data/columns/ArrayFactory.java | 10 +- .../frame/data/columns/BitSetArray.java | 307 ++++--- .../frame/data/columns/BooleanArray.java | 105 ++- .../frame/data/columns/DoubleArray.java | 78 +- .../frame/data/columns/FloatArray.java | 76 +- .../frame/data/columns/IntegerArray.java | 85 +- .../runtime/frame/data/columns/LongArray.java | 83 +- .../frame/data/columns/StringArray.java | 127 +-- .../frame/data/lib/FrameFromMatrixBlock.java | 408 +++++++-- .../frame/data/lib/FrameLibApplySchema.java | 14 +- .../frame/data/lib/FrameLibDetectSchema.java | 58 +- .../instructions/CPInstructionParser.java | 21 +- .../instructions/InstructionUtils.java | 25 +- .../instructions/cp/BinaryCPInstruction.java | 25 +- .../cp/BinaryFrameFrameCPInstruction.java | 2 +- .../cp/BinaryScalarScalarCPInstruction.java | 3 + .../cp/BroadcastCPInstruction.java | 2 +- .../instructions/cp/DataGenCPInstruction.java | 10 +- .../cp/PrefetchCPInstruction.java | 3 +- .../instructions/cp/UnaryCPInstruction.java | 21 + .../cp/UnaryFrameCPInstruction.java | 8 +- .../cp/VariableCPInstruction.java | 60 +- .../spark/UnaryFrameSPInstruction.java | 2 +- .../sysds/runtime/io/FrameWriterTextCSV.java | 7 +- .../runtime/matrix/data/LibMatrixReorg.java | 3 - .../matrix/operators/UnaryOperator.java | 4 + .../sysds/runtime/util/DataConverter.java | 54 +- .../sysds/runtime/util/UtilFunctions.java | 84 ++ src/main/python/systemds/utils/converters.py | 8 +- .../java/org/apache/sysds/test/TestUtils.java | 84 +- .../AbstractCompressedUnaryTests.java | 3 - .../compress/colgroup/JolEstimateRLETest.java | 3 - .../frame/FrameFromMatrixBlockTest.java | 174 +++- .../frame/array/CustomArrayTests.java | 21 + .../frame/array/FrameArrayConstantTests.java | 123 +++ .../frame/array/FrameArrayTests.java | 437 +++++++++- .../frame/array/NegativeArrayTests.java | 27 + .../functions/frame/FrameConstructorTest.java | 28 +- 50 files changed, 2732 insertions(+), 1005 deletions(-) create mode 100644 src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayConstantTests.java diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java index 23439b182e8..ac4e018a5f9 100644 --- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java +++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java @@ -138,7 +138,7 @@ else if( et != ExecType.FED && isUnaryAggregateOuterCPRewriteApplicable() ) if (getDataType() == DataType.SCALAR) { UnaryCP unary1 = new UnaryCP(agg1, OpOp1.CAST_AS_SCALAR, - getDataType(), getValueType()); + getDataType(), getValueType(), 1); unary1.getOutputParameters().setDimensions(0, 0, 0, -1); setLineNumbers(unary1); agg1 = unary1; @@ -180,7 +180,7 @@ else if( isUnaryAggregateOuterSPRewriteApplicable() ) if (getDataType() == DataType.SCALAR) { UnaryCP unary1 = new UnaryCP(transform1, - OpOp1.CAST_AS_SCALAR, getDataType(), getValueType()); + OpOp1.CAST_AS_SCALAR, getDataType(), getValueType(), 1); unary1.getOutputParameters().setDimensions(0, 0, 0, -1); setLineNumbers(unary1); setLops(unary1); @@ -200,7 +200,7 @@ else if( isUnaryAggregateOuterSPRewriteApplicable() ) if (getDataType() == DataType.SCALAR) { UnaryCP unary1 = new UnaryCP(aggregate, - OpOp1.CAST_AS_SCALAR, getDataType(), getValueType()); + OpOp1.CAST_AS_SCALAR, getDataType(), getValueType(), 1); unary1.getOutputParameters().setDimensions(0, 0, 0, -1); setLineNumbers(unary1); setLops(unary1); diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java index 2346eeebfe6..549bf53e33d 100644 --- a/src/main/java/org/apache/sysds/hops/BinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java @@ -453,7 +453,7 @@ op, getDataType(), getValueType(), et, getInput().get(0) == getInput().get(1).getInput().get(0); if(isGPUSoftmax) { UnaryCP softmax = new UnaryCP(getInput().get(0).getInput().get(0).constructLops(), - OpOp1.SOFTMAX, getDataType(), getValueType(), et); + OpOp1.SOFTMAX, getDataType(), getValueType(), et, OptimizerUtils.getConstrainedNumThreads(_maxNumThreads)); setOutputDimensions(softmax); setLineNumbers(softmax); setLops(softmax); diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java index fa911749ee7..3119208f84b 100644 --- a/src/main/java/org/apache/sysds/hops/Hop.java +++ b/src/main/java/org/apache/sysds/hops/Hop.java @@ -65,7 +65,7 @@ import org.apache.sysds.runtime.util.UtilFunctions; public abstract class Hop implements ParseInfo { - private static final Log LOG = LogFactory.getLog(Hop.class.getName()); + protected static final Log LOG = LogFactory.getLog(Hop.class.getName()); public static final long CPThreshold = 2000; diff --git a/src/main/java/org/apache/sysds/hops/LeftIndexingOp.java b/src/main/java/org/apache/sysds/hops/LeftIndexingOp.java index 79d863adf9c..ed8cc192d0f 100644 --- a/src/main/java/org/apache/sysds/hops/LeftIndexingOp.java +++ b/src/main/java/org/apache/sysds/hops/LeftIndexingOp.java @@ -122,9 +122,10 @@ public Lop constructLops() //insert cast to matrix if necessary (for reuse broadcast runtime) Lop rightInput = right.constructLops(); if (isRightHandSideScalar()) { + // one thread because it is cast from scalar. rightInput = new UnaryCP(rightInput, (left.getDataType()==DataType.MATRIX?OpOp1.CAST_AS_MATRIX:OpOp1.CAST_AS_FRAME), - left.getDataType(), right.getValueType()); + left.getDataType(), right.getValueType(), 1); long bsize = ConfigurationManager.getBlocksize(); rightInput.getOutputParameters().setDimensions( 1, 1, bsize, -1); } diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java b/src/main/java/org/apache/sysds/hops/UnaryOp.java index b250ce2c1b9..97b3b4096df 100644 --- a/src/main/java/org/apache/sysds/hops/UnaryOp.java +++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java @@ -130,35 +130,31 @@ public Lop constructLops() //reuse existing lop if( getLops() != null ) return getLops(); - int k; + final int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); try { Hop input = getInput().get(0); Lop ret = null; switch(_op){ case COMPRESS: - k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); ret = new Compression(input.constructLops(), getDataType(), getValueType(), optFindExecType(), 0); break; case DECOMPRESS: - k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); ret = new DeCompression(input.constructLops(), getDataType(), getValueType(), optFindExecType()); break; case LOCAL: ret = new Local(input.constructLops(), getDataType(), getValueType()); break; default: + final boolean isScalarIn = getInput().get(0).getDataType() == DataType.SCALAR; if(getDataType() == DataType.SCALAR // value type casts or matrix to scalar - || (_op == OpOp1.CAST_AS_MATRIX && getInput().get(0).getDataType() == DataType.SCALAR) || - (_op == OpOp1.CAST_AS_FRAME && getInput().get(0).getDataType() == DataType.SCALAR)) { - if(_op == OpOp1.IQM) { // special handling IQM + || (_op == OpOp1.CAST_AS_MATRIX && isScalarIn) // cast matrix + || (_op == OpOp1.CAST_AS_FRAME && isScalarIn)) { // cast frame + if(_op == OpOp1.IQM) // special handling IQM ret = constructLopsIQM(); - } - else if(_op == OpOp1.MEDIAN) { + else if(_op == OpOp1.MEDIAN) ret = constructLopsMedian(); - } - else { // general case SCALAR/CAST (always in CP) - ret = new UnaryCP(input.constructLops(), _op, getDataType(), getValueType()); - } + else // general case SCALAR/CAST (always in CP) & always single threaded + ret = new UnaryCP(input.constructLops(), _op, getDataType(), getValueType(), k); } else { // general case MATRIX ExecType et = optFindExecType(); @@ -168,13 +164,10 @@ else if(_op == OpOp1.MEDIAN) { // TODO additional physical operation if offsets fit in memory ret = constructLopsSparkCumulativeUnary(); } - else // default unary - { + else {// default unary final boolean inplace = OptimizerUtils.ALLOW_UNARY_UPDATE_IN_PLACE && input.getParent().size() == 1 && (!(input instanceof DataOp) || !((DataOp) input).isRead()); - k = isCumulativeUnaryOperation() || isExpensiveUnaryOperation() ? - OptimizerUtils.getConstrainedNumThreads(_maxNumThreads) : 1; ret = new Unary(input.constructLops(), _op, getDataType(), getValueType(), et, k, inplace); } } diff --git a/src/main/java/org/apache/sysds/lops/Lop.java b/src/main/java/org/apache/sysds/lops/Lop.java index 37c2097fcba..ecc9e7f893e 100644 --- a/src/main/java/org/apache/sysds/lops/Lop.java +++ b/src/main/java/org/apache/sysds/lops/Lop.java @@ -38,7 +38,7 @@ public abstract class Lop { - private static final Log LOG = LogFactory.getLog(Lop.class.getName()); + protected static final Log LOG = LogFactory.getLog(Lop.class.getName()); public enum Type { Data, DataGen, //CP/MR read/write/datagen diff --git a/src/main/java/org/apache/sysds/lops/Unary.java b/src/main/java/org/apache/sysds/lops/Unary.java index e68d060e71b..c5323325f11 100644 --- a/src/main/java/org/apache/sysds/lops/Unary.java +++ b/src/main/java/org/apache/sysds/lops/Unary.java @@ -164,14 +164,20 @@ public String getInstructions(String input1, String output) { sb.append( OPERAND_DELIMITOR ); sb.append( prepOutputOperand(output) ); - //num threads for cumulative cp ops - if( (getExecType() == ExecType.CP || getExecType() == ExecType.FED) && isMultiThreadedOp(operation) ) { + if( getExecType() == ExecType.CP || getExecType() == ExecType.FED){ sb.append( OPERAND_DELIMITOR ); sb.append( _numThreads ); - sb.append( OPERAND_DELIMITOR ); - sb.append( _inplace ); + if( isMultiThreadedOp(operation)){ + + sb.append( OPERAND_DELIMITOR ); + sb.append( _inplace ); + } } + // //num threads for cumulative cp ops + // if( (getExecType() == ExecType.CP || getExecType() == ExecType.FED) && isMultiThreadedOp(operation) ) { + // } + appendFedOut(sb); return sb.toString(); diff --git a/src/main/java/org/apache/sysds/lops/UnaryCP.java b/src/main/java/org/apache/sysds/lops/UnaryCP.java index 9c2a67f94bc..09ce9dd5681 100644 --- a/src/main/java/org/apache/sysds/lops/UnaryCP.java +++ b/src/main/java/org/apache/sysds/lops/UnaryCP.java @@ -29,6 +29,7 @@ public class UnaryCP extends Lop { private OpOp1 operation; + private int _numThreads = 1; /** * Constructor to perform a scalar operation @@ -38,22 +39,31 @@ public class UnaryCP extends Lop * @param dt data type of the output * @param vt value type of the output * @param et exec type + * @param k parallelization degree */ - public UnaryCP(Lop input, OpOp1 op, DataType dt, ValueType vt, ExecType et) { + public UnaryCP(Lop input, OpOp1 op, DataType dt, ValueType vt, ExecType et, int k) { super(Lop.Type.UnaryCP, dt, vt); operation = op; addInput(input); input.addOutput(this); lps.setProperties(inputs, et); } + + public UnaryCP(Lop input, OpOp1 op, DataType dt, ValueType vt, ExecType et) { + this(input, op, dt, vt, et, 1); + } + public UnaryCP(Lop input, OpOp1 op, DataType dt, ValueType vt, int k) { + this(input, op, dt, vt, ExecType.CP, k); + } + public UnaryCP(Lop input, OpOp1 op, DataType dt, ValueType vt) { - this(input, op, dt, vt, ExecType.CP); + this(input, op, dt, vt, ExecType.CP, 1); } @Override public String toString() { - return "Operation: " + operation; + return "Operation: " + getInstructions("", ""); } private String getOpCode() { @@ -65,6 +75,7 @@ public String getInstructions(String input, String output) { return InstructionUtils.concatOperands( getExecType().name(), getOpCode(), getInputs().get(0).prepScalarInputOperand(getExecType()), - prepOutputOperand(output)); + prepOutputOperand(output), + Integer.toString(_numThreads)); } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictLibMatrixMult.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictLibMatrixMult.java index 2b2e41600a6..3d59f854e32 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictLibMatrixMult.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictLibMatrixMult.java @@ -426,7 +426,6 @@ protected static void MMToUpperTriangleSparseDenseDiagonal(SparseBlock left, dou protected static void MMToUpperTriangleDenseDense(double[] left, double[] right, int[] rowsLeft, int[] colsRight, MatrixBlock result) { final int loc = location(rowsLeft, colsRight); - // LOG.error("loc:" + loc); if(loc < 0) MMToUpperTriangleDenseDenseAllUpperTriangle(left, right, rowsLeft, colsRight, result); else if(loc > 0) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java index f075eea78aa..a39a04166be 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/ExecutionContext.java @@ -933,7 +933,7 @@ public Set getTmpParforFunctions() { @Override public String toString(){ StringBuilder sb = new StringBuilder(); - sb.append(super.toString()); + sb.append(this.getClass().getSimpleName().toString()); if(_prog != null) sb.append("\nProgram: " + _prog.toString()); if(_variables != null) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java index bf7ce6bfb71..da69265c143 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java @@ -51,7 +51,6 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.codegen.CodegenUtils; import org.apache.sysds.runtime.controlprogram.caching.CacheBlock; -import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysds.runtime.frame.data.columns.Array; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory; @@ -79,7 +78,7 @@ import org.apache.sysds.runtime.util.UtilFunctions; import org.apache.sysds.utils.MemoryEstimates; -@SuppressWarnings({"rawtypes","unchecked"}) //allow generic native arrays +@SuppressWarnings({"rawtypes", "unchecked"}) // allow generic native arrays public class FrameBlock implements CacheBlock, Externalizable { private static final long serialVersionUID = -3993450030207130665L; private static final Log LOG = LogFactory.getLog(FrameBlock.class.getName()); @@ -100,7 +99,7 @@ public class FrameBlock implements CacheBlock, Externalizable { /** The column names of the data frame as an ordered list of strings, allocated on-demand */ private String[] _colnames = null; - /** The column metadata */ + /** The column metadata */ private ColumnMetadata[] _colmeta = null; /** The data frame data as an ordered list of columns */ @@ -114,9 +113,8 @@ public FrameBlock() { } /** - * Copy constructor for frame blocks, which uses a shallow copy for - * the schema (column types and names) but a deep copy for meta data - * and actual column data. + * Copy constructor for frame blocks, which uses a shallow copy for the schema (column types and names) but a deep + * copy for meta data and actual column data. * * @param that frame block */ @@ -139,25 +137,40 @@ public FrameBlock(ValueType[] schema, String[] names) { } public FrameBlock(ValueType[] schema, String[][] data) { - //default column names not materialized + // default column names not materialized this(schema, null, data); } + /** + * FrameBlock constructor with constant + * + * @param schema The schema to allocate (also specifying number of columns) + * @param constant The constant to allocate in all cells + * @param nRow the number of rows + */ + public FrameBlock(ValueType[] schema, String constant, int nRow) { + this(); + // allocate the values. + + for(int i = 0; i < schema.length; i++) + appendColumn(ArrayFactory.allocate(schema[i], nRow, constant)); + } + public FrameBlock(ValueType[] schema, String[] names, String[][] data) { - _numRows = 0; //maintained on append + _numRows = 0; // maintained on append _schema = schema; _colnames = names; ensureAllocateMeta(); if(data != null) - for( int i=0; i[] data ){ + public FrameBlock(ValueType[] schema, String[] colNames, ColumnMetadata[] meta, Array[] data) { _numRows = data[0].size(); _schema = schema; _colnames = colNames; - _colmeta = meta; + _colmeta = meta; _coldata = data; } @@ -202,8 +215,7 @@ public void setNumRows(int numRows) { } /** - * Get the number of columns of the frame block, that is - * the number of columns defined in the schema. + * Get the number of columns of the frame block, that is the number of columns defined in the schema. * * @return number of columns */ @@ -236,8 +248,7 @@ public void setSchema(ValueType[] schema) { } /** - * Returns the column names of the frame block. This method - * allocates default column names if required. + * Returns the column names of the frame block. This method allocates default column names if required. * * @return column names */ @@ -252,27 +263,25 @@ public FrameBlock getColumnNamesAsFrame() { } /** - * Returns the column names of the frame block. This method - * allocates default column names if required. + * Returns the column names of the frame block. This method allocates default column names if required. * * @param alloc if true, create column names * @return array of column names */ public String[] getColumnNames(boolean alloc) { - if( _colnames == null && alloc ) + if(_colnames == null && alloc) _colnames = createColNames(getNumColumns()); return _colnames; } /** - * Returns the column name for the requested column. This - * method allocates default column names if required. + * Returns the column name for the requested column. This method allocates default column names if required. * * @param c column index * @return column name */ public String getColumnName(int c) { - if( _colnames == null ) + if(_colnames == null) _colnames = createColNames(getNumColumns()); return _colnames[c]; } @@ -289,13 +298,13 @@ public ColumnMetadata getColumnMetadata(int c) { return _colmeta[c]; } - public Array[] getColumns(){ + public Array[] getColumns() { return _coldata; } public boolean isColumnMetadataDefault() { boolean ret = true; - for( int j=0; j getColumnNameIDMap() { + public Map getColumnNameIDMap() { Map ret = new HashMap<>(); - for( int j=0; j 0 && _numRows != newLen ) - throw new RuntimeException("Mismatch in number of rows: "+newLen+" (expected: "+_numRows+")"); + if(_coldata != null && _coldata.length > 0 && _numRows != newLen) + throw new RuntimeException("Mismatch in number of rows: " + newLen + " (expected: " + _numRows + ")"); _numRows = newLen; } @@ -384,8 +392,8 @@ public static String[] createColNames(int size) { public static String[] createColNames(int off, int size) { String[] ret = new String[size]; - for( int i=off+1; i<=off+size; i++ ) - ret[i-off-1] = createColName(i); + for(int i = off + 1; i <= off + size; i++) + ret[i - off - 1] = createColName(i); return ret; } @@ -395,20 +403,19 @@ public static String createColName(int i) { public boolean isColNamesDefault() { boolean ret = (_colnames != null); - for( int j=0; j getColumn(int c) { } public void setColumn(int c, Array column) { - if( _coldata == null ) + if(_coldata == null) _coldata = new Array[getNumColumns()]; _coldata[c] = column; _msize = -1; @@ -688,7 +686,8 @@ public void readFields(DataInput in) throws IOException { if(!isDefaultMeta) { // If not default meta read in meta _colnames[j] = in.readUTF(); _colmeta[j] = ColumnMetadata.read(in); - }else{ + } + else { _colmeta[j] = new ColumnMetadata(); // must be allocated. } if(type >= 0) // if in allocated column data then read it @@ -699,13 +698,13 @@ public void readFields(DataInput in) throws IOException { @Override public void writeExternal(ObjectOutput out) throws IOException { - //redirect serialization to writable impl + // redirect serialization to writable impl write(out); } @Override public void readExternal(ObjectInput in) throws IOException { - //redirect deserialization to writable impl + // redirect deserialization to writable impl readFields(in); } @@ -773,40 +772,40 @@ public boolean isShallowSerialize() { @Override public boolean isShallowSerialize(boolean inclConvert) { - //shallow serialize if non-string schema because a frame block - //is always dense but strings have large array overhead per cell + // shallow serialize if non-string schema because a frame block + // is always dense but strings have large array overhead per cell boolean ret = true; - for( int j=0; j<_schema.length && ret; j++ ) + for(int j = 0; j < _schema.length && ret; j++) ret &= (_schema[j] != ValueType.STRING); return ret; } @Override public void toShallowSerializeBlock() { - //do nothing (not applicable). + // do nothing (not applicable). } @Override public void compactEmptyBlock() { - //do nothing + // do nothing } /** - * This method performs the value comparison on two frames - * if the values in both frames are equal, not equal, less than, greater than, less than/greater than and equal to - * the output frame will store boolean value for each each comparison + * This method performs the value comparison on two frames if the values in both frames are equal, not equal, less + * than, greater than, less than/greater than and equal to the output frame will store boolean value for each each + * comparison * - * @param bop binary operator - * @param that frame block of rhs of m * n dimensions - * @param out output frame block - * @return a boolean frameBlock + * @param bop binary operator + * @param that frame block of rhs of m * n dimensions + * @param out output frame block + * @return a boolean frameBlock */ public FrameBlock binaryOperations(BinaryOperator bop, FrameBlock that, FrameBlock out) { if(getNumColumns() != that.getNumColumns() && getNumRows() != that.getNumColumns()) - throw new DMLRuntimeException("Frame dimension mismatch "+getNumRows()+" * "+getNumColumns()+ - " != "+that.getNumRows()+" * "+that.getNumColumns()); + throw new DMLRuntimeException("Frame dimension mismatch " + getNumRows() + " * " + getNumColumns() + " != " + + that.getNumRows() + " * " + that.getNumColumns()); String[][] outputData = new String[getNumRows()][getNumColumns()]; - //compare output value, incl implicit type promotion if necessary + // compare output value, incl implicit type promotion if necessary if(bop.fn instanceof ValueComparisonFunction) { ValueComparisonFunction vcomp = (ValueComparisonFunction) bop.fn; out = executeValueComparisons(this, that, vcomp, outputData); @@ -829,9 +828,8 @@ private FrameBlock executeValueComparisons(FrameBlock frameBlock, FrameBlock tha outputData[j][i] = String.valueOf(vcomp.compare(v1, v2)); } } - else if(getSchema()[i] == ValueType.FP64 || that - .getSchema()[i] == ValueType.FP64 || getSchema()[i] == ValueType.FP32 || that - .getSchema()[i] == ValueType.FP32) { + else if(getSchema()[i] == ValueType.FP64 || that.getSchema()[i] == ValueType.FP64 || + getSchema()[i] == ValueType.FP32 || that.getSchema()[i] == ValueType.FP32) { for(int j = 0; j < getNumRows(); j++) { if(checkAndSetEmpty(frameBlock, that, outputData, j, i)) continue; @@ -840,9 +838,8 @@ else if(getSchema()[i] == ValueType.FP64 || that outputData[j][i] = String.valueOf(vcomp.compare(so1.getDoubleValue(), so2.getDoubleValue())); } } - else if(getSchema()[i] == ValueType.INT64 || that - .getSchema()[i] == ValueType.INT64 || getSchema()[i] == ValueType.INT32 || that - .getSchema()[i] == ValueType.INT32) { + else if(getSchema()[i] == ValueType.INT64 || that.getSchema()[i] == ValueType.INT64 || + getSchema()[i] == ValueType.INT32 || that.getSchema()[i] == ValueType.INT32) { for(int j = 0; j < this.getNumRows(); j++) { if(checkAndSetEmpty(frameBlock, that, outputData, j, i)) continue; @@ -876,29 +873,28 @@ private static boolean checkAndSetEmpty(FrameBlock fb1, FrameBlock fb2, String[] // indexing and append operations public FrameBlock leftIndexingOperations(FrameBlock rhsFrame, IndexRange ixrange, FrameBlock ret) { - return leftIndexingOperations(rhsFrame, - (int)ixrange.rowStart, (int)ixrange.rowEnd, - (int)ixrange.colStart, (int)ixrange.colEnd, ret); + return leftIndexingOperations(rhsFrame, (int) ixrange.rowStart, (int) ixrange.rowEnd, (int) ixrange.colStart, + (int) ixrange.colEnd, ret); } public FrameBlock leftIndexingOperations(FrameBlock rhsFrame, int rl, int ru, int cl, int cu, FrameBlock ret) { // check the validity of bounds - if ( rl < 0 || rl >= getNumRows() || ru < rl || ru >= getNumRows() - || cl < 0 || cu >= getNumColumns() || cu < cl || cu >= getNumColumns() ) { - throw new DMLRuntimeException("Invalid values for frame indexing: ["+(rl+1)+":"+(ru+1)+"," + (cl+1)+":"+(cu+1)+"] " + - "must be within frame dimensions ["+getNumRows()+","+getNumColumns()+"]."); + if(rl < 0 || rl >= getNumRows() || ru < rl || ru >= getNumRows() || cl < 0 || cu >= getNumColumns() || cu < cl || + cu >= getNumColumns()) { + throw new DMLRuntimeException( + "Invalid values for frame indexing: [" + (rl + 1) + ":" + (ru + 1) + "," + (cl + 1) + ":" + (cu + 1) + "] " + + "must be within frame dimensions [" + getNumRows() + "," + getNumColumns() + "]."); } - if ( (ru-rl+1) < rhsFrame.getNumRows() || (cu-cl+1) < rhsFrame.getNumColumns()) { - throw new DMLRuntimeException("Invalid values for frame indexing: " + - "dimensions of the source frame ["+rhsFrame.getNumRows()+"x" + rhsFrame.getNumColumns() + "] " + - "do not match the shape of the frame specified by indices [" + - (rl+1) +":" + (ru+1) + ", " + (cl+1) + ":" + (cu+1) + "]."); + if((ru - rl + 1) < rhsFrame.getNumRows() || (cu - cl + 1) < rhsFrame.getNumColumns()) { + throw new DMLRuntimeException( + "Invalid values for frame indexing: " + "dimensions of the source frame [" + rhsFrame.getNumRows() + "x" + + rhsFrame.getNumColumns() + "] " + "do not match the shape of the frame specified by indices [" + + (rl + 1) + ":" + (ru + 1) + ", " + (cl + 1) + ":" + (cu + 1) + "]."); } - - //allocate output frame (incl deep copy schema) - if( ret == null ) + // allocate output frame (incl deep copy schema) + if(ret == null) ret = new FrameBlock(); ret._numRows = _numRows; ret._schema = _schema.clone(); @@ -906,18 +902,17 @@ public FrameBlock leftIndexingOperations(FrameBlock rhsFrame, int rl, int ru, in ret._colmeta = _colmeta.clone(); ret._coldata = new Array[getNumColumns()]; - //copy data to output and partial overwrite w/ rhs - for( int j=0; j=cl && j<=cu ) { - //fast-path for homogeneous column schemas - if( _schema[j]==rhsFrame._schema[j-cl] ) - tmp.set(rl, ru, rhsFrame._coldata[j-cl]); - //general-path for heterogeneous column schemas + if(j >= cl && j <= cu) { + // fast-path for homogeneous column schemas + if(_schema[j] == rhsFrame._schema[j - cl]) + tmp.set(rl, ru, rhsFrame._coldata[j - cl]); + // general-path for heterogeneous column schemas else { - for( int i=rl; i<=ru; i++ ) - tmp.set(i, UtilFunctions.objectToObject( - _schema[j], rhsFrame._coldata[j-cl].get(i-rl))); + for(int i = rl; i <= ru; i++) + tmp.set(i, UtilFunctions.objectToObject(_schema[j], rhsFrame._coldata[j - cl].get(i - rl))); } } ret._coldata[j] = tmp; @@ -933,12 +928,12 @@ public final FrameBlock slice(IndexRange ixrange, FrameBlock ret) { @Override public final FrameBlock slice(int rl, int ru) { - return slice(rl, ru, 0, getNumColumns()-1, false, null); + return slice(rl, ru, 0, getNumColumns() - 1, false, null); } @Override public final FrameBlock slice(int rl, int ru, boolean deep) { - return slice(rl, ru, 0, getNumColumns()-1, deep, null); + return slice(rl, ru, 0, getNumColumns() - 1, deep, null); } @Override @@ -959,54 +954,55 @@ public final FrameBlock slice(int rl, int ru, int cl, int cu, boolean deep) { @Override public FrameBlock slice(int rl, int ru, int cl, int cu, boolean deep, FrameBlock ret) { validateSliceArgument(rl, ru, cl, cu); - //allocate output frame - if( ret == null ) + // allocate output frame + if(ret == null) ret = new FrameBlock(); else - ret.reset(ru-rl+1, true); + ret.reset(ru - rl + 1, true); - //copy output schema and colnames - int numCols = cu-cl+1; + // copy output schema and colnames + int numCols = cu - cl + 1; boolean isDefNames = isColNamesDefault(); ret._schema = new ValueType[numCols]; ret._colnames = !isDefNames ? new String[numCols] : null; ret._colmeta = new ColumnMetadata[numCols]; - for( int j=cl; j<=cu; j++ ) { - ret._schema[j-cl] = _schema[j]; - ret._colmeta[j-cl] = _colmeta[j]; - if( !isDefNames ) - ret._colnames[j-cl] = getColumnName(j); + for(int j = cl; j <= cu; j++) { + ret._schema[j - cl] = _schema[j]; + ret._colmeta[j - cl] = _colmeta[j]; + if(!isDefNames) + ret._colnames[j - cl] = getColumnName(j); } - ret._numRows = ru-rl+1; - if(ret._coldata == null ) + ret._numRows = ru - rl + 1; + if(ret._coldata == null) ret._coldata = new Array[numCols]; - //fast-path: shallow copy column indexing - if( ret._numRows == _numRows && !deep ) { - //this shallow copy does not only avoid an array copy, but - //also allows for bi-directional reuses of recodemaps - for( int j=cl; j<=cu; j++ ) - ret._coldata[j-cl] = _coldata[j]; + // fast-path: shallow copy column indexing + if(ret._numRows == _numRows && !deep) { + // this shallow copy does not only avoid an array copy, but + // also allows for bi-directional reuses of recodemaps + for(int j = cl; j <= cu; j++) + ret._coldata[j - cl] = _coldata[j]; } - //copy output data + // copy output data else { - for( int j=cl; j<=cu; j++ ) { - if( ret._coldata[j-cl] == null ) - ret._coldata[j-cl] = _coldata[j].slice(rl,ru+1); + for(int j = cl; j <= cu; j++) { + if(ret._coldata[j - cl] == null) + ret._coldata[j - cl] = _coldata[j].slice(rl, ru + 1); else - ret._coldata[j-cl].set(0, ru-rl, _coldata[j], rl); + ret._coldata[j - cl].set(0, ru - rl, _coldata[j], rl); } } return ret; } - protected void validateSliceArgument(int rl, int ru, int cl, int cu){ - if ( rl < 0 || rl >= getNumRows() || ru < rl || ru >= getNumRows() - || cl < 0 || cu >= getNumColumns() || cu < cl || cu >= getNumColumns() ) { - throw new DMLRuntimeException("Invalid values for frame indexing: ["+(rl+1)+":"+(ru+1)+"," + (cl+1)+":"+(cu+1)+"] " + - "must be within frame dimensions ["+getNumRows()+","+getNumColumns()+"]"); + protected void validateSliceArgument(int rl, int ru, int cl, int cu) { + if(rl < 0 || rl >= getNumRows() || ru < rl || ru >= getNumRows() || cl < 0 || cu >= getNumColumns() || cu < cl || + cu >= getNumColumns()) { + throw new DMLRuntimeException( + "Invalid values for frame indexing: [" + (rl + 1) + ":" + (ru + 1) + "," + (cl + 1) + ":" + (cu + 1) + "] " + + "must be within frame dimensions [" + getNumRows() + "," + getNumColumns() + "]"); } } @@ -1016,72 +1012,69 @@ public void slice(ArrayList> outList, IndexRange range, i throw new NotImplementedException("Not implemented slice of more than 1 block out"); int r = (int) range.rowStart; final FrameBlock out = outList.get(0).getValue(); - if(range.rowStart < rowCut) - slice(r, (int)Math.min(rowCut, range.rowEnd + 1), - (int)range.colStart, (int)range.colEnd,out); - - if(range.rowEnd >= rowCut) - slice(r, (int)range.rowEnd, - (int)range.colStart, (int)range.colEnd, out); - + if(range.rowStart < rowCut) + slice(r, (int) Math.min(rowCut, range.rowEnd + 1), (int) range.colStart, (int) range.colEnd, out); + + if(range.rowEnd >= rowCut) + slice(r, (int) range.rowEnd, (int) range.colStart, (int) range.colEnd, out); + } } /** - * Appends the given argument FrameBlock 'that' to this FrameBlock by - * creating a deep copy to prevent side effects. For cbind, the frames - * are appended column-wise (same number of rows), while for rbind the - * frames are appended row-wise (same number of columns). + * Appends the given argument FrameBlock 'that' to this FrameBlock by creating a deep copy to prevent side effects. + * For cbind, the frames are appended column-wise (same number of rows), while for rbind the frames are appended + * row-wise (same number of columns). * - * @param that frame block to append to current frame block + * @param that frame block to append to current frame block * @param cbind if true, column append * @return frame block */ public FrameBlock append(FrameBlock that, boolean cbind) { FrameBlock ret = new FrameBlock(); - if( cbind ) //COLUMN APPEND + if(cbind) // COLUMN APPEND { - //sanity check row dimension mismatch - if( getNumRows() != that.getNumRows() ) { - throw new DMLRuntimeException("Incompatible number of rows for cbind: "+ - that.getNumRows()+" (expected: "+getNumRows()+")"); + // sanity check row dimension mismatch + if(getNumRows() != that.getNumRows()) { + throw new DMLRuntimeException( + "Incompatible number of rows for cbind: " + that.getNumRows() + " (expected: " + getNumRows() + ")"); } - //allocate output frame + // allocate output frame ret._numRows = _numRows; - //concatenate schemas (w/ deep copy to prevent side effects) + // concatenate schemas (w/ deep copy to prevent side effects) ret._schema = (ValueType[]) ArrayUtils.addAll(_schema, that._schema); ret._colnames = (String[]) ArrayUtils.addAll(getColumnNames(), that.getColumnNames()); ret._colmeta = (ColumnMetadata[]) ArrayUtils.addAll(_colmeta, that._colmeta); - //check and enforce unique columns names - if( !Arrays.stream(ret._colnames).allMatch(new HashSet<>()::add) ) + // check and enforce unique columns names + if(!Arrays.stream(ret._colnames).allMatch(new HashSet<>()::add)) ret._colnames = createColNames(ret.getNumColumns()); - //concatenate column data (w/ shallow copy which is safe due to copy on write semantics) + // concatenate column data (w/ shallow copy which is safe due to copy on write semantics) ret._coldata = (Array[]) ArrayUtils.addAll(_coldata, that._coldata); } - else //ROW APPEND + else // ROW APPEND { - //sanity check column dimension mismatch - if( getNumColumns() != that.getNumColumns() ) { - throw new DMLRuntimeException("Incompatible number of columns for rbind: "+ - that.getNumColumns()+" (expected: "+getNumColumns()+")"); + // sanity check column dimension mismatch + if(getNumColumns() != that.getNumColumns()) { + throw new DMLRuntimeException("Incompatible number of columns for rbind: " + that.getNumColumns() + + " (expected: " + getNumColumns() + ")"); } ret._numRows = _numRows; // note set to previous since each row is appended on. ret._schema = _schema.clone(); - ret._colnames = (_colnames!=null) ? _colnames.clone() : null; + ret._colnames = (_colnames != null) ? _colnames.clone() : null; ret._colmeta = new ColumnMetadata[getNumColumns()]; - for( int j=0; j<_schema.length; j++ ) + for(int j = 0; j < _schema.length; j++) ret._colmeta[j] = new ColumnMetadata(); - //concatenate data (deep copy first, append second) + // concatenate data (deep copy first, append second) ret._coldata = new Array[getNumColumns()]; - for( int j=0; j iter = IteratorFactory.getObjectRowIterator(that, _schema); - while( iter.hasNext() ) + while(iter.hasNext()) ret.appendRow(iter.next()); } @@ -1089,7 +1082,7 @@ public FrameBlock append(FrameBlock that, boolean cbind) { return ret; } - public FrameBlock copy(){ + public FrameBlock copy() { FrameBlock ret = new FrameBlock(); ret.copy(this); return ret; @@ -1103,27 +1096,27 @@ public void copy(FrameBlock src) { _colnames = Arrays.copyOf(src._colnames, nCol); if(!src.isColumnMetadataDefault()) _colmeta = Arrays.copyOf(src._colmeta, nCol); - if(src._coldata != null){ + if(src._coldata != null) { _coldata = new Array[nCol]; - for(int i = 0; i < nCol; i ++) + for(int i = 0; i < nCol; i++) _coldata[i] = src._coldata[i].clone(); } - _msize = -1; + _msize = -1; } /** * Copy src matrix into the index range of the existing current matrix. * - * @param rl row start - * @param ru row end inclusive - * @param cl col start - * @param cu col end inclusive + * @param rl row start + * @param ru row end inclusive + * @param cl col start + * @param cu col end inclusive * @param src source FrameBlock */ public void copy(int rl, int ru, int cl, int cu, FrameBlock src) { // If full copy, fall back to default copy - if(rl == 0 && cl == 0 && ru +1 == this.getNumRows() && cu +1 == this.getNumColumns()){ + if(rl == 0 && cl == 0 && ru + 1 == this.getNumRows() && cu + 1 == this.getNumColumns()) { copy(src); return; } @@ -1135,199 +1128,189 @@ public void copy(int rl, int ru, int cl, int cu, FrameBlock src) { if(_schema[j].equals(src._schema[j - cl])) _coldata[j].set(rl, ru, src._coldata[j - cl]); else {// general case w/ schema transformation - for(int i = rl; i <= ru; i++) + for(int i = rl; i <= ru; i++) set(i, j, UtilFunctions.objectToObject(_schema[j], src.get(i - rl, j - cl))); } } } - /////// // transform specific functionality /** - * This function will split every Recode map in the column using delimiter Lop.DATATYPE_PREFIX, - * as Recode map generated earlier in the form of Code+Lop.DATATYPE_PREFIX+Token and store it in a map - * which contains token and code for every unique tokens. + * This function will split every Recode map in the column using delimiter Lop.DATATYPE_PREFIX, as Recode map + * generated earlier in the form of Code+Lop.DATATYPE_PREFIX+Token and store it in a map which contains token and + * code for every unique tokens. * - * @param col is the column # from frame data which contains Recode map generated earlier. + * @param col is the column # from frame data which contains Recode map generated earlier. * @return map of token and code for every element in the input column of a frame containing Recode map */ - public HashMap getRecodeMap(int col) { - //probe cache for existing map - if( REUSE_RECODE_MAPS ) { - SoftReference> tmp = _coldata[col].getCache(); - HashMap map = (tmp!=null) ? tmp.get() : null; - if( map != null ) return map; + public HashMap getRecodeMap(int col) { + // probe cache for existing map + if(REUSE_RECODE_MAPS) { + SoftReference> tmp = _coldata[col].getCache(); + HashMap map = (tmp != null) ? tmp.get() : null; + if(map != null) + return map; } - //construct recode map - HashMap map = new HashMap<>(); + // construct recode map + HashMap map = new HashMap<>(); Array ldata = _coldata[col]; - for( int i=0; i(map)); return map; } @Override - public void merge(FrameBlock that, boolean bDummy) { + public void merge(FrameBlock that, boolean appendOnly) { merge(that); } public void merge(FrameBlock that) { - //check for empty input source (nothing to merge) - if( that == null || that.getNumRows() == 0 ) + // check for empty input source (nothing to merge) + if(that == null || that.getNumRows() == 0) return; - //check dimensions (before potentially copy to prevent implicit dimension change) - if ( getNumRows() != that.getNumRows() || getNumColumns() != that.getNumColumns() ) - throw new DMLRuntimeException("Dimension mismatch on merge disjoint (target="+getNumRows()+"x"+getNumColumns()+", source="+that.getNumRows()+"x"+that.getNumColumns()+")"); + // check dimensions (before potentially copy to prevent implicit dimension change) + if(getNumRows() != that.getNumRows() || getNumColumns() != that.getNumColumns()) + throw new DMLRuntimeException("Dimension mismatch on merge disjoint (target=" + getNumRows() + "x" + + getNumColumns() + ", source=" + that.getNumRows() + "x" + that.getNumColumns() + ")"); - //meta data copy if necessary - for( int j=0; j vt.toString()).toArray(String[]::new)); + FrameBlock fb = new FrameBlock(UtilFunctions.nCopies(getNumColumns(), ValueType.STRING)); + fb.appendRow(Arrays.stream(_schema).map(vt -> vt.toString()).toArray(String[]::new)); return fb; } - public final FrameBlock detectSchema(){ - return FrameLibDetectSchema.detectSchema(this); + public final FrameBlock detectSchema(int k) { + return FrameLibDetectSchema.detectSchema(this, k); } - public final FrameBlock detectSchema(double sampleFraction) { - return FrameLibDetectSchema.detectSchema(this, sampleFraction); + public final FrameBlock detectSchema(double sampleFraction, int k) { + return FrameLibDetectSchema.detectSchema(this, sampleFraction, k); } /** * Drop the cell value which does not confirms to the data type of its column + * * @param schema of the frame * @return original frame where invalid values are replaced with null */ public FrameBlock dropInvalidType(FrameBlock schema) { - //sanity checks + // sanity checks if(this.getNumColumns() != schema.getNumColumns()) - throw new DMLException("mismatch in number of columns in frame and its schema "+this.getNumColumns()+" != "+schema.getNumColumns()); + throw new DMLException("mismatch in number of columns in frame and its schema " + this.getNumColumns() + " != " + + schema.getNumColumns()); // extract the schema in String array String[] schemaString = IteratorFactory.getStringRowIterator(schema).next(); - for (int i = 0; i < this.getNumColumns(); i++) { + for(int i = 0; i < this.getNumColumns(); i++) { Array obj = this.getColumn(i); String schemaCol = schemaString[i]; String type; if(schemaCol.contains("FP")) type = "FP"; - else if(schemaCol.contains("INT")) + else if(schemaCol.contains("INT")) type = "INT"; - else if(schemaCol.contains("STRING")) + else if(schemaCol.contains("STRING")) // In case of String columns, don't do any verification or replacements. continue; - else + else type = schemaCol; - - for (int j = 0; j < this.getNumRows(); j++){ + + for(int j = 0; j < this.getNumRows(); j++) { if(obj.get(j) == null) continue; - String dataValue = obj.get(j).toString().trim().replace("\"", "").toLowerCase() ; + String dataValue = obj.get(j).toString().trim().replace("\"", "").toLowerCase(); ValueType dataType = FrameUtil.isType(dataValue); if(!dataType.toString().contains(type) && !(dataType == ValueType.BOOLEAN && type.equals("INT")) && - !(dataType == ValueType.BOOLEAN && type.equals("FP"))){ - LOG.warn("Datatype detected: " + dataType + " where expected: " + schemaString[i] + " col: " + - (i+1) + ", row:" +(j+1)); + !(dataType == ValueType.BOOLEAN && type.equals("FP"))) { + LOG.warn("Datatype detected: " + dataType + " where expected: " + schemaString[i] + " col: " + (i + 1) + + ", row:" + (j + 1)); - this.set(j,i,null); + this.set(j, i, null); } } } @@ -1335,26 +1318,25 @@ else if(schemaCol.contains("STRING")) } /** - * This method validates the frame data against an attribute length constrain - * if data value in any cell is greater than the specified threshold of that attribute - * the output frame will store a null on that cell position, thus removing the length-violating values. + * This method validates the frame data against an attribute length constrain if data value in any cell is greater + * than the specified threshold of that attribute the output frame will store a null on that cell position, thus + * removing the length-violating values. * - * @param feaLen vector of valid lengths - * @return FrameBlock with invalid values converted into missing values (null) + * @param feaLen vector of valid lengths + * @return FrameBlock with invalid values converted into missing values (null) */ public FrameBlock invalidByLength(MatrixBlock feaLen) { - //sanity checks + // sanity checks if(this.getNumColumns() != feaLen.getNumColumns()) throw new DMLException("mismatch in number of columns in frame and corresponding feature-length vector"); FrameBlock outBlock = new FrameBlock(this); - for (int i = 0; i < this.getNumColumns(); i++) { + for(int i = 0; i < this.getNumColumns(); i++) { if(feaLen.quickGetValue(0, i) == -1) continue; - int validLength = (int)feaLen.quickGetValue(0, i); + int validLength = (int) feaLen.quickGetValue(0, i); Array obj = this.getColumn(i); - for (int j = 0; j < obj.size(); j++) - { + for(int j = 0; j < obj.size(); j++) { if(obj.get(j) == null) continue; String dataValue = obj.get(j).toString(); @@ -1371,51 +1353,52 @@ public static FrameBlock mergeSchema(FrameBlock temp1, FrameBlock temp2) { String[] rowTemp2 = IteratorFactory.getStringRowIterator(temp2).next(); if(rowTemp1.length != rowTemp2.length) - throw new DMLRuntimeException("Schema dimension " - + "mismatch: "+rowTemp1.length+" vs "+rowTemp2.length); + throw new DMLRuntimeException("Schema dimension " + "mismatch: " + rowTemp1.length + " vs " + rowTemp2.length); - for(int i=0; i< rowTemp1.length; i++ ) { - //modify schema1 if necessary (different schema2) + for(int i = 0; i < rowTemp1.length; i++) { + // modify schema1 if necessary (different schema2) if(!rowTemp1[i].equals(rowTemp2[i])) { if(rowTemp1[i].equals("STRING") || rowTemp2[i].equals("STRING")) rowTemp1[i] = "STRING"; - else if (rowTemp1[i].equals("FP64") || rowTemp2[i].equals("FP64")) + else if(rowTemp1[i].equals("FP64") || rowTemp2[i].equals("FP64")) rowTemp1[i] = "FP64"; - else if (rowTemp1[i].equals("FP32") && new ArrayList<>(Arrays.asList("INT64", "INT32", "CHARACTER")).contains(rowTemp2[i]) ) + else if(rowTemp1[i].equals("FP32") && + new ArrayList<>(Arrays.asList("INT64", "INT32", "CHARACTER")).contains(rowTemp2[i])) rowTemp1[i] = "FP32"; - else if (rowTemp1[i].equals("INT64") && new ArrayList<>(Arrays.asList("INT32", "CHARACTER")).contains(rowTemp2[i])) + else if(rowTemp1[i].equals("INT64") && + new ArrayList<>(Arrays.asList("INT32", "CHARACTER")).contains(rowTemp2[i])) rowTemp1[i] = "INT64"; - else if (rowTemp1[i].equals("INT32") || rowTemp2[i].equals("CHARACTER")) + else if(rowTemp1[i].equals("INT32") || rowTemp2[i].equals("CHARACTER")) rowTemp1[i] = "INT32"; } } - //create output block one row representing the schema as strings - FrameBlock mergedFrame = new FrameBlock( - UtilFunctions.nCopies(temp1.getNumColumns(), ValueType.STRING)); + // create output block one row representing the schema as strings + FrameBlock mergedFrame = new FrameBlock(UtilFunctions.nCopies(temp1.getNumColumns(), ValueType.STRING)); mergedFrame.appendRow(rowTemp1); return mergedFrame; } public void mapInplace(Function fun) { - for(int j=0; j")) { String args = lambdaExpr.substring(lambdaExpr.indexOf('(') + 1, lambdaExpr.indexOf(')')); if(args.contains(",")) { String[] arguments = args.split(","); return DMVUtils.syntacticalPatternDiscovery(this, Double.parseDouble(arguments[0]), arguments[1]); - } else if (args.contains(";")) { + } + else if(args.contains(";")) { String[] arguments = args.split(";"); return EMAUtils.exponentialMovingAverageImputation(this, Integer.parseInt(arguments[0]), arguments[1], - Integer.parseInt(arguments[2]), Double.parseDouble(arguments[3]), Double.parseDouble(arguments[4]), Double.parseDouble(arguments[5])); + Integer.parseInt(arguments[2]), Double.parseDouble(arguments[3]), Double.parseDouble(arguments[4]), + Double.parseDouble(arguments[5])); } } if(lambdaExpr.contains("jaccardSim")) @@ -1428,9 +1411,9 @@ public FrameBlock frameRowReplication(FrameBlock rowToreplicate) { if(this.getNumColumns() != rowToreplicate.getNumColumns()) throw new DMLRuntimeException("Mismatch number of columns"); if(rowToreplicate.getNumRows() > 1) - throw new DMLRuntimeException("only supported single rows frames to replicate"); - for(int i=0; i minMax = _coldata[k].getMinMaxLength(); + Pair minMax = _coldata[k].getMinMaxLength(); maxColLength[k] = minMax.getKey(); minColLength[k] = minMax.getValue(); } - + ArrayList probColList = new ArrayList(); for(int i = 0; i < this.getNumColumns(); i++) { for(int j = 0; j < this.getNumRows(); j++) { @@ -1462,15 +1445,16 @@ public FrameBlock valueSwap(FrameBlock schema) { ValueType dataType = FrameUtil.isType(dataValue); String type = dataType.toString().replaceAll("\\d", ""); - // get the avergae column length - if(!dataType.toString().contains(schemaString[i]) && !(dataType == ValueType.BOOLEAN && schemaString[i] - .equals("INT")) && !(dataType == ValueType.BOOLEAN && schemaString[i].equals("FP")) && !(dataType - .toString().contains("INT") && schemaString[i].equals("FP"))) { + // get the avergae column length + if(!dataType.toString().contains(schemaString[i]) && + !(dataType == ValueType.BOOLEAN && schemaString[i].equals("INT")) && + !(dataType == ValueType.BOOLEAN && schemaString[i].equals("FP")) && + !(dataType.toString().contains("INT") && schemaString[i].equals("FP"))) { LOG.warn("conflict " + dataType + " " + schemaString[i] + " " + dataValue); - // check the other column with satisfy the data type of this value + // check the other column with satisfy the data type of this value for(int w = 0; w < schemaString.length; w++) { - if(schemaString[w].equals(type) && dataValue.length() > minColLength[w] && dataValue - .length() < maxColLength[w] && (w != i)) { + if(schemaString[w].equals(type) && dataValue.length() > minColLength[w] && + dataValue.length() < maxColLength[w] && (w != i)) { Object item = this.get(j, w); String dataValueProb = (item != null) ? item.toString().trim().replace("\"", "") .toLowerCase() : "0"; @@ -1514,11 +1498,11 @@ else if(probColList.size() > 0) { return this; } - public FrameBlock map (FrameMapFunction lambdaExpr, long margin) { + public FrameBlock map(FrameMapFunction lambdaExpr, long margin) { // Prepare temporary output array String[][] output = new String[getNumRows()][getNumColumns()]; - if (margin == 1) { + if(margin == 1) { // Execute map function on rows for(int i = 0; i < getNumRows(); i++) { String[] row = new String[getNumColumns()]; @@ -1528,16 +1512,21 @@ public FrameBlock map (FrameMapFunction lambdaExpr, long margin) { } output[i] = lambdaExpr.apply(row); } - } else if (margin == 2) { + } + else if(margin == 2) { // Execute map function on columns for(int j = 0; j < getNumColumns(); j++) { - String[] actualColumn = Arrays.copyOfRange((String[]) getColumnData(j), 0, getNumRows()); // since more rows can be allocated, mutable array + String[] actualColumn = Arrays.copyOfRange((String[]) getColumnData(j), 0, getNumRows()); // since more rows + // can be + // allocated, + // mutable array String[] outColumn = lambdaExpr.apply(actualColumn); for(int i = 0; i < getNumRows(); i++) output[i][j] = outColumn[i]; } - } else { + } + else { // Execute map function on all cells for(int j = 0; j < getNumColumns(); j++) { Array input = getColumn(j); @@ -1549,7 +1538,7 @@ public FrameBlock map (FrameMapFunction lambdaExpr, long margin) { return new FrameBlock(UtilFunctions.nCopies(getNumColumns(), ValueType.STRING), output); } - public FrameBlock mapDist (FrameMapFunction lambdaExpr) { + public FrameBlock mapDist(FrameMapFunction lambdaExpr) { String[][] output = new String[getNumRows()][getNumRows()]; for(String[] row : output) Arrays.fill(row, "0.0"); @@ -1563,7 +1552,7 @@ public FrameBlock mapDist (FrameMapFunction lambdaExpr) { return new FrameBlock(UtilFunctions.nCopies(getNumRows(), ValueType.STRING), output); } - public static FrameMapFunction getCompiledFunction (String lambdaExpr, long margin) { + public static FrameMapFunction getCompiledFunction(String lambdaExpr, long margin) { String cname = "StringProcessing" + CLASS_ID.getNextID(); StringBuilder sb = new StringBuilder(); String[] parts = lambdaExpr.split("->"); @@ -1581,7 +1570,8 @@ public static FrameMapFunction getCompiledFunction (String lambdaExpr, long marg if(margin != 0) { sb.append("public String[] apply(String[] " + varname[0].trim() + ") {\n"); sb.append(" return UtilFunctions.toStringArray(" + expr + "); }}\n"); - } else { + } + else { if(varname.length == 1) { sb.append("public String apply(String " + varname[0].trim() + ") {\n"); sb.append(" return String.valueOf(" + expr + "); }}\n"); @@ -1593,11 +1583,10 @@ else if(varname.length == 2) { } // compile class, and create FrameMapFunction object try { - return (FrameMapFunction) CodegenUtils.compileClass(cname, sb.toString()) - .getDeclaredConstructor().newInstance(); + return (FrameMapFunction) CodegenUtils.compileClass(cname, sb.toString()).getDeclaredConstructor() + .newInstance(); } - catch(InstantiationException | IllegalAccessException - | IllegalArgumentException | InvocationTargetException + catch(InstantiationException | IllegalAccessException | IllegalArgumentException | InvocationTargetException | NoSuchMethodException | SecurityException e) { throw new DMLRuntimeException("Failed to compile FrameMapFunction.", e); } @@ -1605,9 +1594,18 @@ else if(varname.length == 2) { public static class FrameMapFunction implements Serializable { private static final long serialVersionUID = -8398572153616520873L; - public String apply(String input) {return null;} - public String apply(String input1, String input2) { return null;} - public String[] apply(String[] input1) { return null;} + + public String apply(String input) { + return null; + } + + public String apply(String input1, String input2) { + return null; + } + + public String[] apply(String[] input1) { + return null; + } } public FrameBlock replaceOperations(String pattern, String replacement) { @@ -1615,34 +1613,39 @@ public FrameBlock replaceOperations(String pattern, String replacement) { boolean NaNp = "NaN".equals(pattern); boolean NaNr = "NaN".equals(replacement); - ValueType patternType = UtilFunctions.isBoolean(pattern) ? ValueType.BOOLEAN : (NumberUtils.isCreatable(pattern) | NaNp ? - (UtilFunctions.isIntegerNumber(pattern) ? ValueType.INT64 : ValueType.FP64) : ValueType.STRING); - ValueType replacementType = UtilFunctions.isBoolean(replacement) ? ValueType.BOOLEAN : (NumberUtils.isCreatable(replacement) | NaNr ? - (UtilFunctions.isIntegerNumber(replacement) ? ValueType.INT64 : ValueType.FP64) : ValueType.STRING); + ValueType patternType = UtilFunctions.isBoolean(pattern) ? ValueType.BOOLEAN : (NumberUtils.isCreatable(pattern) | + NaNp ? (UtilFunctions.isIntegerNumber(pattern) ? ValueType.INT64 : ValueType.FP64) : ValueType.STRING); + ValueType replacementType = UtilFunctions + .isBoolean(replacement) ? ValueType.BOOLEAN : (NumberUtils.isCreatable(replacement) | + NaNr ? (UtilFunctions.isIntegerNumber(replacement) ? ValueType.INT64 : ValueType.FP64) : ValueType.STRING); if(patternType != replacementType || !ValueType.isSameTypeString(patternType, replacementType)) - throw new DMLRuntimeException("Pattern and replacement types should be same: "+patternType+" "+replacementType); + throw new DMLRuntimeException( + "Pattern and replacement types should be same: " + patternType + " " + replacementType); - for(int i = 0; i < ret.getNumColumns(); i++){ + for(int i = 0; i < ret.getNumColumns(); i++) { Array colData = ret._coldata[i]; - for(int j = 0; j < colData.size() && (ValueType.isSameTypeString(_schema[i], patternType) || _schema[i] == ValueType.STRING); j++) { - T patternNew = (T) UtilFunctions.stringToObject(_schema[i], pattern); + for(int j = 0; + j < colData.size() && + (ValueType.isSameTypeString(_schema[i], patternType) || _schema[i] == ValueType.STRING); + j++) { + T patternNew = (T) UtilFunctions.stringToObject(_schema[i], pattern); T replacementNew = (T) UtilFunctions.stringToObject(_schema[i], replacement); Object ent = colData.get(j); if(ent != null && ent.toString().equals(patternNew.toString())) - colData.set(j,replacementNew); - else if(ent instanceof String && ent.equals(pattern)) + colData.set(j, replacementNew); + else if(ent instanceof String && ent.equals(pattern)) colData.set(j, replacement); } } return ret; } - public FrameBlock removeEmptyOperations(boolean rows, boolean emptyReturn, MatrixBlock select) { - if( rows ) + public FrameBlock removeEmptyOperations(boolean rows, boolean emptyReturn, MatrixBlock select) { + if(rows) return removeEmptyRows(select, emptyReturn); - else //cols + else // cols return removeEmptyColumns(select, emptyReturn); } @@ -1652,13 +1655,13 @@ private FrameBlock removeEmptyRows(MatrixBlock select, boolean emptyReturn) { FrameBlock ret = new FrameBlock(_schema, _colnames); - if (select == null) { + if(select == null) { Object[] row = new Object[getNumColumns()]; for(int i = 0; i < _numRows; i++) { boolean isEmpty = true; for(int j = 0; j < getNumColumns(); j++) { row[j] = _coldata[j].get(i); - isEmpty &= ArrayUtils.contains(new double[]{0.0, Double.NaN}, + isEmpty &= ArrayUtils.contains(new double[] {0.0, Double.NaN}, UtilFunctions.objectToDoubleSafe(_schema[j], row[j])); } if(!isEmpty) @@ -1679,9 +1682,9 @@ private FrameBlock removeEmptyRows(MatrixBlock select, boolean emptyReturn) { } } - if (ret.getNumRows() == 0 && emptyReturn) { + if(ret.getNumRows() == 0 && emptyReturn) { String[][] arr = new String[1][getNumColumns()]; - Arrays.fill(arr, new String[]{null}); + Arrays.fill(arr, new String[] {null}); ValueType[] schema = new ValueType[getNumColumns()]; Arrays.fill(schema, ValueType.STRING); return new FrameBlock(schema, arr); @@ -1699,20 +1702,20 @@ private FrameBlock removeEmptyColumns(MatrixBlock select, boolean emptyReturn) { FrameBlock ret = new FrameBlock(); List columnMetadata = new ArrayList<>(); - if (select == null) { + if(select == null) { for(int i = 0; i < getNumColumns(); i++) { Array colData = _coldata[i]; ValueType type = _schema[i]; - boolean isEmpty = IntStream.range(0, colData.size()) - .mapToObj((IntFunction) colData::get) - .allMatch(e -> ArrayUtils.contains(new double[]{0.0, Double.NaN}, UtilFunctions.objectToDoubleSafe(type, e))); + boolean isEmpty = IntStream.range(0, colData.size()).mapToObj((IntFunction) colData::get).allMatch( + e -> ArrayUtils.contains(new double[] {0.0, Double.NaN}, UtilFunctions.objectToDoubleSafe(type, e))); if(!isEmpty) { ret.appendColumn(_coldata[i]); columnMetadata.add(new ColumnMetadata(_colmeta[i])); } } - } else { + } + else { if(select.getNonZeros() == getNumColumns()) return new FrameBlock(this); @@ -1728,8 +1731,8 @@ private FrameBlock removeEmptyColumns(MatrixBlock select, boolean emptyReturn) { if(ret.getNumColumns() == 0 && emptyReturn) { String[][] arr = new String[_numRows][]; - Arrays.fill(arr, new String[]{null}); - return new FrameBlock(new ValueType[]{ValueType.STRING}, arr); + Arrays.fill(arr, new String[] {null}); + return new FrameBlock(new ValueType[] {ValueType.STRING}, arr); } ret._colmeta = new ColumnMetadata[columnMetadata.size()]; @@ -1754,20 +1757,20 @@ public FrameBlock applySchema(ValueType[] schema) { * Method to create a new FrameBlock where the given schema is applied. * * @param schema of value types. - * @param k parallelization degree + * @param k parallelization degree * @return A new FrameBlock with the schema applied. */ - public FrameBlock applySchema(ValueType[] schema, int k){ + public FrameBlock applySchema(ValueType[] schema, int k) { return FrameLibApplySchema.applySchema(this, schema, k); } @Override - public String toString(){ + public String toString() { StringBuilder sb = new StringBuilder(); sb.append("FrameBlock"); - if(!isColumnMetadataDefault()){ + if(!isColumnMetadataDefault()) { - if(_colnames != null){ + if(_colnames != null) { sb.append("\n"); sb.append(Arrays.toString(_colnames)); } @@ -1777,8 +1780,8 @@ public String toString(){ sb.append("\n"); sb.append(Arrays.toString(_schema)); sb.append("\n"); - if(_coldata != null){ - for(int i = 0; i < _coldata.length; i++){ + if(_coldata != null) { + for(int i = 0; i < _coldata.length; i++) { sb.append(_coldata[i]); sb.append("\n"); } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java index baa5e03b9d0..c2f892f35f1 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java @@ -32,88 +32,201 @@ import org.apache.sysds.runtime.matrix.data.Pair; /** - * generic, resizable native arrays - * - * Base class for generic, resizable array of various value types. We use this custom class hierarchy instead of Trove - * or other libraries in order to avoid unnecessary dependencies. + * Generic, resizable native arrays for the internal representation of the columns in the FrameBlock. We use this custom + * class hierarchy instead of Trove or other libraries in order to avoid unnecessary dependencies. */ public abstract class Array implements Writable { protected static final Log LOG = LogFactory.getLog(Array.class.getName()); + /** A soft reference to a memorization of this arrays mapping, used in transformEncode */ protected SoftReference> _rcdMapCache = null; + /** The current allocated number of elements in this Array */ protected int _size; + protected Array(int size) { + _size = size; + } + protected int newSize() { return Math.max(_size * 2, 4); } + /** + * Get the current cached element. + * + * @return The cached object + */ public final SoftReference> getCache() { return _rcdMapCache; } + /** + * Set the cached hashmap cache of this Array allocation, to be used in transformEncode. + * + * @param m The element to cache. + */ public final void setCache(SoftReference> m) { _rcdMapCache = m; } + /** + * Get the number of elements in the array, this does not necessarily reflect the current allocated size. + * + * @return the current number of elements + */ public final int size() { return _size; } + /** + * Get the value at a given index. + * + * This method returns objects that have a high overhead in allocation. Therefore it is not as efficient as using the + * vectorized operations specified in the object. + * + * @param index The index to query + * @return The value returned as an object + */ public abstract T get(int index); /** - * Get the underlying array out of the column Group, it is the responsibility of the caller to know what type it is + * Get the underlying array out of the column Group, + * + * it is the responsibility of the caller to know what type it is. + * + * Also it is not guaranteed that the underlying data structure does not allocate an appropriate response to the + * caller. This in practice means that if called there is a possibility that the entire array is allocated again. So + * the method should only be used for debugging purposes not for performance. * * @return The underlying array. */ public abstract Object get(); + public abstract double getAsDouble(int i); + + /** + * Set index to the given value of same type + * + * @param index The index to set + * @param value The value to assign + */ public abstract void set(int index, T value); + /** + * Set index to given double value (cast to the correct type of this array) + * + * @param index the index to set + * @param value the value to set it to (before casting to correct value type) + */ public abstract void set(int index, double value); + /** + * Set range to given arrays value + * + * @param rl row lower + * @param ru row upper (inclusive) + * @param value value array to take values from (other type) + */ public abstract void setFromOtherType(int rl, int ru, Array value); + /** + * Set range to given arrays value + * + * @param rl row lower + * @param ru row upper (inclusive) + * @param value value array to take values from (same type) + */ public abstract void set(int rl, int ru, Array value); + /** + * Set range to given arrays value with an offset into other array + * + * @param rl row lower + * @param ru row upper (inclusive) + * @param value value array to take values from + * @param rlSrc the offset into the value array to take values from + */ public abstract void set(int rl, int ru, Array value, int rlSrc); + /** + * Set non default values from the value array given + * + * @param value array of same type + */ + public final void setNz(Array value) { + setNz(0, value.size() - 1, value); + } + + /** + * Set non default values in the range from the value array given + * + * @param rl row start + * @param ru row upper inclusive + * @param value value array of same type + */ public abstract void setNz(int rl, int ru, Array value); + /** + * Set non default values from the value array given + * + * @param value array of other type + */ + public final void setFromOtherTypeNz(Array value) { + setFromOtherTypeNz(0, value.size(), value); + } + + /** + * Set non default values in the range from the value array given + * + * @param rl row start + * @param ru row end inclusive + * @param value value array of different type + */ + public abstract void setFromOtherTypeNz(int rl, int ru, Array value); + + /** + * Append a string value to the current Array, this should in general be avoided, and appending larger blocks at a + * time should be preferred. + * + * @param value The value to append + */ public abstract void append(String value); + /** + * Append a value of the same type of the Array. This should in general be avoided, and appending larger blocks at a + * time should be preferred. + * + * @param value The value to append + */ public abstract void append(T value); - @Override - public abstract Array clone(); - /** * Slice out the sub range and return new array with the specified type. * * If the conversion fails fallback to normal slice * * @param rl row start - * @param ru row end + * @param ru row end (not included) * @return A new array of sub range. */ public abstract Array slice(int rl, int ru); /** - * Slice out the sub range and return new array with the specified type. - * - * If the conversion fails fallback to normal slice + * Reset the Array and set to a different size. This method is used to reuse an already allocated Array, without + * extra allocation. It should only be done in cases where the Array is no longer in use in any FrameBlocks. * - * @param rl row start - * @param ru row end - * @param vt valuetype target - * @return A new array of sub range. + * @param size The size to reallocate into. */ - public abstract Array sliceTransform(int rl, int ru, ValueType vt); - public abstract void reset(int size); - public abstract byte[] getAsByteArray(int nRow); + /** + * Return the current allocated Array as a byte[], this is used to serialize the allocated Arrays out to the + * PythonAPI. + * + * @return The array as bytes + */ + public abstract byte[] getAsByteArray(); /** * Get the current value type of this array. @@ -129,6 +242,12 @@ public final int size() { */ public abstract ValueType analyzeValueType(); + /** + * Get the internal FrameArrayType, to specify the encoding of the Types, note there are more Frame Array Types than + * there is ValueTypes. + * + * @return The FrameArrayType + */ public abstract FrameArrayType getFrameArrayType(); /** @@ -136,15 +255,25 @@ public final int size() { * * @return the size in memory of this object. */ - public long getInMemorySize(){ - return baseMemoryCost(); + public long getInMemorySize() { + return baseMemoryCost(); } - public static long baseMemoryCost(){ - // Object header , int size, padding, softref. + /** + * Get the base memory cost of the Arrays allocation. + * + * @return The base memory cost + */ + public static long baseMemoryCost() { + // Object header , int size, padding, softref. return 16 + 4 + 4 + 8; } + /** + * Get the exact serialized size on disk of this array. + * + * @return The exact size on disk + */ public abstract long getExactSerializedSize(); /** @@ -178,29 +307,89 @@ public final Array changeType(ValueType t) { } } - protected abstract Array changeTypeBitSet(); + /** + * Change type to a bitSet, of underlying longs to store the individual values + * + * @return A Boolean type of array + */ + protected abstract Array changeTypeBitSet(); - protected abstract Array changeTypeBoolean(); + /** + * Change type to a boolean array + * + * @returnA Boolean type of array + */ + protected abstract Array changeTypeBoolean(); - protected abstract Array changeTypeDouble(); + /** + * Change type to a Double array type + * + * @return Double type of array + */ + protected abstract Array changeTypeDouble(); - protected abstract Array changeTypeFloat(); + /** + * Change type to a Float array type + * + * @return Float type of array + */ + protected abstract Array changeTypeFloat(); - protected abstract Array changeTypeInteger(); + /** + * Change type to a Integer array type + * + * @return Integer type of array + */ + protected abstract Array changeTypeInteger(); - protected abstract Array changeTypeLong(); + /** + * Change type to a Long array type + * + * @return Long type of array + */ + protected abstract Array changeTypeLong(); - protected Array changeTypeString() { - String[] ret = new String[size()]; - for(int i = 0; i < size(); i++) - ret[i] = get(i).toString(); - return new StringArray(ret); - } + /** + * Change type to a String array type + * + * @return String type of array + */ + protected abstract Array changeTypeString(); + /** + * Get the minimum and maximum length of the contained values as string type. + * + * @return A Pair of first the minimum length, second the maximum length + */ public Pair getMinMaxLength() { throw new DMLRuntimeException("Length is only relevant if case is String"); } + /** + * fill the entire array with specific value. + * + * @param val the value to fill with. + */ + public abstract void fill(String val); + + /** + * fill the entire array with specific value. + * + * @param val the value to fill with. + */ + public abstract void fill(T val); + + /** + * Overwrite of the java internal clone function for arrays, return a clone of underlying data that is mutable, (not + * immutable data.) + * + * Immutable data is dependent on the individual allocated arrays + * + * @return A clone + */ + @Override + public abstract Array clone(); + @Override public String toString() { return this.getClass().getSimpleName(); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java index b237c4d550b..c8414fb3cdb 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java @@ -29,7 +29,7 @@ public interface ArrayFactory { - public static int bitSetSwitchPoint = 64; + public final static int bitSetSwitchPoint = 64; public enum FrameArrayType { STRING, BOOLEAN, BITSET, INT32, INT64, FP32, FP64; @@ -67,7 +67,7 @@ public static long getInMemorySize(ValueType type, int _numRows) { switch(type) { case BOOLEAN: if(_numRows > bitSetSwitchPoint) - return Array.baseMemoryCost() + 8 + (long) MemoryEstimates.bitSetCost(_numRows); + return Array.baseMemoryCost() + (long) MemoryEstimates.longArrayCost(_numRows >> 6 + 1); else return Array.baseMemoryCost() + (long) MemoryEstimates.booleanArrayCost(_numRows); case INT64: @@ -88,6 +88,12 @@ public static long getInMemorySize(ValueType type, int _numRows) { } } + public static Array allocate(ValueType v, int nRow, String val){ + Array a = allocate(v, nRow); + a.fill(val); + return a; + } + public static Array allocate(ValueType v, int nRow) { switch(v) { case STRING: diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java index 601cb552a9d..785c39ce8b6 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java @@ -27,52 +27,78 @@ import java.util.Arrays; import java.util.BitSet; -import org.apache.commons.lang.NotImplementedException; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType; +import org.apache.sysds.runtime.util.UtilFunctions; import org.apache.sysds.utils.MemoryEstimates; public class BitSetArray extends Array { - private static boolean useVectorizedKernel = true; - private BitSet _data; + private static final boolean useVectorizedKernel = true; + + /** Vectorized "words" containing all the bits set */ + long[] _data; protected BitSetArray(int size) { - _size = size; - _data = new BitSet(size); + super(size); + _data = new long[size / 64 + 1]; } public BitSetArray(boolean[] data) { - _size = data.length; - _data = new BitSet(data.length); + super(data.length); + _data = new long[_size / 64 + 1]; // set bits. for(int i = 0; i < data.length; i++) if(data[i]) // slightly more efficient to check. - _data.set(i); + set(i, true); } - public BitSetArray(BitSet data, int size) { - _size = size; + public BitSetArray(long[] data, int size) { + super(size); _data = data; + if(_size > _data.length * 64) + throw new DMLRuntimeException("Invalid allocation long array must be long enough"); + if(_data.length > _size / 64 + 1) + throw new DMLRuntimeException( + "Invalid allocation long array must not be to long" + _data.length + " " + _size + " " + (size / 64 + 1)); + } + + public BitSetArray(BitSet data, int size) { + super(size); + _data = toLongArrayPadded(data, size); } public BitSet get() { + return BitSet.valueOf(_data); + } + + public long[] getLongs() { return _data; } @Override public Boolean get(int index) { - return _data.get(index); + int wIdx = index >> 6; // same as divide by 64 bit faster + return (_data[wIdx] & (1L << index)) != 0; } @Override public void set(int index, Boolean value) { - _data.set(index, value != null ? value : false); + set(index, value != null && value); + } + + public void set(int index, boolean value) { + int wIdx = index >> 6; // same as divide by 64 bit faster + if(value) + _data[wIdx] |= (1L << index); + else + _data[wIdx] &= ~(1L << index); } @Override public void set(int index, double value) { - _data.set(index, value == 0 ? false : true); + set(index, value == 1.0); } @Override @@ -82,7 +108,10 @@ public void set(int rl, int ru, Array value) { @Override public void setFromOtherType(int rl, int ru, Array value) { - throw new NotImplementedException(); + final ValueType vt = value.getValueType(); + for(int i = rl; i <= ru; i++) + set(i, UtilFunctions.objectToBoolean(vt, value.get(i))); + } private static long[] toLongArrayPadded(BitSet data, int minLength) { @@ -99,164 +128,211 @@ public void set(int rl, int ru, Array value, int rlSrc) { setVectorized(rl, ru, (BitSetArray) value, rlSrc); else // default for(int i = rl, off = rlSrc; i <= ru; i++, off++) - _data.set(i, value.get(off)); + set(i, value.get(off)); } private void setVectorized(int rl, int ru, BitSetArray value, int rlSrc) { final int rangeLength = ru - rl + 1; - final long[] otherValues = toLongArrayPadded(// - (BitSet) value.get().get(rlSrc, rangeLength + rlSrc), rangeLength); - long[] ret = toLongArrayPadded(_data, size()); - - ret = setVectorizedLongs(rl, ru, otherValues, ret); - _data = BitSet.valueOf(ret); + final BitSetArray v = value.slice(rlSrc, rangeLength + rlSrc); + final long[] otherValues = v.getLongs(); + setVectorizedLongs(rl, ru, otherValues); } - private static long[] setVectorizedLongs(int rl, int ru, long[] ov, long[] ret) { + private void setVectorizedLongs(int rl, int ru, long[] ov) { final long remainder = rl % 64L; if(remainder == 0) - return setVectorizedLongsNoOffset(rl, ru, ov, ret); - else - return setVectorizedLongsWithOffset(rl, ru, ov, ret); + setVectorizedLongsNoOffset(rl, ru, ov); + else + setVectorizedLongsWithOffset(rl, ru, ov); } - private static long[] setVectorizedLongsNoOffset(int rl, int ru, long[] ov, long[] ret) { + private void setVectorizedLongsNoOffset(int rl, int ru, long[] ov) { final long remainderEnd = (ru + 1) % 64L; final long remainderEndInv = 64L - remainderEnd; - final int last = ov.length -1; + final int last = ov.length - 1; int retP = rl / 64; // assign all full. - for(int j = 0; j < last; j++) { - ret[retP] = ov[j]; - retP++; - } + for(int j = 0; j < last; j++, retP++) + _data[retP] = ov[j]; // handle tail. if(remainderEnd != 0) { // clear ret in the area. - final long r = (ret[retP] >>> remainderEnd) << remainderEnd; + final long r = (_data[retP] >>> remainderEnd) << remainderEnd; final long v = (ov[last] << remainderEndInv) >>> remainderEndInv; // assign ret in the area. - ret[retP] = r ^ v; + _data[retP] = r ^ v; } else - ret[retP] = ov[last]; - return ret; + _data[retP] = ov[last]; } - private static long[] setVectorizedLongsWithOffset(int rl, int ru, long[] ov, long[] ret) { + private void setVectorizedLongsWithOffset(int rl, int ru, long[] ov) { final long remainder = rl % 64L; final long invRemainder = 64L - remainder; - final int last = ov.length -1; - final int lastP = (ru+1) / 64; - final long finalOriginal = ret[lastP]; // original log at the ru location. + final int last = ov.length - 1; + final int lastP = (ru + 1) / 64; + final long finalOriginal = _data[lastP]; // original log at the ru location. int retP = rl / 64; // pointer for current long to edit - + // first mask out previous and then continue // mask by shifting two times (easier than constructing a mask) - ret[retP] = (ret[retP] << invRemainder) >>> invRemainder; - + _data[retP] = (_data[retP] << invRemainder) >>> invRemainder; + // middle full 64 bit overwrite no need to mask first. // do not include last (it has to be specially handled) for(int j = 0; j < last; j++) { final long v = ov[j]; - ret[retP] = ret[retP] ^ (v << remainder); + _data[retP] = _data[retP] ^ (v << remainder); retP++; - ret[retP] = v >>> invRemainder; + _data[retP] = v >>> invRemainder; } - - ret[retP] = (ov[last] << remainder) ^ ret[retP]; + + _data[retP] = (ov[last] << remainder) ^ _data[retP]; retP++; - if(retP < ret.length && retP <= lastP) // aka there is a remainder - ret[retP] = ov[last] >>> invRemainder; - + if(retP < _data.length && retP <= lastP) // aka there is a remainder + _data[retP] = ov[last] >>> invRemainder; + // reassign everything outside range of ru. final long remainderEnd = (ru + 1) % 64L; final long remainderEndInv = 64L - remainderEnd; - ret[lastP] = (ret[lastP] << remainderEndInv) >>> remainderEndInv; - ret[lastP] = ret[lastP] ^ (finalOriginal >>> remainderEnd) << remainderEnd; - - return ret; + _data[lastP] = (_data[lastP] << remainderEndInv) >>> remainderEndInv; + _data[lastP] = _data[lastP] ^ (finalOriginal >>> remainderEnd) << remainderEnd; + } @Override public void setNz(int rl, int ru, Array value) { - if(value instanceof BitSetArray) { - throw new NotImplementedException(); + if(value instanceof BooleanArray) { + final boolean[] data2 = ((BooleanArray) value)._data; + for(int i = rl; i <= ru; i++) + if(data2[i]) + set(i, data2[i]); } else { + // TODO add an vectorized setNz. + for(int i = rl; i <= ru; i++) { + final boolean v = value.get(i); + if(v) + set(i, v); + } + } + } - boolean[] data2 = ((BooleanArray) value)._data; - for(int i = rl; i < ru + 1; i++) - if(data2[i]) - _data.set(i, data2[i]); + @Override + public void setFromOtherTypeNz(int rl, int ru, Array value) { + final ValueType vt = value.getValueType(); + for(int i = rl; i <= ru; i++) { + boolean v = UtilFunctions.objectToBoolean(vt, value.get(i)); + if(v) + set(i, v); } } @Override public void append(String value) { - append(Boolean.parseBoolean(value)); + append(BooleanArray.parseBoolean(value)); } @Override public void append(Boolean value) { - _data.set(_size, value); + if(_data.length * 64 < _size + 1) + _data = Arrays.copyOf(_data, newSize()); + set(_size, value); _size++; } + @Override + public int newSize() { + return _data.length * 2; + } + @Override public void write(DataOutput out) throws IOException { out.writeByte(FrameArrayType.BITSET.ordinal()); - long[] internals = _data.toLongArray(); - out.writeInt(internals.length); - for(int i = 0; i < internals.length; i++) - out.writeLong(internals[i]); + out.writeInt(_data.length); + for(int i = 0; i < _data.length; i++) + out.writeLong(_data[i]); } @Override public void readFields(DataInput in) throws IOException { - long[] internalLong = new long[in.readInt()]; - for(int i = 0; i < internalLong.length; i++) - internalLong[i] = in.readLong(); - _data = BitSet.valueOf(internalLong); + _data = new long[in.readInt()]; + for(int i = 0; i < _data.length; i++) + _data[i] = in.readLong(); } @Override - public Array clone() { - long[] d = _data.toLongArray(); - int ln = d.length; - long[] nd = Arrays.copyOf(d, ln); - BitSet nBS = BitSet.valueOf(nd); - return new BitSetArray(nBS, _size); + public BitSetArray clone() { + return new BitSetArray(Arrays.copyOf(_data, _data.length), _size); } @Override - public Array slice(int rl, int ru) { - return new BitSetArray(_data.get(rl, ru), ru - rl); + public BitSetArray slice(int rl, int ru) { + return ru - rl > 30 // if over threshold + ? sliceVectorized(rl, ru) // slice vectorized + : sliceSimple(rl, ru); // slice via get } - @Override - public Array sliceTransform(int rl, int ru, ValueType vt) { - return slice(rl, ru); + private BitSetArray sliceSimple(int rl, int ru) { + final boolean[] ret = new boolean[ru - rl + 1]; + for(int i = rl, off = 0; i <= ru; i++, off++) + ret[off] = get(i); + return new BitSetArray(ret); + } + + private BitSetArray sliceVectorized(int rl, int ru) { + + final long[] ret = new long[(ru - rl) / 64 + 1]; + + final long BitIndexMask = (1 << 6) - 1L; + final long lastMask = 0xffffffffffffffffL >>> -ru; + + // targetWords + final int tW = ((ru - rl - 1) >>> 6) + 1; + // sourceIndex + int sI = rl >> 6; + + boolean aligned = (rl & BitIndexMask) == 0; + + // all but last + if(aligned) { + for(int i = 0; i < tW - 1; i++, sI++) { + ret[i] = _data[sI]; + } + } + else { + for(int i = 0; i < tW - 1; i++, sI++) { + ret[i] = (_data[sI] >>> rl) | (_data[sI + 1] << -rl); + } + } + + // last + ret[tW - 1] = // + (((ru - 1) & BitIndexMask)) < (rl & BitIndexMask) // + ? (_data[sI] >>> rl) | (_data[sI + 1] & lastMask) << -rl // + : (_data[sI] & lastMask) >>> rl; + + return new BitSetArray(ret, ru - rl); } @Override public void reset(int size) { - _data = new BitSet(); + _data = new long[size / 64 + 1]; _size = size; } @Override - public byte[] getAsByteArray(int nRow) { + public byte[] getAsByteArray() { // over allocating here.. we could maybe bit pack? - ByteBuffer booleanBuffer = ByteBuffer.allocate(nRow); + ByteBuffer booleanBuffer = ByteBuffer.allocate(_size); booleanBuffer.order(ByteOrder.nativeOrder()); // TODO: fix inefficient transfer 8 x bigger. // We should do bit unpacking on the python side. - for(int i = 0; i < nRow; i++) - booleanBuffer.put((byte) (_data.get(i) ? 1 : 0)); + for(int i = 0; i < _size; i++) + booleanBuffer.put((byte) (get(i) ? 1 : 0)); return booleanBuffer.array(); } @@ -277,77 +353,104 @@ public FrameArrayType getFrameArrayType() { @Override public long getInMemorySize() { - long size = super.getInMemorySize() + 8; // object header + object reference - size += MemoryEstimates.bitSetCost(_size); + return estimateInMemorySize(_size); + } + + public static long estimateInMemorySize(int nRow) { + long size = baseMemoryCost(); // object header + object reference + size += MemoryEstimates.longArrayCost(nRow >> 6 + 1); return size; } @Override public long getExactSerializedSize() { long size = 1 + 4; - size += _data.toLongArray().length * 8; + size += _data.length * 8; return size; } @Override - protected Array changeTypeBitSet() { + protected Array changeTypeBitSet() { return clone(); } @Override - protected Array changeTypeBoolean() { + protected Array changeTypeBoolean() { boolean[] ret = new boolean[size()]; for(int i = 0; i < size(); i++) // if ever relevant use next set bit instead. // to increase speed, but it should not be the case in general - ret[i] = _data.get(i); + ret[i] = get(i); return new BooleanArray(ret); } @Override - protected Array changeTypeDouble() { + protected Array changeTypeDouble() { double[] ret = new double[size()]; for(int i = 0; i < size(); i++) - ret[i] = _data.get(i) ? 1.0 : 0.0; + ret[i] = get(i) ? 1.0 : 0.0; return new DoubleArray(ret); } @Override - protected Array changeTypeFloat() { + protected Array changeTypeFloat() { float[] ret = new float[size()]; for(int i = 0; i < size(); i++) - ret[i] = _data.get(i) ? 1.0f : 0.0f; + ret[i] = get(i) ? 1.0f : 0.0f; return new FloatArray(ret); } @Override - protected Array changeTypeInteger() { + protected Array changeTypeInteger() { int[] ret = new int[size()]; for(int i = 0; i < size(); i++) - ret[i] = _data.get(i) ? 1 : 0; + ret[i] = get(i) ? 1 : 0; return new IntegerArray(ret); } @Override - protected Array changeTypeLong() { + protected Array changeTypeLong() { long[] ret = new long[size()]; for(int i = 0; i < size(); i++) - ret[i] = _data.get(i) ? 1L : 0L; + ret[i] = get(i) ? 1L : 0L; return new LongArray(ret); } + @Override + protected Array changeTypeString() { + String[] ret = new String[size()]; + for(int i = 0; i < size(); i++) + ret[i] = get(i).toString(); + return new StringArray(ret); + } + + @Override + public void fill(String value) { + fill(BooleanArray.parseBoolean(value)); + } + + @Override + public void fill(Boolean value) { + for(int i = 0; i < _size / 64 + 1; i++) + _data[i] = value ? -1L : 0L; + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_size * 5 + 2); sb.append(super.toString() + ":["); - for(int i = 0; i < _size - 1; i++) - sb.append((_data.get(i) ? 1 : 0) + ","); - sb.append(_data.get(_size - 1) ? 1 : 0); + for(int i = 0; i < _size; i++) + sb.append((get(i) ? 1 : 0)); sb.append("]"); return sb.toString(); } + @Override + public double getAsDouble(int i){ + return get(i) ? 1.0: 0.0; + } + public static String longToBits(long l) { String bits = Long.toBinaryString(l); StringBuilder sb = new StringBuilder(64); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java index 59dff6f11f9..72a67421a49 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java @@ -26,17 +26,17 @@ import java.nio.ByteOrder; import java.util.Arrays; -import org.apache.commons.lang.NotImplementedException; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType; +import org.apache.sysds.runtime.util.UtilFunctions; import org.apache.sysds.utils.MemoryEstimates; public class BooleanArray extends Array { protected boolean[] _data; public BooleanArray(boolean[] data) { + super(data.length); _data = data; - _size = _data.length; } public boolean[] get() { @@ -50,12 +50,12 @@ public Boolean get(int index) { @Override public void set(int index, Boolean value) { - _data[index] = (value != null) ? value : false; + _data[index] = value != null && value; } - + @Override public void set(int index, double value) { - _data[index] = value == 0 ? false : true; + _data[index] = value == 1.0; } @Override @@ -64,8 +64,10 @@ public void set(int rl, int ru, Array value) { } @Override - public void setFromOtherType(int rl, int ru, Array value){ - throw new NotImplementedException(); + public void setFromOtherType(int rl, int ru, Array value) { + final ValueType vt = value.getValueType(); + for(int i = rl; i <= ru; i++) + set(i, UtilFunctions.objectToBoolean(vt, value.get(i))); } @Override @@ -79,15 +81,34 @@ public void set(int rl, int ru, Array value, int rlSrc) { @Override public void setNz(int rl, int ru, Array value) { - boolean[] data2 = ((BooleanArray) value)._data; - for(int i = rl; i < ru + 1; i++) - if(data2[i]) - _data[i] = data2[i]; + if(value instanceof BooleanArray) { + boolean[] data2 = ((BooleanArray) value)._data; + for(int i = rl; i <= ru; i++) + if(data2[i]) + _data[i] = data2[i]; + } + else { + for(int i = rl; i <= ru; i++) { + final boolean v = value.get(i); + if(v) + _data[i] = v; + } + } + } + + @Override + public void setFromOtherTypeNz(int rl, int ru, Array value) { + final ValueType vt = value.getValueType(); + for(int i = rl; i <= ru; i++) { + boolean v = UtilFunctions.objectToBoolean(vt, value.get(i)); + if(v) + _data[i] = v; + } } @Override public void append(String value) { - append(Boolean.parseBoolean(value)); + append(parseBoolean(value)); } @Override @@ -121,11 +142,6 @@ public Array slice(int rl, int ru) { return new BooleanArray(Arrays.copyOfRange(_data, rl, ru)); } - @Override - public Array sliceTransform(int rl, int ru, ValueType vt) { - return slice(rl, ru); - } - @Override public void reset(int size) { if(_data.length < size) @@ -134,11 +150,11 @@ public void reset(int size) { } @Override - public byte[] getAsByteArray(int nRow) { + public byte[] getAsByteArray() { // over allocating here.. we could maybe bit pack? - ByteBuffer booleanBuffer = ByteBuffer.allocate(nRow); + ByteBuffer booleanBuffer = ByteBuffer.allocate(_size); booleanBuffer.order(ByteOrder.nativeOrder()); - for(int i = 0; i < nRow; i++) + for(int i = 0; i < _size; i++) booleanBuffer.put((byte) (_data[i] ? 1 : 0)); return booleanBuffer.array(); } @@ -160,8 +176,12 @@ public FrameArrayType getFrameArrayType() { @Override public long getInMemorySize() { - long size = super.getInMemorySize() ; // object header + object reference - size += MemoryEstimates.booleanArrayCost(_data.length); + return estimateInMemorySize(_size); + } + + public static long estimateInMemorySize(int nRow) { + long size = baseMemoryCost(); // object header + object reference + size += MemoryEstimates.booleanArrayCost(nRow); return size; } @@ -171,17 +191,17 @@ public long getExactSerializedSize() { } @Override - protected Array changeTypeBitSet() { + protected Array changeTypeBitSet() { return new BitSetArray(_data); } @Override - protected Array changeTypeBoolean() { + protected Array changeTypeBoolean() { return clone(); } @Override - protected Array changeTypeDouble() { + protected Array changeTypeDouble() { double[] ret = new double[size()]; for(int i = 0; i < size(); i++) ret[i] = _data[i] ? 1.0 : 0.0; @@ -189,7 +209,7 @@ protected Array changeTypeDouble() { } @Override - protected Array changeTypeFloat() { + protected Array changeTypeFloat() { float[] ret = new float[size()]; for(int i = 0; i < size(); i++) ret[i] = _data[i] ? 1.0f : 0.0f; @@ -197,7 +217,7 @@ protected Array changeTypeFloat() { } @Override - protected Array changeTypeInteger() { + protected Array changeTypeInteger() { int[] ret = new int[size()]; for(int i = 0; i < size(); i++) ret[i] = _data[i] ? 1 : 0; @@ -205,13 +225,33 @@ protected Array changeTypeInteger() { } @Override - protected Array changeTypeLong() { + protected Array changeTypeLong() { long[] ret = new long[size()]; for(int i = 0; i < size(); i++) ret[i] = _data[i] ? 1L : 0L; return new LongArray(ret); } + @Override + protected Array changeTypeString() { + String[] ret = new String[size()]; + for(int i = 0; i < size(); i++) + ret[i] = get(i).toString(); + return new StringArray(ret); + } + + @Override + public void fill(String value) { + fill(parseBoolean(value)); + } + + @Override + public void fill(Boolean value) { + for(int i = 0; i < _size; i++) + _data[i] = value; + + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_data.length * 5 + 2); @@ -222,4 +262,13 @@ public String toString() { sb.append("]"); return sb.toString(); } + + @Override + public double getAsDouble(int i){ + return _data[i] ? 1.0: 0.0; + } + + protected static boolean parseBoolean(String value) { + return value != null && (Boolean.parseBoolean(value) || value.equals("1") || value.equals("1.0")); + } } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java index bd70e1a8bf5..8102a2e4306 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java @@ -27,19 +27,19 @@ import java.util.Arrays; import java.util.BitSet; -import org.apache.commons.lang.NotImplementedException; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameUtil; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType; +import org.apache.sysds.runtime.util.UtilFunctions; import org.apache.sysds.utils.MemoryEstimates; public class DoubleArray extends Array { private double[] _data; public DoubleArray(double[] data) { + super(data.length); _data = data; - _size = _data.length; } public double[] get() { @@ -68,7 +68,9 @@ public void set(int rl, int ru, Array value) { @Override public void setFromOtherType(int rl, int ru, Array value) { - throw new NotImplementedException(); + final ValueType vt = value.getValueType(); + for(int i = rl; i <= ru; i++) + _data[i] = UtilFunctions.objectToDouble(vt, value.get(i)); } @Override @@ -79,14 +81,24 @@ public void set(int rl, int ru, Array value, int rlSrc) { @Override public void setNz(int rl, int ru, Array value) { double[] data2 = ((DoubleArray) value)._data; - for(int i = rl; i < ru + 1; i++) + for(int i = rl; i <= ru; i++) if(data2[i] != 0) _data[i] = data2[i]; } + @Override + public void setFromOtherTypeNz(int rl, int ru, Array value) { + final ValueType vt = value.getValueType(); + for(int i = rl; i <= ru; i++) { + double v = UtilFunctions.objectToDouble(vt, value.get(i)); + if(v != 0) + _data[i] = v; + } + } + @Override public void append(String value) { - append((value != null) ? Double.parseDouble(value) : null); + append(parseDouble(value)); } @Override @@ -120,11 +132,6 @@ public Array slice(int rl, int ru) { return new DoubleArray(Arrays.copyOfRange(_data, rl, ru)); } - @Override - public Array sliceTransform(int rl, int ru, ValueType vt) { - return slice(rl, ru); - } - @Override public void reset(int size) { if(_data.length < size) @@ -133,10 +140,10 @@ public void reset(int size) { } @Override - public byte[] getAsByteArray(int nRow) { - ByteBuffer doubleBuffer = ByteBuffer.allocate(8 * nRow); + public byte[] getAsByteArray() { + ByteBuffer doubleBuffer = ByteBuffer.allocate(8 * _size); doubleBuffer.order(ByteOrder.nativeOrder()); - for(int i = 0; i < nRow; i++) + for(int i = 0; i < _size; i++) doubleBuffer.putDouble(_data[i]); return doubleBuffer.array(); } @@ -212,19 +219,19 @@ public long getExactSerializedSize() { } @Override - protected Array changeTypeBitSet() { + protected Array changeTypeBitSet() { BitSet ret = new BitSet(size()); for(int i = 0; i < size(); i++) { if(_data[i] != 0 && _data[i] != 1) throw new DMLRuntimeException( "Unable to change to Boolean from Integer array because of value:" + _data[i]); - ret.set(i, _data[i] == 0 ? false : true); + ret.set(i, _data[i] == 0 ? false : true); } return new BitSetArray(ret, size()); } @Override - protected Array changeTypeBoolean() { + protected Array changeTypeBoolean() { boolean[] ret = new boolean[size()]; for(int i = 0; i < size(); i++) { if(_data[i] != 0 && _data[i] != 1) @@ -236,12 +243,12 @@ protected Array changeTypeBoolean() { } @Override - protected Array changeTypeDouble() { + protected Array changeTypeDouble() { return clone(); } @Override - protected Array changeTypeFloat() { + protected Array changeTypeFloat() { float[] ret = new float[size()]; for(int i = 0; i < size(); i++) ret[i] = (float) _data[i]; @@ -249,7 +256,7 @@ protected Array changeTypeFloat() { } @Override - protected Array changeTypeInteger() { + protected Array changeTypeInteger() { int[] ret = new int[size()]; for(int i = 0; i < size(); i++) { if(_data[i] != (int) _data[i]) @@ -260,7 +267,7 @@ protected Array changeTypeInteger() { } @Override - protected Array changeTypeLong() { + protected Array changeTypeLong() { long[] ret = new long[size()]; for(int i = 0; i < size(); i++) { if(_data[i] != (long) _data[i]) @@ -270,6 +277,37 @@ protected Array changeTypeLong() { return new LongArray(ret); } + @Override + protected Array changeTypeString() { + String[] ret = new String[size()]; + for(int i = 0; i < size(); i++) + ret[i] = get(i).toString(); + return new StringArray(ret); + } + + @Override + public void fill(String value) { + fill(parseDouble(value)); + } + + @Override + public void fill(Double value) { + for(int i = 0; i < _size; i++) + _data[i] = value; + } + + @Override + public double getAsDouble(int i) { + return _data[i]; + } + + protected static double parseDouble(String value) { + if(value == null) + return 0.0; + else + return Double.parseDouble(value); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_data.length * 5 + 2); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java index 25b5144cece..839e64d41f7 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java @@ -27,18 +27,18 @@ import java.util.Arrays; import java.util.BitSet; -import org.apache.commons.lang.NotImplementedException; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType; +import org.apache.sysds.runtime.util.UtilFunctions; import org.apache.sysds.utils.MemoryEstimates; public class FloatArray extends Array { private float[] _data; public FloatArray(float[] data) { + super(data.length); _data = data; - _size = _data.length; } public float[] get() { @@ -67,7 +67,9 @@ public void set(int rl, int ru, Array value) { @Override public void setFromOtherType(int rl, int ru, Array value) { - throw new NotImplementedException(); + final ValueType vt = value.getValueType(); + for(int i = rl; i <= ru; i++) + _data[i] = UtilFunctions.objectToFloat(vt, value.get(i)); } @Override @@ -78,14 +80,24 @@ public void set(int rl, int ru, Array value, int rlSrc) { @Override public void setNz(int rl, int ru, Array value) { float[] data2 = ((FloatArray) value)._data; - for(int i = rl; i < ru + 1; i++) + for(int i = rl; i <= ru; i++) if(data2[i] != 0) _data[i] = data2[i]; } + @Override + public void setFromOtherTypeNz(int rl, int ru, Array value) { + final ValueType vt = value.getValueType(); + for(int i = rl; i <= ru; i++) { + float v = UtilFunctions.objectToFloat(vt, value.get(i)); + if(v != 0) + _data[i] = v; + } + } + @Override public void append(String value) { - append((value != null) ? Float.parseFloat(value) : null); + append(parseFloat(value)); } @Override @@ -119,11 +131,6 @@ public Array slice(int rl, int ru) { return new FloatArray(Arrays.copyOfRange(_data, rl, ru)); } - @Override - public Array sliceTransform(int rl, int ru, ValueType vt) { - return slice(rl, ru); - } - @Override public void reset(int size) { if(_data.length < size) @@ -132,10 +139,10 @@ public void reset(int size) { } @Override - public byte[] getAsByteArray(int nRow) { - ByteBuffer floatBuffer = ByteBuffer.allocate(8 * nRow); + public byte[] getAsByteArray() { + ByteBuffer floatBuffer = ByteBuffer.allocate(8 * _size); floatBuffer.order(ByteOrder.nativeOrder()); - for(int i = 0; i < nRow; i++) + for(int i = 0; i < _size; i++) floatBuffer.putFloat(_data[i]); return floatBuffer.array(); } @@ -168,7 +175,7 @@ public long getExactSerializedSize() { } @Override - protected Array changeTypeBitSet() { + protected Array changeTypeBitSet() { BitSet ret = new BitSet(size()); for(int i = 0; i < size(); i++) { if(_data[i] != 0 && _data[i] != 1) @@ -180,7 +187,7 @@ protected Array changeTypeBitSet() { } @Override - protected Array changeTypeBoolean() { + protected Array changeTypeBoolean() { boolean[] ret = new boolean[size()]; for(int i = 0; i < size(); i++) { if(_data[i] != 0 && _data[i] != 1) @@ -192,7 +199,7 @@ protected Array changeTypeBoolean() { } @Override - protected Array changeTypeDouble() { + protected Array changeTypeDouble() { double[] ret = new double[size()]; for(int i = 0; i < size(); i++) ret[i] = (double) _data[i]; @@ -200,7 +207,7 @@ protected Array changeTypeDouble() { } @Override - protected Array changeTypeInteger() { + protected Array changeTypeInteger() { int[] ret = new int[size()]; for(int i = 0; i < size(); i++) { if(_data[i] != (int) _data[i]) @@ -211,7 +218,7 @@ protected Array changeTypeInteger() { } @Override - protected Array changeTypeLong() { + protected Array changeTypeLong() { long[] ret = new long[size()]; for(int i = 0; i < size(); i++) ret[i] = (int) _data[i]; @@ -219,10 +226,41 @@ protected Array changeTypeLong() { } @Override - protected Array changeTypeFloat() { + protected Array changeTypeFloat() { return clone(); } + @Override + protected Array changeTypeString() { + String[] ret = new String[size()]; + for(int i = 0; i < size(); i++) + ret[i] = get(i).toString(); + return new StringArray(ret); + } + + @Override + public void fill(String value) { + fill(parseFloat(value)); + } + + @Override + public void fill(Float value) { + for(int i = 0; i < _size; i++) + _data[i] = value; + } + + @Override + public double getAsDouble(int i){ + return _data[i]; + } + + protected static float parseFloat(String value) { + if(value == null) + return 0.0f; + else + return Float.parseFloat(value); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_data.length * 5 + 2); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java index 6cc839d9458..37f6f2b9101 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java @@ -27,18 +27,18 @@ import java.util.Arrays; import java.util.BitSet; -import org.apache.commons.lang.NotImplementedException; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType; +import org.apache.sysds.runtime.util.UtilFunctions; import org.apache.sysds.utils.MemoryEstimates; public class IntegerArray extends Array { private int[] _data; public IntegerArray(int[] data) { + super(data.length); _data = data; - _size = _data.length; } public int[] get() { @@ -67,7 +67,10 @@ public void set(int rl, int ru, Array value) { @Override public void setFromOtherType(int rl, int ru, Array value) { - throw new NotImplementedException(); + final ValueType vt = value.getValueType(); + for(int i = rl; i <= ru; i++) + _data[i] = UtilFunctions.objectToInteger(vt, value.get(i)); + } @Override @@ -78,14 +81,24 @@ public void set(int rl, int ru, Array value, int rlSrc) { @Override public void setNz(int rl, int ru, Array value) { int[] data2 = ((IntegerArray) value)._data; - for(int i = rl; i < ru + 1; i++) + for(int i = rl; i <= ru; i++) if(data2[i] != 0) _data[i] = data2[i]; } + @Override + public void setFromOtherTypeNz(int rl, int ru, Array value) { + final ValueType vt = value.getValueType(); + for(int i = rl; i <= ru; i++) { + int v = UtilFunctions.objectToInteger(vt, value.get(i)); + if(v != 0) + _data[i] = v; + } + } + @Override public void append(String value) { - append((value != null) ? Integer.parseInt(value.trim()) : null); + append(parseInt(value)); } @Override @@ -119,11 +132,6 @@ public Array slice(int rl, int ru) { return new IntegerArray(Arrays.copyOfRange(_data, rl, ru)); } - @Override - public Array sliceTransform(int rl, int ru, ValueType vt) { - return slice(rl, ru); - } - public void reset(int size) { if(_data.length < size) _data = new int[size]; @@ -131,10 +139,10 @@ public void reset(int size) { } @Override - public byte[] getAsByteArray(int nRow) { - ByteBuffer intBuffer = ByteBuffer.allocate(4 * nRow); + public byte[] getAsByteArray() { + ByteBuffer intBuffer = ByteBuffer.allocate(4 * _size); intBuffer.order(ByteOrder.LITTLE_ENDIAN); - for(int i = 0; i < nRow; i++) + for(int i = 0; i < _size; i++) intBuffer.putInt(_data[i]); return intBuffer.array(); } @@ -167,7 +175,7 @@ public long getExactSerializedSize() { } @Override - protected Array changeTypeBitSet() { + protected Array changeTypeBitSet() { BitSet ret = new BitSet(size()); for(int i = 0; i < size(); i++) { if(_data[i] != 0 && _data[i] != 1) @@ -179,7 +187,7 @@ protected Array changeTypeBitSet() { } @Override - protected Array changeTypeBoolean() { + protected Array changeTypeBoolean() { boolean[] ret = new boolean[size()]; for(int i = 0; i < size(); i++) { if(_data[i] < 0 || _data[i] > 1) @@ -191,7 +199,7 @@ protected Array changeTypeBoolean() { } @Override - protected Array changeTypeDouble() { + protected Array changeTypeDouble() { double[] ret = new double[size()]; for(int i = 0; i < size(); i++) ret[i] = (double) _data[i]; @@ -199,7 +207,7 @@ protected Array changeTypeDouble() { } @Override - protected Array changeTypeFloat() { + protected Array changeTypeFloat() { float[] ret = new float[size()]; for(int i = 0; i < size(); i++) ret[i] = (float) _data[i]; @@ -207,18 +215,57 @@ protected Array changeTypeFloat() { } @Override - protected Array changeTypeInteger() { + protected Array changeTypeInteger() { return clone(); } @Override - protected Array changeTypeLong() { + protected Array changeTypeLong() { long[] ret = new long[size()]; for(int i = 0; i < size(); i++) ret[i] = _data[i]; return new LongArray(ret); } + @Override + protected Array changeTypeString() { + String[] ret = new String[size()]; + for(int i = 0; i < size(); i++) + ret[i] = get(i).toString(); + return new StringArray(ret); + } + + @Override + public void fill(String value) { + fill(parseInt(value)); + } + + @Override + public void fill(Integer value) { + for(int i = 0; i < _size; i++) + _data[i] = value; + } + + @Override + public double getAsDouble(int i){ + return _data[i]; + } + + protected static int parseInt(String s) { + if(s == null) + return 0; + try { + return Integer.parseInt(s); + } + catch(NumberFormatException e) { + if(s.contains(".")){ + return Integer.parseInt(s.split("\\.")[0]); + } + else + throw e; + } + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_data.length * 5 + 2); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java index bf217ecf055..36ceda15d8e 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java @@ -27,18 +27,18 @@ import java.util.Arrays; import java.util.BitSet; -import org.apache.commons.lang.NotImplementedException; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType; +import org.apache.sysds.runtime.util.UtilFunctions; import org.apache.sysds.utils.MemoryEstimates; public class LongArray extends Array { private long[] _data; public LongArray(long[] data) { + super(data.length); _data = data; - _size = _data.length; } public long[] get() { @@ -67,7 +67,9 @@ public void set(int rl, int ru, Array value) { @Override public void setFromOtherType(int rl, int ru, Array value) { - throw new NotImplementedException(); + final ValueType vt = value.getValueType(); + for(int i = rl; i <= ru; i++) + _data[i] = UtilFunctions.objectToLong(vt, value.get(i)); } @Override @@ -78,14 +80,24 @@ public void set(int rl, int ru, Array value, int rlSrc) { @Override public void setNz(int rl, int ru, Array value) { long[] data2 = ((LongArray) value)._data; - for(int i = rl; i < ru + 1; i++) + for(int i = rl; i <= ru; i++) if(data2[i] != 0) _data[i] = data2[i]; } + @Override + public void setFromOtherTypeNz(int rl, int ru, Array value) { + final ValueType vt = value.getValueType(); + for(int i = rl; i <= ru; i++) { + long v = UtilFunctions.objectToLong(vt, value.get(i)); + if(v != 0) + _data[i] = v; + } + } + @Override public void append(String value) { - append((value != null) ? Long.parseLong(value.trim()) : null); + append(parseLong(value)); } @Override @@ -119,11 +131,6 @@ public Array slice(int rl, int ru) { return new LongArray(Arrays.copyOfRange(_data, rl, ru)); } - @Override - public Array sliceTransform(int rl, int ru, ValueType vt) { - return slice(rl, ru); - } - @Override public void reset(int size) { if(_data.length < size) @@ -132,10 +139,10 @@ public void reset(int size) { } @Override - public byte[] getAsByteArray(int nRow) { - ByteBuffer longBuffer = ByteBuffer.allocate(8 * nRow); + public byte[] getAsByteArray() { + ByteBuffer longBuffer = ByteBuffer.allocate(8 * _size); longBuffer.order(ByteOrder.LITTLE_ENDIAN); - for(int i = 0; i < nRow; i++) + for(int i = 0; i < _size; i++) longBuffer.putLong(_data[i]); return longBuffer.array(); } @@ -168,7 +175,7 @@ public long getExactSerializedSize() { } @Override - protected Array changeTypeBitSet() { + protected Array changeTypeBitSet() { BitSet ret = new BitSet(size()); for(int i = 0; i < size(); i++) { if(_data[i] != 0 && _data[i] != 1) @@ -180,7 +187,7 @@ protected Array changeTypeBitSet() { } @Override - protected Array changeTypeBoolean() { + protected Array changeTypeBoolean() { boolean[] ret = new boolean[size()]; for(int i = 0; i < size(); i++) { if(_data[i] < 0 || _data[i] > 1) @@ -192,7 +199,7 @@ protected Array changeTypeBoolean() { } @Override - protected Array changeTypeDouble() { + protected Array changeTypeDouble() { double[] ret = new double[size()]; for(int i = 0; i < size(); i++) ret[i] = (double) _data[i]; @@ -200,7 +207,7 @@ protected Array changeTypeDouble() { } @Override - protected Array changeTypeFloat() { + protected Array changeTypeFloat() { float[] ret = new float[size()]; for(int i = 0; i < size(); i++) ret[i] = (float) _data[i]; @@ -208,7 +215,7 @@ protected Array changeTypeFloat() { } @Override - protected Array changeTypeInteger() { + protected Array changeTypeInteger() { int[] ret = new int[size()]; for(int i = 0; i < size(); i++) { if(_data[i] != (long) (int) _data[i]) @@ -219,10 +226,48 @@ protected Array changeTypeInteger() { } @Override - protected Array changeTypeLong() { + protected Array changeTypeLong() { return clone(); } + @Override + protected Array changeTypeString() { + String[] ret = new String[size()]; + for(int i = 0; i < size(); i++) + ret[i] = get(i).toString(); + return new StringArray(ret); + } + + @Override + public void fill(String value) { + fill(parseLong(value)); + } + + @Override + public void fill(Long value) { + for(int i = 0; i < _size; i++) + _data[i] = value; + } + + @Override + public double getAsDouble(int i) { + return _data[i]; + } + + protected static long parseLong(String s) { + if(s == null) + return 0; + try { + return Long.parseLong(s); + } + catch(NumberFormatException e) { + if(s.contains(".")) + return Long.parseLong(s.split("\\.")[0]); + else + throw e; + } + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_data.length * 5 + 2); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java index 5996431b46e..807c21088a2 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java @@ -38,8 +38,8 @@ public class StringArray extends Array { private String[] _data; public StringArray(String[] data) { + super(data.length); _data = data; - _size = _data.length; } public String[] get() { @@ -68,7 +68,13 @@ public void set(int rl, int ru, Array value) { @Override public void setFromOtherType(int rl, int ru, Array value) { - throw new NotImplementedException(); + for(int i = rl; i <= ru; i++) { + final Object v = value.get(i); + if(v != null) + _data[i] = v.toString(); + else + _data[i] = null; + } } @Override @@ -79,11 +85,20 @@ public void set(int rl, int ru, Array value, int rlSrc) { @Override public void setNz(int rl, int ru, Array value) { String[] data2 = ((StringArray) value)._data; - for(int i = rl; i < ru + 1; i++) + for(int i = rl; i <= ru; i++) if(data2[i] != null) _data[i] = data2[i]; } + @Override + public void setFromOtherTypeNz(int rl, int ru, Array value) { + for(int i = rl; i <= ru; i++) { + Object v = value.get(i); + if(v != null) + _data[i] = v.toString(); + } + } + @Override public void append(String value) { if(_data.length <= _size) @@ -117,71 +132,6 @@ public Array slice(int rl, int ru) { return new StringArray(Arrays.copyOfRange(_data, rl, ru)); } - @Override - public Array sliceTransform(int rl, int ru, ValueType vt) { - LOG.error(rl + " " + ru + " len: " + _data.length); - try { - switch(vt) { - case BOOLEAN: - return sliceTransformBoolean(rl, ru); - case INT32: - return sliceTransformInt32(rl, ru); - case INT64: - return sliceTransformInt64(rl, ru); - case FP64: - return sliceTransformFP64(rl, ru); - case FP32: - return sliceTransformFP32(rl, ru); - default: - return slice(rl, ru); - } - } - catch(Exception e) { - LOG.error("Failed to slice with transform to " + vt); - return slice(rl, ru); - } - } - - private Array sliceTransformBoolean(int rl, int ru) { - boolean[] ret = new boolean[ru - rl]; - for(int i = rl, off = 0; i < ru; i++, off++) { - String val = _data[i].toLowerCase(); - if(val.matches("true|t|1|1\\.0+")) // if true - ret[off] = true; - else if(!val.matches("false|f|0|0\\.0+")) // if not false - throw new DMLRuntimeException("Invalid transform to boolean on: " + val); - } - return new BooleanArray(ret); - } - - private Array sliceTransformInt32(int rl, int ru) { - int[] ret = new int[ru - rl]; - for(int i = rl, off = 0; i < ru; i++, off++) - ret[off] = Integer.parseInt(_data[i]); - return new IntegerArray(ret); - } - - private Array sliceTransformInt64(int rl, int ru) { - long[] ret = new long[ru - rl]; - for(int i = rl, off = 0; i < ru; i++, off++) - ret[off] = Long.parseLong(_data[i]); - return new LongArray(ret); - } - - private Array sliceTransformFP64(int rl, int ru) { - double[] ret = new double[ru - rl]; - for(int i = rl, off = 0; i < ru; i++, off++) - ret[off] = Double.parseDouble(_data[i]); - return new DoubleArray(ret); - } - - private Array sliceTransformFP32(int rl, int ru) { - float[] ret = new float[ru - rl]; - for(int i = rl, off = 0; i < ru; i++, off++) - ret[off] = Float.parseFloat(_data[i]); - return new FloatArray(ret); - } - @Override public void reset(int size) { if(_data.length < size) @@ -190,7 +140,7 @@ public void reset(int size) { } @Override - public byte[] getAsByteArray(int nRow) { + public byte[] getAsByteArray() { throw new NotImplementedException("Not Implemented getAsByte for string"); } @@ -281,12 +231,12 @@ public long getExactSerializedSize() { } @Override - protected Array changeTypeBitSet(){ + protected Array changeTypeBitSet() { return changeTypeBoolean(); } @Override - protected Array changeTypeBoolean() { + protected Array changeTypeBoolean() { // detect type of transform. if(_data[0].toLowerCase().equals("true") || _data[0].toLowerCase().equals("false")) return changeTypeBooleanStandard(); @@ -296,14 +246,14 @@ protected Array changeTypeBoolean() { throw new DMLRuntimeException("Not supported type of Strings to change to Booleans value: " + _data[0]); } - protected Array changeTypeBooleanStandard() { + protected Array changeTypeBooleanStandard() { if(size() > ArrayFactory.bitSetSwitchPoint) return changeTypeBooleanStandardBitSet(); else return changeTypeBooleanStandardArray(); } - protected Array changeTypeBooleanStandardBitSet() { + protected Array changeTypeBooleanStandardBitSet() { BitSet ret = new BitSet(size()); for(int i = 0; i < size(); i++) ret.set(i, Boolean.parseBoolean(_data[i])); @@ -311,21 +261,21 @@ protected Array changeTypeBooleanStandardBitSet() { return new BitSetArray(ret, size()); } - protected Array changeTypeBooleanStandardArray() { + protected Array changeTypeBooleanStandardArray() { boolean[] ret = new boolean[size()]; for(int i = 0; i < size(); i++) ret[i] = Boolean.parseBoolean(_data[i]); return new BooleanArray(ret); } - protected Array changeTypeBooleanNumeric() { + protected Array changeTypeBooleanNumeric() { if(size() > ArrayFactory.bitSetSwitchPoint) - return changeTypeBooleanStandardBitSet(); + return changeTypeBooleanNumericBitSet(); else return changeTypeBooleanNumericArray(); } - protected Array changeTypeBooleanNumericBitSet() { + protected Array changeTypeBooleanNumericBitSet() { BitSet ret = new BitSet(size()); for(int i = 0; i < size(); i++) { final boolean zero = _data[i].equals("0"); @@ -338,7 +288,7 @@ protected Array changeTypeBooleanNumericBitSet() { return new BitSetArray(ret, size()); } - protected Array changeTypeBooleanNumericArray() { + protected Array changeTypeBooleanNumericArray() { boolean[] ret = new boolean[size()]; for(int i = 0; i < size(); i++) { final boolean zero = _data[i].equals("0"); @@ -352,7 +302,7 @@ protected Array changeTypeBooleanNumericArray() { } @Override - protected Array changeTypeDouble() { + protected Array changeTypeDouble() { try { double[] ret = new double[size()]; for(int i = 0; i < size(); i++) @@ -365,7 +315,7 @@ protected Array changeTypeDouble() { } @Override - protected Array changeTypeFloat() { + protected Array changeTypeFloat() { try { float[] ret = new float[size()]; for(int i = 0; i < size(); i++) @@ -378,7 +328,7 @@ protected Array changeTypeFloat() { } @Override - protected Array changeTypeInteger() { + protected Array changeTypeInteger() { try { int[] ret = new int[size()]; for(int i = 0; i < size(); i++) @@ -391,7 +341,7 @@ protected Array changeTypeInteger() { } @Override - protected Array changeTypeLong() { + protected Array changeTypeLong() { try { long[] ret = new long[size()]; for(int i = 0; i < size(); i++) @@ -404,7 +354,7 @@ protected Array changeTypeLong() { } @Override - public Array changeTypeString() { + public Array changeTypeString() { return clone(); } @@ -423,6 +373,17 @@ public Pair getMinMaxLength() { return new Pair<>(minLength, maxLength); } + @Override + public void fill(String value) { + for(int i = 0; i < _size; i++) + _data[i] = value; + } + + @Override + public double getAsDouble(int i){ + return Double.parseDouble(_data[i]); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_data.length * 5 + 2); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameFromMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameFromMatrixBlock.java index 5cece033a61..ac8baaa40b6 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameFromMatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameFromMatrixBlock.java @@ -19,108 +19,374 @@ package org.apache.sysds.runtime.frame.data.lib; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.Array; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.CommonThreadPool; import org.apache.sysds.runtime.util.UtilFunctions; public class FrameFromMatrixBlock { - public static FrameBlock convertToFrameBlock(MatrixBlock mb, ValueType vt) { - return convertToFrameBlock(mb, UtilFunctions.nCopies(mb.getNumColumns(), vt)); - } + protected static final Log LOG = LogFactory.getLog(FrameFromMatrixBlock.class.getName()); + + private final MatrixBlock mb; + private final ValueType[] schema; + private final FrameBlock frame; + + private final int blocksizeIJ = 32; // blocks of a/c+overhead in L1 cache + private final int blocksizeParallel = 64; // block size for each task - public static FrameBlock convertToFrameBlock(MatrixBlock mb, ValueType[] schema) { - if(mb.isInSparseFormat()) - return convertToFrameBlockSparse(mb, schema); + private final int m; + private final int n; + + /** Parallelization degree */ + private final int k; + private final ExecutorService pool; + + private FrameFromMatrixBlock(MatrixBlock mb, ValueType[] schema, int k) { + this.mb = mb; + m = mb.getNumRows(); + n = mb.getNumColumns(); + this.schema = schema; + this.frame = new FrameBlock(schema); + this.k = k; + if(k > 1) + pool = CommonThreadPool.get(k); else - return convertToFrameBlockDense(mb, schema); + pool = null; } - private static FrameBlock convertToFrameBlockSparse(MatrixBlock mb, ValueType[] schema) { + public static FrameBlock convertToFrameBlock(MatrixBlock mb, ValueType vt, int k) { + return new FrameFromMatrixBlock(mb, UtilFunctions.nCopies(mb.getNumColumns(), vt), k).apply(); + } + + public static FrameBlock convertToFrameBlock(MatrixBlock mb, ValueType[] schema, int k) { + return new FrameFromMatrixBlock(mb, schema, k).apply(); + } + + private FrameBlock apply() { + try { + + if(mb.isEmpty()) + convertToEmptyFrameBlock(); + else if(mb.isInSparseFormat()) + convertToFrameBlockSparse(); + else + convertToFrameBlockDense(); + if(pool != null) + pool.shutdown(); + return frame; + } + catch(InterruptedException | ExecutionException e) { + pool.shutdown(); + throw new DMLRuntimeException("failed to convert to matrix block"); + } + } + + private void convertToEmptyFrameBlock() { + frame.ensureAllocatedColumns(mb.getNumRows()); + } + + private void convertToFrameBlockSparse() { SparseBlock sblock = mb.getSparseBlock(); - FrameBlock frame = new FrameBlock(); Array[] columns = new Array[mb.getNumColumns()]; for(int i = 0; i < columns.length; i++) columns[i] = ArrayFactory.allocate(schema[i], mb.getNumRows()); - + for(int i = 0; i < mb.getNumRows(); i++) { - // Arrays.fill(row, null); // reset - if(sblock != null && !sblock.isEmpty(i)) { - int apos = sblock.pos(i); - int alen = sblock.size(i); - int[] aix = sblock.indexes(i); - double[] aval = sblock.values(i); - for(int j = apos; j < apos + alen; j++) { - columns[aix[j]].set(i, aval[j]); - } - } + if(sblock.isEmpty(i)) + continue; + + int apos = sblock.pos(i); + int alen = sblock.size(i); + int[] aix = sblock.indexes(i); + double[] aval = sblock.values(i); + for(int j = apos; j < apos + alen; j++) + columns[aix[j]].set(i, aval[j]); + } + frame.reset(); for(int i = 0; i < columns.length; i++) frame.appendColumn(columns[i]); - return frame; } - private static FrameBlock convertToFrameBlockDense(MatrixBlock mb, ValueType[] schema) { - FrameBlock frame = new FrameBlock(schema); - Object[] row = new Object[mb.getNumColumns()]; - int dFreq = UtilFunctions.frequency(schema, ValueType.FP64); + private void convertToFrameBlockDense() throws InterruptedException, ExecutionException { + // the frequency of double type + final int dFreq = UtilFunctions.frequency(schema, ValueType.FP64); - if(schema.length == 1 && dFreq == 1 && mb.isAllocated()) { - // special case double schema and single columns which - // allows for a shallow copy since the physical representation - // of row-major matrix and column-major frame match exactly - frame.reset(); - frame.appendColumns(new double[][] {mb.getDenseBlockValues()}); + if(schema.length == 1) { + if(dFreq == 1) + convertToFrameDenseSingleColDouble(); + else + convertToFrameDenseSingleColOther(schema[0]); } - else if(dFreq == schema.length) { - // special case double schema (without cell-object creation, - // col pre-allocation, and cache-friendly row-column copy) - int m = mb.getNumRows(); - int n = mb.getNumColumns(); - double[][] c = new double[n][m]; - int blocksizeIJ = 32; // blocks of a/c+overhead in L1 cache - if(!mb.isEmptyBlock(false)) { - if(mb.getDenseBlock().isContiguous()) { - double[] a = mb.getDenseBlockValues(); - for(int bi = 0; bi < m; bi += blocksizeIJ) - for(int bj = 0; bj < n; bj += blocksizeIJ) { - int bimin = Math.min(bi + blocksizeIJ, m); - int bjmin = Math.min(bj + blocksizeIJ, n); - for(int i = bi, aix = bi * n; i < bimin; i++, aix += n) - for(int j = bj; j < bjmin; j++) - c[j][i] = a[aix + j]; - } - } - else { // large dense blocks - DenseBlock a = mb.getDenseBlock(); - for(int bi = 0; bi < m; bi += blocksizeIJ) - for(int bj = 0; bj < n; bj += blocksizeIJ) { - int bimin = Math.min(bi + blocksizeIJ, m); - int bjmin = Math.min(bj + blocksizeIJ, n); - for(int i = bi; i < bimin; i++) { - double[] avals = a.values(i); - int apos = a.pos(i); - for(int j = bj; j < bjmin; j++) - c[j][i] = avals[apos + j]; - } - } + else if(dFreq == schema.length) + convertToFrameDenseMultiColDouble(); + else + convertToFrameDenseMultiColGeneric(); + + } + + private void convertToFrameDenseSingleColDouble() { + frame.reset(); + frame.appendColumn(mb.getDenseBlockValues()); + } + + private void convertToFrameDenseSingleColOther(ValueType vt) { + Array d = ArrayFactory.create(mb.getDenseBlockValues()); + frame.reset(); + frame.appendColumn(d.changeType(vt)); + } + + private void convertToFrameDenseMultiColDouble() throws InterruptedException, ExecutionException { + double[][] c = (mb.getDenseBlock().isContiguous()) // + ? convertToFrameDenseMultiColContiguous() // + : convertToFrameDenseMultiColMultiBlock(); + + frame.reset(); + frame.appendColumns(c); + } + + private double[][] convertToFrameDenseMultiColContiguous() throws InterruptedException, ExecutionException { + return k == 1 // + ? convertToFrameDenseMultiColContiguousSingleThread() // + : convertToFrameDenseMultiColContiguousMultiThread(); + } + + private double[][] convertToFrameDenseMultiColContiguousSingleThread() { + + final double[][] c = new double[n][m]; + final double[] a = mb.getDenseBlockValues(); + for(int bi = 0; bi < m; bi += blocksizeIJ) { + for(int bj = 0; bj < n; bj += blocksizeIJ) { + int bimin = Math.min(bi + blocksizeIJ, m); + int bjmin = Math.min(bj + blocksizeIJ, n); + for(int i = bi, aix = bi * n; i < bimin; i++, aix += n) + for(int j = bj; j < bjmin; j++) + c[j][i] = a[aix + j]; + } + } + return c; + } + + private double[][] convertToFrameDenseMultiColContiguousMultiThread() + throws InterruptedException, ExecutionException { + + final double[][] c = new double[n][m]; + + final List tasks = new ArrayList<>(); + for(int bi = 0; bi < m; bi += blocksizeParallel) + for(int bj = 0; bj < n; bj += blocksizeParallel) + tasks.add(new CVB(bi, bj, c)); + + for(Future rt : pool.invokeAll(tasks)) + rt.get(); + return c; + + } + + protected class CVB implements Callable { + + private final int bi; + private final int bj; + private final double[][] c; + + protected CVB(int bi, int bj, double[][] c) { + this.bi = bi; + this.bj = bj; + this.c = c; + } + + @Override + public Object call() { + final double[] a = mb.getDenseBlockValues(); + int bimin = Math.min(bi + blocksizeParallel, m); + int bjmin = Math.min(bj + blocksizeParallel, n); + for(int i = bi, aix = bi * n; i < bimin; i++, aix += n) + for(int j = bj; j < bjmin; j++) + c[j][i] = a[aix + j]; + return null; + } + } + + private double[][] convertToFrameDenseMultiColMultiBlock() throws InterruptedException, ExecutionException { + return k == 1 // + ? convertToFrameDenseMultiColMultiBlockSingleThread() // + : convertToFrameDenseMultiColMultiBlockMultiThread();// + } + + private double[][] convertToFrameDenseMultiColMultiBlockSingleThread() { + + final double[][] c = new double[n][m]; + final DenseBlock a = mb.getDenseBlock(); + for(int bi = 0; bi < m; bi += blocksizeIJ) { + for(int bj = 0; bj < n; bj += blocksizeIJ) { + int bimin = Math.min(bi + blocksizeIJ, m); + int bjmin = Math.min(bj + blocksizeIJ, n); + for(int i = bi; i < bimin; i++) { + double[] avals = a.values(i); + int apos = a.pos(i); + for(int j = bj; j < bjmin; j++) + c[j][i] = avals[apos + j]; } } - frame.reset(); - frame.appendColumns(c); } - else { // general case - for(int i = 0; i < mb.getNumRows(); i++) { - for(int j = 0; j < mb.getNumColumns(); j++) - row[j] = UtilFunctions.doubleToObject(schema[j], mb.quickGetValue(i, j)); + return c; + } + + private double[][] convertToFrameDenseMultiColMultiBlockMultiThread() + throws InterruptedException, ExecutionException { + + final double[][] c = new double[n][m]; + + final List tasks = new ArrayList<>(); + for(int bi = 0; bi < m; bi += blocksizeParallel) + for(int bj = 0; bj < n; bj += blocksizeParallel) + tasks.add(new CVMB(bi, bj, c)); + + for(Future rt : pool.invokeAll(tasks)) + rt.get(); + return c; + } + + protected class CVMB implements Callable { + + private final int bi; + private final int bj; + private final double[][] c; + + protected CVMB(int bi, int bj, double[][] c) { + this.bi = bi; + this.bj = bj; + this.c = c; + } + + @Override + public Object call() { + final DenseBlock a = mb.getDenseBlock(); + int bimin = Math.min(bi + blocksizeParallel, m); + int bjmin = Math.min(bj + blocksizeParallel, n); + for(int i = bi; i < bimin; i++) { + double[] avals = a.values(i); + int apos = a.pos(i); + for(int j = bj; j < bjmin; j++) + c[j][i] = avals[apos + j]; + } + return null; + } + } + + private void convertToFrameDenseMultiColGeneric() throws InterruptedException, ExecutionException { + Array[] c = (mb.getDenseBlock().isContiguous()) // + ? convertToFrameDenseMultiColGenericContiguous() // + : convertToFrameDenseMultiColGenericMultiBlock(); + + frame.reset(); + for(Array col : c) + frame.appendColumn(col); + } + + private Array[] convertToFrameDenseMultiColGenericContiguous() throws InterruptedException, ExecutionException { + return k == 1 // + ? convertToFrameDenseMultiColGenericContiguousSingleThread() // + : convertToFrameDenseMultiColGenericContiguousMultiThread();// + } + + private Array[] convertToFrameDenseMultiColGenericContiguousSingleThread() { + + final Array[] c = new Array[n]; + + for(int i = 0; i < n; i++) + c[i] = ArrayFactory.allocate(schema[i], m); + + final double[] a = mb.getDenseBlockValues(); + for(int bi = 0; bi < m; bi += blocksizeIJ) { + for(int bj = 0; bj < n; bj += blocksizeIJ) { + int bimin = Math.min(bi + blocksizeIJ, m); + int bjmin = Math.min(bj + blocksizeIJ, n); + for(int i = bi, aix = bi * n; i < bimin; i++, aix += n) + for(int j = bj; j < bjmin; j++) + c[j].set(i, a[aix + j]); + } + } + return c; + + } + + private Array[] convertToFrameDenseMultiColGenericContiguousMultiThread() + throws InterruptedException, ExecutionException { + + final Array[] c = new Array[n]; + + for(int i = 0; i < n; i++) + c[i] = ArrayFactory.allocate(schema[i], m); + + final List tasks = new ArrayList<>(); + for(int bi = 0; bi < m; bi += blocksizeParallel) + for(int bj = 0; bj < n; bj += blocksizeParallel) + tasks.add(new CVTAB(bi, bj, c)); - frame.appendRow(row); + for(Future rt : pool.invokeAll(tasks)) + rt.get(); + return c; + } + + protected class CVTAB implements Callable { + + private final int bi; + private final int bj; + private final Array[] c; + + protected CVTAB(int bi, int bj, Array[] c) { + this.bi = bi; + this.bj = bj; + this.c = c; + } + + @Override + public Object call() { + final double[] a = mb.getDenseBlockValues(); + int bimin = Math.min(bi + blocksizeParallel, m); + int bjmin = Math.min(bj + blocksizeParallel, n); + for(int i = bi, aix = bi * n; i < bimin; i++, aix += n) + for(int j = bj; j < bjmin; j++) + c[j].set(i, a[aix + j]); + return null; + } + } + + private Array[] convertToFrameDenseMultiColGenericMultiBlock() { + final Array[] c = new Array[n]; + for(int i = 0; i < n; i++) + c[i] = ArrayFactory.allocate(schema[i], m); + + final DenseBlock a = mb.getDenseBlock(); + for(int bi = 0; bi < m; bi += blocksizeIJ) { + for(int bj = 0; bj < n; bj += blocksizeIJ) { + final int bimin = Math.min(bi + blocksizeIJ, m); + final int bjmin = Math.min(bj + blocksizeIJ, n); + for(int i = bi; i < bimin; i++) { + final double[] avals = a.values(i); + int apos = a.pos(i); + for(int j = bj; j < bjmin; j++) + c[j].set(i, avals[apos + j]); + } } } - return frame; + return c; } + } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java index 57c79a49a9e..6fc4515b538 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibApplySchema.java @@ -42,6 +42,8 @@ public class FrameLibApplySchema { private final Array[] columnsIn; private final Array[] columnsOut; + private final int k; + /** * Method to create a new FrameBlock where the given schema is applied, k is parallelization degree. * @@ -51,24 +53,24 @@ public class FrameLibApplySchema { * @return A new FrameBlock allocated with new arrays. */ public static FrameBlock applySchema(FrameBlock fb, ValueType[] schema, int k) { - return new FrameLibApplySchema(fb, schema).apply(k); + return new FrameLibApplySchema(fb, schema, k).apply(); } - private FrameLibApplySchema(FrameBlock fb, ValueType[] schema) { + private FrameLibApplySchema(FrameBlock fb, ValueType[] schema, int k) { this.fb = fb; this.schema = schema; + this.k = k; verifySize(); nCol = fb.getNumColumns(); columnsIn = fb.getColumns(); columnsOut = new Array[nCol]; - } - private FrameBlock apply(int k) { + private FrameBlock apply() { if(k <= 1 || nCol == 1) applySingleThread(); else - applyMultiThread(k); + applyMultiThread(); final String[] colNames = fb.getColumnNames(false); final ColumnMetadata[] meta = fb.getColumnMetadata(); @@ -80,7 +82,7 @@ private void applySingleThread() { columnsOut[i] = columnsIn[i].changeType(schema[i]); } - private void applyMultiThread(int k) { + private void applyMultiThread() { final ExecutorService pool = CommonThreadPool.get(k); try { diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java index 3219617f27a..6c21359ea6c 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibDetectSchema.java @@ -36,40 +36,58 @@ public final class FrameLibDetectSchema { // private static final Log LOG = LogFactory.getLog(FrameLibDetectSchema.class.getName()); - private FrameLibDetectSchema() { - // private constructor + private final FrameBlock in; + // private final double sampleFraction; + private final int k; + + private FrameLibDetectSchema(FrameBlock in, double sampleFraction, int k) { + this.in = in; + // this.sampleFraction = sampleFraction; + this.k = k; + } + + public static FrameBlock detectSchema(FrameBlock in, int k) { + return new FrameLibDetectSchema(in, 1.0, k).apply(); } - public static FrameBlock detectSchema(FrameBlock in) { - return detectSchema(in, 1.0); + public static FrameBlock detectSchema(FrameBlock in, double sampleFraction, int k) { + return new FrameLibDetectSchema(in, sampleFraction, k).apply(); } - public static FrameBlock detectSchema(FrameBlock in, double sampleFraction) { - // LOG.error(Arrays.toString(in.getSchema())); + private FrameBlock apply() { final int cols = in.getNumColumns(); - ArrayList tasks = new ArrayList<>(cols); - for(int i = 0; i < cols; i++) - tasks.add(new DetectValueTypeTask(in.getColumn(i))); + final FrameBlock fb = new FrameBlock(UtilFunctions.nCopies(cols, ValueType.STRING)); + String[] schemaInfo = (k == 1) ? singleThreadApply() : parallelApply(); + fb.appendRow(schemaInfo); + return fb; + } - List> ret; + private String[] singleThreadApply() { + final int cols = in.getNumColumns(); + final String[] schemaInfo = new String[cols]; + for(int i = 0; i < cols; i++) { + schemaInfo[i] = in.getColumn(i).analyzeValueType().toString(); + } + return schemaInfo; + } - ExecutorService pool = CommonThreadPool.get(cols); + private String[] parallelApply() { + final ExecutorService pool = CommonThreadPool.get(k); try { - ret = pool.invokeAll(tasks); - final FrameBlock fb = new FrameBlock(UtilFunctions.nCopies(cols, ValueType.STRING)); + final int cols = in.getNumColumns(); + final ArrayList tasks = new ArrayList<>(cols); + for(int i = 0; i < cols; i++) + tasks.add(new DetectValueTypeTask(in.getColumn(i))); + final List> ret = pool.invokeAll(tasks); final String[] schemaInfo = new String[cols]; pool.shutdown(); for(int i = 0; i < cols; i++) schemaInfo[i] = ret.get(i).get().toString(); - - fb.appendRow(schemaInfo); - return fb; + return schemaInfo; } catch(ExecutionException | InterruptedException e) { - throw new DMLRuntimeException("Exception Interupted or Exception thrown in Detect Schema", e); - } - finally { pool.shutdown(); + throw new DMLRuntimeException("Exception interrupted or exception thrown in detectSchema", e); } } @@ -82,7 +100,7 @@ protected DetectValueTypeTask(Array obj) { @Override public ValueType call() { - return _obj.analyzeValueType(); + return _obj.analyzeValueType(); } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java index f1ba7410505..1dc7b068b8b 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java @@ -21,6 +21,8 @@ import java.util.HashMap; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.common.Types.OpOp1; import org.apache.sysds.hops.FunctionOp; @@ -71,10 +73,10 @@ import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction; import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction; import org.apache.sysds.runtime.instructions.cpfile.MatrixIndexingCPFileInstruction; -import org.apache.sysds.runtime.util.UtilFunctions; -public class CPInstructionParser extends InstructionParser -{ +public class CPInstructionParser extends InstructionParser { + protected static final Log LOG = LogFactory.getLog(CPInstructionParser.class.getName()); + public static final HashMap String2CPInstructionType; static { String2CPInstructionType = new HashMap<>(); @@ -435,17 +437,12 @@ public static CPInstruction parseSingleInstruction ( CPType cptype, String str ) case Builtin: String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); - if ( parts[0].equals("log") || parts[0].equals("log_nz") ) { - if ( parts.length == 3 || (parts.length == 5 && - UtilFunctions.isIntegerNumber(parts[3])) ) { - // B=log(A), y=log(x) + if(parts[0].equals("log") || parts[0].equals("log_nz")) { + if(InstructionUtils.isInteger(parts[3])) // B=log(A), y=log(x) + // We exploit the fact the number of threads is specified as an integer at parts 3. return UnaryCPInstruction.parseInstruction(str); - } else if ( parts.length == 4 || (parts.length == 5 && - UtilFunctions.isIntegerNumber(parts[4])) ) { - // B=log(A,10), y=log(x,10) - // num threads non-existing for scalar-scalar + else // B=log(A,10), y=log(x,10) return BinaryCPInstruction.parseInstruction(str); - } } throw new DMLRuntimeException("Invalid Builtin Instruction: " + str ); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java index e4e641e8f13..cd30034019b 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java @@ -23,6 +23,8 @@ import java.util.StringTokenizer; import org.apache.commons.lang.ArrayUtils; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types; import org.apache.sysds.common.Types.AggOp; import org.apache.sysds.common.Types.CorrectionLocationType; @@ -108,8 +110,9 @@ import org.apache.sysds.runtime.matrix.operators.UnarySketchOperator; -public class InstructionUtils -{ +public class InstructionUtils { + protected static final Log LOG = LogFactory.getLog(InstructionUtils.class.getName()); + //thread-local string builders for instruction concatenation (avoid allocation) private static ThreadLocal _strBuilders = new ThreadLocal() { @Override @@ -581,6 +584,12 @@ public static UnaryOperator parseUnaryOperator(String opcode) { new UnaryOperator(Builtin.getBuiltinFnObject(opcode)); } + public static UnaryOperator parseUnaryOperator(String opcode, int k) { + return opcode.equals("!") ? + new UnaryOperator(Not.getNotFnObject(), k) : + new UnaryOperator(Builtin.getBuiltinFnObject(opcode), k); + } + public static MultiThreadedOperator parseBinaryOrBuiltinOperator(String opcode, CPOperand in1, CPOperand in2) { if( LibCommonsMath.isSupportedMatrixMatrixOperation(opcode) ) return null; @@ -1241,4 +1250,16 @@ private static String replaceOperand(String linst, CPOperand oldOperand, String Lop.OPERAND_DELIMITOR+oldOperand.getName()+Lop.DATATYPE_PREFIX, Lop.OPERAND_DELIMITOR+newOperandName+Lop.DATATYPE_PREFIX); } + + protected static boolean isInteger(String s){ + if(s.isEmpty()) return false; + for(int i = 0; i < s.length(); i++) { + if(i == 0 && s.charAt(i) == '-') { + if(s.length() == 1) return false; + else continue; + } + if(Character.digit(s.charAt(i),10) < 0) return false; + } + return true; + } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java index 4c2ae8a2d5a..fddb8301a96 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java @@ -41,12 +41,15 @@ public static BinaryCPInstruction parseInstruction( String str ) { CPOperand in1 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); CPOperand out = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN); - String opcode = parseBinaryInstruction(str, in1, in2, out); + final String[] parts = parseBinaryInstruction(str, in1, in2, out); + final String opcode = parts[0]; if(!(in1.getDataType() == DataType.FRAME || in2.getDataType() == DataType.FRAME)) checkOutputDataType(in1, in2, out); MultiThreadedOperator operator = InstructionUtils.parseBinaryOrBuiltinOperator(opcode, in1, in2); + if(parts.length == 5 && operator != null) + operator.setNumThreads(Integer.parseInt(parts[4])); if (in1.getDataType() == DataType.SCALAR && in2.getDataType() == DataType.SCALAR) return new BinaryScalarScalarCPInstruction(operator, in1, in2, out, opcode, str); @@ -62,29 +65,17 @@ else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.MA return new BinaryMatrixScalarCPInstruction(operator, in1, in2, out, opcode, str); } - protected static String parseBinaryInstruction(String instr, CPOperand in1, CPOperand in2, CPOperand out) { + private static String[] parseBinaryInstruction(String instr, CPOperand in1, CPOperand in2, CPOperand out) { String[] parts = InstructionUtils.getInstructionPartsWithValueType(instr); - InstructionUtils.checkNumFields ( parts, 3, 4 ); - String opcode = parts[0]; + InstructionUtils.checkNumFields ( parts, 3, 4, 5 ); in1.split(parts[1]); in2.split(parts[2]); out.split(parts[3]); - return opcode; - } - - protected static String parseBinaryInstruction(String instr, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out) { - String[] parts = InstructionUtils.getInstructionPartsWithValueType(instr); - InstructionUtils.checkNumFields ( parts, 4 ); - String opcode = parts[0]; - in1.split(parts[1]); - in2.split(parts[2]); - in3.split(parts[3]); - out.split(parts[4]); - - return opcode; + return parts; } + public Operator getOperator() { return _optr; } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java index bd0ad427c10..fd749f84300 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameFrameCPInstruction.java @@ -62,7 +62,7 @@ else if(getOpcode().equals("applySchema")) { ValueType[] schema = new ValueType[inBlock2.getNumColumns()]; for(int i=0; i 0)?data.length/lrows:0; if(data.length != schemaLength && data.length > 1 && rowLength != schemaLength) throw new DMLRuntimeException( @@ -375,12 +376,9 @@ else if(data.length > 1 && rowLength == schemaLength) outF.appendRow(data1); } } - else { - String[] data1 = new String[lcols]; - Arrays.fill(data1, frame_data); - for(int i = 0; i < lrows; i++) - outF.appendRow(data1); - } + else + out = new FrameBlock(vt, frame_data, lrows); + } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java index fa0d1c0e831..96e4b7afe2d 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/PrefetchCPInstruction.java @@ -34,11 +34,12 @@ private PrefetchCPInstruction(Operator op, CPOperand in, CPOperand out, String o } public static PrefetchCPInstruction parseInstruction (String str) { - InstructionUtils.checkNumFields(str, 2); + InstructionUtils.checkNumFields(str, 3); String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); String opcode = parts[0]; CPOperand in = new CPOperand(parts[1]); CPOperand out = new CPOperand(parts[2]); + // int k = Integer.parseInt(parts[3]); return new PrefetchCPInstruction(null, in, out, opcode, str); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryCPInstruction.java index 8f92c07ee8e..a55c6448783 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryCPInstruction.java @@ -69,6 +69,18 @@ public static UnaryCPInstruction parseInstruction ( String str ) { else return new UnaryScalarCPInstruction(null, in, out, opcode, str); } + else if(parts.length==4){ + opcode = parseUnaryInstructionWithThreads(str, in, out); + int k = Integer.parseInt(parts[3]); + if(in.getDataType() == DataType.SCALAR) + return new UnaryScalarCPInstruction(InstructionUtils.parseUnaryOperator(opcode, k), in, out, opcode, str); + else if(in.getDataType() == DataType.MATRIX) + return new UnaryMatrixCPInstruction( + LibCommonsMath.isSupportedUnaryOperation(opcode) ? null : InstructionUtils.parseUnaryOperator(opcode, k), + in, out, opcode, str); + else if(in.getDataType() == DataType.FRAME) + return new UnaryFrameCPInstruction(InstructionUtils.parseUnaryOperator(opcode, k), in, out, opcode, str); + } else { //2+1, general case opcode = parseUnaryInstruction(str, in, out); @@ -83,6 +95,15 @@ else if(in.getDataType() == DataType.FRAME) return null; } + private static String parseUnaryInstructionWithThreads(String instr, CPOperand in, CPOperand out){ + InstructionUtils.checkNumFields(instr, 3); + String[] parts = InstructionUtils.getInstructionPartsWithValueType(instr); + String opcode = parts[0]; + out.split(parts[parts.length-2]); + in.split(parts[1]); + return opcode; + } + static String parseUnaryInstruction(String instr, CPOperand in, CPOperand out) { InstructionUtils.checkNumFields(instr, 2); return parse(instr, in, null, null, out); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java index 7ab354405dd..c572c0fb893 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryFrameCPInstruction.java @@ -22,10 +22,12 @@ import org.apache.sysds.runtime.DMLScriptException; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.frame.data.FrameBlock; -import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.matrix.operators.MultiThreadedOperator; public class UnaryFrameCPInstruction extends UnaryCPInstruction { - protected UnaryFrameCPInstruction(Operator op, CPOperand in, CPOperand out, String opcode, String instr) { + + protected UnaryFrameCPInstruction(MultiThreadedOperator op, CPOperand in, CPOperand out, String opcode, + String instr) { super(CPType.Unary, op, in, out, opcode, instr); } @@ -39,7 +41,7 @@ public void processInstruction(ExecutionContext ec) { } else if(getOpcode().equals("detectSchema")) { FrameBlock inBlock = ec.getFrameInput(input1.getName()); - FrameBlock retBlock = inBlock.detectSchema(); + FrameBlock retBlock = inBlock.detectSchema(((MultiThreadedOperator) _optr).getNumThreads()); ec.releaseFrameInput(input1.getName()); ec.setFrameOutput(output.getName(), retBlock); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java index 26e99c89820..6026f6692aa 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java @@ -48,17 +48,18 @@ import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.io.FileFormatProperties; import org.apache.sysds.runtime.io.FileFormatPropertiesCSV; -import org.apache.sysds.runtime.io.FileFormatPropertiesLIBSVM; import org.apache.sysds.runtime.io.FileFormatPropertiesHDF5; +import org.apache.sysds.runtime.io.FileFormatPropertiesLIBSVM; import org.apache.sysds.runtime.io.ListReader; import org.apache.sysds.runtime.io.ListWriter; +import org.apache.sysds.runtime.io.WriterHDF5; import org.apache.sysds.runtime.io.WriterMatrixMarket; import org.apache.sysds.runtime.io.WriterTextCSV; -import org.apache.sysds.runtime.io.WriterHDF5; import org.apache.sysds.runtime.lineage.LineageItem; import org.apache.sysds.runtime.lineage.LineageItemUtils; import org.apache.sysds.runtime.lineage.LineageTraceable; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.MultiThreadedOperator; import org.apache.sysds.runtime.meta.DataCharacteristics; import org.apache.sysds.runtime.meta.MatrixCharacteristics; import org.apache.sysds.runtime.meta.MetaData; @@ -104,7 +105,22 @@ public enum VariableOperationCode { CastAsBooleanVariable, Write, Read, - SetFileName, + SetFileName; + + public boolean isCast() { + switch(this) { + case CastAsScalarVariable: + case CastAsMatrixVariable: + case CastAsFrameVariable: + case CastAsListVariable: + case CastAsDoubleVariable: + case CastAsIntegerVariable: + case CastAsBooleanVariable: + return true; + default: + return false; + } + } } private static final IDSequence _uniqueVarID = new IDSequence(true); @@ -120,11 +136,14 @@ public enum VariableOperationCode { // Frame related members private final String _schema; + // parallelization degree for non IO related operations + private final int k; + // CSV and LIBSVM related members (used only in createvar instructions) private final FileFormatProperties _formatProperties; private VariableCPInstruction(VariableOperationCode op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, - MetaData meta, FileFormatProperties fprops, String schema, UpdateType utype, String sopcode, String istr) { + MetaData meta, FileFormatProperties fprops, String schema, UpdateType utype, String sopcode, String istr, int k) { super(CPType.Variable, sopcode, istr); opcode = op; inputs = new ArrayList<>(); @@ -138,11 +157,23 @@ private VariableCPInstruction(VariableOperationCode op, CPOperand in1, CPOperand _updateType = utype; _containsPreadPrefix = in1 != null && in1.getName() .contains(org.apache.sysds.lops.Data.PREAD_PREFIX); + this.k = k; } + private VariableCPInstruction(VariableOperationCode op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, - String sopcode, String istr) { - this(op, in1, in2, in3, out, null, null, null, null, sopcode, istr); + MetaData meta, FileFormatProperties fprops, String schema, UpdateType utype, String sopcode, String istr) { + this(op ,in1,in2,in3,out,meta, fprops, schema, utype, sopcode, istr, 1); + } + + private VariableCPInstruction(VariableOperationCode op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, + String sopcode, String istr) { + this(op, in1, in2, in3, out, null, null, null, null, sopcode, istr, 1); + } + + private VariableCPInstruction(VariableOperationCode op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, + String sopcode, String istr, int k) { + this(op, in1, in2, in3, out, null, null, null, null, sopcode, istr, k); } // This version of the constructor is used only in case of CreateVariable @@ -294,6 +325,8 @@ public CPOperand getOutput(){ } private static int getArity(VariableOperationCode op) { + if(op.isCast()) + return 3; switch(op) { case Write: case SetFileName: @@ -326,11 +359,17 @@ else if ( voc == VariableOperationCode.Write ) { throw new DMLRuntimeException("Invalid number of operands in write instruction: " + str); } else { - if( voc != VariableOperationCode.RemoveVariable ) - InstructionUtils.checkNumFields ( parts, getArity(voc) ); // no output + try{ + if( voc != VariableOperationCode.RemoveVariable ) + InstructionUtils.checkNumFields ( parts, getArity(voc) ); // no output + } + catch(Exception e){ + throw new DMLRuntimeException("Invalid number of fields with operation code: " + voc, e); + } } CPOperand in1=null, in2=null, in3=null, in4=null, out=null; + int k = 1; switch (voc) { @@ -515,6 +554,7 @@ else if(fmt.equalsIgnoreCase("hdf5")) { case CastAsBooleanVariable: in1 = new CPOperand(parts[1]); // first operand is a variable name => string value type out = new CPOperand(parts[2]); // output variable name + k = Integer.parseInt(parts[3]); // thread count break; case Write: @@ -562,7 +602,7 @@ else if(in3.getName().equalsIgnoreCase("hdf5") ){ break; } - return new VariableCPInstruction(getVariableOperationCode(opcode), in1, in2, in3, out, opcode, str); + return new VariableCPInstruction(getVariableOperationCode(opcode), in1, in2, in3, out, opcode, str, k); } @Override @@ -928,7 +968,7 @@ private void processCastAsFrameVariableInstruction(ExecutionContext ec){ } else if(getInput1().getDataType()==DataType.MATRIX) { //DataType.FRAME MatrixBlock min = ec.getMatrixInput(getInput1().getName()); - out = DataConverter.convertToFrameBlock(min); + out = DataConverter.convertToFrameBlock(min, k); ec.releaseMatrixInput(getInput1().getName()); ec.setFrameOutput(output.getName(), out); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/UnaryFrameSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/UnaryFrameSPInstruction.java index 1b9224fefd9..75a226d1e6a 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/UnaryFrameSPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/UnaryFrameSPInstruction.java @@ -78,7 +78,7 @@ private static class DetectSchemaUsingRows implements PairFunction call(Tuple2 arg0) throws Exception { - FrameBlock resultBlock = new FrameBlock(arg0._2.detectSchema()); + FrameBlock resultBlock = new FrameBlock(arg0._2.detectSchema(1)); return new Tuple2<>(1L, resultBlock); } } diff --git a/src/main/java/org/apache/sysds/runtime/io/FrameWriterTextCSV.java b/src/main/java/org/apache/sysds/runtime/io/FrameWriterTextCSV.java index 92abe0932d7..418da607654 100644 --- a/src/main/java/org/apache/sysds/runtime/io/FrameWriterTextCSV.java +++ b/src/main/java/org/apache/sysds/runtime/io/FrameWriterTextCSV.java @@ -24,6 +24,8 @@ import java.io.OutputStreamWriter; import java.util.Iterator; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.mapred.JobConf; @@ -38,8 +40,9 @@ * Single-threaded frame text csv writer. * */ -public class FrameWriterTextCSV extends FrameWriter -{ +public class FrameWriterTextCSV extends FrameWriter{ + protected static final Log LOG = LogFactory.getLog(FrameWriterTextCSV.class.getName()); + //blocksize for string concatenation in order to prevent write OOM //(can be set to very large value to disable blocking) public static final int BLOCKSIZE_J = 32; //32 cells (typically ~512B, should be less than write buffer of 1KB) diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java index f09be4b4bf9..26513279578 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java @@ -1317,8 +1317,6 @@ private static void c2r(MatrixBlock in, int k){ int a = m/c; int b = n/c; - // LOG.error("c2r"); - double[] tmp = memPool.get(); if(tmp == null) { memPool.set(new double[Math.max(m,n)]); @@ -1633,7 +1631,6 @@ private static void sj_inv(double[] tmp, double[] A, int j, int a, int n, int m) int sji = ((j + i - (off/a)) % m ); tmp[sji] = A[i + j]; } - // LOG.error(Arrays.toString(sjiList)); for (int i = j, off = 0; i< m * n; i+= n, off++){ A[i] = tmp[off]; } diff --git a/src/main/java/org/apache/sysds/runtime/matrix/operators/UnaryOperator.java b/src/main/java/org/apache/sysds/runtime/matrix/operators/UnaryOperator.java index d1d2abc56e5..8f5e2ff6d69 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/operators/UnaryOperator.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/UnaryOperator.java @@ -34,6 +34,10 @@ public class UnaryOperator extends MultiThreadedOperator public UnaryOperator(ValueFunction p) { this(p, 1, false); //default single-threaded } + + public UnaryOperator(ValueFunction p, int k){ + this(p, k, false); // multithreaded + } public UnaryOperator(ValueFunction p, int numThreads, boolean inPlace) { super(p instanceof Builtin && diff --git a/src/main/java/org/apache/sysds/runtime/util/DataConverter.java b/src/main/java/org/apache/sysds/runtime/util/DataConverter.java index c3426fce17f..2817569dad2 100644 --- a/src/main/java/org/apache/sysds/runtime/util/DataConverter.java +++ b/src/main/java/org/apache/sysds/runtime/util/DataConverter.java @@ -754,7 +754,29 @@ public static FrameBlock convertToFrameBlock(String[][] data, ValueType[] schema * @return frame block of type double */ public static FrameBlock convertToFrameBlock(MatrixBlock mb) { - return convertToFrameBlock(mb, ValueType.FP64); + return convertToFrameBlock(mb, ValueType.FP64, 1); + } + + /** + * Converts a matrix block into a frame block of value type double. + * + * @param mb matrix block + * @param k parallelization degree + * @return frame block of type double + */ + public static FrameBlock convertToFrameBlock(MatrixBlock mb, int k) { + return convertToFrameBlock(mb, ValueType.FP64, k); + } + + /** + * Converts a matrix block into a frame block of value type given. + * + * @param mb matrix block + * @param vt value type target + * @return frame block of type given + */ + public static FrameBlock convertToFrameBlock(MatrixBlock mb, ValueType vt) { + return FrameFromMatrixBlock.convertToFrameBlock(mb, vt, 1); } /** @@ -762,14 +784,34 @@ public static FrameBlock convertToFrameBlock(MatrixBlock mb) { * * @param mb matrix block * @param vt value type - * @return frame block + * @param k parallelization degree + * @return a return frame block with the given schema */ - public static FrameBlock convertToFrameBlock(MatrixBlock mb, ValueType vt) { - return FrameFromMatrixBlock.convertToFrameBlock(mb, vt); + public static FrameBlock convertToFrameBlock(MatrixBlock mb, ValueType vt, int k) { + return FrameFromMatrixBlock.convertToFrameBlock(mb, vt, k); } - public static FrameBlock convertToFrameBlock(MatrixBlock mb, ValueType[] schema){ - return FrameFromMatrixBlock.convertToFrameBlock(mb, schema); + /** + * Converts a matrix block into a frame block of with the given schema + * + * @param mb matrix block + * @param schema schema + * @return a return frame block with the given schema + */ + public static FrameBlock convertToFrameBlock(MatrixBlock mb, ValueType[] schema) { + return FrameFromMatrixBlock.convertToFrameBlock(mb, schema, 1); + } + + /** + * Converts a matrix block into a frame block of with the given schema + * + * @param mb matrix block + * @param schema schema + * @param k parallelization degree + * @return a return frame block with the given schema + */ + public static FrameBlock convertToFrameBlock(MatrixBlock mb, ValueType[] schema, int k) { + return FrameFromMatrixBlock.convertToFrameBlock(mb, schema, k); } public static TensorBlock convertToTensorBlock(MatrixBlock mb, ValueType vt, boolean toBasicTensor) { diff --git a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java index 16a86ae5285..dd080cdfac7 100644 --- a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java +++ b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java @@ -515,6 +515,90 @@ public static double objectToDouble(ValueType vt, Object in) { } } + public static float objectToFloat(ValueType vt, Object in) { + if(in == null) + return Float.NaN; + switch(vt) { + case FP64: + return ((Double) in).floatValue(); + case FP32: + return (Float) in; + case INT64: + return (Long) in; + case INT32: + return (Integer) in; + case BOOLEAN: + return ((Boolean) in) ? 1 : 0; + case STRING: + return !((String) in).isEmpty() ? Float.parseFloat((String) in) : 0; + default: + throw new DMLRuntimeException("Unsupported value type: " + vt); + } + } + + public static int objectToInteger(ValueType vt, Object in) { + if(in == null) + return 0; + switch(vt) { + case FP64: + return ((Double) in).intValue(); + case FP32: + return ((Float) in).intValue(); + case INT64: + return ((Long) in).intValue(); + case INT32: + return (Integer) in; + case BOOLEAN: + return ((Boolean) in) ? 1 : 0; + case STRING: + return !((String) in).isEmpty() ? Integer.parseInt((String) in) : 0; + default: + throw new DMLRuntimeException("Unsupported value type: " + vt); + } + } + + public static long objectToLong(ValueType vt, Object in) { + if(in == null) + return 0; + switch(vt) { + case FP64: + return ((Double) in).longValue(); + case FP32: + return ((Float) in).longValue(); + case INT64: + return (Long) in; + case INT32: + return (Integer) in; + case BOOLEAN: + return ((Boolean) in) ? 1 : 0; + case STRING: + return !((String) in).isEmpty() ? Long.parseLong((String) in) : 0; + default: + throw new DMLRuntimeException("Unsupported value type: " + vt); + } + } + + public static boolean objectToBoolean(ValueType vt, Object in) { + if(in == null) + return false; + switch(vt) { + case FP64: + return ((Double) in) == 1.0; + case FP32: + return ((Float) in) == 1.0; + case INT64: + return (Long) in == 1; + case INT32: + return (Integer) in == 1; + case BOOLEAN: + return ((Boolean) in); + case STRING: + return Boolean.parseBoolean((String) in); + default: + throw new DMLRuntimeException("Unsupported value type: " + vt); + } + } + public static String objectToString( Object in ) { return (in !=null) ? in.toString() : null; } diff --git a/src/main/python/systemds/utils/converters.py b/src/main/python/systemds/utils/converters.py index b86ac6c5413..fa1da33bf8d 100644 --- a/src/main/python/systemds/utils/converters.py +++ b/src/main/python/systemds/utils/converters.py @@ -155,18 +155,18 @@ def frame_block_to_pandas(sds: "SystemDSContext", fb: JavaObject): else: ret.append(None) elif d_type == "INT32": - byteArray = fb.getColumn(c_index).getAsByteArray(num_rows) + byteArray = fb.getColumn(c_index).getAsByteArray() ret = np.frombuffer(byteArray, dtype=np.int32) elif d_type == "INT64": - byteArray = fb.getColumn(c_index).getAsByteArray(num_rows) + byteArray = fb.getColumn(c_index).getAsByteArray() ret = np.frombuffer(byteArray, dtype=np.int64) elif d_type == "FP64": - byteArray = fb.getColumn(c_index).getAsByteArray(num_rows) + byteArray = fb.getColumn(c_index).getAsByteArray() ret = np.frombuffer(byteArray, dtype=np.float64) elif d_type == "BOOLEAN" or d_type == "BITSET": # TODO maybe it is more efficient to bit pack the booleans. # https://stackoverflow.com/questions/5602155/numpy-boolean-array-with-1-bit-entries - byteArray = fb.getColumn(c_index).getAsByteArray(num_rows) + byteArray = fb.getColumn(c_index).getAsByteArray() ret = np.frombuffer(byteArray, dtype=np.dtype("?")) else: raise NotImplementedError( diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java b/src/test/java/org/apache/sysds/test/TestUtils.java index 8565a1e2e50..88211197d1d 100644 --- a/src/test/java/org/apache/sysds/test/TestUtils.java +++ b/src/test/java/org/apache/sysds/test/TestUtils.java @@ -810,16 +810,18 @@ public static void compareMatrices(double[][] expectedMatrix, double[][] actualM public static void compareFrames(String[][] expectedFrame, String[][] actualFrame, int rows, int cols ) { int countErrors = 0; - for (int i = 0; i < rows; i++) { - for (int j = 0; j < cols; j++) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < rows&& countErrors < 10; i++) { + for (int j = 0; j < cols && countErrors < 10; j++) { if( !( (expectedFrame[i][j]==null && actualFrame[i][j]==null) || expectedFrame[i][j].equals(actualFrame[i][j]) || (expectedFrame[i][j]+".0").equals(actualFrame[i][j])) ) { - System.out.println("Expected:" + expectedFrame[i][j] +" vs actual: "+actualFrame[i][j]+" at "+i+" "+j); + sb.append("Expected:" + expectedFrame[i][j] +" vs actual: "+actualFrame[i][j]+" at "+i+" "+j + "\n"); countErrors++; } } } - assertTrue("" + countErrors + " values are not in equal", countErrors == 0); + sb.append("at least " + countErrors + " values are not equal" ); + assertTrue(sb.toString(), countErrors == 0); } public static void compareFrames(FrameBlock expected, FrameBlock actual, boolean checkMeta) { @@ -828,37 +830,59 @@ public static void compareFrames(FrameBlock expected, FrameBlock actual, boolean int rows = expected.getNumRows(); int cols = expected.getNumColumns(); - if(checkMeta) { - ColumnMetadata[] expectedMeta = expected.getColumnMetadata(); - ColumnMetadata[] actualMeta = actual.getColumnMetadata(); - - if((expectedMeta == null && actualMeta != null) || (expectedMeta != null && actualMeta == null)) - fail("wrongly allocated metadata"); - else if(expectedMeta != null && actualMeta != null) { - assertEquals("MetaData not correct size", expectedMeta.length, cols); - assertEquals("MetaData not correct size", actualMeta.length, cols); - for(int i = 0; i < cols; i++) - if(!expectedMeta[i].equals(actualMeta[i])) - fail("Meta data not equivalent: " + expectedMeta[i] + " vs " + actualMeta[i]); - } - - String[] expectedColNames = expected.getColumnNames(false); - String[] actualColNames = expected.getColumnNames(false); - if((expectedColNames == null && actualColNames != null) || - (expectedColNames != null && actualColNames == null)) - fail("wrongly allocated metadata"); - else if(expectedColNames != null && actualColNames != null) { - assertEquals("Column names not correct size", expectedColNames.length, cols); - assertEquals("Column names not correct size", actualColNames.length, cols); - for(int i = 0; i < cols; i++) { - assertEquals("Column names not equivalent", expectedColNames[i], actualColNames[i]); - } + if(checkMeta) + checkMetadata(expected, actual); + + for(int i = 0; i < rows; i++) { + for(int j = 0; j < cols; j++) { + assertEquals("Values not equivalent at: " + i + ", " + j, expected.get(i, j), actual.get(i, j)); + } + } + } + + private static void checkMetadata(FrameBlock expected, FrameBlock actual) { + int cols = expected.getNumColumns(); + ColumnMetadata[] expectedMeta = expected.getColumnMetadata(); + ColumnMetadata[] actualMeta = actual.getColumnMetadata(); + + if((expectedMeta == null && actualMeta != null) || (expectedMeta != null && actualMeta == null)) + fail("wrongly allocated metadata"); + else if(expectedMeta != null && actualMeta != null) { + assertEquals("MetaData not correct size", expectedMeta.length, cols); + assertEquals("MetaData not correct size", actualMeta.length, cols); + for(int i = 0; i < cols; i++) + if(!expectedMeta[i].equals(actualMeta[i])) + fail("Meta data not equivalent: " + expectedMeta[i] + " vs " + actualMeta[i]); + } + + String[] expectedColNames = expected.getColumnNames(false); + String[] actualColNames = expected.getColumnNames(false); + if((expectedColNames == null && actualColNames != null) || (expectedColNames != null && actualColNames == null)) + fail("wrongly allocated metadata"); + else if(expectedColNames != null && actualColNames != null) { + assertEquals("Column names not correct size", expectedColNames.length, cols); + assertEquals("Column names not correct size", actualColNames.length, cols); + for(int i = 0; i < cols; i++) { + assertEquals("Column names not equivalent", expectedColNames[i], actualColNames[i]); } } + } + + public static void compareFramesAsString(FrameBlock expected, FrameBlock actual, boolean checkMeta) { + assertEquals("Number of columns and rows are not equivalent", expected.getNumRows(), actual.getNumRows()); + assertEquals("Number of columns and rows are not equivalent", expected.getNumColumns(), actual.getNumColumns()); + + int rows = expected.getNumRows(); + int cols = expected.getNumColumns(); + + if(checkMeta) + checkMetadata(expected, actual); + for(int i = 0; i < rows; i++) { for(int j = 0; j < cols; j++) { - assertEquals("Values not equivalent at: " + i + ", " + j, expected.get(i, j), actual.get(i, j)); + assertEquals("Values not equivalent at: " + i + ", " + j, expected.get(i, j).toString(), + actual.get(i, j).toString()); } } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/AbstractCompressedUnaryTests.java b/src/test/java/org/apache/sysds/test/component/compress/AbstractCompressedUnaryTests.java index 4a8ced11589..f579ae53324 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/AbstractCompressedUnaryTests.java +++ b/src/test/java/org/apache/sysds/test/component/compress/AbstractCompressedUnaryTests.java @@ -266,9 +266,6 @@ public void testUnaryOperators(AggType aggType, AggregateUnaryOperator auop, boo // matrix-vector compressed MatrixBlock ret2 = cmb.aggregateUnaryOperations(auop, new MatrixBlock(), Math.max(rows, cols), null, inCP); - // LOG.error(cmb); - // LOG.error(ret1 + " " + ret2); - final int ruc = ret1.getNumRows(); final int cuc = ret1.getNumColumns(); final int rc = ret2.getNumRows(); diff --git a/src/test/java/org/apache/sysds/test/component/compress/colgroup/JolEstimateRLETest.java b/src/test/java/org/apache/sysds/test/component/compress/colgroup/JolEstimateRLETest.java index 786b0483a7d..411f634b9a8 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/colgroup/JolEstimateRLETest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/colgroup/JolEstimateRLETest.java @@ -113,9 +113,6 @@ public static Collection data() { mb = genRLE(100, 1, 10, 1); tests.add(new Object[] {mb}); - // mb = genRLE(1, 100, 10, 1); - // LOG.error(mb); - // tests.add(new Object[] {mb}); mb = genRLE(1000, 1, 10, 1312); tests.add(new Object[] {mb}); mb = genRLE(10000, 1, 10, 14512); diff --git a/src/test/java/org/apache/sysds/test/component/frame/FrameFromMatrixBlockTest.java b/src/test/java/org/apache/sysds/test/component/frame/FrameFromMatrixBlockTest.java index 985a0847660..76c71973222 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/FrameFromMatrixBlockTest.java +++ b/src/test/java/org/apache/sysds/test/component/frame/FrameFromMatrixBlockTest.java @@ -18,9 +18,14 @@ */ package org.apache.sysds.test.component.frame; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.data.DenseBlockFP64; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.lib.FrameFromMatrixBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -29,30 +34,171 @@ public class FrameFromMatrixBlockTest { + protected static final Log LOG = LogFactory.getLog(FrameFromMatrixBlockTest.class.getName()); + @Test - public void booleanColumn() { - MatrixBlock mb = new MatrixBlock(10, 3, 1.0); - FrameBlock fb = FrameFromMatrixBlock.convertToFrameBlock(mb, ValueType.BOOLEAN); - for(int i = 0; i < mb.getNumColumns(); i++) { - assertTrue(fb.getColumn(i).getValueType() == ValueType.BOOLEAN); + public void toBoolean() { + try { + MatrixBlock mb = new MatrixBlock(10, 3, 1.0); + FrameBlock fb = FrameFromMatrixBlock.convertToFrameBlock(mb, ValueType.BOOLEAN, 1); + verifyEquivalence(mb, fb, ValueType.BOOLEAN); + } + catch(Exception e) { + e.printStackTrace(); + fail("failed"); } } @Test - public void booleanColumnEmpty() { + public void toBooleanEmpty() { MatrixBlock mb = new MatrixBlock(10, 3, 0.0); - FrameBlock fb = FrameFromMatrixBlock.convertToFrameBlock(mb, ValueType.BOOLEAN); - for(int i = 0; i < mb.getNumColumns(); i++) { - assertTrue(fb.getColumn(i).getValueType() == ValueType.BOOLEAN); - } + FrameBlock fb = FrameFromMatrixBlock.convertToFrameBlock(mb, ValueType.BOOLEAN, 1); + verifyEquivalence(mb, fb, ValueType.BOOLEAN); } @Test - public void booleanColumnSparse() { + public void toBooleanSparse() { MatrixBlock mb = TestUtils.generateTestMatrixBlock(100, 100, 1, 1, 0.2, 213); - FrameBlock fb = FrameFromMatrixBlock.convertToFrameBlock(mb, ValueType.BOOLEAN); - for(int i = 0; i < mb.getNumColumns(); i++) { - assertTrue(fb.getColumn(i).getValueType() == ValueType.BOOLEAN); + FrameBlock fb = FrameFromMatrixBlock.convertToFrameBlock(mb, ValueType.BOOLEAN, 1); + verifyEquivalence(mb, fb, ValueType.BOOLEAN); + } + + @Test + public void toBooleanVerySparse() { + MatrixBlock mb = TestUtils.generateTestMatrixBlock(100, 100, 1, 1, 0.001, 213); + FrameBlock fb = FrameFromMatrixBlock.convertToFrameBlock(mb, ValueType.BOOLEAN, 1); + verifyEquivalence(mb, fb, ValueType.BOOLEAN); + } + + @Test + public void singleColShortcut() { + MatrixBlock mb = TestUtils.generateTestMatrixBlock(100, 1, 0, 1, 0.2, 213); + FrameBlock fb = FrameFromMatrixBlock.convertToFrameBlock(mb, ValueType.FP64, 1); + verifyEquivalence(mb, fb, ValueType.FP64); + } + + @Test + public void singleColShortcutToBoolean() { + MatrixBlock mb = TestUtils.generateTestMatrixBlock(100, 1, 1, 1, 0.2, 213); + FrameBlock fb = FrameFromMatrixBlock.convertToFrameBlock(mb, ValueType.BOOLEAN, 1); + verifyEquivalence(mb, fb, ValueType.BOOLEAN); + } + + @Test + public void toFloatDense() { + MatrixBlock mb = TestUtils.generateTestMatrixBlock(100, 10, 0, 1, 1.0, 213); + FrameBlock fb = FrameFromMatrixBlock.convertToFrameBlock(mb, ValueType.FP64, 1); + verifyEquivalence(mb, fb, ValueType.FP64); + } + + @Test + public void toFloatDenseParallel() { + MatrixBlock mb = TestUtils.generateTestMatrixBlock(1000, 100, 0, 1, 1.0, 213); + FrameBlock fb = FrameFromMatrixBlock.convertToFrameBlock(mb, ValueType.FP64, 4); + verifyEquivalence(mb, fb, ValueType.FP64); + } + + @Test + public void toBooleanDenseParallel() { + MatrixBlock mb = TestUtils.generateTestMatrixBlock(1000, 100, 1, 1, 0.5, 213); + FrameBlock fb = FrameFromMatrixBlock.convertToFrameBlock(mb, ValueType.BOOLEAN, 4); + verifyEquivalence(mb, fb, ValueType.BOOLEAN); + } + + @Test + public void toFloatDenseMultiBlock() { + MatrixBlock mb = mock(TestUtils.generateTestMatrixBlock(100, 10, 0, 1, 1.0, 213)); + FrameBlock fb = FrameFromMatrixBlock.convertToFrameBlock(mb, ValueType.FP64, 1); + verifyEquivalence(mb, fb, ValueType.FP64); + } + + @Test + public void toFloatDenseMultiBlockParallel() { + MatrixBlock mb = mock(TestUtils.generateTestMatrixBlock(100, 10, 0, 1, 1.0, 213)); + FrameBlock fb = FrameFromMatrixBlock.convertToFrameBlock(mb, ValueType.FP64, 4); + verifyEquivalence(mb, fb, ValueType.FP64); + } + + @Test + public void toBooleanDenseMultiBlock() { + MatrixBlock mb = mock(TestUtils.generateTestMatrixBlock(100, 10, 1, 1, 0.7, 213)); + FrameBlock fb = FrameFromMatrixBlock.convertToFrameBlock(mb, ValueType.BOOLEAN, 1); + verifyEquivalence(mb, fb, ValueType.BOOLEAN); + } + + @Test + public void shortcutEmpty() { + MatrixBlock mb = TestUtils.generateTestMatrixBlock(100, 10, 0, 0, 1.0, 213); + assertTrue(mb.isEmpty()); + FrameBlock fb = FrameFromMatrixBlock.convertToFrameBlock(mb, ValueType.FP64, 1); + verifyEquivalence(mb, fb, ValueType.FP64); + } + + @Test + public void timeChange() { + // MatrixBlock mb = TestUtils.generateTestMatrixBlock(64000, 2000, 1, 1, 0.5, 2340); + + // for(int i = 0; i < 10; i++) { + // Timing time = new Timing(true); + // FrameFromMatrixBlock.convertToFrameBlock(mb, ValueType.BOOLEAN, 1); + // LOG.error(time.stop()); + // } + + // for(int i = 0; i < 10; i++) { + // Timing time = new Timing(true); + // FrameFromMatrixBlock.convertToFrameBlock(mb, ValueType.BOOLEAN, 16); + // LOG.error(time.stop()); + // } + + // for(int i = 0; i < 10; i ++){ + // Timing time = new Timing(true); + // FrameFromMatrixBlock.convertToFrameBlock(mb, ValueType.FP64, 1); + // LOG.error(time.stop()); + // } + + // for(int i = 0; i < 10; i++) { + // Timing time = new Timing(true); + // FrameFromMatrixBlock.convertToFrameBlock(mb, ValueType.FP64, 16); + // LOG.error(time.stop()); + // } + } + + private void verifyEquivalence(MatrixBlock mb, FrameBlock fb, ValueType vt) { + int nRow = mb.getNumRows(); + int nCol = mb.getNumColumns(); + assertEquals(mb.getNumColumns(), fb.getSchema().length); + for(int i = 0; i < nCol; i++) + assertTrue(fb.getColumn(i).getValueType() == vt); + + for(int i = 0; i < nRow; i++) + for(int j = 0; j < nCol; j++) + assertEquals(mb.getValue(i, j), fb.getDouble(i, j), 0.0000001); + + } + + private MatrixBlock mock(MatrixBlock m) { + MatrixBlock ret = new MatrixBlock(m.getNumRows(), m.getNumColumns(), + new DenseBlockFP64Mock(new int[] {m.getNumRows(), m.getNumColumns()}, m.getDenseBlockValues())); + ret.setNonZeros(m.getNumRows() * m.getNumColumns()); + return ret; + + } + + private class DenseBlockFP64Mock extends DenseBlockFP64 { + private static final long serialVersionUID = -3601232958390554672L; + + public DenseBlockFP64Mock(int[] dims, double[] data) { + super(dims, data); + } + + @Override + public boolean isContiguous() { + return false; + } + + @Override + public int numBlocks() { + return 1; } } diff --git a/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java b/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java index ec9a8b61471..d5276d99752 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java +++ b/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java @@ -19,6 +19,7 @@ package org.apache.sysds.test.component.frame.array; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -466,6 +467,26 @@ public void setRangeBitSet_VectorizedKernel_3() { } } + @Test + public void LongToBits_0(){ + assertEquals(BitSetArray.longToBits(0), "0000000000000000000000000000000000000000000000000000000000000000"); + } + + @Test + public void LongToBits_2(){ + assertEquals(BitSetArray.longToBits(2), "0000000000000000000000000000000000000000000000000000000000000010"); + } + + @Test + public void LongToBits_5(){ + assertEquals(BitSetArray.longToBits(5), "0000000000000000000000000000000000000000000000000000000000000101"); + } + + @Test + public void LongToBits_minusOne(){ + assertEquals(BitSetArray.longToBits(-1), "1111111111111111111111111111111111111111111111111111111111111111"); + } + public static BitSetArray createTrueBitArray(int length) { BitSet init = new BitSet(); diff --git a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayConstantTests.java b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayConstantTests.java new file mode 100644 index 00000000000..36d37605d15 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayConstantTests.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.frame.array; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +import java.util.ArrayList; +import java.util.Collection; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.frame.data.columns.ArrayFactory; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(value = Parameterized.class) +public class FrameArrayConstantTests { + protected static final Log LOG = LogFactory.getLog(FrameArrayConstantTests.class.getName()); + + final public ValueType t; + final public int nRow; + + @Parameters + public static Collection data() { + ArrayList tests = new ArrayList<>(); + try { + for(ValueType t : ValueType.values()) { + if(t == ValueType.UNKNOWN) + continue; + tests.add(new Object[] {t, 10}); + tests.add(new Object[] {t, 100}); + tests.add(new Object[] {t, 1}); + } + } + catch(Exception e) { + e.printStackTrace(); + fail("failed constructing tests"); + } + + return tests; + } + + public FrameArrayConstantTests(ValueType t, int nRow) { + this.t = t; + this.nRow = nRow; + } + + @Test + public void testConstruction() { + try { + Array a = ArrayFactory.allocate(t, nRow, "0"); + for(int i = 0; i < nRow; i++) + assertEquals(a.getAsDouble(i), 0.0, 0.0000000001); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testConstruction_default() { + try { + Array a = ArrayFactory.allocate(t, nRow); + if(t != ValueType.STRING) + for(int i = 0; i < nRow; i++) + assertEquals(a.getAsDouble(i), 0.0, 0.0000000001); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testConstruction_1() { + try { + Array a = ArrayFactory.allocate(t, nRow, "1.0"); + for(int i = 0; i < nRow; i++) + assertEquals(a.getAsDouble(i), 1.0, 0.0000000001); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testConstruction_null() { + try { + Array a = ArrayFactory.allocate(t, nRow, null); + if(t != ValueType.STRING) + for(int i = 0; i < nRow; i++) + assertEquals(a.getAsDouble(i), 0.0, 0.0000000001); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java index abba5b2c08e..303c55d244e 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java +++ b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java @@ -22,7 +22,6 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import static org.junit.Assume.assumeTrue; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -41,6 +40,8 @@ import org.apache.sysds.runtime.frame.data.columns.Array; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType; +import org.apache.sysds.runtime.frame.data.columns.BitSetArray; +import org.apache.sysds.runtime.frame.data.columns.BooleanArray; import org.apache.sysds.runtime.frame.data.columns.StringArray; import org.junit.Test; import org.junit.runner.RunWith; @@ -84,6 +85,12 @@ public static Collection data() { tests.add(new Object[] {ArrayFactory.create(new double[] {0.0, 1.0, 1.0, 0.0}), FrameArrayType.FP64}); tests.add(new Object[] {ArrayFactory.create(new long[] {0, 1, 1, 0, 0, 1}), FrameArrayType.INT64}); tests.add(new Object[] {ArrayFactory.create(new int[] {0, 1, 1, 0, 0, 1}), FrameArrayType.INT32}); + tests.add(new Object[] {ArrayFactory.create(generateRandom01String(100, 324)), FrameArrayType.STRING}); + tests.add(new Object[] {ArrayFactory.create(generateRandom01String(80, 22)), FrameArrayType.STRING}); + tests.add(new Object[] {ArrayFactory.create(generateRandom01String(32, 221)), FrameArrayType.STRING}); + tests.add(new Object[] {ArrayFactory.create(generateRandomTrueFalseString(32, 221)), FrameArrayType.STRING}); + tests.add(new Object[] {ArrayFactory.create(generateRandomTrueFalseString(80, 221)), FrameArrayType.STRING}); + tests.add(new Object[] {ArrayFactory.create(generateRandomTrueFalseString(150, 221)), FrameArrayType.STRING}); // Long to int tests.add(new Object[] {ArrayFactory.create(new long[] {3214, 424, 13, 22, 111, 134}), FrameArrayType.INT64}); @@ -118,7 +125,7 @@ public void serialize() { public void testGet() { int size = a.size(); for(int i = 0; i < size; i++) - assumeTrue(a.get(i).toString().equals(s.get(i))); + assertTrue(a.get(i).toString().equals(s.get(i))); } @Test(expected = ArrayIndexOutOfBoundsException.class) @@ -137,7 +144,20 @@ public void testGetOutOfBoundsLower() { @Test public void getSizeEstimateVsReal() { - assumeTrue(a.getInMemorySize() <= ArrayFactory.getInMemorySize(a.getValueType(), a.size())); + long memSize = a.getInMemorySize(); + long estSize = ArrayFactory.getInMemorySize(a.getValueType(), a.size()); + switch(a.getValueType()) { + case BOOLEAN: + if(a instanceof BitSetArray) + estSize = BitSetArray.estimateInMemorySize(a.size()); + else + estSize = BooleanArray.estimateInMemorySize(a.size()); + default: // nothing + } + if(memSize > estSize) + fail("Estimated size is not smaller than actual:" + memSize + " " + estSize + "\n" + a.getValueType() + " " + + a.getClass().getSimpleName()); + } @Test @@ -318,6 +338,75 @@ public void testSetRange(int start, int end, int off) { } } + @Test + public void testSetRange_1() { + if(a.size() > 10) + testSetRange(0, 10, 20, 132); + } + + @SuppressWarnings("unchecked") + public void testSetRange(int start, int end, int otherSize, int seed) { + try { + Array other = null; + switch(a.getFrameArrayType()) { + case BITSET: + other = ArrayFactory.create(generateRandomBitSet(otherSize, seed), otherSize); + break; + case BOOLEAN: + other = ArrayFactory.create(generateRandomBoolean(otherSize, seed)); + break; + case FP32: + other = ArrayFactory.create(generateRandomFloat(otherSize, seed)); + break; + case FP64: + other = ArrayFactory.create(generateRandomDouble(otherSize, seed)); + break; + case INT32: + other = ArrayFactory.create(generateRandomInteger(otherSize, seed)); + break; + case INT64: + other = ArrayFactory.create(generateRandomLong(otherSize, seed)); + break; + case STRING: + other = ArrayFactory.create(generateRandomString(otherSize, seed)); + break; + default: + throw new NotImplementedException(); + } + + Array aa = a.clone(); + switch(a.getFrameArrayType()) { + case FP64: + ((Array) aa).set(start, end, (Array) other); + break; + case FP32: + ((Array) aa).set(start, end, (Array) other); + break; + case INT32: + ((Array) aa).set(start, end, (Array) other); + break; + case INT64: + ((Array) aa).set(start, end, (Array) other); + break; + case BOOLEAN: + case BITSET: + ((Array) aa).set(start, end, (Array) other); + break; + case STRING: + ((Array) aa).set(start, end, (Array) other); + break; + default: + throw new NotImplementedException(); + } + compareSetSubRange(aa, other, start, end, 0, aa.getValueType()); + + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + @Test @SuppressWarnings("unchecked") public void set() { @@ -392,6 +481,315 @@ public void setDouble() { } } + @Test + @SuppressWarnings("unchecked") + public void setDouble_2() { + Double vd = 0.0d; + a.set(0, vd); + switch(a.getFrameArrayType()) { + case FP64: + assertEquals(((Array) a).get(0), vd, 0.0000001); + return; + case FP32: + assertEquals(((Array) a).get(0), vd, 0.0000001); + return; + case INT32: + assertEquals(((Array) a).get(0), Integer.valueOf((int) (double) vd)); + return; + case INT64: + assertEquals(((Array) a).get(0), Long.valueOf((long) (double) vd)); + return; + case BOOLEAN: + case BITSET: + assertEquals(((Array) a).get(0), false); + return; + case STRING: + assertEquals(((Array) a).get(0), Double.toString(vd)); + return; + default: + throw new NotImplementedException(); + } + } + + @Test + public void analyzeValueType() { + ValueType av = a.analyzeValueType(); + switch(a.getValueType()) { + case BOOLEAN: + switch(av) { + case BOOLEAN: + return; + default: + fail("Invalid type returned from analyze valueType"); + } + case INT32: + switch(av) { + case BOOLEAN: + case INT32: + case UINT8: + return; + default: + fail("Invalid type returned from analyze valueType"); + } + case INT64: + switch(av) { + case BOOLEAN: + case INT32: + case UINT8: + case INT64: + return; + default: + fail("Invalid type returned from analyze valueType"); + } + case UINT8: + switch(av) { + case BOOLEAN: + case UINT8: + return; + default: + fail("Invalid type returned from analyze valueType"); + } + case FP32: + switch(av) { + case BOOLEAN: + case INT32: + case UINT8: + case INT64: + case FP32: + return; + default: + fail("Invalid type returned from analyze valueType"); + } + case FP64: + switch(av) { + case BOOLEAN: + case INT32: + case UINT8: + case INT64: + case FP32: + case FP64: + return; + default: + fail("Invalid type returned from analyze valueType"); + } + case STRING: + break;// all allowed + case UNKNOWN: + fail("Not allowed to be unknown"); + default: + break; + } + } + + @Test + public void setNull() { + // should not crash + a.set(0, null); + } + + @Test + public void toByteArray() { + if(a.getValueType() == ValueType.STRING) + return; + // just test that it serialize as byte array with no crashes + a.getAsByteArray(); + } + + @Test + public void appendString() { + Array aa = a.clone(); + + switch(a.getValueType()) { + case BOOLEAN: + aa.append("0"); + assertEquals((Boolean) aa.get(aa.size() - 1), false); + aa.append("1"); + assertEquals((Boolean) aa.get(aa.size() - 1), true); + break; + case FP32: + float vf = 3215216.222f; + String vfs = vf + ""; + aa.append(vfs); + assertEquals((float) aa.get(aa.size() - 1), vf, 0.00001); + + vf = 32152336.222f; + vfs = vf + ""; + aa.append(vfs); + assertEquals((float) aa.get(aa.size() - 1), vf, 0.00001); + break; + case FP64: + double vd = 3215216.222; + String vds = vd + ""; + aa.append(vds); + assertEquals((double) aa.get(aa.size() - 1), vd, 0.00001); + + vd = 222.222; + vds = vd + ""; + aa.append(vds); + assertEquals((double) aa.get(aa.size() - 1), vd, 0.00001); + break; + case INT32: + int vi = 321521; + String vis = vi + ""; + aa.append(vis); + assertEquals((int) aa.get(aa.size() - 1), vi); + + vi = -2321; + vis = vi + ""; + aa.append(vis); + assertEquals((int) aa.get(aa.size() - 1), vi); + break; + case INT64: + long vl = 321521; + String vls = vl + ""; + aa.append(vls); + assertEquals((long) aa.get(aa.size() - 1), vl); + + vl = -22223; + vls = vl + ""; + aa.append(vls); + assertEquals((long) aa.get(aa.size() - 1), vl); + break; + case STRING: + String vs = "ThisIsAMonkeyTestSting"; + aa.append(vs); + assertEquals((String) aa.get(aa.size() - 1), vs); + + vs = "£$&*%!))"; + aa.append(vs); + assertEquals((String) aa.get(aa.size() - 1), vs); + break; + case UINT8: + int vi8 = 234; + String vi8s = vi8 + ""; + aa.append(vi8s); + assertEquals((int) aa.get(aa.size() - 1), vi8); + + vi8 = 42; + vi8s = vi8 + ""; + aa.append(vi8s); + assertEquals((int) aa.get(aa.size() - 1), vi8); + break; + case UNKNOWN: + default: + throw new DMLRuntimeException("Invalid type"); + } + } + + @Test + public void appendNull() { + Array aa = a.clone(); + + aa.append((String) null); + switch(a.getValueType()) { + case BOOLEAN: + assertEquals((Boolean) aa.get(aa.size() - 1), false); + break; + case FP32: + assertEquals((float) aa.get(aa.size() - 1), 0.0, 0.00001); + break; + case FP64: + assertEquals((double) aa.get(aa.size() - 1), 0.0, 0.00001); + break; + case INT32: + assertEquals((int) aa.get(aa.size() - 1), 0); + break; + case INT64: + assertEquals((long) aa.get(aa.size() - 1), 0); + break; + case STRING: + assertEquals((String) aa.get(aa.size() - 1), null); + break; + case UINT8: + assertEquals((int) aa.get(aa.size() - 1), 0); + break; + case UNKNOWN: + default: + throw new DMLRuntimeException("Invalid type"); + } + } + + @Test + public void append60Null() { + Array aa = a.clone(); + + try{ + + for(int i = 0; i < 60; i++) + aa.append((String) null); + + switch(a.getValueType()) { + case BOOLEAN: + assertEquals((Boolean) aa.get(aa.size() - 1), false); + break; + case FP32: + assertEquals((float) aa.get(aa.size() - 1), 0.0, 0.00001); + break; + case FP64: + assertEquals((double) aa.get(aa.size() - 1), 0.0, 0.00001); + break; + case INT32: + assertEquals((int) aa.get(aa.size() - 1), 0); + break; + case INT64: + assertEquals((long) aa.get(aa.size() - 1), 0); + break; + case STRING: + assertEquals((String) aa.get(aa.size() - 1), null); + break; + case UINT8: + assertEquals((int) aa.get(aa.size() - 1), 0); + break; + case UNKNOWN: + default: + throw new DMLRuntimeException("Invalid type"); + } + } + catch(Exception e){ + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + @SuppressWarnings("unchecked") + public void testSetNzSelf() { + Array aa = a.clone(); + try { + + switch(a.getValueType()) { + case BOOLEAN: + ((Array) aa).setNz((Array) a); + break; + case FP32: + ((Array) aa).setNz((Array) a); + break; + case FP64: + ((Array) aa).setNz((Array) a); + break; + case INT32: + case UINT8: + ((Array) aa).setNz((Array) a); + break; + case INT64: + ((Array) aa).setNz((Array) a); + break; + case STRING: + ((Array) aa).setNz((Array) a); + break; + case UNKNOWN: + default: + throw new DMLRuntimeException("Invalid type"); + } + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + + compare(aa, a); + } + protected static void compare(Array a, Array b) { int size = a.size(); assertTrue(a.size() == b.size()); @@ -407,20 +805,11 @@ protected static void compare(Array sub, Array b, int off) { } protected static void compareSetSubRange(Array out, Array in, int rl, int ru, int off, ValueType vt) { - switch(vt) { - // case FP64: - // case FP32: - // return; - default: - for(int i = rl; i <= ru; i++, off++) { - String v1 = out.get(i).toString(); - String v2 = in.get(off).toString(); - - assertEquals("i: " + i + " args: " + rl + " " + ru + " " + (off - i) + " " + out.size(), v1, v2); - } - + for(int i = rl; i <= ru; i++, off++) { + String v1 = out.get(i).toString(); + String v2 = in.get(off).toString(); + assertEquals("i: " + i + " args: " + rl + " " + ru + " " + (off - i) + " " + out.size(), v1, v2); } - } protected static Array serializeAndBack(Array g) { @@ -476,6 +865,22 @@ public static String[] generateRandomString(int size, int seed) { return ret; } + public static String[] generateRandom01String(int size, int seed) { + Random r = new Random(seed); + String[] ret = new String[size]; + for(int i = 0; i < size; i++) + ret[i] = r.nextInt(1) + ""; + return ret; + } + + public static String[] generateRandomTrueFalseString(int size, int seed) { + Random r = new Random(seed); + String[] ret = new String[size]; + for(int i = 0; i < size; i++) + ret[i] = r.nextInt(1) == 1 ? "true" : "false"; + return ret; + } + protected static boolean[] generateRandomBoolean(int size, int seed) { Random r = new Random(seed); boolean[] ret = new boolean[size]; diff --git a/src/test/java/org/apache/sysds/test/component/frame/array/NegativeArrayTests.java b/src/test/java/org/apache/sysds/test/component/frame/array/NegativeArrayTests.java index 15d313770e8..59f0135fe13 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/array/NegativeArrayTests.java +++ b/src/test/java/org/apache/sysds/test/component/frame/array/NegativeArrayTests.java @@ -24,6 +24,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.columns.Array; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory; +import org.apache.sysds.runtime.frame.data.columns.BitSetArray; import org.apache.sysds.runtime.frame.data.columns.StringArray; import org.junit.Test; @@ -73,4 +74,30 @@ public void changeTypeBoolean_3() { StringArray a = ArrayFactory.create(new String[] {"HI", "false", "0"}); a.changeType(ValueType.BOOLEAN); } + + @Test(expected = DMLRuntimeException.class) + public void changeTypeBoolean_4() { + String[] s = new String[100]; + s[0] = "1"; + s[1] = "10"; + StringArray a = ArrayFactory.create(s); + a.changeType(ValueType.BOOLEAN); + } + + + @Test(expected = DMLRuntimeException.class) + public void invalidConstructionBitArrayToSmall(){ + new BitSetArray(new long[0], 10 ); + } + + @Test(expected = DMLRuntimeException.class) + public void invalidConstructionBitArrayToSmall_2(){ + new BitSetArray(new long[1], 80 ); + } + + @Test(expected = DMLRuntimeException.class) + public void invalidConstructionBitArrayToBig(){ + new BitSetArray(new long[10], 10 ); + } + } diff --git a/src/test/java/org/apache/sysds/test/functions/frame/FrameConstructorTest.java b/src/test/java/org/apache/sysds/test/functions/frame/FrameConstructorTest.java index 039692c9e50..146c560ae50 100644 --- a/src/test/java/org/apache/sysds/test/functions/frame/FrameConstructorTest.java +++ b/src/test/java/org/apache/sysds/test/functions/frame/FrameConstructorTest.java @@ -19,23 +19,27 @@ package org.apache.sysds.test.functions.frame; +import java.util.Random; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types; +import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.io.FrameReaderFactory; -import org.apache.sysds.runtime.util.DataConverter; -import org.apache.sysds.test.TestConfiguration; -import org.junit.Test; -import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.meta.MatrixCharacteristics; import org.apache.sysds.runtime.util.UtilFunctions; import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.test.TestUtils; - -import java.util.Random; +import org.junit.Test; public class FrameConstructorTest extends AutomatedTestBase { + + protected static final Log LOG = LogFactory.getLog(FrameConstructorTest.class.getName()); + private final static String TEST_DIR = "functions/frame/"; private final static String TEST_NAME = "FrameConstructorTest"; private final static String TEST_CLASS_DIR = TEST_DIR + FrameConstructorTest.class.getSimpleName() + "/"; @@ -151,13 +155,17 @@ private void runFrameTest(TestType type, FrameBlock expectedOutput, Types.ExecMo fullDMLScriptName = HOME + TEST_NAME + ".dml"; programArgs = new String[] {"-explain", "-args", String.valueOf(type), output("F2")}; - runTest(true, false, null, -1); + + runTest(null); + FrameBlock fB = FrameReaderFactory .createFrameReader(Types.FileFormat.CSV) .readFrameFromHDFS(output("F2"), rows, cols); - String[][] R1 = DataConverter.convertToStringFrame(expectedOutput); - String[][] R2 = DataConverter.convertToStringFrame(fB); - TestUtils.compareFrames(R1, R2, R1.length, R1[0].length); + + if( type == TestType.MULTI_ROW_DATA) + fB = fB.slice(0, expectedOutput.getNumRows() -1); + + TestUtils.compareFramesAsString(expectedOutput, fB, false); int nrow = type == TestType.MULTI_ROW_DATA ? 5 : 40; checkDMLMetaDataFile("F2", new MatrixCharacteristics(nrow, cols)); } From ed7980737bd8b33d672d253e0c07327bbfcd10ad Mon Sep 17 00:00:00 2001 From: baunsgaard Date: Wed, 4 Jan 2023 22:07:17 +0100 Subject: [PATCH 2/5] prepare for merge --- .../org/apache/sysds/hops/AggUnaryOp.java | 6 ++-- .../java/org/apache/sysds/hops/BinaryOp.java | 2 +- .../org/apache/sysds/hops/LeftIndexingOp.java | 3 +- .../java/org/apache/sysds/lops/Unary.java | 32 ++++++++----------- .../java/org/apache/sysds/lops/UnaryCP.java | 8 ++--- .../dictionary/DictLibMatrixMult.java | 1 + .../cp/BinaryScalarScalarCPInstruction.java | 3 -- .../sysds/runtime/io/FrameWriterTextCSV.java | 7 ++-- 8 files changed, 26 insertions(+), 36 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java index ac4e018a5f9..23439b182e8 100644 --- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java +++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java @@ -138,7 +138,7 @@ else if( et != ExecType.FED && isUnaryAggregateOuterCPRewriteApplicable() ) if (getDataType() == DataType.SCALAR) { UnaryCP unary1 = new UnaryCP(agg1, OpOp1.CAST_AS_SCALAR, - getDataType(), getValueType(), 1); + getDataType(), getValueType()); unary1.getOutputParameters().setDimensions(0, 0, 0, -1); setLineNumbers(unary1); agg1 = unary1; @@ -180,7 +180,7 @@ else if( isUnaryAggregateOuterSPRewriteApplicable() ) if (getDataType() == DataType.SCALAR) { UnaryCP unary1 = new UnaryCP(transform1, - OpOp1.CAST_AS_SCALAR, getDataType(), getValueType(), 1); + OpOp1.CAST_AS_SCALAR, getDataType(), getValueType()); unary1.getOutputParameters().setDimensions(0, 0, 0, -1); setLineNumbers(unary1); setLops(unary1); @@ -200,7 +200,7 @@ else if( isUnaryAggregateOuterSPRewriteApplicable() ) if (getDataType() == DataType.SCALAR) { UnaryCP unary1 = new UnaryCP(aggregate, - OpOp1.CAST_AS_SCALAR, getDataType(), getValueType(), 1); + OpOp1.CAST_AS_SCALAR, getDataType(), getValueType()); unary1.getOutputParameters().setDimensions(0, 0, 0, -1); setLineNumbers(unary1); setLops(unary1); diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java index 549bf53e33d..2346eeebfe6 100644 --- a/src/main/java/org/apache/sysds/hops/BinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java @@ -453,7 +453,7 @@ op, getDataType(), getValueType(), et, getInput().get(0) == getInput().get(1).getInput().get(0); if(isGPUSoftmax) { UnaryCP softmax = new UnaryCP(getInput().get(0).getInput().get(0).constructLops(), - OpOp1.SOFTMAX, getDataType(), getValueType(), et, OptimizerUtils.getConstrainedNumThreads(_maxNumThreads)); + OpOp1.SOFTMAX, getDataType(), getValueType(), et); setOutputDimensions(softmax); setLineNumbers(softmax); setLops(softmax); diff --git a/src/main/java/org/apache/sysds/hops/LeftIndexingOp.java b/src/main/java/org/apache/sysds/hops/LeftIndexingOp.java index ed8cc192d0f..79d863adf9c 100644 --- a/src/main/java/org/apache/sysds/hops/LeftIndexingOp.java +++ b/src/main/java/org/apache/sysds/hops/LeftIndexingOp.java @@ -122,10 +122,9 @@ public Lop constructLops() //insert cast to matrix if necessary (for reuse broadcast runtime) Lop rightInput = right.constructLops(); if (isRightHandSideScalar()) { - // one thread because it is cast from scalar. rightInput = new UnaryCP(rightInput, (left.getDataType()==DataType.MATRIX?OpOp1.CAST_AS_MATRIX:OpOp1.CAST_AS_FRAME), - left.getDataType(), right.getValueType(), 1); + left.getDataType(), right.getValueType()); long bsize = ConfigurationManager.getBlocksize(); rightInput.getOutputParameters().setDimensions( 1, 1, bsize, -1); } diff --git a/src/main/java/org/apache/sysds/lops/Unary.java b/src/main/java/org/apache/sysds/lops/Unary.java index c5323325f11..d95235798ad 100644 --- a/src/main/java/org/apache/sysds/lops/Unary.java +++ b/src/main/java/org/apache/sysds/lops/Unary.java @@ -156,28 +156,24 @@ public String getInstructions(String input1, String output) { // Unary operators with one input StringBuilder sb = new StringBuilder(); - sb.append( getExecType() ); - sb.append( Lop.OPERAND_DELIMITOR ); - sb.append( getOpcode() ); - sb.append( OPERAND_DELIMITOR ); - sb.append( getInputs().get(0).prepInputOperand(input1) ); - sb.append( OPERAND_DELIMITOR ); - sb.append( prepOutputOperand(output) ); - - if( getExecType() == ExecType.CP || getExecType() == ExecType.FED){ - sb.append( OPERAND_DELIMITOR ); - sb.append( _numThreads ); - if( isMultiThreadedOp(operation)){ + sb.append(getExecType()); + sb.append(Lop.OPERAND_DELIMITOR); + sb.append(getOpcode()); + sb.append(OPERAND_DELIMITOR); + sb.append(getInputs().get(0).prepInputOperand(input1)); + sb.append(OPERAND_DELIMITOR); + sb.append(prepOutputOperand(output)); + + if(getExecType() == ExecType.CP || getExecType() == ExecType.FED) { + sb.append(OPERAND_DELIMITOR); + sb.append(_numThreads); + if(isMultiThreadedOp(operation)) { - sb.append( OPERAND_DELIMITOR ); - sb.append( _inplace ); + sb.append(OPERAND_DELIMITOR); + sb.append(_inplace); } } - // //num threads for cumulative cp ops - // if( (getExecType() == ExecType.CP || getExecType() == ExecType.FED) && isMultiThreadedOp(operation) ) { - // } - appendFedOut(sb); return sb.toString(); diff --git a/src/main/java/org/apache/sysds/lops/UnaryCP.java b/src/main/java/org/apache/sysds/lops/UnaryCP.java index 09ce9dd5681..4b95e8c1b07 100644 --- a/src/main/java/org/apache/sysds/lops/UnaryCP.java +++ b/src/main/java/org/apache/sysds/lops/UnaryCP.java @@ -26,10 +26,9 @@ import org.apache.sysds.common.Types.OpOp1; import org.apache.sysds.common.Types.ValueType; -public class UnaryCP extends Lop -{ - private OpOp1 operation; - private int _numThreads = 1; +public class UnaryCP extends Lop { + private final OpOp1 operation; + private final int _numThreads; /** * Constructor to perform a scalar operation @@ -47,6 +46,7 @@ public UnaryCP(Lop input, OpOp1 op, DataType dt, ValueType vt, ExecType et, int addInput(input); input.addOutput(this); lps.setProperties(inputs, et); + _numThreads = k; } public UnaryCP(Lop input, OpOp1 op, DataType dt, ValueType vt, ExecType et) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictLibMatrixMult.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictLibMatrixMult.java index 3d59f854e32..2b2e41600a6 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictLibMatrixMult.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictLibMatrixMult.java @@ -426,6 +426,7 @@ protected static void MMToUpperTriangleSparseDenseDiagonal(SparseBlock left, dou protected static void MMToUpperTriangleDenseDense(double[] left, double[] right, int[] rowsLeft, int[] colsRight, MatrixBlock result) { final int loc = location(rowsLeft, colsRight); + // LOG.error("loc:" + loc); if(loc < 0) MMToUpperTriangleDenseDenseAllUpperTriangle(left, right, rowsLeft, colsRight, result); else if(loc > 0) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryScalarScalarCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryScalarScalarCPInstruction.java index 167a57096f0..83c119e6c34 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryScalarScalarCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryScalarScalarCPInstruction.java @@ -19,8 +19,6 @@ package org.apache.sysds.runtime.instructions.cp; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; @@ -29,7 +27,6 @@ import org.apache.sysds.runtime.matrix.operators.Operator; public class BinaryScalarScalarCPInstruction extends BinaryCPInstruction { - private static final Log LOG = LogFactory.getLog(BinaryScalarScalarCPInstruction.class.getName()); protected BinaryScalarScalarCPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) { super(CPType.Binary, op, in1, in2, out, opcode, istr); diff --git a/src/main/java/org/apache/sysds/runtime/io/FrameWriterTextCSV.java b/src/main/java/org/apache/sysds/runtime/io/FrameWriterTextCSV.java index 418da607654..92abe0932d7 100644 --- a/src/main/java/org/apache/sysds/runtime/io/FrameWriterTextCSV.java +++ b/src/main/java/org/apache/sysds/runtime/io/FrameWriterTextCSV.java @@ -24,8 +24,6 @@ import java.io.OutputStreamWriter; import java.util.Iterator; -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.mapred.JobConf; @@ -40,9 +38,8 @@ * Single-threaded frame text csv writer. * */ -public class FrameWriterTextCSV extends FrameWriter{ - protected static final Log LOG = LogFactory.getLog(FrameWriterTextCSV.class.getName()); - +public class FrameWriterTextCSV extends FrameWriter +{ //blocksize for string concatenation in order to prevent write OOM //(can be set to very large value to disable blocking) public static final int BLOCKSIZE_J = 32; //32 cells (typically ~512B, should be less than write buffer of 1KB) From 6b222bb3944e15129af5b699b933895206a33820 Mon Sep 17 00:00:00 2001 From: baunsgaard Date: Wed, 4 Jan 2023 23:02:13 +0100 Subject: [PATCH 3/5] FrameAppend --- .../sysds/runtime/frame/data/FrameBlock.java | 136 +++++++----------- .../frame/data/columns/DoubleArray.java | 2 +- .../frame/data/columns/StringArray.java | 2 +- .../frame/data/lib/FrameLibAppend.java | 103 +++++++++++++ .../frame/data/{ => lib}/FrameUtil.java | 42 +++++- .../spark/UnaryFrameSPInstruction.java | 3 +- 6 files changed, 201 insertions(+), 87 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java rename src/main/java/org/apache/sysds/runtime/frame/data/{ => lib}/FrameUtil.java (62%) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java index da69265c143..b037b3b5cc2 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java @@ -57,8 +57,10 @@ import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; import org.apache.sysds.runtime.frame.data.iterators.IteratorFactory; import org.apache.sysds.runtime.frame.data.lib.FrameFromMatrixBlock; +import org.apache.sysds.runtime.frame.data.lib.FrameLibAppend; import org.apache.sysds.runtime.frame.data.lib.FrameLibApplySchema; import org.apache.sysds.runtime.frame.data.lib.FrameLibDetectSchema; +import org.apache.sysds.runtime.frame.data.lib.FrameUtil; import org.apache.sysds.runtime.functionobjects.ValueComparisonFunction; import org.apache.sysds.runtime.instructions.cp.BooleanObject; import org.apache.sysds.runtime.instructions.cp.DoubleObject; @@ -1031,55 +1033,57 @@ public void slice(ArrayList> outList, IndexRange range, i * @return frame block */ public FrameBlock append(FrameBlock that, boolean cbind) { - FrameBlock ret = new FrameBlock(); - if(cbind) // COLUMN APPEND - { - // sanity check row dimension mismatch - if(getNumRows() != that.getNumRows()) { - throw new DMLRuntimeException( - "Incompatible number of rows for cbind: " + that.getNumRows() + " (expected: " + getNumRows() + ")"); - } - - // allocate output frame - ret._numRows = _numRows; - - // concatenate schemas (w/ deep copy to prevent side effects) - ret._schema = (ValueType[]) ArrayUtils.addAll(_schema, that._schema); - ret._colnames = (String[]) ArrayUtils.addAll(getColumnNames(), that.getColumnNames()); - ret._colmeta = (ColumnMetadata[]) ArrayUtils.addAll(_colmeta, that._colmeta); - - // check and enforce unique columns names - if(!Arrays.stream(ret._colnames).allMatch(new HashSet<>()::add)) - ret._colnames = createColNames(ret.getNumColumns()); - - // concatenate column data (w/ shallow copy which is safe due to copy on write semantics) - ret._coldata = (Array[]) ArrayUtils.addAll(_coldata, that._coldata); - } - else // ROW APPEND - { - // sanity check column dimension mismatch - if(getNumColumns() != that.getNumColumns()) { - throw new DMLRuntimeException("Incompatible number of columns for rbind: " + that.getNumColumns() - + " (expected: " + getNumColumns() + ")"); - } - ret._numRows = _numRows; // note set to previous since each row is appended on. - ret._schema = _schema.clone(); - ret._colnames = (_colnames != null) ? _colnames.clone() : null; - ret._colmeta = new ColumnMetadata[getNumColumns()]; - for(int j = 0; j < _schema.length; j++) - ret._colmeta[j] = new ColumnMetadata(); - - // concatenate data (deep copy first, append second) - ret._coldata = new Array[getNumColumns()]; - for(int j = 0; j < getNumColumns(); j++) - ret._coldata[j] = _coldata[j].clone(); - Iterator iter = IteratorFactory.getObjectRowIterator(that, _schema); - while(iter.hasNext()) - ret.appendRow(iter.next()); - } - - ret._msize = -1; - return ret; + return FrameLibAppend.append(this, that, cbind); + + // FrameBlock ret = new FrameBlock(); + // if(cbind) // COLUMN APPEND + // { + // // sanity check row dimension mismatch + // if(getNumRows() != that.getNumRows()) { + // throw new DMLRuntimeException( + // "Incompatible number of rows for cbind: " + that.getNumRows() + " (expected: " + getNumRows() + ")"); + // } + + // // allocate output frame + // ret._numRows = _numRows; + + // // concatenate schemas (w/ deep copy to prevent side effects) + // ret._schema = (ValueType[]) ArrayUtils.addAll(_schema, that._schema); + // ret._colnames = (String[]) ArrayUtils.addAll(getColumnNames(), that.getColumnNames()); + // ret._colmeta = (ColumnMetadata[]) ArrayUtils.addAll(_colmeta, that._colmeta); + + // // check and enforce unique columns names + // if(!Arrays.stream(ret._colnames).allMatch(new HashSet<>()::add)) + // ret._colnames = createColNames(ret.getNumColumns()); + + // // concatenate column data (w/ shallow copy which is safe due to copy on write semantics) + // ret._coldata = (Array[]) ArrayUtils.addAll(_coldata, that._coldata); + // } + // else // ROW APPEND + // { + // // sanity check column dimension mismatch + // if(getNumColumns() != that.getNumColumns()) { + // throw new DMLRuntimeException("Incompatible number of columns for rbind: " + that.getNumColumns() + // + " (expected: " + getNumColumns() + ")"); + // } + // ret._numRows = _numRows; // note set to previous since each row is appended on. + // ret._schema = _schema.clone(); + // ret._colnames = (_colnames != null) ? _colnames.clone() : null; + // ret._colmeta = new ColumnMetadata[getNumColumns()]; + // for(int j = 0; j < _schema.length; j++) + // ret._colmeta[j] = new ColumnMetadata(); + + // // concatenate data (deep copy first, append second) + // ret._coldata = new Array[getNumColumns()]; + // for(int j = 0; j < getNumColumns(); j++) + // ret._coldata[j] = _coldata[j].clone(); + // Iterator iter = IteratorFactory.getObjectRowIterator(that, _schema); + // while(iter.hasNext()) + // ret.appendRow(iter.next()); + // } + + // ret._msize = -1; + // return ret; } public FrameBlock copy() { @@ -1134,9 +1138,6 @@ public void copy(int rl, int ru, int cl, int cu, FrameBlock src) { } } - /////// - // transform specific functionality - /** * This function will split every Recode map in the column using delimiter Lop.DATATYPE_PREFIX, as Recode map * generated earlier in the form of Code+Lop.DATATYPE_PREFIX+Token and store it in a map which contains token and @@ -1348,37 +1349,6 @@ public FrameBlock invalidByLength(MatrixBlock feaLen) { return outBlock; } - public static FrameBlock mergeSchema(FrameBlock temp1, FrameBlock temp2) { - String[] rowTemp1 = IteratorFactory.getStringRowIterator(temp1).next(); - String[] rowTemp2 = IteratorFactory.getStringRowIterator(temp2).next(); - - if(rowTemp1.length != rowTemp2.length) - throw new DMLRuntimeException("Schema dimension " + "mismatch: " + rowTemp1.length + " vs " + rowTemp2.length); - - for(int i = 0; i < rowTemp1.length; i++) { - // modify schema1 if necessary (different schema2) - if(!rowTemp1[i].equals(rowTemp2[i])) { - if(rowTemp1[i].equals("STRING") || rowTemp2[i].equals("STRING")) - rowTemp1[i] = "STRING"; - else if(rowTemp1[i].equals("FP64") || rowTemp2[i].equals("FP64")) - rowTemp1[i] = "FP64"; - else if(rowTemp1[i].equals("FP32") && - new ArrayList<>(Arrays.asList("INT64", "INT32", "CHARACTER")).contains(rowTemp2[i])) - rowTemp1[i] = "FP32"; - else if(rowTemp1[i].equals("INT64") && - new ArrayList<>(Arrays.asList("INT32", "CHARACTER")).contains(rowTemp2[i])) - rowTemp1[i] = "INT64"; - else if(rowTemp1[i].equals("INT32") || rowTemp2[i].equals("CHARACTER")) - rowTemp1[i] = "INT32"; - } - } - - // create output block one row representing the schema as strings - FrameBlock mergedFrame = new FrameBlock(UtilFunctions.nCopies(temp1.getNumColumns(), ValueType.STRING)); - mergedFrame.appendRow(rowTemp1); - return mergedFrame; - } - public void mapInplace(Function fun) { for(int j = 0; j < getNumColumns(); j++) for(int i = 0; i < getNumRows(); i++) { diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java index 8102a2e4306..d382e935015 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java @@ -29,8 +29,8 @@ import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; -import org.apache.sysds.runtime.frame.data.FrameUtil; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType; +import org.apache.sysds.runtime.frame.data.lib.FrameUtil; import org.apache.sysds.runtime.util.UtilFunctions; import org.apache.sysds.utils.MemoryEstimates; diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java index 807c21088a2..87201d1758f 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java @@ -28,8 +28,8 @@ import org.apache.commons.lang.NotImplementedException; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; -import org.apache.sysds.runtime.frame.data.FrameUtil; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType; +import org.apache.sysds.runtime.frame.data.lib.FrameUtil; import org.apache.sysds.runtime.io.IOUtilFunctions; import org.apache.sysds.runtime.matrix.data.Pair; import org.apache.sysds.utils.MemoryEstimates; diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java new file mode 100644 index 00000000000..58a17b14e23 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.frame.data.lib; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Iterator; + +import org.apache.commons.lang.ArrayUtils; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; +import org.apache.sysds.runtime.frame.data.iterators.IteratorFactory; + +public class FrameLibAppend { + + /** + * Appends the given argument FrameBlock 'that' to this FrameBlock by creating a deep copy to prevent side effects. + * For cbind, the frames are appended column-wise (same number of rows), while for rbind the frames are appended + * row-wise (same number of columns). + * + * @param a FrameBlock to append to + * @param that frame block to append + * @param cbind if true, column append + * @return frame block + */ + public static FrameBlock append(FrameBlock a, FrameBlock b, boolean cbind) { + if(cbind) + return appendCbind(a, b); + else + return appendRbind(a, b); + } + + public static FrameBlock appendCbind(FrameBlock a, FrameBlock b) { + final int nRow = a.getNumRows(); + // sanity check row dimension mismatch + if(nRow != b.getNumRows()) + throw new DMLRuntimeException( + "Incompatible number of rows for cbind: " + b.getNumRows() + " (expected: " + nRow + ")"); + + // concatenate schemas (w/ deep copy to prevent side effects) + ValueType[] _schema = addAll(a.getSchema(), b.getSchema()); + String[] _colnames = addAll(a.getColumnNames(), b.getColumnNames()); + ColumnMetadata[] _colmeta = addAll(a.getColumnMetadata(), b.getColumnMetadata()); + + // check and enforce unique columns names + if(!Arrays.stream(_colnames).allMatch(new HashSet<>()::add)) + _colnames = null; // set to default of null. + + // concatenate column data (w/ shallow copy which is safe due to copy on write semantics) + Array[] _coldata = (Array[]) ArrayUtils.addAll(a.getColumns(), b.getColumns()); + return new FrameBlock(_schema, _colnames, _colmeta, _coldata); + } + + public static FrameBlock appendRbind(FrameBlock a, FrameBlock b) { + final int nCol = a.getNumColumns(); + // sanity check column dimension mismatch + if(nCol != b.getNumColumns()) { + throw new DMLRuntimeException("Incompatible number of columns for rbind: " + b.getNumColumns() + " (expected: " + + nCol + ")"); + } + + // ret._schema = a.getSchema().clone(); + String[] _colnames = (a.getColumnNames(false) != null) ? a.getColumnNames().clone() : null; + ColumnMetadata[] _colmeta = new ColumnMetadata[a.getNumColumns()]; + for(int j = 0; j < nCol; j++) + _colmeta[j] = new ColumnMetadata(); + + // concatenate data (deep copy first, append second) + ret._coldata = new Array[a.getNumColumns()]; + for(int j = 0; j < a.getNumColumns(); j++) + ret._coldata[j] = a._coldata[j].clone(); + Iterator iter = IteratorFactory.getObjectRowIterator(b, a._schema); + while(iter.hasNext()) + ret.appendRow(iter.next()); + + return new FrameBlock(a.getSchema().clone(), _colnames, _colmeta, _coldata); + } + + @SuppressWarnings("unchecked") + private static T[] addAll(T[] a, T[] b) { + return (T[]) ArrayUtils.addAll(a, b); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameUtil.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java similarity index 62% rename from src/main/java/org/apache/sysds/runtime/frame/data/FrameUtil.java rename to src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java index 3fb72b93263..b60f8a4fad0 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameUtil.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java @@ -17,12 +17,19 @@ * under the License. */ -package org.apache.sysds.runtime.frame.data; +package org.apache.sysds.runtime.frame.data.lib; + +import java.util.ArrayList; +import java.util.Arrays; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.frame.data.iterators.IteratorFactory; +import org.apache.sysds.runtime.util.UtilFunctions; public interface FrameUtil { public static final Log LOG = LogFactory.getLog(FrameUtil.class.getName()); @@ -87,4 +94,37 @@ else if((double)((float) val) == val ) return ValueType.FP64; } + + + public static FrameBlock mergeSchema(FrameBlock temp1, FrameBlock temp2) { + String[] rowTemp1 = IteratorFactory.getStringRowIterator(temp1).next(); + String[] rowTemp2 = IteratorFactory.getStringRowIterator(temp2).next(); + + if(rowTemp1.length != rowTemp2.length) + throw new DMLRuntimeException("Schema dimension " + "mismatch: " + rowTemp1.length + " vs " + rowTemp2.length); + + for(int i = 0; i < rowTemp1.length; i++) { + // modify schema1 if necessary (different schema2) + if(!rowTemp1[i].equals(rowTemp2[i])) { + if(rowTemp1[i].equals("STRING") || rowTemp2[i].equals("STRING")) + rowTemp1[i] = "STRING"; + else if(rowTemp1[i].equals("FP64") || rowTemp2[i].equals("FP64")) + rowTemp1[i] = "FP64"; + else if(rowTemp1[i].equals("FP32") && + new ArrayList<>(Arrays.asList("INT64", "INT32", "CHARACTER")).contains(rowTemp2[i])) + rowTemp1[i] = "FP32"; + else if(rowTemp1[i].equals("INT64") && + new ArrayList<>(Arrays.asList("INT32", "CHARACTER")).contains(rowTemp2[i])) + rowTemp1[i] = "INT64"; + else if(rowTemp1[i].equals("INT32") || rowTemp2[i].equals("CHARACTER")) + rowTemp1[i] = "INT32"; + } + } + + // create output block one row representing the schema as strings + FrameBlock mergedFrame = new FrameBlock(UtilFunctions.nCopies(temp1.getNumColumns(), ValueType.STRING)); + mergedFrame.appendRow(rowTemp1); + return mergedFrame; + } + } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/UnaryFrameSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/UnaryFrameSPInstruction.java index 75a226d1e6a..d11eb39205a 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/UnaryFrameSPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/UnaryFrameSPInstruction.java @@ -28,6 +28,7 @@ import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.lib.FrameUtil; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.cp.CPOperand; import org.apache.sysds.runtime.matrix.operators.Operator; @@ -88,7 +89,7 @@ private static class MergeFrame implements Function2 Date: Wed, 4 Jan 2023 23:11:30 +0100 Subject: [PATCH 4/5] Frame append lib function --- .../frame/data/lib/FrameLibAppend.java | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java index 58a17b14e23..f0e035e950a 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java @@ -57,17 +57,15 @@ public static FrameBlock appendCbind(FrameBlock a, FrameBlock b) { throw new DMLRuntimeException( "Incompatible number of rows for cbind: " + b.getNumRows() + " (expected: " + nRow + ")"); - // concatenate schemas (w/ deep copy to prevent side effects) - ValueType[] _schema = addAll(a.getSchema(), b.getSchema()); + final ValueType[] _schema = addAll(a.getSchema(), b.getSchema()); + final ColumnMetadata[] _colmeta = addAll(a.getColumnMetadata(), b.getColumnMetadata()); + final Array[] _coldata = addAll(a.getColumns(), b.getColumns()); String[] _colnames = addAll(a.getColumnNames(), b.getColumnNames()); - ColumnMetadata[] _colmeta = addAll(a.getColumnMetadata(), b.getColumnMetadata()); // check and enforce unique columns names if(!Arrays.stream(_colnames).allMatch(new HashSet<>()::add)) _colnames = null; // set to default of null. - // concatenate column data (w/ shallow copy which is safe due to copy on write semantics) - Array[] _coldata = (Array[]) ArrayUtils.addAll(a.getColumns(), b.getColumns()); return new FrameBlock(_schema, _colnames, _colmeta, _coldata); } @@ -75,8 +73,8 @@ public static FrameBlock appendRbind(FrameBlock a, FrameBlock b) { final int nCol = a.getNumColumns(); // sanity check column dimension mismatch if(nCol != b.getNumColumns()) { - throw new DMLRuntimeException("Incompatible number of columns for rbind: " + b.getNumColumns() + " (expected: " - + nCol + ")"); + throw new DMLRuntimeException( + "Incompatible number of columns for rbind: " + b.getNumColumns() + " (expected: " + nCol + ")"); } // ret._schema = a.getSchema().clone(); @@ -86,14 +84,14 @@ public static FrameBlock appendRbind(FrameBlock a, FrameBlock b) { _colmeta[j] = new ColumnMetadata(); // concatenate data (deep copy first, append second) - ret._coldata = new Array[a.getNumColumns()]; + Array[] _coldata = new Array[a.getNumColumns()]; for(int j = 0; j < a.getNumColumns(); j++) - ret._coldata[j] = a._coldata[j].clone(); - Iterator iter = IteratorFactory.getObjectRowIterator(b, a._schema); + _coldata[j] = a.getColumn(j).clone(); + Iterator iter = IteratorFactory.getObjectRowIterator(b, a.getSchema()); + FrameBlock ret = new FrameBlock(a.getSchema().clone(), _colnames, _colmeta, _coldata); while(iter.hasNext()) ret.appendRow(iter.next()); - - return new FrameBlock(a.getSchema().clone(), _colnames, _colmeta, _coldata); + return ret; } @SuppressWarnings("unchecked") From 602ea45769df496ddc2230d876c1d2d13c5d2c47 Mon Sep 17 00:00:00 2001 From: baunsgaard Date: Wed, 4 Jan 2023 23:16:21 +0100 Subject: [PATCH 5/5] fix docs --- .../apache/sysds/runtime/frame/data/lib/FrameLibAppend.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java index f0e035e950a..9a3995f6106 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameLibAppend.java @@ -39,9 +39,9 @@ public class FrameLibAppend { * row-wise (same number of columns). * * @param a FrameBlock to append to - * @param that frame block to append + * @param b FrameBlock to append * @param cbind if true, column append - * @return frame block + * @return frame block of the two blocks combined. */ public static FrameBlock append(FrameBlock a, FrameBlock b, boolean cbind) { if(cbind) @@ -64,7 +64,7 @@ public static FrameBlock appendCbind(FrameBlock a, FrameBlock b) { // check and enforce unique columns names if(!Arrays.stream(_colnames).allMatch(new HashSet<>()::add)) - _colnames = null; // set to default of null. + _colnames = null; // set to default of null to allocate on demand return new FrameBlock(_schema, _colnames, _colmeta, _coldata); }