diff --git a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java index 1d30d13fea3..1906ee818e8 100644 --- a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java @@ -49,6 +49,7 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier public static final String TF_FN_PARAM_DATA = "target"; public static final String TF_FN_PARAM_MTD2 = "meta"; public static final String TF_FN_PARAM_SPEC = "spec"; + public static final String TF_FN_PARAM_EMBD = "embedding"; public static final String LINEAGE_TRACE = "lineage"; public static final String TF_FN_PARAM_MTD = "transformPath"; //NOTE MB: for backwards compatibility @@ -617,11 +618,14 @@ private void validateTransformApply(DataIdentifier output, boolean conditional) //validate data / metadata (recode maps) checkDataType(false, "transformapply", TF_FN_PARAM_DATA, DataType.FRAME, conditional); checkDataType(false, "transformapply", TF_FN_PARAM_MTD2, DataType.FRAME, conditional); - + //validate specification checkDataValueType(false, "transformapply", TF_FN_PARAM_SPEC, DataType.SCALAR, ValueType.STRING, conditional); validateTransformSpec(TF_FN_PARAM_SPEC, conditional); - + + //validate additional argument for word_embeddings tranform + checkDataType(true, "transformapply", TF_FN_PARAM_EMBD, DataType.MATRIX, conditional); + //set output dimensions output.setDataType(DataType.MATRIX); output.setValueType(ValueType.FP64); diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java index 44f8846ea97..719ad3a9cd2 100644 --- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java +++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java @@ -178,7 +178,7 @@ public DenseBlock set(int r, double[] v) { System.arraycopy(v, 0, _data, pos(r), _odims[0]); return this; } - + @Override public DenseBlock set(int[] ix, double v) { _data[pos(ix)] = v; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java index 9dfbdbec7f4..18a199e9308 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java @@ -54,7 +54,11 @@ import org.apache.sysds.runtime.util.AutoDiff; import org.apache.sysds.runtime.util.DataConverter; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -310,11 +314,12 @@ else if(opcode.equalsIgnoreCase("transformapply")) { // acquire locks FrameBlock data = ec.getFrameInput(params.get("target")); FrameBlock meta = ec.getFrameInput(params.get("meta")); + MatrixBlock embeddings = params.get("embedding") != null ? ec.getMatrixInput(params.get("embedding")) : null; String[] colNames = data.getColumnNames(); // compute transformapply MultiColumnEncoder encoder = EncoderFactory - .createEncoder(params.get("spec"), colNames, data.getNumColumns(), meta); + .createEncoder(params.get("spec"), colNames, data.getNumColumns(), meta, embeddings); MatrixBlock mbout = encoder.apply(data, OptimizerUtils.getTransformNumThreads()); // release locks @@ -346,7 +351,7 @@ else if(opcode.equalsIgnoreCase("transformcolmap")) { // compute transformapply MultiColumnEncoder encoder = EncoderFactory - .createEncoder(params.get("spec"), colNames, meta.getNumColumns(), null); + .createEncoder(params.get("spec"), colNames, meta.getNumColumns(), null, null); MatrixBlock mbout = encoder.getColMapping(meta); // release locks @@ -532,6 +537,8 @@ else if(opcode.equalsIgnoreCase("transformdecode") || opcode.equalsIgnoreCase("t CPOperand target = new CPOperand(params.get("target"), ValueType.FP64, DataType.FRAME); CPOperand meta = getLiteral("meta", ValueType.UNKNOWN, DataType.FRAME); CPOperand spec = getStringLiteral("spec"); + //FIXME: Taking only spec file name as a literal leads to wrong reuse + //TODO: Add Embedding to the lineage item return Pair.of(output.getName(), new LineageItem(getOpcode(), LineageItemUtils.getLineage(ec, target, meta, spec))); } diff --git a/src/main/java/org/apache/sysds/runtime/transform/TfUtils.java b/src/main/java/org/apache/sysds/runtime/transform/TfUtils.java index ec4758a819e..b264004b612 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/TfUtils.java +++ b/src/main/java/org/apache/sysds/runtime/transform/TfUtils.java @@ -47,7 +47,7 @@ protected byte toID() { //transform methods public enum TfMethod { - IMPUTE, RECODE, HASH, BIN, DUMMYCODE, UDF, OMIT; + IMPUTE, RECODE, HASH, BIN, DUMMYCODE, UDF, OMIT, WORD_EMBEDDING; @Override public String toString() { return name().toLowerCase(); diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java index 610e0cc4145..3020553e713 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java @@ -65,8 +65,13 @@ public abstract class ColumnEncoder implements Encoder, Comparable in, MatrixBlock out, int outputCol, int r case DUMMYCODE: TransformStatistics.incDummyCodeApplyTime(t); break; + case WORD_EMBEDDING: + TransformStatistics.incWordEmbeddingApplyTime(t); + break; case FEATURE_HASH: TransformStatistics.incFeatureHashingApplyTime(t); break; diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java index 6f18263a26d..fd69d5bf26d 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java @@ -319,6 +319,12 @@ public void initMetaData(FrameBlock out) { columnEncoder.initMetaData(out); } + //pass down init to actual encoders, only ColumnEncoderWordEmbedding has actually implemented the init method + public void initEmbeddings(MatrixBlock embeddings){ + for(ColumnEncoder columnEncoder : _columnEncoders) + columnEncoder.initEmbeddings(embeddings); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java new file mode 100644 index 00000000000..03584cf5ee8 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.transform.encode; + +import org.apache.commons.lang.NotImplementedException; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.CacheBlock; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; + +import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex; + +public class ColumnEncoderWordEmbedding extends ColumnEncoder { + private MatrixBlock wordEmbeddings; + + //domain size is equal to the number columns of the embedding column (equal to length of an embedding vector) + @Override + public int getDomainSize(){ + return wordEmbeddings.getNumColumns(); + } + protected ColumnEncoderWordEmbedding(int colID) { + super(colID); + } + + @Override + protected double getCode(CacheBlock in, int row) { + throw new NotImplementedException(); + } + + @Override + protected double[] getCodeCol(CacheBlock in, int startInd, int blkSize) { + throw new NotImplementedException(); + } + + //previous recode replaced strings with indices of the corresponding matrix row index + //now, the indices are replaced with actual word embedding vectors + //current limitation: in case the transform is done on multiple cols, the same embedding + //matrix is used for both transform + @Override + public void applyDense(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk){ + if (!(in instanceof MatrixBlock)){ + throw new DMLRuntimeException("ColumnEncoderWordEmbedding called with: " + in.getClass().getSimpleName() + + " and not MatrixBlock"); + } + int rowEnd = getEndIndex(in.getNumRows(), rowStart, blk); + //map each recoded index to the corresponding embedding vector + for(int i=rowStart; i in) { + throw new NotImplementedException(); + } + + @Override + public void allocateMetaData(FrameBlock meta) { + throw new NotImplementedException(); + } + + @Override + public FrameBlock getMetaData(FrameBlock out) { + throw new NotImplementedException(); + } + + @Override + public void initMetaData(FrameBlock meta) { + return; + } + + //save embeddings matrix reference for apply step + @Override + public void initEmbeddings(MatrixBlock embeddings){ + this.wordEmbeddings = embeddings; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java index 075b6fbdd40..313258831ae 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java @@ -36,6 +36,7 @@ import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.transform.TfUtils.TfMethod; import org.apache.sysds.runtime.transform.encode.ColumnEncoder.EncoderType; import org.apache.sysds.runtime.transform.meta.TfMetaUtils; @@ -68,7 +69,21 @@ public static MultiColumnEncoder createEncoder(String spec, String[] colnames, V } public static MultiColumnEncoder createEncoder(String spec, String[] colnames, ValueType[] schema, FrameBlock meta, - int minCol, int maxCol) { + int minCol, int maxCol){ + return createEncoder(spec, colnames, schema, meta, null, minCol, maxCol); + } + + public static MultiColumnEncoder createEncoder(String spec, String[] colnames, int clen, FrameBlock meta, MatrixBlock embeddings) { + return createEncoder(spec, colnames, UtilFunctions.nCopies(clen, ValueType.STRING), meta, embeddings); + } + + public static MultiColumnEncoder createEncoder(String spec, String[] colnames, ValueType[] schema, + FrameBlock meta, MatrixBlock embeddings) { + return createEncoder(spec, colnames, schema, meta, embeddings, -1, -1); + } + + public static MultiColumnEncoder createEncoder(String spec, String[] colnames, ValueType[] schema, FrameBlock meta, + MatrixBlock embeddings, int minCol, int maxCol) { MultiColumnEncoder encoder; int clen = schema.length; @@ -88,9 +103,18 @@ public static MultiColumnEncoder createEncoder(String spec, String[] colnames, V List dcIDs = Arrays.asList(ArrayUtils .toObject(TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.DUMMYCODE.toString(), minCol, maxCol))); List binIDs = TfMetaUtils.parseBinningColIDs(jSpec, colnames, minCol, maxCol); + List weIDs = Arrays.asList(ArrayUtils + .toObject(TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.WORD_EMBEDDING.toString(), minCol, maxCol))); + + //check if user passed an embeddings matrix + if(!weIDs.isEmpty() && embeddings == null) + throw new DMLRuntimeException("Missing argument Embeddings Matrix for transform [" + TfMethod.WORD_EMBEDDING + "]"); + // NOTE: any dummycode column requires recode as preparation, unless the dummycode // column follows binning or feature hashing rcIDs = unionDistinct(rcIDs, except(except(dcIDs, binIDs), haIDs)); + // NOTE: Word Embeddings requires recode as preparation + rcIDs = unionDistinct(rcIDs, weIDs); // Error out if the first level encoders have overlaps if (intersect(rcIDs, binIDs, haIDs)) throw new DMLRuntimeException("More than one encoders (recode, binning, hashing) on one column is not allowed"); @@ -114,7 +138,9 @@ public static MultiColumnEncoder createEncoder(String spec, String[] colnames, V if(!ptIDs.isEmpty()) for(Integer id : ptIDs) addEncoderToMap(new ColumnEncoderPassThrough(id), colEncoders); - + if(!weIDs.isEmpty()) + for(Integer id : weIDs) + addEncoderToMap(new ColumnEncoderWordEmbedding(id), colEncoders); if(!binIDs.isEmpty()) for(Object o : (JSONArray) jSpec.get(TfMethod.BIN.toString())) { JSONObject colspec = (JSONObject) o; @@ -185,6 +211,9 @@ else if ("EQUI-HEIGHT".equals(method)) } encoder.initMetaData(meta); } + //initialize embeddings matrix block in the encoders in case word embedding transform is used + if(!weIDs.isEmpty()) + encoder.initEmbeddings(embeddings); } catch(Exception ex) { throw new DMLRuntimeException(ex); diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java index 6838cdd1e29..59c22f5640c 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java @@ -314,10 +314,12 @@ public MatrixBlock apply(CacheBlock in) { public MatrixBlock apply(CacheBlock in, int k) { // domain sizes are not updated if called from transformapply boolean hasUDF = _columnEncoders.stream().anyMatch(e -> e.hasEncoder(ColumnEncoderUDF.class)); + boolean hasWE = _columnEncoders.stream().anyMatch(e -> e.hasEncoder(ColumnEncoderWordEmbedding.class)); for(ColumnEncoderComposite columnEncoder : _columnEncoders) columnEncoder.updateAllDCEncoders(); int numCols = getNumOutCols(); - long estNNz = (long) in.getNumRows() * (hasUDF ? numCols : (long) in.getNumColumns()); + long estNNz = (long) in.getNumRows() * (hasUDF ? numCols : hasWE ? getEstNNzRow() : (long) in.getNumColumns()); + // FIXME: estimate nnz for multiple encoders including dummycode and embedding boolean sparse = MatrixBlock.evalSparseFormatInMemory(in.getNumRows(), numCols, estNNz) && !hasUDF; MatrixBlock out = new MatrixBlock(in.getNumRows(), numCols, sparse, estNNz); return apply(in, out, 0, k); @@ -353,8 +355,7 @@ public MatrixBlock apply(CacheBlock in, MatrixBlock out, int outputCol, int k int offset = outputCol; for(ColumnEncoderComposite columnEncoder : _columnEncoders) { columnEncoder.apply(in, out, columnEncoder._colID - 1 + offset); - if (columnEncoder.hasEncoder(ColumnEncoderDummycode.class)) - offset += columnEncoder.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1; + offset = getOffset(offset, columnEncoder); } } // Recomputing NNZ since we access the block directly @@ -373,12 +374,19 @@ private List> getApplyTasks(CacheBlock in, MatrixBlock out, int offset = outputCol; for(ColumnEncoderComposite e : _columnEncoders) { tasks.addAll(e.getApplyTasks(in, out, e._colID - 1 + offset)); - if(e.hasEncoder(ColumnEncoderDummycode.class)) - offset += e.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1; + offset = getOffset(offset, e); } return tasks; } + private int getOffset(int offset, ColumnEncoderComposite e) { + if(e.hasEncoder(ColumnEncoderDummycode.class)) + offset += e.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1; + if(e.hasEncoder(ColumnEncoderWordEmbedding.class)) + offset += e.getEncoder(ColumnEncoderWordEmbedding.class).getDomainSize() - 1; + return offset; + } + private void applyMT(CacheBlock in, MatrixBlock out, int outputCol, int k) { DependencyThreadPool pool = new DependencyThreadPool(k); try { @@ -386,8 +394,7 @@ private void applyMT(CacheBlock in, MatrixBlock out, int outputCol, int k) { int offset = outputCol; for (ColumnEncoderComposite e : _columnEncoders) { pool.submitAllAndWait(e.getApplyTasks(in, out, e._colID - 1 + offset)); - if (e.hasEncoder(ColumnEncoderDummycode.class)) - offset += e.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1; + offset = getOffset(offset, e); } } else pool.submitAllAndWait(getApplyTasks(in, out, outputCol)); @@ -696,6 +703,12 @@ public void initMetaData(FrameBlock meta) { _legacyMVImpute.initMetaData(meta); } + //pass down init to composite encoders + public void initEmbeddings(MatrixBlock embeddings) { + for(ColumnEncoder columnEncoder : _columnEncoders) + columnEncoder.initEmbeddings(embeddings); + } + @Override public void prepareBuildPartial() { for(Encoder encoder : _columnEncoders) @@ -855,6 +868,13 @@ public List> getEncoderTypes() { return getEncoderTypes(-1); } + public int getEstNNzRow(){ + int nnz = 0; + for(int i = 0; i < _columnEncoders.size(); i++) + nnz += _columnEncoders.get(i).getDomainSize(); + return nnz; + } + public int getNumOutCols() { int sum = 0; for(int i = 0; i < _columnEncoders.size(); i++) diff --git a/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java b/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java index b7779e4ee19..9ace7294627 100644 --- a/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java +++ b/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java @@ -32,6 +32,8 @@ public class TransformStatistics { //private static final LongAdder applyTime = new LongAdder(); private static final LongAdder recodeApplyTime = new LongAdder(); private static final LongAdder dummyCodeApplyTime = new LongAdder(); + + private static final LongAdder wordEmbeddingApplyTime = new LongAdder(); private static final LongAdder passThroughApplyTime = new LongAdder(); private static final LongAdder featureHashingApplyTime = new LongAdder(); private static final LongAdder binningApplyTime = new LongAdder(); @@ -55,6 +57,10 @@ public static void incDummyCodeApplyTime(long t) { dummyCodeApplyTime.add(t); } + public static void incWordEmbeddingApplyTime(long t){ + wordEmbeddingApplyTime.add(t); + } + public static void incBinningApplyTime(long t) { binningApplyTime.add(t); } @@ -112,7 +118,7 @@ public static long getEncodeApplyTime() { return dummyCodeApplyTime.longValue() + binningApplyTime.longValue() + featureHashingApplyTime.longValue() + passThroughApplyTime.longValue() + recodeApplyTime.longValue() + UDFApplyTime.longValue() + - omitApplyTime.longValue() + imputeApplyTime.longValue(); + omitApplyTime.longValue() + imputeApplyTime.longValue() + wordEmbeddingApplyTime.longValue(); } public static void reset() { @@ -163,6 +169,9 @@ public static String displayStatistics() { if(dummyCodeApplyTime.longValue() > 0) sb.append("\tDummyCode apply time:\t").append(String.format("%.3f", dummyCodeApplyTime.longValue()*1e-9)).append(" sec.\n"); + if(wordEmbeddingApplyTime.longValue() > 0) + sb.append("\tWordEmbedding apply time:\t").append(String.format("%.3f", + wordEmbeddingApplyTime.longValue()*1e-9)).append(" sec.\n"); if(featureHashingApplyTime.longValue() > 0) sb.append("\tHashing apply time:\t").append(String.format("%.3f", featureHashingApplyTime.longValue()*1e-9)).append(" sec.\n"); diff --git a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java new file mode 100644 index 00000000000..8ab52d9f640 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.transform; + +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.lops.Lop; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Ignore; +import org.junit.Test; + +import java.io.BufferedWriter; +import java.io.FileWriter; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Date; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; + +public class TransformFrameEncodeWordEmbedding2Test extends AutomatedTestBase +{ + private final static String TEST_NAME1 = "TransformFrameEncodeWordEmbeddings2"; + private final static String TEST_NAME2 = "TransformFrameEncodeWordEmbeddings2MultiCols1"; + private final static String TEST_NAME3 = "TransformFrameEncodeWordEmbeddings2MultiCols2"; + + private final static String TEST_DIR = "functions/transform/"; + private final static String TEST_CLASS_DIR = TEST_DIR + TransformFrameEncodeWordEmbeddingTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_DIR, TEST_NAME1)); + addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_DIR, TEST_NAME2)); + addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_DIR, TEST_NAME3)); + } + + @Test + public void testTransformToWordEmbeddings() { + runTransformTest(TEST_NAME1, ExecMode.SINGLE_NODE); + } + + @Test + @Ignore + public void testNonRandomTransformToWordEmbeddings2Cols() { + runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE); + } + + @Test + @Ignore + public void testRandomTransformToWordEmbeddings4Cols() { + runTransformTestMultiCols(TEST_NAME3, ExecMode.SINGLE_NODE); + } + + private void runTransformTest(String testname, ExecMode rt) + { + //set runtime platform + ExecMode rtold = setExecMode(rt); + try + { + int rows = 100; + int cols = 100; + getAndLoadTestConfiguration(testname); + fullDMLScriptName = getScript(); + + // Generate random embeddings for the distinct tokens + double[][] a = createRandomMatrix("embeddings", rows, cols, 0, 10, 1, new Date().getTime()); + + // Generate random distinct tokens + List strings = generateRandomStrings(rows, 10); + + // Generate the dictionary by assigning unique ID to each distinct token + Map map = writeDictToCsvFile(strings, baseDirectory + INPUT_DIR + "dict"); + + // Create the dataset by repeating and shuffling the distinct tokens + List stringsColumn = shuffleAndMultiplyStrings(strings, 320); + writeStringsToCsvFile(stringsColumn, baseDirectory + INPUT_DIR + "data"); + + //run script + programArgs = new String[]{"-stats","-args", input("embeddings"), input("data"), input("dict"), output("result")}; + runTest(true, EXCEPTION_NOT_EXPECTED, null, -1); + + // Manually derive the expected result + double[][] res_expected = manuallyDeriveWordEmbeddings(cols, a, map, stringsColumn); + + // Compare results + HashMap res_actual = readDMLMatrixFromOutputDir("result"); + double[][] resultActualDouble = TestUtils.convertHashMapToDoubleArray(res_actual); + //System.out.println("Actual Result [" + resultActualDouble.length + "x" + resultActualDouble[0].length + "]:"); + //print2DimDoubleArray(resultActualDouble); + //System.out.println("\nExpected Result [" + res_expected.length + "x" + res_expected[0].length + "]:"); + //print2DimDoubleArray(res_expected); + TestUtils.compareMatrices(resultActualDouble, res_expected, 1e-6); + } + catch(Exception ex) { + throw new RuntimeException(ex); + + } + finally { + resetExecMode(rtold); + } + } + + private void print2DimDoubleArray(double[][] resultActualDouble) { + Arrays.stream(resultActualDouble).forEach( + e -> System.out.println(Arrays.stream(e).mapToObj(d -> String.format("%06.1f", d)) + .reduce("", (sub, elem) -> sub + " " + elem))); + } + + private void runTransformTestMultiCols(String testname, ExecMode rt) + { + //set runtime platform + ExecMode rtold = setExecMode(rt); + try + { + int rows = 100; + int cols = 100; + getAndLoadTestConfiguration(testname); + fullDMLScriptName = getScript(); + + // Generate random embeddings for the distinct tokens + double[][] a = createRandomMatrix("embeddings", rows, cols, 0, 10, 1, new Date().getTime()); + + // Generate random distinct tokens + List strings = generateRandomStrings(rows, 10); + + // Generate the dictionary by assigning unique ID to each distinct token + Map map = writeDictToCsvFile(strings, baseDirectory + INPUT_DIR + "dict"); + + // Create the dataset by repeating and shuffling the distinct tokens + List stringsColumn = shuffleAndMultiplyStrings(strings, 10); + writeStringsToCsvFile(stringsColumn, baseDirectory + INPUT_DIR + "data"); + + //run script + programArgs = new String[]{"-stats","-args", input("embeddings"), input("data"), input("dict"), output("result"), output("result2")}; + runTest(true, EXCEPTION_NOT_EXPECTED, null, -1); + + // Manually derive the expected result + double[][] res_expected = manuallyDeriveWordEmbeddings(cols, a, map, stringsColumn); + + // Compare results + HashMap res_actual = readDMLMatrixFromOutputDir("result"); + HashMap res_actual2 = readDMLMatrixFromOutputDir("result2"); + double[][] resultActualDouble = TestUtils.convertHashMapToDoubleArray(res_actual); + double[][] resultActualDouble2 = TestUtils.convertHashMapToDoubleArray(res_actual2); + //System.out.println("Actual Result1 [" + resultActualDouble.length + "x" + resultActualDouble[0].length + "]:"); + ///print2DimDoubleArray(resultActualDouble); + //System.out.println("\nActual Result2 [" + resultActualDouble.length + "x" + resultActualDouble[0].length + "]:"); + //print2DimDoubleArray(resultActualDouble2); + //System.out.println("\nExpected Result [" + res_expected.length + "x" + res_expected[0].length + "]:"); + //print2DimDoubleArray(res_expected); + TestUtils.compareMatrices(resultActualDouble, res_expected, 1e-6); + TestUtils.compareMatrices(resultActualDouble, resultActualDouble2, 1e-6); + } + catch(Exception ex) { + throw new RuntimeException(ex); + + } + finally { + resetExecMode(rtold); + } + } + + private double[][] manuallyDeriveWordEmbeddings(int cols, double[][] a, Map map, List stringsColumn) { + // Manually derive the expected result + double[][] res_expected = new double[stringsColumn.size()][cols]; + for (int i = 0; i < stringsColumn.size(); i++) { + int rowMapped = map.get(stringsColumn.get(i)); + System.arraycopy(a[rowMapped], 0, res_expected[i], 0, cols); + } + return res_expected; + } + + private double[][] generateWordEmbeddings(int rows, int cols) { + double[][] a = new double[rows][cols]; + for (int i = 0; i < a.length; i++) { + for (int j = 0; j < a[i].length; j++) { + a[i][j] = cols *i + j; + } + + } + return a; + } + + public static List shuffleAndMultiplyStrings(List strings, int multiply){ + List out = new ArrayList<>(); + Random random = new Random(); + for (int i = 0; i < strings.size()*multiply; i++) { + out.add(strings.get(random.nextInt(strings.size()))); + } + return out; + } + + public static List generateRandomStrings(int numStrings, int stringLength) { + List randomStrings = new ArrayList<>(); + Random random = new Random(); + String characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; + for (int i = 0; i < numStrings; i++) { + randomStrings.add(generateRandomString(random, stringLength, characters)); + } + return randomStrings; + } + + public static String generateRandomString(Random random, int stringLength, String characters){ + StringBuilder randomString = new StringBuilder(); + for (int j = 0; j < stringLength; j++) { + int randomIndex = random.nextInt(characters.length()); + randomString.append(characters.charAt(randomIndex)); + } + return randomString.toString(); + } + + public static void writeStringsToCsvFile(List strings, String fileName) { + try (BufferedWriter bw = new BufferedWriter(new FileWriter(fileName))) { + for (String line : strings) { + bw.write(line); + bw.newLine(); + } + } catch (IOException e) { + e.printStackTrace(); + } + } + + public static Map writeDictToCsvFile(List strings, String fileName) { + try (BufferedWriter bw = new BufferedWriter(new FileWriter(fileName))) { + Map map = new HashMap<>(); + for (int i = 0; i < strings.size(); i++) { + map.put(strings.get(i), i); + bw.write(strings.get(i) + Lop.DATATYPE_PREFIX + (i+1) + "\n"); + } + return map; + } catch (IOException e) { + e.printStackTrace(); + return null; + } + } +} diff --git a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2.dml b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2.dml new file mode 100644 index 00000000000..29a4bfab74a --- /dev/null +++ b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2.dml @@ -0,0 +1,36 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read the pre-trained word embeddings +E = read($1, rows=100, cols=100, format="text"); +# Read the token sequence (1K) w/ 100 distinct tokens +Data = read($2, data_type="frame", format="csv"); +# Read the recode map for the distinct tokens +Meta = read($3, data_type="frame", format="csv"); + +jspec = "{ids: true, word_embedding: [1]}"; +Data_enc = transformapply(target=Data, spec=jspec, meta=Meta, embedding=E); + +write(Data_enc, $4, format="text"); + + + + diff --git a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2MultiCols1.dml b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2MultiCols1.dml new file mode 100644 index 00000000000..00484697d6a --- /dev/null +++ b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2MultiCols1.dml @@ -0,0 +1,43 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read the pre-trained word embeddings +E = read($1, rows=100, cols=100, format="text"); +# Read the token sequence (1K) w/ 100 distinct tokens +Data = read($2, data_type="frame", format="csv"); +# Read the recode map for the distinct tokens +Meta = read($3, data_type="frame", format="csv"); + +DataExtension = as.frame(matrix(1, rows=length(Data), cols=1)) +Data = cbind(Data, DataExtension) +Data = cbind(DataExtension, Data) +Meta = cbind(Meta, Meta) + +jspec = "{ids: true, word_embedding: [2]}"; +#jspec = "{ids: true, dummycode: [2]}"; +Data_enc = transformapply(target=Data, spec=jspec, meta=Meta, embedding=E); + +Data_enc = Data_enc[,2:101] +write(Data_enc, $4, format="text"); + + + + diff --git a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2MultiCols2.dml b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2MultiCols2.dml new file mode 100644 index 00000000000..fd742520e7d --- /dev/null +++ b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2MultiCols2.dml @@ -0,0 +1,44 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# Read the pre-trained word embeddings +E = read($1, rows=100, cols=100, format="text"); +# Read the token sequence (1K) w/ 100 distinct tokens +Data = read($2, data_type="frame", format="csv"); +# Read the recode map for the distinct tokens +Meta = read($3, data_type="frame", format="csv"); + +DataExtension = as.frame(matrix(1, rows=length(Data), cols=1)) +Data = cbind(Data, DataExtension) +Data = cbind(Data, Data) +Meta = cbind(Meta, Meta) +Meta = cbind(Meta, Meta) + +jspec = "{ids: true, word_embedding: [1,3]}"; +Data_enc = transformapply(target=Data, spec=jspec, meta=Meta, embedding=E); + +Data_enc1 = Data_enc[,1:100] +Data_enc2 = Data_enc[,102:201] +write(Data_enc1, $4, format="text"); +write(Data_enc2, $5, format="text"); + + +