Skip to content

Commit

Permalink
add DestinationState stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
edgao committed Feb 21, 2024
1 parent d2c07e5 commit c0221ad
Show file tree
Hide file tree
Showing 20 changed files with 778 additions and 173 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,14 @@
import io.airbyte.integrations.base.destination.typing_deduping.NoopV2TableMigrator;
import io.airbyte.integrations.base.destination.typing_deduping.ParsedCatalog;
import io.airbyte.integrations.base.destination.typing_deduping.TyperDeduper;
import io.airbyte.integrations.base.destination.typing_deduping.migrators.MinimumDestinationState;
import io.airbyte.protocol.models.v0.AirbyteConnectionStatus;
import io.airbyte.protocol.models.v0.AirbyteConnectionStatus.Status;
import io.airbyte.protocol.models.v0.AirbyteMessage;
import io.airbyte.protocol.models.v0.ConfiguredAirbyteCatalog;
import io.airbyte.protocol.models.v0.ConfiguredAirbyteStream;
import java.sql.SQLException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -249,7 +251,7 @@ private void assertCustomParametersDontOverwriteDefaultParameters(final Map<Stri

protected abstract JdbcSqlGenerator getSqlGenerator();

protected abstract JdbcDestinationHandler getDestinationHandler(final String databaseName, final JdbcDatabase database);
protected abstract JdbcDestinationHandler<? extends MinimumDestinationState> getDestinationHandler(final String databaseName, final JdbcDatabase database);

/**
* "database" key at root of the config json, for any other variants in config, override this
Expand Down Expand Up @@ -293,15 +295,15 @@ public SerializedAirbyteMessageConsumer getSerializedMessageConsumer(final JsonN
final String databaseName = getDatabaseName(config);
final var migrator = new JdbcV1V2Migrator(namingResolver, database, databaseName);
final NoopV2TableMigrator v2TableMigrator = new NoopV2TableMigrator();
final DestinationHandler destinationHandler = getDestinationHandler(databaseName, database);
final DestinationHandler<? extends MinimumDestinationState> destinationHandler = getDestinationHandler(databaseName, database);
final boolean disableTypeDedupe = config.has(DISABLE_TYPE_DEDUPE) && config.get(DISABLE_TYPE_DEDUPE).asBoolean(false);
final TyperDeduper typerDeduper;
if (disableTypeDedupe) {
typerDeduper = new NoOpTyperDeduperWithV1V2Migrations(sqlGenerator, destinationHandler, parsedCatalog, migrator, v2TableMigrator,
8);
} else {
typerDeduper =
new DefaultTyperDeduper<>(sqlGenerator, destinationHandler, parsedCatalog, migrator, v2TableMigrator, 8);
new DefaultTyperDeduper<>(sqlGenerator, destinationHandler, parsedCatalog, migrator, v2TableMigrator, Collections.emptyList(), 8);
}
return JdbcBufferedConsumerFactory.createAsync(
outputRecordCollector,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,69 +8,99 @@
import static io.airbyte.cdk.integrations.base.JavaBaseConstants.COLUMN_NAME_AB_META;
import static io.airbyte.cdk.integrations.base.JavaBaseConstants.COLUMN_NAME_AB_RAW_ID;
import static io.airbyte.cdk.integrations.base.JavaBaseConstants.V2_FINAL_TABLE_METADATA_COLUMNS;
import static java.util.stream.Collectors.toMap;
import static org.jooq.impl.DSL.asterisk;
import static org.jooq.impl.DSL.createTableIfNotExists;
import static org.jooq.impl.DSL.deleteFrom;
import static org.jooq.impl.DSL.exists;
import static org.jooq.impl.DSL.field;
import static org.jooq.impl.DSL.insertInto;
import static org.jooq.impl.DSL.name;
import static org.jooq.impl.DSL.quotedName;
import static org.jooq.impl.DSL.select;
import static org.jooq.impl.DSL.selectOne;
import static org.jooq.impl.DSL.table;

import com.fasterxml.jackson.databind.JsonNode;
import io.airbyte.cdk.db.jdbc.JdbcDatabase;
import io.airbyte.cdk.integrations.destination.jdbc.ColumnDefinition;
import io.airbyte.cdk.integrations.destination.jdbc.TableDefinition;
import io.airbyte.cdk.integrations.util.ConnectorExceptionUtil;
import io.airbyte.commons.concurrency.CompletableFutures;
import io.airbyte.commons.exceptions.SQLRuntimeException;
import io.airbyte.commons.functional.Either;
import io.airbyte.commons.json.Jsons;
import io.airbyte.integrations.base.destination.typing_deduping.AirbyteProtocolType;
import io.airbyte.integrations.base.destination.typing_deduping.AirbyteType;
import io.airbyte.integrations.base.destination.typing_deduping.DestinationHandler;
import io.airbyte.integrations.base.destination.typing_deduping.DestinationInitialState;
import io.airbyte.integrations.base.destination.typing_deduping.DestinationInitialStateImpl;
import io.airbyte.integrations.base.destination.typing_deduping.InitialRawTableState;
import io.airbyte.integrations.base.destination.typing_deduping.Sql;
import io.airbyte.integrations.base.destination.typing_deduping.StreamConfig;
import io.airbyte.integrations.base.destination.typing_deduping.StreamId;
import io.airbyte.integrations.base.destination.typing_deduping.Struct;
import io.airbyte.protocol.models.v0.AirbyteStreamNameNamespacePair;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Timestamp;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.stream.Stream;
import lombok.extern.slf4j.Slf4j;
import org.jooq.Condition;
import org.jooq.DSLContext;
import org.jooq.InsertValuesStep3;
import org.jooq.Record;
import org.jooq.SQLDialect;
import org.jooq.conf.ParamType;
import org.jooq.impl.DSL;
import org.jooq.impl.SQLDataType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Slf4j
public abstract class JdbcDestinationHandler implements DestinationHandler {
public abstract class JdbcDestinationHandler<DestinationState> implements DestinationHandler<DestinationState> {

private static final Logger LOGGER = LoggerFactory.getLogger(JdbcDestinationHandler.class);
private static final String DESTINATION_STATE_TABLE_NAME = "_airbyte_destination_state";
private static final String DESTINATION_STATE_TABLE_COLUMN_NAME = "name";
private static final String DESTINATION_STATE_TABLE_COLUMN_NAMESPACE = "namespace";
private static final String DESTINATION_STATE_TABLE_COLUMN_STATE = "state";

protected final String databaseName;
protected final JdbcDatabase jdbcDatabase;
protected final String rawTableSchemaName;
private final SQLDialect dialect;

public JdbcDestinationHandler(final String databaseName,
final JdbcDatabase jdbcDatabase) {
final JdbcDatabase jdbcDatabase,
final String rawTableSchemaName,
final SQLDialect dialect) {
this.databaseName = databaseName;
this.jdbcDatabase = jdbcDatabase;
this.rawTableSchemaName = rawTableSchemaName;
this.dialect = dialect;
}

protected DSLContext getDslContext() {
return DSL.using(dialect);
}

private Optional<TableDefinition> findExistingTable(final StreamId id) throws Exception {
return findExistingTable(jdbcDatabase, databaseName, id.finalNamespace(), id.finalName());
return findExistingTable(id);
}

private boolean isFinalTableEmpty(final StreamId id) throws Exception {
return !jdbcDatabase.queryBoolean(
select(
getDslContext().select(
field(exists(
selectOne()
.from(name(id.finalNamespace(), id.finalName()))
Expand Down Expand Up @@ -99,7 +129,7 @@ private InitialRawTableState getInitialRawTableState(final StreamId id) throws E
// but it's also the only method in the JdbcDatabase interface to return non-string/int types
try (final Stream<Timestamp> timestampStream = jdbcDatabase.unsafeQuery(
conn -> conn.prepareStatement(
select(field("MIN(_airbyte_extracted_at)").as("min_timestamp"))
getDslContext().select(field("MIN(_airbyte_extracted_at)").as("min_timestamp"))
.from(name(id.rawNamespace(), id.rawName()))
.where(DSL.condition("_airbyte_loaded_at IS NULL"))
.getSQL()),
Expand All @@ -118,7 +148,7 @@ record -> record.getTimestamp("min_timestamp"))) {
// This second query just finds the newest raw record.
try (final Stream<Timestamp> timestampStream = jdbcDatabase.unsafeQuery(
conn -> conn.prepareStatement(
select(field("MAX(_airbyte_extracted_at)").as("min_timestamp"))
getDslContext().select(field("MAX(_airbyte_extracted_at)").as("min_timestamp"))
.from(name(id.rawNamespace(), id.rawName()))
.getSQL()),
record -> record.getTimestamp("min_timestamp"))) {
Expand Down Expand Up @@ -149,16 +179,45 @@ public void execute(final Sql sql) throws Exception {
}

@Override
public List<DestinationInitialState> gatherInitialState(List<StreamConfig> streamConfigs) throws Exception {
final List<CompletionStage<DestinationInitialState>> initialStates = streamConfigs.stream()
.map(this::retrieveState)
public List<DestinationInitialState<DestinationState>> gatherInitialState(List<StreamConfig> streamConfigs) throws Exception {
// Use stream n/ns pair because we don't want to build the full StreamId here
CompletableFuture<Map<AirbyteStreamNameNamespacePair, DestinationState>> destinationStatesFuture = CompletableFuture.supplyAsync(() -> {
try {
// Guarantee the table exists.
jdbcDatabase.execute(
getDslContext().createTableIfNotExists(quotedName(rawTableSchemaName, DESTINATION_STATE_TABLE_NAME))
.column(quotedName(DESTINATION_STATE_TABLE_COLUMN_NAME), SQLDataType.VARCHAR)
.column(quotedName(DESTINATION_STATE_TABLE_COLUMN_NAMESPACE), SQLDataType.VARCHAR)
// Just use a string type, even if the destination has a json type.
// We're never going to query this column in a fancy way - all our processing can happen client-side.
.column(quotedName(DESTINATION_STATE_TABLE_COLUMN_STATE), SQLDataType.VARCHAR)
.getSQL(ParamType.INLINED)
);
// Fetch all records from it. We _could_ filter down to just our streams... but meh. This is small data.
return jdbcDatabase.queryJsons(
getDslContext().select(asterisk())
.from(quotedName(rawTableSchemaName, DESTINATION_STATE_TABLE_NAME))
.getSQL()
).stream().collect(toMap(
record -> new AirbyteStreamNameNamespacePair(
record.get(DESTINATION_STATE_TABLE_COLUMN_NAME).asText(),
record.get(DESTINATION_STATE_TABLE_COLUMN_NAMESPACE).asText()),
record -> toDestinationState(Jsons.deserialize(record.get(DESTINATION_STATE_TABLE_COLUMN_STATE).asText()))
));
} catch (SQLException e) {
throw new RuntimeException(e);
}
});

final List<CompletionStage<DestinationInitialState<DestinationState>>> initialStates = streamConfigs.stream()
.map(streamConfig -> retrieveState(destinationStatesFuture, streamConfig))
.toList();
final List<Either<? extends Exception, DestinationInitialState>> states = CompletableFutures.allOf(initialStates).toCompletableFuture().join();
final List<Either<? extends Exception, DestinationInitialState<DestinationState>>> states = CompletableFutures.allOf(initialStates).toCompletableFuture().join();
return ConnectorExceptionUtil.getResultsOrLogAndThrowFirst("Failed to retrieve initial state", states);
}

private CompletionStage<DestinationInitialState> retrieveState(final StreamConfig streamConfig) {
return CompletableFuture.supplyAsync(() -> {
private CompletionStage<DestinationInitialState<DestinationState>> retrieveState(final CompletableFuture<Map<AirbyteStreamNameNamespacePair, DestinationState>> destinationStatesFuture, final StreamConfig streamConfig) {
return destinationStatesFuture.thenApply(destinationStates -> {
try {
final Optional<TableDefinition> finalTableDefinition = findExistingTable(streamConfig.id());
// Only evaluate schema mismatch & final table emptiness if the final table exists.
Expand All @@ -169,8 +228,9 @@ private CompletionStage<DestinationInitialState> retrieveState(final StreamConfi
isFinalTableEmpty = isFinalTableEmpty(streamConfig.id());
}
final InitialRawTableState initialRawTableState = getInitialRawTableState(streamConfig.id());
return new DestinationInitialStateImpl(streamConfig, finalTableDefinition.isPresent(), initialRawTableState,
isSchemaMismatch, isFinalTableEmpty);
DestinationState destinationState = destinationStates.getOrDefault(streamConfig.id().asPair(), toDestinationState(Jsons.emptyObject()));
return new DestinationInitialState<>(streamConfig, finalTableDefinition.isPresent(), initialRawTableState,
isSchemaMismatch, isFinalTableEmpty, destinationState);
} catch (Exception e) {
throw new RuntimeException(e);
}
Expand Down Expand Up @@ -253,6 +313,35 @@ protected boolean existingSchemaMatchesStreamConfig(final StreamConfig stream, f
return actualColumns.equals(intendedColumns);
}

@Override
public void commitDestinationStates(final Map<StreamId, DestinationState> destinationStates) throws Exception {
// Delete all state records where the stream name+namespace match one of our states
String deleteStates = getDslContext().deleteFrom(table(quotedName(rawTableSchemaName, DESTINATION_STATE_TABLE_NAME)))
.where(destinationStates.keySet().stream()
.map(streamId -> field(DESTINATION_STATE_TABLE_COLUMN_NAME).eq(streamId.originalName())
.and(field(DESTINATION_STATE_TABLE_COLUMN_NAMESPACE).eq(streamId.originalNamespace())))
.reduce(
DSL.trueCondition(),
Condition::or
))
.getSQL(ParamType.INLINED);

// Reinsert all of our states
InsertValuesStep3<Record, String, String, String> insertStatesStep = getDslContext().insertInto(table(quotedName(rawTableSchemaName, DESTINATION_STATE_TABLE_NAME)))
.columns(
field(quotedName(DESTINATION_STATE_TABLE_COLUMN_NAME), String.class),
field(quotedName(DESTINATION_STATE_TABLE_COLUMN_NAMESPACE), String.class),
field(quotedName(DESTINATION_STATE_TABLE_COLUMN_STATE), String.class));
for (Map.Entry<StreamId, DestinationState> destinationState : destinationStates.entrySet()) {
final StreamId streamId = destinationState.getKey();
final String stateJson = Jsons.serialize(destinationState.getValue());
insertStatesStep = insertStatesStep.values(streamId.originalName(), streamId.originalNamespace(), stateJson);
}
String insertStates = insertStatesStep.getSQL(ParamType.INLINED);

jdbcDatabase.executeWithinTransaction(List.of(deleteStates, insertStates));
}

/**
* Convert to the TYPE_NAME retrieved from {@link java.sql.DatabaseMetaData#getColumns}
*
Expand All @@ -261,4 +350,6 @@ protected boolean existingSchemaMatchesStreamConfig(final StreamConfig stream, f
*/
protected abstract String toJdbcTypeName(final AirbyteType airbyteType);

protected abstract DestinationState toDestinationState(final JsonNode json);

}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import org.jooq.impl.DSL;
import org.jooq.impl.SQLDataType;

public abstract class JdbcSqlGeneratorIntegrationTest extends BaseSqlGeneratorIntegrationTest {
public abstract class JdbcSqlGeneratorIntegrationTest<DestinationState> extends BaseSqlGeneratorIntegrationTest<DestinationState> {

protected abstract JdbcDatabase getDatabase();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public abstract class BaseDestinationV1V2Migrator<DialectTableDefinition> implem
@Override
public void migrateIfNecessary(
final SqlGenerator sqlGenerator,
final DestinationHandler destinationHandler,
final DestinationHandler<?> destinationHandler,
final StreamConfig streamConfig)
throws Exception {
LOGGER.info("Assessing whether migration is necessary for stream {}", streamConfig.id().finalName());
Expand Down Expand Up @@ -60,7 +60,7 @@ protected boolean shouldMigrate(final StreamConfig streamConfig) throws Exceptio
* @param streamConfig the stream to migrate the raw table of
*/
public void migrate(final SqlGenerator sqlGenerator,
final DestinationHandler destinationHandler,
final DestinationHandler<?> destinationHandler,
final StreamConfig streamConfig)
throws TableNotMigratedException {
final var namespacedTableName = convertToV1RawName(streamConfig);
Expand Down
Loading

0 comments on commit c0221ad

Please sign in to comment.