Skip to content

Commit

Permalink
convert destination-snowflake to Kotlin CDK
Browse files Browse the repository at this point in the history
  • Loading branch information
stephane-airbyte committed Apr 8, 2024
1 parent 9cd72c3 commit 625e34b
Show file tree
Hide file tree
Showing 19 changed files with 196 additions and 185 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class AirbyteExceptionHandler : Thread.UncaughtExceptionHandler {
}
}

@JvmStatic
fun addAllStringsInConfigForDeinterpolation(node: JsonNode) {
if (node.isTextual) {
addStringForDeinterpolation(node.asText())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ object AdaptiveDestinationRunner {
private const val DEPLOYMENT_MODE_KEY = EnvVariableFeatureFlags.DEPLOYMENT_MODE
private const val CLOUD_MODE = "CLOUD"

@JvmStatic
fun baseOnEnv(): OssDestinationBuilder {
val mode = System.getenv(DEPLOYMENT_MODE_KEY)
return OssDestinationBuilder(mode)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ abstract class JdbcSqlOperations : SqlOperations {
* @param e the exception to check.
* @return A ConfigErrorException with a message with actionable feedback to the user.
*/
protected fun checkForKnownConfigExceptions(e: Exception?): Optional<ConfigErrorException> {
protected open fun checkForKnownConfigExceptions(
e: Exception?
): Optional<ConfigErrorException> {
return Optional.empty()
}

Expand Down Expand Up @@ -206,7 +208,7 @@ abstract class JdbcSqlOperations : SqlOperations {
}
}

fun dropTableIfExistsQuery(schemaName: String?, tableName: String?): String {
open fun dropTableIfExistsQuery(schemaName: String?, tableName: String?): String {
return String.format("DROP TABLE IF EXISTS %s.%s;\n", schemaName, tableName)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ abstract class JdbcDestinationHandler<DestinationState>(
existingTable.columns[JavaBaseConstants.COLUMN_NAME_AB_META]!!.type
}

protected fun existingSchemaMatchesStreamConfig(
open protected fun existingSchemaMatchesStreamConfig(
stream: StreamConfig?,
existingTable: TableDefinition
): Boolean {
Expand Down Expand Up @@ -543,6 +543,7 @@ abstract class JdbcDestinationHandler<DestinationState>(
return Optional.of(TableDefinition(retrievedColumnDefns))
}

@JvmStatic
fun fromIsNullableIsoString(isNullable: String?): Boolean {
return "YES".equals(isNullable, ignoreCase = true)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1464,7 +1464,7 @@ abstract class DestinationAcceptanceTest {
}

/** Whether the destination should be tested against different namespaces. */
protected fun supportNamespaceTest(): Boolean {
open protected fun supportNamespaceTest(): Boolean {
return false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ abstract class BaseDestinationV1V2Migrator<DialectTableDefinition> : Destination
* @return whether it exists and is in the correct format
*/
@Throws(Exception::class)
protected fun doesValidV1RawTableExist(namespace: String?, tableName: String?): Boolean {
protected open fun doesValidV1RawTableExist(namespace: String?, tableName: String?): Boolean {
val existingV1RawTable = getTableIfExists(namespace, tableName)
return existingV1RawTable.isPresent &&
doesV1RawTableMatchExpectedSchema(existingV1RawTable.get())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ object CollectionUtils {
* @param searchTerms the keys you're looking for
* @return whether all searchTerms are in the searchCollection
*/
@JvmStatic
fun containsAllIgnoreCase(
searchCollection: Collection<String>,
searchTerms: Collection<String>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ abstract class BaseSqlGeneratorIntegrationTest<DestinationState : MinimumDestina
protected var cdcIncrementalAppendStream: StreamConfig = mock()

protected var generator: SqlGenerator = mock()
protected abstract var destinationHandler: DestinationHandler<DestinationState>
protected abstract val destinationHandler: DestinationHandler<DestinationState>
protected var namespace: String = mock()

protected var streamId: StreamId = mock()
Expand All @@ -74,7 +74,7 @@ abstract class BaseSqlGeneratorIntegrationTest<DestinationState : MinimumDestina
* Subclasses should override this method if they need to make changes to the stream ID. For
* example, you could upcase the final table name here.
*/
protected fun buildStreamId(
open protected fun buildStreamId(
namespace: String,
finalTableName: String,
rawTableName: String
Expand Down Expand Up @@ -143,7 +143,7 @@ abstract class BaseSqlGeneratorIntegrationTest<DestinationState : MinimumDestina
/** Identical to [BaseTypingDedupingTest.getRawMetadataColumnNames]. */
get() = HashMap()

protected val finalMetadataColumnNames: Map<String, String>
open protected val finalMetadataColumnNames: Map<String, String>
/** Identical to [BaseTypingDedupingTest.getFinalMetadataColumnNames]. */
get() = HashMap()

Expand Down Expand Up @@ -728,7 +728,7 @@ abstract class BaseSqlGeneratorIntegrationTest<DestinationState : MinimumDestina
*/
@Test
@Throws(Exception::class)
fun ignoreOldRawRecords() {
open fun ignoreOldRawRecords() {
createRawTable(streamId)
createFinalTable(incrementalAppendStream, "")
insertRawTableRecords(
Expand Down Expand Up @@ -1523,7 +1523,10 @@ abstract class BaseSqlGeneratorIntegrationTest<DestinationState : MinimumDestina
executeSoftReset(generator!!, destinationHandler!!, incrementalAppendStream!!)
}

protected fun migrationAssertions(v1RawRecords: List<JsonNode>, v2RawRecords: List<JsonNode>) {
protected open fun migrationAssertions(
v1RawRecords: List<JsonNode>,
v2RawRecords: List<JsonNode>
) {
val v2RecordMap =
v2RawRecords
.stream()
Expand Down Expand Up @@ -1574,7 +1577,7 @@ abstract class BaseSqlGeneratorIntegrationTest<DestinationState : MinimumDestina
}

@Throws(Exception::class)
protected fun dumpV1RawTableRecords(streamId: StreamId): List<JsonNode> {
open protected fun dumpV1RawTableRecords(streamId: StreamId): List<JsonNode> {
return dumpRawTableRecords(streamId)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ abstract class BaseTypingDedupingTest {
/** Conceptually identical to [.getFinalMetadataColumnNames], but for the raw table. */
get() = HashMap()

val finalMetadataColumnNames: Map<String, String>
open val finalMetadataColumnNames: Map<String, String>
/**
* If the destination connector uses a nonstandard schema for the final table, override this
* method. For example, destination-snowflake upcases all column names in the final tables.
Expand Down Expand Up @@ -1075,7 +1075,7 @@ abstract class BaseTypingDedupingTest {

companion object {
private val LOGGER: Logger = LoggerFactory.getLogger(BaseTypingDedupingTest::class.java)
protected val SCHEMA: JsonNode
@JvmField protected val SCHEMA: JsonNode

init {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ plugins {
airbyteJavaConnector {
cdkVersionRequired = '0.27.7'
features = ['db-destinations', 's3-destinations', 'typing-deduping']
useLocalCdk = false
useLocalCdk = true
}

java {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import com.fasterxml.jackson.databind.JsonNode;
import io.airbyte.cdk.db.jdbc.JdbcDatabase;
import io.airbyte.cdk.integrations.base.JavaBaseConstants;
import io.airbyte.cdk.integrations.destination.async.partial_messages.PartialAirbyteMessage;
import io.airbyte.cdk.integrations.destination.async.model.PartialAirbyteMessage;
import io.airbyte.cdk.integrations.destination.jdbc.JdbcSqlOperations;
import io.airbyte.cdk.integrations.destination.jdbc.SqlOperations;
import io.airbyte.cdk.integrations.destination.jdbc.SqlOperationsUtils;
Expand All @@ -34,10 +34,10 @@ class SnowflakeSqlOperations extends JdbcSqlOperations implements SqlOperations
@Override
public void createSchemaIfNotExists(final JdbcDatabase database, final String schemaName) throws Exception {
try {
if (!schemaSet.contains(schemaName) && !isSchemaExists(database, schemaName)) {
if (!getSchemaSet().contains(schemaName) && !isSchemaExists(database, schemaName)) {
// 1s1t is assuming a lowercase airbyte_internal schema name, so we need to quote it
database.execute(String.format("CREATE SCHEMA IF NOT EXISTS \"%s\";", schemaName));
schemaSet.add(schemaName);
getSchemaSet().add(schemaName);
}
} catch (final Exception e) {
throw checkForKnownConfigExceptions(e).orElseThrow(() -> e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ public static LinkedHashMap<String, LinkedHashMap<String, TableDefinition>> find
final LinkedHashMap<String, LinkedHashMap<String, TableDefinition>> existingTables = new LinkedHashMap<>();
final String paramHolder = String.join(",", Collections.nCopies(streamIds.size(), "?"));
// convert list stream to array
final String[] namespaces = streamIds.stream().map(StreamId::finalNamespace).toArray(String[]::new);
final String[] names = streamIds.stream().map(StreamId::finalName).toArray(String[]::new);
final String[] namespaces = streamIds.stream().map(StreamId::getFinalNamespace).toArray(String[]::new);
final String[] names = streamIds.stream().map(StreamId::getFinalName).toArray(String[]::new);
final String query = """
SELECT table_schema, table_name, column_name, data_type, is_nullable
FROM information_schema.columns
Expand Down Expand Up @@ -103,8 +103,8 @@ private LinkedHashMap<String, LinkedHashMap<String, Integer>> getFinalTableRowCo
final LinkedHashMap<String, LinkedHashMap<String, Integer>> tableRowCounts = new LinkedHashMap<>();
final String paramHolder = String.join(",", Collections.nCopies(streamIds.size(), "?"));
// convert list stream to array
final String[] namespaces = streamIds.stream().map(StreamId::finalNamespace).toArray(String[]::new);
final String[] names = streamIds.stream().map(StreamId::finalName).toArray(String[]::new);
final String[] namespaces = streamIds.stream().map(StreamId::getFinalNamespace).toArray(String[]::new);
final String[] names = streamIds.stream().map(StreamId::getFinalName).toArray(String[]::new);
final String query = """
SELECT table_schema, table_name, row_count
FROM information_schema.tables
Expand Down Expand Up @@ -133,8 +133,8 @@ private InitialRawTableStatus getInitialRawTableState(final StreamId id, final D
}
final ResultSet tables = database.getMetaData().getTables(
databaseName,
id.rawNamespace(),
id.rawName(),
id.getRawNamespace(),
id.getRawName(),
null);
if (!tables.next()) {
return new InitialRawTableStatus(false, false, Optional.empty());
Expand Down Expand Up @@ -227,25 +227,26 @@ public void execute(final Sql sql) throws Exception {
}

private Set<String> getPks(final StreamConfig stream) {
return stream.primaryKey() != null ? stream.primaryKey().stream().map(ColumnId::name).collect(Collectors.toSet()) : Collections.emptySet();
return stream.getPrimaryKey() != null ? stream.getPrimaryKey().stream().map(ColumnId::getName).collect(Collectors.toSet())
: Collections.emptySet();
}

private boolean isAirbyteRawIdColumnMatch(final TableDefinition existingTable) {
final String abRawIdColumnName = COLUMN_NAME_AB_RAW_ID.toUpperCase();
return existingTable.columns().containsKey(abRawIdColumnName) &&
toJdbcTypeName(AirbyteProtocolType.STRING).equals(existingTable.columns().get(abRawIdColumnName).type());
toJdbcTypeName(AirbyteProtocolType.STRING).equals(existingTable.columns().get(abRawIdColumnName).getType());
}

private boolean isAirbyteExtractedAtColumnMatch(final TableDefinition existingTable) {
final String abExtractedAtColumnName = COLUMN_NAME_AB_EXTRACTED_AT.toUpperCase();
return existingTable.columns().containsKey(abExtractedAtColumnName) &&
toJdbcTypeName(AirbyteProtocolType.TIMESTAMP_WITH_TIMEZONE).equals(existingTable.columns().get(abExtractedAtColumnName).type());
toJdbcTypeName(AirbyteProtocolType.TIMESTAMP_WITH_TIMEZONE).equals(existingTable.columns().get(abExtractedAtColumnName).getType());
}

private boolean isAirbyteMetaColumnMatch(TableDefinition existingTable) {
final String abMetaColumnName = COLUMN_NAME_AB_META.toUpperCase();
return existingTable.columns().containsKey(abMetaColumnName) &&
"VARIANT".equals(existingTable.columns().get(abMetaColumnName).type());
"VARIANT".equals(existingTable.columns().get(abMetaColumnName).getType());
}

protected boolean existingSchemaMatchesStreamConfig(final StreamConfig stream, final TableDefinition existingTable) {
Expand All @@ -259,17 +260,17 @@ protected boolean existingSchemaMatchesStreamConfig(final StreamConfig stream, f
// Missing AB meta columns from final table, we need them to do proper T+D so trigger soft-reset
return false;
}
final LinkedHashMap<String, String> intendedColumns = stream.columns().entrySet().stream()
final LinkedHashMap<String, String> intendedColumns = stream.getColumns().entrySet().stream()
.collect(LinkedHashMap::new,
(map, column) -> map.put(column.getKey().name(), toJdbcTypeName(column.getValue())),
(map, column) -> map.put(column.getKey().getName(), toJdbcTypeName(column.getValue())),
LinkedHashMap::putAll);

// Filter out Meta columns since they don't exist in stream config.
final LinkedHashMap<String, String> actualColumns = existingTable.columns().entrySet().stream()
.filter(column -> V2_FINAL_TABLE_METADATA_COLUMNS.stream().map(String::toUpperCase)
.noneMatch(airbyteColumnName -> airbyteColumnName.equals(column.getKey())))
.collect(LinkedHashMap::new,
(map, column) -> map.put(column.getKey(), column.getValue().type()),
(map, column) -> map.put(column.getKey(), column.getValue().getType()),
LinkedHashMap::putAll);
// soft-resetting https://github.com/airbytehq/airbyte/pull/31082
@SuppressWarnings("deprecation")
Expand All @@ -285,13 +286,13 @@ protected boolean existingSchemaMatchesStreamConfig(final StreamConfig stream, f
public List<DestinationInitialStatus<SnowflakeState>> gatherInitialState(List<StreamConfig> streamConfigs) throws Exception {
final Map<AirbyteStreamNameNamespacePair, SnowflakeState> destinationStates = super.getAllDestinationStates();

List<StreamId> streamIds = streamConfigs.stream().map(StreamConfig::id).toList();
List<StreamId> streamIds = streamConfigs.stream().map(StreamConfig::getId).toList();
final LinkedHashMap<String, LinkedHashMap<String, TableDefinition>> existingTables = findExistingTables(database, databaseName, streamIds);
final LinkedHashMap<String, LinkedHashMap<String, Integer>> tableRowCounts = getFinalTableRowCount(streamIds);
return streamConfigs.stream().map(streamConfig -> {
try {
final String namespace = streamConfig.id().finalNamespace().toUpperCase();
final String name = streamConfig.id().finalName().toUpperCase();
final String namespace = streamConfig.getId().getFinalNamespace().toUpperCase();
final String name = streamConfig.getId().getFinalName().toUpperCase();
boolean isSchemaMismatch = false;
boolean isFinalTableEmpty = true;
boolean isFinalTablePresent = existingTables.containsKey(namespace) && existingTables.get(namespace).containsKey(name);
Expand All @@ -301,8 +302,9 @@ public List<DestinationInitialStatus<SnowflakeState>> gatherInitialState(List<St
isSchemaMismatch = !existingSchemaMatchesStreamConfig(streamConfig, existingTable);
isFinalTableEmpty = hasRowCount && tableRowCounts.get(namespace).get(name) == 0;
}
final InitialRawTableStatus initialRawTableState = getInitialRawTableState(streamConfig.id(), streamConfig.destinationSyncMode());
final SnowflakeState destinationState = destinationStates.getOrDefault(streamConfig.id().asPair(), toDestinationState(Jsons.emptyObject()));
final InitialRawTableStatus initialRawTableState = getInitialRawTableState(streamConfig.getId(), streamConfig.getDestinationSyncMode());
final SnowflakeState destinationState =
destinationStates.getOrDefault(streamConfig.getId().asPair(), toDestinationState(Jsons.emptyObject()));
return new DestinationInitialStatus<>(
streamConfig,
isFinalTablePresent,
Expand Down
Loading

0 comments on commit 625e34b

Please sign in to comment.