Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🎉 Snowflake Destination internal staging support #8253

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ public class SnowflakeDestination extends SwitchingDestination<SnowflakeDestinat
enum DestinationType {
INSERT,
COPY_S3,
COPY_GCS
COPY_GCS,
INTERNAL_STAGING
}

public SnowflakeDestination() {
Expand All @@ -32,11 +33,18 @@ public static DestinationType getTypeFromConfig(final JsonNode config) {
return DestinationType.COPY_S3;
} else if (isGcsCopy(config)) {
return DestinationType.COPY_GCS;
} else if (isInternalStaging(config)) {
return DestinationType.INTERNAL_STAGING;
} else {
return DestinationType.INSERT;
}
}

public static boolean isInternalStaging(JsonNode config) {
return config.has("loading_method") && config.get("loading_method").isObject()
&& config.get("loading_method").get("method").asText().equals("Internal Staging");
}

public static boolean isS3Copy(final JsonNode config) {
return config.has("loading_method") && config.get("loading_method").isObject() && config.get("loading_method").has("s3_bucket_name");
}
Expand All @@ -49,11 +57,13 @@ public static Map<DestinationType, Destination> getTypeToDestination() {
final SnowflakeInsertDestination insertDestination = new SnowflakeInsertDestination();
final SnowflakeCopyS3Destination copyS3Destination = new SnowflakeCopyS3Destination();
final SnowflakeCopyGcsDestination copyGcsDestination = new SnowflakeCopyGcsDestination();
final SnowflakeInternalStagingDestination internalStagingDestination = new SnowflakeInternalStagingDestination();

return ImmutableMap.of(
DestinationType.INSERT, insertDestination,
DestinationType.COPY_S3, copyS3Destination,
DestinationType.COPY_GCS, copyGcsDestination);
DestinationType.COPY_GCS, copyGcsDestination,
DestinationType.INTERNAL_STAGING, internalStagingDestination);
}

public static void main(final String[] args) throws Exception {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
/*
* Copyright (c) 2021 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.integrations.destination.snowflake;

import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.base.Preconditions;
import io.airbyte.commons.json.Jsons;
import io.airbyte.db.jdbc.JdbcDatabase;
import io.airbyte.integrations.base.AirbyteMessageConsumer;
import io.airbyte.integrations.base.AirbyteStreamNameNamespacePair;
import io.airbyte.integrations.destination.NamingConventionTransformer;
import io.airbyte.integrations.destination.buffered_stream_consumer.BufferedStreamConsumer;
import io.airbyte.integrations.destination.buffered_stream_consumer.OnCloseFunction;
import io.airbyte.integrations.destination.buffered_stream_consumer.OnStartFunction;
import io.airbyte.integrations.destination.buffered_stream_consumer.RecordWriter;
import io.airbyte.integrations.destination.jdbc.SqlOperations;
import io.airbyte.integrations.destination.jdbc.WriteConfig;
import io.airbyte.protocol.models.AirbyteMessage;
import io.airbyte.protocol.models.AirbyteStream;
import io.airbyte.protocol.models.ConfiguredAirbyteCatalog;
import io.airbyte.protocol.models.ConfiguredAirbyteStream;
import io.airbyte.protocol.models.DestinationSyncMode;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SnowflakeInternalStagingConsumerFactory {
tuliren marked this conversation as resolved.
Show resolved Hide resolved

private static final Logger LOGGER = LoggerFactory.getLogger(SnowflakeInternalStagingConsumerFactory.class);

private static final int MAX_BATCH_SIZE_BYTES = 1024 * 1024 * 1024 / 4; // 256mb
tuliren marked this conversation as resolved.
Show resolved Hide resolved

public static AirbyteMessageConsumer create(final Consumer<AirbyteMessage> outputRecordCollector,
final JdbcDatabase database,
final SnowflakeStagingSqlOperations sqlOperations,
final SnowflakeSQLNameTransformer namingResolver,
final JsonNode config,
final ConfiguredAirbyteCatalog catalog) {
final List<WriteConfig> writeConfigs = createWriteConfigs(namingResolver, config, catalog);

return new BufferedStreamConsumer(
outputRecordCollector,
onStartFunction(database, sqlOperations, writeConfigs, namingResolver),
recordWriterFunction(database, sqlOperations, writeConfigs, catalog, namingResolver),
onCloseFunction(database, sqlOperations, writeConfigs, namingResolver),
catalog,
sqlOperations::isValidData,
MAX_BATCH_SIZE_BYTES);
}

private static List<WriteConfig> createWriteConfigs(final NamingConventionTransformer namingResolver,
final JsonNode config,
final ConfiguredAirbyteCatalog catalog) {

return catalog.getStreams().stream().map(toWriteConfig(namingResolver, config)).collect(Collectors.toList());
}

private static Function<ConfiguredAirbyteStream, WriteConfig> toWriteConfig(
final NamingConventionTransformer namingResolver,
final JsonNode config) {
return stream -> {
Preconditions.checkNotNull(stream.getDestinationSyncMode(), "Undefined destination sync mode");
final AirbyteStream abStream = stream.getStream();

final String outputSchema = getOutputSchema(abStream, namingResolver.getIdentifier(config.get("schema").asText()));

final String streamName = abStream.getName();
final String tableName = namingResolver.getRawTableName(streamName);
final String tmpTableName = namingResolver.getTmpTableName(streamName);
final DestinationSyncMode syncMode = stream.getDestinationSyncMode();

final WriteConfig writeConfig = new WriteConfig(streamName, abStream.getNamespace(), outputSchema, tmpTableName, tableName, syncMode);
LOGGER.info("Write config: {}", writeConfig);

return writeConfig;
};
}

private static String getOutputSchema(final AirbyteStream stream, final String defaultDestSchema) {
final String sourceSchema = stream.getNamespace();
if (sourceSchema != null) {
return sourceSchema;
}
return defaultDestSchema;
}

private static OnStartFunction onStartFunction(final JdbcDatabase database,
final SnowflakeStagingSqlOperations snowflakeSqlOperations,
final List<WriteConfig> writeConfigs,
final SnowflakeSQLNameTransformer namingResolver) {
return () -> {
LOGGER.info("Preparing tmp tables in destination started for {} streams", writeConfigs.size());
for (final WriteConfig writeConfig : writeConfigs) {
final String schemaName = writeConfig.getOutputSchemaName();
final String tmpTableName = writeConfig.getTmpTableName();
LOGGER.info("Preparing tmp table in destination started for stream {}. schema: {}, tmp table name: {}", writeConfig.getStreamName(),
schemaName, tmpTableName);
final String outputTableName = writeConfig.getOutputTableName();
final String stageName = namingResolver.getStageName(schemaName, outputTableName);
LOGGER.info("Preparing stage in destination started for stream {}. schema: {}, stage: {}", writeConfig.getStreamName(),
schemaName, stageName);

snowflakeSqlOperations.createSchemaIfNotExists(database, schemaName);
snowflakeSqlOperations.createTableIfNotExists(database, schemaName, tmpTableName);
snowflakeSqlOperations.createStageIfNotExists(database, stageName);
LOGGER.info("Preparing stages in destination completed " + stageName);

}
LOGGER.info("Preparing tables in destination completed.");
};
}

private static AirbyteStreamNameNamespacePair toNameNamespacePair(final WriteConfig config) {
return new AirbyteStreamNameNamespacePair(config.getStreamName(), config.getNamespace());
}

private static RecordWriter recordWriterFunction(final JdbcDatabase database,
final SqlOperations snowflakeSqlOperations,
final List<WriteConfig> writeConfigs,
final ConfiguredAirbyteCatalog catalog,
final SnowflakeSQLNameTransformer namingResolver) {
final Map<AirbyteStreamNameNamespacePair, WriteConfig> pairToWriteConfig =
writeConfigs.stream()
.collect(Collectors.toUnmodifiableMap(
SnowflakeInternalStagingConsumerFactory::toNameNamespacePair, Function.identity()));

return (pair, records) -> {
if (!pairToWriteConfig.containsKey(pair)) {
throw new IllegalArgumentException(
String.format("Message contained record from a stream that was not in the catalog. \ncatalog: %s", Jsons.serialize(catalog)));
}

final WriteConfig writeConfig = pairToWriteConfig.get(pair);
final String schemaName = writeConfig.getOutputSchemaName();
final String tableName = writeConfig.getOutputTableName();
final String stageName = namingResolver.getStageName(schemaName, tableName);

snowflakeSqlOperations.insertRecords(database, records, schemaName, stageName);
};
}

private static OnCloseFunction onCloseFunction(final JdbcDatabase database,
final SnowflakeStagingSqlOperations sqlOperations,
final List<WriteConfig> writeConfigs,
final SnowflakeSQLNameTransformer namingResolver) {
return (hasFailed) -> {
if (!hasFailed) {
final List<String> queryList = new ArrayList<>();
LOGGER.info("Finalizing tables in destination started for {} streams", writeConfigs.size());
for (final WriteConfig writeConfig : writeConfigs) {
final String schemaName = writeConfig.getOutputSchemaName();
final String srcTableName = writeConfig.getTmpTableName();
final String dstTableName = writeConfig.getOutputTableName();
LOGGER.info("Finalizing stream {}. schema {}, tmp table {}, final table {}", writeConfig.getStreamName(), schemaName, srcTableName,
dstTableName);

final String stageName = namingResolver.getStageName(schemaName, dstTableName);
sqlOperations.copyIntoTmpTableFromStage(database, stageName, srcTableName, schemaName);
LOGGER.info("Uploading data from stage: stream {}. schema {}, tmp table {}, stage {}", writeConfig.getStreamName(), schemaName,
srcTableName,
stageName);
sqlOperations.createTableIfNotExists(database, schemaName, dstTableName);
switch (writeConfig.getSyncMode()) {
case OVERWRITE -> queryList.add(sqlOperations.truncateTableQuery(database, schemaName, dstTableName));
case APPEND, APPEND_DEDUP -> {}
default -> throw new IllegalStateException("Unrecognized sync mode: " + writeConfig.getSyncMode());
}
queryList.add(sqlOperations.copyTableQuery(database, schemaName, srcTableName, dstTableName));
}

LOGGER.info("Executing finalization of tables.");
sqlOperations.executeTransaction(database, queryList);
LOGGER.info("Finalizing tables in destination completed.");
}
LOGGER.info("Cleaning tmp tables in destination started for {} streams", writeConfigs.size());
for (final WriteConfig writeConfig : writeConfigs) {
final String schemaName = writeConfig.getOutputSchemaName();
final String tmpTableName = writeConfig.getTmpTableName();
LOGGER.info("Cleaning tmp table in destination started for stream {}. schema {}, tmp table name: {}", writeConfig.getStreamName(), schemaName,
tmpTableName);

sqlOperations.dropTableIfExists(database, schemaName, tmpTableName);
final String outputTableName = writeConfig.getOutputTableName();
final String stageName = namingResolver.getStageName(schemaName, outputTableName);
LOGGER.info("Cleaning stage in destination started for stream {}. schema {}, stage: {}", writeConfig.getStreamName(), schemaName,
stageName);
sqlOperations.dropStageIfExists(database, stageName);
}
LOGGER.info("Cleaning tmp tables and stages in destination completed.");
};
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright (c) 2021 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.integrations.destination.snowflake;

import com.fasterxml.jackson.databind.JsonNode;
import io.airbyte.commons.json.Jsons;
import io.airbyte.db.jdbc.JdbcDatabase;
import io.airbyte.integrations.base.AirbyteMessageConsumer;
import io.airbyte.integrations.base.Destination;
import io.airbyte.integrations.destination.jdbc.AbstractJdbcDestination;
import io.airbyte.protocol.models.AirbyteMessage;
import io.airbyte.protocol.models.ConfiguredAirbyteCatalog;
import java.util.function.Consumer;

public class SnowflakeInternalStagingDestination extends AbstractJdbcDestination implements Destination {

public SnowflakeInternalStagingDestination() {
super("", new SnowflakeSQLNameTransformer(), new SnowflakeStagingSqlOperations());
}

@Override
protected JdbcDatabase getDatabase(final JsonNode config) {
return SnowflakeDatabase.getDatabase(config);
}

// this is a no op since we override getDatabase.
@Override
public JsonNode toJdbcConfig(final JsonNode config) {
return Jsons.emptyObject();
}

@Override
public AirbyteMessageConsumer getConsumer(final JsonNode config,
final ConfiguredAirbyteCatalog catalog,
final Consumer<AirbyteMessage> outputRecordCollector) {
return SnowflakeInternalStagingConsumerFactory.create(outputRecordCollector, getDatabase(config),
new SnowflakeStagingSqlOperations(), new SnowflakeSQLNameTransformer(), config, catalog);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,8 @@ protected String applyDefaultCase(final String input) {
return input.toUpperCase();
}

public String getStageName(String schemaName, String outputTableName) {
return schemaName.concat(outputTableName).replaceAll("-", "_").toUpperCase();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright (c) 2021 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.integrations.destination.snowflake;

import com.google.common.collect.Iterables;
import io.airbyte.db.jdbc.JdbcDatabase;
import io.airbyte.integrations.base.JavaBaseConstants;
import io.airbyte.integrations.destination.jdbc.JdbcSqlOperations;
import io.airbyte.integrations.destination.jdbc.SqlOperations;
import io.airbyte.protocol.models.AirbyteRecordMessage;
import java.io.File;
import java.nio.file.Files;
import java.sql.SQLException;
import java.util.List;
import java.util.UUID;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SnowflakeStagingSqlOperations extends JdbcSqlOperations implements SqlOperations {

private static final Logger LOGGER = LoggerFactory.getLogger(SnowflakeSqlOperations.class);

@Override
protected void insertRecordsInternal(JdbcDatabase database, List<AirbyteRecordMessage> records, String schemaName, String stage) throws Exception {
LOGGER.info("actual size of batch for staging: {}", records.size());

if (records.isEmpty()) {
return;
}
try {
loadDataIntoStage(database, stage, records);
} catch (Exception e) {
LOGGER.error("Failed to upload records into stage {}", stage, e);
throw new RuntimeException(e);
}
tuliren marked this conversation as resolved.
Show resolved Hide resolved
}

private void loadDataIntoStage(JdbcDatabase database, String stage, List<AirbyteRecordMessage> partition) throws Exception {
final File tempFile = Files.createTempFile(UUID.randomUUID().toString(), ".csv").toFile();
writeBatchToFile(tempFile, partition);
database.execute(String.format("PUT file://%s @%s PARALLEL = %d", tempFile.getAbsolutePath(), stage, Runtime.getRuntime().availableProcessors()));
Files.delete(tempFile.toPath());
}

public void createStageIfNotExists(final JdbcDatabase database, final String stageName) throws SQLException {
database.execute(String.format("CREATE STAGE IF NOT EXISTS %s encryption = (type = 'SNOWFLAKE_SSE')" +
" copy_options = (on_error='skip_file');", stageName));
}

public void copyIntoTmpTableFromStage(JdbcDatabase database, String stageName, String dstTableName, String schemaName) throws SQLException {
database.execute(String.format("COPY INTO %s.%s FROM @%s file_format = " +
"(type = csv field_delimiter = ',' skip_header = 0 FIELD_OPTIONALLY_ENCLOSED_BY = '\"')",
schemaName,
dstTableName,
stageName));

}

public void dropStageIfExists(final JdbcDatabase database, final String stageName) throws SQLException {
database.execute(String.format("DROP STAGE IF EXISTS %s;", stageName));
}

@Override
public void createTableIfNotExists(final JdbcDatabase database, final String schemaName, final String tableName) throws SQLException {
database.execute(createTableQuery(database, schemaName, tableName));
}

@Override
public String createTableQuery(final JdbcDatabase database, final String schemaName, final String tableName) {
return String.format(
"CREATE TABLE IF NOT EXISTS %s.%s ( \n"
+ "%s VARCHAR PRIMARY KEY,\n"
+ "%s VARIANT,\n"
+ "%s TIMESTAMP WITH TIME ZONE DEFAULT current_timestamp()\n"
+ ") data_retention_time_in_days = 0;",
schemaName, tableName, JavaBaseConstants.COLUMN_NAME_AB_ID, JavaBaseConstants.COLUMN_NAME_DATA, JavaBaseConstants.COLUMN_NAME_EMITTED_AT);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,19 @@
"order": 3
}
}
},
{
"title": "Internal Staging",
tuliren marked this conversation as resolved.
Show resolved Hide resolved
"additionalProperties": false,
"description": "Writes large batches of records to a file, uploads the file to Snowflake, then uses <pre>COPY INTO table</pre> to upload the file. Recommended for large production workloads for better speed and scalability.",
"required": ["method"],
"properties": {
"method": {
"type": "string",
"enum": ["Internal Staging"],
"default": "Internal Staging"
}
}
tuliren marked this conversation as resolved.
Show resolved Hide resolved
}
]
}
Expand Down
Loading