Skip to content

Commit

Permalink
[SYSTEMDS-3579] Word embedding transformapply in Spark
Browse files Browse the repository at this point in the history
This patch adds the support for word embedding
transformapply in Spark.

Closes #1882 #1918
  • Loading branch information
e-strauss authored and phaniarnab committed Sep 22, 2023
1 parent 0d78859 commit 6fe47a4
Show file tree
Hide file tree
Showing 17 changed files with 531 additions and 206 deletions.
Expand Up @@ -49,19 +49,21 @@ protected void allocateBlock(int bix, int length) {

@Override
public void reset(int rlen, int[] odims, double v) {
if(rlen > capacity() / _odims[0])
if(rlen > _rlen)
_data = new double[rlen][];
else {
if(v == 0.0) {
for(int i = 0; i < rlen; i++)
_data[i] = null;
else{
if(_data == null)
_data = new double[rlen][];
if(v == 0.0){
for(int i = 0; i < rlen; i++)
_data[i] = null;
}
else {
for(int i = 0; i < rlen; i++) {
if(odims[0] > _odims[0] ||_data[i] == null )
allocateBlock(i, odims[0]);
Arrays.fill(_data[i], 0, odims[0], v);
}
for(int i = 0; i < rlen; i++) {
if(odims[0] > _odims[0] ||_data[i] == null )
allocateBlock(i, odims[0]);
Arrays.fill(_data[i], 0, odims[0], v);
}
}
}
_rlen = rlen;
Expand Down Expand Up @@ -178,6 +180,12 @@ public int pos(int[] ix){
public int blockSize(int bix) {
return 1;
}

@Override
public boolean isContiguous() {
return false;
}
@Override
public boolean isContiguous(int rl, int ru){
return rl == ru;
}
Expand Down Expand Up @@ -252,6 +260,25 @@ public DenseBlock set(DenseBlock db) {
throw new NotImplementedException();
}

@Override
public DenseBlock set(int rl, int ru, int ol, int ou, DenseBlock db) {
if( !(db instanceof DenseBlockFP64DEDUP))
throw new NotImplementedException();
HashMap<double[], double[]> cache = new HashMap<>();
int len = ou - ol;
for(int i=rl, ix1 = 0; i<ru; i++, ix1++){
double[] row = db.values(ix1);
double[] newRow = cache.get(row);
if (newRow == null) {
newRow = new double[len];
System.arraycopy(row, 0, newRow, 0, len);
cache.put(row, newRow);
}
set(i, newRow);
}
return this;
}

@Override
public DenseBlock set(int[] ix, double v) {
return set(ix[0], pos(ix), v);
Expand Down
Expand Up @@ -52,8 +52,13 @@ public static DenseBlock createDenseBlock(ValueType vt, int[] dims) {
}

public static DenseBlock createDenseBlock(ValueType vt, int[] dims, boolean dedup) {
DenseBlock.Type type = (UtilFunctions.prod(dims) < Integer.MAX_VALUE) ?
DenseBlock.Type.DRB : DenseBlock.Type.LDRB;
DenseBlock.Type type;
if(dedup)
type = (dims[0] < Integer.MAX_VALUE) ?
DenseBlock.Type.DRB : DenseBlock.Type.LDRB;
else
type = (UtilFunctions.prod(dims) < Integer.MAX_VALUE) ?
DenseBlock.Type.DRB : DenseBlock.Type.LDRB;
return createDenseBlock(vt, type, dims, dedup);
}

Expand Down
Expand Up @@ -60,8 +60,10 @@
import org.apache.sysds.runtime.instructions.spark.functions.PerformGroupByAggInCombiner;
import org.apache.sysds.runtime.instructions.spark.functions.PerformGroupByAggInReducer;
import org.apache.sysds.runtime.instructions.spark.functions.ReplicateVectorFunction;
import org.apache.sysds.runtime.instructions.spark.utils.FrameRDDAggregateUtils;
import org.apache.sysds.runtime.instructions.spark.utils.FrameRDDConverterUtils;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.instructions.spark.utils.RDDConverterUtils;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
Expand Down Expand Up @@ -504,6 +506,8 @@ else if(opcode.equalsIgnoreCase("transformapply")) {
JavaPairRDD<Long, FrameBlock> in = (JavaPairRDD<Long, FrameBlock>) sec.getRDDHandleForFrameObject(fo,
FileFormat.BINARY);
FrameBlock meta = sec.getFrameInput(params.get("meta"));
MatrixBlock embeddings = params.get("embedding") != null ? ec.getMatrixInput(params.get("embedding")) : null;

DataCharacteristics mcIn = sec.getDataCharacteristics(params.get("target"));
DataCharacteristics mcOut = sec.getDataCharacteristics(output.getName());
String[] colnames = !TfMetaUtils.isIDSpec(params.get("spec")) ? in.lookup(1L).get(0)
Expand All @@ -518,20 +522,41 @@ else if(opcode.equalsIgnoreCase("transformapply")) {

// create encoder broadcast (avoiding replication per task)
MultiColumnEncoder encoder = EncoderFactory
.createEncoder(params.get("spec"), colnames, fo.getSchema(), (int) fo.getNumColumns(), meta);
mcOut.setDimension(mcIn.getRows() - ((omap != null) ? omap.getNumRmRows() : 0), encoder.getNumOutCols());
.createEncoder(params.get("spec"), colnames, fo.getSchema(), (int) fo.getNumColumns(), meta, embeddings);
encoder.updateAllDCEncoders();
mcOut.setDimension(mcIn.getRows() - ((omap != null) ? omap.getNumRmRows() : 0),
(int)encoder.getNumOutCols());
Broadcast<MultiColumnEncoder> bmeta = sec.getSparkContext().broadcast(encoder);
Broadcast<TfOffsetMap> bomap = (omap != null) ? sec.getSparkContext().broadcast(omap) : null;

// execute transform apply
JavaPairRDD<Long, FrameBlock> tmp = in.mapToPair(new RDDTransformApplyFunction(bmeta, bomap));
JavaPairRDD<MatrixIndexes, MatrixBlock> out = FrameRDDConverterUtils
.binaryBlockToMatrixBlock(tmp, mcOut, mcOut);
JavaPairRDD<MatrixIndexes, MatrixBlock> out;
Tuple2<Boolean, Integer> aligned = FrameRDDAggregateUtils.checkRowAlignment(in, -1);
// NOTE: currently disabled for LegacyEncoders, because OMIT probably results in not aligned
// blocks and for IMPUTE was an inaccuracy for the "testHomesImputeColnamesSparkCSV" test case.
// Expected: 8.150349617004395 vs actual: 8.15035 at 0 8 (expected is calculated from transform encode,
// which currently always uses the else branch: either inaccuracy must come from serialisation of
// matrixblock or from binaryBlockToBinaryBlock reblock
if(aligned._1 && mcOut.getCols() <= aligned._2 && !encoder.hasLegacyEncoder() /*&& containsWE*/) {
//Blocks are aligned & #Col is below Block length (necessary for matrix-matrix reblock)
JavaPairRDD<Long, MatrixBlock> tmp = in.mapToPair(new RDDTransformApplyFunction2(bmeta, bomap));
mcIn.setBlocksize(aligned._2);
mcIn.setDimension(mcIn.getRows(), mcOut.getCols());
JavaPairRDD<MatrixIndexes, MatrixBlock> tmp2 = tmp.mapToPair((PairFunction<Tuple2<Long, MatrixBlock>, MatrixIndexes, MatrixBlock>) in12 ->
new Tuple2<>(new MatrixIndexes(UtilFunctions.computeBlockIndex(in12._1, aligned._2),1), in12._2));
out = RDDConverterUtils.binaryBlockToBinaryBlock(tmp2, mcIn, mcOut);
//out = RDDConverterUtils.matrixBlockToAlignedMatrixBlock(tmp, mcOut, mcOut);
} else {
JavaPairRDD<Long, FrameBlock> tmp = in.mapToPair(new RDDTransformApplyFunction(bmeta, bomap));
out = FrameRDDConverterUtils.binaryBlockToMatrixBlock(tmp, mcOut, mcOut);
}

// set output and maintain lineage/output characteristics
sec.setRDDHandleForVariable(output.getName(), out);
sec.addLineageRDD(output.getName(), params.get("target"));
ec.releaseFrameInput(params.get("meta"));
if(params.get("embedding") != null)
ec.releaseMatrixInput(params.get("embedding"));
}
else if(opcode.equalsIgnoreCase("transformdecode")) {
// get input RDD and meta data
Expand Down Expand Up @@ -979,7 +1004,6 @@ public Tuple2<Long, FrameBlock> call(Tuple2<Long, FrameBlock> in) throws Excepti
// execute block transform apply
MultiColumnEncoder encoder = _bencoder.getValue();
MatrixBlock tmp = encoder.apply(blk);

// remap keys
if(_omap != null) {
key = _omap.getValue().getOffset(key);
Expand All @@ -990,6 +1014,8 @@ public Tuple2<Long, FrameBlock> call(Tuple2<Long, FrameBlock> in) throws Excepti
}
}



public static class RDDTransformApplyOffsetFunction implements PairFunction<Tuple2<Long, FrameBlock>, Long, Long> {
private static final long serialVersionUID = 3450977356721057440L;

Expand Down Expand Up @@ -1026,6 +1052,35 @@ public Tuple2<Long, Long> call(Tuple2<Long, FrameBlock> in) throws Exception {
}
}

public static class RDDTransformApplyFunction2 implements PairFunction<Tuple2<Long, FrameBlock>, Long, MatrixBlock> {
private static final long serialVersionUID = 5759813006068230916L;

private Broadcast<MultiColumnEncoder> _bencoder = null;
private Broadcast<TfOffsetMap> _omap = null;

public RDDTransformApplyFunction2(Broadcast<MultiColumnEncoder> bencoder, Broadcast<TfOffsetMap> omap) {
_bencoder = bencoder;
_omap = omap;
}

@Override
public Tuple2<Long, MatrixBlock> call(Tuple2<Long, FrameBlock> in) throws Exception {
long key = in._1();
FrameBlock blk = in._2();

// execute block transform apply
MultiColumnEncoder encoder = _bencoder.getValue();
MatrixBlock tmp = encoder.apply(blk);
// remap keys
if(_omap != null) {
key = _omap.getValue().getOffset(key);
}

// convert to frameblock to reuse frame-matrix reblock
return new Tuple2<>(key, tmp);
}
}

public static class RDDTransformDecodeFunction
implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, Long, FrameBlock> {
private static final long serialVersionUID = -4797324742568170756L;
Expand Down
Expand Up @@ -20,14 +20,77 @@
package org.apache.sysds.runtime.instructions.spark.utils;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import scala.Function3;
import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple4;
import scala.Tuple5;


public class FrameRDDAggregateUtils
{
public static Tuple2<Boolean, Integer> checkRowAlignment(JavaPairRDD<Long,FrameBlock> in, int blen){
JavaRDD<Tuple5<Boolean, Long, Integer, Integer, Boolean>> row_rdd = in.map((Function<Tuple2<Long, FrameBlock>, Tuple5<Boolean, Long, Integer, Integer, Boolean>>) in1 -> {
long key = in1._1();
FrameBlock blk = in1._2();
return new Tuple5<>(true, key, blen == -1 ? blk.getNumRows() : blen, blk.getNumRows(), true);
});
Tuple5<Boolean, Long, Integer, Integer, Boolean> result = row_rdd.fold(null, (Function2<Tuple5<Boolean, Long, Integer, Integer, Boolean>, Tuple5<Boolean, Long, Integer, Integer, Boolean>, Tuple5<Boolean, Long, Integer, Integer, Boolean>>) (in1, in2) -> {
//easy evaluation
if (in1 == null)
return in2;
if (in2 == null)
return in1;
if (!in1._1() || !in2._1())
return new Tuple5<>(false, null, null, null, null);

//default evaluation
int in1_max = in1._3();
int in1_min = in1._4();
long in1_min_index = in1._2(); //Index of Block with min nr rows --> Block with largest index ( --> last block index)
int in2_max = in2._3();
int in2_min = in2._4();
long in2_min_index = in2._2();

boolean in1_isSingleBlock = in1._5();
boolean in2_isSingleBlock = in2._5();
boolean min_index_comp = in1_min_index > in2_min_index;

if (in1_max == in2_max) {
if (in1_min == in1_max) {
if (in2_min == in2_max)
return new Tuple5<>(true, min_index_comp ? in1_min_index : in2_min_index, in1_max, in1_max, false);
else if (!min_index_comp)
return new Tuple5<>(true, in2_min_index, in1_max, in2_min, false);
//else: in1_min_index > in2_min_index --> in2 is not aligned
} else {
if (in2_min == in2_max)
if (min_index_comp)
return new Tuple5<>(true, in1_min_index, in1_max, in1_min, false);
//else: in1_min_index < in2_min_index --> in1 is not aligned
//else: both contain blocks with less blocks than max
}
} else {
if (in1_max > in2_max && in1_min == in1_max && in2_isSingleBlock && in1_min_index < in2_min_index)
return new Tuple5<>(true, in2_min_index, in1_max, in2_min, false);
/* else:
in1_min != in1_max -> both contain blocks with less blocks than max
!in2_isSingleBlock -> in2 contains at least 2 blocks with less blocks than in1's max
in1_min_index > in2_min_index -> in2's min block != lst block
*/
if (in1_max < in2_max && in2_min == in2_max && in1_isSingleBlock && in2_min_index < in1_min_index)
return new Tuple5<>(true, in1_min_index, in2_max, in1_min, false);
}
return new Tuple5<>(false, null, null, null, null);
});
return new Tuple2<>(result._1(), result._3()) ;
}

public static JavaPairRDD<Long, FrameBlock> mergeByKey( JavaPairRDD<Long, FrameBlock> in )
{
Expand Down
Expand Up @@ -57,6 +57,8 @@
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.data.DenseBlockFP64DEDUP;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.instructions.spark.data.ReblockBuffer;
import org.apache.sysds.runtime.instructions.spark.data.SerLongWritable;
Expand Down Expand Up @@ -380,6 +382,18 @@ public static void libsvmToBinaryBlock(JavaSparkContext sc, String pathIn,
}
}

//can be removed if not necessary, it's basically the Frame-Matrix reblock but with matrix
public static JavaPairRDD<MatrixIndexes, MatrixBlock> matrixBlockToAlignedMatrixBlock(JavaPairRDD<Long,
MatrixBlock> input, DataCharacteristics mcIn, DataCharacteristics mcOut)
{
//align matrix blocks
JavaPairRDD<MatrixIndexes, MatrixBlock> out = input
.flatMapToPair(new RDDConverterUtils.MatrixBlockToAlignedMatrixBlockFunction(mcIn, mcOut));

//aggregate partial matrix blocks
return RDDAggregateUtils.mergeByKey(out, false);
}

public static JavaPairRDD<LongWritable, Text> stringToSerializableText(JavaPairRDD<Long,String> in)
{
return in.mapToPair(new TextToSerTextFunction());
Expand Down Expand Up @@ -1436,5 +1450,51 @@ public static JavaPairRDD<MatrixIndexes, MatrixBlock> libsvmToBinaryBlock(JavaSp
}
///////////////////////////////
// END LIBSVM FUNCTIONS

private static class MatrixBlockToAlignedMatrixBlockFunction implements PairFlatMapFunction<Tuple2<Long,MatrixBlock>,MatrixIndexes, MatrixBlock> {
private static final long serialVersionUID = -2654986510471835933L;

private DataCharacteristics _mcIn;
private DataCharacteristics _mcOut;
public MatrixBlockToAlignedMatrixBlockFunction(DataCharacteristics mcIn, DataCharacteristics mcOut) {
_mcIn = mcIn; //Frame Characteristics
_mcOut = mcOut; //Matrix Characteristics
}
@Override
public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Tuple2<Long, MatrixBlock> arg0)
throws Exception
{
long rowIndex = arg0._1();
MatrixBlock blk = arg0._2();
boolean dedup = blk.getDenseBlock() instanceof DenseBlockFP64DEDUP;
ArrayList<Tuple2<MatrixIndexes, MatrixBlock>> ret = new ArrayList<>();
long rlen = _mcIn.getRows();
long clen = _mcIn.getCols();
int blen = _mcOut.getBlocksize();

//slice aligned matrix blocks out of given frame block
long rstartix = UtilFunctions.computeBlockIndex(rowIndex, blen);
long rendix = UtilFunctions.computeBlockIndex(rowIndex+blk.getNumRows()-1, blen);
long cendix = UtilFunctions.computeBlockIndex(blk.getNumColumns(), blen);
for( long rix=rstartix; rix<=rendix; rix++ ) { //for all row blocks
long rpos = UtilFunctions.computeCellIndex(rix, blen, 0);
int lrlen = UtilFunctions.computeBlockSize(rlen, rix, blen);
int fix = (int)((rpos-rowIndex>=0) ? rpos-rowIndex : 0);
int fix2 = (int)Math.min(rpos+lrlen-rowIndex-1,blk.getNumRows()-1);
int mix = UtilFunctions.computeCellInBlock(rowIndex+fix, blen);
int mix2 = mix + (fix2-fix);
for( long cix=1; cix<=cendix; cix++ ) { //for all column blocks
long cpos = UtilFunctions.computeCellIndex(cix, blen, 0);
int lclen = UtilFunctions.computeBlockSize(clen, cix, blen);
MatrixBlock tmp = blk.slice(fix, fix2,
(int)cpos-1, (int)cpos+lclen-2, new MatrixBlock());
MatrixBlock newBlock = new MatrixBlock(lrlen, lclen, false);
ret.add(new Tuple2<>(new MatrixIndexes(rix, cix), newBlock.leftIndexingOperations(tmp, mix, mix2, 0, lclen-1,
new MatrixBlock(), MatrixObject.UpdateType.INPLACE_PINNED)));
}
}
return ret.iterator();
}
}
}

Expand Up @@ -25,6 +25,7 @@

import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.sysds.runtime.data.DenseBlockFP64DEDUP;
import org.apache.sysds.runtime.util.UtilFunctions;


Expand Down Expand Up @@ -74,7 +75,17 @@ public void convert(MatrixIndexes k1, MatrixBlock v1) {
{
if(v1.getDenseBlock()==null)
return;
denseArray=v1.getDenseBlockValues();
if(v1.getDenseBlock() instanceof DenseBlockFP64DEDUP){
DenseBlockFP64DEDUP db = (DenseBlockFP64DEDUP) v1.getDenseBlock();
denseArray = new double[v1.rlen*v1.clen];
for (int i = 0; i < v1.rlen; i++) {
double[] row = db.values(i);
for (int j = 0; j < v1.clen; j++) {
denseArray[i*v1.clen + j] = row[j];
}
}
} else
denseArray=v1.getDenseBlockValues();
nextInDenseArray=0;
denseArraySize=v1.getNumRows()*v1.getNumColumns();
}
Expand Down

0 comments on commit 6fe47a4

Please sign in to comment.