In [0]:
# Mount gold zone
gold_container_name = "gold-zone-bdm"
gold_mount_point = "/mnt/gold-zone-bdm"
account_key = "rk/ZtM2OOzNHT9/V8zTSAzvOUN4BkAuCY+GetsdhuIAQM0Ai70D9lWxZLemC5T72U3FImY1CmCzR+ASt1EnOtg=="

if not any(mount.mountPoint == gold_mount_point for mount in dbutils.fs.mounts()):
    dbutils.fs.mount(
        source=f"wasbs://{gold_container_name}@{storage_account_name}.blob.core.windows.net",
        mount_point=gold_mount_point,
        extra_configs={
            f"fs.azure.account.key.{storage_account_name}.blob.core.windows.net": account_key
        }
    )
    print(f"Mounted {gold_container_name} to {gold_mount_point}")
else:
    print(f"Mount point {gold_mount_point} already exists")

## Mounting Storage
This section above mounts the Azure Blob Storage container (`gold-zone-bdm`) to `/mnt/gold-zone-bdm` in Databricks.
- Uses `dbutils.fs.mount()` with an account key for secure access.
- Checks if the mount point exists to avoid remounting, ensuring efficiency.
- Note: Replace `storage_account_name` with your actual Azure storage account name, and `account_key` with a secure credential (provided here as an example; store securely in Databricks secrets in production).

In [0]:
#sparkDF = spark.read.format("delta").load("/mnt/gold-zone-bdm/enriched_transactions").limit(1000)
#sparkDF.write.format("parquet").mode("overwrite").save("/mnt/gold-zone-bdm/subset_enriched_transactions")

## Step 2: Load and Build Graph (M4)
Loads enriched transaction data from a Delta table, samples 100,000 records (~0.35% of 28M) for the PoC, and constructs a property graph.
- Sampling uses a fixed seed (42) for reproducibility.
- Derives `age_group` for customer segmentation.
- Creates vertices (Customer, Article) and edges (PURCHASE) with optimized partitioning (20 partitions) and persistence for performance.

In [0]:
from pyspark.sql.functions import col, when, lit
from graphframes import GraphFrame
import time

# Start time for logging
start_time = time.time()

# Step 2: Load and Sample Data, Build Graph (M4)
print("Starting Step 2: Loading and sampling data...")
sparkDF = spark.read.format("delta").load("/mnt/gold-zone-bdm/enriched_transactions")
# Sample 100,000 records (~0.35% of 28M)
sparkDF = sparkDF.sample(fraction=0.0035, seed=42).limit(100000)
print(f"Sampled {sparkDF.count()} records")

# Derive age_group
sparkDF = sparkDF.withColumn(
    "age_group",
    when(col("age") < 30, "Youth")
    .when((col("age") >= 30) & (col("age") <= 50), "Adult")
    .otherwise("Senior")
)

Starting Step 2: Loading and sampling data...
Sampled 99620 records


In [0]:
# Create vertices with aligned schemas
print("Creating vertices...")
customer_vertices = sparkDF.select(
    col("customer_id").alias("id"),
    lit("Customer").alias("type"),
    col("age_group").alias("property1"),
    lit(None).cast("string").alias("property2")
).distinct()

article_vertices = sparkDF.select(
    col("article_id").alias("id"),
    lit("Article").alias("type"),
    col("prod_name").alias("property1"),
    col("color_category").alias("property2")
).distinct()

vertices = customer_vertices.union(article_vertices)

# Create edges
print("Creating edges...")
edges = sparkDF.select(
    col("customer_id").alias("src"),
    col("article_id").alias("dst"),
    col("t_dat"),
    col("price"),
    col("transaction_year"),
    col("transaction_month")
).withColumn("relationship", lit("PURCHASE"))

# Optimize with smaller partitions for the sample
vertices = vertices.repartition(20).persist()
edges = edges.repartition(20).persist()
print(f"Vertices: {vertices.count()}, Edges: {edges.count()}")

# Create GraphFrame
graph = GraphFrame(vertices, edges)
print(f"Step 2 completed in {(time.time() - start_time) / 60:.2f} minutes")

Creating vertices...
Creating edges...
Vertices: 123502, Edges: 99620
Step 2 completed in 0.04 minutes




%md
## Step 3: Run Graph Analytics (M5)
Executes graph analytics on the constructed graph:
- **PageRank**: Identifies trending articles (3 iterations, resetProbability=0.15) to support inventory optimization.
- **Connected Components**: Segments customers into communities for demand forecasting, using a checkpoint directory for Spark.
- Outputs top 10 results for validation.

In [0]:
# Step 3: Run Analytics (M5)
print("Starting Step 3: Running graph analytics...")
step3_start = time.time()

print("Running PageRank...")
pagerank_start = time.time()
pagerank_result = graph.pageRank(resetProbability=0.15, maxIter=3)
trending_articles = pagerank_result.vertices.filter(col("type") == "Article").select(
    col("id").alias("article_id"),
    col("property1").alias("prod_name"),
    col("property2").alias("color_category"),
    col("pagerank")
).orderBy(col("pagerank").desc()).limit(10)
print(f"PageRank completed in {(time.time() - pagerank_start) / 60:.2f} minutes")

print("Setting checkpoint directory...")
spark.sparkContext.setCheckpointDir("dbfs:/tmp/checkpoints")

print("Running Connected Components...")
cc_start = time.time()
cc_result = graph.connectedComponents()
customer_segments = cc_result.filter(col("type") == "Customer").select(
    col("id").alias("customer_id"),
    col("component").alias("community")
).limit(10)
print(f"Connected Components completed in {(time.time() - cc_start) / 60:.2f} minutes")

print("PageRank Results:")
trending_articles.show()
print("Connected Components Results:")
customer_segments.show()
print(f"Step 3 completed in {(time.time() - step3_start) / 60:.2f} minutes")

Starting Step 3: Running graph analytics...
Running PageRank...




PageRank completed in 0.27 minutes
Setting checkpoint directory...
Running Connected Components...
Connected Components completed in 4.04 minutes
PageRank Results:
+----------+--------------------+----------------+------------------+
|article_id|           prod_name|  color_category|          pagerank|
+----------+--------------------+----------------+------------------+
| 706016001|Jade HW Skinny De...|      Black Dark| 72.90340895414506|
| 706016002|Jade HW Skinny De...|Light Blue Light| 66.34432907630007|
| 610776002|               Tilly|      Black Dark| 48.91214362914148|
| 372860002|  7p Basic Shaftless|     White Light| 44.46605592670959|
| 759871002|          Tilda tank|      Black Dark| 42.83729112482861|
| 610776001|               Tilly|     White Light|42.661208443544155|
| 372860001|  7p Basic Shaftless|      Black Dark|40.328112916525455|
| 464297007|Greta Thong Mynta...|      Black Dark| 39.79986487267216|
| 399223001|Curvy Jeggings HW...|      Black Dark|  34.34130175285

## Step 4: Store Metadata (M6)
Generates and stores metadata from analytics results:
- PageRank scores for articles and community IDs for customers.
- Saves to a Delta table (`/mnt/gold-zone-bdm/graph_metadata`) for reuse in downstream processes (e.g., forecasting, inventory).

In [0]:
# Step 4: Store Metadata (M6)
print("Starting Step 4: Storing metadata...")
pagerank_metadata = trending_articles.select(
    col("article_id").alias("entity_id"),
    lit("Article").alias("entity_type"),
    lit("PageRank").alias("metric"),
    col("pagerank").cast("double").alias("value")
)

community_metadata = customer_segments.select(
    col("customer_id").alias("entity_id"),
    lit("Customer").alias("entity_type"),
    lit("Community").alias("metric"),
    col("community").cast("double").alias("value")
)

metadata_df = pagerank_metadata.union(community_metadata)
metadata_df.write.format("delta").mode("overwrite").save("/mnt/gold-zone-bdm/graph_metadata")
print("Metadata saved to Delta")

Starting Step 4: Storing metadata...
Metadata saved to Delta


## Step 5: Verify Results (M7)
Verifies the PoC by reloading metadata, joining with article details, and displaying final outputs.
- Combines article metadata with `prod_name` and `color_category` for context.
- Ensures end-to-end functionality with a sample output of trending articles and customer segments.

In [0]:
# Step 5: Verify (M7)
print("Starting Step 5: Verifying results...")
metadata_df = spark.read.format("delta").load("/mnt/gold-zone-bdm/graph_metadata")

# Join with trending_articles to include prod_name and color_category for articles
articles_with_details = metadata_df.filter(col("entity_type") == "Article")\
    .join(trending_articles, metadata_df.entity_id == trending_articles.article_id, "left")\
    .select(
        metadata_df.entity_id,
        metadata_df.entity_type,
        metadata_df.metric,
        metadata_df.value,
        trending_articles.prod_name,
        trending_articles.color_category
    )

# Customers remain as-is
customers = metadata_df.filter(col("entity_type") == "Customer")

# Combine and display
final_output = articles_with_details.union(customers.select("entity_id", "entity_type", "metric", "value", lit(None).alias("prod_name"), lit(None).alias("color_category")))
final_output.show()
print(f"Total runtime: {(time.time() - start_time) / 60:.2f} minutes")


Starting Step 5: Verifying results...
+--------------------+-----------+---------+------------------+--------------------+----------------+
|           entity_id|entity_type|   metric|             value|           prod_name|  color_category|
+--------------------+-----------+---------+------------------+--------------------+----------------+
|           706016001|    Article| PageRank| 72.90340895414506|Jade HW Skinny De...|      Black Dark|
|           706016002|    Article| PageRank| 66.34432907630007|Jade HW Skinny De...|Light Blue Light|
|           610776002|    Article| PageRank| 48.91214362914148|               Tilly|      Black Dark|
|           372860002|    Article| PageRank| 44.46605592670959|  7p Basic Shaftless|     White Light|
|           759871002|    Article| PageRank| 42.83729112482861|          Tilda tank|      Black Dark|
|           610776001|    Article| PageRank|42.661208443544155|               Tilly|     White Light|
|           372860001|    Article| PageRank|

DataFrame[src: string, dst: string, t_dat: date, price: double, transaction_year: int, transaction_month: int, relationship: string]

In [0]:
vertices.unpersist()
edges.unpersist()

DataFrame[src: string, dst: string, t_dat: date, price: double, transaction_year: int, transaction_month: int, relationship: string]