From 35b5ed0a4f795108df4aabb39b7c4989351bc180 Mon Sep 17 00:00:00 2001
From: baunsgaard
Date: Thu, 3 Nov 2022 15:06:42 +0100
Subject: [PATCH] [MINOR] Move map function test
This commit moves the map function tests for frames to functions/frame
instead of binary/frame, to colocate the frame testing.
---
.../sysds/runtime/frame/data/FrameBlock.java | 87 ++++++++----------
.../runtime/frame/data/columns/Array.java | 3 +-
.../frame/data/columns/ArrayFactory.java | 1 +
.../frame/data/columns/BooleanArray.java | 15 +++-
.../frame/data/columns/ColumnMetadata.java | 46 +++++++---
.../frame/data/columns/DoubleArray.java | 11 +++
.../frame/data/columns/FloatArray.java | 11 +++
.../frame/data/columns/IntegerArray.java | 11 +++
.../runtime/frame/data/columns/LongArray.java | 10 +++
.../frame/data/columns/StringArray.java | 11 +++
...turnParameterizedBuiltinSPInstruction.java | 14 +--
.../sysds/runtime/io/FrameWriterFactory.java | 9 +-
.../transform/decode/DecoderDummycode.java | 7 +-
.../transform/decode/DecoderPassThrough.java | 4 +-
.../transform/encode/ColumnEncoderBin.java | 4 +-
.../encode/ColumnEncoderDummycode.java | 1 -
.../sysds/runtime/util/UtilFunctions.java | 1 +
.../python/systemds/operator/nodes/frame.py | 6 ++
.../tests/frame/test_transform_apply.py | 3 -
.../java/org/apache/sysds/test/TestUtils.java | 40 ++++++---
.../frame/transform/transformCustomTest.java | 90 +++++++++++++++++++
.../primitives/FederatedRightIndexTest.java | 54 ++++++-----
.../frame/FrameDropInvalidLengthTest.java | 7 +-
.../frame/FrameDropInvalidTypeTest.java | 42 ++++-----
.../{binary => }/frame/FrameEqualTest.java | 6 +-
.../frame/FrameIndexingDistTest.java | 1 -
.../frame/FrameMapMarginTest.java | 6 +-
.../{binary => }/frame/FrameMapTest.java | 6 +-
28 files changed, 358 insertions(+), 149 deletions(-)
create mode 100644 src/test/java/org/apache/sysds/test/component/frame/transform/transformCustomTest.java
rename src/test/java/org/apache/sysds/test/functions/{binary => }/frame/FrameEqualTest.java (98%)
rename src/test/java/org/apache/sysds/test/functions/{binary => }/frame/FrameMapMarginTest.java (97%)
rename src/test/java/org/apache/sysds/test/functions/{binary => }/frame/FrameMapTest.java (98%)
diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java
index 6078ad54207..f774d166ba2 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java
@@ -128,37 +128,30 @@ public FrameBlock(FrameBlock that) {
}
public FrameBlock(int ncols, ValueType vt) {
- this();
- _schema = UtilFunctions.nCopies(ncols, vt);
- _colnames = null; //default not materialized
- _colmeta = new ColumnMetadata[ncols];
- for( int j=0; j call(Tuple2> arg0) thr
// compute global mode of categorical feature, i.e., value with highest frequency
if(_encoder.getMethod(colix) == MVMethod.GLOBAL_MODE) {
HashMap hist = new HashMap<>();
- while(iter.hasNext()) {
+ while(iter.hasNext() ) {
ColumnMetadata cmeta = iter.next();
- Long tmp = hist.get(cmeta.getMvValue());
- hist.put(cmeta.getMvValue(), cmeta.getNumDistinct() + ((tmp != null) ? tmp : 0));
+ if(!cmeta.isDefault()){
+ Long tmp = hist.get(cmeta.getMvValue());
+ hist.put(cmeta.getMvValue(), cmeta.getNumDistinct() + ((tmp != null) ? tmp : 0));
+ }
}
long max = Long.MIN_VALUE;
String mode = null;
@@ -442,8 +444,10 @@ else if(_encoder.getMethod(colix) == MVMethod.GLOBAL_MEAN) {
int count = 0;
while(iter.hasNext()) {
ColumnMetadata cmeta = iter.next();
- kplus.execute2(kbuff, Double.parseDouble(cmeta.getMvValue()));
- count += cmeta.getNumDistinct();
+ if(!cmeta.isDefault()){
+ kplus.execute2(kbuff, Double.parseDouble(cmeta.getMvValue()));
+ count += cmeta.getNumDistinct();
+ }
}
if(count > 0)
ret.add("-2 " + colix + " " + kbuff._sum / count);
diff --git a/src/main/java/org/apache/sysds/runtime/io/FrameWriterFactory.java b/src/main/java/org/apache/sysds/runtime/io/FrameWriterFactory.java
index 3df8191ba1e..d573c049192 100644
--- a/src/main/java/org/apache/sysds/runtime/io/FrameWriterFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/io/FrameWriterFactory.java
@@ -19,13 +19,16 @@
package org.apache.sysds.runtime.io;
-import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.conf.CompilerConfig.ConfigType;
+import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.runtime.DMLRuntimeException;
-public class FrameWriterFactory
-{
+public class FrameWriterFactory {
+ protected static final Log LOG = LogFactory.getLog(FrameWriterFactory.class.getName());
+
public static FrameWriter createFrameWriter(FileFormat fmt) {
return createFrameWriter(fmt, null);
}
diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java
index 6d9480b9f7d..dec1486bebc 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java
@@ -25,8 +25,10 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
+
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.UtilFunctions;
@@ -117,8 +119,9 @@ public void initMetaData(FrameBlock meta) {
_cuPos = new int[_colList.length]; //col upper pos
for( int j=0, off=0; j<_colList.length; j++ ) {
int colID = _colList[j];
- int ndist = (int)meta.getColumnMetadata()[colID-1]
- .getNumDistinct();
+ ColumnMetadata d = meta.getColumnMetadata()[colID-1];
+ int ndist = d.isDefault() ? 0 : (int)d.getNumDistinct();
+ ndist = ndist < -1 ? 0: ndist;
_clPos[j] = off + colID;
_cuPos[j] = _clPos[j] + ndist;
off += ndist - 1;
diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java
index e4b4c3771a5..2a90696747e 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java
@@ -28,6 +28,7 @@
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.UtilFunctions;
@@ -107,7 +108,8 @@ public void initMetaData(FrameBlock meta) {
ix1 ++;
}
else { //_colList[ix1] > _dcCols[ix2]
- off += (int)meta.getColumnMetadata()[_dcCols[ix2]-1].getNumDistinct() - 1;
+ ColumnMetadata d =meta.getColumnMetadata()[_dcCols[ix2]-1];
+ off += d.isDefault() ? -1 : d.getNumDistinct() - 1;
ix2 ++;
}
}
diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
index fdb23588258..2b71abcc351 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
@@ -165,7 +165,7 @@ protected double[] getCodeCol(CacheBlock in, int startInd, int blkSize) {
int endInd = getEndIndex(in.getNumRows(), startInd, blkSize);
double[] codes = new double[endInd-startInd];
for (int i=startInd; i 'Frame':
"""
return Frame(self.sds_context, "replace", named_input_nodes={"target": self, "pattern": f"'{pattern}'", "replacement": f"'{replacement}'"})
+ def to_string(self, **kwargs: Dict[str, VALID_INPUT_TYPES]) -> 'Scalar':
+ """ Converts the input to a string representation.
+ :return: `Scalar` containing the string.
+ """
+ return Scalar(self.sds_context, 'toString', [self], kwargs, output_type=OutputType.STRING)
+
def __str__(self):
return "FrameNode"
diff --git a/src/main/python/tests/frame/test_transform_apply.py b/src/main/python/tests/frame/test_transform_apply.py
index 79259943e4d..9cbd22292e7 100644
--- a/src/main/python/tests/frame/test_transform_apply.py
+++ b/src/main/python/tests/frame/test_transform_apply.py
@@ -20,9 +20,6 @@
# -------------------------------------------------------------
import json
-import os
-import shutil
-import sys
import unittest
import numpy as np
diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java b/src/test/java/org/apache/sysds/test/TestUtils.java
index d61a5f5f3c0..6cbbca96160 100644
--- a/src/test/java/org/apache/sysds/test/TestUtils.java
+++ b/src/test/java/org/apache/sysds/test/TestUtils.java
@@ -303,11 +303,17 @@ public static void readValuesFromFileStreamAndPut(BufferedReader inReader, HashM
{
String line = null;
while ((line = inReader.readLine()) != null) {
- StringTokenizer st = new StringTokenizer(line, " ");
- int i = Integer.parseInt(st.nextToken());
- int j = Integer.parseInt(st.nextToken());
- double v = Double.parseDouble(st.nextToken());
- values.put(new CellIndex(i, j), v);
+ try{
+
+ StringTokenizer st = new StringTokenizer(line, " ");
+ int i = Integer.parseInt(st.nextToken());
+ int j = Integer.parseInt(st.nextToken());
+ double v = Double.parseDouble(st.nextToken());
+ values.put(new CellIndex(i, j), v);
+ }
+ catch(Exception e){
+ throw new IOException("failed parsing line:" + line,e);
+ }
}
}
@@ -477,11 +483,10 @@ public static HashMap readDMLMatrixFromHDFS(String filePath)
{
HashMap expectedValues = new HashMap<>();
+ Path outDirectory = new Path(filePath);
try
{
- Path outDirectory = new Path(filePath);
FileSystem fs = IOUtilFunctions.getFileSystem(outDirectory, conf);
-
FileStatus[] outFiles = fs.listStatus(outDirectory);
for (FileStatus file : outFiles) {
FSDataInputStream outIn = fs.open(file.getPath());
@@ -489,7 +494,20 @@ public static HashMap readDMLMatrixFromHDFS(String filePath)
}
}
catch (IOException e) {
- assertTrue("could not read from file " + filePath+": "+e.getMessage(), false);
+ e.printStackTrace();
+ try{
+ FileSystem fs = IOUtilFunctions.getFileSystem(outDirectory, conf);
+ FileStatus[] outFiles = fs.listStatus(outDirectory);
+ String fileContent = "";
+ for (FileStatus file : outFiles) {
+ FSDataInputStream outIn = fs.open(file.getPath());
+ fileContent += new String(outIn.readAllBytes());
+ }
+ fail("could not read from file " + filePath+": "+e.getMessage() + "\ncontent:\n" + fileContent);
+
+ }catch (IOException e2){
+ fail("could not read from file " + filePath+": "+e.getMessage());
+ }
}
return expectedValues;
@@ -2140,10 +2158,7 @@ public static void generateTestMatrixToFile(String file, int rows, int cols, dou
*
*/
public static FrameBlock generateRandomFrameBlock(int rows, int cols, ValueType[] schema, Random random){
- String[] names = new String[cols];
- for(int i = 0; i < cols; i++)
- names[i] = schema[i].toString();
- FrameBlock frameBlock = new FrameBlock(schema, names);
+ FrameBlock frameBlock = new FrameBlock(schema);
frameBlock.ensureAllocatedColumns(rows);
for(int row = 0; row < rows; row++)
for(int col = 0; col < cols; col++)
@@ -2299,6 +2314,7 @@ public static String generateRandomJSONPath(int len, long seed){
*/
public static Object generateRandomValueFromValueType(ValueType valueType, Random random){
switch (valueType){
+ case UINT8: return random.nextInt(256);
case FP32: return random.nextFloat();
case FP64: return random.nextDouble();
case INT32: return random.nextInt();
diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/transformCustomTest.java b/src/test/java/org/apache/sysds/test/component/frame/transform/transformCustomTest.java
new file mode 100644
index 00000000000..be6d0fe4974
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/component/frame/transform/transformCustomTest.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.frame.transform;
+
+import static org.junit.Assert.fail;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.transform.encode.EncoderFactory;
+import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+public class transformCustomTest {
+ protected static final Log LOG = LogFactory.getLog(transformCustomTest.class.getName());
+
+ final FrameBlock data;
+
+ public transformCustomTest() {
+ data = TestUtils.generateRandomFrameBlock(100, 1, new ValueType[] {ValueType.UINT8}, 231);
+ data.setSchema(new ValueType[] {ValueType.INT32});
+ }
+
+ @Test
+ public void testRecode() {
+ test("{recode:[C1]}");
+ }
+
+ @Test
+ public void testBin() {
+ test("{ids:true, bin:[{id:1, method:equi-width, numbins:4}]}");
+ }
+
+ @Test
+ public void testBin2() {
+ test("{ids:true, bin:[{id:1, method:equi-width, numbins:100}]}");
+ }
+
+ @Test
+ public void testBin3() {
+ test("{ids:true, bin:[{id:1, method:equi-width, numbins:2}]}");
+ }
+
+ @Test
+ public void testBin4() {
+ test("{ids:true, bin:[{id:1, method:equi-height, numbins:2}]}");
+ }
+
+ @Test
+ public void testBin5() {
+ test("{ids:true, bin:[{id:1, method:equi-height, numbins:10}]}");
+ }
+
+ public void test(String spec) {
+ try {
+
+ FrameBlock meta = null;
+ MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, data.getColumnNames(), data.getNumColumns(),
+ meta);
+ MatrixBlock out = encoder.encode(data);
+ MatrixBlock out2 = encoder.apply(data);
+
+ TestUtils.compareMatrices(out, out2, 0, "Not Equal after apply");
+ }
+ catch(Exception e) {
+ e.printStackTrace();
+ fail(e.getMessage());
+ }
+ }
+}
diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java
index 0139137cd61..7b8e73b4554 100644
--- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRightIndexTest.java
@@ -19,6 +19,8 @@
package org.apache.sysds.test.functions.federated.primitives;
+import static org.junit.Assert.fail;
+
import java.util.Arrays;
import java.util.Collection;
@@ -67,8 +69,10 @@ public class FederatedRightIndexTest extends AutomatedTestBase {
@Parameterized.Parameters
public static Collection