Skip to content

Commit

Permalink
[MINOR] Frame Float detection refinement
Browse files Browse the repository at this point in the history
This commit refine the detection and selection of float values in the
schema detection algorithm located in:

src/main/java/org/apache/sysds/runtime/frame/data/lib/FrameUtil.java

To improve the performance a custom matcher have been made to avoid using
regexes if possible, and we try to avoid parsing the float value from the
string representation if at all possible. The implementation is not
completely fool proof and does not consider many adversarial inputs.

on a small test case of 1000 string values that all are fp32 the
implementation improve performance from 0.5 ms to 0.02 ms.
  • Loading branch information
Baunsgaard committed Sep 6, 2023
1 parent 9028579 commit 4851f34
Show file tree
Hide file tree
Showing 17 changed files with 232 additions and 60 deletions.
91 changes: 54 additions & 37 deletions src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java
Expand Up @@ -31,9 +31,11 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.Function;

Expand All @@ -56,6 +58,7 @@
import org.apache.sysds.runtime.frame.data.iterators.IteratorFactory;
import org.apache.sysds.runtime.frame.data.lib.FrameFromMatrixBlock;
import org.apache.sysds.runtime.frame.data.lib.FrameLibAppend;
import org.apache.sysds.runtime.frame.data.lib.FrameLibApplySchema;
import org.apache.sysds.runtime.frame.data.lib.FrameLibDetectSchema;
import org.apache.sysds.runtime.frame.data.lib.FrameLibRemoveEmpty;
import org.apache.sysds.runtime.frame.data.lib.FrameUtil;
Expand Down Expand Up @@ -214,8 +217,8 @@ public FrameBlock(Array<?>[] data) {
if(debug) {
for(int i = 0; i < data.length; i++) {
if(data[i].size() != getNumRows())
throw new DMLRuntimeException(
"Invalid Frame allocation with different size arrays " + data[i].size() + " vs " + getNumRows());
throw new DMLRuntimeException("Invalid Frame allocation with different size arrays "
+ data[i].size() + " vs " + getNumRows());
}
}
}
Expand All @@ -239,8 +242,8 @@ public FrameBlock(Array<?>[] data, String[] colnames) {
if(debug) {
for(int i = 0; i < data.length; i++) {
if(data[i].size() != getNumRows())
throw new DMLRuntimeException(
"Invalid Frame allocation with different size arrays " + data[i].size() + " vs " + getNumRows());
throw new DMLRuntimeException("Invalid Frame allocation with different size arrays "
+ data[i].size() + " vs " + getNumRows());
}
}
}
Expand Down Expand Up @@ -400,8 +403,8 @@ public Map<String, Integer> getColumnNameIDMap() {
}

/**
* Allocate column data structures if necessary, i.e., if schema specified but not all column data structures created
* yet.
* Allocate column data structures if necessary, i.e., if schema specified but not all column data structures
* created yet.
*
* @param numRows number of rows
*/
Expand Down Expand Up @@ -640,8 +643,8 @@ public void appendColumn(int[] col) {
}

/**
* Append a column of value type LONG as the last column of the data frame. The given array is wrapped but not copied
* and hence might be updated in the future.
* Append a column of value type LONG as the last column of the data frame. The given array is wrapped but not
* copied and hence might be updated in the future.
*
* @param col array of longs
*/
Expand Down Expand Up @@ -701,7 +704,9 @@ public void appendColumns(double[][] cols) {
Array[] tmpData = new Array[ncol];
for(int j = 0; j < ncol; j++)
tmpData[j] = ArrayFactory.create(cols[j]);
_colnames = empty ? null : ArrayUtils.addAll(getColumnNames(), createColNames(getNumColumns(), ncol)); // before schema modification
_colnames = empty ? null : ArrayUtils.addAll(getColumnNames(), createColNames(getNumColumns(), ncol)); // before
// schema
// modification
_schema = empty ? tmpSchema : ArrayUtils.addAll(_schema, tmpSchema);
_coldata = empty ? tmpData : ArrayUtils.addAll(_coldata, tmpData);
_nRow = cols[0].length;
Expand Down Expand Up @@ -859,17 +864,22 @@ private double arraysSizeInMemory() {
if(rlen > 1000 && clen > 10 && ConfigurationManager.isParallelIOEnabled()) {
final ExecutorService pool = CommonThreadPool.get();
try {
size += pool.submit(() -> {
return Arrays.stream(_coldata).parallel() // parallel columns
.map(x ->x.getInMemorySize()).reduce(0L, (a,x) -> a + x);
}).get();
List<Future<Long>> f = new ArrayList<>(clen);
for(int i = 0; i < clen; i++) {
final int j = i;
f.add(pool.submit(() -> _coldata[j].getInMemorySize()));
}

for(Future<Long> e : f) {
size += e.get();
}
}
catch(InterruptedException | ExecutionException e) {
LOG.error(e);
for(Array<?> aa : _coldata)
size += aa.getInMemorySize();
}
finally{
finally {
pool.shutdown();
}
}
Expand Down Expand Up @@ -1012,11 +1022,11 @@ public FrameBlock leftIndexingOperations(FrameBlock rhsFrame, IndexRange ixrange

public FrameBlock leftIndexingOperations(FrameBlock rhsFrame, int rl, int ru, int cl, int cu, FrameBlock ret) {
// check the validity of bounds
if(rl < 0 || rl >= getNumRows() || ru < rl || ru >= getNumRows() || cl < 0 || cu >= getNumColumns() || cu < cl ||
cu >= getNumColumns()) {
if(rl < 0 || rl >= getNumRows() || ru < rl || ru >= getNumRows() || cl < 0 || cu >= getNumColumns() ||
cu < cl || cu >= getNumColumns()) {
throw new DMLRuntimeException(
"Invalid values for frame indexing: [" + (rl + 1) + ":" + (ru + 1) + "," + (cl + 1) + ":" + (cu + 1) + "] "
+ "must be within frame dimensions [" + getNumRows() + "," + getNumColumns() + "].");
"Invalid values for frame indexing: [" + (rl + 1) + ":" + (ru + 1) + "," + (cl + 1) + ":" + (cu + 1)
+ "] " + "must be within frame dimensions [" + getNumRows() + "," + getNumColumns() + "].");
}

if((ru - rl + 1) < rhsFrame.getNumRows() || (cu - cl + 1) < rhsFrame.getNumColumns()) {
Expand Down Expand Up @@ -1132,11 +1142,11 @@ public FrameBlock slice(int rl, int ru, int cl, int cu, boolean deep, FrameBlock
}

protected void validateSliceArgument(int rl, int ru, int cl, int cu) {
if(rl < 0 || rl >= getNumRows() || ru < rl || ru >= getNumRows() || cl < 0 || cu >= getNumColumns() || cu < cl ||
cu >= getNumColumns()) {
if(rl < 0 || rl >= getNumRows() || ru < rl || ru >= getNumRows() || cl < 0 || cu >= getNumColumns() ||
cu < cl || cu >= getNumColumns()) {
throw new DMLRuntimeException(
"Invalid values for frame indexing: [" + (rl + 1) + ":" + (ru + 1) + "," + (cl + 1) + ":" + (cu + 1) + "] "
+ "must be within frame dimensions [" + getNumRows() + "," + getNumColumns() + "]");
"Invalid values for frame indexing: [" + (rl + 1) + ":" + (ru + 1) + "," + (cl + 1) + ":" + (cu + 1)
+ "] " + "must be within frame dimensions [" + getNumRows() + "," + getNumColumns() + "]");
}
}

Expand Down Expand Up @@ -1338,6 +1348,14 @@ public final FrameBlock detectSchema(double sampleFraction, int k) {
return FrameLibDetectSchema.detectSchema(this, sampleFraction, k);
}

public final FrameBlock applySchema(FrameBlock schema) {
return FrameLibApplySchema.applySchema(this, schema);
}

public final FrameBlock applySchema(FrameBlock schema, int k) {
return FrameLibApplySchema.applySchema(this, schema, k);
}

/**
* Drop the cell value which does not confirms to the data type of its column
*
Expand All @@ -1347,8 +1365,8 @@ public final FrameBlock detectSchema(double sampleFraction, int k) {
public FrameBlock dropInvalidType(FrameBlock schema) {
// sanity checks
if(this.getNumColumns() != schema.getNumColumns())
throw new DMLException("mismatch in number of columns in frame and its schema " + this.getNumColumns() + " != "
+ schema.getNumColumns());
throw new DMLException("mismatch in number of columns in frame and its schema " + this.getNumColumns()
+ " != " + schema.getNumColumns());

// extract the schema in String array
String[] schemaString = IteratorFactory.getStringRowIterator(schema).next();
Expand All @@ -1375,8 +1393,8 @@ else if(schemaCol.contains("STRING"))

if(!dataType.toString().contains(type) && !(dataType == ValueType.BOOLEAN && type.equals("INT")) &&
!(dataType == ValueType.BOOLEAN && type.equals("FP"))) {
LOG.warn("Datatype detected: " + dataType + " where expected: " + schemaString[i] + " col: " + (i + 1)
+ ", row:" + (j + 1));
LOG.warn("Datatype detected: " + dataType + " where expected: " + schemaString[i] + " col: "
+ (i + 1) + ", row:" + (j + 1));

this.set(j, i, null);
}
Expand Down Expand Up @@ -1554,12 +1572,9 @@ public FrameBlock map(FrameMapFunction lambdaExpr, long margin) {
else if(margin == 2) {
// Execute map function on columns
for(int j = 0; j < getNumColumns(); j++) {
String[] actualColumn = Arrays.copyOfRange((String[]) getColumnData(j), 0, getNumRows()); // since more rows
// can be
// allocated,
// mutable array
// since more rows can be allocated, mutable array
String[] actualColumn = Arrays.copyOfRange((String[]) getColumnData(j), 0, getNumRows());
String[] outColumn = lambdaExpr.apply(actualColumn);

for(int i = 0; i < getNumRows(); i++)
output[i][j] = outColumn[i];
}
Expand Down Expand Up @@ -1615,7 +1630,8 @@ public static FrameMapFunction getCompiledFunction(String lambdaExpr, long margi
sb.append(" return String.valueOf(" + expr + "); }}\n");
}
else if(varname.length == 2) {
sb.append("public String apply(String " + varname[0].trim() + ", String " + varname[1].trim() + ") {\n");
sb.append(
"public String apply(String " + varname[0].trim() + ", String " + varname[1].trim() + ") {\n");
sb.append(" return String.valueOf(" + expr + "); }}\n");
}
}
Expand Down Expand Up @@ -1651,11 +1667,12 @@ public <T> FrameBlock replaceOperations(String pattern, String replacement) {

boolean NaNp = "NaN".equals(pattern);
boolean NaNr = "NaN".equals(replacement);
ValueType patternType = UtilFunctions.isBoolean(pattern) ? ValueType.BOOLEAN : (NumberUtils.isCreatable(pattern) |
NaNp ? (UtilFunctions.isIntegerNumber(pattern) ? ValueType.INT64 : ValueType.FP64) : ValueType.STRING);
ValueType replacementType = UtilFunctions
.isBoolean(replacement) ? ValueType.BOOLEAN : (NumberUtils.isCreatable(replacement) |
NaNr ? (UtilFunctions.isIntegerNumber(replacement) ? ValueType.INT64 : ValueType.FP64) : ValueType.STRING);
ValueType patternType = UtilFunctions
.isBoolean(pattern) ? ValueType.BOOLEAN : (NumberUtils.isCreatable(pattern) |
NaNp ? (UtilFunctions.isIntegerNumber(pattern) ? ValueType.INT64 : ValueType.FP64) : ValueType.STRING);
ValueType replacementType = UtilFunctions.isBoolean(replacement) ? ValueType.BOOLEAN : (NumberUtils
.isCreatable(replacement) |
NaNr ? (UtilFunctions.isIntegerNumber(replacement) ? ValueType.INT64 : ValueType.FP64) : ValueType.STRING);

if(patternType != replacementType || !ValueType.isSameTypeString(patternType, replacementType))
throw new DMLRuntimeException(
Expand Down
Expand Up @@ -38,4 +38,9 @@ public ABooleanArray(int size) {

@Override
public abstract ABooleanArray select(boolean[] select, int nTrue);

@Override
public boolean possiblyContainsNaN(){
return false;
}
}
Expand Up @@ -394,6 +394,8 @@ public boolean containsNull() {
return false;
}

public abstract boolean possiblyContainsNaN();

public Array<?> changeTypeWithNulls(ValueType t) {
final ABooleanArray nulls = getNulls();
if(nulls == null)
Expand Down
Expand Up @@ -335,6 +335,11 @@ public boolean equals(Array<Character> other){
return false;
}

@Override
public boolean possiblyContainsNaN(){
return false;
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder(_data.length * 2 + 15);
Expand Down
Expand Up @@ -315,6 +315,12 @@ public boolean equals(Array<T> other) {
return false;
}

@Override
public boolean possiblyContainsNaN(){
return dict.possiblyContainsNaN();
}


@Override
public String toString() {
StringBuilder sb = new StringBuilder();
Expand Down
Expand Up @@ -387,6 +387,11 @@ public boolean equals(Array<Double> other) {
return false;
}

@Override
public boolean possiblyContainsNaN(){
return true;
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder(_data.length * 5 + 2);
Expand Down
Expand Up @@ -339,6 +339,11 @@ public boolean equals(Array<Float> other) {
return false;
}

@Override
public boolean possiblyContainsNaN(){
return true;
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder(_data.length * 5 + 2);
Expand Down
Expand Up @@ -344,6 +344,11 @@ public boolean equals(Array<Integer> other) {
return false;
}

@Override
public boolean possiblyContainsNaN(){
return false;
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder(_data.length * 5 + 2);
Expand Down
Expand Up @@ -346,6 +346,11 @@ public boolean equals(Array<Long> other) {
return false;
}

@Override
public boolean possiblyContainsNaN(){
return false;
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder(_data.length * 5 + 2);
Expand Down
Expand Up @@ -456,6 +456,11 @@ public boolean equals(Array<T> other) {
return false;
}

@Override
public boolean possiblyContainsNaN(){
return true;
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder(_size + 2);
Expand Down
Expand Up @@ -395,6 +395,11 @@ public boolean containsNull() {
return (_a.size() < super._size) || _a.containsNull();
}

@Override
public boolean possiblyContainsNaN(){
return true;
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder(_size + 2);
Expand Down
Expand Up @@ -37,6 +37,9 @@
import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode;
import org.apache.sysds.utils.MemoryEstimates;

import ch.randelshofer.fastdoubleparser.JavaDoubleParser;
import ch.randelshofer.fastdoubleparser.JavaFloatParser;

public class StringArray extends Array<String> {
private String[] _data;

Expand Down Expand Up @@ -466,7 +469,7 @@ protected Array<Double> changeTypeDouble() {
for(int i = 0; i < size(); i++) {
final String s = _data[i];
if(s != null)
ret[i] = Double.parseDouble(s);
ret[i] = JavaDoubleParser.parseDouble(s);
}
return new DoubleArray(ret);
}
Expand All @@ -482,7 +485,7 @@ protected Array<Float> changeTypeFloat() {
for(int i = 0; i < size(); i++) {
final String s = _data[i];
if(s != null)
ret[i] = Float.parseFloat(s);
ret[i] = JavaFloatParser.parseFloat(s);
}
return new FloatArray(ret);
}
Expand Down Expand Up @@ -678,6 +681,11 @@ public boolean equals(Array<String> other) {
return false;
}

@Override
public boolean possiblyContainsNaN(){
return true;
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder(_size * 5 + 2);
Expand Down

0 comments on commit 4851f34

Please sign in to comment.