diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java index d659b25067645..b59e437b47402 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/schemas/Schema.java @@ -302,7 +302,7 @@ public boolean assignableTo(Schema other) { return equivalent(other, EquivalenceNullablePolicy.WEAKEN); } - /** Returns true if this Schema can be assigned to another Schema, igmoring nullable. * */ + /** Returns true if this Schema can be assigned to another Schema, ignoring nullable. * */ public boolean assignableToIgnoreNullable(Schema other) { return equivalent(other, EquivalenceNullablePolicy.IGNORE); } diff --git a/sdks/java/io/common/src/test/java/org/apache/beam/sdk/io/common/DatabaseTestHelper.java b/sdks/java/io/common/src/test/java/org/apache/beam/sdk/io/common/DatabaseTestHelper.java index 6d610f2ad9fe7..c324c4ddc24e4 100644 --- a/sdks/java/io/common/src/test/java/org/apache/beam/sdk/io/common/DatabaseTestHelper.java +++ b/sdks/java/io/common/src/test/java/org/apache/beam/sdk/io/common/DatabaseTestHelper.java @@ -47,6 +47,15 @@ public static void createTable(DataSource dataSource, String tableName) throws S } } + public static void createTableForRowWithSchema(DataSource dataSource, String tableName) + throws SQLException { + try (Connection connection = dataSource.getConnection()) { + try (Statement statement = connection.createStatement()) { + statement.execute(String.format("create table %s (name VARCHAR(500), id INT)", tableName)); + } + } + } + public static void deleteTable(DataSource dataSource, String tableName) throws SQLException { if (tableName != null) { try (Connection connection = dataSource.getConnection(); @@ -69,4 +78,13 @@ public static String getPostgresDBUrl(PostgresIOTestPipelineOptions options) { options.getPostgresPort(), options.getPostgresDatabaseName()); } + + public static void createTableWithStatement(DataSource dataSource, String stmt) + throws SQLException { + try (Connection connection = dataSource.getConnection()) { + try (Statement statement = connection.createStatement()) { + statement.execute(stmt); + } + } + } } diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java index 2574c45cbdc3b..18bd9a6bdf5da 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.io.jdbc; +import static org.apache.beam.sdk.io.jdbc.SchemaUtil.checkNullabilityForFields; import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkArgument; import com.google.auto.value.AutoValue; @@ -28,7 +29,11 @@ import java.sql.SQLException; import java.util.ArrayList; import java.util.List; +import java.util.Objects; +import java.util.Optional; import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import javax.annotation.Nullable; import javax.sql.DataSource; import org.apache.beam.sdk.annotations.Experimental; @@ -904,7 +909,7 @@ public interface RetryStrategy extends Serializable { *

All methods in this class delegate to the appropriate method of {@link JdbcIO.WriteVoid}. */ public static class Write extends PTransform, PDone> { - final WriteVoid inner; + WriteVoid inner; Write() { this(JdbcIO.writeVoid()); @@ -948,6 +953,11 @@ public Write withRetryStrategy(RetryStrategy retryStrategy) { return new Write(inner.withRetryStrategy(retryStrategy)); } + /** See {@link WriteVoid#withTable(String)}. */ + public Write withTable(String table) { + return new Write(inner.withTable(table)); + } + /** * Returns {@link WriteVoid} transform which can be used in {@link Wait#on(PCollection[])} to * wait until all data is written. @@ -972,11 +982,142 @@ public void populateDisplayData(DisplayData.Builder builder) { inner.populateDisplayData(builder); } + private boolean hasStatementAndSetter() { + return inner.getStatement() != null && inner.getPreparedStatementSetter() != null; + } + @Override public PDone expand(PCollection input) { + // fixme: validate invalid table input + if (input.hasSchema() && !hasStatementAndSetter()) { + checkArgument( + inner.getTable() != null, "table cannot be null if statement is not provided"); + Schema schema = input.getSchema(); + List fields = getFilteredFields(schema); + inner = + inner.withStatement( + JdbcUtil.generateStatement( + inner.getTable(), + fields.stream() + .map(SchemaUtil.FieldWithIndex::getField) + .collect(Collectors.toList()))); + inner = + inner.withPreparedStatementSetter( + new AutoGeneratedPreparedStatementSetter(fields, input.getToRowFunction())); + } + inner.expand(input); return PDone.in(input.getPipeline()); } + + private List getFilteredFields(Schema schema) { + Schema tableSchema; + + try (Connection connection = inner.getDataSourceProviderFn().apply(null).getConnection(); + PreparedStatement statement = + connection.prepareStatement((String.format("SELECT * FROM %s", inner.getTable())))) { + tableSchema = SchemaUtil.toBeamSchema(statement.getMetaData()); + statement.close(); + } catch (SQLException e) { + throw new RuntimeException( + "Error while determining columns from table: " + inner.getTable(), e); + } + + if (tableSchema.getFieldCount() < schema.getFieldCount()) { + throw new RuntimeException("Input schema has more fields than actual table."); + } + + // filter out missing fields from output table + List missingFields = + tableSchema.getFields().stream() + .filter( + line -> + schema.getFields().stream() + .noneMatch(s -> s.getName().equalsIgnoreCase(line.getName()))) + .collect(Collectors.toList()); + + // allow insert only if missing fields are nullable + if (checkNullabilityForFields(missingFields)) { + throw new RuntimeException("Non nullable fields are not allowed without schema."); + } + + List tableFilteredFields = + tableSchema.getFields().stream() + .map( + (tableField) -> { + Optional optionalSchemaField = + schema.getFields().stream() + .filter((f) -> SchemaUtil.compareSchemaField(tableField, f)) + .findFirst(); + return (optionalSchemaField.isPresent()) + ? SchemaUtil.FieldWithIndex.of( + tableField, schema.getFields().indexOf(optionalSchemaField.get())) + : null; + }) + .filter(Objects::nonNull) + .collect(Collectors.toList()); + + if (tableFilteredFields.size() != schema.getFieldCount()) { + throw new RuntimeException("Provided schema doesn't match with database schema."); + } + + return tableFilteredFields; + } + + /** + * A {@link org.apache.beam.sdk.io.jdbc.JdbcIO.PreparedStatementSetter} implementation that + * calls related setters on prepared statement. + */ + private class AutoGeneratedPreparedStatementSetter implements PreparedStatementSetter { + + private List fields; + private SerializableFunction toRowFn; + private List preparedStatementFieldSetterList = new ArrayList<>(); + + AutoGeneratedPreparedStatementSetter( + List fieldsWithIndex, SerializableFunction toRowFn) { + this.fields = fieldsWithIndex; + this.toRowFn = toRowFn; + populatePreparedStatementFieldSetter(); + } + + private void populatePreparedStatementFieldSetter() { + IntStream.range(0, fields.size()) + .forEach( + (index) -> { + Schema.FieldType fieldType = fields.get(index).getField().getType(); + preparedStatementFieldSetterList.add( + JdbcUtil.getPreparedStatementSetCaller(fieldType)); + }); + } + + @Override + public void setParameters(T element, PreparedStatement preparedStatement) throws Exception { + Row row = (element instanceof Row) ? (Row) element : toRowFn.apply(element); + IntStream.range(0, fields.size()) + .forEach( + (index) -> { + try { + preparedStatementFieldSetterList + .get(index) + .set(row, preparedStatement, index, fields.get(index)); + } catch (SQLException | NullPointerException e) { + throw new RuntimeException("Error while setting data to preparedStatement", e); + } + }); + } + } + } + + /** Interface implemented by functions that sets prepared statement data. */ + @FunctionalInterface + interface PreparedStatementSetCaller extends Serializable { + void set( + Row element, + PreparedStatement preparedStatement, + int prepareStatementIndex, + SchemaUtil.FieldWithIndex schemaFieldWithIndex) + throws SQLException; } /** A {@link PTransform} to write to a JDBC datasource. */ @@ -1001,6 +1142,9 @@ public abstract static class WriteVoid extends PTransform, PCo @Nullable abstract RetryStrategy getRetryStrategy(); + @Nullable + abstract String getTable(); + abstract Builder toBuilder(); @AutoValue.Builder @@ -1020,6 +1164,8 @@ abstract Builder setDataSourceProviderFn( abstract Builder setRetryStrategy(RetryStrategy deadlockPredicate); + abstract Builder setTable(String table); + abstract WriteVoid build(); } @@ -1065,6 +1211,11 @@ public WriteVoid withRetryStrategy(RetryStrategy retryStrategy) { return toBuilder().setRetryStrategy(retryStrategy).build(); } + public WriteVoid withTable(String table) { + checkArgument(table != null, "table name can not be null"); + return toBuilder().setTable(table).build(); + } + @Override public PCollection expand(PCollection input) { checkArgument(getStatement() != null, "withStatement() is required"); diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java new file mode 100644 index 0000000000000..17fdae73c91f3 --- /dev/null +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcUtil.java @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jdbc; + +import java.sql.Date; +import java.sql.JDBCType; +import java.sql.Time; +import java.sql.Timestamp; +import java.util.Calendar; +import java.util.EnumMap; +import java.util.List; +import java.util.Map; +import java.util.TimeZone; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMap; +import org.joda.time.DateTime; + +/** Provides utility functions for working with {@link JdbcIO}. */ +public class JdbcUtil { + + /** Generates an insert statement based on {@Link Schema.Field}. * */ + public static String generateStatement(String tableName, List fields) { + + String fieldNames = + IntStream.range(0, fields.size()) + .mapToObj( + (index) -> { + return fields.get(index).getName(); + }) + .collect(Collectors.joining(", ")); + + String valuePlaceholder = + IntStream.range(0, fields.size()) + .mapToObj( + (index) -> { + return "?"; + }) + .collect(Collectors.joining(", ")); + + return String.format("INSERT INTO %s(%s) VALUES(%s)", tableName, fieldNames, valuePlaceholder); + } + + /** PreparedStatementSetCaller for Schema Field types. * */ + private static Map typeNamePsSetCallerMap = + new EnumMap( + ImmutableMap.builder() + .put( + Schema.TypeName.BYTE, + (element, ps, i, fieldWithIndex) -> + ps.setByte(i + 1, element.getByte(fieldWithIndex.getIndex()))) + .put( + Schema.TypeName.INT16, + (element, ps, i, fieldWithIndex) -> + ps.setInt(i + 1, element.getInt16(fieldWithIndex.getIndex()))) + .put( + Schema.TypeName.INT64, + (element, ps, i, fieldWithIndex) -> + ps.setLong(i + 1, element.getInt64(fieldWithIndex.getIndex()))) + .put( + Schema.TypeName.DECIMAL, + (element, ps, i, fieldWithIndex) -> + ps.setBigDecimal(i + 1, element.getDecimal(fieldWithIndex.getIndex()))) + .put( + Schema.TypeName.FLOAT, + (element, ps, i, fieldWithIndex) -> + ps.setFloat(i + 1, element.getFloat(fieldWithIndex.getIndex()))) + .put( + Schema.TypeName.DOUBLE, + (element, ps, i, fieldWithIndex) -> + ps.setDouble(i + 1, element.getDouble(fieldWithIndex.getIndex()))) + .put( + Schema.TypeName.DATETIME, + (element, ps, i, fieldWithIndex) -> + ps.setTimestamp( + i + 1, + new Timestamp( + element.getDateTime(fieldWithIndex.getIndex()).getMillis()))) + .put( + Schema.TypeName.BOOLEAN, + (element, ps, i, fieldWithIndex) -> + ps.setBoolean(i + 1, element.getBoolean(fieldWithIndex.getIndex()))) + .put(Schema.TypeName.BYTES, createBytesCaller()) + .put( + Schema.TypeName.INT32, + (element, ps, i, fieldWithIndex) -> + ps.setInt(i + 1, element.getInt32(fieldWithIndex.getIndex()))) + .put(Schema.TypeName.STRING, createStringCaller()) + .build()); + + /** PreparedStatementSetCaller for Schema Field Logical types. * */ + public static JdbcIO.PreparedStatementSetCaller getPreparedStatementSetCaller( + Schema.FieldType fieldType) { + switch (fieldType.getTypeName()) { + case ARRAY: + return (element, ps, i, fieldWithIndex) -> { + ps.setArray( + i + 1, + ps.getConnection() + .createArrayOf( + fieldType.getCollectionElementType().getTypeName().name(), + element.getArray(fieldWithIndex.getIndex()).toArray())); + }; + case LOGICAL_TYPE: + { + String logicalTypeName = fieldType.getLogicalType().getIdentifier(); + JDBCType jdbcType = JDBCType.valueOf(logicalTypeName); + switch (jdbcType) { + case DATE: + return (element, ps, i, fieldWithIndex) -> { + ps.setDate( + i + 1, + new Date( + getDateOrTimeOnly( + element.getDateTime(fieldWithIndex.getIndex()).toDateTime(), true) + .getTime() + .getTime())); + }; + case TIME: + return (element, ps, i, fieldWithIndex) -> { + ps.setTime( + i + 1, + new Time( + getDateOrTimeOnly( + element.getDateTime(fieldWithIndex.getIndex()).toDateTime(), false) + .getTime() + .getTime())); + }; + case TIMESTAMP_WITH_TIMEZONE: + return (element, ps, i, fieldWithIndex) -> { + Calendar calendar = + withTimestampAndTimezone( + element.getDateTime(fieldWithIndex.getIndex()).toDateTime()); + ps.setTimestamp(i + 1, new Timestamp(calendar.getTime().getTime()), calendar); + }; + default: + return getPreparedStatementSetCaller(fieldType.getLogicalType().getBaseType()); + } + } + default: + { + if (typeNamePsSetCallerMap.containsKey(fieldType.getTypeName())) { + return typeNamePsSetCallerMap.get(fieldType.getTypeName()); + } else { + throw new RuntimeException( + fieldType.getTypeName().name() + + " in schema is not supported while writing. Please provide statement and preparedStatementSetter"); + } + } + } + } + + private static JdbcIO.PreparedStatementSetCaller createBytesCaller() { + return (element, ps, i, fieldWithIndex) -> { + validateLogicalTypeLength( + fieldWithIndex.getField(), element.getBytes(fieldWithIndex.getIndex()).length); + ps.setBytes(i + 1, element.getBytes(fieldWithIndex.getIndex())); + }; + } + + private static JdbcIO.PreparedStatementSetCaller createStringCaller() { + return (element, ps, i, fieldWithIndex) -> { + validateLogicalTypeLength( + fieldWithIndex.getField(), element.getString(fieldWithIndex.getIndex()).length()); + ps.setString(i + 1, element.getString(fieldWithIndex.getIndex())); + }; + } + + private static void validateLogicalTypeLength(Schema.Field field, Integer length) { + try { + if (field.getType().getTypeName().isLogicalType() + && !field.getType().getLogicalType().getArgument().isEmpty()) { + int maxLimit = Integer.parseInt(field.getType().getLogicalType().getArgument()); + if (field.getType().getTypeName().isLogicalType() && length >= maxLimit) { + throw new RuntimeException( + String.format( + "Length of Schema.Field[%s] data exceeds database column capacity", + field.getName())); + } + } + } catch (NumberFormatException e) { + // if argument is not set or not integer then do nothing and proceed with the insertion + } + } + + private static Calendar getDateOrTimeOnly(DateTime dateTime, boolean wantDateOnly) { + Calendar cal = Calendar.getInstance(); + cal.setTimeZone(TimeZone.getTimeZone(dateTime.getZone().getID())); + + if (wantDateOnly) { // return date only + cal.set(Calendar.YEAR, dateTime.getYear()); + cal.set(Calendar.MONTH, dateTime.getMonthOfYear() - 1); + cal.set(Calendar.DATE, dateTime.getDayOfMonth()); + + cal.set(Calendar.HOUR_OF_DAY, 0); + cal.set(Calendar.MINUTE, 0); + cal.set(Calendar.SECOND, 0); + cal.set(Calendar.MILLISECOND, 0); + } else { // return time only + cal.set(Calendar.YEAR, 1970); + cal.set(Calendar.MONTH, Calendar.JANUARY); + cal.set(Calendar.DATE, 1); + + cal.set(Calendar.HOUR_OF_DAY, dateTime.getHourOfDay()); + cal.set(Calendar.MINUTE, dateTime.getMinuteOfHour()); + cal.set(Calendar.SECOND, dateTime.getSecondOfMinute()); + cal.set(Calendar.MILLISECOND, dateTime.getMillisOfSecond()); + } + + return cal; + } + + private static Calendar withTimestampAndTimezone(DateTime dateTime) { + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone(dateTime.getZone().getID())); + calendar.setTimeInMillis(dateTime.getMillis()); + + return calendar; + } +} diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/SchemaUtil.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/SchemaUtil.java index 9c1dcb050a7e5..99fa06f7ed727 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/SchemaUtil.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/SchemaUtil.java @@ -28,6 +28,7 @@ import static java.sql.JDBCType.VARBINARY; import static java.sql.JDBCType.VARCHAR; import static java.sql.JDBCType.valueOf; +import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkArgument; import java.io.Serializable; import java.sql.Array; @@ -338,4 +339,73 @@ public Row mapRow(ResultSet rs) throws Exception { return rowBuilder.build(); } } + + /** + * compares two fields. Does not compare nullability of field types. + * + * @param a field 1 + * @param b field 2 + * @return TRUE if fields are equal. Otherwise FALSE + */ + public static boolean compareSchemaField(Schema.Field a, Schema.Field b) { + if (!a.getName().equalsIgnoreCase(b.getName())) { + return false; + } + + return compareSchemaFieldType(a.getType(), b.getType()); + } + + /** + * checks nullability for fields. + * + * @param fields + * @return TRUE if any field is not nullable + */ + public static boolean checkNullabilityForFields(List fields) { + return fields.stream().anyMatch(field -> !field.getType().getNullable()); + } + + /** + * compares two FieldType. Does not compare nullability. + * + * @param a FieldType 1 + * @param b FieldType 2 + * @return TRUE if FieldType are equal. Otherwise FALSE + */ + public static boolean compareSchemaFieldType(Schema.FieldType a, Schema.FieldType b) { + if (a.getTypeName().equals(b.getTypeName())) { + return !a.getTypeName().equals(Schema.TypeName.LOGICAL_TYPE) + || compareSchemaFieldType( + a.getLogicalType().getBaseType(), b.getLogicalType().getBaseType()); + } else if (a.getTypeName().isLogicalType()) { + return a.getLogicalType().getBaseType().getTypeName().equals(b.getTypeName()); + } else if (b.getTypeName().isLogicalType()) { + return b.getLogicalType().getBaseType().getTypeName().equals(a.getTypeName()); + } + return false; + } + + static class FieldWithIndex implements Serializable { + private final Schema.Field field; + private final Integer index; + + private FieldWithIndex(Schema.Field field, Integer index) { + this.field = field; + this.index = index; + } + + public static FieldWithIndex of(Schema.Field field, Integer index) { + checkArgument(field != null); + checkArgument(index != null); + return new FieldWithIndex(field, index); + } + + public Schema.Field getField() { + return field; + } + + public Integer getIndex() { + return index; + } + } } diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java index ecd26a139486c..0e815339bcc29 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java @@ -19,19 +19,35 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import java.io.PrintWriter; import java.io.Serializable; import java.io.StringWriter; +import java.math.BigDecimal; import java.net.InetAddress; +import java.nio.charset.Charset; +import java.sql.Array; import java.sql.Connection; +import java.sql.Date; import java.sql.JDBCType; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; +import java.sql.Time; +import java.sql.Timestamp; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Calendar; import java.util.Collections; +import java.util.List; +import java.util.TimeZone; import javax.sql.DataSource; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.SerializableCoder; @@ -56,12 +72,16 @@ import org.apache.commons.dbcp2.PoolingDataSource; import org.apache.derby.drda.NetworkServerControl; import org.apache.derby.jdbc.ClientDataSource; +import org.joda.time.DateTime; +import org.joda.time.LocalDate; +import org.joda.time.chrono.ISOChronology; import org.junit.After; import org.junit.AfterClass; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.slf4j.Logger; @@ -85,6 +105,8 @@ public class JdbcIOTest implements Serializable { @Rule public final transient ExpectedLogs expectedLogs = ExpectedLogs.none(JdbcIO.class); + @Rule public transient ExpectedException thrown = ExpectedException.none(); + @BeforeClass public static void beforeClass() throws Exception { port = NetworkTestHelper.getAvailableLocalPort(); @@ -490,6 +512,345 @@ public void tearDown() { } } + @Test + public void testWriteWithoutPreparedStatement() throws Exception { + final int rowsToAdd = 10; + + Schema.Builder schemaBuilder = Schema.builder(); + schemaBuilder.addField(Schema.Field.of("column_boolean", Schema.FieldType.BOOLEAN)); + schemaBuilder.addField(Schema.Field.of("column_string", Schema.FieldType.STRING)); + schemaBuilder.addField(Schema.Field.of("column_int", Schema.FieldType.INT32)); + schemaBuilder.addField(Schema.Field.of("column_long", Schema.FieldType.INT64)); + schemaBuilder.addField(Schema.Field.of("column_float", Schema.FieldType.FLOAT)); + schemaBuilder.addField(Schema.Field.of("column_double", Schema.FieldType.DOUBLE)); + schemaBuilder.addField(Schema.Field.of("column_bigdecimal", Schema.FieldType.DECIMAL)); + schemaBuilder.addField(Schema.Field.of("column_date", LogicalTypes.JDBC_DATE_TYPE)); + schemaBuilder.addField(Schema.Field.of("column_time", LogicalTypes.JDBC_TIME_TYPE)); + schemaBuilder.addField( + Schema.Field.of("column_timestamptz", LogicalTypes.JDBC_TIMESTAMP_WITH_TIMEZONE_TYPE)); + schemaBuilder.addField(Schema.Field.of("column_timestamp", Schema.FieldType.DATETIME)); + schemaBuilder.addField(Schema.Field.of("column_short", Schema.FieldType.INT16)); + Schema schema = schemaBuilder.build(); + + String tableName = DatabaseTestHelper.getTestTableName("UT_WRITE_PS"); + StringBuilder stmt = new StringBuilder("CREATE TABLE "); + stmt.append(tableName); + stmt.append(" ("); + stmt.append("column_boolean BOOLEAN,"); // boolean + stmt.append("column_string VARCHAR(254),"); // String + stmt.append("column_int INTEGER,"); // int + stmt.append("column_long BIGINT,"); // long + stmt.append("column_float REAL,"); // float + stmt.append("column_double DOUBLE PRECISION,"); // double + stmt.append("column_bigdecimal DECIMAL(13,0),"); // BigDecimal + stmt.append("column_date DATE,"); // Date + stmt.append("column_time TIME,"); // Time + stmt.append("column_timestamptz TIMESTAMP,"); // Timestamp + stmt.append("column_timestamp TIMESTAMP,"); // Timestamp + stmt.append("column_short SMALLINT"); // short + stmt.append(" )"); + DatabaseTestHelper.createTableWithStatement(dataSource, stmt.toString()); + try { + ArrayList data = getRowsToWrite(rowsToAdd, schema); + pipeline + .apply(Create.of(data)) + .setRowSchema(schema) + .apply( + JdbcIO.write() + .withDataSourceConfiguration( + JdbcIO.DataSourceConfiguration.create( + "org.apache.derby.jdbc.ClientDriver", + "jdbc:derby://localhost:" + port + "/target/beam")) + .withBatchSize(10L) + .withTable(tableName)); + pipeline.run(); + assertRowCount(tableName, rowsToAdd); + } finally { + DatabaseTestHelper.deleteTable(dataSource, tableName); + } + } + + @Test + public void testWriteWithoutPreparedStatementWithReadRows() throws Exception { + SerializableFunction dataSourceProvider = ignored -> dataSource; + PCollection rows = + pipeline.apply( + JdbcIO.readRows() + .withDataSourceProviderFn(dataSourceProvider) + .withQuery(String.format("select name,id from %s where name = ?", readTableName)) + .withStatementPreparator( + preparedStatement -> + preparedStatement.setString(1, TestRow.getNameForSeed(1)))); + + String writeTableName = DatabaseTestHelper.getTestTableName("UT_WRITE_PS_WITH_READ_ROWS"); + DatabaseTestHelper.createTableForRowWithSchema(dataSource, writeTableName); + try { + rows.apply( + JdbcIO.write() + .withDataSourceConfiguration( + JdbcIO.DataSourceConfiguration.create( + "org.apache.derby.jdbc.ClientDriver", + "jdbc:derby://localhost:" + port + "/target/beam")) + .withBatchSize(10L) + .withTable(writeTableName)); + pipeline.run(); + } finally { + DatabaseTestHelper.deleteTable(dataSource, writeTableName); + } + } + + @Test + public void testWriteWithoutPsWithNonNullableTableField() throws Exception { + final int rowsToAdd = 10; + + Schema.Builder schemaBuilder = Schema.builder(); + schemaBuilder.addField(Schema.Field.of("column_boolean", Schema.FieldType.BOOLEAN)); + schemaBuilder.addField(Schema.Field.of("column_string", Schema.FieldType.STRING)); + Schema schema = schemaBuilder.build(); + + String tableName = DatabaseTestHelper.getTestTableName("UT_WRITE"); + StringBuilder stmt = new StringBuilder("CREATE TABLE "); + stmt.append(tableName); + stmt.append(" ("); + stmt.append("column_boolean BOOLEAN,"); + stmt.append("column_int INTEGER NOT NULL"); + stmt.append(" )"); + DatabaseTestHelper.createTableWithStatement(dataSource, stmt.toString()); + try { + ArrayList data = getRowsToWrite(rowsToAdd, schema); + pipeline + .apply(Create.of(data)) + .setRowSchema(schema) + .apply( + JdbcIO.write() + .withDataSourceConfiguration( + JdbcIO.DataSourceConfiguration.create( + "org.apache.derby.jdbc.ClientDriver", + "jdbc:derby://localhost:" + port + "/target/beam")) + .withBatchSize(10L) + .withTable(tableName)); + pipeline.run(); + } finally { + DatabaseTestHelper.deleteTable(dataSource, tableName); + thrown.expect(RuntimeException.class); + } + } + + @Test + public void testWriteWithoutPreparedStatementAndNonRowType() throws Exception { + final int rowsToAdd = 10; + + String tableName = DatabaseTestHelper.getTestTableName("UT_WRITE_PS_NON_ROW"); + DatabaseTestHelper.createTableForRowWithSchema(dataSource, tableName); + try { + List data = getRowsWithSchemaToWrite(rowsToAdd); + + pipeline + .apply(Create.of(data)) + .apply( + JdbcIO.write() + .withDataSourceConfiguration( + JdbcIO.DataSourceConfiguration.create( + "org.apache.derby.jdbc.ClientDriver", + "jdbc:derby://localhost:" + port + "/target/beam")) + .withBatchSize(10L) + .withTable(tableName)); + pipeline.run(); + assertRowCount(tableName, rowsToAdd); + } finally { + DatabaseTestHelper.deleteTable(dataSource, tableName); + } + } + + @Test + public void testGetPreparedStatementSetCaller() throws Exception { + + Schema schema = + Schema.builder() + .addField("bigint_col", Schema.FieldType.INT64) + .addField("binary_col", Schema.FieldType.BYTES) + .addField("bit_col", Schema.FieldType.BOOLEAN) + .addField("char_col", Schema.FieldType.STRING) + .addField("decimal_col", Schema.FieldType.DECIMAL) + .addField("double_col", Schema.FieldType.DOUBLE) + .addField("float_col", Schema.FieldType.FLOAT) + .addField("integer_col", Schema.FieldType.INT32) + .addField("datetime_col", Schema.FieldType.DATETIME) + .addField("int16_col", Schema.FieldType.INT16) + .addField("byte_col", Schema.FieldType.BYTE) + .build(); + Row row = + Row.withSchema(schema) + .addValues( + 42L, + "binary".getBytes(Charset.forName("UTF-8")), + true, + "char", + BigDecimal.valueOf(25L), + 20.5D, + 15.5F, + 10, + new DateTime(), + (short) 5, + Byte.parseByte("1", 2)) + .build(); + + PreparedStatement psMocked = mock(PreparedStatement.class); + + JdbcUtil.getPreparedStatementSetCaller(Schema.FieldType.INT64) + .set(row, psMocked, 0, SchemaUtil.FieldWithIndex.of(schema.getField(0), 0)); + JdbcUtil.getPreparedStatementSetCaller(Schema.FieldType.BYTES) + .set(row, psMocked, 1, SchemaUtil.FieldWithIndex.of(schema.getField(1), 1)); + JdbcUtil.getPreparedStatementSetCaller(Schema.FieldType.BOOLEAN) + .set(row, psMocked, 2, SchemaUtil.FieldWithIndex.of(schema.getField(2), 2)); + JdbcUtil.getPreparedStatementSetCaller(Schema.FieldType.STRING) + .set(row, psMocked, 3, SchemaUtil.FieldWithIndex.of(schema.getField(3), 3)); + JdbcUtil.getPreparedStatementSetCaller(Schema.FieldType.DECIMAL) + .set(row, psMocked, 4, SchemaUtil.FieldWithIndex.of(schema.getField(4), 4)); + JdbcUtil.getPreparedStatementSetCaller(Schema.FieldType.DOUBLE) + .set(row, psMocked, 5, SchemaUtil.FieldWithIndex.of(schema.getField(5), 5)); + JdbcUtil.getPreparedStatementSetCaller(Schema.FieldType.FLOAT) + .set(row, psMocked, 6, SchemaUtil.FieldWithIndex.of(schema.getField(6), 6)); + JdbcUtil.getPreparedStatementSetCaller(Schema.FieldType.INT32) + .set(row, psMocked, 7, SchemaUtil.FieldWithIndex.of(schema.getField(7), 7)); + JdbcUtil.getPreparedStatementSetCaller(Schema.FieldType.DATETIME) + .set(row, psMocked, 8, SchemaUtil.FieldWithIndex.of(schema.getField(8), 8)); + JdbcUtil.getPreparedStatementSetCaller(Schema.FieldType.INT16) + .set(row, psMocked, 9, SchemaUtil.FieldWithIndex.of(schema.getField(9), 9)); + JdbcUtil.getPreparedStatementSetCaller(Schema.FieldType.BYTE) + .set(row, psMocked, 10, SchemaUtil.FieldWithIndex.of(schema.getField(10), 10)); + + verify(psMocked, times(1)).setLong(1, 42L); + verify(psMocked, times(1)).setBytes(2, "binary".getBytes(Charset.forName("UTF-8"))); + verify(psMocked, times(1)).setBoolean(3, true); + verify(psMocked, times(1)).setString(4, "char"); + verify(psMocked, times(1)).setBigDecimal(5, BigDecimal.valueOf(25L)); + verify(psMocked, times(1)).setDouble(6, 20.5D); + verify(psMocked, times(1)).setFloat(7, 15.5F); + verify(psMocked, times(1)).setInt(8, 10); + verify(psMocked, times(1)) + .setTimestamp(9, new Timestamp(row.getDateTime("datetime_col").getMillis())); + verify(psMocked, times(1)).setInt(10, (short) 5); + verify(psMocked, times(1)).setByte(11, Byte.parseByte("1", 2)); + } + + @Test + public void testGetPreparedStatementSetCallerForLogicalTypes() throws Exception { + + Schema schema = + Schema.builder() + .addField("logical_date_col", LogicalTypes.JDBC_DATE_TYPE) + .addField("logical_time_col", LogicalTypes.JDBC_TIME_TYPE) + .addField("logical_time_with_tz_col", LogicalTypes.JDBC_TIMESTAMP_WITH_TIMEZONE_TYPE) + .build(); + + long epochMilli = 1558719710000L; + DateTime dateTime = new DateTime(epochMilli, ISOChronology.getInstanceUTC()); + + Row row = + Row.withSchema(schema) + .addValues( + dateTime.withTimeAtStartOfDay(), dateTime.withDate(new LocalDate(0L)), dateTime) + .build(); + + PreparedStatement psMocked = mock(PreparedStatement.class); + + JdbcUtil.getPreparedStatementSetCaller(LogicalTypes.JDBC_DATE_TYPE) + .set(row, psMocked, 0, SchemaUtil.FieldWithIndex.of(schema.getField(0), 0)); + JdbcUtil.getPreparedStatementSetCaller(LogicalTypes.JDBC_TIME_TYPE) + .set(row, psMocked, 1, SchemaUtil.FieldWithIndex.of(schema.getField(1), 1)); + JdbcUtil.getPreparedStatementSetCaller(LogicalTypes.JDBC_TIMESTAMP_WITH_TIMEZONE_TYPE) + .set(row, psMocked, 2, SchemaUtil.FieldWithIndex.of(schema.getField(2), 2)); + + verify(psMocked, times(1)).setDate(1, new Date(row.getDateTime(0).getMillis())); + verify(psMocked, times(1)).setTime(2, new Time(row.getDateTime(1).getMillis())); + + Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("UTC")); + cal.setTimeInMillis(epochMilli); + + verify(psMocked, times(1)).setTimestamp(3, new Timestamp(cal.getTime().getTime()), cal); + } + + @Test + public void testGetPreparedStatementSetCallerForArray() throws Exception { + + Schema schema = + Schema.builder() + .addField("string_array_col", Schema.FieldType.array(Schema.FieldType.STRING)) + .build(); + + List stringList = Arrays.asList("string 1", "string 2"); + + Row row = Row.withSchema(schema).addValues(stringList).build(); + + PreparedStatement psMocked = mock(PreparedStatement.class); + Connection connectionMocked = mock(Connection.class); + Array arrayMocked = mock(Array.class); + + when(psMocked.getConnection()).thenReturn(connectionMocked); + when(connectionMocked.createArrayOf(anyString(), any())).thenReturn(arrayMocked); + + JdbcUtil.getPreparedStatementSetCaller(Schema.FieldType.array(Schema.FieldType.STRING)) + .set(row, psMocked, 0, SchemaUtil.FieldWithIndex.of(schema.getField(0), 0)); + + verify(psMocked, times(1)).setArray(1, arrayMocked); + } + + private static ArrayList getRowsToWrite(long rowsToAdd, Schema schema) { + + ArrayList data = new ArrayList<>(); + for (int i = 0; i < rowsToAdd; i++) { + List fields = new ArrayList<>(); + + Row row = + schema.getFields().stream() + .map(field -> dummyFieldValue(field.getType())) + .collect(Row.toRow(schema)); + data.add(row); + } + return data; + } + + private static ArrayList getRowsWithSchemaToWrite(long rowsToAdd) { + + ArrayList data = new ArrayList<>(); + for (int i = 0; i < rowsToAdd; i++) { + data.add(new RowWithSchema("Test", i)); + } + return data; + } + + private static Object dummyFieldValue(Schema.FieldType fieldType) { + long epochMilli = 1558719710000L; + if (fieldType.equals(Schema.FieldType.STRING)) { + return "string value"; + } else if (fieldType.equals(Schema.FieldType.INT32)) { + return 100; + } else if (fieldType.equals(Schema.FieldType.DOUBLE)) { + return 20.5D; + } else if (fieldType.equals(Schema.FieldType.BOOLEAN)) { + return Boolean.TRUE; + } else if (fieldType.equals(Schema.FieldType.INT16)) { + return Short.MAX_VALUE; + } else if (fieldType.equals(Schema.FieldType.INT64)) { + return Long.MAX_VALUE; + } else if (fieldType.equals(Schema.FieldType.FLOAT)) { + return 15.5F; + } else if (fieldType.equals(Schema.FieldType.DECIMAL)) { + return BigDecimal.ONE; + } else if (fieldType.equals(LogicalTypes.JDBC_DATE_TYPE)) { + return new DateTime(epochMilli, ISOChronology.getInstanceUTC()).withTimeAtStartOfDay(); + } else if (fieldType.equals(LogicalTypes.JDBC_TIME_TYPE)) { + return new DateTime(epochMilli, ISOChronology.getInstanceUTC()).withDate(new LocalDate(0L)); + } else if (fieldType.equals(LogicalTypes.JDBC_TIMESTAMP_WITH_TIMEZONE_TYPE)) { + return new DateTime(epochMilli, ISOChronology.getInstanceUTC()); + } else if (fieldType.equals(Schema.FieldType.DATETIME)) { + return new DateTime(epochMilli, ISOChronology.getInstanceUTC()); + } else { + return null; + } + } + @Test public void testWriteWithEmptyPCollection() { pipeline diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcUtilTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcUtilTest.java new file mode 100644 index 0000000000000..a867f8e14eb37 --- /dev/null +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcUtilTest.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jdbc; + +import static org.junit.Assert.assertEquals; + +import org.apache.beam.sdk.schemas.Schema; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Test JdbcUtil. */ +@RunWith(JUnit4.class) +public class JdbcUtilTest { + + @Test + public void testGetPreparedStatementSetCaller() throws Exception { + Schema wantSchema = + Schema.builder() + .addField("col1", Schema.FieldType.INT64) + .addField("col2", Schema.FieldType.INT64) + .addField("col3", Schema.FieldType.INT64) + .build(); + + String generatedStmt = JdbcUtil.generateStatement("test_table", wantSchema.getFields()); + String expectedStmt = "INSERT INTO test_table(col1, col2, col3) VALUES(?, ?, ?)"; + assertEquals(expectedStmt, generatedStmt); + } +} diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/SchemaUtilTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/SchemaUtilTest.java index f606bde055f93..70487e8ac626b 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/SchemaUtilTest.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/SchemaUtilTest.java @@ -18,6 +18,8 @@ package org.apache.beam.sdk.io.jdbc; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.mock; @@ -301,4 +303,43 @@ private static JdbcFieldInfo of(String columnLabel, int columnType, int precisio return new JdbcFieldInfo(columnLabel, columnType, null, false, precision, scale); } } + + @Test + public void testSchemaFieldComparator() { + assertTrue( + SchemaUtil.compareSchemaField( + Schema.Field.of("name", Schema.FieldType.STRING), + Schema.Field.of("name", Schema.FieldType.STRING))); + assertFalse( + SchemaUtil.compareSchemaField( + Schema.Field.of("name", Schema.FieldType.STRING), + Schema.Field.of("anotherName", Schema.FieldType.STRING))); + assertFalse( + SchemaUtil.compareSchemaField( + Schema.Field.of("name", Schema.FieldType.STRING), + Schema.Field.of("name", Schema.FieldType.INT64))); + } + + @Test + public void testSchemaFieldTypeComparator() { + assertTrue(SchemaUtil.compareSchemaFieldType(Schema.FieldType.STRING, Schema.FieldType.STRING)); + assertFalse(SchemaUtil.compareSchemaFieldType(Schema.FieldType.STRING, Schema.FieldType.INT16)); + assertTrue( + SchemaUtil.compareSchemaFieldType( + LogicalTypes.variableLengthString(JDBCType.VARCHAR, 255), + LogicalTypes.variableLengthString(JDBCType.VARCHAR, 255))); + assertFalse( + SchemaUtil.compareSchemaFieldType( + LogicalTypes.variableLengthString(JDBCType.VARCHAR, 255), + LogicalTypes.fixedLengthBytes(JDBCType.BIT, 255))); + assertTrue( + SchemaUtil.compareSchemaFieldType( + Schema.FieldType.STRING, LogicalTypes.variableLengthString(JDBCType.VARCHAR, 255))); + assertFalse( + SchemaUtil.compareSchemaFieldType( + Schema.FieldType.INT16, LogicalTypes.variableLengthString(JDBCType.VARCHAR, 255))); + assertTrue( + SchemaUtil.compareSchemaFieldType( + LogicalTypes.variableLengthString(JDBCType.VARCHAR, 255), Schema.FieldType.STRING)); + } }