Skip to content

Commit

Permalink
[SYSTEMDS-2363] Improve persistent writes w/ variable format parameter
Browse files Browse the repository at this point in the history
Although we extended persistent writes with dynamic file names, the
format parameter remained static and thus depended on constant
propagation during compilation time, and otherwise failed with errors.
We now support truly dynamic format types (e.g., selected at script
level) and added tests for both CP and Spark.
  • Loading branch information
mboehm7 committed Jun 5, 2023
1 parent 10315ab commit 565c0e5
Show file tree
Hide file tree
Showing 10 changed files with 184 additions and 77 deletions.
3 changes: 2 additions & 1 deletion src/main/java/org/apache/sysds/common/Types.java
Expand Up @@ -632,7 +632,8 @@ public enum FileFormat {
BINARY, // binary block representation (dense/sparse/ultra-sparse)
FEDERATED, // A federated matrix
PROTO, // protocol buffer representation
HDF5; // Hierarchical Data Format (HDF)
HDF5, // Hierarchical Data Format (HDF)
UNKNOWN;

public boolean isIJV() {
return this == TEXT || this == MM;
Expand Down
Expand Up @@ -82,8 +82,8 @@ private void rule_BlockSizeAndReblock(Hop hop, final int blocksize)

// if block size does not match
if( (dop.getDataType() == DataType.MATRIX && (dop.getBlocksize() != blocksize))
||(dop.getDataType() == DataType.FRAME && OptimizerUtils.isSparkExecutionMode() && (dop.getFileFormat()==FileFormat.TEXT
|| dop.getFileFormat()==FileFormat.CSV)) )
||(dop.getDataType() == DataType.FRAME && OptimizerUtils.isSparkExecutionMode()
&& (dop.getFileFormat()==FileFormat.TEXT || dop.getFileFormat()==FileFormat.CSV)) )
{
if( dop.getOp() == OpOpData.PERSISTENTREAD)
{
Expand Down
18 changes: 11 additions & 7 deletions src/main/java/org/apache/sysds/lops/Data.java
Expand Up @@ -68,7 +68,7 @@ public static Data createLiteralLop(ValueType vt, String literalValue) {
* @param fmt file format
*/
public Data(OpOpData op, Lop input, HashMap<String, Lop>
inputParametersLops, String name, String literal, DataType dt, ValueType vt, FileFormat fmt)
inputParametersLops, String name, String literal, DataType dt, ValueType vt, FileFormat fmt)
{
super(Lop.Type.Data, dt, vt);
_op = op;
Expand Down Expand Up @@ -286,14 +286,18 @@ else if (_op.isWrite()) {
OutputParameters oparams = getOutputParameters();
if ( _op.isWrite() ) {
sb.append( OPERAND_DELIMITOR );
String fmt = null;
FileFormat fmt = null;
if ( getDataType() == DataType.MATRIX || getDataType() == DataType.FRAME )
fmt = oparams.getFormat().toString();
fmt = oparams.getFormat();
else // scalars will always be written in text format
fmt = FileFormat.TEXT.toString();

sb.append( prepOperand(fmt, DataType.SCALAR, ValueType.STRING, true));

fmt = FileFormat.TEXT;

//format literal or variable
Lop fmtLop = _inputParams.get(DataExpression.FORMAT_TYPE);
String fmtLabel = (fmt!=FileFormat.UNKNOWN) ? fmt.toString() : fmtLop.getOutputParameters().getLabel();
sb.append(prepOperand(fmtLabel, DataType.SCALAR, ValueType.STRING,
(fmtLop instanceof Data && ((Data)fmtLop).isLiteral()))); //even fmtLop may be Data literal

if(oparams.getFormat() == FileFormat.CSV) {
Data headerLop = (Data) getNamedInputLop(DataExpression.DELIM_HAS_HEADER_ROW);
Data delimLop = (Data) getNamedInputLop(DataExpression.DELIM_DELIMITER);
Expand Down
14 changes: 9 additions & 5 deletions src/main/java/org/apache/sysds/parser/DMLTranslator.java
Expand Up @@ -1047,8 +1047,9 @@ public void constructHops(StatementBlock sb) {
}

DataOp ae = (DataOp)processExpression(source, target, ids);
String formatName = os.getExprParam(DataExpression.FORMAT_TYPE).toString();
ae.setFileFormat(Expression.convertFormatType(formatName));
Expression fmtExpr = os.getExprParam(DataExpression.FORMAT_TYPE);
ae.setFileFormat((fmtExpr instanceof StringIdentifier) ?
Expression.convertFormatType(fmtExpr.toString()) : FileFormat.UNKNOWN);

if (ae.getDataType() == DataType.SCALAR ) {
ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateType(), -1);
Expand All @@ -1065,6 +1066,7 @@ public void constructHops(StatementBlock sb) {
break;
case BINARY:
case COMPRESSED:
case UNKNOWN:
// write output in binary block format
ae.setOutputParams(ae.getDim1(), ae.getDim2(), ae.getNnz(), ae.getUpdateType(), ae.getBlocksize());
break;
Expand All @@ -1077,7 +1079,6 @@ public void constructHops(StatementBlock sb) {
}

output.add(ae);

}

if (current instanceof PrintStatement) {
Expand Down Expand Up @@ -1565,8 +1566,11 @@ else if( source instanceof DataExpression ) {
Hop ae = processDataExpression((DataExpression)source, target, hops);
if (ae instanceof DataOp && ((DataOp) ae).getOp() != OpOpData.SQLREAD &&
((DataOp) ae).getOp() != OpOpData.FEDERATED) {
String formatName = ((DataExpression)source).getVarParam(DataExpression.FORMAT_TYPE).toString();
((DataOp)ae).setFileFormat(Expression.convertFormatType(formatName));
Expression expr = ((DataExpression)source).getVarParam(DataExpression.FORMAT_TYPE);
if( expr instanceof StringIdentifier )
((DataOp)ae).setFileFormat(Expression.convertFormatType(expr.toString()));
else
((DataOp)ae).setFileFormat(FileFormat.UNKNOWN);
}
return ae;
}
Expand Down
36 changes: 21 additions & 15 deletions src/main/java/org/apache/sysds/parser/DataExpression.java
Expand Up @@ -1299,7 +1299,7 @@ else if (valueTypeString.equalsIgnoreCase(ValueType.UNKNOWN.name()))
case WRITE:

// for CSV format, if no delimiter specified THEN set default ","
if (getVarParam(FORMAT_TYPE) == null || getVarParam(FORMAT_TYPE).toString().equalsIgnoreCase(FileFormat.CSV.toString())){
if (getVarParam(FORMAT_TYPE) == null || checkFormatType(FileFormat.CSV) ){
if (getVarParam(DELIM_DELIMITER) == null) {
addVarParam(DELIM_DELIMITER, new StringIdentifier(DEFAULT_DELIM_DELIMITER, this));
}
Expand All @@ -1312,28 +1312,28 @@ else if (valueTypeString.equalsIgnoreCase(ValueType.UNKNOWN.name()))
}

// for LIBSVM format, add the default separators if not specified
if (getVarParam(FORMAT_TYPE) == null || getVarParam(FORMAT_TYPE).toString().equalsIgnoreCase(FileFormat.LIBSVM.toString())) {
if(getVarParam(DELIM_DELIMITER) == null) {
addVarParam(DELIM_DELIMITER, new StringIdentifier(DEFAULT_DELIM_DELIMITER, this));
}
if(getVarParam(LIBSVM_INDEX_DELIM) == null) {
addVarParam(LIBSVM_INDEX_DELIM, new StringIdentifier(DEFAULT_LIBSVM_INDEX_DELIM, this));
}
if(getVarParam(DELIM_SPARSE) == null) {
addVarParam(DELIM_SPARSE, new BooleanIdentifier(DEFAULT_DELIM_SPARSE, this));
}
if (getVarParam(FORMAT_TYPE) == null || checkFormatType(FileFormat.LIBSVM)) {
if(getVarParam(DELIM_DELIMITER) == null) {
addVarParam(DELIM_DELIMITER, new StringIdentifier(DEFAULT_DELIM_DELIMITER, this));
}
if(getVarParam(LIBSVM_INDEX_DELIM) == null) {
addVarParam(LIBSVM_INDEX_DELIM, new StringIdentifier(DEFAULT_LIBSVM_INDEX_DELIM, this));
}
if(getVarParam(DELIM_SPARSE) == null) {
addVarParam(DELIM_SPARSE, new BooleanIdentifier(DEFAULT_DELIM_SPARSE, this));
}
}

//validate read filename
if (getVarParam(FORMAT_TYPE) == null || FileFormat.isTextFormat(getVarParam(FORMAT_TYPE).toString()))
getOutput().setBlocksize(-1);
else if (getVarParam(FORMAT_TYPE).toString().equalsIgnoreCase(FileFormat.BINARY.toString()) || getVarParam(FORMAT_TYPE).toString().equalsIgnoreCase(FileFormat.COMPRESSED.toString())) {
else if (checkFormatType(FileFormat.BINARY, FileFormat.COMPRESSED, FileFormat.UNKNOWN)) {
if( getVarParam(ROWBLOCKCOUNTPARAM)!=null )
getOutput().setBlocksize(Integer.parseInt(getVarParam(ROWBLOCKCOUNTPARAM).toString()));
else
getOutput().setBlocksize(ConfigurationManager.getBlocksize());
}
else
else if( getVarParam(FORMAT_TYPE) instanceof StringIdentifier ) //literal format
raiseValidateError("Invalid format " + getVarParam(FORMAT_TYPE)
+ " in statement: " + toString(), conditional);
break;
Expand Down Expand Up @@ -2189,6 +2189,12 @@ private void handleCSVDefaultParam(String param, ValueType vt, boolean condition
}
}

private boolean checkFormatType(FileFormat... fmts) {
String fmtStr = getVarParam(FORMAT_TYPE).toString();
return Arrays.stream(fmts)
.anyMatch(fmt -> fmtStr.equalsIgnoreCase(fmt.toString()));
}

private boolean checkValueType(Expression expr, ValueType vt) {
return (vt == ValueType.STRING && expr instanceof StringIdentifier)
|| (vt == ValueType.FP64 && (expr instanceof DoubleIdentifier || expr instanceof IntIdentifier))
Expand Down Expand Up @@ -2328,7 +2334,7 @@ public VariableSet variablesUpdated() {

public boolean isCSVReadWithUnknownSize() {
Expression format = getVarParam(FORMAT_TYPE);
if( _opcode == DataOp.READ && format!=null && format.toString().equalsIgnoreCase(FileFormat.CSV.toString()) ) {
if( _opcode == DataOp.READ && format!=null && checkFormatType(FileFormat.CSV) ) {
Expression rows = getVarParam(READROWPARAM);
Expression cols = getVarParam(READCOLPARAM);
return (rows==null || Long.parseLong(rows.toString())<0)
Expand All @@ -2339,7 +2345,7 @@ public boolean isCSVReadWithUnknownSize() {

public boolean isLIBSVMReadWithUnknownSize() {
Expression format = getVarParam(FORMAT_TYPE);
if (_opcode == DataOp.READ && format != null && format.toString().equalsIgnoreCase(FileFormat.LIBSVM.toString())) {
if (_opcode == DataOp.READ && format != null && checkFormatType(FileFormat.LIBSVM)) {
Expression rows = getVarParam(READROWPARAM);
Expression cols = getVarParam(READCOLPARAM);
return (rows == null || Long.parseLong(rows.toString()) < 0)
Expand Down
23 changes: 12 additions & 11 deletions src/main/java/org/apache/sysds/parser/StatementBlock.java
Expand Up @@ -1092,18 +1092,19 @@ public void setStatementFormatType(OutputStatement s, boolean conditionalValidat
if (s.getExprParam(DataExpression.FORMAT_TYPE)!= null )
{
Expression formatTypeExpr = s.getExprParam(DataExpression.FORMAT_TYPE);
if (!(formatTypeExpr instanceof StringIdentifier)){
raiseValidateError("IO statement parameter " + DataExpression.FORMAT_TYPE
+ " can only be a string with one of following values: binary, text, mm, csv.", false, LanguageErrorCodes.INVALID_PARAMETERS);
}
String ft = formatTypeExpr.toString();
try {
s.getIdentifier().setFileFormat(FileFormat.safeValueOf(ft));
if( formatTypeExpr instanceof StringIdentifier ) {
String ft = formatTypeExpr.toString();
try {
s.getIdentifier().setFileFormat(FileFormat.safeValueOf(ft));
}
catch(Exception ex) {
raiseValidateError("IO statement parameter " + DataExpression.FORMAT_TYPE
+ " can only be a string with one of following values: binary, text, mm, csv, libsvm, jsonl;"
+ " invalid format: '"+ft+"'.", false, LanguageErrorCodes.INVALID_PARAMETERS);
}
}
catch(Exception ex) {
raiseValidateError("IO statement parameter " + DataExpression.FORMAT_TYPE
+ " can only be a string with one of following values: binary, text, mm, csv, libsvm, jsonl;"
+ " invalid format: '"+ft+"'.", false, LanguageErrorCodes.INVALID_PARAMETERS);
else {
s.getIdentifier().setFileFormat(FileFormat.UNKNOWN);
}
}
//case of unspecified format parameter, use default
Expand Down
Expand Up @@ -1023,8 +1023,8 @@ private void processCopyInstruction(ExecutionContext ec) {
*/
private void processWriteInstruction(ExecutionContext ec) {
//get filename (literal or variable expression)
String fname = ec.getScalarInput(getInput2().getName(), ValueType.STRING, getInput2().isLiteral()).getStringValue();
String fmtStr = getInput3().getName();
String fname = ec.getScalarInput(getInput2()).getStringValue();
String fmtStr = ec.getScalarInput(getInput3()).getStringValue();
FileFormat fmt = FileFormat.safeValueOf(fmtStr);
if( fmt != FileFormat.LIBSVM && fmt != FileFormat.HDF5) {
String desc = ec.getScalarInput(getInput4().getName(), ValueType.STRING, getInput4().isLiteral()).getStringValue();
Expand Down Expand Up @@ -1110,11 +1110,13 @@ public static void processRmvarInstruction( ExecutionContext ec, String varname
private void writeCSVFile(ExecutionContext ec, String fname) {
MatrixObject mo = ec.getMatrixObject(getInput1().getName());
String outFmt = "csv";

FileFormatProperties fprop = (_formatProperties instanceof FileFormatPropertiesCSV) ?
_formatProperties : new FileFormatPropertiesCSV(); //for dynamic format strings

if(mo.isDirty()) {
// there exist data computed in CP that is not backed up on HDFS
// i.e., it is either in-memory or in evicted space
mo.exportData(fname, outFmt, _formatProperties);
mo.exportData(fname, outFmt, fprop);
}
else {
try {
Expand All @@ -1123,14 +1125,14 @@ private void writeCSVFile(ExecutionContext ec, String fname) {
if( fmt == FileFormat.CSV
&& !getInput1().getName().startsWith(org.apache.sysds.lops.Data.PREAD_PREFIX) )
{
WriterTextCSV writer = new WriterTextCSV((FileFormatPropertiesCSV)_formatProperties);
WriterTextCSV writer = new WriterTextCSV((FileFormatPropertiesCSV)fprop);
writer.addHeaderToCSV(mo.getFileName(), fname, dc.getRows(), dc.getCols());
}
else {
mo.exportData(fname, outFmt, _formatProperties);
mo.exportData(fname, outFmt, fprop);
}
HDFSTool.writeMetaDataFile(fname + ".mtd", mo.getValueType(),
dc, FileFormat.CSV, _formatProperties, mo.getPrivacyConstraint());
dc, FileFormat.CSV, fprop, mo.getPrivacyConstraint());
}
catch(IOException e) {
throw new DMLRuntimeException(e);
Expand Down
Expand Up @@ -144,8 +144,8 @@ public void processInstruction(ExecutionContext ec) {
SparkExecutionContext sec = (SparkExecutionContext) ec;

//get filename (literal or variable expression)
String fname = ec.getScalarInput(input2.getName(), ValueType.STRING, input2.isLiteral()).getStringValue();
String desc = ec.getScalarInput(input4.getName(), ValueType.STRING, input4.isLiteral()).getStringValue();
String fname = ec.getScalarInput(input2).getStringValue();
String desc = ec.getScalarInput(input4).getStringValue();
formatProperties.setDescription(desc);

ValueType[] schema = (input1.getDataType()==DataType.FRAME) ?
Expand All @@ -157,7 +157,8 @@ public void processInstruction(ExecutionContext ec) {
HDFSTool.deleteFileIfExistOnHDFS( fname );

//prepare output info according to meta data
FileFormat fmt = FileFormat.safeValueOf(input3.getName());
String fmtStr = ec.getScalarInput(input3).getStringValue();
FileFormat fmt = FileFormat.safeValueOf(fmtStr);

//core matrix/frame write
switch( input1.getDataType() ) {
Expand Down Expand Up @@ -214,7 +215,9 @@ else if( fmt == FileFormat.CSV ) {
throw new IOException("Write of matrices with zero rows or columns"
+ " not supported ("+mc.getRows()+"x"+mc.getCols()+").");
}

FileFormatProperties fprop = (formatProperties instanceof FileFormatPropertiesCSV) ?
formatProperties : new FileFormatPropertiesCSV(); //for dynamic format strings

//piggyback nnz computation on actual write
LongAccumulator aNnz = null;
if( !mc.nnzKnown() ) {
Expand All @@ -223,7 +226,7 @@ else if( fmt == FileFormat.CSV ) {
}

JavaRDD<String> out = RDDConverterUtils.binaryBlockToCsv(
in1, mc, (FileFormatPropertiesCSV) formatProperties, true);
in1, mc, (FileFormatPropertiesCSV) fprop, true);

customSaveTextFile(out, fname, false);

Expand All @@ -233,7 +236,7 @@ else if( fmt == FileFormat.CSV ) {
else if( fmt == FileFormat.BINARY ) {
//reblock output if needed
int blen = Integer.parseInt(input4.getName());
boolean nonDefaultBlen = ConfigurationManager.getBlocksize() != blen;
boolean nonDefaultBlen = ConfigurationManager.getBlocksize() != blen && blen > 0;
if( nonDefaultBlen )
in1 = RDDConverterUtils.binaryBlockToBinaryBlock(in1, mc,
new MatrixCharacteristics(mc).setBlocksize(blen));
Expand Down

0 comments on commit 565c0e5

Please sign in to comment.