diff --git a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java index 895e85886cc5a..34133b2583685 100644 --- a/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java +++ b/sdks/java/io/jdbc/src/main/java/org/apache/beam/sdk/io/jdbc/JdbcIO.java @@ -24,6 +24,8 @@ import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; +import java.util.ArrayList; +import java.util.List; import javax.annotation.Nullable; import javax.sql.DataSource; @@ -57,14 +59,25 @@ * * pipeline.apply(JdbcIO.read() * .withDataSource(myDataSource) + * .withRowMapper(new JdbcIO.RowMapper() { + * public MyElement mapRow(ResultSet resultSet) { + * // use the resultSet to build the element of the PCollection + * // for instance: + * // return resultSet.getString(2); + * } + * }) * * } * + *

+ * You can find an full {@link RowMapper} example in {@code JdbcIOTest}. + *

*

Writing to JDBC datasource

*

* JDBC sink supports writing records into a database. It expects a * {@code PCollection}, converts the {@code T} elements as SQL statement - * and insert into the database. T is the type expected by the provided {@link ElementInserter}. + * and setParameters into the database. T is the type expected by the + * provided {@link PreparedStatementSetter}. *

*

* Like the source, to configure JDBC sink, you have to provide a datasource. For instance: @@ -74,10 +87,22 @@ * * pipeline * .apply(...) - * .apply(JdbcIO.write().withDataSource(myDataSource) + * .apply(JdbcIO.write() + * .withDataSource(myDataSource) + * .withQuery(query) + * .withPreparedStatementSetter(new JdbcIO.PreparedStatementSetter() { + * public void setParameters(MyElement element, PreparedStatement statement) { + * // use the PCollection element to set parameters of the SQL statement used to insert + * // in the database + * // for instance: statement.setString(0, element.toString()); + * } + * }) * * } * + *

+ * You can find a full {@link PreparedStatementSetter} in {@code JdbcIOTest}. + *

*/ public class JdbcIO { @@ -98,7 +123,8 @@ public static Read read() { * @return a {@link Write} {@link PTransform}. */ public static Write write() { - return new Write(new Write.JdbcWriter(null, null, null, null)); + return new Write(new Write.JdbcWriter(null, null, null, null, + null, 1024L)); } private JdbcIO() { @@ -110,9 +136,7 @@ private JdbcIO() { * object used in the {@link PCollection}. */ public interface RowMapper extends Serializable { - T mapRow(ResultSet resultSet); - } /** @@ -261,9 +285,7 @@ public void processElement(ProcessContext context) throws Exception { try (ResultSet resultSet = statement.executeQuery()) { while (resultSet.next()) { T record = rowMapper.mapRow(resultSet); - if (record != null) { context.output(record); - } } } } @@ -273,15 +295,11 @@ public void processElement(ProcessContext context) throws Exception { } /** - * An interface used by the JdbcIO Write for mapping {@link PCollection} elements as a rows of a - * ResultSet on a per-row basis. - * Implementations of this interface perform the actual work of mapping each row to a result - * object used in the {@link PCollection}. + * An interface used by the JdbcIO Write to set the parameters of the {@link PreparedStatement} + * used to setParameters into the database. */ - public interface ElementInserter extends Serializable { - - PreparedStatement insert(T element, Connection connection); - + public interface PreparedStatementSetter extends Serializable { + void setParameters(T element, PreparedStatement preparedStatement) throws Exception; } /** @@ -293,6 +311,10 @@ public Write withDataSource(DataSource dataSource) { return new Write<>(writer.withDataSource(dataSource)); } + public Write withQuery(String query) { + return new Write<>(writer.withQuery(query)); + } + public Write withUsername(String username) { return new Write<>(writer.withUsername(username)); } @@ -301,8 +323,13 @@ public Write withPassword(String password) { return new Write<>(writer.withPassword(password)); } - public Write withElementInserter(ElementInserter elementInserter) { - return new Write<>(writer.withElementInserter(elementInserter)); + public Write withPreparedStatementSetter( + PreparedStatementSetter preparedStatementSetter) { + return new Write<>(writer.withPreparedStatementSetter(preparedStatementSetter)); + } + + public Write withBatchSize(long batchSize) { + return new Write<>(writer.withBatchSize(batchSize)); } private final JdbcWriter writer; @@ -325,39 +352,60 @@ public void validate(PCollection input) { private static class JdbcWriter extends DoFn { private final DataSource dataSource; + private final String query; private final String username; private final String password; - private final ElementInserter elementInserter; + private final PreparedStatementSetter preparedStatementSetter; + private long batchSize; private Connection connection; + private List batch; - public JdbcWriter(DataSource dataSource, String username, String password, - ElementInserter elementInserter) { + public JdbcWriter(DataSource dataSource, String query, String username, String password, + PreparedStatementSetter preparedStatementSetter, long batchSize) { this.dataSource = dataSource; + this.query = query; this.username = username; this.password = password; - this.elementInserter = elementInserter; + this.preparedStatementSetter = preparedStatementSetter; + this.batchSize = batchSize; } public JdbcWriter withDataSource(DataSource dataSource) { - return new JdbcWriter<>(dataSource, username, password, elementInserter); + return new JdbcWriter<>(dataSource, query, username, password, preparedStatementSetter, + batchSize); + } + + public JdbcWriter withQuery(String query) { + return new JdbcWriter<>(dataSource, query, username, password, preparedStatementSetter, + batchSize); } public JdbcWriter withUsername(String username) { - return new JdbcWriter<>(dataSource, username, password, elementInserter); + return new JdbcWriter<>(dataSource, query, username, password, preparedStatementSetter, + batchSize); } public JdbcWriter withPassword(String password) { - return new JdbcWriter<>(dataSource, username, password, elementInserter); + return new JdbcWriter<>(dataSource, query, username, password, preparedStatementSetter, + batchSize); + } + + public JdbcWriter withPreparedStatementSetter( + PreparedStatementSetter preparedStatementSetter) { + return new JdbcWriter<>(dataSource, query, username, password, preparedStatementSetter, + batchSize); } - public JdbcWriter withElementInserter(ElementInserter elementInserter) { - return new JdbcWriter<>(dataSource, username, password, elementInserter); + public JdbcWriter withBatchSize(long batchSize) { + return new JdbcWriter<>(dataSource, query, username, password, preparedStatementSetter, + batchSize); } public void validate() { Preconditions.checkNotNull(dataSource, "dataSource"); - Preconditions.checkNotNull(elementInserter, "elementInserter"); + Preconditions.checkNotNull(query, "query"); + Preconditions.checkNotNull(preparedStatementSetter, "preparedStatementSetter"); } @Setup @@ -367,29 +415,40 @@ public void connectToDatabase() throws Exception { } else { connection = dataSource.getConnection(); } - connection.setAutoCommit(false); + connection.setAutoCommit(true); + } + + @StartBundle + public void startBundle(Context context) { + batch = new ArrayList<>(); } @ProcessElement - public void processElement(ProcessContext context) { + public void processElement(ProcessContext context) throws Exception { T record = context.element(); - try { - PreparedStatement statement = elementInserter.insert(record, connection); - if (statement != null) { + + batch.add(record); + if (batch.size() >= batchSize) { + finishBundle(context); + } + } + + @FinishBundle + public void finishBundle(Context context) throws Exception { + for (T record : batch) { + try { + PreparedStatement statement = connection.prepareStatement(query); + preparedStatementSetter.setParameters(record, statement); try { statement.executeUpdate(); } finally { statement.close(); } + } catch (Exception e) { + LOGGER.error("Can't insert into database", e); } - } catch (Exception e) { - LOGGER.warn("Can't insert data into table", e); } - } - - @FinishBundle - public void finishBundle(Context context) throws Exception { - connection.commit(); + batch.clear(); } @Teardown diff --git a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java index 586916132e337..2d8fe34c5698c 100644 --- a/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java +++ b/sdks/java/io/jdbc/src/test/java/org/apache/beam/sdk/io/jdbc/JdbcIOTest.java @@ -20,25 +20,14 @@ import static org.junit.Assert.assertEquals; import java.io.Serializable; -import java.math.BigDecimal; import java.net.InetAddress; -import java.sql.Array; -import java.sql.Blob; -import java.sql.Clob; import java.sql.Connection; -import java.sql.Date; -import java.sql.NClob; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.Statement; -import java.sql.Time; -import java.sql.Timestamp; import java.sql.Types; import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.testing.NeedsRunner; @@ -351,128 +340,21 @@ public void setColumnValue(Object columnValue) { public void testWrite() throws Exception { TestPipeline pipeline = TestPipeline.create(); - ArrayList data = new ArrayList<>(); + ArrayList> data = new ArrayList<>(); for (int i = 0; i < 1000; i++) { - JdbcDataRecord record = new JdbcDataRecord(2); - record.getTableNames()[0] = "TEST"; - record.getColumnNames()[0] = "ID"; - record.getColumnTypes()[0] = Types.INTEGER; - record.getColumnValues()[0] = i; - record.getTableNames()[1] = "TEST"; - record.getColumnNames()[1] = "NAME"; - record.getColumnTypes()[1] = Types.VARCHAR; - record.getColumnValues()[1] = "Test"; - data.add(record); + KV kv = KV.of(i, "TEST"); + data.add(kv); } pipeline.apply(Create.of(data)) .apply(JdbcIO.write().withDataSource(dataSource) - .withElementInserter(new JdbcIO.ElementInserter() { - public PreparedStatement insert(JdbcDataRecord element, Connection connection) { - PreparedStatement statement = null; - // map record per table - Map> tableMap = new HashMap<>(); - Map insertPerTable = new HashMap<>(); - for (int i = 0; i < element.getTableNames().length; i++) { - String tableName = element.getTableNames()[i]; - List recordList = tableMap.get(tableName); - if (recordList == null) { - recordList = new ArrayList<>(); - } - recordList.add(new InsertRecord( - element.getColumnTypes()[i], - element.getColumnValues()[i])); - tableMap.put(tableName, recordList); - } - // create insert string - for (String tableName : tableMap.keySet()) { - String insertString = "insert into " + tableName + " values("; - for (InsertRecord insertRecord : tableMap.get(tableName)) { - insertString = insertString + "?,"; - } - // remove trailing ',' and close parentheses - insertString = insertString.substring(0, insertString.length() - 1) + ")"; - LOGGER.debug(insertString); - try { - statement = connection.prepareStatement(insertString); - int index = 1; - for (InsertRecord insertRecord : tableMap.get(tableName)) { - switch (insertRecord.getColumnType()) { - case Types.ARRAY: - statement.setArray(index, (Array) insertRecord.getColumnValue()); - break; - case Types.BIGINT: - statement.setInt(index, (int) insertRecord.getColumnValue()); - break; - case Types.BIT: - statement.setInt(index, (int) insertRecord.getColumnValue()); - break; - case Types.BLOB: - statement.setBlob(index, (Blob) insertRecord.getColumnValue()); - break; - case Types.BOOLEAN: - statement.setBoolean(index, (boolean) insertRecord.getColumnValue()); - break; - case Types.CHAR: - statement.setString(index, (String) insertRecord.getColumnValue()); - break; - case Types.CLOB: - statement.setClob(index, (Clob) insertRecord.getColumnValue()); - break; - case Types.DATE: - statement.setDate(index, (Date) insertRecord.getColumnValue()); - break; - case Types.DECIMAL: - statement.setBigDecimal(index, - (BigDecimal) insertRecord.getColumnValue()); - break; - case Types.DOUBLE: - statement.setDouble(index, (double) insertRecord.getColumnValue()); - break; - case Types.FLOAT: - statement.setFloat(index, (float) insertRecord.getColumnValue()); - break; - case Types.INTEGER: - statement.setInt(index, (int) insertRecord.getColumnValue()); - break; - case Types.LONGNVARCHAR: - statement.setString(index, (String) insertRecord.getColumnValue()); - break; - case Types.LONGVARCHAR: - statement.setString(index, (String) insertRecord.getColumnValue()); - break; - case Types.NCHAR: - statement.setNString(index, (String) insertRecord.getColumnValue()); - break; - case Types.NCLOB: - statement.setNClob(index, (NClob) insertRecord.getColumnValue()); - break; - case Types.SMALLINT: - statement.setInt(index, (int) insertRecord.getColumnValue()); - break; - case Types.TIME: - statement.setTime(index, (Time) insertRecord.getColumnValue()); - break; - case Types.TIMESTAMP: - statement.setTimestamp(index, (Timestamp) insertRecord.getColumnValue()); - break; - case Types.TINYINT: - statement.setInt(index, (int) insertRecord.getColumnValue()); - break; - case Types.VARCHAR: - statement.setString(index, (String) insertRecord.getColumnValue()); - break; - default: - statement.setObject(index, insertRecord.getColumnValue()); - break; - } - index++; - } - } catch (Exception e) { - LOGGER.error("Can't prepare statement", e); - } - } - return statement; - }})); + .withQuery("insert into TEST values(?, ?)") + .withPreparedStatementSetter(new JdbcIO.PreparedStatementSetter>() { + public void setParameters(KV element, PreparedStatement statement) + throws Exception { + statement.setInt(1, element.getKey()); + statement.setString(2, element.getValue()); + } + })); pipeline.run();