In [0]:
%sql
drop table if exists processed_data.dim_str_location

In [0]:
from pyspark.sql import *
from pyspark.sql.functions import *
from pyspark.sql.types import *

#---Load data dfs
df_inv_loc = spark.sql("select * from raw_data.dim_inv_loc")  # Inventory locations
df_pos_site = spark.sql("select * from raw_data.dim_pos_site")  # POS sites
df_rtl_loc = spark.sql("select * from raw_data.dim_rtl_loc")  # Retail locations

#rewrite the column names
df_inv_loc = df_inv_loc.select(
    col("loc").alias("location_id"),
    col("loc_label").alias("location_label"),
    col("loctype_label").alias("location_type"),
)
df_inv_loc = df_inv_loc.withColumn(
    "location_type",
    when(col("location_type").isin(["Store", "Retail Store", "Retail POS"]), "Retail Store")
    .otherwise(col("location_type"))
)

df_pos_site = df_pos_site.select(
    col("site_id").alias("location_id"),
    col("site_label").alias("location_label"),
    col("chnl_label").alias("location_type"),
    col("subchnl_label").alias("channel")
)
df_pos_site = df_pos_site.withColumn(
    "location_type",
    when(col("location_type").isin(["Store", "Retail Store", "Retail POS"]), "Retail Store")
    .otherwise(col("location_type"))
)

df_rtl_loc = df_rtl_loc.select(
    col("str").alias("location_id"),
    col("str_label").alias("location_label"),
    lit("Retail Store").alias("location_type"),
    col("dstr_label").alias("district"),
    col("rgn_label").alias("region")
)
df_rtl_loc = df_rtl_loc.withColumn(
    "location_type",
    when(col("location_type").isin(["Store", "Retail Store", "Retail POS"]), "Retail Store")
    .otherwise(col("location_type"))
)

#Extract only numeric values from location_id
df_pos_site = df_pos_site.withColumn("location_id", regexp_extract(col("location_id"), r'^\d+$', 0))

#Filter out empty values (non-numeric values become empty strings)
df_pos_site = df_pos_site.filter(col("location_id") != "")

#----Convert location_id to IntegerType
df_pos_site = df_pos_site.withColumn("location_id", col("location_id").cast(IntegerType()))


#Combine all locations and remove duplicates
df_dim_str_location = df_inv_loc[['location_id','location_label','location_type']].join(df_pos_site[['location_id','channel']],on="location_id",how="inner").join(df_rtl_loc[['location_id','region','district']],on="location_id",how="inner")

#unique primary key sarrogate key
df_dim_str_location = df_dim_str_location.withColumn("dim_location_id", monotonically_increasing_id())


In [0]:
#------initial loding of table------
df_dim_str_location.write.mode("overwrite").saveAsTable("processed_data.dim_str_location")

In [0]:
#Defining the expected Schema of the table dim_str_location
expected_schema = {
    "location_id": "int",
    "location_label": "string",
    "location_type": "string",
    "channel": "string",
    "region": "string",
    "district": "string",
    "dim_location_id": "bigint",
}

#1. Schema Validation - Ensure data types are correct
for column, dtype in expected_schema.items():
    actual_dtype = df_dim_str_location.schema[column].dataType.simpleString()
    if actual_dtype != dtype:
        print(f"Schema Mismatch: Column {column} expected {dtype}, but got {actual_dtype}")

#2. Null Value Check - Ensure No Missing Data in Key Columns
null_check_df = df_dim_str_location.select(
    [count(when(col(c).isNull(), c)).alias(c) for c in expected_schema.keys()]
)
null_check_df.show()

#3. Duplicate Check
duplicate_count = df_dim_str_location.groupBy("location_id").count().filter(col("count") > 1).count()
if duplicate_count > 0:
    print(f"Warning: Found {duplicate_count} duplicate location_id records")

#4. Invalid Location ID Check - Ensure Only Numeric IDs
invalid_location_ids = df_dim_str_location.filter(col("location_id").isNull() | (col("location_id") <= 0))
if invalid_location_ids.count() > 0:
    print("Warning: Found invalid location_id values")

if duplicate_count == 0 and invalid_location_ids.count() == 0:
    df_dim_str_location.write.mode("overwrite").saveAsTable("processed_data.dim_str_location")
    print("Data passed all quality checks and loaded successfully")
else:
    print("Data quality checks failed,need fix")

In [0]:
%sql
select * from processed_data.dim_location