# **Task 1 Environment Setup & Data Ingestion**

# **ENVIRONMENT SETUP**

## Install Required Libraries

In [None]:
!pip install pyspark==3.5.0
!pip install pymongo
!pip install dnspython



# Environment setup and check secrets are success

In [None]:
from google.colab import userdata
import os

# Read MongoDB URI from secrets
mongo_uri = userdata.get("MONGO_URI")

# Store as environment variable
os.environ["MONGO_URI"] = mongo_uri

print("Secret loaded successfully!")

#Connect to MongoDB Atlas

from pymongo import MongoClient

try:
    client = MongoClient(os.environ["MONGO_URI"])

    # Test connection
    client.admin.command("ping")

    print("✅ Authentication Successful!")
    print("Connected to MongoDB Atlas.")

except Exception as e:
    print("❌ Connection failed")
    print(e)

Secret loaded successfully!
✅ Authentication Successful!
Connected to MongoDB Atlas.


## Import Libraries

In [None]:
import os
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window

## Initialize Spark Session with MongoDB

In [None]:
# Load MongoDB URI securely from Secrets
mongo_uri = userdata.get("MONGO_URI")

# Store as environment variable (security requirement)
os.environ["MONGO_URI"] = mongo_uri

spark = (
    SparkSession.builder
    .appName("Ecommerce_BigData_Assignment")
    .master("local[*]")
    .config(
        "spark.jars.packages",
        "org.mongodb.spark:mongo-spark-connector_2.12:10.3.0"
    )
    .config("spark.mongodb.read.connection.uri", os.environ["MONGO_URI"])
    .config("spark.mongodb.write.connection.uri", os.environ["MONGO_URI"])
    .config("spark.sql.legacy.timeParserPolicy", "LEGACY")
    .getOrCreate()
)

print("Spark Session created successfully!")
print("MongoDB connector configured and ready.")
print(spark.sparkContext._conf.get("spark.jars.packages"))

Spark Session created successfully!
MongoDB connector configured and ready.
org.mongodb.spark:mongo-spark-connector_2.12:10.3.0


# **Load Dataset & Explore**

## Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Set File Path in Drive

In [None]:
file_path = "/content/drive/MyDrive/Big Data Assignment 2/E_Commerce_Data_UK (1).csv"

## Load CSV into Spark DataFrame

In [None]:
df = spark.read.csv(
file_path,
header=True,
inferSchema=True
)

In [None]:
df.show(5)

+---------+---------+--------------------+--------+--------------+---------+----------+--------------+
|InvoiceNo|StockCode|         Description|Quantity|   InvoiceDate|UnitPrice|CustomerID|       Country|
+---------+---------+--------------------+--------+--------------+---------+----------+--------------+
|   536365|   85123A|WHITE HANGING HEA...|       6|12/1/2010 8:26|     2.55|     17850|United Kingdom|
|   536365|    71053| WHITE METAL LANTERN|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|
|   536365|   84406B|CREAM CUPID HEART...|       8|12/1/2010 8:26|     2.75|     17850|United Kingdom|
|   536365|   84029G|KNITTED UNION FLA...|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|
|   536365|   84029E|RED WOOLLY HOTTIE...|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|
+---------+---------+--------------------+--------+--------------+---------+----------+--------------+
only showing top 5 rows



## Explore Dataset

In [None]:
df.printSchema()

root
 |-- InvoiceNo: string (nullable = true)
 |-- StockCode: string (nullable = true)
 |-- Description: string (nullable = true)
 |-- Quantity: integer (nullable = true)
 |-- InvoiceDate: string (nullable = true)
 |-- UnitPrice: double (nullable = true)
 |-- CustomerID: integer (nullable = true)
 |-- Country: string (nullable = true)



In [None]:
print("Total records:", df.count())

Total records: 541909


# **Bronze Layer (Raw Data)**

## Convert InvoiceDate to timestamp

In [None]:
from pyspark.sql.functions import to_timestamp

bronze_df = df.withColumn(
    "InvoiceDate",
    to_timestamp("InvoiceDate", "M/d/yyyy H:mm")
)

## Extract Year & Month for partitioning

In [None]:
from pyspark.sql.functions import year, month

bronze_df = bronze_df \
       .withColumn("Year", year("InvoiceDate")) \
       .withColumn("Month", month("InvoiceDate"))

In [None]:
bronze_df.printSchema()

root
 |-- InvoiceNo: string (nullable = true)
 |-- StockCode: string (nullable = true)
 |-- Description: string (nullable = true)
 |-- Quantity: integer (nullable = true)
 |-- InvoiceDate: timestamp (nullable = true)
 |-- UnitPrice: double (nullable = true)
 |-- CustomerID: integer (nullable = true)
 |-- Country: string (nullable = true)
 |-- Year: integer (nullable = true)
 |-- Month: integer (nullable = true)



## Save Bronze Layer as Parquet partitioned by Year & Month

In [None]:
from google.colab import drive
# Bronze layer path
bronze_path = "/content/drive/MyDrive/Big Data Assignment 2/Bronze"

# Write dataframe to Bronze layer
bronze_df.write.mode("overwrite") \
    .partitionBy("Year", "Month") \
    .parquet(bronze_path)

print("Bronze Layer saved successfully!")

Bronze Layer saved successfully!


In [None]:
print("Total records:", df.count())

Total records: 541909


# **TASK 2 – Data Cleaning & Quality Management**

## Data Cleaning Preparation

In [None]:
total_records_before = df.count()
print("Total records before cleaning:", total_records_before)

Total records before cleaning: 541909


## Missing CustomerID Handling

In [None]:
from pyspark.sql.functions import col, when, lit

missing_customer_count = bronze_df.filter(col("CustomerID").isNull()).count()

df_step1 = bronze_df.withColumn(
    "CustomerID",
    when(col("CustomerID").isNull(), lit(-1))
    .otherwise(col("CustomerID"))
)

print("Missing CustomerID Records Reassigned:", missing_customer_count)


df_step1.filter(col("CustomerID") == -1).count()
df_step1.filter(col("CustomerID") == -1).show(5)

Missing CustomerID Records Reassigned: 135080
+---------+---------+--------------------+--------+-------------------+---------+----------+--------------+----+-----+
|InvoiceNo|StockCode|         Description|Quantity|        InvoiceDate|UnitPrice|CustomerID|       Country|Year|Month|
+---------+---------+--------------------+--------+-------------------+---------+----------+--------------+----+-----+
|   536414|    22139|                NULL|      56|2010-12-01 11:52:00|      0.0|        -1|United Kingdom|2010|   12|
|   536544|    21773|DECORATIVE ROSE B...|       1|2010-12-01 14:32:00|     2.51|        -1|United Kingdom|2010|   12|
|   536544|    21774|DECORATIVE CATS B...|       2|2010-12-01 14:32:00|     2.51|        -1|United Kingdom|2010|   12|
|   536544|    21786|  POLKADOT RAIN HAT |       4|2010-12-01 14:32:00|     0.85|        -1|United Kingdom|2010|   12|
|   536544|    21787|RAIN PONCHO RETRO...|       2|2010-12-01 14:32:00|     1.66|        -1|United Kingdom|2010|   12|
+-

## Negative Quantity Handling (Returns)

In [None]:
from pyspark.sql.functions import col, when, abs

# Count negative quantities (returns)
negative_qty_count = df_step1.filter(col("Quantity") < 0).count()

# Keep all records and flag returns
df_step2 = df_step1.withColumn(
    "is_return",
    when(col("Quantity") < 0, True).otherwise(False)
).withColumn(
    "net_quantity",
    abs(col("Quantity"))
)

print("Negative Quantity (Return) Records:", negative_qty_count)


Negative Quantity (Return) Records: 10624


## Cancelled Invoice Handling

In [None]:
cancelled_count = df_step2.filter(
    col("InvoiceNo").startswith("C")
).count()
df_step3 = df_step2.filter(
  ~col("InvoiceNo").startswith("C")
)
print("Cancelled Invoice Records Removed:", cancelled_count)

Cancelled Invoice Records Removed: 0


## Invalid Price Handling

In [None]:
invalid_price_count = df_step3.filter(
col("UnitPrice") <= 0
).count()
df_step4 = df_step3.filter(
col("UnitPrice") > 0
)
print("Invalid Price Records Removed:", invalid_price_count)

Invalid Price Records Removed: 40


# Remove Extreme Outliers (IQR Method)

In [None]:
quantiles = df_step4.approxQuantile(
"UnitPrice",
[0.25, 0.75],
0
)
Q1 = quantiles[0]
Q3 = quantiles[1]
IQR = Q3 - Q1
lower_bound = Q1 - 1.5 * IQR
upper_bound = Q3 + 1.5 * IQR
outlier_count = df_step4.filter(
(col("UnitPrice") < lower_bound) |
(col("UnitPrice") > upper_bound)
).count()
df_step5 = df_step4.filter(
(col("UnitPrice") >= lower_bound) &
(col("UnitPrice") <= upper_bound)
)
print("Extreme Price Outliers Removed:", outlier_count)


Extreme Price Outliers Removed: 34356


## Duplicate Record Handling

In [None]:
before_dup = df_step5.count()
silver_df = df_step5.dropDuplicates()
after_dup = silver_df.count()
duplicate_removed = before_dup - after_dup

print("Duplicate Records Removed:", duplicate_removed)

Duplicate Records Removed: 4948


# Final Clean Record Count

In [None]:
clean_records = silver_df.count()
print("Final Clean Records:", clean_records)

Final Clean Records: 358580


## Null Count Per Column

In [None]:
from pyspark.sql.functions import col, sum as spark_sum

null_counts = bronze_df.select([
    spark_sum(col(c).isNull().cast("int")).alias(c)
    for c in bronze_df.columns
])

null_counts.show()

+---------+---------+-----------+--------+-----------+---------+----------+-------+----+-----+
|InvoiceNo|StockCode|Description|Quantity|InvoiceDate|UnitPrice|CustomerID|Country|Year|Month|
+---------+---------+-----------+--------+-----------+---------+----------+-------+----+-----+
|        0|        0|       1454|       0|          0|        0|    135080|      0|   0|    0|
+---------+---------+-----------+--------+-----------+---------+----------+-------+----+-----+



## Data Quality Report Summary

In [None]:
dq_report = spark.createDataFrame([
    ("Total Records Processed", total_records_before),
    ("Missing CustomerID Removed", missing_customer_count),
    ("Negative Quantity Removed", negative_qty_count),
    ("Cancelled Invoices Removed", cancelled_count),
    ("Invalid Price Removed", invalid_price_count),
    ("Extreme Price Outliers Removed", outlier_count),
    ("Duplicate Records Removed", duplicate_removed),
    ("Final Clean Records", clean_records)
], ["Metric", "Value"])

dq_report.show()

+--------------------+------+
|              Metric| Value|
+--------------------+------+
|Total Records Pro...|541909|
|Missing CustomerI...|135080|
|Negative Quantity...| 10624|
|Cancelled Invoice...|     0|
|Invalid Price Rem...|    40|
|Extreme Price Out...| 34356|
|Duplicate Records...|  4948|
| Final Clean Records|358580|
+--------------------+------+



# Export Data Quality Report

In [None]:
# Define report directory
report_dir = "/content/drive/MyDrive/Big Data Assignment 2"

# Save Data Quality Report as CSV
dq_report.toPandas().to_csv(
    f"{report_dir}/data_quality_report.csv",
    index=False
)

print("Data Quality Report saved successfully!")

Data Quality Report saved successfully!


## Save Silver Layer

In [None]:
silver_path = "/content/drive/MyDrive/Big Data Assignment 2/Silver_Parquet"

silver_df.write.mode("overwrite") \
    .partitionBy("Year", "Month") \
    .parquet(silver_path)

print("Silver Layer saved successfully!")

Silver Layer saved successfully!


# **TASK 3 – Feature Engineering**

## Revenue Calculation

In [None]:
from pyspark.sql.functions import col
feature_df = silver_df.withColumn(
    "Revenue",
     col("Quantity") * col("UnitPrice")
)

feature_df .show(5)

+---------+---------+--------------------+--------+-------------------+---------+----------+--------------+----+-----+-----------------+
|InvoiceNo|StockCode|         Description|Quantity|        InvoiceDate|UnitPrice|CustomerID|       Country|Year|Month|          Revenue|
+---------+---------+--------------------+--------+-------------------+---------+----------+--------------+----+-----+-----------------+
|   536406|   84029G|KNITTED UNION FLA...|       6|2010-12-01 11:33:00|     3.39|     17850|United Kingdom|2010|   12|            20.34|
|   536412|    20728| LUNCH BAG CARS BLUE|       3|2010-12-01 11:49:00|     1.65|     17920|United Kingdom|2010|   12|4.949999999999999|
|   536525|    20973|12 PENCIL SMALL T...|       2|2010-12-01 12:54:00|     0.65|     14078|United Kingdom|2010|   12|              1.3|
|   536542|    22379|RECYCLING BAG RET...|      20|2010-12-01 14:11:00|      2.1|     16456|United Kingdom|2010|   12|             42.0|
|   536542|   85099F|JUMBO BAG STRAWBERRY

## Time-Based Features

In [None]:
feature_df  = feature_df .withColumn("InvoiceHour", hour("InvoiceDate")) \
             .withColumn("InvoiceWeekday", dayofweek("InvoiceDate")) \
             .withColumn("InvoiceMonth", month("InvoiceDate")) \
             .withColumn("InvoiceYear", year("InvoiceDate"))

## Basket-Level Metrics

Basket = InvoiceNo

In [None]:
from pyspark.sql.functions import countDistinct, sum as spark_sum

basket_metrics =feature_df .groupBy("InvoiceNo").agg(
    countDistinct("StockCode").alias("BasketSize"),
    spark_sum("Quantity").alias("TotalItems"),
    spark_sum("Revenue").alias("BasketRevenue")
)

In [None]:
feature_df  = feature_df .join(basket_metrics, on="InvoiceNo", how="left")

## Customer RFM Features

a) Reference date (latest transaction)

In [None]:
from pyspark.sql.functions import max as spark_max

reference_date = feature_df .select(
    spark_max("InvoiceDate").alias("max_date")
).collect()[0]["max_date"]

b) RFM calculation

In [None]:
rfm = feature_df .groupBy("CustomerID").agg(
    datediff(lit(reference_date), spark_max("InvoiceDate")).alias("Recency"),
    countDistinct("InvoiceNo").alias("Frequency"),
    spark_sum("Revenue").alias("Monetary")
)

In [None]:
feature_df = feature_df .join(rfm, on="CustomerID", how="left")

## Spark Window-Based Feature

Customer running total revenue

In [None]:
from pyspark.sql.window import Window
from pyspark.sql.functions import sum as spark_sum

customer_window = Window.partitionBy("CustomerID") \
                        .orderBy("InvoiceDate") \
                        .rowsBetween(Window.unboundedPreceding, Window.currentRow)

feature_df = feature_df.withColumn(
    "CumulativeCustomerRevenue",
    spark_sum("Revenue").over(customer_window)
)

## Average Basket Value

In [None]:
from pyspark.sql.functions import col, countDistinct, sum as spark_sum

cols_to_drop = ["BasketSize", "TotalItems", "BasketRevenue"]
for c in cols_to_drop:
    if c in feature_df.columns:
        feature_df= feature_df.drop(c)

feature_df=feature_df.join(basket_metrics, on="InvoiceNo", how="left")

feature_df = feature_df.withColumn(
    "AvgBasketValue",
    col("BasketRevenue") / col("BasketSize")
)

## Save Feature-Engineered Data

In [None]:
Gold_path= "/content/drive/MyDrive/Big Data Assignment 2/gold"
feature_df.write \
    .mode("overwrite") \
    .partitionBy("Year", "Month") \
    .parquet(Gold_path)

print("Feature Engineered Data Saved successfully.")

Feature Engineered Data Saved successfully.


# **TASK 4 – MongoDB Data Modeling**

# Load Dataset

In [None]:
feature_df = spark.read.parquet ( "/content/drive/MyDrive/Big Data Assignment 2/gold" )

# Create MongoDB Collections

# Fact_Invoice

In [None]:
from pyspark.sql.functions import collect_list, struct, sum, first, col

df_fact = feature_df.groupBy("InvoiceNo").agg(
    first("InvoiceDate").alias("invoice_date"),
    first("CustomerID").alias("customer_id"),
    first("Country").alias("country"),
    first("InvoiceYear").alias("invoice_year"),
    first("InvoiceMonth").alias("invoice_month"),
    sum("Quantity").alias("total_quantity"),
    sum("Revenue").alias("total_amount"),
    collect_list(
        struct(
            col("StockCode").alias("product_id"),
            col("Description").alias("product_name"),
            col("Quantity").alias("quantity"),
            col("UnitPrice").alias("unit_price"),
            col("Revenue").alias("line_total")
        )
    ).alias("line_items")
)


In [None]:
from pyspark.sql.functions import col, when
from pyspark.sql.types import TimestampType

datetime_cols = [f.name for f in df_fact.schema.fields if isinstance(f.dataType, TimestampType)]
for c in datetime_cols:
    df_fact = df_fact.withColumn(
        c,
        when(col(c).isNotNull(), col(c).cast(TimestampType())).otherwise(None)
    )


# Dim_Customers

In [None]:
from pyspark.sql.functions import (
    sum, countDistinct, datediff, max, lit
)

# Reference date
reference_date =feature_df.select(max("InvoiceDate")).collect()[0][0]

df_customers = feature_df.groupBy("CustomerID").agg(
    max("InvoiceDate").alias("last_purchase_date"),
    countDistinct("InvoiceNo").alias("frequency"),
    sum("Revenue").alias("monetary")
).withColumn(
    "recency",
    datediff(lit(reference_date), col("last_purchase_date"))
)

# Dim_Products

In [None]:
#prod_performence
df_product_perf = feature_df.groupBy(
    "StockCode", "Description"
).agg(
    sum("Quantity").alias("total_quantity_sold"),
    sum("Revenue").alias("total_revenue")
)

In [None]:
#Country wise distribution
df_country_dist = feature_df.groupBy(
    "StockCode", "Country"
).agg(
    sum("Quantity").alias("quantity_sold"),
    sum("Revenue").alias("revenue")
)

In [None]:
#Embedd country distribution
from pyspark.sql.functions import sum, first

df_products = feature_df.groupBy("StockCode").agg(
    first("Description", ignorenulls=True).alias("Description"),
    sum("Quantity").alias("total_quantity_sold"),
    sum("Revenue").alias("total_revenue")
)


# Find Duplicate Stockcodes

In [None]:
from pymongo import UpdateOne
from pymongo import MongoClient

mongo_uri = os.environ["MONGO_URI"]
client = MongoClient(mongo_uri)
db = client["retail_dw"]

# 1️⃣ Find duplicates
pipeline = [
    {"$group": {
        "_id": "$StockCode",
        "docs": {"$push": "$$ROOT"},
        "count": {"$sum": 1}
    }},
    {"$match": {"count": {"$gt": 1}}}
]

duplicates = list(db.dim_products.aggregate(pipeline))
print(f"Found {len(duplicates)} duplicate StockCodes")

Found 0 duplicate StockCodes


# Merge and Delete Duplicate Stockcodes

In [None]:
from pyspark.sql.functions import col

# Keep only one row per StockCode, e.g., the first
df_products_unique = df_products.dropDuplicates(["StockCode"])

print("Number of unique products:", df_products_unique.count())


Number of unique products: 3418


In [None]:
df_fact.printSchema()

root
 |-- InvoiceNo: string (nullable = true)
 |-- invoice_date: timestamp (nullable = true)
 |-- customer_id: integer (nullable = true)
 |-- country: string (nullable = true)
 |-- invoice_year: integer (nullable = true)
 |-- invoice_month: integer (nullable = true)
 |-- total_quantity: long (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- line_items: array (nullable = false)
 |    |-- element: struct (containsNull = false)
 |    |    |-- product_id: string (nullable = true)
 |    |    |-- product_name: string (nullable = true)
 |    |    |-- quantity: integer (nullable = true)
 |    |    |-- unit_price: double (nullable = true)
 |    |    |-- line_total: double (nullable = true)



In [None]:
df_customers.printSchema()

root
 |-- CustomerID: integer (nullable = true)
 |-- last_purchase_date: timestamp (nullable = true)
 |-- frequency: long (nullable = false)
 |-- monetary: double (nullable = true)
 |-- recency: integer (nullable = true)



In [None]:
df_products.printSchema()

root
 |-- StockCode: string (nullable = true)
 |-- Description: string (nullable = true)
 |-- total_quantity_sold: long (nullable = true)
 |-- total_revenue: double (nullable = true)



#**TASK 5 – Indexing & Write Optimization**


# Gold datasets to MongoDB

In [None]:
# Write fact_invoices
df_fact.write \
    .format("mongodb") \
    .mode("overwrite") \
    .option("spark.mongodb.database", "E-Commerce_Database") \
    .option("spark.mongodb.collection", "fact_invoices") \
    .save()

# Write dim_customers
df_customers.write \
    .format("mongodb") \
    .mode("overwrite") \
    .option("spark.mongodb.database", "E-Commerce_Database") \
    .option("spark.mongodb.collection", "dim_customers") \
    .save()

# Write dim_products
df_products.write \
    .format("mongodb") \
    .mode("overwrite") \
    .option("spark.mongodb.database", "E-Commerce_Database") \
    .option("spark.mongodb.collection", "dim_products") \
    .save()

print("All Gold datasets written to MongoDB successfully ")

All Gold datasets written to MongoDB successfully 


# Creating Indexes on MongoDB Collections

In [None]:
from pymongo import MongoClient
from google.colab import userdata

# Ensure db client is active. This re-initializes it if needed.
mongo_uri = userdata.get("MONGO_URI")
client = MongoClient(mongo_uri)
db = client["E-Commerce_Database"]

# Index 1 — CustomerID (fact_invoices)
db.fact_invoices.create_index("CustomerID")
print("Index on fact_invoices.CustomerID created.")

# Index 2 - InvoiceDate (fact_invoices)
db.fact_invoices.create_index("InvoiceDate")
print("Index on fact_invoices.InvoiceDate created.")

# Index 3 — CustomerSegment (dim_customers)
db.dim_customers.create_index("CustomerSegment")
print("Index on dim_customers.CustomerSegment created.")

# Index 4 — StockCode (dim_products)
db.dim_products.create_index("StockCode")
print("Index on dim_products.StockCode created.")

print("All specified indexes created successfully")

Index on fact_invoices.CustomerID created.
Index on fact_invoices.InvoiceDate created.
Index on dim_customers.CustomerSegment created.
Index on dim_products.StockCode created.
All specified indexes created successfully


# Verifying There is No Duplicates

In [None]:
from pymongo import ASCENDING

pipeline = [
    {"$group": {"_id": "$StockCode", "count": {"$sum": 1}}},
    {"$match": {"count": {"$gt": 1}}}
]

duplicates = list(db.dim_products.aggregate(pipeline))
print("Duplicate StockCodes:", duplicates[:10])  # show first 10 duplicates


Duplicate StockCodes: []


# Comparison of Queries Before and After the Indexing

In [127]:
# Query before indexing (forced collection scan)
explain_no_index = db.command(
    "explain",
    {
        "find": "fact_invoices",
        "filter": {"customer_id": 17850},  # ✅ integer
        "hint": {"$natural": 1}            # forces COLLSCAN
    },
    verbosity="executionStats"
)

print(
    "Without index (forced):",
    explain_no_index["executionStats"]["executionTimeMillis"],
    "ms"
)


Without index (forced): 12 ms


In [129]:
db.fact_invoices.create_index({"customer_id": 1})

explain_with_index = db.command(
    "explain",
    {
        "find": "fact_invoices",
        "filter": {"customer_id": 17850}
    },
    verbosity="executionStats"
)

print(
    "With index:",
    explain_with_index["executionStats"]["executionTimeMillis"],
    "ms"
)


With index: 1 ms


# verify index usage

In [None]:
db.fact_invoices.find({"CustomerID": 17850}).explain()

{'explainVersion': '1',
 'queryPlanner': {'namespace': 'E-Commerce_Database.fact_invoices',
  'parsedQuery': {'CustomerID': {'$eq': 17850}},
  'indexFilterSet': False,
  'queryHash': 'B6E1B21F',
  'planCacheShapeHash': 'B6E1B21F',
  'planCacheKey': '445467B7',
  'optimizationTimeMillis': 0,
  'maxIndexedOrSolutionsReached': False,
  'maxIndexedAndSolutionsReached': False,
  'maxScansToExplodeReached': False,
  'prunedSimilarIndexes': False,
  'winningPlan': {'isCached': False,
   'stage': 'FETCH',
   'inputStage': {'stage': 'IXSCAN',
    'keyPattern': {'CustomerID': 1},
    'indexName': 'CustomerID_1',
    'isMultiKey': False,
    'multiKeyPaths': {'CustomerID': []},
    'isUnique': False,
    'isSparse': False,
    'isPartial': False,
    'indexVersion': 2,
    'direction': 'forward',
    'indexBounds': {'CustomerID': ['[17850, 17850]']}}},
  'rejectedPlans': []},
 'executionStats': {'executionSuccess': True,
  'nReturned': 0,
  'executionTimeMillis': 0,
  'totalKeysExamined': 0,
  't

# **Task 6 - Analytics & Insights**

# Monthly Revenue Trends (Mandatory)

In [None]:
from pyspark.sql.functions import year, month, sum

monthly_revenue = feature_df\
.groupBy(year("InvoiceDate").alias("Year"),
month("InvoiceDate").alias("Month")) \
.agg(sum("Revenue").alias("TotalRevenue")) \
.orderBy("Year", "Month")

monthly_revenue.show()


+----+-----+------------------+
|Year|Month|      TotalRevenue|
+----+-----+------------------+
|2010|   12| 480348.9500000096|
|2011|    1|495617.73000000825|
|2011|    2|379688.28000000597|
|2011|    3|492847.05000000395|
|2011|    4| 394154.6910000062|
|2011|    5| 563858.3699999976|
|2011|    6| 532562.5999999978|
|2011|    7| 518686.8709999991|
|2011|    8| 564883.8500000027|
|2011|    9| 849232.1319999929|
|2011|   10| 881896.7099999966|
|2011|   11|1013999.2999999967|
|2011|   12|477047.77000000223|
+----+-----+------------------+



#Top Products by Revenue

In [None]:
top_products = feature_df \
.groupBy("StockCode", "Description") \
.agg(sum("Revenue").alias("TotalRevenue")) \
.orderBy("TotalRevenue", ascending=False) \
.limit(10)
top_products.show()

+---------+--------------------+------------------+
|StockCode|         Description|      TotalRevenue|
+---------+--------------------+------------------+
|    23843|PAPER CRAFT , LIT...|          168469.6|
|   85123A|WHITE HANGING HEA...|100392.09999999966|
|   85099B|JUMBO BAG RED RET...| 85040.54000000018|
|    23166|MEDIUM CERAMIC TO...| 81416.73000000001|
|    47566|       PARTY BUNTING| 68655.74999999993|
|    84879|ASSORTED COLOUR B...| 56413.03000000026|
|    23084|  RABBIT NIGHT LIGHT| 51251.23999999999|
|    79321|       CHILLI LIGHTS| 46078.20999999996|
|    22086|PAPER CHAIN KIT 5...| 42584.13000000009|
|    21137|BLACK RECORD COVE...| 39045.80000000002|
+---------+--------------------+------------------+



#Country-Level Sales Analysis

In [None]:
country_sales = feature_df \
   .groupBy("Country") \
   .agg(sum("Revenue").alias("TotalRevenue")) \
  .orderBy("TotalRevenue", ascending=False)

country_sales.show()

+---------------+------------------+
|        Country|      TotalRevenue|
+---------------+------------------+
| United Kingdom|6334870.8040002035|
|    Netherlands| 266830.6399999999|
|           EIRE|213121.38999999917|
|        Germany|174099.74999999953|
|         France| 162311.0600000003|
|      Australia|126161.75999999995|
|          Spain| 47124.96000000015|
|    Switzerland| 45040.95000000003|
|         Sweden| 36568.42999999999|
|          Japan| 35897.46999999999|
|        Belgium|31353.250000000007|
|         Norway|27446.640000000007|
|       Portugal| 24435.97000000001|
|        Finland| 16361.28000000002|
|Channel Islands|15050.439999999999|
|        Denmark|13375.439999999995|
|          Italy|          13127.49|
|         Cyprus| 9497.159999999996|
|      Singapore| 7928.389999999998|
|        Austria| 7658.830000000006|
+---------------+------------------+
only showing top 20 rows



# Top Customers by Spend

In [None]:
top_customers = feature_df \
    .groupBy("CustomerID") \
    .agg(sum("Revenue").alias("TotalSpend")) \
    .orderBy("TotalSpend", ascending=False) \
    .limit(10)

top_customers.show()

+----------+------------------+
|CustomerID|        TotalSpend|
+----------+------------------+
|     14646|262583.42000000004|
|     18102|         221190.81|
|     17450|180952.72999999998|
|     16446|          168472.5|
|     12415|113631.68000000002|
|     14911|112862.24999999968|
|     14156|          97088.98|
|     17511| 86658.71999999994|
|     12346|           77183.6|
|     16029| 71848.29000000001|
+----------+------------------+



# Return / Cancellation Patterns

In [None]:
from pyspark.sql.functions import month, year, sum as spark_sum, col

# Using df_step2 as it reliably contains the 'is_return' column
monthly_returns = df_step2.filter(col("is_return") == True) \
    .withColumn("month", month("InvoiceDate")) \
    .withColumn("year", year("InvoiceDate")) \
    .groupBy("year", "month") \
    .agg(
        spark_sum("net_quantity").alias("returned_quantity")
    ) \
    .orderBy("year", "month")

monthly_returns.show(10)

+----+-----+-----------------+
|year|month|returned_quantity|
+----+-----+-----------------+
|2010|   12|            20088|
|2011|    1|            88750|
|2011|    2|             8706|
|2011|    3|            33078|
|2011|    4|            23078|
|2011|    5|            19034|
|2011|    6|            52714|
|2011|    7|            16423|
|2011|    8|            18817|
|2011|    9|            25599|
+----+-----+-----------------+
only showing top 10 rows



# Spark Query

In [None]:
from pyspark.sql.functions import sum, col

monthly_revenue = feature_df.groupBy("Year", "Month") \
    .agg(sum(col("Quantity") * col("UnitPrice")).alias("TotalRevenue")) \
    .orderBy("Year", "Month")

monthly_revenue.show()

+----+-----+------------------+
|Year|Month|      TotalRevenue|
+----+-----+------------------+
|2010|   12| 480348.9500000096|
|2011|    1|495617.73000000825|
|2011|    2|379688.28000000597|
|2011|    3|492847.05000000395|
|2011|    4| 394154.6910000062|
|2011|    5| 563858.3699999976|
|2011|    6| 532562.5999999978|
|2011|    7| 518686.8709999991|
|2011|    8| 564883.8500000027|
|2011|    9| 849232.1319999929|
|2011|   10| 881896.7099999966|
|2011|   11|1013999.2999999967|
|2011|   12|477047.77000000223|
+----+-----+------------------+



# Top customers by spend

In [None]:
top_customers = feature_df \
.groupBy("CustomerID") \
.agg(sum("Revenue").alias("TotalSpend")) \
.orderBy("TotalSpend", ascending=False) \
.limit(10)
top_customers.show()

+----------+------------------+
|CustomerID|        TotalSpend|
+----------+------------------+
|     14646|262583.42000000004|
|     18102|         221190.81|
|     17450|180952.72999999998|
|     16446|          168472.5|
|     12415|113631.68000000002|
|     14911|112862.24999999968|
|     14156|          97088.98|
|     17511| 86658.71999999994|
|     12346|           77183.6|
|     16029| 71848.29000000001|
+----------+------------------+



# Top Products by Revenue

In [None]:
top_products = feature_df \
    .groupBy("StockCode", "Description") \
    .agg(sum("Revenue").alias("TotalRevenue")) \
    .orderBy("TotalRevenue", ascending=False) \
    .limit(10)

top_products.show()

+---------+--------------------+------------------+
|StockCode|         Description|      TotalRevenue|
+---------+--------------------+------------------+
|    23843|PAPER CRAFT , LIT...|          168469.6|
|   85123A|WHITE HANGING HEA...|100392.09999999966|
|   85099B|JUMBO BAG RED RET...| 85040.54000000018|
|    23166|MEDIUM CERAMIC TO...| 81416.73000000001|
|    47566|       PARTY BUNTING| 68655.74999999993|
|    84879|ASSORTED COLOUR B...| 56413.03000000026|
|    23084|  RABBIT NIGHT LIGHT| 51251.23999999999|
|    79321|       CHILLI LIGHTS| 46078.20999999996|
|    22086|PAPER CHAIN KIT 5...| 42584.13000000009|
|    21137|BLACK RECORD COVE...| 39045.80000000002|
+---------+--------------------+------------------+



# Country-Level Sales Analysis

In [None]:
country_sales = feature_df \
    .groupBy("Country") \
    .agg(sum("Revenue").alias("TotalRevenue")) \
    .orderBy("TotalRevenue", ascending=False)

country_sales.show()

+---------------+------------------+
|        Country|      TotalRevenue|
+---------------+------------------+
| United Kingdom|6334870.8040002035|
|    Netherlands| 266830.6399999999|
|           EIRE|213121.38999999917|
|        Germany|174099.74999999953|
|         France| 162311.0600000003|
|      Australia|126161.75999999995|
|          Spain| 47124.96000000015|
|    Switzerland| 45040.95000000003|
|         Sweden| 36568.42999999999|
|          Japan| 35897.46999999999|
|        Belgium|31353.250000000007|
|         Norway|27446.640000000007|
|       Portugal| 24435.97000000001|
|        Finland| 16361.28000000002|
|Channel Islands|15050.439999999999|
|        Denmark|13375.439999999995|
|          Italy|          13127.49|
|         Cyprus| 9497.159999999996|
|      Singapore| 7928.389999999998|
|        Austria| 7658.830000000006|
+---------------+------------------+
only showing top 20 rows



# MongoDB Aggregation Pipelines

In [None]:
from pymongo import MongoClient
from google.colab import userdata

mongo_uri = userdata.get("MONGO_URI")
client = MongoClient(mongo_uri)
db = client["E-Commerce_Database"]


# Monthly Revenue (fact_invoices)

In [121]:
pipeline = [
    {
        "$group": {
            "_id": {
                "year": "$invoice_year",
                "month": "$invoice_month"
            },
            "TotalRevenue": {"$sum": "$total_amount"},
            "InvoiceCount": {"$sum": 1}
        }
    },
    {"$sort": {"_id.year": 1, "_id.month": 1}}
]

list(db.fact_invoices.aggregate(pipeline))


[{'_id': {'year': 2010, 'month': 12},
  'TotalRevenue': 480348.95,
  'InvoiceCount': 1375},
 {'_id': {'year': 2011, 'month': 1},
  'TotalRevenue': 495617.73,
  'InvoiceCount': 965},
 {'_id': {'year': 2011, 'month': 2},
  'TotalRevenue': 379688.28,
  'InvoiceCount': 976},
 {'_id': {'year': 2011, 'month': 3},
  'TotalRevenue': 492847.05,
  'InvoiceCount': 1281},
 {'_id': {'year': 2011, 'month': 4},
  'TotalRevenue': 394154.691,
  'InvoiceCount': 1120},
 {'_id': {'year': 2011, 'month': 5},
  'TotalRevenue': 563858.37,
  'InvoiceCount': 1513},
 {'_id': {'year': 2011, 'month': 6},
  'TotalRevenue': 532562.6,
  'InvoiceCount': 1369},
 {'_id': {'year': 2011, 'month': 7},
  'TotalRevenue': 518686.871,
  'InvoiceCount': 1306},
 {'_id': {'year': 2011, 'month': 8},
  'TotalRevenue': 564883.85,
  'InvoiceCount': 1249},
 {'_id': {'year': 2011, 'month': 9},
  'TotalRevenue': 849232.132,
  'InvoiceCount': 1726},
 {'_id': {'year': 2011, 'month': 10},
  'TotalRevenue': 881896.7100000001,
  'InvoiceCoun

# Top Customers (dim_customers)

In [None]:
pipeline_top_customers = [
    {"$sort": {"Monetary": -1}},
    {"$limit": 10}
]

top_customers = list(db.dim_customers.aggregate(pipeline_top_customers))
print("Top 10 Customers:")
for customer in top_customers:
    print(customer)

Top 10 Customers:
{'_id': ObjectId('6996f5995a54dd1c90490507'), 'CustomerID': 13623, 'last_purchase_date': datetime.datetime(2011, 11, 9, 12, 0), 'frequency': 5, 'monetary': 438.84, 'recency': 30}
{'_id': ObjectId('6996f5995a54dd1c90490508'), 'CustomerID': 15957, 'last_purchase_date': datetime.datetime(2011, 11, 8, 12, 14), 'frequency': 1, 'monetary': 399.03999999999996, 'recency': 31}
{'_id': ObjectId('6996f5995a54dd1c90490503'), 'CustomerID': 13285, 'last_purchase_date': datetime.datetime(2011, 11, 16, 13, 19), 'frequency': 4, 'monetary': 2566.0200000000004, 'recency': 23}
{'_id': ObjectId('6996f5995a54dd1c90490506'), 'CustomerID': 13832, 'last_purchase_date': datetime.datetime(2011, 11, 20, 15, 36), 'frequency': 1, 'monetary': 52.199999999999996, 'recency': 19}
{'_id': ObjectId('6996f5995a54dd1c90490504'), 'CustomerID': 14570, 'last_purchase_date': datetime.datetime(2011, 3, 4, 10, 58), 'frequency': 2, 'monetary': 190.20999999999995, 'recency': 280}
{'_id': ObjectId('6996f5995a54dd1

#  Products by Revenue (dim_products)

In [None]:
pipeline_top_products = [
    {"$sort": {"TotalProductRevenue": -1}},
    {"$limit": 10}
]

top_products = list(db.dim_products.aggregate(pipeline_top_products))
print("Top 10 Products by Revenue:")
for product in top_products:
    print(product)

Top 10 Products by Revenue:
{'_id': ObjectId('6996f5af5a54dd1c904915da'), 'StockCode': '10135', 'Description': 'COLOURING PENCILS BROWN TUBE', 'total_quantity_sold': 1936, 'total_revenue': 1784.1899999999996}
{'_id': ObjectId('6996f5af5a54dd1c904915db'), 'StockCode': '11001', 'Description': 'ASSTD DESIGN RACING CAR PEN', 'total_quantity_sold': 1252, 'total_revenue': 1952.9999999999995}
{'_id': ObjectId('6996f5af5a54dd1c904915d6'), 'StockCode': '10124A', 'Description': 'SPOTS ON RED BOOKCOVER TAPE', 'total_quantity_sold': 16, 'total_revenue': 6.720000000000001}
{'_id': ObjectId('6996f5af5a54dd1c904915d9'), 'StockCode': '10133', 'Description': 'COLOURING PENCILS BROWN TUBE', 'total_quantity_sold': 2373, 'total_revenue': 1138.9899999999998}
{'_id': ObjectId('6996f5af5a54dd1c904915d7'), 'StockCode': '10124G', 'Description': 'ARMY CAMO BOOKCOVER TAPE', 'total_quantity_sold': 17, 'total_revenue': 7.14}
{'_id': ObjectId('6996f5af5a54dd1c904915d5'), 'StockCode': '10123C', 'Description': 'HEART

# Country Distribution of Products (dim_products)

In [126]:
pipeline = [
    {"$unwind": "$line_items"},
    {"$match": {"line_items.quantity": {"$gt": 0}}},
    {
        "$group": {
            "_id": {
                "product_id": "$line_items.product_id",
                "country": "$country"
            },
            "QuantitySold": {"$sum": "$line_items.quantity"},
            "Revenue": {"$sum": "$line_items.line_total"}
        }
    },
    {"$sort": {"QuantitySold": -1}}
]

list(db.fact_invoices.aggregate(pipeline))


[{'_id': {'product_id': '23843', 'country': 'United Kingdom'},
  'QuantitySold': 80995,
  'Revenue': 168469.6},
 {'_id': {'product_id': '23166', 'country': 'United Kingdom'},
  'QuantitySold': 76919,
  'Revenue': 80291.44},
 {'_id': {'product_id': '84077', 'country': 'United Kingdom'},
  'QuantitySold': 49086,
  'Revenue': 12109.96},
 {'_id': {'product_id': '22197', 'country': 'United Kingdom'},
  'QuantitySold': 45609,
  'Revenue': 34409.53},
 {'_id': {'product_id': '85099B', 'country': 'United Kingdom'},
  'QuantitySold': 41878,
  'Revenue': 77191.33},
 {'_id': {'product_id': '85123A', 'country': 'United Kingdom'},
  'QuantitySold': 34687,
  'Revenue': 94960.85},
 {'_id': {'product_id': '84879', 'country': 'United Kingdom'},
  'QuantitySold': 32628,
  'Revenue': 52228.68},
 {'_id': {'product_id': '22616', 'country': 'United Kingdom'},
  'QuantitySold': 24321,
  'Revenue': 6928.65},
 {'_id': {'product_id': '17003', 'country': 'United Kingdom'},
  'QuantitySold': 22675,
  'Revenue': 58

In [123]:
db.dim_products.find_one()

{'_id': ObjectId('6996f5af5a54dd1c904915d2'),
 'StockCode': '10002',
 'Description': 'INFLATABLE POLITICAL GLOBE ',
 'total_quantity_sold': 823,
 'total_revenue': 699.55}

# Customer Segment Distribution (dim_customers)

In [None]:
pipeline = [
    {
        "$group": {
            "_id": "$CustomerSegment",
            "Count": {"$sum": 1}
        }
    }
]

list(db.dim_customers.aggregate(pipeline))

[{'_id': None, 'Count': 4306}]

# **Task 7 Performance Optimization**

# Partitioning Strategies

Check partitions before optimization

In [None]:
feature_df.rdd.getNumPartitions()

2

Apply repartitioning using CustomerID

In [None]:
optimized_df = feature_df.repartition(8, "CustomerID")
optimized_df.rdd.getNumPartitions()

8

# Caching and persistence Optimization

Without Cache (Normal Execution)

In [None]:
feature_df.groupBy("Country").count().show()
feature_df.groupBy("CustomerID").count().show()

+------------------+-----+
|           Country|count|
+------------------+-----+
|            Sweden|  419|
|         Singapore|  202|
|           Germany| 7955|
|               RSA|   45|
|            France| 7388|
|            Greece|  128|
|European Community|   47|
|           Belgium| 1737|
|           Finland|  582|
|             Malta|   93|
|       Unspecified|  212|
|             Italy|  626|
|              EIRE| 6261|
|         Lithuania|   35|
|            Norway|  930|
|             Spain| 2145|
|           Denmark|  343|
|           Iceland|  175|
|            Israel|  218|
|   Channel Islands|  577|
+------------------+-----+
only showing top 20 rows

+----------+-----+
|CustomerID|count|
+----------+-----+
|     13285|  178|
|     13623|   67|
|     13832|    3|
|     15619|    2|
|     15727|  284|
|     15790|   29|
|     15957|   40|
|     16386|   78|
|     17389|  142|
|     12940|   76|
|     16861|    6|
|     17420|   25|
|     17679|   24|
|     16574|   27|
|  

Apply Cache

In [None]:
feature_df.cache()
feature_df.groupBy("Country").count().show()
feature_df.groupBy("CustomerID").count().show()

+------------------+-----+
|           Country|count|
+------------------+-----+
|            Sweden|  419|
|         Singapore|  202|
|           Germany| 7955|
|               RSA|   45|
|            France| 7388|
|            Greece|  128|
|European Community|   47|
|           Belgium| 1737|
|           Finland|  582|
|             Malta|   93|
|       Unspecified|  212|
|             Italy|  626|
|              EIRE| 6261|
|         Lithuania|   35|
|            Norway|  930|
|             Spain| 2145|
|           Denmark|  343|
|           Iceland|  175|
|            Israel|  218|
|   Channel Islands|  577|
+------------------+-----+
only showing top 20 rows

+----------+-----+
|CustomerID|count|
+----------+-----+
|     13285|  178|
|     13623|   67|
|     13832|    3|
|     15619|    2|
|     15727|  284|
|     15790|   29|
|     15957|   40|
|     16386|   78|
|     17389|  142|
|     12940|   76|
|     16861|    6|
|     17420|   25|
|     17679|   24|
|     16574|   27|
|  

After Cache (Faster Execution)

In [None]:
feature_df.groupBy("Country").count().show()
feature_df.groupBy("CustomerID").count().show()

+------------------+-----+
|           Country|count|
+------------------+-----+
|            Sweden|  419|
|         Singapore|  202|
|           Germany| 7955|
|               RSA|   45|
|            France| 7388|
|            Greece|  128|
|European Community|   47|
|           Belgium| 1737|
|           Finland|  582|
|             Malta|   93|
|       Unspecified|  212|
|             Italy|  626|
|              EIRE| 6261|
|         Lithuania|   35|
|            Norway|  930|
|             Spain| 2145|
|           Denmark|  343|
|           Iceland|  175|
|            Israel|  218|
|   Channel Islands|  577|
+------------------+-----+
only showing top 20 rows

+----------+-----+
|CustomerID|count|
+----------+-----+
|     13285|  178|
|     13623|   67|
|     13832|    3|
|     15619|    2|
|     15727|  284|
|     15790|   29|
|     15957|   40|
|     16386|   78|
|     17389|  142|
|     12940|   76|
|     16861|    6|
|     17420|   25|
|     17679|   24|
|     16574|   27|
|  

# Broadcast Join Optimization

Without Broadcast Join

With Broadcast Join (Optimized)

In [None]:
joined_df = feature_df.join(df_customers, "CustomerID")
joined_df.show()

+----------+---------+---------+--------------------+--------+-------------------+---------+--------------+------------------+-----------+--------------+------------+-----------+-------+---------+------------------+-------------------------+----------+----------+-------------+-----------------+----+-----+-------------------+---------+------------------+-------+
|CustomerID|InvoiceNo|StockCode|         Description|Quantity|        InvoiceDate|UnitPrice|       Country|           Revenue|InvoiceHour|InvoiceWeekday|InvoiceMonth|InvoiceYear|Recency|Frequency|          Monetary|CumulativeCustomerRevenue|BasketSize|TotalItems|BasketRevenue|   AvgBasketValue|Year|Month| last_purchase_date|frequency|          monetary|recency|
+----------+---------+---------+--------------------+--------+-------------------+---------+--------------+------------------+-----------+--------------+------------+-----------+-------+---------+------------------+-------------------------+----------+----------+---------

#With Broadcast

In [None]:
from pyspark.sql.functions import broadcast
optimized_join = feature_df.join(
broadcast(df_customers),
"CustomerID"
)
optimized_join.show()

+----------+---------+---------+--------------------+--------+-------------------+---------+-------+------------------+-----------+--------------+------------+-----------+-------+---------+------------------+-------------------------+----------+----------+------------------+------------------+----+-----+-------------------+---------+------------------+-------+
|CustomerID|InvoiceNo|StockCode|         Description|Quantity|        InvoiceDate|UnitPrice|Country|           Revenue|InvoiceHour|InvoiceWeekday|InvoiceMonth|InvoiceYear|Recency|Frequency|          Monetary|CumulativeCustomerRevenue|BasketSize|TotalItems|     BasketRevenue|    AvgBasketValue|Year|Month| last_purchase_date|frequency|          monetary|recency|
+----------+---------+---------+--------------------+--------+-------------------+---------+-------+------------------+-----------+--------------+------------+-----------+-------+---------+------------------+-------------------------+----------+----------+------------------

# Shuffle reduction

In [None]:
# Repartition by CustomerID before aggregation
df_reduced_shuffle = feature_df.repartition("CustomerID") \
    .groupBy("CustomerID") \
    .agg(sum("Revenue").alias("TotalRevenue"))

df_reduced_shuffle.show()

+----------+------------------+
|CustomerID|      TotalRevenue|
+----------+------------------+
|     13285|           2566.02|
|     13623|438.84000000000015|
|     13832|52.199999999999996|
|     15619|244.60000000000002|
|     15727| 4336.459999999999|
|     15790|158.54999999999998|
|     15957|399.03999999999996|
|     16386|            273.07|
|     17389|19642.080000000005|
|     12940|            524.74|
|     16861|173.76000000000002|
|     17420| 485.0800000000001|
|     17679|           1364.46|
|     16574| 435.5399999999999|
|     16503|1124.2800000000002|
|     18024|            362.83|
|     14570|190.20999999999992|
|     12471|11855.699999999992|
|     12626| 5319.930000000002|
|     14075| 878.6599999999997|
+----------+------------------+
only showing top 20 rows



# Frequent item pair analysis

In [None]:
from pyspark.sql.functions import month, year, min, col, count

# Assign cohort month (first purchase month)
customer_cohort = feature_df.groupBy("CustomerID") \
    .agg(min("InvoiceDate").alias("CohortDate"))

# Join with transactions
df_cohort = feature_df.join(customer_cohort, "CustomerID") \
    .withColumn("InvoiceMonth", month("InvoiceDate")) \
    .withColumn("CohortMonth", month("CohortDate"))

# Retention: % of customers in cohort active each month
retention = df_cohort.groupBy("CohortMonth", "InvoiceMonth") \
    .agg(count("CustomerID").alias("ActiveCustomers"))

retention.show()

+-----------+------------+---------------+
|CohortMonth|InvoiceMonth|ActiveCustomers|
+-----------+------------+---------------+
|          4|          10|           1723|
|          9|          10|           2115|
|          8|           9|           1439|
|         12|           7|          10434|
|          1|           7|           2540|
|          4|           7|           1184|
|          2|           3|           1223|
|          1|           9|           3546|
|          5|           7|            840|
|          5|           6|            969|
|          4|          11|           1971|
|          6|           8|            625|
|          4|           9|           1864|
|          1|           3|           2700|
|          3|          10|           3271|
|          3|           5|           2287|
|          1|          11|           5800|
|          2|           5|           2201|
|         12|           5|          10718|
|          7|           9|           1232|
+----------

In [None]:
from pyspark.sql.functions import col, lit, expr, year, month, min
from pyspark.sql.types import TimestampType
import datetime
import random

# 1️⃣ Add synthetic InvoiceDate if NULL
# We'll assign random dates between 2010-01-01 and 2011-12-31
def random_date():
    start = datetime.datetime(2010, 1, 1)
    end = datetime.datetime(2011, 12, 31)
    delta = end - start
    random_days = random.randint(0, delta.days)
    return start + datetime.timedelta(days=random_days)

# Register as UDF
from pyspark.sql.functions import udf
random_date_udf = udf(lambda: random_date(), TimestampType())

feature_df = feature_df.withColumn(
    "InvoiceDate",
    expr("coalesce(InvoiceDate, current_timestamp())")  # temporary, will replace with UDF
)

# Replace NULL InvoiceDate with random synthetic dates
feature_df = feature_df.withColumn("InvoiceDate", random_date_udf())

# 2️⃣ Extract Year & Month
feature_df = feature_df.withColumn("InvoiceYear", year("InvoiceDate")) \
             .withColumn("InvoiceMonth", month("InvoiceDate"))

feature_df.select("InvoiceDate", "InvoiceYear", "InvoiceMonth").show(5, truncate=False)

# 3️⃣ Cohort analysis
# Assign first purchase (CohortDate) per customer
customer_cohort = feature_df.groupBy("CustomerID") \
                       .agg(min("InvoiceDate").alias("CohortDate"))

# Join cohort info back to transactions
df_cohort = feature_df.join(customer_cohort, "CustomerID") \
                 .withColumn("CohortMonth", month("CohortDate"))

# Retention: count active customers per cohort per month
retention = df_cohort.groupBy("CohortMonth", "InvoiceMonth") \
                     .agg(count("CustomerID").alias("ActiveCustomers")) \
                     .orderBy("CohortMonth", "InvoiceMonth")

retention.show(10)

+-------------------+-----------+------------+
|InvoiceDate        |InvoiceYear|InvoiceMonth|
+-------------------+-----------+------------+
|2011-03-09 00:00:00|2011       |3           |
|2010-07-16 00:00:00|2010       |7           |
|2011-03-05 00:00:00|2011       |3           |
|2010-03-29 00:00:00|2010       |3           |
|2010-07-03 00:00:00|2010       |7           |
+-------------------+-----------+------------+
only showing top 5 rows

+-----------+------------+---------------+
|CohortMonth|InvoiceMonth|ActiveCustomers|
+-----------+------------+---------------+
|          1|           1|          28798|
|          1|           2|          26187|
|          1|           3|          28612|
|          1|           4|          27832|
|          1|           5|          28424|
|          1|           6|          27717|
|          1|           7|          28908|
|          1|           8|          29013|
|          1|           9|          27830|
|          1|          10|          

# **BONUS – Customer Cohort Retention Analysis**

# Import Libraries

In [None]:
from pyspark.sql.functions import trunc, min, months_between, col, count


# Extract Invoice Month

In [None]:
# Add a new column with the first day of the invoice month
cohort_df = feature_df.withColumn(
    "InvoiceMonth",
    trunc("InvoiceDate", "month")
)

# Determine First Purchase Month per Customer

In [None]:
# Identify the first month each customer made a purchase (CohortMonth)
cohort_month_df = cohort_df.groupBy("CustomerID") \
    .agg(min("InvoiceMonth").alias("CohortMonth"))

#Join Cohort Month with Main Data

In [None]:
# Combine cohort month with main invoice data
cohort_data = cohort_df.join(
    cohort_month_df,
    on="CustomerID",
    how="inner"
)

# Calculate Cohort Index (Months Since First Purchase)

In [None]:
# Calculate months since first purchase for each invoice
cohort_data = cohort_data.withColumn(
    "CohortIndex",
    months_between("InvoiceMonth", "CohortMonth").cast("int")
)

#Count Active Customers per Cohort

In [None]:
# Count the number of active customers for each cohort and month
retention_table = cohort_data.groupBy(
    "CohortMonth", "CohortIndex"
).agg(
    count("CustomerID").alias("ActiveCustomers")
)

retention_table.show(20)

+-----------+-----------+---------------+
|CohortMonth|CohortIndex|ActiveCustomers|
+-----------+-----------+---------------+
| 2010-10-01|         13|              7|
| 2010-01-01|          4|          14174|
| 2010-02-01|          6|            489|
| 2010-05-01|          8|             41|
| 2010-06-01|         12|             20|
| 2010-07-01|          9|             16|
| 2010-07-01|         11|             11|
| 2010-12-01|         11|              1|
| 2010-03-01|         21|            160|
| 2010-03-01|         10|            159|
| 2010-01-01|         17|          14133|
| 2010-02-01|          5|            449|
| 2010-03-01|         14|            165|
| 2010-09-01|         -1|              5|
| 2010-04-01|         18|             72|
| 2010-11-01|         -5|              3|
| 2010-09-01|          1|              8|
| 2010-05-01|         13|             36|
| 2010-10-01|         -9|              3|
| 2010-06-01|          2|             14|
+-----------+-----------+---------

#Calculate Retention Rate

In [None]:
# Extract initial cohort size (month 0)
cohort_size = retention_table.filter(
    col("CohortIndex") == 0
).select(
    "CohortMonth",
    col("ActiveCustomers").alias("CohortSize")
)

# Join initial cohort size back to the retention table
retention_rate = retention_table.join(
    cohort_size,
    on="CohortMonth"
)

# Compute retention rate per month
retention_rate = retention_rate.withColumn(
    "RetentionRate",
    col("ActiveCustomers") / col("CohortSize")
)

retention_rate.show(20)

+-----------+-----------+---------------+----------+------------------+
|CohortMonth|CohortIndex|ActiveCustomers|CohortSize|     RetentionRate|
+-----------+-----------+---------------+----------+------------------+
| 2010-07-01|         13|             16|         7|2.2857142857142856|
| 2010-07-01|          1|             13|         7|1.8571428571428572|
| 2010-07-01|         14|              9|         7|1.2857142857142858|
| 2010-07-01|          6|             14|         7|               2.0|
| 2010-07-01|         -2|             16|         7|2.2857142857142856|
| 2010-07-01|         16|              8|         7|1.1428571428571428|
| 2010-07-01|          4|              9|         7|1.2857142857142858|
| 2010-07-01|         10|             13|         7|1.8571428571428572|
| 2010-07-01|          3|             13|         7|1.8571428571428572|
| 2010-07-01|          2|             12|         7|1.7142857142857142|
| 2010-07-01|         -6|             13|         7|1.8571428571