Skip to content

Commit

Permalink
changed merge statement and added testcase (#120)
Browse files Browse the repository at this point in the history
  • Loading branch information
dgcaron committed Sep 27, 2023
1 parent ff3f4fd commit 5a93f4e
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 3 deletions.
5 changes: 2 additions & 3 deletions mack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,10 @@ def type_2_scd_generic_upsert(
delta_table.alias("base")
.merge(
source=staged_updates.alias("staged_updates"),
condition=pyspark.sql.functions.expr(
f"base.{primary_key} = mergeKey AND base.{is_current_col_name} = true AND ({staged_updates_attrs})"
),
condition=pyspark.sql.functions.expr(f"base.{primary_key} = mergeKey"),
)
.whenMatchedUpdate(
condition=f"base.{is_current_col_name} = true AND ({staged_updates_attrs})",
set={
is_current_col_name: "false",
end_time_col_name: f"staged_updates.{effective_time_col_name}",
Expand Down
50 changes: 50 additions & 0 deletions tests/test_public_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,56 @@ def test_upserts_based_on_version_number(tmp_path):
chispa.assert_df_equality(res, expected, ignore_row_order=True)


def test_upserts_does_not_insert_duplicate(tmp_path):
path = f"{tmp_path}/tmp/delta-no-duplicate"
# create Delta Lake
data2 = [
(1, "A", True, dt(2019, 1, 1), None),
(2, "B", True, dt(2019, 1, 1), None),
(4, "D", True, dt(2019, 1, 1), None),
]

schema = StructType(
[
StructField("pkey", IntegerType(), True),
StructField("attr", StringType(), True),
StructField("cur", BooleanType(), True),
StructField("effective_date", DateType(), True),
StructField("end_date", DateType(), True),
]
)

df = spark.createDataFrame(data=data2, schema=schema)
df.write.format("delta").save(path)

# create updates DF
updates_df = spark.createDataFrame(
[
(1, "A", dt(2019, 1, 1)), # duplicate row

]
).toDF("pkey", "attr", "effective_date")

# perform upsert
delta_table = DeltaTable.forPath(spark, path)
mack.type_2_scd_generic_upsert(
delta_table, updates_df, "pkey", ["attr"], "cur", "effective_date", "end_date"
)

actual_df = spark.read.format("delta").load(path)

expected_df = spark.createDataFrame(
[
(1, "A", True, dt(2019, 1, 1), None),
(2, "B", True, dt(2019, 1, 1), None),
(4, "D", True, dt(2019, 1, 1), None),
],
schema,
)

chispa.assert_df_equality(actual_df, expected_df, ignore_row_order=True)


# def describe_kill_duplicates():
def test_kills_duplicates_in_a_delta_table(tmp_path):
path = f"{tmp_path}/deduplicate1"
Expand Down

0 comments on commit 5a93f4e

Please sign in to comment.