Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BEAM-6675] Generate JDBC statement and preparedStatementSetter automatically when schema is available #8962

Merged
merged 12 commits into from
Jul 16, 2019
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ public boolean assignableTo(Schema other) {
return equivalent(other, EquivalenceNullablePolicy.WEAKEN);
}

/** Returns true if this Schema can be assigned to another Schema, igmoring nullable. * */
/** Returns true if this Schema can be assigned to another Schema, ignoring nullable. * */
public boolean assignableToIgnoreNullable(Schema other) {
return equivalent(other, EquivalenceNullablePolicy.IGNORE);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ public static void createTable(DataSource dataSource, String tableName) throws S
}
}

public static void createTableForRowWithSchema(DataSource dataSource, String tableName)
throws SQLException {
try (Connection connection = dataSource.getConnection()) {
try (Statement statement = connection.createStatement()) {
statement.execute(String.format("create table %s (name VARCHAR(500), id INT)", tableName));
}
}
}

public static void deleteTable(DataSource dataSource, String tableName) throws SQLException {
if (tableName != null) {
try (Connection connection = dataSource.getConnection();
Expand All @@ -69,4 +78,13 @@ public static String getPostgresDBUrl(PostgresIOTestPipelineOptions options) {
options.getPostgresPort(),
options.getPostgresDatabaseName());
}

public static void createTableWithStatement(DataSource dataSource, String stmt)
throws SQLException {
try (Connection connection = dataSource.getConnection()) {
try (Statement statement = connection.createStatement()) {
statement.execute(stmt);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package org.apache.beam.sdk.io.jdbc;

import static org.apache.beam.sdk.io.jdbc.SchemaUtil.checkNullabilityForFields;
import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkArgument;

import com.google.auto.value.AutoValue;
Expand All @@ -28,7 +29,11 @@
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javax.annotation.Nullable;
import javax.sql.DataSource;
import org.apache.beam.sdk.annotations.Experimental;
Expand Down Expand Up @@ -904,7 +909,7 @@ public interface RetryStrategy extends Serializable {
* <p>All methods in this class delegate to the appropriate method of {@link JdbcIO.WriteVoid}.
*/
public static class Write<T> extends PTransform<PCollection<T>, PDone> {
final WriteVoid<T> inner;
WriteVoid<T> inner;

Write() {
this(JdbcIO.writeVoid());
Expand Down Expand Up @@ -948,6 +953,11 @@ public Write<T> withRetryStrategy(RetryStrategy retryStrategy) {
return new Write(inner.withRetryStrategy(retryStrategy));
}

/** See {@link WriteVoid#withTable(String)}. */
public Write<T> withTable(String table) {
return new Write(inner.withTable(table));
}

/**
* Returns {@link WriteVoid} transform which can be used in {@link Wait#on(PCollection[])} to
* wait until all data is written.
Expand All @@ -972,11 +982,142 @@ public void populateDisplayData(DisplayData.Builder builder) {
inner.populateDisplayData(builder);
}

private boolean hasStatementAndSetter() {
return inner.getStatement() != null && inner.getPreparedStatementSetter() != null;
}

@Override
public PDone expand(PCollection<T> input) {
// fixme: validate invalid table input
if (input.hasSchema() && !hasStatementAndSetter()) {
checkArgument(
inner.getTable() != null, "table cannot be null if statement is not provided");
Schema schema = input.getSchema();
List<SchemaUtil.FieldWithIndex> fields = getFilteredFields(schema);
inner =
inner.withStatement(
JdbcUtil.generateStatement(
inner.getTable(),
fields.stream()
.map(SchemaUtil.FieldWithIndex::getField)
.collect(Collectors.toList())));
inner =
inner.withPreparedStatementSetter(
new AutoGeneratedPreparedStatementSetter(fields, input.getToRowFunction()));
}

inner.expand(input);
return PDone.in(input.getPipeline());
}

private List<SchemaUtil.FieldWithIndex> getFilteredFields(Schema schema) {
Schema tableSchema;

try (Connection connection = inner.getDataSourceProviderFn().apply(null).getConnection();
PreparedStatement statement =
connection.prepareStatement((String.format("SELECT * FROM %s", inner.getTable())))) {
tableSchema = SchemaUtil.toBeamSchema(statement.getMetaData());
statement.close();
} catch (SQLException e) {
throw new RuntimeException(
"Error while determining columns from table: " + inner.getTable(), e);
}

if (tableSchema.getFieldCount() < schema.getFieldCount()) {
throw new RuntimeException("Input schema has more fields than actual table.");
}

// filter out missing fields from output table
List<Schema.Field> missingFields =
tableSchema.getFields().stream()
.filter(
line ->
schema.getFields().stream()
.noneMatch(s -> s.getName().equalsIgnoreCase(line.getName())))
.collect(Collectors.toList());

// allow insert only if missing fields are nullable
if (checkNullabilityForFields(missingFields)) {
throw new RuntimeException("Non nullable fields are not allowed without schema.");
}

List<SchemaUtil.FieldWithIndex> tableFilteredFields =
tableSchema.getFields().stream()
.map(
(tableField) -> {
Optional<Schema.Field> optionalSchemaField =
schema.getFields().stream()
.filter((f) -> SchemaUtil.compareSchemaField(tableField, f))
.findFirst();
return (optionalSchemaField.isPresent())
? SchemaUtil.FieldWithIndex.of(
tableField, schema.getFields().indexOf(optionalSchemaField.get()))
: null;
})
.filter(Objects::nonNull)
.collect(Collectors.toList());

if (tableFilteredFields.size() != schema.getFieldCount()) {
throw new RuntimeException("Provided schema doesn't match with database schema.");
}

return tableFilteredFields;
}

/**
* A {@link org.apache.beam.sdk.io.jdbc.JdbcIO.PreparedStatementSetter} implementation that
* calls related setters on prepared statement.
*/
private class AutoGeneratedPreparedStatementSetter implements PreparedStatementSetter<T> {

private List<SchemaUtil.FieldWithIndex> fields;
private SerializableFunction<T, Row> toRowFn;
private List<PreparedStatementSetCaller> preparedStatementFieldSetterList = new ArrayList<>();

AutoGeneratedPreparedStatementSetter(
List<SchemaUtil.FieldWithIndex> fieldsWithIndex, SerializableFunction<T, Row> toRowFn) {
this.fields = fieldsWithIndex;
this.toRowFn = toRowFn;
populatePreparedStatementFieldSetter();
}

private void populatePreparedStatementFieldSetter() {
IntStream.range(0, fields.size())
.forEach(
(index) -> {
Schema.FieldType fieldType = fields.get(index).getField().getType();
preparedStatementFieldSetterList.add(
JdbcUtil.getPreparedStatementSetCaller(fieldType));
});
}

@Override
public void setParameters(T element, PreparedStatement preparedStatement) throws Exception {
Row row = (element instanceof Row) ? (Row) element : toRowFn.apply(element);
IntStream.range(0, fields.size())
.forEach(
(index) -> {
try {
preparedStatementFieldSetterList
.get(index)
.set(row, preparedStatement, index, fields.get(index));
} catch (SQLException | NullPointerException e) {
throw new RuntimeException("Error while setting data to preparedStatement", e);
}
});
}
}
}

/** Interface implemented by functions that sets prepared statement data. */
@FunctionalInterface
interface PreparedStatementSetCaller extends Serializable {
void set(
Row element,
PreparedStatement preparedStatement,
int prepareStatementIndex,
SchemaUtil.FieldWithIndex schemaFieldWithIndex)
throws SQLException;
}

/** A {@link PTransform} to write to a JDBC datasource. */
Expand All @@ -1001,6 +1142,9 @@ public abstract static class WriteVoid<T> extends PTransform<PCollection<T>, PCo
@Nullable
abstract RetryStrategy getRetryStrategy();

@Nullable
abstract String getTable();

abstract Builder<T> toBuilder();

@AutoValue.Builder
Expand All @@ -1020,6 +1164,8 @@ abstract Builder<T> setDataSourceProviderFn(

abstract Builder<T> setRetryStrategy(RetryStrategy deadlockPredicate);

abstract Builder<T> setTable(String table);

abstract WriteVoid<T> build();
}

Expand Down Expand Up @@ -1065,6 +1211,11 @@ public WriteVoid<T> withRetryStrategy(RetryStrategy retryStrategy) {
return toBuilder().setRetryStrategy(retryStrategy).build();
}

public WriteVoid<T> withTable(String table) {
checkArgument(table != null, "table name can not be null");
return toBuilder().setTable(table).build();
}

@Override
public PCollection<Void> expand(PCollection<T> input) {
checkArgument(getStatement() != null, "withStatement() is required");
Expand Down
Loading