Skip to content

Commit

Permalink
[FLINK-1512] [scala api] Add CsvReader for reading into POJOs
Browse files Browse the repository at this point in the history
  • Loading branch information
chiwanpark authored and fhueske committed Mar 25, 2015
1 parent 7b1c19c commit 7a6f296
Show file tree
Hide file tree
Showing 4 changed files with 378 additions and 188 deletions.
Expand Up @@ -19,66 +19,91 @@
package org.apache.flink.api.scala.operators;


import com.google.common.base.Charsets;
import com.google.common.base.Preconditions;

import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.io.GenericCsvInputFormat;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.PojoTypeInfo;
import org.apache.flink.api.java.typeutils.TupleTypeInfoBase;
import org.apache.flink.api.java.typeutils.runtime.TupleSerializerBase;
import org.apache.flink.core.fs.FileInputSplit;
import org.apache.flink.core.fs.Path;
import org.apache.flink.types.parser.FieldParser;
import org.apache.flink.util.StringUtils;

import org.apache.flink.types.parser.FieldParser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.charset.IllegalCharsetNameException;
import java.nio.charset.UnsupportedCharsetException;
import java.util.Map;
import java.util.TreeMap;
import java.lang.reflect.Field;
import java.util.Arrays;

import scala.Product;

public class ScalaCsvInputFormat<OUT extends Product> extends GenericCsvInputFormat<OUT> {
public class ScalaCsvInputFormat<OUT> extends GenericCsvInputFormat<OUT> {

private static final long serialVersionUID = 1L;

private static final Logger LOG = LoggerFactory.getLogger(ScalaCsvInputFormat.class);

private transient Object[] parsedValues;

// To speed up readRecord processing. Used to find windows line endings.
// It is set when open so that readRecord does not have to evaluate it
private boolean lineDelimiterIsLinebreak = false;

private final TupleSerializerBase<OUT> serializer;
private transient Object[] parsedValues;

private byte[] commentPrefix = null;
private final TupleSerializerBase<OUT> tupleSerializer;

private transient int commentCount;
private transient int invalidLineCount;
private Class<OUT> pojoTypeClass = null;
private String[] pojoFieldsName = null;
private transient Field[] pojoFields = null;
private transient PojoTypeInfo<OUT> pojoTypeInfo = null;

public ScalaCsvInputFormat(Path filePath, TypeInformation<OUT> typeInfo) {
super(filePath);

if (!(typeInfo.isTupleType())) {
throw new UnsupportedOperationException("This only works on tuple types.");
Class<?>[] classes = new Class[typeInfo.getArity()];

if (typeInfo instanceof TupleTypeInfoBase) {
TupleTypeInfoBase<OUT> tupleType = (TupleTypeInfoBase<OUT>) typeInfo;
// We can use an empty config here, since we only use the serializer to create
// the top-level case class
tupleSerializer = (TupleSerializerBase<OUT>) tupleType.createSerializer(new ExecutionConfig());

for (int i = 0; i < tupleType.getArity(); i++) {
classes[i] = tupleType.getTypeAt(i).getTypeClass();
}

setFieldTypes(classes);
} else {
tupleSerializer = null;
pojoTypeInfo = (PojoTypeInfo<OUT>) typeInfo;
pojoTypeClass = typeInfo.getTypeClass();
pojoFieldsName = pojoTypeInfo.getFieldNames();

for (int i = 0, arity = pojoTypeInfo.getArity(); i < arity; i++) {
classes[i] = pojoTypeInfo.getTypeAt(i).getTypeClass();
}

setFieldTypes(classes);
setOrderOfPOJOFields(pojoFieldsName);
}
}

public void setOrderOfPOJOFields(String[] fieldsOrder) {
Preconditions.checkNotNull(pojoTypeClass, "Field order can only be specified if output type is a POJO.");
Preconditions.checkNotNull(fieldsOrder);

int includedCount = 0;
for (boolean isIncluded : fieldIncluded) {
if (isIncluded) {
includedCount++;
}
}
TupleTypeInfoBase<OUT> tupleType = (TupleTypeInfoBase<OUT>) typeInfo;
// We can use an empty config here, since we only use the serializer to create
// the top-level case class
serializer = (TupleSerializerBase<OUT>) tupleType.createSerializer(new ExecutionConfig());

Class<?>[] classes = new Class[tupleType.getArity()];
for (int i = 0; i < tupleType.getArity(); i++) {
classes[i] = tupleType.getTypeAt(i).getTypeClass();

Preconditions.checkArgument(includedCount == fieldsOrder.length,
"The number of selected POJO fields should be the same as that of CSV fields.");

for (String field : fieldsOrder) {
Preconditions.checkNotNull(field, "The field name cannot be null.");
Preconditions.checkArgument(pojoTypeInfo.getFieldIndex(field) != -1,
"The given field name isn't matched to POJO fields.");
}
setFieldTypes(classes);

pojoFieldsName = Arrays.copyOfRange(fieldsOrder, 0, fieldsOrder.length);
}

public void setFieldTypes(Class<?>[] fieldTypes) {
Expand All @@ -98,98 +123,66 @@ public void setFields(int[] sourceFieldIndices, Class<?>[] fieldTypes) {
setFieldsGeneric(sourceFieldIndices, fieldTypes);
}

public byte[] getCommentPrefix() {
return commentPrefix;
}

public void setCommentPrefix(byte[] commentPrefix) {
this.commentPrefix = commentPrefix;
}

public void setCommentPrefix(char commentPrefix) {
setCommentPrefix(String.valueOf(commentPrefix));
}
public void setFields(boolean[] sourceFieldMask, Class<?>[] fieldTypes) {
Preconditions.checkNotNull(sourceFieldMask);
Preconditions.checkNotNull(fieldTypes);

public void setCommentPrefix(String commentPrefix) {
setCommentPrefix(commentPrefix, Charsets.UTF_8);
setFieldsGeneric(sourceFieldMask, fieldTypes);
}

public void setCommentPrefix(String commentPrefix, String charsetName) throws IllegalCharsetNameException, UnsupportedCharsetException {
if (charsetName == null) {
throw new IllegalArgumentException("Charset name must not be null");
}

if (commentPrefix != null) {
Charset charset = Charset.forName(charsetName);
setCommentPrefix(commentPrefix, charset);
} else {
this.commentPrefix = null;
}
public Class<?>[] getFieldTypes() {
return super.getGenericFieldTypes();
}

public void setCommentPrefix(String commentPrefix, Charset charset) {
if (charset == null) {
throw new IllegalArgumentException("Charset must not be null");
}
if (commentPrefix != null) {
this.commentPrefix = commentPrefix.getBytes(charset);
} else {
this.commentPrefix = null;
}
}

@Override
public void close() throws IOException {
if (this.invalidLineCount > 0) {
if (LOG.isWarnEnabled()) {
LOG.warn("In file \""+ this.filePath + "\" (split start: " + this.splitStart + ") " + this.invalidLineCount +" invalid line(s) were skipped.");
}
}

if (this.commentCount > 0) {
if (LOG.isInfoEnabled()) {
LOG.info("In file \""+ this.filePath + "\" (split start: " + this.splitStart + ") " + this.commentCount +" comment line(s) were skipped.");
}
}
super.close();
}

@Override
public OUT nextRecord(OUT record) throws IOException {
OUT returnRecord = null;
do {
returnRecord = super.nextRecord(record);
} while (returnRecord == null && !reachedEnd());

return returnRecord;
}

@Override
public void open(FileInputSplit split) throws IOException {
super.open(split);

@SuppressWarnings("unchecked")
FieldParser<Object>[] fieldParsers = (FieldParser<Object>[]) getFieldParsers();

//throw exception if no field parsers are available
if (fieldParsers.length == 0) {
throw new IOException("CsvInputFormat.open(FileInputSplit split) - no field parsers to parse input");
}

// create the value holders
this.parsedValues = new Object[fieldParsers.length];
for (int i = 0; i < fieldParsers.length; i++) {
this.parsedValues[i] = fieldParsers[i].createValue();
}

this.commentCount = 0;
this.invalidLineCount = 0;

// left to right evaluation makes access [0] okay
// this marker is used to fasten up readRecord, so that it doesn't have to check each call if the line ending is set to default
if (this.getDelimiter().length == 1 && this.getDelimiter()[0] == '\n' ) {
this.lineDelimiterIsLinebreak = true;
}

// for POJO type
if (pojoTypeClass != null) {
pojoFields = new Field[pojoFieldsName.length];
for (int i = 0; i < pojoFieldsName.length; i++) {
try {
pojoFields[i] = pojoTypeClass.getDeclaredField(pojoFieldsName[i]);
pojoFields[i].setAccessible(true);
} catch (NoSuchFieldException e) {
throw new RuntimeException("There is no field called \"" + pojoFieldsName[i] + "\" in " + pojoTypeClass.getName(), e);
}
}
}

this.commentCount = 0;
this.invalidLineCount = 0;
}

@Override
public OUT nextRecord(OUT record) throws IOException {
OUT returnRecord = null;
do {
returnRecord = super.nextRecord(record);
} while (returnRecord == null && !reachedEnd());

return returnRecord;
}

@Override
Expand Down Expand Up @@ -219,73 +212,22 @@ public OUT readRecord(OUT reuse, byte[] bytes, int offset, int numBytes) {
}

if (parseRecord(parsedValues, bytes, offset, numBytes)) {
OUT result = serializer.createInstance(parsedValues);
return result;
if (tupleSerializer != null) {
return tupleSerializer.createInstance(parsedValues);
} else {
for (int i = 0; i < pojoFields.length; i++) {
try {
pojoFields[i].set(reuse, parsedValues[i]);
} catch (IllegalAccessException e) {
throw new RuntimeException("Parsed value could not be set in POJO field \"" + pojoFieldsName[i] + "\"", e);
}
}

return reuse;
}
} else {
this.invalidLineCount++;
return null;
}
}


@Override
public String toString() {
return "CSV Input (" + StringUtils.showControlCharacters(String.valueOf(getFieldDelimiter())) + ") " + getFilePath();
}

// --------------------------------------------------------------------------------------------

@SuppressWarnings("unused")
private static void checkAndCoSort(int[] positions, Class<?>[] types) {
if (positions.length != types.length) {
throw new IllegalArgumentException("The positions and types must be of the same length");
}

TreeMap<Integer, Class<?>> map = new TreeMap<Integer, Class<?>>();

for (int i = 0; i < positions.length; i++) {
if (positions[i] < 0) {
throw new IllegalArgumentException("The field " + " (" + positions[i] + ") is invalid.");
}
if (types[i] == null) {
throw new IllegalArgumentException("The type " + i + " is invalid (null)");
}

if (map.containsKey(positions[i])) {
throw new IllegalArgumentException("The position " + positions[i] + " occurs multiple times.");
}

map.put(positions[i], types[i]);
}

int i = 0;
for (Map.Entry<Integer, Class<?>> entry : map.entrySet()) {
positions[i] = entry.getKey();
types[i] = entry.getValue();
i++;
}
}

private static void checkForMonotonousOrder(int[] positions, Class<?>[] types) {
if (positions.length != types.length) {
throw new IllegalArgumentException("The positions and types must be of the same length");
}

int lastPos = -1;

for (int i = 0; i < positions.length; i++) {
if (positions[i] < 0) {
throw new IllegalArgumentException("The field " + " (" + positions[i] + ") is invalid.");
}
if (types[i] == null) {
throw new IllegalArgumentException("The type " + i + " is invalid (null)");
}

if (positions[i] <= lastPos) {
throw new IllegalArgumentException("The positions must be strictly increasing (no permutations are supported).");
}

lastPos = positions[i];
}
}
}

0 comments on commit 7a6f296

Please sign in to comment.