diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java
index d659b25067645..b59e437b47402 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java
@@ -302,7 +302,7 @@ public boolean assignableTo(Schema other) {
return equivalent(other, EquivalenceNullablePolicy.WEAKEN);
}
- /** Returns true if this Schema can be assigned to another Schema, igmoring nullable. * */
+ /** Returns true if this Schema can be assigned to another Schema, ignoring nullable. * */
public boolean assignableToIgnoreNullable(Schema other) {
return equivalent(other, EquivalenceNullablePolicy.IGNORE);
}
diff --git a/sdks/java/io/common/src/test/java/org/apache/beam/sdk/io/common/DatabaseTestHelper.java b/sdks/java/io/common/src/test/java/org/apache/beam/sdk/io/common/DatabaseTestHelper.java
index 6d610f2ad9fe7..c324c4ddc24e4 100644
--- a/sdks/java/io/common/src/test/java/org/apache/beam/sdk/io/common/DatabaseTestHelper.java
+++ b/sdks/java/io/common/src/test/java/org/apache/beam/sdk/io/common/DatabaseTestHelper.java
@@ -47,6 +47,15 @@ public static void createTable(DataSource dataSource, String tableName) throws S
}
}
+ public static void createTableForRowWithSchema(DataSource dataSource, String tableName)
+ throws SQLException {
+ try (Connection connection = dataSource.getConnection()) {
+ try (Statement statement = connection.createStatement()) {
+ statement.execute(String.format("create table %s (name VARCHAR(500), id INT)", tableName));
+ }
+ }
+ }
+
public static void deleteTable(DataSource dataSource, String tableName) throws SQLException {
if (tableName != null) {
try (Connection connection = dataSource.getConnection();
@@ -69,4 +78,13 @@ public static String getPostgresDBUrl(PostgresIOTestPipelineOptions options) {
options.getPostgresPort(),
options.getPostgresDatabaseName());
}
+
+ public static void createTableWithStatement(DataSource dataSource, String stmt)
+ throws SQLException {
+ try (Connection connection = dataSource.getConnection()) {
+ try (Statement statement = connection.createStatement()) {
+ statement.execute(stmt);
+ }
+ }
+ }
}
diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java
index 2574c45cbdc3b..18bd9a6bdf5da 100644
--- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java
+++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java
@@ -17,6 +17,7 @@
*/
package org.apache.beam.sdk.io.jdbc;
+import static org.apache.beam.sdk.io.jdbc.SchemaUtil.checkNullabilityForFields;
import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkArgument;
import com.google.auto.value.AutoValue;
@@ -28,7 +29,11 @@
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
import java.util.function.Predicate;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
import javax.annotation.Nullable;
import javax.sql.DataSource;
import org.apache.beam.sdk.annotations.Experimental;
@@ -904,7 +909,7 @@ public interface RetryStrategy extends Serializable {
*
All methods in this class delegate to the appropriate method of {@link JdbcIO.WriteVoid}.
*/
public static class Write extends PTransform, PDone> {
- final WriteVoid inner;
+ WriteVoid inner;
Write() {
this(JdbcIO.writeVoid());
@@ -948,6 +953,11 @@ public Write withRetryStrategy(RetryStrategy retryStrategy) {
return new Write(inner.withRetryStrategy(retryStrategy));
}
+ /** See {@link WriteVoid#withTable(String)}. */
+ public Write withTable(String table) {
+ return new Write(inner.withTable(table));
+ }
+
/**
* Returns {@link WriteVoid} transform which can be used in {@link Wait#on(PCollection[])} to
* wait until all data is written.
@@ -972,11 +982,142 @@ public void populateDisplayData(DisplayData.Builder builder) {
inner.populateDisplayData(builder);
}
+ private boolean hasStatementAndSetter() {
+ return inner.getStatement() != null && inner.getPreparedStatementSetter() != null;
+ }
+
@Override
public PDone expand(PCollection input) {
+ // fixme: validate invalid table input
+ if (input.hasSchema() && !hasStatementAndSetter()) {
+ checkArgument(
+ inner.getTable() != null, "table cannot be null if statement is not provided");
+ Schema schema = input.getSchema();
+ List fields = getFilteredFields(schema);
+ inner =
+ inner.withStatement(
+ JdbcUtil.generateStatement(
+ inner.getTable(),
+ fields.stream()
+ .map(SchemaUtil.FieldWithIndex::getField)
+ .collect(Collectors.toList())));
+ inner =
+ inner.withPreparedStatementSetter(
+ new AutoGeneratedPreparedStatementSetter(fields, input.getToRowFunction()));
+ }
+
inner.expand(input);
return PDone.in(input.getPipeline());
}
+
+ private List getFilteredFields(Schema schema) {
+ Schema tableSchema;
+
+ try (Connection connection = inner.getDataSourceProviderFn().apply(null).getConnection();
+ PreparedStatement statement =
+ connection.prepareStatement((String.format("SELECT * FROM %s", inner.getTable())))) {
+ tableSchema = SchemaUtil.toBeamSchema(statement.getMetaData());
+ statement.close();
+ } catch (SQLException e) {
+ throw new RuntimeException(
+ "Error while determining columns from table: " + inner.getTable(), e);
+ }
+
+ if (tableSchema.getFieldCount() < schema.getFieldCount()) {
+ throw new RuntimeException("Input schema has more fields than actual table.");
+ }
+
+ // filter out missing fields from output table
+ List missingFields =
+ tableSchema.getFields().stream()
+ .filter(
+ line ->
+ schema.getFields().stream()
+ .noneMatch(s -> s.getName().equalsIgnoreCase(line.getName())))
+ .collect(Collectors.toList());
+
+ // allow insert only if missing fields are nullable
+ if (checkNullabilityForFields(missingFields)) {
+ throw new RuntimeException("Non nullable fields are not allowed without schema.");
+ }
+
+ List tableFilteredFields =
+ tableSchema.getFields().stream()
+ .map(
+ (tableField) -> {
+ Optional optionalSchemaField =
+ schema.getFields().stream()
+ .filter((f) -> SchemaUtil.compareSchemaField(tableField, f))
+ .findFirst();
+ return (optionalSchemaField.isPresent())
+ ? SchemaUtil.FieldWithIndex.of(
+ tableField, schema.getFields().indexOf(optionalSchemaField.get()))
+ : null;
+ })
+ .filter(Objects::nonNull)
+ .collect(Collectors.toList());
+
+ if (tableFilteredFields.size() != schema.getFieldCount()) {
+ throw new RuntimeException("Provided schema doesn't match with database schema.");
+ }
+
+ return tableFilteredFields;
+ }
+
+ /**
+ * A {@link org.apache.beam.sdk.io.jdbc.JdbcIO.PreparedStatementSetter} implementation that
+ * calls related setters on prepared statement.
+ */
+ private class AutoGeneratedPreparedStatementSetter implements PreparedStatementSetter {
+
+ private List fields;
+ private SerializableFunction toRowFn;
+ private List preparedStatementFieldSetterList = new ArrayList<>();
+
+ AutoGeneratedPreparedStatementSetter(
+ List fieldsWithIndex, SerializableFunction toRowFn) {
+ this.fields = fieldsWithIndex;
+ this.toRowFn = toRowFn;
+ populatePreparedStatementFieldSetter();
+ }
+
+ private void populatePreparedStatementFieldSetter() {
+ IntStream.range(0, fields.size())
+ .forEach(
+ (index) -> {
+ Schema.FieldType fieldType = fields.get(index).getField().getType();
+ preparedStatementFieldSetterList.add(
+ JdbcUtil.getPreparedStatementSetCaller(fieldType));
+ });
+ }
+
+ @Override
+ public void setParameters(T element, PreparedStatement preparedStatement) throws Exception {
+ Row row = (element instanceof Row) ? (Row) element : toRowFn.apply(element);
+ IntStream.range(0, fields.size())
+ .forEach(
+ (index) -> {
+ try {
+ preparedStatementFieldSetterList
+ .get(index)
+ .set(row, preparedStatement, index, fields.get(index));
+ } catch (SQLException | NullPointerException e) {
+ throw new RuntimeException("Error while setting data to preparedStatement", e);
+ }
+ });
+ }
+ }
+ }
+
+ /** Interface implemented by functions that sets prepared statement data. */
+ @FunctionalInterface
+ interface PreparedStatementSetCaller extends Serializable {
+ void set(
+ Row element,
+ PreparedStatement preparedStatement,
+ int prepareStatementIndex,
+ SchemaUtil.FieldWithIndex schemaFieldWithIndex)
+ throws SQLException;
}
/** A {@link PTransform} to write to a JDBC datasource. */
@@ -1001,6 +1142,9 @@ public abstract static class WriteVoid extends PTransform, PCo
@Nullable
abstract RetryStrategy getRetryStrategy();
+ @Nullable
+ abstract String getTable();
+
abstract Builder toBuilder();
@AutoValue.Builder
@@ -1020,6 +1164,8 @@ abstract Builder setDataSourceProviderFn(
abstract Builder setRetryStrategy(RetryStrategy deadlockPredicate);
+ abstract Builder setTable(String table);
+
abstract WriteVoid build();
}
@@ -1065,6 +1211,11 @@ public WriteVoid withRetryStrategy(RetryStrategy retryStrategy) {
return toBuilder().setRetryStrategy(retryStrategy).build();
}
+ public WriteVoid withTable(String table) {
+ checkArgument(table != null, "table name can not be null");
+ return toBuilder().setTable(table).build();
+ }
+
@Override
public PCollection expand(PCollection input) {
checkArgument(getStatement() != null, "withStatement() is required");
diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java
new file mode 100644
index 0000000000000..17fdae73c91f3
--- /dev/null
+++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java
@@ -0,0 +1,235 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.jdbc;
+
+import java.sql.Date;
+import java.sql.JDBCType;
+import java.sql.Time;
+import java.sql.Timestamp;
+import java.util.Calendar;
+import java.util.EnumMap;
+import java.util.List;
+import java.util.Map;
+import java.util.TimeZone;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import org.apache.beam.sdk.schemas.Schema;
+import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMap;
+import org.joda.time.DateTime;
+
+/** Provides utility functions for working with {@link JdbcIO}. */
+public class JdbcUtil {
+
+ /** Generates an insert statement based on {@Link Schema.Field}. * */
+ public static String generateStatement(String tableName, List fields) {
+
+ String fieldNames =
+ IntStream.range(0, fields.size())
+ .mapToObj(
+ (index) -> {
+ return fields.get(index).getName();
+ })
+ .collect(Collectors.joining(", "));
+
+ String valuePlaceholder =
+ IntStream.range(0, fields.size())
+ .mapToObj(
+ (index) -> {
+ return "?";
+ })
+ .collect(Collectors.joining(", "));
+
+ return String.format("INSERT INTO %s(%s) VALUES(%s)", tableName, fieldNames, valuePlaceholder);
+ }
+
+ /** PreparedStatementSetCaller for Schema Field types. * */
+ private static Map typeNamePsSetCallerMap =
+ new EnumMap(
+ ImmutableMap.builder()
+ .put(
+ Schema.TypeName.BYTE,
+ (element, ps, i, fieldWithIndex) ->
+ ps.setByte(i + 1, element.getByte(fieldWithIndex.getIndex())))
+ .put(
+ Schema.TypeName.INT16,
+ (element, ps, i, fieldWithIndex) ->
+ ps.setInt(i + 1, element.getInt16(fieldWithIndex.getIndex())))
+ .put(
+ Schema.TypeName.INT64,
+ (element, ps, i, fieldWithIndex) ->
+ ps.setLong(i + 1, element.getInt64(fieldWithIndex.getIndex())))
+ .put(
+ Schema.TypeName.DECIMAL,
+ (element, ps, i, fieldWithIndex) ->
+ ps.setBigDecimal(i + 1, element.getDecimal(fieldWithIndex.getIndex())))
+ .put(
+ Schema.TypeName.FLOAT,
+ (element, ps, i, fieldWithIndex) ->
+ ps.setFloat(i + 1, element.getFloat(fieldWithIndex.getIndex())))
+ .put(
+ Schema.TypeName.DOUBLE,
+ (element, ps, i, fieldWithIndex) ->
+ ps.setDouble(i + 1, element.getDouble(fieldWithIndex.getIndex())))
+ .put(
+ Schema.TypeName.DATETIME,
+ (element, ps, i, fieldWithIndex) ->
+ ps.setTimestamp(
+ i + 1,
+ new Timestamp(
+ element.getDateTime(fieldWithIndex.getIndex()).getMillis())))
+ .put(
+ Schema.TypeName.BOOLEAN,
+ (element, ps, i, fieldWithIndex) ->
+ ps.setBoolean(i + 1, element.getBoolean(fieldWithIndex.getIndex())))
+ .put(Schema.TypeName.BYTES, createBytesCaller())
+ .put(
+ Schema.TypeName.INT32,
+ (element, ps, i, fieldWithIndex) ->
+ ps.setInt(i + 1, element.getInt32(fieldWithIndex.getIndex())))
+ .put(Schema.TypeName.STRING, createStringCaller())
+ .build());
+
+ /** PreparedStatementSetCaller for Schema Field Logical types. * */
+ public static JdbcIO.PreparedStatementSetCaller getPreparedStatementSetCaller(
+ Schema.FieldType fieldType) {
+ switch (fieldType.getTypeName()) {
+ case ARRAY:
+ return (element, ps, i, fieldWithIndex) -> {
+ ps.setArray(
+ i + 1,
+ ps.getConnection()
+ .createArrayOf(
+ fieldType.getCollectionElementType().getTypeName().name(),
+ element.getArray(fieldWithIndex.getIndex()).toArray()));
+ };
+ case LOGICAL_TYPE:
+ {
+ String logicalTypeName = fieldType.getLogicalType().getIdentifier();
+ JDBCType jdbcType = JDBCType.valueOf(logicalTypeName);
+ switch (jdbcType) {
+ case DATE:
+ return (element, ps, i, fieldWithIndex) -> {
+ ps.setDate(
+ i + 1,
+ new Date(
+ getDateOrTimeOnly(
+ element.getDateTime(fieldWithIndex.getIndex()).toDateTime(), true)
+ .getTime()
+ .getTime()));
+ };
+ case TIME:
+ return (element, ps, i, fieldWithIndex) -> {
+ ps.setTime(
+ i + 1,
+ new Time(
+ getDateOrTimeOnly(
+ element.getDateTime(fieldWithIndex.getIndex()).toDateTime(), false)
+ .getTime()
+ .getTime()));
+ };
+ case TIMESTAMP_WITH_TIMEZONE:
+ return (element, ps, i, fieldWithIndex) -> {
+ Calendar calendar =
+ withTimestampAndTimezone(
+ element.getDateTime(fieldWithIndex.getIndex()).toDateTime());
+ ps.setTimestamp(i + 1, new Timestamp(calendar.getTime().getTime()), calendar);
+ };
+ default:
+ return getPreparedStatementSetCaller(fieldType.getLogicalType().getBaseType());
+ }
+ }
+ default:
+ {
+ if (typeNamePsSetCallerMap.containsKey(fieldType.getTypeName())) {
+ return typeNamePsSetCallerMap.get(fieldType.getTypeName());
+ } else {
+ throw new RuntimeException(
+ fieldType.getTypeName().name()
+ + " in schema is not supported while writing. Please provide statement and preparedStatementSetter");
+ }
+ }
+ }
+ }
+
+ private static JdbcIO.PreparedStatementSetCaller createBytesCaller() {
+ return (element, ps, i, fieldWithIndex) -> {
+ validateLogicalTypeLength(
+ fieldWithIndex.getField(), element.getBytes(fieldWithIndex.getIndex()).length);
+ ps.setBytes(i + 1, element.getBytes(fieldWithIndex.getIndex()));
+ };
+ }
+
+ private static JdbcIO.PreparedStatementSetCaller createStringCaller() {
+ return (element, ps, i, fieldWithIndex) -> {
+ validateLogicalTypeLength(
+ fieldWithIndex.getField(), element.getString(fieldWithIndex.getIndex()).length());
+ ps.setString(i + 1, element.getString(fieldWithIndex.getIndex()));
+ };
+ }
+
+ private static void validateLogicalTypeLength(Schema.Field field, Integer length) {
+ try {
+ if (field.getType().getTypeName().isLogicalType()
+ && !field.getType().getLogicalType().getArgument().isEmpty()) {
+ int maxLimit = Integer.parseInt(field.getType().getLogicalType().getArgument());
+ if (field.getType().getTypeName().isLogicalType() && length >= maxLimit) {
+ throw new RuntimeException(
+ String.format(
+ "Length of Schema.Field[%s] data exceeds database column capacity",
+ field.getName()));
+ }
+ }
+ } catch (NumberFormatException e) {
+ // if argument is not set or not integer then do nothing and proceed with the insertion
+ }
+ }
+
+ private static Calendar getDateOrTimeOnly(DateTime dateTime, boolean wantDateOnly) {
+ Calendar cal = Calendar.getInstance();
+ cal.setTimeZone(TimeZone.getTimeZone(dateTime.getZone().getID()));
+
+ if (wantDateOnly) { // return date only
+ cal.set(Calendar.YEAR, dateTime.getYear());
+ cal.set(Calendar.MONTH, dateTime.getMonthOfYear() - 1);
+ cal.set(Calendar.DATE, dateTime.getDayOfMonth());
+
+ cal.set(Calendar.HOUR_OF_DAY, 0);
+ cal.set(Calendar.MINUTE, 0);
+ cal.set(Calendar.SECOND, 0);
+ cal.set(Calendar.MILLISECOND, 0);
+ } else { // return time only
+ cal.set(Calendar.YEAR, 1970);
+ cal.set(Calendar.MONTH, Calendar.JANUARY);
+ cal.set(Calendar.DATE, 1);
+
+ cal.set(Calendar.HOUR_OF_DAY, dateTime.getHourOfDay());
+ cal.set(Calendar.MINUTE, dateTime.getMinuteOfHour());
+ cal.set(Calendar.SECOND, dateTime.getSecondOfMinute());
+ cal.set(Calendar.MILLISECOND, dateTime.getMillisOfSecond());
+ }
+
+ return cal;
+ }
+
+ private static Calendar withTimestampAndTimezone(DateTime dateTime) {
+ Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone(dateTime.getZone().getID()));
+ calendar.setTimeInMillis(dateTime.getMillis());
+
+ return calendar;
+ }
+}
diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/SchemaUtil.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/SchemaUtil.java
index 9c1dcb050a7e5..99fa06f7ed727 100644
--- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/SchemaUtil.java
+++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/SchemaUtil.java
@@ -28,6 +28,7 @@
import static java.sql.JDBCType.VARBINARY;
import static java.sql.JDBCType.VARCHAR;
import static java.sql.JDBCType.valueOf;
+import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkArgument;
import java.io.Serializable;
import java.sql.Array;
@@ -338,4 +339,73 @@ public Row mapRow(ResultSet rs) throws Exception {
return rowBuilder.build();
}
}
+
+ /**
+ * compares two fields. Does not compare nullability of field types.
+ *
+ * @param a field 1
+ * @param b field 2
+ * @return TRUE if fields are equal. Otherwise FALSE
+ */
+ public static boolean compareSchemaField(Schema.Field a, Schema.Field b) {
+ if (!a.getName().equalsIgnoreCase(b.getName())) {
+ return false;
+ }
+
+ return compareSchemaFieldType(a.getType(), b.getType());
+ }
+
+ /**
+ * checks nullability for fields.
+ *
+ * @param fields
+ * @return TRUE if any field is not nullable
+ */
+ public static boolean checkNullabilityForFields(List fields) {
+ return fields.stream().anyMatch(field -> !field.getType().getNullable());
+ }
+
+ /**
+ * compares two FieldType. Does not compare nullability.
+ *
+ * @param a FieldType 1
+ * @param b FieldType 2
+ * @return TRUE if FieldType are equal. Otherwise FALSE
+ */
+ public static boolean compareSchemaFieldType(Schema.FieldType a, Schema.FieldType b) {
+ if (a.getTypeName().equals(b.getTypeName())) {
+ return !a.getTypeName().equals(Schema.TypeName.LOGICAL_TYPE)
+ || compareSchemaFieldType(
+ a.getLogicalType().getBaseType(), b.getLogicalType().getBaseType());
+ } else if (a.getTypeName().isLogicalType()) {
+ return a.getLogicalType().getBaseType().getTypeName().equals(b.getTypeName());
+ } else if (b.getTypeName().isLogicalType()) {
+ return b.getLogicalType().getBaseType().getTypeName().equals(a.getTypeName());
+ }
+ return false;
+ }
+
+ static class FieldWithIndex implements Serializable {
+ private final Schema.Field field;
+ private final Integer index;
+
+ private FieldWithIndex(Schema.Field field, Integer index) {
+ this.field = field;
+ this.index = index;
+ }
+
+ public static FieldWithIndex of(Schema.Field field, Integer index) {
+ checkArgument(field != null);
+ checkArgument(index != null);
+ return new FieldWithIndex(field, index);
+ }
+
+ public Schema.Field getField() {
+ return field;
+ }
+
+ public Integer getIndex() {
+ return index;
+ }
+ }
}
diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java
index ecd26a139486c..0e815339bcc29 100644
--- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java
+++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java
@@ -19,19 +19,35 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
import java.io.PrintWriter;
import java.io.Serializable;
import java.io.StringWriter;
+import java.math.BigDecimal;
import java.net.InetAddress;
+import java.nio.charset.Charset;
+import java.sql.Array;
import java.sql.Connection;
+import java.sql.Date;
import java.sql.JDBCType;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
+import java.sql.Time;
+import java.sql.Timestamp;
import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Calendar;
import java.util.Collections;
+import java.util.List;
+import java.util.TimeZone;
import javax.sql.DataSource;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.SerializableCoder;
@@ -56,12 +72,16 @@
import org.apache.commons.dbcp2.PoolingDataSource;
import org.apache.derby.drda.NetworkServerControl;
import org.apache.derby.jdbc.ClientDataSource;
+import org.joda.time.DateTime;
+import org.joda.time.LocalDate;
+import org.joda.time.chrono.ISOChronology;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
+import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.slf4j.Logger;
@@ -85,6 +105,8 @@ public class JdbcIOTest implements Serializable {
@Rule public final transient ExpectedLogs expectedLogs = ExpectedLogs.none(JdbcIO.class);
+ @Rule public transient ExpectedException thrown = ExpectedException.none();
+
@BeforeClass
public static void beforeClass() throws Exception {
port = NetworkTestHelper.getAvailableLocalPort();
@@ -490,6 +512,345 @@ public void tearDown() {
}
}
+ @Test
+ public void testWriteWithoutPreparedStatement() throws Exception {
+ final int rowsToAdd = 10;
+
+ Schema.Builder schemaBuilder = Schema.builder();
+ schemaBuilder.addField(Schema.Field.of("column_boolean", Schema.FieldType.BOOLEAN));
+ schemaBuilder.addField(Schema.Field.of("column_string", Schema.FieldType.STRING));
+ schemaBuilder.addField(Schema.Field.of("column_int", Schema.FieldType.INT32));
+ schemaBuilder.addField(Schema.Field.of("column_long", Schema.FieldType.INT64));
+ schemaBuilder.addField(Schema.Field.of("column_float", Schema.FieldType.FLOAT));
+ schemaBuilder.addField(Schema.Field.of("column_double", Schema.FieldType.DOUBLE));
+ schemaBuilder.addField(Schema.Field.of("column_bigdecimal", Schema.FieldType.DECIMAL));
+ schemaBuilder.addField(Schema.Field.of("column_date", LogicalTypes.JDBC_DATE_TYPE));
+ schemaBuilder.addField(Schema.Field.of("column_time", LogicalTypes.JDBC_TIME_TYPE));
+ schemaBuilder.addField(
+ Schema.Field.of("column_timestamptz", LogicalTypes.JDBC_TIMESTAMP_WITH_TIMEZONE_TYPE));
+ schemaBuilder.addField(Schema.Field.of("column_timestamp", Schema.FieldType.DATETIME));
+ schemaBuilder.addField(Schema.Field.of("column_short", Schema.FieldType.INT16));
+ Schema schema = schemaBuilder.build();
+
+ String tableName = DatabaseTestHelper.getTestTableName("UT_WRITE_PS");
+ StringBuilder stmt = new StringBuilder("CREATE TABLE ");
+ stmt.append(tableName);
+ stmt.append(" (");
+ stmt.append("column_boolean BOOLEAN,"); // boolean
+ stmt.append("column_string VARCHAR(254),"); // String
+ stmt.append("column_int INTEGER,"); // int
+ stmt.append("column_long BIGINT,"); // long
+ stmt.append("column_float REAL,"); // float
+ stmt.append("column_double DOUBLE PRECISION,"); // double
+ stmt.append("column_bigdecimal DECIMAL(13,0),"); // BigDecimal
+ stmt.append("column_date DATE,"); // Date
+ stmt.append("column_time TIME,"); // Time
+ stmt.append("column_timestamptz TIMESTAMP,"); // Timestamp
+ stmt.append("column_timestamp TIMESTAMP,"); // Timestamp
+ stmt.append("column_short SMALLINT"); // short
+ stmt.append(" )");
+ DatabaseTestHelper.createTableWithStatement(dataSource, stmt.toString());
+ try {
+ ArrayList data = getRowsToWrite(rowsToAdd, schema);
+ pipeline
+ .apply(Create.of(data))
+ .setRowSchema(schema)
+ .apply(
+ JdbcIO.write()
+ .withDataSourceConfiguration(
+ JdbcIO.DataSourceConfiguration.create(
+ "org.apache.derby.jdbc.ClientDriver",
+ "jdbc:derby://localhost:" + port + "/target/beam"))
+ .withBatchSize(10L)
+ .withTable(tableName));
+ pipeline.run();
+ assertRowCount(tableName, rowsToAdd);
+ } finally {
+ DatabaseTestHelper.deleteTable(dataSource, tableName);
+ }
+ }
+
+ @Test
+ public void testWriteWithoutPreparedStatementWithReadRows() throws Exception {
+ SerializableFunction dataSourceProvider = ignored -> dataSource;
+ PCollection rows =
+ pipeline.apply(
+ JdbcIO.readRows()
+ .withDataSourceProviderFn(dataSourceProvider)
+ .withQuery(String.format("select name,id from %s where name = ?", readTableName))
+ .withStatementPreparator(
+ preparedStatement ->
+ preparedStatement.setString(1, TestRow.getNameForSeed(1))));
+
+ String writeTableName = DatabaseTestHelper.getTestTableName("UT_WRITE_PS_WITH_READ_ROWS");
+ DatabaseTestHelper.createTableForRowWithSchema(dataSource, writeTableName);
+ try {
+ rows.apply(
+ JdbcIO.write()
+ .withDataSourceConfiguration(
+ JdbcIO.DataSourceConfiguration.create(
+ "org.apache.derby.jdbc.ClientDriver",
+ "jdbc:derby://localhost:" + port + "/target/beam"))
+ .withBatchSize(10L)
+ .withTable(writeTableName));
+ pipeline.run();
+ } finally {
+ DatabaseTestHelper.deleteTable(dataSource, writeTableName);
+ }
+ }
+
+ @Test
+ public void testWriteWithoutPsWithNonNullableTableField() throws Exception {
+ final int rowsToAdd = 10;
+
+ Schema.Builder schemaBuilder = Schema.builder();
+ schemaBuilder.addField(Schema.Field.of("column_boolean", Schema.FieldType.BOOLEAN));
+ schemaBuilder.addField(Schema.Field.of("column_string", Schema.FieldType.STRING));
+ Schema schema = schemaBuilder.build();
+
+ String tableName = DatabaseTestHelper.getTestTableName("UT_WRITE");
+ StringBuilder stmt = new StringBuilder("CREATE TABLE ");
+ stmt.append(tableName);
+ stmt.append(" (");
+ stmt.append("column_boolean BOOLEAN,");
+ stmt.append("column_int INTEGER NOT NULL");
+ stmt.append(" )");
+ DatabaseTestHelper.createTableWithStatement(dataSource, stmt.toString());
+ try {
+ ArrayList data = getRowsToWrite(rowsToAdd, schema);
+ pipeline
+ .apply(Create.of(data))
+ .setRowSchema(schema)
+ .apply(
+ JdbcIO.write()
+ .withDataSourceConfiguration(
+ JdbcIO.DataSourceConfiguration.create(
+ "org.apache.derby.jdbc.ClientDriver",
+ "jdbc:derby://localhost:" + port + "/target/beam"))
+ .withBatchSize(10L)
+ .withTable(tableName));
+ pipeline.run();
+ } finally {
+ DatabaseTestHelper.deleteTable(dataSource, tableName);
+ thrown.expect(RuntimeException.class);
+ }
+ }
+
+ @Test
+ public void testWriteWithoutPreparedStatementAndNonRowType() throws Exception {
+ final int rowsToAdd = 10;
+
+ String tableName = DatabaseTestHelper.getTestTableName("UT_WRITE_PS_NON_ROW");
+ DatabaseTestHelper.createTableForRowWithSchema(dataSource, tableName);
+ try {
+ List data = getRowsWithSchemaToWrite(rowsToAdd);
+
+ pipeline
+ .apply(Create.of(data))
+ .apply(
+ JdbcIO.write()
+ .withDataSourceConfiguration(
+ JdbcIO.DataSourceConfiguration.create(
+ "org.apache.derby.jdbc.ClientDriver",
+ "jdbc:derby://localhost:" + port + "/target/beam"))
+ .withBatchSize(10L)
+ .withTable(tableName));
+ pipeline.run();
+ assertRowCount(tableName, rowsToAdd);
+ } finally {
+ DatabaseTestHelper.deleteTable(dataSource, tableName);
+ }
+ }
+
+ @Test
+ public void testGetPreparedStatementSetCaller() throws Exception {
+
+ Schema schema =
+ Schema.builder()
+ .addField("bigint_col", Schema.FieldType.INT64)
+ .addField("binary_col", Schema.FieldType.BYTES)
+ .addField("bit_col", Schema.FieldType.BOOLEAN)
+ .addField("char_col", Schema.FieldType.STRING)
+ .addField("decimal_col", Schema.FieldType.DECIMAL)
+ .addField("double_col", Schema.FieldType.DOUBLE)
+ .addField("float_col", Schema.FieldType.FLOAT)
+ .addField("integer_col", Schema.FieldType.INT32)
+ .addField("datetime_col", Schema.FieldType.DATETIME)
+ .addField("int16_col", Schema.FieldType.INT16)
+ .addField("byte_col", Schema.FieldType.BYTE)
+ .build();
+ Row row =
+ Row.withSchema(schema)
+ .addValues(
+ 42L,
+ "binary".getBytes(Charset.forName("UTF-8")),
+ true,
+ "char",
+ BigDecimal.valueOf(25L),
+ 20.5D,
+ 15.5F,
+ 10,
+ new DateTime(),
+ (short) 5,
+ Byte.parseByte("1", 2))
+ .build();
+
+ PreparedStatement psMocked = mock(PreparedStatement.class);
+
+ JdbcUtil.getPreparedStatementSetCaller(Schema.FieldType.INT64)
+ .set(row, psMocked, 0, SchemaUtil.FieldWithIndex.of(schema.getField(0), 0));
+ JdbcUtil.getPreparedStatementSetCaller(Schema.FieldType.BYTES)
+ .set(row, psMocked, 1, SchemaUtil.FieldWithIndex.of(schema.getField(1), 1));
+ JdbcUtil.getPreparedStatementSetCaller(Schema.FieldType.BOOLEAN)
+ .set(row, psMocked, 2, SchemaUtil.FieldWithIndex.of(schema.getField(2), 2));
+ JdbcUtil.getPreparedStatementSetCaller(Schema.FieldType.STRING)
+ .set(row, psMocked, 3, SchemaUtil.FieldWithIndex.of(schema.getField(3), 3));
+ JdbcUtil.getPreparedStatementSetCaller(Schema.FieldType.DECIMAL)
+ .set(row, psMocked, 4, SchemaUtil.FieldWithIndex.of(schema.getField(4), 4));
+ JdbcUtil.getPreparedStatementSetCaller(Schema.FieldType.DOUBLE)
+ .set(row, psMocked, 5, SchemaUtil.FieldWithIndex.of(schema.getField(5), 5));
+ JdbcUtil.getPreparedStatementSetCaller(Schema.FieldType.FLOAT)
+ .set(row, psMocked, 6, SchemaUtil.FieldWithIndex.of(schema.getField(6), 6));
+ JdbcUtil.getPreparedStatementSetCaller(Schema.FieldType.INT32)
+ .set(row, psMocked, 7, SchemaUtil.FieldWithIndex.of(schema.getField(7), 7));
+ JdbcUtil.getPreparedStatementSetCaller(Schema.FieldType.DATETIME)
+ .set(row, psMocked, 8, SchemaUtil.FieldWithIndex.of(schema.getField(8), 8));
+ JdbcUtil.getPreparedStatementSetCaller(Schema.FieldType.INT16)
+ .set(row, psMocked, 9, SchemaUtil.FieldWithIndex.of(schema.getField(9), 9));
+ JdbcUtil.getPreparedStatementSetCaller(Schema.FieldType.BYTE)
+ .set(row, psMocked, 10, SchemaUtil.FieldWithIndex.of(schema.getField(10), 10));
+
+ verify(psMocked, times(1)).setLong(1, 42L);
+ verify(psMocked, times(1)).setBytes(2, "binary".getBytes(Charset.forName("UTF-8")));
+ verify(psMocked, times(1)).setBoolean(3, true);
+ verify(psMocked, times(1)).setString(4, "char");
+ verify(psMocked, times(1)).setBigDecimal(5, BigDecimal.valueOf(25L));
+ verify(psMocked, times(1)).setDouble(6, 20.5D);
+ verify(psMocked, times(1)).setFloat(7, 15.5F);
+ verify(psMocked, times(1)).setInt(8, 10);
+ verify(psMocked, times(1))
+ .setTimestamp(9, new Timestamp(row.getDateTime("datetime_col").getMillis()));
+ verify(psMocked, times(1)).setInt(10, (short) 5);
+ verify(psMocked, times(1)).setByte(11, Byte.parseByte("1", 2));
+ }
+
+ @Test
+ public void testGetPreparedStatementSetCallerForLogicalTypes() throws Exception {
+
+ Schema schema =
+ Schema.builder()
+ .addField("logical_date_col", LogicalTypes.JDBC_DATE_TYPE)
+ .addField("logical_time_col", LogicalTypes.JDBC_TIME_TYPE)
+ .addField("logical_time_with_tz_col", LogicalTypes.JDBC_TIMESTAMP_WITH_TIMEZONE_TYPE)
+ .build();
+
+ long epochMilli = 1558719710000L;
+ DateTime dateTime = new DateTime(epochMilli, ISOChronology.getInstanceUTC());
+
+ Row row =
+ Row.withSchema(schema)
+ .addValues(
+ dateTime.withTimeAtStartOfDay(), dateTime.withDate(new LocalDate(0L)), dateTime)
+ .build();
+
+ PreparedStatement psMocked = mock(PreparedStatement.class);
+
+ JdbcUtil.getPreparedStatementSetCaller(LogicalTypes.JDBC_DATE_TYPE)
+ .set(row, psMocked, 0, SchemaUtil.FieldWithIndex.of(schema.getField(0), 0));
+ JdbcUtil.getPreparedStatementSetCaller(LogicalTypes.JDBC_TIME_TYPE)
+ .set(row, psMocked, 1, SchemaUtil.FieldWithIndex.of(schema.getField(1), 1));
+ JdbcUtil.getPreparedStatementSetCaller(LogicalTypes.JDBC_TIMESTAMP_WITH_TIMEZONE_TYPE)
+ .set(row, psMocked, 2, SchemaUtil.FieldWithIndex.of(schema.getField(2), 2));
+
+ verify(psMocked, times(1)).setDate(1, new Date(row.getDateTime(0).getMillis()));
+ verify(psMocked, times(1)).setTime(2, new Time(row.getDateTime(1).getMillis()));
+
+ Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("UTC"));
+ cal.setTimeInMillis(epochMilli);
+
+ verify(psMocked, times(1)).setTimestamp(3, new Timestamp(cal.getTime().getTime()), cal);
+ }
+
+ @Test
+ public void testGetPreparedStatementSetCallerForArray() throws Exception {
+
+ Schema schema =
+ Schema.builder()
+ .addField("string_array_col", Schema.FieldType.array(Schema.FieldType.STRING))
+ .build();
+
+ List stringList = Arrays.asList("string 1", "string 2");
+
+ Row row = Row.withSchema(schema).addValues(stringList).build();
+
+ PreparedStatement psMocked = mock(PreparedStatement.class);
+ Connection connectionMocked = mock(Connection.class);
+ Array arrayMocked = mock(Array.class);
+
+ when(psMocked.getConnection()).thenReturn(connectionMocked);
+ when(connectionMocked.createArrayOf(anyString(), any())).thenReturn(arrayMocked);
+
+ JdbcUtil.getPreparedStatementSetCaller(Schema.FieldType.array(Schema.FieldType.STRING))
+ .set(row, psMocked, 0, SchemaUtil.FieldWithIndex.of(schema.getField(0), 0));
+
+ verify(psMocked, times(1)).setArray(1, arrayMocked);
+ }
+
+ private static ArrayList getRowsToWrite(long rowsToAdd, Schema schema) {
+
+ ArrayList data = new ArrayList<>();
+ for (int i = 0; i < rowsToAdd; i++) {
+ List