# Import module from other notebook


In [0]:
%run /Workspace/Users/nipun.a@bluebik.com/de-learning/src/databricks/merge_functions

In [0]:
%pip install nutter --quiet

[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m
[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m


In [0]:
from pyspark.sql import Row
from databricks.sdk.runtime import *
from delta.tables import *
from datetime import datetime
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, DateType
from pyspark.sql.functions import col, lit, current_date, md5, concat_ws
from delta.tables import DeltaTable
from runtime.nutterfixture import NutterFixture, tag

# Unit testing class

In [0]:
class TestMergePipeline(NutterFixture):
    def __init__(self):
        self.spark = SparkSession.builder.appName("Test Merge Pipeline").getOrCreate()
        self.PayLoadParquet = """{
                            "Path": "/raw/sql_server/ShippedOrders/year=2024/month=7/day=31/run_id=",
                            "LocalRunDatetime": "Jul 23 2024  5:25PM",
                            "DatasetName": "test_dataset",
                            "SchemaName": "test_schema",
                            "LoadType": "I",
                            "PrimaryKeyFields": "order_id",
                            "StagingType": "parquet",
                            "PartitionFields": ["year"],
                            "UpdateType": ""
                        }"""
        self.RunIdParquet = "5b408abd-ec21-4807-8eaf-dff9b17ee6dc"
        self.CatalogName = '`nipun-catalog`'
        self.PayLoadCsv =   """{
                    "LocalRunDatetime": "Aug  7 2024  2:41PM",
                    "DatasetName": "SalesOrderHeader",
                    "LoadType": "F",
                    "UpdateType": null,
                    "PrimaryKeyFields": null,
                    "CsvPath": "abfss://nipun@delearningstdfssandbox.dfs.core.windows.net/raw_csv/",
                    "CsvDelimiter": "','",
                    "StagingType": "csv"
                }
                """
        self.RunIdCsv = "E4D1E20D-9F43-4CB1-A2B2-2E244A7C2F62"
        super().__init__() 

    def assertion_read_data_csv(self):
        etl_pipeline = MergePipeline(payload=self.PayLoadCsv, run_id=self.RunIdCsv)
        df_actual = etl_pipeline.read_data()
        df_expect = spark.read.format("csv").option("header", "true").load("abfss://nipun@delearningstdfssandbox.dfs.core.windows.net/raw_csv/SalesOrderHeader.csv")

        df_diff = df_actual.subtract(df_expect).union(df_expect.subtract(df_actual))
        assert df_diff.count() == 0, "DataFrames do not match"
    
    def assertion_read_data_parquet(self):
        etl_pipeline = MergePipeline(payload=self.PayLoadParquet, run_id=self.RunIdParquet)
        df_actual = etl_pipeline.read_data()
        df_expect = spark.read.parquet("abfss://nipun@delearningstdfssandbox.dfs.core.windows.net/raw/sql_server/ShippedOrders/year=2024/month=7/day=31/run_id=5b408abd-ec21-4807-8eaf-dff9b17ee6dc")

        df_diff = df_actual.subtract(df_expect).union(df_expect.subtract(df_actual))
        assert df_diff.count() == 0, "DataFrames do not match"

    def assertion_create_database(self):
        etl_pipeline = MergePipeline(self.PayLoadParquet, run_id=self.RunIdParquet)
        etl_pipeline.create_database()

        use_catalog = f"USE CATALOG {self.CatalogName}"
        self.spark.sql(use_catalog)

        schema_name = "test_schema_Nipun"
        schema_exists = self.spark.sql(f"SHOW SCHEMAS LIKE '{schema_name}'").count() > 0
        assert schema_exists, f"Schema {schema_name} was not created successfully"

        drop_schema = f"DROP SCHEMA IF EXISTS {self.CatalogName}.{schema_name} CASCADE"
        self.spark.sql(drop_schema)
        
        schema_exists_after_cleanup = self.spark.sql(f"SHOW SCHEMAS LIKE '{schema_name}'").count() == 0
        assert schema_exists_after_cleanup, f"Schema {schema_name} was not removed after cleanup"
  
    def assertion_create_table_if_not_exist(self):
        json_object = json.loads(self.PayLoadParquet)
        schema_name = json_object['SchemaName'] + "_Nipun"
        dataset_name = json_object['DatasetName'].lower()

        mock_data = [(1, "sample_data")]
        spark_df = self.spark.createDataFrame(mock_data, ["id", "value"])

        etl_pipeline = MergePipeline(payload=self.PayLoadParquet, run_id=self.RunIdParquet)
        etl_pipeline.create_database()
        etl_pipeline.create_table_if_not_exist(spark_df)

        use_catalog = f"USE CATALOG {self.CatalogName}"
        self.spark.sql(use_catalog)
        show_tables_query = f"SHOW TABLES IN {self.CatalogName}.{schema_name}"
        table_exists_after = self.spark.sql(show_tables_query).filter(f"tableName = '{dataset_name}'").count() > 0
        assert table_exists_after, f"Table {dataset_name} was not created successfully"

        drop_schema = f"DROP SCHEMA IF EXISTS {self.CatalogName}.{schema_name} CASCADE"
        self.spark.sql(drop_schema)

    def assertion_apply_full_load(self):
        json_object = json.loads(self.PayLoadParquet)
        schema_name = json_object['SchemaName'] + "_Nipun"
        dataset_name = json_object['DatasetName'].lower()

        initial_data = [(1, "initial_data_1"), (2, "initial_data_2")]
        initial_df = self.spark.createDataFrame(initial_data, ["id", "value"])

        create_schema = f"""CREATE SCHEMA IF NOT EXISTS {self.CatalogName}.{schema_name}"""
        self.spark.sql(create_schema)
        initial_df.write.format('delta').mode("overwrite").saveAsTable(f'{schema_name}.{dataset_name}')

        table_data_before = self.spark.sql(f"SELECT * FROM {schema_name}.{dataset_name}")
        assert table_data_before.count() == 2, "Initial data not written correctly to the table"

        new_data = [(3, "new_data_1"), (4, "new_data_2")]
        new_df = self.spark.createDataFrame(new_data, ["id", "value"])

        json_object['LoadType'] = "F"
        self.PayLoadParquet = json.dumps(json_object)
        
        merge_pipeline = MergePipeline(payload=self.PayLoadParquet, run_id=self.RunIdParquet)
        merge_pipeline.apply_update(new_df)

        table_data_after = self.spark.sql(f"SELECT * FROM {schema_name}.{dataset_name}")
        assert table_data_after.count() == 2, "Table was not overwritten correctly"
        assert table_data_after.filter("id = 3").count() == 1, "New data was not loaded correctly"
        assert table_data_after.filter("id = 1").count() == 0, "Old data was not overwritten correctly"

        drop_schema = f"DROP SCHEMA IF EXISTS {self.CatalogName}.{schema_name} CASCADE"
        self.spark.sql(drop_schema)

    def assertion_apply_insert_load(self):
        json_object = json.loads(self.PayLoadParquet)
        schema_name = json_object['SchemaName'] + "_Nipun"
        dataset_name = json_object['DatasetName'].lower()

        initial_data = [(1, "initial_data_1"), (2, "initial_data_2")]
        initial_df = self.spark.createDataFrame(initial_data, ["id", "value"])

        create_schema = f"""CREATE SCHEMA IF NOT EXISTS {self.CatalogName}.{schema_name}"""
        self.spark.sql(create_schema)
        initial_df.write.format('delta').mode("overwrite").saveAsTable(f'{schema_name}.{dataset_name}')

        table_data_before = self.spark.sql(f"SELECT * FROM {schema_name}.{dataset_name}")
        assert table_data_before.count() == 2, "Initial data not written correctly to the table"

        new_data = [(3, "new_data_1"), (4, "new_data_2")]
        new_df = self.spark.createDataFrame(new_data, ["id", "value"])

        json_object['LoadType'] = "I"
        json_object['PrimaryKeyFields'] = ""  
        self.PayLoadParquet = json.dumps(json_object)

        merge_pipeline = MergePipeline(payload=self.PayLoadParquet, run_id=self.RunIdParquet)
        merge_pipeline.apply_update(new_df)

        table_data_after = self.spark.sql(f"SELECT * FROM {schema_name}.{dataset_name}")
        assert table_data_after.count() == 4, "Data was not inserted correctly"
        assert table_data_after.filter("id = 3").count() == 1, "New data was not loaded correctly"
        assert table_data_after.filter("id = 1").count() == 1, "Existing data was overwritten, which should not happen"

        drop_schema = f"DROP SCHEMA IF EXISTS {self.CatalogName}.{schema_name} CASCADE"
        self.spark.sql(drop_schema)

    def assertion_apply_upsert_load(self):
        json_object = json.loads(self.PayLoadParquet)
        schema_name = json_object['SchemaName'] + "_Nipun"
        dataset_name = json_object['DatasetName'].lower()

        create_schema = f"""CREATE SCHEMA IF NOT EXISTS {self.CatalogName}.{schema_name}"""
        self.spark.sql(create_schema)

        initial_data = [(1, "initial_data_1"), (2, "initial_data_2")]
        initial_df = self.spark.createDataFrame(initial_data, ["id", "value"])
        initial_df.write.format('delta').mode("overwrite").saveAsTable(f'{schema_name}.{dataset_name}')

        table_data_before = self.spark.sql(f"SELECT * FROM {schema_name}.{dataset_name}")
        assert table_data_before.count() == 2, "Initial data not written correctly to the table"

        upsert_data = [(2, "updated_data_2"), (3, "new_data_3")]
        upsert_df = self.spark.createDataFrame(upsert_data, ["id", "value"])

        json_object['LoadType'] = "I"
        json_object['PrimaryKeyFields'] = "id"  
        self.PayLoadParquet = json.dumps(json_object)

        merge_pipeline = MergePipeline(payload=self.PayLoadParquet, run_id=self.RunIdParquet)
        merge_pipeline.apply_update(upsert_df)

        table_data_after = self.spark.sql(f"SELECT * FROM {schema_name}.{dataset_name}")
        assert table_data_after.count() == 3, "Data was not upserted correctly"
        
        assert table_data_after.filter("id = 2 AND value = 'updated_data_2'").count() == 1, "Existing data was not updated correctly"
        assert table_data_after.filter("id = 3 AND value = 'new_data_3'").count() == 1, "New data was not inserted correctly"

        drop_schema = f"DROP SCHEMA IF EXISTS {self.CatalogName}.{schema_name} CASCADE"
        self.spark.sql(drop_schema)


In [0]:
class TestMergePipeline2(NutterFixture):
    def __init__(self):
        self.spark = SparkSession.builder.appName("Test Merge Pipeline").getOrCreate()
        self.PayLoadParquet = """{
                            "Path": "/raw/sql_server/ShippedOrders/year=2024/month=7/day=31/run_id=",
                            "LocalRunDatetime": "Jul 23 2024  5:25PM",
                            "DatasetName": "test_dataset",
                            "SchemaName": "test_schema",
                            "LoadType": "I",
                            "PrimaryKeyFields": "order_id",
                            "StagingType": "parquet",
                            "PartitionFields": ["year"],
                            "UpdateType": ""
                        }"""
        self.RunIdParquet = "5b408abd-ec21-4807-8eaf-dff9b17ee6dc"
        self.CatalogName = '`nipun-catalog`'
        self.PayLoadCsv =   """{
                    "LocalRunDatetime": "Aug  7 2024  2:41PM",
                    "DatasetName": "SalesOrderHeader",
                    "LoadType": "F",
                    "UpdateType": null,
                    "PrimaryKeyFields": null,
                    "CsvPath": "abfss://nipun@delearningstdfssandbox.dfs.core.windows.net/raw_csv/",
                    "CsvDelimiter": "','",
                    "StagingType": "csv"
                }
                """
        self.RunIdCsv = "E4D1E20D-9F43-4CB1-A2B2-2E244A7C2F62"
        super().__init__()

    def assertion_apply_op(self):
        json_object = json.loads(self.PayLoadParquet)
        schema_name = json_object['SchemaName'] + "_Nipun"
        dataset_name = json_object['DatasetName'].lower()

        # Create a DataFrame with initial data to simulate existing records in the table
        initial_data = [
            (2020, 'USA', 1000),
            (2020, 'Canada', 1500),
            (2020, 'UK', 2000),
            (2021, 'USA', 2000),
            (2021, 'Canada', 3500),
            (2021, 'UK', 1000),
        ]
        initial_schema = ["year", "country", "amount"]
        initial_df = self.spark.createDataFrame(initial_data, initial_schema)

        create_schema = f"""CREATE SCHEMA IF NOT EXISTS {self.CatalogName}.{schema_name}"""
        self.spark.sql(create_schema)
        initial_df.write.format('delta').mode('overwrite').saveAsTable(f"{self.CatalogName}.{schema_name}.{dataset_name}")

        new_data = [
            (2020, 'USA', 1112),
            (2020, 'Canada', 4445),
            (2022, 'USA', 5667),
            (2022, 'Canada', 3244),
            (2022, 'UK', 8976),
        ]
        new_schema = ["year", "country", "amount"]
        new_df = self.spark.createDataFrame(new_data, new_schema)

        merge_pipeline = MergePipeline(payload=self.PayLoadParquet, run_id=self.RunIdParquet)
        merge_pipeline.apply_op(new_df)

        query_table = f"SELECT * FROM {self.CatalogName}.{schema_name}.{dataset_name}"
        final_table_df = self.spark.sql(query_table)
        #final_table_df.show(truncate=False)  # Optional: display the result for debugging

        expected_data = [
            Row(year=2020, country='Canada', amount=4445),
            Row(year=2020, country='USA', amount=1112),
            Row(year=2021, country='Canada', amount=3500),
            Row(year=2021, country='UK', amount=1000),
            Row(year=2021, country='USA', amount=2000),
            Row(year=2022, country='Canada', amount=3244),
            Row(year=2022, country='UK', amount=8976),
            Row(year=2022, country='USA', amount=5667),
        ]
        expected_df = self.spark.createDataFrame(expected_data)

        assert final_table_df.subtract(expected_df).count() == 0, "The table data does not match the expected result"

        drop_schema = f"DROP SCHEMA IF EXISTS {self.CatalogName}.{schema_name} CASCADE"
        self.spark.sql(drop_schema)

    def compare_dataframes(self, df1: DataFrame, df2: DataFrame) -> bool:
        if df1.schema != df2.schema:
            print("Schemas are different")
            return False

        if df1.count() != df2.count():
            print("Row counts are different")
            return False

        df1_sorted = df1.orderBy(df1.columns)
        df2_sorted = df2.orderBy(df2.columns)

        df1_list = df1_sorted.collect()
        df2_list = df2_sorted.collect()

        if df1_list != df2_list:
            print("Data is different")
            return False

        print("DataFrames are identical")
        return True

    def assertion_apply_scd2_load(self):
        json_object = json.loads(self.PayLoadParquet)
        schema_name = json_object['SchemaName'] + "_Nipun"
        dataset_name = json_object['DatasetName'].lower()

        schema = StructType([
                    StructField("CustomerID", IntegerType(), True),
                    StructField("Name", StringType(), True),
                    StructField("Region", StringType(), True),
                    StructField("eff_start_date", DateType(), True),
                    StructField("eff_end_date", DateType(), True),
                    StructField("flag", IntegerType(), True)
                ])
        data_target = [
            (1, "John Doe", "North", datetime.strptime("2023-01-01", "%Y-%m-%d"), None, 1),
            (2, "Jane Smith", "South", datetime.strptime("2023-01-01", "%Y-%m-%d"), None, 1),
            (3, "Mike Johnson", "East", datetime.strptime("2023-01-01", "%Y-%m-%d"), None, 1)
        ]

        df_target = spark.createDataFrame(data_target, schema)   
        create_schema = f"""CREATE SCHEMA IF NOT EXISTS {self.CatalogName}.{schema_name}"""          
        self.spark.sql(create_schema)
        df_target.write.format("delta").mode("overwrite").saveAsTable(f"{schema_name}.{dataset_name}")

        source_schema = StructType([
            StructField("CustomerID", IntegerType(), True),
            StructField("Name", StringType(), True),
            StructField("Region", StringType(), True)
        ])

        data_source = [
            (1, "John Doe", "West"),
            (4, "Alice Brown", "North")  
        ]

        df_source = spark.createDataFrame(data_source, schema=source_schema)

        json_object['UpdateType'] = "scd2"
        json_object['PrimaryKeyFields'] = "CustomerID"
        self.PayLoadParquet = json.dumps(json_object)

        merge_pipeline = MergePipeline(payload=self.PayLoadParquet, run_id="mock_run_id")
        merge_pipeline.apply_update(df_source)

        actual_df = DeltaTable.forName(self.spark, f'{self.CatalogName}.{schema_name}.{dataset_name}').toDF()

        expected_data = [
            (1, "John Doe", "North", datetime.strptime("2023-01-01", "%Y-%m-%d").date(), datetime.today().date(), 0),
            (3, "Mike Johnson", "East", datetime.strptime("2023-01-01", "%Y-%m-%d").date(), None, 1),
            (4, "Alice Brown", "North", datetime.today().date(), None, 1),
            (2, "Jane Smith", "South", datetime.strptime("2023-01-01", "%Y-%m-%d").date(), None, 1),
            (1, "John Doe", "West",datetime.today().date(), None, 1)
        ]

        expected_df = spark.createDataFrame(expected_data, schema) 

        assert self.compare_dataframes(actual_df, expected_df), "SCD2 update failed"

        drop_schema = f"DROP SCHEMA IF EXISTS {self.CatalogName}.{schema_name} CASCADE"
        self.spark.sql(drop_schema)

#Run Unit Test

In [0]:
result = TestMergePipeline().execute_tests()
print(result.to_string())

# Comment out the next line (result.exit(dbutils)) to see the test result report from within the notebook
#result.exit(dbutils)


Notebook: N/A - Lifecycle State: N/A, Result: N/A
Run Page URL: N/A
PASSING TESTS
------------------------------------------------------------
apply_full_load (11.708774161999827 seconds)
apply_insert_load (7.2016212669987 seconds)
apply_upsert_load (8.904465669000274 seconds)
create_database (1.4783431269988796 seconds)
create_table_if_not_exist (4.525577089998478 seconds)
read_data_csv (1.0168841579998116 seconds)
read_data_parquet (1.0076113070008432 seconds)





In [0]:
result = TestMergePipeline2().execute_tests()
print(result.to_string())

# Comment out the next line (result.exit(dbutils)) to see the test result report from within the notebook
#result.exit(dbutils)

DataFrames are identical

Notebook: N/A - Lifecycle State: N/A, Result: N/A
Run Page URL: N/A
PASSING TESTS
------------------------------------------------------------
apply_op (11.445538974001465 seconds)
apply_scd2_load (10.676581640000222 seconds)





In [0]:
CatalogName = '`nipun-catalog`'
schema_name = 'test_schema_Nipun'
drop_schema = f"DROP SCHEMA IF EXISTS {CatalogName}.{schema_name} CASCADE"
spark.sql(drop_schema)

DataFrame[]