From cd9e1957d97c6572312a2ad8c08e6793662916d6 Mon Sep 17 00:00:00 2001 From: Nathan <148575555+nathan-artie@users.noreply.github.com> Date: Mon, 13 May 2024 16:01:16 -0700 Subject: [PATCH] [mssql] Move merge query building to `MSSQLDialect` (#633) --- clients/mssql/dialect/dialect.go | 70 ++++++++++++++++++++ clients/mssql/dialect/dialect_test.go | 57 +++++++++++++++- lib/destination/dml/merge.go | 75 ++++----------------- lib/destination/dml/merge_mssql_test.go | 87 ------------------------- 4 files changed, 136 insertions(+), 153 deletions(-) delete mode 100644 lib/destination/dml/merge_mssql_test.go diff --git a/clients/mssql/dialect/dialect.go b/clients/mssql/dialect/dialect.go index f548a4df..74590375 100644 --- a/clients/mssql/dialect/dialect.go +++ b/clients/mssql/dialect/dialect.go @@ -1,14 +1,17 @@ package dialect import ( + "errors" "fmt" "strconv" "strings" + "github.com/artie-labs/transfer/lib/array" "github.com/artie-labs/transfer/lib/config/constants" "github.com/artie-labs/transfer/lib/kafkalib" "github.com/artie-labs/transfer/lib/sql" "github.com/artie-labs/transfer/lib/typing" + "github.com/artie-labs/transfer/lib/typing/columns" "github.com/artie-labs/transfer/lib/typing/ext" ) @@ -173,3 +176,70 @@ func (MSSQLDialect) BuildProcessToastStructColExpression(colName string) string func (MSSQLDialect) BuildDedupeQueries(tableID, stagingTableID sql.TableIdentifier, primaryKeys []string, topicConfig kafkalib.TopicConfig) []string { panic("not implemented") // We don't currently support deduping for MS SQL. } + +func (md MSSQLDialect) BuildMergeQueries( + tableID sql.TableIdentifier, + subQuery string, + idempotentKey string, + primaryKeys []columns.Column, + _ []string, + cols []columns.Column, + softDelete bool, + _ *bool, +) ([]string, error) { + var idempotentClause string + if idempotentKey != "" { + idempotentClause = fmt.Sprintf("AND cc.%s >= c.%s ", idempotentKey, idempotentKey) + } + + var equalitySQLParts []string + for _, primaryKey := range primaryKeys { + // We'll need to escape the primary key as well. + quotedPrimaryKey := md.QuoteIdentifier(primaryKey.Name()) + equalitySQL := fmt.Sprintf("c.%s = cc.%s", quotedPrimaryKey, quotedPrimaryKey) + equalitySQLParts = append(equalitySQLParts, equalitySQL) + } + + if softDelete { + return []string{fmt.Sprintf(` +MERGE INTO %s c +USING %s AS cc ON %s +WHEN MATCHED %sTHEN UPDATE SET %s +WHEN NOT MATCHED AND COALESCE(cc.%s, 0) = 0 THEN INSERT (%s) VALUES (%s);`, + tableID.FullyQualifiedName(), subQuery, strings.Join(equalitySQLParts, " and "), + // Update + Soft Deletion + idempotentClause, columns.BuildColumnsUpdateFragment(cols, md), + // Insert + md.QuoteIdentifier(constants.DeleteColumnMarker), strings.Join(columns.QuoteColumns(cols, md), ","), + array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{ + Vals: columns.QuoteColumns(cols, md), + Separator: ",", + Prefix: "cc.", + }))}, nil + } + + // We also need to remove __artie flags since it does not exist in the destination table + cols, removed := columns.RemoveDeleteColumnMarker(cols) + if !removed { + return nil, errors.New("artie delete flag doesn't exist") + } + + return []string{fmt.Sprintf(` +MERGE INTO %s c +USING %s AS cc ON %s +WHEN MATCHED AND cc.%s = 1 THEN DELETE +WHEN MATCHED AND COALESCE(cc.%s, 0) = 0 %sTHEN UPDATE SET %s +WHEN NOT MATCHED AND COALESCE(cc.%s, 1) = 0 THEN INSERT (%s) VALUES (%s);`, + tableID.FullyQualifiedName(), subQuery, strings.Join(equalitySQLParts, " and "), + // Delete + md.QuoteIdentifier(constants.DeleteColumnMarker), + // Update + md.QuoteIdentifier(constants.DeleteColumnMarker), idempotentClause, columns.BuildColumnsUpdateFragment(cols, md), + // Insert + md.QuoteIdentifier(constants.DeleteColumnMarker), strings.Join(columns.QuoteColumns(cols, md), ","), + array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{ + Vals: columns.QuoteColumns(cols, md), + Separator: ",", + Prefix: "cc.", + }))}, nil +} diff --git a/clients/mssql/dialect/dialect_test.go b/clients/mssql/dialect/dialect_test.go index c70ada2f..f8ed8715 100644 --- a/clients/mssql/dialect/dialect_test.go +++ b/clients/mssql/dialect/dialect_test.go @@ -2,12 +2,15 @@ package dialect import ( "fmt" + "strings" "testing" + "time" "github.com/artie-labs/transfer/lib/config/constants" "github.com/artie-labs/transfer/lib/mocks" "github.com/artie-labs/transfer/lib/ptr" "github.com/artie-labs/transfer/lib/typing" + "github.com/artie-labs/transfer/lib/typing/columns" "github.com/artie-labs/transfer/lib/typing/ext" "github.com/stretchr/testify/assert" ) @@ -146,10 +149,60 @@ func TestMSSQLDialect_BuildAlterColumnQuery(t *testing.T) { ) } -func TestBuildProcessToastColExpression(t *testing.T) { +func TestMSSQLDialect_BuildProcessToastColExpression(t *testing.T) { assert.Equal(t, `CASE WHEN COALESCE(cc.bar, '') != '__debezium_unavailable_value' THEN cc.bar ELSE c.bar END`, MSSQLDialect{}.BuildProcessToastColExpression("bar")) } -func TestBuildProcessToastStructColExpression(t *testing.T) { +func TestMSSQLDialect_BuildProcessToastStructColExpression(t *testing.T) { assert.Equal(t, `CASE WHEN COALESCE(cc.foo, {}) != {'key': '__debezium_unavailable_value'} THEN cc.foo ELSE c.foo END`, MSSQLDialect{}.BuildProcessToastStructColExpression("foo")) } + +func TestMSSQLDialect_BuildMergeQueries(t *testing.T) { + var _cols = []columns.Column{ + columns.NewColumn("id", typing.String), + columns.NewColumn("bar", typing.String), + columns.NewColumn("updated_at", typing.String), + columns.NewColumn("start", typing.String), + columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean), + } + cols := make([]string, len(_cols)) + for i, col := range _cols { + cols[i] = col.Name() + } + + tableValues := []string{ + fmt.Sprintf("('%s', '%s', '%v', '%v', false)", "1", "456", "foo", time.Now().Round(0).UTC()), + fmt.Sprintf("('%s', '%s', '%v', '%v', false)", "2", "bb", "bar", time.Now().Round(0).UTC()), + fmt.Sprintf("('%s', '%s', '%v', '%v', false)", "3", "dd", "world", time.Now().Round(0).UTC()), + } + + // select cc.foo, cc.bar from (values (12, 34), (44, 55)) as cc(foo, bar); + subQuery := fmt.Sprintf("SELECT %s from (values %s) as %s(%s)", + strings.Join(cols, ","), strings.Join(tableValues, ","), "_tbl", strings.Join(cols, ",")) + + fqTable := "database.schema.table" + fakeID := &mocks.FakeTableIdentifier{} + fakeID.FullyQualifiedNameReturns(fqTable) + + queries, err := MSSQLDialect{}.BuildMergeQueries( + fakeID, + subQuery, + "", + []columns.Column{columns.NewColumn("id", typing.Invalid)}, + []string{}, + _cols, + false, + nil, + ) + assert.Len(t, queries, 1) + mergeSQL := queries[0] + assert.NoError(t, err) + assert.Contains(t, mergeSQL, fmt.Sprintf("MERGE INTO %s", fqTable), mergeSQL) + assert.NotContains(t, mergeSQL, fmt.Sprintf(`cc."%s" >= c."%s"`, "updated_at", "updated_at"), fmt.Sprintf("Idempotency key: %s", mergeSQL)) + // Check primary keys clause + assert.Contains(t, mergeSQL, `AS cc ON c."id" = cc."id"`, mergeSQL) + + assert.Contains(t, mergeSQL, `SET "id"=cc."id","bar"=cc."bar","updated_at"=cc."updated_at","start"=cc."start"`, mergeSQL) + assert.Contains(t, mergeSQL, `id,bar,updated_at,start`, mergeSQL) + assert.Contains(t, mergeSQL, `cc."id",cc."bar",cc."updated_at",cc."start"`, mergeSQL) +} diff --git a/lib/destination/dml/merge.go b/lib/destination/dml/merge.go index e04292ce..688f3eb3 100644 --- a/lib/destination/dml/merge.go +++ b/lib/destination/dml/merge.go @@ -244,78 +244,25 @@ WHEN NOT MATCHED AND IFNULL(cc.%s, false) = false THEN INSERT (%s) VALUES (%s);` })), nil } -func (m *MergeArgument) buildMSSQLStatement() (string, error) { - var idempotentClause string - if m.IdempotentKey != "" { - idempotentClause = fmt.Sprintf("AND cc.%s >= c.%s ", m.IdempotentKey, m.IdempotentKey) - } - - var equalitySQLParts []string - for _, primaryKey := range m.PrimaryKeys { - // We'll need to escape the primary key as well. - quotedPrimaryKey := m.Dialect.QuoteIdentifier(primaryKey.Name()) - equalitySQL := fmt.Sprintf("c.%s = cc.%s", quotedPrimaryKey, quotedPrimaryKey) - equalitySQLParts = append(equalitySQLParts, equalitySQL) - } - - if m.SoftDelete { - return fmt.Sprintf(` -MERGE INTO %s c -USING %s AS cc ON %s -WHEN MATCHED %sTHEN UPDATE SET %s -WHEN NOT MATCHED AND COALESCE(cc.%s, 0) = 0 THEN INSERT (%s) VALUES (%s);`, - m.TableID.FullyQualifiedName(), m.SubQuery, strings.Join(equalitySQLParts, " and "), - // Update + Soft Deletion - idempotentClause, columns.BuildColumnsUpdateFragment(m.Columns, m.Dialect), - // Insert - m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), strings.Join(columns.QuoteColumns(m.Columns, m.Dialect), ","), - array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{ - Vals: columns.QuoteColumns(m.Columns, m.Dialect), - Separator: ",", - Prefix: "cc.", - })), nil - } - - // We also need to remove __artie flags since it does not exist in the destination table - cols, removed := columns.RemoveDeleteColumnMarker(m.Columns) - if !removed { - return "", errors.New("artie delete flag doesn't exist") - } - - return fmt.Sprintf(` -MERGE INTO %s c -USING %s AS cc ON %s -WHEN MATCHED AND cc.%s = 1 THEN DELETE -WHEN MATCHED AND COALESCE(cc.%s, 0) = 0 %sTHEN UPDATE SET %s -WHEN NOT MATCHED AND COALESCE(cc.%s, 1) = 0 THEN INSERT (%s) VALUES (%s);`, - m.TableID.FullyQualifiedName(), m.SubQuery, strings.Join(equalitySQLParts, " and "), - // Delete - m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), - // Update - m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), idempotentClause, columns.BuildColumnsUpdateFragment(cols, m.Dialect), - // Insert - m.Dialect.QuoteIdentifier(constants.DeleteColumnMarker), strings.Join(columns.QuoteColumns(cols, m.Dialect), ","), - array.StringsJoinAddPrefix(array.StringsJoinAddPrefixArgs{ - Vals: columns.QuoteColumns(cols, m.Dialect), - Separator: ",", - Prefix: "cc.", - })), nil -} - func (m *MergeArgument) BuildStatements() ([]string, error) { if err := m.Valid(); err != nil { return nil, err } - switch m.Dialect.(type) { + switch specificDialect := m.Dialect.(type) { case redshiftDialect.RedshiftDialect: return m.buildRedshiftStatements() case mssqlDialect.MSSQLDialect: - mergeQuery, err := m.buildMSSQLStatement() - if err != nil { - return nil, err - } - return []string{mergeQuery}, nil + return specificDialect.BuildMergeQueries( + m.TableID, + m.SubQuery, + m.IdempotentKey, + m.PrimaryKeys, + m.AdditionalEqualityStrings, + m.Columns, + m.SoftDelete, + m.ContainsHardDeletes, + ) default: mergeQuery, err := m.buildDefaultStatement() if err != nil { diff --git a/lib/destination/dml/merge_mssql_test.go b/lib/destination/dml/merge_mssql_test.go deleted file mode 100644 index 19adc9dc..00000000 --- a/lib/destination/dml/merge_mssql_test.go +++ /dev/null @@ -1,87 +0,0 @@ -package dml - -import ( - "fmt" - "strings" - "testing" - "time" - - "github.com/artie-labs/transfer/clients/mssql/dialect" - "github.com/artie-labs/transfer/lib/config/constants" - "github.com/artie-labs/transfer/lib/mocks" - "github.com/artie-labs/transfer/lib/ptr" - "github.com/artie-labs/transfer/lib/typing" - "github.com/artie-labs/transfer/lib/typing/columns" - "github.com/stretchr/testify/assert" -) - -func Test_BuildMSSQLStatement(t *testing.T) { - - var _cols = []columns.Column{ - columns.NewColumn("id", typing.String), - columns.NewColumn("bar", typing.String), - columns.NewColumn("updated_at", typing.String), - columns.NewColumn("start", typing.String), - columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean), - } - cols := make([]string, len(_cols)) - for i, col := range _cols { - cols[i] = col.Name() - } - - tableValues := []string{ - fmt.Sprintf("('%s', '%s', '%v', '%v', false)", "1", "456", "foo", time.Now().Round(0).UTC()), - fmt.Sprintf("('%s', '%s', '%v', '%v', false)", "2", "bb", "bar", time.Now().Round(0).UTC()), - fmt.Sprintf("('%s', '%s', '%v', '%v', false)", "3", "dd", "world", time.Now().Round(0).UTC()), - } - - // select cc.foo, cc.bar from (values (12, 34), (44, 55)) as cc(foo, bar); - subQuery := fmt.Sprintf("SELECT %s from (values %s) as %s(%s)", - strings.Join(cols, ","), strings.Join(tableValues, ","), "_tbl", strings.Join(cols, ",")) - - fqTable := "database.schema.table" - fakeID := &mocks.FakeTableIdentifier{} - fakeID.FullyQualifiedNameReturns(fqTable) - mergeArg := MergeArgument{ - TableID: fakeID, - SubQuery: subQuery, - IdempotentKey: "", - PrimaryKeys: []columns.Column{columns.NewColumn("id", typing.Invalid)}, - Columns: _cols, - Dialect: dialect.MSSQLDialect{}, - SoftDelete: false, - } - - mergeSQL, err := mergeArg.buildMSSQLStatement() - assert.NoError(t, err) - assert.Contains(t, mergeSQL, fmt.Sprintf("MERGE INTO %s", fqTable), mergeSQL) - assert.NotContains(t, mergeSQL, fmt.Sprintf(`cc."%s" >= c."%s"`, "updated_at", "updated_at"), fmt.Sprintf("Idempotency key: %s", mergeSQL)) - // Check primary keys clause - assert.Contains(t, mergeSQL, `AS cc ON c."id" = cc."id"`, mergeSQL) - - assert.Contains(t, mergeSQL, `SET "id"=cc."id","bar"=cc."bar","updated_at"=cc."updated_at","start"=cc."start"`, mergeSQL) - assert.Contains(t, mergeSQL, `id,bar,updated_at,start`, mergeSQL) - assert.Contains(t, mergeSQL, `cc."id",cc."bar",cc."updated_at",cc."start"`, mergeSQL) -} - -func TestMergeArgument_BuildStatements_MSSQL(t *testing.T) { - var cols = []columns.Column{ - columns.NewColumn("id", typing.String), - columns.NewColumn(constants.DeleteColumnMarker, typing.Boolean), - } - - mergeArg := MergeArgument{ - TableID: &mocks.FakeTableIdentifier{}, - SubQuery: "{SUB_QUERY}", - PrimaryKeys: []columns.Column{cols[0]}, - Columns: cols, - Dialect: dialect.MSSQLDialect{}, - ContainsHardDeletes: ptr.ToBool(true), - } - - statement, err := mergeArg.buildMSSQLStatement() - assert.NoError(t, err) - statements, err := mergeArg.BuildStatements() - assert.NoError(t, err) - assert.Equal(t, statements, []string{statement}) -}