Skip to content

Commit

Permalink
[SYSTEMDS-3709] Additional tests for UDF backwards compatibility
Browse files Browse the repository at this point in the history
This patch adds tests for the old SystemML UDF MultiInputCbind,
ensuring the related DML script is properly compiled to an nary cbind
and if the inputs are vectors and are reshaped to vectors, we also
eliminate the unnecessary reshape.
  • Loading branch information
mboehm7 committed Jun 7, 2024
1 parent da7889e commit 5015f63
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,20 @@

package org.apache.sysds.test.functions.binary.matrix;

import org.junit.Assert;
import org.junit.Test;

import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.apache.sysds.utils.Statistics;

public class UDFBackwardsCompatibilityTest extends AutomatedTestBase
{
private final static String TEST_NAME1 = "RowClassMeetTest";
private final static String TEST_NAME2 = "MultiInputCbindTest";
private final static String TEST_DIR = "functions/binary/matrix/";
private final static String TEST_CLASS_DIR = TEST_DIR +
UDFBackwardsCompatibilityTest.class.getSimpleName() + "/";
Expand All @@ -44,29 +48,46 @@ public class UDFBackwardsCompatibilityTest extends AutomatedTestBase
public void setUp() {
addTestConfiguration( TEST_NAME1,
new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "C" }) );
addTestConfiguration( TEST_NAME2,
new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] { "C" }) );
}

@Test
public void testRowClassMeetDenseDense() {
runUDFTest(TEST_NAME1, false, false, ExecType.CP);
runUDFTest(TEST_NAME1, false, false, false, false, ExecType.CP);
}

@Test
public void testRowClassMeetDenseSparse() {
runUDFTest(TEST_NAME1, false, true, ExecType.CP);
runUDFTest(TEST_NAME1, false, true, false, false, ExecType.CP);
}

@Test
public void testRowClassMeetSparseDense() {
runUDFTest(TEST_NAME1, true, false, ExecType.CP);
runUDFTest(TEST_NAME1, true, false, false, false, ExecType.CP);
}

@Test
public void testRowClassMeetSparseSparse() {
runUDFTest(TEST_NAME1, true, true, ExecType.CP);
runUDFTest(TEST_NAME1, true, true, false, false, ExecType.CP);
}

@Test
public void testMultiInputCBindDenseDenseMatrixMatrix() {
runUDFTest(TEST_NAME2, false, false, false, false, ExecType.CP);
}

@Test
public void testMultiInputCBindDenseDenseMatrixVector() {
runUDFTest(TEST_NAME2, false, false, false, true, ExecType.CP);
}

@Test
public void testMultiInputCBindDenseDenseVectorVector() {
runUDFTest(TEST_NAME2, false, false, true, true, ExecType.CP);
}

private void runUDFTest(String testname, boolean sparseM1, boolean sparseM2, ExecType instType)
private void runUDFTest(String testname, boolean sparseM1, boolean sparseM2, boolean vectorData, boolean vectorize, ExecType instType)
{
ExecMode platformOld = setExecMode(instType);
String TEST_NAME = testname;
Expand All @@ -76,18 +97,27 @@ private void runUDFTest(String testname, boolean sparseM1, boolean sparseM2, Exe

String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[]{"-explain","-args", input("A"), input("B"), output("C")};
programArgs = new String[]{"-stats", "-explain","-args",
input("A"), input("B"), String.valueOf(vectorize).toUpperCase(), output("C")};

//generate actual dataset
int nr = vectorData ? rows*cols : rows;
int nc = vectorData ? 1 : cols;

double[][] A = TestUtils.round(
getRandomMatrix(rows, cols, 0, 10, sparseM1?sparsity2:sparsity1, 7));
getRandomMatrix(nr, nc, 0, 10, sparseM1?sparsity2:sparsity1, 7));
writeInputMatrixWithMTD("A", A, false);
double[][] B = TestUtils.round(
getRandomMatrix(rows, cols, 0, 10, sparseM2?sparsity2:sparsity1, 3));
getRandomMatrix(nr, nc, 0, 10, sparseM2?sparsity2:sparsity1, 3));
writeInputMatrixWithMTD("B", B, false);

//run test case
runTest(true, false, null, -1);
runTest(true, false, null, -1);

if( TEST_NAME.equals(TEST_NAME2) ) //check nary cbind
Assert.assertEquals(1, Statistics.getCPHeavyHitterCount("cbind"));
if( vectorData && vectorize ) //check eliminated reshape
Assert.assertFalse(heavyHittersContainsString("rshape"));
}
finally {
rtplatform = platformOld;
Expand Down
32 changes: 32 additions & 0 deletions src/test/scripts/functions/binary/matrix/MultiInputCbindTest.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

A = read($1);
B = read($2);

if( as.logical($3) ) {
A = matrix(A, rows=length(A), cols=1)
B = matrix(B, rows=length(B), cols=1)
}

R = cbind(cbind(A, B), A);
write(R, $4);

Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@
A = read($1);
B = read($2);
[C,N] = rowClassMeet(A, B);
write(C, $3);
write(C, $4);

0 comments on commit 5015f63

Please sign in to comment.