Skip to content

Commit

Permalink
add refreshes to databricks
Browse files Browse the repository at this point in the history
  • Loading branch information
edgao committed Jul 12, 2024
1 parent c019d19 commit a459fb0
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ plugins {
airbyteJavaConnector {
cdkVersionRequired = '0.38.3'
features = ['db-destinations', 's3-destinations', 'typing-deduping']
useLocalCdk = false
useLocalCdk = true
}

//remove once upgrading the CDK version to 0.4.x or later
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,9 @@ class DatabricksDestination : BaseConnector(), Destination {
syncId = 0
)

val noSuffix = ""
try {
storageOperations.prepareStage(streamId, DestinationSyncMode.OVERWRITE)
storageOperations.prepareStage(streamId, suffix = noSuffix)
} catch (e: Exception) {
log.error(e) { "Failed to prepare stage as part of CHECK" }
return AirbyteConnectionStatus()
Expand All @@ -116,7 +117,7 @@ class DatabricksDestination : BaseConnector(), Destination {
System.currentTimeMillis()
)
it.flush()
storageOperations.writeToStage(streamConfig, writeBuffer)
storageOperations.writeToStage(streamConfig, suffix = noSuffix, writeBuffer)
}
} catch (e: Exception) {
log.error(e) { "Failed to write to stage as part of CHECK" }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

package io.airbyte.integrations.destination.databricks.jdbc

import com.fasterxml.jackson.databind.JsonNode
import io.airbyte.cdk.db.jdbc.JdbcDatabase
import io.airbyte.cdk.integrations.base.JavaBaseConstants
import io.airbyte.cdk.integrations.base.JavaBaseConstants.COLUMN_NAME_AB_EXTRACTED_AT
Expand All @@ -12,6 +13,7 @@ import io.airbyte.cdk.integrations.base.JavaBaseConstants.COLUMN_NAME_AB_META
import io.airbyte.cdk.integrations.base.JavaBaseConstants.COLUMN_NAME_AB_RAW_ID
import io.airbyte.cdk.integrations.destination.jdbc.ColumnDefinition
import io.airbyte.cdk.integrations.destination.jdbc.TableDefinition
import io.airbyte.integrations.base.destination.operation.AbstractStreamOperation
import io.airbyte.integrations.base.destination.typing_deduping.AirbyteProtocolType
import io.airbyte.integrations.base.destination.typing_deduping.AirbyteProtocolType.STRING
import io.airbyte.integrations.base.destination.typing_deduping.AirbyteProtocolType.TIMESTAMP_WITH_TIMEZONE
Expand All @@ -22,15 +24,16 @@ import io.airbyte.integrations.base.destination.typing_deduping.Sql
import io.airbyte.integrations.base.destination.typing_deduping.StreamConfig
import io.airbyte.integrations.base.destination.typing_deduping.StreamId
import io.airbyte.integrations.base.destination.typing_deduping.migrators.MinimumDestinationState
import io.airbyte.protocol.models.v0.DestinationSyncMode
import io.github.oshai.kotlinlogging.KotlinLogging
import java.sql.Connection
import java.sql.ResultSet
import java.sql.SQLException
import java.time.Instant
import java.time.LocalDateTime
import java.time.ZoneOffset
import java.util.*
import java.util.Objects
import java.util.Optional
import java.util.UUID
import kotlin.streams.asSequence

class DatabricksDestinationHandler(
Expand Down Expand Up @@ -79,35 +82,40 @@ class DatabricksDestinationHandler(
.map {
val namespace = it.id.finalNamespace
val name = it.id.finalName
val initialRawTableStatus =
if (it.destinationSyncMode == DestinationSyncMode.OVERWRITE)
InitialRawTableStatus(
rawTableExists = false,
hasUnprocessedRecords = false,
maxProcessedTimestamp = Optional.empty(),
)
else getInitialRawTableState(it.id)
// finalTablePresent
val initialRawTableStatus = getInitialRawTableState(it.id, suffix = "")
val initialTempRawTableStatus =
getInitialRawTableState(
it.id,
suffix = AbstractStreamOperation.TMP_TABLE_SUFFIX,
)
if (
existingTables.contains(namespace) &&
existingTables[namespace]?.contains(name) == true
) {
// The final table exists. Do some extra querying to find out what it looks
// like.
val isFinalTableSchemaMismatch =
!isSchemaMatch(it, existingTables[namespace]?.get(name)!!)
val isFinalTableEmpty = isFinalTableEmpty(it.id)
DestinationInitialStatus(
it,
true,
initialRawTableStatus,
!isSchemaMatch(it, existingTables[namespace]?.get(name)!!),
isFinalTableEmpty(it.id),
MinimumDestinationState.Impl(false),
isFinalTablePresent = true,
initialRawTableStatus = initialRawTableStatus,
initialTempRawTableStatus = initialTempRawTableStatus,
isFinalTableSchemaMismatch,
isFinalTableEmpty,
MinimumDestinationState.Impl(needsSoftReset = false),
)
} else {
// The final table doesn't exist, so no further querying to do.
DestinationInitialStatus(
it,
false,
initialRawTableStatus,
isFinalTablePresent = false,
initialRawTableStatus = initialRawTableStatus,
initialTempRawTableStatus = initialTempRawTableStatus,
isSchemaMismatch = false,
isFinalTableEmpty = true,
destinationState = MinimumDestinationState.Impl(false),
destinationState = MinimumDestinationState.Impl(needsSoftReset = false),
)
}
}
Expand All @@ -130,7 +138,7 @@ class DatabricksDestinationHandler(
"""
|SELECT table_schema, table_name, column_name, data_type, is_nullable
|FROM ${databaseName.lowercase()}.information_schema.columns
|WHERE
|WHERE
| table_catalog = ?
| AND table_schema IN ($paramHolder)
| AND table_name IN ($paramHolder)
Expand Down Expand Up @@ -235,7 +243,7 @@ class DatabricksDestinationHandler(
}
}

private fun getInitialRawTableState(id: StreamId): InitialRawTableStatus {
private fun getInitialRawTableState(id: StreamId, suffix: String): InitialRawTableStatus {
jdbcDatabase
.executeMetadataQuery { metadata ->
// Handle resultset call in the function which will be closed
Expand All @@ -244,7 +252,7 @@ class DatabricksDestinationHandler(
metadata.getTables(
databaseName,
id.rawNamespace,
id.rawName,
id.rawName + suffix,
null,
)
resultSet?.next() ?: false
Expand All @@ -262,13 +270,13 @@ class DatabricksDestinationHandler(
val minExtractedAtLoadedNotNullQuery =
"""
|SELECT min(`$COLUMN_NAME_AB_EXTRACTED_AT`) as last_loaded_at
|FROM $databaseName.${id.rawTableId(DatabricksSqlGenerator.QUOTE)}
|FROM $databaseName.${id.rawTableId(DatabricksSqlGenerator.QUOTE, suffix)}
|WHERE ${JavaBaseConstants.COLUMN_NAME_AB_LOADED_AT} IS NULL
|""".trimMargin()
val maxExtractedAtQuery =
"""
|SELECT max(`$COLUMN_NAME_AB_EXTRACTED_AT`) as last_loaded_at
|FROM $databaseName.${id.rawTableId(DatabricksSqlGenerator.QUOTE)}
|FROM $databaseName.${id.rawTableId(DatabricksSqlGenerator.QUOTE, suffix)}
""".trimMargin()

findLastLoadedTs(minExtractedAtLoadedNotNullQuery)
Expand All @@ -295,4 +303,6 @@ class DatabricksDestinationHandler(
) {
// do Nothing
}

fun query(query: String): List<JsonNode> = jdbcDatabase.queryJsons(query)
}
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,10 @@ class DatabricksSqlGenerator(
}

// Start: Functions scattered over other classes needed for T+D
fun createRawTable(streamId: StreamId): Sql {
fun createRawTable(streamId: StreamId, suffix: String): Sql {
return Sql.of(
"""
CREATE TABLE IF NOT EXISTS $unityCatalogName.${streamId.rawNamespace}.${streamId.rawName} (
CREATE TABLE IF NOT EXISTS $unityCatalogName.${streamId.rawNamespace}.${streamId.rawName}$suffix (
$AB_RAW_ID STRING,
$AB_EXTRACTED_AT TIMESTAMP,
$AB_LOADED_AT TIMESTAMP,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,27 @@

package io.airbyte.integrations.destination.databricks.operation

import io.airbyte.integrations.base.destination.typing_deduping.TyperDeduperUtil as tdutils
import com.databricks.sdk.WorkspaceClient
import io.airbyte.cdk.integrations.base.JavaBaseConstants.COLUMN_NAME_AB_GENERATION_ID
import io.airbyte.cdk.integrations.destination.record_buffer.SerializableBuffer
import io.airbyte.integrations.base.destination.operation.StorageOperation
import io.airbyte.integrations.base.destination.typing_deduping.DestinationHandler
import io.airbyte.integrations.base.destination.typing_deduping.Sql
import io.airbyte.integrations.base.destination.typing_deduping.SqlGenerator
import io.airbyte.integrations.base.destination.typing_deduping.StreamConfig
import io.airbyte.integrations.base.destination.typing_deduping.StreamId
import io.airbyte.integrations.base.destination.typing_deduping.TyperDeduperUtil as tdutils
import io.airbyte.integrations.base.destination.typing_deduping.migrators.MinimumDestinationState
import io.airbyte.integrations.destination.databricks.jdbc.DatabricksDestinationHandler
import io.airbyte.integrations.destination.databricks.jdbc.DatabricksSqlGenerator
import io.airbyte.protocol.models.v0.DestinationSyncMode
import io.github.oshai.kotlinlogging.KotlinLogging
import java.time.Instant
import java.time.ZoneOffset
import java.time.ZonedDateTime
import java.util.*
import java.util.Optional
import java.util.UUID

class DatabricksStorageOperation(
private val sqlGenerator: SqlGenerator,
private val destinationHandler: DestinationHandler<MinimumDestinationState.Impl>,
private val destinationHandler: DatabricksDestinationHandler,
private val workspaceClient: WorkspaceClient,
private val database: String,
private val purgeStagedFiles: Boolean = false
Expand All @@ -36,7 +36,11 @@ class DatabricksStorageOperation(
// Hoist them to SqlGenerator interface in CDK, until then using concrete instance.
private val databricksSqlGenerator = sqlGenerator as DatabricksSqlGenerator

override fun writeToStage(streamConfig: StreamConfig, data: SerializableBuffer) {
override fun writeToStage(
streamConfig: StreamConfig,
suffix: String,
data: SerializableBuffer
) {
val streamId = streamConfig.id
val stagedFile = "${stagingDirectory(streamId, database)}/${data.filename}"
workspaceClient.files().upload(stagedFile, data.inputStream)
Expand All @@ -46,7 +50,7 @@ class DatabricksStorageOperation(
// which can't be loaded into a bigint (int64) column.
// So we have to explicitly cast it to a bigint.
"""
COPY INTO `$database`.`${streamId.rawNamespace}`.`${streamId.rawName}`
COPY INTO `$database`.`${streamId.rawNamespace}`.`${streamId.rawName}$suffix`
FROM (
SELECT _airbyte_generation_id :: bigint, * except (_airbyte_generation_id)
FROM '$stagedFile'
Expand Down Expand Up @@ -93,16 +97,12 @@ class DatabricksStorageOperation(
)
}

private fun prepareStagingTable(streamId: StreamId, destinationSyncMode: DestinationSyncMode) {
val rawSchema = streamId.rawNamespace
// TODO: Optimize by running SHOW SCHEMAS; rather than CREATE SCHEMA if not exists
destinationHandler.execute(sqlGenerator.createSchema(rawSchema))

private fun prepareStagingTable(streamId: StreamId, suffix: String, replace: Boolean) {
// TODO: Optimize by running SHOW TABLES; truncate or create based on mode
// Create raw tables.
destinationHandler.execute(databricksSqlGenerator.createRawTable(streamId))
destinationHandler.execute(databricksSqlGenerator.createRawTable(streamId, suffix))
// Truncate the raw table if sync in OVERWRITE.
if (destinationSyncMode == DestinationSyncMode.OVERWRITE) {
if (replace) {
destinationHandler.execute(databricksSqlGenerator.truncateRawTable(streamId))
}
}
Expand All @@ -116,11 +116,49 @@ class DatabricksStorageOperation(
workspaceClient.files().createDirectory(stagingDirectory(streamId, database))
}

override fun prepareStage(streamId: StreamId, destinationSyncMode: DestinationSyncMode) {
prepareStagingTable(streamId, destinationSyncMode)
override fun prepareStage(streamId: StreamId, suffix: String, replace: Boolean) {
prepareStagingTable(streamId, suffix, replace)
prepareStagingVolume(streamId)
}

override fun overwriteStage(streamId: StreamId, suffix: String) {
// databricks recommends CREATE OR REPLACE ... AS SELECT
// instead of dropping the table and then doing more operations
// https://docs.databricks.com/en/delta/drop-table.html#when-to-replace-a-table
destinationHandler.execute(
Sql.of(
"""
CREATE OR REPLACE TABLE `$database`.`${streamId.rawNamespace}`.`${streamId.rawName}`
AS SELECT * FROM `$database`.`${streamId.rawNamespace}`.`${streamId.rawName}$suffix`
""".trimIndent(),
// TODO drop table
)
)
}

override fun transferFromTempStage(streamId: StreamId, suffix: String) {
destinationHandler.execute(
// Databricks doesn't support transactions, so we have to do these separately
Sql.separately(
"""
INSERT INTO `$database`.`${streamId.rawNamespace}`.`${streamId.rawName}`
SELECT * FROM `$database`.`${streamId.rawNamespace}`.`${streamId.rawName}$suffix`
""".trimIndent(),
"DROP TABLE `$database`.`${streamId.rawNamespace}`.`${streamId.rawName}$suffix`",
)
)
}

override fun getStageGeneration(streamId: StreamId, suffix: String): Long? {
val generationIds =
destinationHandler.query("SELECT $COLUMN_NAME_AB_GENERATION_ID FROM `$database`.`${streamId.rawNamespace}`.`${streamId.rawName}$suffix` LIMIT 1")
return if (generationIds.isEmpty()) {
null
} else {
generationIds.first()[COLUMN_NAME_AB_GENERATION_ID].asLong()
}
}

override fun cleanupStage(streamId: StreamId) {
if (purgeStagedFiles) {
// This operation might fail if there are files left over for any reason from COPY step
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ class DatabricksStreamOperation(
disableTypeDedupe = disableTypeDedupe
) {
private val log = KotlinLogging.logger {}
override fun writeRecords(streamConfig: StreamConfig, stream: Stream<PartialAirbyteMessage>) {
override fun writeRecordsImpl(
streamConfig: StreamConfig,
suffix: String,
stream: Stream<PartialAirbyteMessage>
) {
val writeBuffer = DatabricksFileBufferFactory.createBuffer(fileUploadFormat)
writeBuffer.use {
stream.forEach { record: PartialAirbyteMessage ->
Expand All @@ -54,7 +58,7 @@ class DatabricksStreamOperation(
)
}) to staging"
}
storageOperation.writeToStage(streamConfig, writeBuffer)
storageOperation.writeToStage(streamConfig, suffix, writeBuffer)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ class DatabricksSqlGeneratorIntegrationTest :
}

override fun createRawTable(streamId: StreamId) {
destinationHandler.execute(databricksSqlGenerator.createRawTable(streamId))
destinationHandler.execute(databricksSqlGenerator.createRawTable(streamId, suffix = ""))
}

override fun createV1RawTable(v1RawTable: StreamId) {
TODO("Not yet implemented")
throw NotImplementedError("Databricks does not support a V1->V2 migration")
}

override fun insertRawTableRecords(streamId: StreamId, records: List<JsonNode>) {
Expand Down Expand Up @@ -130,7 +130,7 @@ class DatabricksSqlGeneratorIntegrationTest :
}

override fun insertV1RawTableRecords(streamId: StreamId, records: List<JsonNode>) {
TODO("Not yet implemented")
throw NotImplementedError("Databricks does not support a V1->V2 migration")
}

override fun insertFinalTableRecords(
Expand Down

0 comments on commit a459fb0

Please sign in to comment.