# ðŸ“± Samsung Sales Analytics Pipeline (PySpark ETL)

This notebook contains an end-to-end PySpark ETL pipeline that:
- Ingests mock sales transactions, product catalog, and store/region data
- Cleans and standardizes fields
- Joins data to enrich transactions
- Calculates KPIs: **revenue and units sold per model and country**
- Writes curated outputs in **Parquet** format

You can run this in Databricks, EMR, or any Spark environment with PySpark.


## 1) Configuration

Update these paths to match your environment (DBFS/S3/ADLS/local).

In [None]:
# ---- Config ----
INPUT_SALES = "/data/raw/sales_transactions"         # e.g., "dbfs:/mnt/raw/sales_transactions" or "s3://bucket/raw/sales_transactions"
INPUT_PRODUCTS = "/data/raw/product_catalog"        # e.g., "dbfs:/mnt/raw/product_catalog"
INPUT_STORES = "/data/raw/store_regions"            # e.g., "dbfs:/mnt/raw/store_regions"

OUT_ENRICHED = "/data/curated/enriched_transactions" # enriched fact-like table
OUT_KPI_MODEL_COUNTRY = "/data/marts/kpi_model_country"
OUT_KPI_SUMMARY = "/data/marts/kpi_summary"

WRITE_MODE = "overwrite"  # change to "append" for incremental strategies


## 2) Imports & Spark Session

In Databricks, a Spark session usually exists as `spark`. In other environments, we create it.

In [None]:
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql import functions as F
from pyspark.sql import types as T

try:
    spark
except NameError:
    spark = SparkSession.builder.appName("SamsungSalesAnalyticsETL").getOrCreate()

spark

## 3) Helper Functions

Utilities for cleaning, casting, and validating schemas.

In [None]:
def standardize_strings(df: DataFrame, cols: list[str]) -> DataFrame:
    """Trim + normalize casing for join keys and categorical values."""
    out = df
    for c in cols:
        out = out.withColumn(c, F.upper(F.trim(F.col(c))))
    return out


def safe_cast_numeric(df: DataFrame, col_name: str, to_type: T.DataType) -> DataFrame:
    """Cast a column to numeric safely; non-castable values become null."""
    return df.withColumn(col_name, F.col(col_name).cast(to_type))


def assert_has_columns(df: DataFrame, required: list[str], df_name: str) -> None:
    missing = [c for c in required if c not in df.columns]
    if missing:
        raise ValueError(f"{df_name} missing required columns: {missing}. Found: {df.columns}")


## 4) Extract

Reads CSV inputs (header-based). For production, prefer explicit schemas over `inferSchema`.

In [None]:
def read_sales_transactions(spark: SparkSession, path: str) -> DataFrame:
    """Expected columns: transaction_id, transaction_ts, store_id, sku, units, unit_price (+ optional model, currency)."""
    return (
        spark.read.format("csv")
        .option("header", "true")
        .option("inferSchema", "true")
        .load(f"{path}/*.csv")
    )


def read_product_catalog(spark: SparkSession, path: str) -> DataFrame:
    """Expected columns: sku, model (+ optional series, launch_year, msrp)."""
    return (
        spark.read.format("csv")
        .option("header", "true")
        .option("inferSchema", "true")
        .load(f"{path}/*.csv")
    )


def read_store_regions(spark: SparkSession, path: str) -> DataFrame:
    """Expected columns: store_id, country (+ optional store_name, city, region)."""
    return (
        spark.read.format("csv")
        .option("header", "true")
        .option("inferSchema", "true")
        .load(f"{path}/*.csv")
    )


## 5) Transform

Cleans each dataset, enriches transactions via joins, and computes KPIs.

In [None]:
def clean_sales(df: DataFrame) -> DataFrame:
    required = ["transaction_id", "transaction_ts", "store_id", "sku", "units", "unit_price"]
    assert_has_columns(df, required, "sales_transactions")

    out = df
    out = standardize_strings(out, ["store_id", "sku"])

    if "model" in out.columns:
        out = out.withColumn("model", F.trim(F.col("model")))

    # timestamp parsing
    out = out.withColumn(
        "transaction_ts",
        F.coalesce(
            F.to_timestamp("transaction_ts"),
            F.to_timestamp("transaction_ts", "yyyy-MM-dd HH:mm:ss"),
            F.to_timestamp("transaction_ts", "MM/dd/yyyy HH:mm:ss"),
            F.to_timestamp("transaction_ts", "yyyy-MM-dd'T'HH:mm:ss"),
        ),
    ).withColumn("transaction_date", F.to_date("transaction_ts"))

    out = safe_cast_numeric(out, "units", T.IntegerType())
    out = safe_cast_numeric(out, "unit_price", T.DoubleType())

    # basic validity filters
    out = out.filter(F.col("transaction_id").isNotNull())
    out = out.filter(F.col("sku").isNotNull() & (F.col("sku") != ""))
    out = out.filter(F.col("store_id").isNotNull() & (F.col("store_id") != ""))
    out = out.filter(F.col("units").isNotNull() & (F.col("units") > 0))
    out = out.filter(F.col("unit_price").isNotNull() & (F.col("unit_price") >= 0))

    if "currency" in out.columns:
        out = out.withColumn("currency", F.upper(F.trim(F.col("currency"))))

    return out


def clean_products(df: DataFrame) -> DataFrame:
    required = ["sku", "model"]
    assert_has_columns(df, required, "product_catalog")

    out = df
    out = standardize_strings(out, ["sku"])
    out = out.withColumn("model", F.trim(F.col("model")))

    if "msrp" in out.columns:
        out = safe_cast_numeric(out, "msrp", T.DoubleType())
    if "launch_year" in out.columns:
        out = safe_cast_numeric(out, "launch_year", T.IntegerType())

    return out.dropDuplicates(["sku"])


def clean_stores(df: DataFrame) -> DataFrame:
    required = ["store_id", "country"]
    assert_has_columns(df, required, "store_regions")

    out = df
    out = standardize_strings(out, ["store_id", "country"])
    if "region" in out.columns:
        out = out.withColumn("region", F.trim(F.col("region")))

    return out.dropDuplicates(["store_id"])


def enrich_transactions(sales: DataFrame, products: DataFrame, stores: DataFrame) -> DataFrame:
    joined = (
        sales.alias("s")
        .join(products.alias("p"), on=F.col("s.sku") == F.col("p.sku"), how="left")
        .join(stores.alias("r"), on=F.col("s.store_id") == F.col("r.store_id"), how="left")
    )

    model_col = (
        F.coalesce(F.col("s.model"), F.col("p.model"))
        if "model" in sales.columns
        else F.col("p.model")
    )

    out = joined.select(
        F.col("s.transaction_id").alias("transaction_id"),
        F.col("s.transaction_ts").alias("transaction_ts"),
        F.col("s.transaction_date").alias("transaction_date"),
        F.col("s.store_id").alias("store_id"),
        (F.col("r.store_name").alias("store_name") if "store_name" in stores.columns else F.lit(None).alias("store_name")),
        (F.col("r.city").alias("city") if "city" in stores.columns else F.lit(None).alias("city")),
        F.col("r.country").alias("country"),
        (F.col("r.region").alias("region") if "region" in stores.columns else F.lit(None).alias("region")),
        F.col("s.sku").alias("sku"),
        model_col.alias("model"),
        (F.col("p.series").alias("series") if "series" in products.columns else F.lit(None).alias("series")),
        (F.col("p.launch_year").alias("launch_year") if "launch_year" in products.columns else F.lit(None).alias("launch_year")),
        F.col("s.units").alias("units"),
        F.col("s.unit_price").alias("unit_price"),
        (F.col("s.units") * F.col("s.unit_price")).alias("line_revenue"),
        (F.col("s.currency").alias("currency") if "currency" in sales.columns else F.lit(None).alias("currency")),
    ).withColumn("model", F.trim(F.col("model")))

    return out


def compute_kpis(enriched: DataFrame) -> DataFrame:
    return (
        enriched
        .filter(F.col("model").isNotNull() & (F.col("model") != ""))
        .filter(F.col("country").isNotNull() & (F.col("country") != ""))
        .groupBy("country", "model")
        .agg(
            F.sum("units").alias("units_sold"),
            F.round(F.sum("line_revenue"), 2).alias("revenue"),
            F.countDistinct("transaction_id").alias("transactions"),
            F.round(F.avg("unit_price"), 2).alias("avg_unit_price"),
        )
        .orderBy(F.col("revenue").desc())
    )


def compute_summary(kpi_model_country: DataFrame) -> DataFrame:
    return (
        kpi_model_country
        .groupBy("country")
        .agg(
            F.sum("units_sold").alias("units_sold"),
            F.round(F.sum("revenue"), 2).alias("revenue"),
            F.sum("transactions").alias("transactions"),
            F.countDistinct("model").alias("distinct_models"),
        )
        .orderBy(F.col("revenue").desc())
    )


## 6) Load (Write Parquet)

Writes the enriched transactions and KPI outputs to Parquet.

In [None]:
def write_parquet(df: DataFrame, path: str, mode: str = "overwrite", partition_cols: list[str] | None = None) -> None:
    writer = df.write.mode(mode).format("parquet")
    if partition_cols:
        writer = writer.partitionBy(*partition_cols)
    writer.save(path)


## 7) Run the Pipeline

This executes the ETL end-to-end and shows a small preview of the KPI output.

In [None]:
# ---- Extract ----
sales_raw = read_sales_transactions(spark, INPUT_SALES)
products_raw = read_product_catalog(spark, INPUT_PRODUCTS)
stores_raw = read_store_regions(spark, INPUT_STORES)

# ---- Transform ----
sales = clean_sales(sales_raw)
products = clean_products(products_raw)
stores = clean_stores(stores_raw)

enriched = enrich_transactions(sales, products, stores)
kpi_model_country = compute_kpis(enriched)
kpi_summary = compute_summary(kpi_model_country)

# ---- Load ----
write_parquet(enriched, OUT_ENRICHED, mode=WRITE_MODE, partition_cols=["transaction_date"])
write_parquet(kpi_model_country, OUT_KPI_MODEL_COUNTRY, mode=WRITE_MODE)
write_parquet(kpi_summary, OUT_KPI_SUMMARY, mode=WRITE_MODE)

print(f"âœ… Enriched transactions written to: {OUT_ENRICHED}")
print(f"âœ… KPI (model/country) written to: {OUT_KPI_MODEL_COUNTRY}")
print(f"âœ… KPI summary written to: {OUT_KPI_SUMMARY}")

display(kpi_model_country.limit(20)) if 'display' in globals() else kpi_model_country.show(20, truncate=False)


## 8) Notes / Next Enhancements

- Add a **quarantine** output for bad rows (DQ layer)
- Add **incremental processing** with a watermark on `transaction_ts`
- Add more marts: top models by region, ASP by series, MoM trends
