Skip to content

Commit

Permalink
source-mysql : chunking queries impl (#29109)
Browse files Browse the repository at this point in the history
  • Loading branch information
akashkulk committed Aug 18, 2023
1 parent 528127a commit 2b18864
Show file tree
Hide file tree
Showing 9 changed files with 306 additions and 92 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package io.airbyte.integrations.source.mysql;

import static io.airbyte.integrations.source.relationaldb.RelationalDbQueryUtils.getFullyQualifiedTableNameWithQuoting;

import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.base.Preconditions;
import io.airbyte.db.jdbc.JdbcDatabase;
import io.airbyte.db.jdbc.JdbcUtils;
import io.airbyte.protocol.models.AirbyteStreamNameNamespacePair;
import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream;
import java.sql.SQLException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MySqlQueryUtils {
private static final Logger LOGGER = LoggerFactory.getLogger(MySqlQueryUtils.class);
public record TableSizeInfo(Long tableSize, Long avgRowLength) { }

public static final String TABLE_ESTIMATE_QUERY =
"""
SELECT
(data_length + index_length) as %s,
AVG_ROW_LENGTH as %s
FROM
information_schema.tables
WHERE
table_schema = '%s' AND table_name = '%s';
""";

public static final String TABLE_SIZE_BYTES_COL = "TotalSizeBytes";
public static final String AVG_ROW_LENGTH = "AVG_ROW_LENGTH";

public static Map<AirbyteStreamNameNamespacePair, TableSizeInfo> getTableSizeInfoForStreams(final JdbcDatabase database,
final List<ConfiguredAirbyteStream> streams,
final String quoteString) {
final Map<AirbyteStreamNameNamespacePair, TableSizeInfo> tableSizeInfoMap = new HashMap<>();
streams.forEach(stream -> {
try {
final String name = stream.getStream().getName();
final String namespace = stream.getStream().getNamespace();
final String fullTableName =
getFullyQualifiedTableNameWithQuoting(name, namespace, quoteString);
final List<JsonNode> tableEstimateResult = getTableEstimate(database, namespace, name);
Preconditions.checkState(tableEstimateResult.size() == 1);
final long tableEstimateBytes = tableEstimateResult.get(0).get(TABLE_SIZE_BYTES_COL).asLong();
final long avgTableRowSizeBytes = tableEstimateResult.get(0).get(AVG_ROW_LENGTH).asLong();
LOGGER.info("Stream {} size estimate is {}, average row size estimate is {}", fullTableName, tableEstimateBytes, avgTableRowSizeBytes);
final TableSizeInfo tableSizeInfo = new TableSizeInfo(tableEstimateBytes, avgTableRowSizeBytes);
final AirbyteStreamNameNamespacePair namespacePair =
new AirbyteStreamNameNamespacePair(stream.getStream().getName(), stream.getStream().getNamespace());
tableSizeInfoMap.put(namespacePair, tableSizeInfo);
} catch (final SQLException e) {
LOGGER.warn("Error occurred while attempting to estimate sync size", e);
}
});
return tableSizeInfoMap;
}

private static List<JsonNode> getTableEstimate(final JdbcDatabase database, final String namespace, final String name)
throws SQLException {
// Construct the table estimate query.
final String tableEstimateQuery =
String.format(TABLE_ESTIMATE_QUERY, TABLE_SIZE_BYTES_COL, AVG_ROW_LENGTH, namespace, name);
LOGGER.info("table estimate query: {}", tableEstimateQuery);
final List<JsonNode> jsonNodes = database.bufferedResultSetQuery(conn -> conn.createStatement().executeQuery(tableEstimateQuery),
resultSet -> JdbcUtils.getDefaultSourceOperations().rowToJson(resultSet));
Preconditions.checkState(jsonNodes.size() == 1);
return jsonNodes;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public class MySqlInitialLoadGlobalStateManager implements MySqlInitialLoadState
// have completed the snapshot.
private final Set<AirbyteStreamNameNamespacePair> streamsThatHaveCompletedSnapshot;


MySqlInitialLoadGlobalStateManager(final InitialLoadStreams initialLoadStreams,
final Map<AirbyteStreamNameNamespacePair, PrimaryKeyInfo> pairToPrimaryKeyInfo,
final CdcState cdcState, final ConfiguredAirbyteCatalog catalog) {
Expand All @@ -55,6 +56,7 @@ private static Set<AirbyteStreamNameNamespacePair> initStreamsCompletedSnapshot(
return streamsThatHaveCompletedSnapshot;
}


private static Map<AirbyteStreamNameNamespacePair, PrimaryKeyLoadStatus> initPairToPrimaryKeyLoadStatusMap(
final Map<io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair, PrimaryKeyLoadStatus> pairToPkStatus) {
final Map<AirbyteStreamNameNamespacePair, PrimaryKeyLoadStatus> map = new HashMap<>();
Expand Down Expand Up @@ -82,6 +84,11 @@ public AirbyteStateMessage createIntermediateStateMessage(final AirbyteStreamNam
.withGlobal(globalState);
}

@Override
public void updatePrimaryKeyLoadState(final AirbyteStreamNameNamespacePair pair, final PrimaryKeyLoadStatus pkLoadStatus) {
pairToPrimaryKeyLoadStatus.put(pair, pkLoadStatus);
}

public AirbyteStateMessage createFinalStateMessage(final AirbyteStreamNameNamespacePair pair, final JsonNode streamStateForIncrementalRun) {
streamsThatHaveCompletedSnapshot.add(pair);
final List<AirbyteStreamState> streamStates = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
package io.airbyte.integrations.source.mysql.initialsync;

import static io.airbyte.integrations.source.relationaldb.RelationalDbQueryUtils.enquoteIdentifier;
import static io.airbyte.integrations.source.relationaldb.RelationalDbQueryUtils.getFullyQualifiedTableNameWithQuoting;

import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.annotations.VisibleForTesting;
import com.mysql.cj.MysqlType;
import io.airbyte.commons.stream.AirbyteStreamUtils;
import io.airbyte.commons.util.AutoCloseableIterator;
import io.airbyte.commons.util.AutoCloseableIterators;
import io.airbyte.db.jdbc.JdbcDatabase;
import io.airbyte.integrations.source.mysql.initialsync.MySqlInitialReadUtil.PrimaryKeyInfo;
import io.airbyte.integrations.source.mysql.MySqlQueryUtils.TableSizeInfo;
import io.airbyte.integrations.source.mysql.internal.models.PrimaryKeyLoadStatus;
import io.airbyte.integrations.source.relationaldb.DbSourceDiscoverUtil;
import io.airbyte.integrations.source.relationaldb.RelationalDbQueryUtils;
import io.airbyte.integrations.source.relationaldb.TableInfo;
import io.airbyte.protocol.models.AirbyteStreamNameNamespacePair;
import io.airbyte.protocol.models.CommonField;
Expand All @@ -24,9 +21,6 @@
import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog;
import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream;
import io.airbyte.protocol.models.v0.SyncMode;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
Expand All @@ -35,7 +29,6 @@
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -50,18 +43,24 @@ public class MySqlInitialLoadHandler {
private final MySqlInitialLoadStateManager initialLoadStateManager;
private final Function<AirbyteStreamNameNamespacePair, JsonNode> streamStateForIncrementalRunSupplier;

private static final long QUERY_TARGET_SIZE_GB = 1_073_741_824;
private static final long DEFAULT_CHUNK_SIZE = 1_000_000;
final Map<AirbyteStreamNameNamespacePair, TableSizeInfo> tableSizeInfoMap;

public MySqlInitialLoadHandler(final JsonNode config,
final JdbcDatabase database,
final MySqlInitialLoadSourceOperations sourceOperations,
final String quoteString,
final MySqlInitialLoadStateManager initialLoadStateManager,
final Function<AirbyteStreamNameNamespacePair, JsonNode> streamStateForIncrementalRunSupplier) {
final Function<AirbyteStreamNameNamespacePair, JsonNode> streamStateForIncrementalRunSupplier,
final Map<AirbyteStreamNameNamespacePair, TableSizeInfo> tableSizeInfoMap) {
this.config = config;
this.database = database;
this.sourceOperations = sourceOperations;
this.quoteString = quoteString;
this.initialLoadStateManager = initialLoadStateManager;
this.streamStateForIncrementalRunSupplier = streamStateForIncrementalRunSupplier;
this.tableSizeInfoMap = tableSizeInfoMap;
}

public List<AutoCloseableIterator<AirbyteMessage>> getIncrementalIterators(
Expand All @@ -88,7 +87,9 @@ public List<AutoCloseableIterator<AirbyteMessage>> getIncrementalIterators(
.map(CommonField::getName)
.filter(CatalogHelpers.getTopLevelFieldNames(airbyteStream)::contains)
.collect(Collectors.toList());
final AutoCloseableIterator<JsonNode> queryStream = queryTablePk(selectedDatabaseFields, table.getNameSpace(), table.getName());
final AutoCloseableIterator<JsonNode> queryStream =
new MySqlInitialLoadRecordIterator(database, sourceOperations, quoteString, initialLoadStateManager, selectedDatabaseFields, pair,
calculateChunkSize(tableSizeInfoMap.get(pair), pair));
final AutoCloseableIterator<AirbyteMessage> recordIterator =
getRecordIterator(queryStream, streamName, namespace, emittedAt.toEpochMilli());
final AutoCloseableIterator<AirbyteMessage> recordAndMessageIterator = augmentWithState(recordIterator, pair);
Expand All @@ -100,75 +101,18 @@ public List<AutoCloseableIterator<AirbyteMessage>> getIncrementalIterators(
return iteratorList;
}

private AutoCloseableIterator<JsonNode> queryTablePk(
final List<String> columnNames,
final String schemaName,
final String tableName) {
LOGGER.info("Queueing query for table: {}", tableName);
final AirbyteStreamNameNamespacePair airbyteStream =
AirbyteStreamUtils.convertFromNameAndNamespace(tableName, schemaName);
return AutoCloseableIterators.lazyIterator(() -> {
try {
final Stream<JsonNode> stream = database.unsafeQuery(
connection -> createPkQueryStatement(connection, columnNames, schemaName, tableName, airbyteStream),
sourceOperations::rowToJson);
return AutoCloseableIterators.fromStream(stream, airbyteStream);
} catch (final SQLException e) {
throw new RuntimeException(e);
}
}, airbyteStream);
}

private PreparedStatement createPkQueryStatement(
final Connection connection,
final List<String> columnNames,
final String schemaName,
final String tableName,
final AirbyteStreamNameNamespacePair pair) {
try {
LOGGER.info("Preparing query for table: {}", tableName);
final String fullTableName = getFullyQualifiedTableNameWithQuoting(schemaName, tableName,
quoteString);

final String wrappedColumnNames = RelationalDbQueryUtils.enquoteIdentifierList(columnNames, quoteString);

final PrimaryKeyLoadStatus pkLoadStatus = initialLoadStateManager.getPrimaryKeyLoadStatus(pair);
final PrimaryKeyInfo pkInfo = initialLoadStateManager.getPrimaryKeyInfo(pair);
final PreparedStatement preparedStatement =
getPkPreparedStatement(connection, wrappedColumnNames, fullTableName, pkLoadStatus, pkInfo);
LOGGER.info("Executing query for table {}: {}", tableName, preparedStatement);
return preparedStatement;
} catch (final SQLException e) {
throw new RuntimeException(e);
}
}

private PreparedStatement getPkPreparedStatement(final Connection connection,
final String wrappedColumnNames,
final String fullTableName,
final PrimaryKeyLoadStatus pkLoadStatus,
final PrimaryKeyInfo pkInfo)
throws SQLException {

if (pkLoadStatus == null) {
final String quotedCursorField = enquoteIdentifier(pkInfo.pkFieldName(), quoteString);
final String sql = String.format("SELECT %s FROM %s ORDER BY %s", wrappedColumnNames, fullTableName,
quotedCursorField, quotedCursorField);
final PreparedStatement preparedStatement = connection.prepareStatement(sql);
return preparedStatement;

} else {
final String quotedCursorField = enquoteIdentifier(pkLoadStatus.getPkName(), quoteString);
// Since a pk is unique, we can issue a > query instead of a >=, as there cannot be two records with the same pk.
final String sql = String.format("SELECT %s FROM %s WHERE %s > ? ORDER BY %s", wrappedColumnNames, fullTableName,
quotedCursorField, quotedCursorField);

final PreparedStatement preparedStatement = connection.prepareStatement(sql);
final MysqlType cursorFieldType = pkInfo.fieldType();
sourceOperations.setCursorField(preparedStatement, 1, cursorFieldType, pkLoadStatus.getPkVal());

return preparedStatement;
// Calculates the number of rows to fetch per query.
@VisibleForTesting
public static long calculateChunkSize(final TableSizeInfo tableSizeInfo, final AirbyteStreamNameNamespacePair pair) {
// If table size info could not be calculated, a default chunk size will be provided.
if (tableSizeInfo == null || tableSizeInfo.tableSize() == 0 || tableSizeInfo.avgRowLength() == 0) {
LOGGER.info("Chunk size could not be determined for pair: {}, defaulting to {} rows", pair, DEFAULT_CHUNK_SIZE);
return DEFAULT_CHUNK_SIZE;
}
final long avgRowLength = tableSizeInfo.avgRowLength();
final long chunkSize = QUERY_TARGET_SIZE_GB / avgRowLength;
LOGGER.info("Chunk size determined for pair: {}, is {}", pair, chunkSize);
return chunkSize;
}

// Transforms the given iterator to create an {@link AirbyteRecordMessage}
Expand Down

0 comments on commit 2b18864

Please sign in to comment.