# ML Enhancement: Market Basket Analysis (FP-Growth)

Discover products frequently purchased together using the FP-Growth algorithm.

## Business Value
- Cross-sell recommendations ("Customers also bought")
- Optimize store layouts based on co-purchase patterns
- Bundle pricing opportunities
- Promotion targeting with complementary products

## Data Flow
```
Silver (fact_receipts + fact_receipt_lines) --> FP-Growth --> Gold (gold_product_associations)
```

## Usage
Schedule this notebook to run **weekly** via Fabric pipeline.

## Output
- `gold_product_associations`: Association rules with support, confidence, lift
- Top 100 rules by lift (minimum support: 0.01, minimum confidence: 0.3)

In [None]:
from pyspark.sql import functions as F
from pyspark.ml.fpm import FPGrowth
from pyspark.sql.window import Window
from pyspark.sql.utils import AnalysisException
from datetime import datetime, timezone
import os

In [None]:
# =============================================================================
# PARAMETERS
# =============================================================================

def get_env(var_name, default=None):
    return os.environ.get(var_name, default)

SILVER_DB = get_env("SILVER_DB", default="ag")
GOLD_DB = get_env("GOLD_DB", default="au")

# FP-Growth parameters from acceptance criteria
MIN_SUPPORT = float(get_env("MIN_SUPPORT", default="0.01"))  # 1% of transactions
MIN_CONFIDENCE = float(get_env("MIN_CONFIDENCE", default="0.3"))  # 30%
TOP_N_RULES = int(get_env("TOP_N_RULES", default="100"))  # Top 100 by lift

print(f"Configuration:")
print(f"  SILVER_DB={SILVER_DB}, GOLD_DB={GOLD_DB}")
print(f"  MIN_SUPPORT={MIN_SUPPORT}, MIN_CONFIDENCE={MIN_CONFIDENCE}")
print(f"  TOP_N_RULES={TOP_N_RULES}")

In [None]:
# =============================================================================
# HELPER FUNCTIONS
# =============================================================================

def ensure_database(name):
    spark.sql(f"CREATE DATABASE IF NOT EXISTS {name}")

def read_silver(table_name):
    return spark.table(f"{SILVER_DB}.{table_name}")

def save_gold(df, table_name):
    full_name = f"{GOLD_DB}.{table_name}"
    df.write.format("delta").mode("overwrite").saveAsTable(full_name)
    print(f"  {full_name}: {df.count()} rows")

def silver_exists(table_name):
    try:
        spark.table(f"{SILVER_DB}.{table_name}")
        return True
    except AnalysisException:
        return False

ensure_database(GOLD_DB)

In [None]:
print("="*60)
print("MARKET BASKET ANALYSIS - TRANSACTION PREPARATION")
print("="*60)

In [None]:
# Step 1: Prepare transaction baskets
# Join fact_receipts with fact_receipt_lines to get all items per receipt

if not silver_exists("fact_receipts") or not silver_exists("fact_receipt_lines"):
    raise RuntimeError("Required tables fact_receipts and fact_receipt_lines not found in Silver")

print("\nPreparing transaction baskets...")

# Read receipt headers (only SALE type, exclude returns)
receipts = (
    read_silver("fact_receipts")
    .filter(F.col("receipt_type") == "SALE")
    .select("receipt_id_ext", "store_id", "event_ts")
)

# Read receipt lines
receipt_lines = (
    read_silver("fact_receipt_lines")
    .select("receipt_id_ext", "product_id")
)

# Join and create baskets (list of product IDs per receipt)
baskets = (
    receipts
    .join(receipt_lines, "receipt_id_ext")
    .groupBy("receipt_id_ext")
    .agg(
        F.collect_set("product_id").alias("items")
    )
    # Filter out single-item transactions (no associations possible)
    .filter(F.size(F.col("items")) > 1)
)

total_baskets = baskets.count()
print(f"  Total transaction baskets (multi-item): {total_baskets:,}")

if total_baskets == 0:
    raise RuntimeError("No multi-item transactions found. Cannot proceed with market basket analysis.")

In [None]:
print("\n" + "="*60)
print("FP-GROWTH MODEL TRAINING")
print("="*60)

# Initialize FP-Growth model
fpGrowth = FPGrowth(
    itemsCol="items",
    minSupport=MIN_SUPPORT,
    minConfidence=MIN_CONFIDENCE
)

print(f"\nTraining FP-Growth model with {total_baskets:,} baskets...")
model = fpGrowth.fit(baskets)
print("  Model training complete")

In [None]:
print("\n" + "="*60)
print("ASSOCIATION RULE EXTRACTION")
print("="*60)

# Extract frequent itemsets
frequent_itemsets = model.freqItemsets
print(f"\nFrequent itemsets found: {frequent_itemsets.count():,}")

# Extract association rules
association_rules = model.associationRules
total_rules = association_rules.count()
print(f"Association rules found: {total_rules:,}")

if total_rules == 0:
    print("\nWARNING: No association rules found with current thresholds.")
    print("Consider lowering MIN_SUPPORT or MIN_CONFIDENCE parameters.")
    # Create empty result table
    from pyspark.sql.types import StructType, StructField, ArrayType, LongType, DoubleType, TimestampType
    schema = StructType([
        StructField("antecedent", ArrayType(LongType()), False),
        StructField("consequent", ArrayType(LongType()), False),
        StructField("support", DoubleType(), False),
        StructField("confidence", DoubleType(), False),
        StructField("lift", DoubleType(), False),
        StructField("computed_at", TimestampType(), False)
    ])
    result_df = spark.createDataFrame([], schema)
else:
    # Calculate lift and rank by it
    # Lift = confidence / (support of consequent)
    # Higher lift means stronger association
    
    # Get support for all items (consequent support needed for lift calculation)
    item_support = (
        frequent_itemsets
        .filter(F.size(F.col("items")) == 1)
        .select(
            F.col("items").getItem(0).alias("item"),
            F.col("freq").alias("item_freq")
        )
    )
    
    # Calculate lift for each rule
    rules_with_lift = (
        association_rules
        .withColumn("consequent_item", F.col("consequent").getItem(0))
        .join(
            item_support,
            F.col("consequent_item") == F.col("item"),
            "left"
        )
        .withColumn(
            "consequent_support",
            F.col("item_freq") / F.lit(total_baskets)
        )
        .withColumn(
            "lift",
            F.when(
                F.col("consequent_support") > 0,
                F.col("confidence") / F.col("consequent_support")
            ).otherwise(0.0)
        )
        .select(
            F.col("antecedent"),
            F.col("consequent"),
            # Support from the rule is actually the support of antecedent+consequent
            # We'll use confidence as the rule support for clarity
            F.col("confidence").alias("support"),
            F.col("confidence"),
            F.col("lift")
        )
    )
    
    # Get top N rules by lift
    top_rules = (
        rules_with_lift
        .orderBy(F.desc("lift"))
        .limit(TOP_N_RULES)
        .withColumn(
            "computed_at",
            F.lit(datetime.now(timezone.utc))
        )
    )
    
    result_df = top_rules
    print(f"\nTop {TOP_N_RULES} association rules by lift prepared")

In [None]:
print("\n" + "="*60)
print("SAVING TO GOLD LAYER")
print("="*60)

# Save to gold_product_associations table
save_gold(result_df, "gold_product_associations")

print("\nSample association rules (top 10 by lift):")
if result_df.count() > 0:
    result_df.orderBy(F.desc("lift")).limit(10).show(truncate=False)
else:
    print("  No rules to display")

In [None]:
print("\n" + "="*60)
print("CREATING PRODUCT RECOMMENDATION VIEW")
print("="*60)

# Create a view that makes recommendations easier to query
# For a given product, what products are frequently bought with it?

if result_df.count() > 0:
    recommendations = (
        result_df
        .withColumn("antecedent_product", F.explode("antecedent"))
        .withColumn("consequent_product", F.explode("consequent"))
        .select(
            F.col("antecedent_product").alias("product_id"),
            F.col("consequent_product").alias("recommended_product_id"),
            "support",
            "confidence",
            "lift",
            "computed_at"
        )
        .orderBy("product_id", F.desc("lift"))
    )
    
    save_gold(recommendations, "product_recommendations")
    
    print("\nSample recommendations (top 10):")
    recommendations.limit(10).show(truncate=False)
else:
    print("  Skipping: no rules to create recommendations")

In [None]:
print("\n" + "="*60)
print("MARKET BASKET ANALYSIS COMPLETE")
print("="*60)

# Summary statistics
print(f"\nSummary:")
print(f"  Transaction baskets analyzed: {total_baskets:,}")
print(f"  Frequent itemsets: {frequent_itemsets.count():,}")
print(f"  Association rules (all): {total_rules:,}")
print(f"  Top rules saved: {result_df.count()}")

# Show Gold tables
gold_tables = spark.sql(f"SHOW TABLES IN {GOLD_DB}").collect()
print(f"\nGold ({GOLD_DB}): {len(gold_tables)} tables")