Skip to content

Commit

Permalink
[SYSTEMDS-3580] Add word embedding encoder
Browse files Browse the repository at this point in the history
This patch extends the transformapply API to accept the pre-trained word
embeddings along with the dictionary as inputs. The new word embedding
column encoder is placed after recode and replace the recoded indices
with the embedding vectors. This addition removes the requirement of
a matrix multiplication to produce the embedding matrix.
The current implementation is slower than the baseline (w/ MatMult).
The future commits will introduce a new dense block to deduplicate
the large embeddings and multi-threading.

Closes #1839
  • Loading branch information
e-strauss authored and phaniarnab committed Jun 9, 2023
1 parent 1fb8c5e commit 384a707
Show file tree
Hide file tree
Showing 14 changed files with 593 additions and 18 deletions.
Expand Up @@ -49,6 +49,7 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier
public static final String TF_FN_PARAM_DATA = "target";
public static final String TF_FN_PARAM_MTD2 = "meta";
public static final String TF_FN_PARAM_SPEC = "spec";
public static final String TF_FN_PARAM_EMBD = "embedding";
public static final String LINEAGE_TRACE = "lineage";
public static final String TF_FN_PARAM_MTD = "transformPath"; //NOTE MB: for backwards compatibility

Expand Down Expand Up @@ -617,11 +618,14 @@ private void validateTransformApply(DataIdentifier output, boolean conditional)
//validate data / metadata (recode maps)
checkDataType(false, "transformapply", TF_FN_PARAM_DATA, DataType.FRAME, conditional);
checkDataType(false, "transformapply", TF_FN_PARAM_MTD2, DataType.FRAME, conditional);

//validate specification
checkDataValueType(false, "transformapply", TF_FN_PARAM_SPEC, DataType.SCALAR, ValueType.STRING, conditional);
validateTransformSpec(TF_FN_PARAM_SPEC, conditional);


//validate additional argument for word_embeddings tranform
checkDataType(true, "transformapply", TF_FN_PARAM_EMBD, DataType.MATRIX, conditional);

//set output dimensions
output.setDataType(DataType.MATRIX);
output.setValueType(ValueType.FP64);
Expand Down
Expand Up @@ -178,7 +178,7 @@ public DenseBlock set(int r, double[] v) {
System.arraycopy(v, 0, _data, pos(r), _odims[0]);
return this;
}

@Override
public DenseBlock set(int[] ix, double v) {
_data[pos(ix)] = v;
Expand Down
Expand Up @@ -54,7 +54,11 @@
import org.apache.sysds.runtime.util.AutoDiff;
import org.apache.sysds.runtime.util.DataConverter;

import java.util.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

Expand Down Expand Up @@ -310,11 +314,12 @@ else if(opcode.equalsIgnoreCase("transformapply")) {
// acquire locks
FrameBlock data = ec.getFrameInput(params.get("target"));
FrameBlock meta = ec.getFrameInput(params.get("meta"));
MatrixBlock embeddings = params.get("embedding") != null ? ec.getMatrixInput(params.get("embedding")) : null;
String[] colNames = data.getColumnNames();

// compute transformapply
MultiColumnEncoder encoder = EncoderFactory
.createEncoder(params.get("spec"), colNames, data.getNumColumns(), meta);
.createEncoder(params.get("spec"), colNames, data.getNumColumns(), meta, embeddings);
MatrixBlock mbout = encoder.apply(data, OptimizerUtils.getTransformNumThreads());

// release locks
Expand Down Expand Up @@ -346,7 +351,7 @@ else if(opcode.equalsIgnoreCase("transformcolmap")) {

// compute transformapply
MultiColumnEncoder encoder = EncoderFactory
.createEncoder(params.get("spec"), colNames, meta.getNumColumns(), null);
.createEncoder(params.get("spec"), colNames, meta.getNumColumns(), null, null);
MatrixBlock mbout = encoder.getColMapping(meta);

// release locks
Expand Down Expand Up @@ -532,6 +537,8 @@ else if(opcode.equalsIgnoreCase("transformdecode") || opcode.equalsIgnoreCase("t
CPOperand target = new CPOperand(params.get("target"), ValueType.FP64, DataType.FRAME);
CPOperand meta = getLiteral("meta", ValueType.UNKNOWN, DataType.FRAME);
CPOperand spec = getStringLiteral("spec");
//FIXME: Taking only spec file name as a literal leads to wrong reuse
//TODO: Add Embedding to the lineage item
return Pair.of(output.getName(),
new LineageItem(getOpcode(), LineageItemUtils.getLineage(ec, target, meta, spec)));
}
Expand Down
Expand Up @@ -47,7 +47,7 @@ protected byte toID() {

//transform methods
public enum TfMethod {
IMPUTE, RECODE, HASH, BIN, DUMMYCODE, UDF, OMIT;
IMPUTE, RECODE, HASH, BIN, DUMMYCODE, UDF, OMIT, WORD_EMBEDDING;
@Override
public String toString() {
return name().toLowerCase();
Expand Down
Expand Up @@ -65,8 +65,13 @@ public abstract class ColumnEncoder implements Encoder, Comparable<ColumnEncoder
protected int _nBuildPartitions = 0;
protected int _nApplyPartitions = 0;

//Override in ColumnEncoderWordEmbedding
public void initEmbeddings(MatrixBlock embeddings){
return;
}

protected enum TransformType{
BIN, RECODE, DUMMYCODE, FEATURE_HASH, PASS_THROUGH, UDF, N_A
BIN, RECODE, DUMMYCODE, FEATURE_HASH, PASS_THROUGH, UDF, WORD_EMBEDDING, N_A
}

protected ColumnEncoder(int colID) {
Expand Down Expand Up @@ -106,6 +111,9 @@ public MatrixBlock apply(CacheBlock<?> in, MatrixBlock out, int outputCol, int r
case DUMMYCODE:
TransformStatistics.incDummyCodeApplyTime(t);
break;
case WORD_EMBEDDING:
TransformStatistics.incWordEmbeddingApplyTime(t);
break;
case FEATURE_HASH:
TransformStatistics.incFeatureHashingApplyTime(t);
break;
Expand Down
Expand Up @@ -319,6 +319,12 @@ public void initMetaData(FrameBlock out) {
columnEncoder.initMetaData(out);
}

//pass down init to actual encoders, only ColumnEncoderWordEmbedding has actually implemented the init method
public void initEmbeddings(MatrixBlock embeddings){
for(ColumnEncoder columnEncoder : _columnEncoders)
columnEncoder.initEmbeddings(embeddings);
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
Expand Down
@@ -0,0 +1,111 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.runtime.transform.encode;

import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;

import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;

public class ColumnEncoderWordEmbedding extends ColumnEncoder {
private MatrixBlock wordEmbeddings;

//domain size is equal to the number columns of the embedding column (equal to length of an embedding vector)
@Override
public int getDomainSize(){
return wordEmbeddings.getNumColumns();
}
protected ColumnEncoderWordEmbedding(int colID) {
super(colID);
}

@Override
protected double getCode(CacheBlock<?> in, int row) {
throw new NotImplementedException();
}

@Override
protected double[] getCodeCol(CacheBlock<?> in, int startInd, int blkSize) {
throw new NotImplementedException();
}

//previous recode replaced strings with indices of the corresponding matrix row index
//now, the indices are replaced with actual word embedding vectors
//current limitation: in case the transform is done on multiple cols, the same embedding
//matrix is used for both transform
@Override
public void applyDense(CacheBlock<?> in, MatrixBlock out, int outputCol, int rowStart, int blk){
if (!(in instanceof MatrixBlock)){
throw new DMLRuntimeException("ColumnEncoderWordEmbedding called with: " + in.getClass().getSimpleName() +
" and not MatrixBlock");
}
int rowEnd = getEndIndex(in.getNumRows(), rowStart, blk);
//map each recoded index to the corresponding embedding vector
for(int i=rowStart; i<rowEnd; i++){
double embeddingIndex = in.getDouble(i, outputCol);
//fill row with zeroes
if(Double.isNaN(embeddingIndex)){
for (int j = outputCol; j < outputCol + getDomainSize(); j++)
out.quickSetValue(i, j, 0.0);
}
//array copy
else{
for (int j = outputCol; j < outputCol + getDomainSize(); j++){
out.quickSetValue(i, j, wordEmbeddings.quickGetValue((int) embeddingIndex - 1,j - outputCol ));
}
}
}
}


@Override
protected TransformType getTransformType() {
return TransformType.WORD_EMBEDDING;
}

@Override
public void build(CacheBlock<?> in) {
throw new NotImplementedException();
}

@Override
public void allocateMetaData(FrameBlock meta) {
throw new NotImplementedException();
}

@Override
public FrameBlock getMetaData(FrameBlock out) {
throw new NotImplementedException();
}

@Override
public void initMetaData(FrameBlock meta) {
return;
}

//save embeddings matrix reference for apply step
@Override
public void initEmbeddings(MatrixBlock embeddings){
this.wordEmbeddings = embeddings;
}
}
Expand Up @@ -36,6 +36,7 @@
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
import org.apache.sysds.runtime.transform.encode.ColumnEncoder.EncoderType;
import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
Expand Down Expand Up @@ -68,7 +69,21 @@ public static MultiColumnEncoder createEncoder(String spec, String[] colnames, V
}

public static MultiColumnEncoder createEncoder(String spec, String[] colnames, ValueType[] schema, FrameBlock meta,
int minCol, int maxCol) {
int minCol, int maxCol){
return createEncoder(spec, colnames, schema, meta, null, minCol, maxCol);
}

public static MultiColumnEncoder createEncoder(String spec, String[] colnames, int clen, FrameBlock meta, MatrixBlock embeddings) {
return createEncoder(spec, colnames, UtilFunctions.nCopies(clen, ValueType.STRING), meta, embeddings);
}

public static MultiColumnEncoder createEncoder(String spec, String[] colnames, ValueType[] schema,
FrameBlock meta, MatrixBlock embeddings) {
return createEncoder(spec, colnames, schema, meta, embeddings, -1, -1);
}

public static MultiColumnEncoder createEncoder(String spec, String[] colnames, ValueType[] schema, FrameBlock meta,
MatrixBlock embeddings, int minCol, int maxCol) {
MultiColumnEncoder encoder;
int clen = schema.length;

Expand All @@ -88,9 +103,18 @@ public static MultiColumnEncoder createEncoder(String spec, String[] colnames, V
List<Integer> dcIDs = Arrays.asList(ArrayUtils
.toObject(TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.DUMMYCODE.toString(), minCol, maxCol)));
List<Integer> binIDs = TfMetaUtils.parseBinningColIDs(jSpec, colnames, minCol, maxCol);
List<Integer> weIDs = Arrays.asList(ArrayUtils
.toObject(TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.WORD_EMBEDDING.toString(), minCol, maxCol)));

//check if user passed an embeddings matrix
if(!weIDs.isEmpty() && embeddings == null)
throw new DMLRuntimeException("Missing argument Embeddings Matrix for transform [" + TfMethod.WORD_EMBEDDING + "]");

// NOTE: any dummycode column requires recode as preparation, unless the dummycode
// column follows binning or feature hashing
rcIDs = unionDistinct(rcIDs, except(except(dcIDs, binIDs), haIDs));
// NOTE: Word Embeddings requires recode as preparation
rcIDs = unionDistinct(rcIDs, weIDs);
// Error out if the first level encoders have overlaps
if (intersect(rcIDs, binIDs, haIDs))
throw new DMLRuntimeException("More than one encoders (recode, binning, hashing) on one column is not allowed");
Expand All @@ -114,7 +138,9 @@ public static MultiColumnEncoder createEncoder(String spec, String[] colnames, V
if(!ptIDs.isEmpty())
for(Integer id : ptIDs)
addEncoderToMap(new ColumnEncoderPassThrough(id), colEncoders);

if(!weIDs.isEmpty())
for(Integer id : weIDs)
addEncoderToMap(new ColumnEncoderWordEmbedding(id), colEncoders);
if(!binIDs.isEmpty())
for(Object o : (JSONArray) jSpec.get(TfMethod.BIN.toString())) {
JSONObject colspec = (JSONObject) o;
Expand Down Expand Up @@ -185,6 +211,9 @@ else if ("EQUI-HEIGHT".equals(method))
}
encoder.initMetaData(meta);
}
//initialize embeddings matrix block in the encoders in case word embedding transform is used
if(!weIDs.isEmpty())
encoder.initEmbeddings(embeddings);
}
catch(Exception ex) {
throw new DMLRuntimeException(ex);
Expand Down
Expand Up @@ -314,10 +314,12 @@ public MatrixBlock apply(CacheBlock<?> in) {
public MatrixBlock apply(CacheBlock<?> in, int k) {
// domain sizes are not updated if called from transformapply
boolean hasUDF = _columnEncoders.stream().anyMatch(e -> e.hasEncoder(ColumnEncoderUDF.class));
boolean hasWE = _columnEncoders.stream().anyMatch(e -> e.hasEncoder(ColumnEncoderWordEmbedding.class));
for(ColumnEncoderComposite columnEncoder : _columnEncoders)
columnEncoder.updateAllDCEncoders();
int numCols = getNumOutCols();
long estNNz = (long) in.getNumRows() * (hasUDF ? numCols : (long) in.getNumColumns());
long estNNz = (long) in.getNumRows() * (hasUDF ? numCols : hasWE ? getEstNNzRow() : (long) in.getNumColumns());
// FIXME: estimate nnz for multiple encoders including dummycode and embedding
boolean sparse = MatrixBlock.evalSparseFormatInMemory(in.getNumRows(), numCols, estNNz) && !hasUDF;
MatrixBlock out = new MatrixBlock(in.getNumRows(), numCols, sparse, estNNz);
return apply(in, out, 0, k);
Expand Down Expand Up @@ -353,8 +355,7 @@ public MatrixBlock apply(CacheBlock<?> in, MatrixBlock out, int outputCol, int k
int offset = outputCol;
for(ColumnEncoderComposite columnEncoder : _columnEncoders) {
columnEncoder.apply(in, out, columnEncoder._colID - 1 + offset);
if (columnEncoder.hasEncoder(ColumnEncoderDummycode.class))
offset += columnEncoder.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1;
offset = getOffset(offset, columnEncoder);
}
}
// Recomputing NNZ since we access the block directly
Expand All @@ -373,21 +374,27 @@ private List<DependencyTask<?>> getApplyTasks(CacheBlock<?> in, MatrixBlock out,
int offset = outputCol;
for(ColumnEncoderComposite e : _columnEncoders) {
tasks.addAll(e.getApplyTasks(in, out, e._colID - 1 + offset));
if(e.hasEncoder(ColumnEncoderDummycode.class))
offset += e.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1;
offset = getOffset(offset, e);
}
return tasks;
}

private int getOffset(int offset, ColumnEncoderComposite e) {
if(e.hasEncoder(ColumnEncoderDummycode.class))
offset += e.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1;
if(e.hasEncoder(ColumnEncoderWordEmbedding.class))
offset += e.getEncoder(ColumnEncoderWordEmbedding.class).getDomainSize() - 1;
return offset;
}

private void applyMT(CacheBlock<?> in, MatrixBlock out, int outputCol, int k) {
DependencyThreadPool pool = new DependencyThreadPool(k);
try {
if(APPLY_ENCODER_SEPARATE_STAGES) {
int offset = outputCol;
for (ColumnEncoderComposite e : _columnEncoders) {
pool.submitAllAndWait(e.getApplyTasks(in, out, e._colID - 1 + offset));
if (e.hasEncoder(ColumnEncoderDummycode.class))
offset += e.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1;
offset = getOffset(offset, e);
}
} else
pool.submitAllAndWait(getApplyTasks(in, out, outputCol));
Expand Down Expand Up @@ -696,6 +703,12 @@ public void initMetaData(FrameBlock meta) {
_legacyMVImpute.initMetaData(meta);
}

//pass down init to composite encoders
public void initEmbeddings(MatrixBlock embeddings) {
for(ColumnEncoder columnEncoder : _columnEncoders)
columnEncoder.initEmbeddings(embeddings);
}

@Override
public void prepareBuildPartial() {
for(Encoder encoder : _columnEncoders)
Expand Down Expand Up @@ -855,6 +868,13 @@ public List<Class<? extends ColumnEncoder>> getEncoderTypes() {
return getEncoderTypes(-1);
}

public int getEstNNzRow(){
int nnz = 0;
for(int i = 0; i < _columnEncoders.size(); i++)
nnz += _columnEncoders.get(i).getDomainSize();
return nnz;
}

public int getNumOutCols() {
int sum = 0;
for(int i = 0; i < _columnEncoders.size(); i++)
Expand Down

0 comments on commit 384a707

Please sign in to comment.