Skip to content

Commit

Permalink
Refactor Snowflake internal Staging as a base class for other staging…
Browse files Browse the repository at this point in the history
… classes (#10865)

* Refactor Snowflake internal Staging as model to share staging abilities in jdbc destinations
  • Loading branch information
ChristopheDuong committed Mar 11, 2022
1 parent e27cb91 commit 744e0d5
Show file tree
Hide file tree
Showing 37 changed files with 298 additions and 171 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ protected String disabled_convertStreamName(final String input) {
}
}

protected String applyDefaultCase(final String input) {
return input;
}

protected boolean useExtendedIdentifiers(final String input) {
boolean result = false;
if (input.matches("[^\\p{Alpha}_].*")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,8 @@ public interface NamingConventionTransformer {
@Deprecated
String getTmpTableName(String name);

String convertStreamName(final String input);

String applyDefaultCase(final String input);

}
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,16 @@ public String getTmpTableName(final String streamName) {
return convertStreamName(Strings.addRandomSuffix("_airbyte_tmp", "_", 3) + "_" + streamName);
}

protected String convertStreamName(final String input) {
@Override
public String convertStreamName(final String input) {
return Names.toAlphanumericAndUnderscore(input);
}

@Override
public String applyDefaultCase(final String input) {
return input;
}

/**
* Rebuild a JsonNode adding sanitized property names (a subset of special characters replaced by
* underscores) while keeping original property names too. This is needed by some destinations as
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,16 @@
#


from typing import Mapping, Any, Iterable
from typing import Any, Iterable, Mapping

from airbyte_cdk import AirbyteLogger
from airbyte_cdk.destinations import Destination
from airbyte_cdk.models import AirbyteConnectionStatus, ConfiguredAirbyteCatalog, AirbyteMessage, Status
from airbyte_cdk.models import AirbyteConnectionStatus, AirbyteMessage, ConfiguredAirbyteCatalog, Status


class Destination{{properCase name}}(Destination):
def write(
self,
config: Mapping[str, Any],
configured_catalog: ConfiguredAirbyteCatalog,
input_messages: Iterable[AirbyteMessage]
self, config: Mapping[str, Any], configured_catalog: ConfiguredAirbyteCatalog, input_messages: Iterable[AirbyteMessage]
) -> Iterable[AirbyteMessage]:

"""
Expand Down Expand Up @@ -54,6 +51,3 @@ class Destination{{properCase name}}(Destination):
return AirbyteConnectionStatus(status=Status.SUCCEEDED)
except Exception as e:
return AirbyteConnectionStatus(status=Status.FAILED, message=f"An exception occurred: {repr(e)}")



Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
"airbyte-cdk",
]

TEST_REQUIREMENTS = [
"pytest~=6.1"
]
TEST_REQUIREMENTS = ["pytest~=6.1"]

setup(
name="destination_{{snakeCase name}}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
public class ClickhouseSQLNameTransformer extends ExtendedNameTransformer {

@Override
protected String applyDefaultCase(final String input) {
public String applyDefaultCase(final String input) {
return input.toLowerCase();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public String getRawTableName(final String streamName) {
}

@Override
protected String applyDefaultCase(final String input) {
public String applyDefaultCase(final String input) {
return input.toLowerCase();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
package io.airbyte.integrations.destination.jdbc;

import io.airbyte.protocol.models.DestinationSyncMode;
import java.util.ArrayList;
import java.util.List;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;

/**
* Write configuration POJO for all destinations extending {@link AbstractJdbcDestination}.
Expand All @@ -19,19 +23,33 @@ public class WriteConfig {
private final String tmpTableName;
private final String outputTableName;
private final DestinationSyncMode syncMode;
private final DateTime writeDatetime;
private final List<String> stagedFiles;

public WriteConfig(final String streamName,
final String namespace,
final String outputSchemaName,
final String tmpTableName,
final String outputTableName,
final DestinationSyncMode syncMode) {
this(streamName, namespace, outputSchemaName, tmpTableName, outputTableName, syncMode, DateTime.now(DateTimeZone.UTC));
}

public WriteConfig(final String streamName,
final String namespace,
final String outputSchemaName,
final String tmpTableName,
final String outputTableName,
final DestinationSyncMode syncMode,
final DateTime writeDatetime) {
this.streamName = streamName;
this.namespace = namespace;
this.outputSchemaName = outputSchemaName;
this.tmpTableName = tmpTableName;
this.outputTableName = outputTableName;
this.syncMode = syncMode;
this.stagedFiles = new ArrayList<>();
this.writeDatetime = writeDatetime;
}

public String getStreamName() {
Expand All @@ -58,6 +76,22 @@ public DestinationSyncMode getSyncMode() {
return syncMode;
}

public DateTime getWriteDatetime() {
return writeDatetime;
}

public List<String> getStagedFiles() {
return stagedFiles;
}

public void addStagedFile(final String file) {
stagedFiles.add(file);
}

public void clearStagedFiles() {
stagedFiles.clear();
}

@Override
public String toString() {
return "WriteConfig{" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* Copyright (c) 2021 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.integrations.destination.snowflake;
package io.airbyte.integrations.destination.staging;

import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.base.Preconditions;
Expand All @@ -16,7 +16,6 @@
import io.airbyte.integrations.destination.buffered_stream_consumer.OnCloseFunction;
import io.airbyte.integrations.destination.buffered_stream_consumer.OnStartFunction;
import io.airbyte.integrations.destination.buffered_stream_consumer.RecordWriter;
import io.airbyte.integrations.destination.jdbc.SqlOperations;
import io.airbyte.integrations.destination.jdbc.WriteConfig;
import io.airbyte.protocol.models.AirbyteMessage;
import io.airbyte.protocol.models.AirbyteStream;
Expand All @@ -30,38 +29,40 @@
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Snowflake Internal Staging consists of 4 main parts
*
* CREATE STAGE @TEMP_STAGE_NAME -- Creates a new named internal stage to use for loading data from
* files into Snowflake tables and unloading data from tables into files PUT
* file://local/<file-patterns> @TEMP_STAGE_NAME. --JDBC Driver will upload the files into stage
* COPY FROM @TEMP_STAGE_NAME -- Loads data from staged files to an existing table.
* DROP @TEMP_STAGE_NAME -- Drop temporary stage after sync
*/
public class SnowflakeInternalStagingConsumerFactory {
public class StagingConsumerFactory {

private static final Logger LOGGER = LoggerFactory.getLogger(SnowflakeInternalStagingConsumerFactory.class);
private static final Logger LOGGER = LoggerFactory.getLogger(StagingConsumerFactory.class);

private static final long MAX_BATCH_SIZE_BYTES = 128 * 1024 * 1024; // 128mb
private final String CURRENT_SYNC_PATH = UUID.randomUUID().toString();
private final DateTime CURRENT_SYNC_PATH = DateTime.now(DateTimeZone.UTC);
// using a random string here as a placeholder for the moment.
// This would avoid mixing data in the staging area between different syncs (especially if they
// manipulate streams with similar names)
// if we replaced the random connection id by the actual connection_id, we'd gain the opportunity to
// leverage data that was uploaded to stage
// in a previous attempt but failed to load to the warehouse for some reason (interrupted?) instead.
// This would also allow other programs/scripts
// to load (or reload backups?) in the connection's staging area to be loaded at the next sync.
private final String RANDOM_CONNECTION_ID = UUID.randomUUID().toString();

public AirbyteMessageConsumer create(final Consumer<AirbyteMessage> outputRecordCollector,
final JdbcDatabase database,
final SnowflakeStagingSqlOperations sqlOperations,
final SnowflakeSQLNameTransformer namingResolver,
final StagingOperations sqlOperations,
final NamingConventionTransformer namingResolver,
final JsonNode config,
final ConfiguredAirbyteCatalog catalog) {
final List<WriteConfig> writeConfigs = createWriteConfigs(namingResolver, config, catalog);

return new BufferedStreamConsumer(
outputRecordCollector,
onStartFunction(database, sqlOperations, writeConfigs, namingResolver),
recordWriterFunction(database, sqlOperations, writeConfigs, catalog, namingResolver),
onCloseFunction(database, sqlOperations, writeConfigs, namingResolver),
onStartFunction(database, sqlOperations, writeConfigs),
recordWriterFunction(database, sqlOperations, writeConfigs, catalog),
onCloseFunction(database, sqlOperations, writeConfigs),
catalog,
sqlOperations::isValidData,
MAX_BATCH_SIZE_BYTES);
Expand All @@ -74,8 +75,7 @@ private static List<WriteConfig> createWriteConfigs(final NamingConventionTransf
return catalog.getStreams().stream().map(toWriteConfig(namingResolver, config)).collect(Collectors.toList());
}

private static Function<ConfiguredAirbyteStream, WriteConfig> toWriteConfig(
final NamingConventionTransformer namingResolver,
private static Function<ConfiguredAirbyteStream, WriteConfig> toWriteConfig(final NamingConventionTransformer namingResolver,
final JsonNode config) {
return stream -> {
Preconditions.checkNotNull(stream.getDestinationSyncMode(), "Undefined destination sync mode");
Expand Down Expand Up @@ -104,26 +104,25 @@ private static String getOutputSchema(final AirbyteStream stream,
}

private static OnStartFunction onStartFunction(final JdbcDatabase database,
final SnowflakeStagingSqlOperations snowflakeSqlOperations,
final List<WriteConfig> writeConfigs,
final SnowflakeSQLNameTransformer namingResolver) {
final StagingOperations stagingOperations,
final List<WriteConfig> writeConfigs) {
return () -> {
LOGGER.info("Preparing tmp tables in destination started for {} streams", writeConfigs.size());

for (final WriteConfig writeConfig : writeConfigs) {
final String schema = writeConfig.getOutputSchemaName();
final String stream = writeConfig.getStreamName();
final String tmpTable = writeConfig.getTmpTableName();
final String stage = namingResolver.getStageName(schema, writeConfig.getOutputTableName());
final String stage = stagingOperations.getStageName(schema, writeConfig.getOutputTableName());

LOGGER.info("Preparing stage in destination started for schema {} stream {}: tmp table: {}, stage: {}",
schema, stream, tmpTable, stage);

AirbyteSentry.executeWithTracing("PrepareStreamStage",
() -> {
snowflakeSqlOperations.createSchemaIfNotExists(database, schema);
snowflakeSqlOperations.createTableIfNotExists(database, schema, tmpTable);
snowflakeSqlOperations.createStageIfNotExists(database, stage);
stagingOperations.createSchemaIfNotExists(database, schema);
stagingOperations.createTableIfNotExists(database, schema, tmpTable);
stagingOperations.createStageIfNotExists(database, stage);
},
Map.of("schema", schema, "stream", stream, "tmpTable", tmpTable, "stage", stage));

Expand All @@ -139,14 +138,13 @@ private static AirbyteStreamNameNamespacePair toNameNamespacePair(final WriteCon
}

private RecordWriter recordWriterFunction(final JdbcDatabase database,
final SqlOperations snowflakeSqlOperations,
final StagingOperations stagingOperations,
final List<WriteConfig> writeConfigs,
final ConfiguredAirbyteCatalog catalog,
final SnowflakeSQLNameTransformer namingResolver) {
final ConfiguredAirbyteCatalog catalog) {
final Map<AirbyteStreamNameNamespacePair, WriteConfig> pairToWriteConfig =
writeConfigs.stream()
.collect(Collectors.toUnmodifiableMap(
SnowflakeInternalStagingConsumerFactory::toNameNamespacePair, Function.identity()));
StagingConsumerFactory::toNameNamespacePair, Function.identity()));

return (pair, records) -> {
if (!pairToWriteConfig.containsKey(pair)) {
Expand All @@ -157,16 +155,14 @@ private RecordWriter recordWriterFunction(final JdbcDatabase database,
final WriteConfig writeConfig = pairToWriteConfig.get(pair);
final String schemaName = writeConfig.getOutputSchemaName();
final String tableName = writeConfig.getOutputTableName();
final String path = namingResolver.getStagingPath(schemaName, tableName, CURRENT_SYNC_PATH);

snowflakeSqlOperations.insertRecords(database, records, schemaName, path);
final String path = stagingOperations.getStagingPath(RANDOM_CONNECTION_ID, schemaName, tableName, CURRENT_SYNC_PATH);
stagingOperations.insertRecords(database, records, schemaName, path);
};
}

private OnCloseFunction onCloseFunction(final JdbcDatabase database,
final SnowflakeStagingSqlOperations sqlOperations,
final List<WriteConfig> writeConfigs,
final SnowflakeSQLNameTransformer namingResolver) {
final StagingOperations stagingOperations,
final List<WriteConfig> writeConfigs) {
return (hasFailed) -> {
if (!hasFailed) {
final List<String> queryList = new ArrayList<>();
Expand All @@ -177,29 +173,29 @@ private OnCloseFunction onCloseFunction(final JdbcDatabase database,
final String streamName = writeConfig.getStreamName();
final String srcTableName = writeConfig.getTmpTableName();
final String dstTableName = writeConfig.getOutputTableName();
final String path = namingResolver.getStagingPath(schemaName, dstTableName, CURRENT_SYNC_PATH);
final String path = stagingOperations.getStagingPath(RANDOM_CONNECTION_ID, schemaName, dstTableName, CURRENT_SYNC_PATH);
LOGGER.info("Finalizing stream {}. schema {}, tmp table {}, final table {}, stage path {}",
streamName, schemaName, srcTableName, dstTableName, path);

try {
sqlOperations.copyIntoTmpTableFromStage(database, path, srcTableName, schemaName);
stagingOperations.copyIntoTmpTableFromStage(database, path, srcTableName, schemaName);
} catch (final Exception e) {
sqlOperations.cleanUpStage(database, path);
stagingOperations.cleanUpStage(database, path);
LOGGER.info("Cleaning stage path {}", path);
throw new RuntimeException("Failed to upload data from stage " + path, e);
}

sqlOperations.createTableIfNotExists(database, schemaName, dstTableName);
stagingOperations.createTableIfNotExists(database, schemaName, dstTableName);
switch (writeConfig.getSyncMode()) {
case OVERWRITE -> queryList.add(sqlOperations.truncateTableQuery(database, schemaName, dstTableName));
case OVERWRITE -> queryList.add(stagingOperations.truncateTableQuery(database, schemaName, dstTableName));
case APPEND, APPEND_DEDUP -> {}
default -> throw new IllegalStateException("Unrecognized sync mode: " + writeConfig.getSyncMode());
}
queryList.add(sqlOperations.copyTableQuery(database, schemaName, srcTableName, dstTableName));
queryList.add(stagingOperations.copyTableQuery(database, schemaName, srcTableName, dstTableName));
}

LOGGER.info("Executing finalization of tables.");
sqlOperations.executeTransaction(database, queryList);
stagingOperations.executeTransaction(database, queryList);
LOGGER.info("Finalizing tables in destination completed.");
}
LOGGER.info("Cleaning tmp tables in destination started for {} streams", writeConfigs.size());
Expand All @@ -209,12 +205,12 @@ private OnCloseFunction onCloseFunction(final JdbcDatabase database,
LOGGER.info("Cleaning tmp table in destination started for stream {}. schema {}, tmp table name: {}", writeConfig.getStreamName(), schemaName,
tmpTableName);

sqlOperations.dropTableIfExists(database, schemaName, tmpTableName);
stagingOperations.dropTableIfExists(database, schemaName, tmpTableName);
final String outputTableName = writeConfig.getOutputTableName();
final String stageName = namingResolver.getStageName(schemaName, outputTableName);
final String stageName = stagingOperations.getStageName(schemaName, outputTableName);
LOGGER.info("Cleaning stage in destination started for stream {}. schema {}, stage: {}", writeConfig.getStreamName(), schemaName,
stageName);
sqlOperations.dropStageIfExists(database, stageName);
stagingOperations.dropStageIfExists(database, stageName);
}
LOGGER.info("Cleaning tmp tables and stages in destination completed.");
};
Expand Down
Loading

0 comments on commit 744e0d5

Please sign in to comment.