Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.SparseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
Expand All @@ -42,7 +43,6 @@

import java.io.IOException;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;

/**
Expand Down Expand Up @@ -90,14 +90,27 @@ public AssemblerFunc(String[] inputCols, String handleInvalid) {
}

@Override
public void flatMap(Row value, Collector<Row> out) throws Exception {
public void flatMap(Row value, Collector<Row> out) {
int nnz = 0;
int vectorSize = 0;
try {
Object[] objects = new Object[inputCols.length];
for (int i = 0; i < objects.length; ++i) {
objects[i] = value.getField(inputCols[i]);
for (String inputCol : inputCols) {
Object object = value.getField(inputCol);
Preconditions.checkNotNull(object, "Input column value should not be null.");
if (object instanceof Number) {
nnz += 1;
vectorSize += 1;
} else if (object instanceof SparseVector) {
nnz += ((SparseVector) object).indices.length;
vectorSize += ((SparseVector) object).size();
} else if (object instanceof DenseVector) {
nnz += ((DenseVector) object).size();
vectorSize += ((DenseVector) object).size();
} else {
throw new IllegalArgumentException(
"Input type has not been supported yet.");
}
}
Vector assembledVector = assemble(objects);
out.collect(Row.join(value, Row.of(assembledVector)));
} catch (Exception e) {
switch (handleInvalid) {
case ERROR_INVALID:
Expand All @@ -112,6 +125,13 @@ public void flatMap(Row value, Collector<Row> out) throws Exception {
"Unsupported " + HANDLE_INVALID + " type: " + handleInvalid);
}
}

boolean toDense = nnz * RATIO > vectorSize;
Vector assembledVec =
toDense
? assembleDense(inputCols, value, vectorSize)
: assembleSparse(inputCols, value, vectorSize, nnz);
out.collect(Row.join(value, Row.of(assembledVec)));
}
}

Expand All @@ -129,57 +149,69 @@ public Map<Param<?>, Object> getParamMap() {
return paramMap;
}

private static Vector assemble(Object[] objects) {
int offset = 0;
Map<Integer, Double> map = new LinkedHashMap<>(objects.length);
for (Object object : objects) {
Preconditions.checkNotNull(object, "Input column value should not be null.");
/** Assembles the input columns into a dense vector. */
private static Vector assembleDense(String[] inputCols, Row inputRow, int vectorSize) {
double[] values = new double[vectorSize];
int currentOffset = 0;

for (String inputCol : inputCols) {
Object object = inputRow.getField(inputCol);
if (object instanceof Number) {
map.put(offset++, ((Number) object).doubleValue());
} else if (object instanceof Vector) {
offset = appendVector((Vector) object, map, offset);
values[currentOffset++] = ((Number) object).doubleValue();
} else if (object instanceof SparseVector) {
SparseVector sparseVector = (SparseVector) object;
for (int i = 0; i < sparseVector.indices.length; i++) {
values[currentOffset + sparseVector.indices[i]] = sparseVector.values[i];
}
currentOffset += sparseVector.size();

} else {
throw new IllegalArgumentException("Input type has not been supported yet.");
}
}
DenseVector denseVector = (DenseVector) object;
System.arraycopy(
denseVector.values, 0, values, currentOffset, denseVector.values.length);

if (map.size() * RATIO > offset) {
DenseVector assembledVector = new DenseVector(offset);
for (int key : map.keySet()) {
assembledVector.values[key] = map.get(key);
currentOffset += denseVector.size();
}
return assembledVector;
} else {
return convertMapToSparseVector(offset, map);
}
return Vectors.dense(values);
}

private static int appendVector(Vector vec, Map<Integer, Double> map, int offset) {
if (vec instanceof SparseVector) {
SparseVector sparseVector = (SparseVector) vec;
int[] indices = sparseVector.indices;
double[] values = sparseVector.values;
for (int i = 0; i < indices.length; ++i) {
map.put(offset + indices[i], values[i]);
}
offset += sparseVector.size();
} else {
DenseVector denseVector = (DenseVector) vec;
for (int i = 0; i < denseVector.size(); ++i) {
map.put(offset++, denseVector.values[i]);
}
}
return offset;
}
/** Assembles the input columns into a sparse vector. */
private static Vector assembleSparse(
String[] inputCols, Row inputRow, int vectorSize, int nnz) {
int[] indices = new int[nnz];
double[] values = new double[nnz];

private static SparseVector convertMapToSparseVector(int size, Map<Integer, Double> map) {
int[] indices = new int[map.size()];
double[] values = new double[map.size()];
int offset = 0;
for (Map.Entry<Integer, Double> entry : map.entrySet()) {
indices[offset] = entry.getKey();
values[offset++] = entry.getValue();
int currentIndex = 0;
int currentOffset = 0;

for (String inputCol : inputCols) {
Object object = inputRow.getField(inputCol);
if (object instanceof Number) {
indices[currentOffset] = currentIndex;
values[currentOffset] = ((Number) object).doubleValue();
currentOffset++;
currentIndex++;
} else if (object instanceof SparseVector) {
SparseVector sparseVector = (SparseVector) object;
for (int i = 0; i < sparseVector.indices.length; i++) {
indices[currentOffset + i] = sparseVector.indices[i] + currentIndex;
}
System.arraycopy(
sparseVector.values, 0, values, currentOffset, sparseVector.values.length);
currentIndex += sparseVector.size();
currentOffset += sparseVector.indices.length;
} else {
DenseVector denseVector = (DenseVector) object;
for (int i = 0; i < denseVector.size(); ++i) {
indices[currentOffset + i] = i + currentIndex;
}
System.arraycopy(
denseVector.values, 0, values, currentOffset, denseVector.values.length);
currentIndex += denseVector.size();
currentOffset += denseVector.size();
}
}
return new SparseVector(size, indices, values);
return new SparseVector(vectorSize, indices, values);
}
}