# Pipeline for Taking in and Training the Model

## Basic Setup

In [0]:
%sql
CREATE CATALOG IF NOT EXISTS project_mayhem;
USE CATALOG project_mayhem;

In [0]:
%sql
CREATE SCHEMA IF NOT EXISTS project_mayhem.source_data;
CREATE SCHEMA IF NOT EXISTS project_mayhem.bronze;
CREATE SCHEMA IF NOT EXISTS project_mayhem.silver;
CREATE SCHEMA IF NOT EXISTS project_mayhem.gold;
CREATE VOLUME IF NOT EXISTS project_mayhem.source_data.raw;
CREATE VOLUME IF NOT EXISTS project_mayhem.source_data.models;
CREATE VOLUME IF NOT EXISTS project_mayhem.gold.artifacts;

## Transforming Raw CSV Data towards Medallion Architecture

In [0]:
RAW_DATA_PATH = "/Volumes/project_mayhem/source_data/raw/*.csv"

df = spark.read.csv(RAW_DATA_PATH, header=True, inferSchema=True)
df.write.format("delta").mode("overwrite").option("mergreSchema", "true").saveAsTable("project_mayhem.bronze.labeled_data")

In [0]:
%pip install torch torchvision --quiet

[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m


In [0]:
# Configuration and Setup

import numpy as np
import pandas as pd
from datetime import datetime
from typing import List, Dict, Tuple, Optional
import json

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

import mlflow
import mlflow.pytorch
from pyspark.sql import functions as F_spark
from pyspark.sql import Window

# Configuration for the training pipeline
CATALOG = "project_mayhem"

# Table references (three-part naming: catalog.schema.table)
BRONZE_TABLE = f"{CATALOG}.bronze.labeled_data"
SILVER_TABLE = f"{CATALOG}.silver.processed_timeseries"
GOLD_FEATURES_TABLE = f"{CATALOG}.gold.ml_features"
GOLD_STATS_TABLE = f"{CATALOG}.gold.normalization_stats"

# Volume paths (these are file system paths within your volumes)
# The format is /Volumes/catalog_name/schema_name/volume_name/
RAW_DATA_VOLUME = f"/Volumes/{CATALOG}/source_data/raw"
MODEL_VOLUME = f"/Volumes/{CATALOG}/source_data/models"
ARTIFACTS_VOLUME = f"/Volumes/{CATALOG}/gold/artifacts"

# For MLflow experiment tracking
# You'll need to replace this with your actual Databricks user email
CURRENT_USER = spark.sql("SELECT current_user()").collect()[0][0]
EXPERIMENT_NAME = f"/Users/{CURRENT_USER}/vehicle_health_training"

# Training hyperparameters
CONFIG = {
    "window_size": 60,
    "stride": 10,
    "batch_size": 64,
    "learning_rate": 1e-4,
    "weight_decay": 1e-5,
    "epochs": 12,
    "val_split": 0.15,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "random_seed": 42
}

# Set random seeds for reproducibility
torch.manual_seed(CONFIG["random_seed"])
np.random.seed(CONFIG["random_seed"])
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(CONFIG["random_seed"])

print("configuration loaded\n")
print(f"Catalog: {CATALOG}")
print(f"Bronze table: {BRONZE_TABLE}")
print(f"Silver table: {SILVER_TABLE}")
print(f"Gold features table: {GOLD_FEATURES_TABLE}")
print(f"Gold stats table: {GOLD_STATS_TABLE}")
print(f"\nFile storage volumes:")
print(f"  Raw data: {RAW_DATA_VOLUME}")
print(f"  Models: {MODEL_VOLUME}")
print(f"  Artifacts: {ARTIFACTS_VOLUME}")
print(f"\nDevice: {CONFIG['device']}")
print(f"MLflow experiment: {EXPERIMENT_NAME}")

configuration loaded

Catalog: project_mayhem
Bronze table: project_mayhem.bronze.labeled_data
Silver table: project_mayhem.silver.processed_timeseries
Gold features table: project_mayhem.gold.ml_features
Gold stats table: project_mayhem.gold.normalization_stats

File storage volumes:
  Raw data: /Volumes/project_mayhem/source_data/raw
  Models: /Volumes/project_mayhem/source_data/models
  Artifacts: /Volumes/project_mayhem/gold/artifacts

Device: cpu
MLflow experiment: /Users/devforbchain@gmail.com/vehicle_health_training


In [0]:
def create_silver_layer():
    """
    Transform bronze layer raw data into clean, resampled time series.
    This function reads from bronze, applies cleaning and resampling, 
    and writes to silver layer.
    """
    print("Starting bronze to silver transformation...")
    
    # Read the bronze table
    bronze_df = spark.table(BRONZE_TABLE)
    
    print(f"Bronze table contains {bronze_df.count()} rows for {bronze_df.select('vehicle_id').distinct().count()} vehicles")
    
    # First, let's validate and clean the data
    # Remove any rows with null critical columns
    cleaned_df = bronze_df.filter(
        F_spark.col("vehicle_id").isNotNull() & 
        F_spark.col("timestamp").isNotNull()
    )
    
    # Cast timestamp to proper timestamp type if it isn't already
    cleaned_df = cleaned_df.withColumn("timestamp", F_spark.col("timestamp").cast("timestamp"))
    
    # Identify sensor columns (numeric columns excluding labels)
    label_cols = ["engine_health", "battery_health", "urgency"]
    all_cols = cleaned_df.columns
    sensor_cols = [c for c in all_cols if c not in ["vehicle_id", "timestamp"] + label_cols]
    
    print(f"Identified sensor columns: {sensor_cols}")
    
    # For resampling in Spark, we'll use a different approach than pandas
    # We'll create a complete time grid and join our data to it
    
    # First, get the time range for each vehicle
    time_ranges = cleaned_df.groupBy("vehicle_id").agg(
        F_spark.min("timestamp").alias("start_time"),
        F_spark.max("timestamp").alias("end_time")
    )
    
    # Convert to pandas for easier time series handling
    # For very large datasets, you'd want to process this in Spark using window functions
    # But for typical vehicle telematics data, this approach works well
    pdf = cleaned_df.toPandas()
    
    # Resample each vehicle's time series to 1-minute intervals
    resampled_dfs = []
    
    for vehicle_id, group in pdf.groupby("vehicle_id"):
        # Sort by timestamp
        group = group.sort_values("timestamp").set_index("timestamp")
        
        # Create a regular time index (1 minute intervals)
        start = group.index.min()
        end = group.index.max()
        regular_index = pd.date_range(start=start, end=end, freq="1min")
        
        # Reindex to the regular intervals
        group = group.reindex(regular_index)
        
        # Fill the vehicle_id column
        group["vehicle_id"] = vehicle_id
        
        # Interpolate sensor values linearly, then forward fill, then backward fill
        # This handles gaps in the data gracefully
        for col in sensor_cols + label_cols:
            if col in group.columns:
                group[col] = group[col].interpolate(method='time').ffill().bfill()
        
        resampled_dfs.append(group.reset_index().rename(columns={"index": "timestamp"}))
    
    # Combine all vehicles back together
    resampled_pdf = pd.concat(resampled_dfs, ignore_index=True)
    
    print(f"Resampled data shape: {resampled_pdf.shape}")
    
    # Convert back to Spark DataFrame
    silver_df = spark.createDataFrame(resampled_pdf)
    
    # Write to silver layer as Delta table
    # Using merge mode to handle incremental updates in the future
    silver_df.write.format("delta").mode("overwrite").option("mergeSchema", "true").saveAsTable(SILVER_TABLE)
    
    print(f"Silver layer created successfully: {SILVER_TABLE}")
    
    return silver_df

# Execute the transformation
silver_df = create_silver_layer()
display(silver_df.limit(10))

Starting bronze to silver transformation...
Bronze table contains 288000 rows for 200 vehicles
Identified sensor columns: ['engine_temp', 'rpm', 'vibration', 'battery_v', 'speed']
Resampled data shape: (288000, 10)
Silver layer created successfully: project_mayhem.silver.processed_timeseries


timestamp,vehicle_id,engine_temp,rpm,vibration,battery_v,speed,engine_health,battery_health,urgency
2025-01-01T00:00:00.000Z,veh_0,73.03946575199711,832.669045716939,0.1145159942845096,12.712044659960071,41.62090312619082,0.7203102488976955,1.0,0
2025-01-01T00:01:00.000Z,veh_0,73.58894220018173,824.4464927458154,0.1229115029558574,12.711944659960071,43.82002717022266,0.6943612907519229,1.0,0
2025-01-01T00:02:00.000Z,veh_0,73.42552556005242,823.1385374348702,0.1110338022890975,12.711844659960072,31.6788205032476,0.7208403027542646,1.0,0
2025-01-01T00:03:00.000Z,veh_0,73.8549139896555,828.6522224069562,0.1121531094780503,12.711744659960074,46.14705382271266,0.7114452145496409,1.0,0
2025-01-01T00:04:00.000Z,veh_0,72.80092185463594,820.4654455913756,0.1177029211818564,12.711644659960072,43.93612759382521,0.717912126725688,1.0,0
2025-01-01T00:05:00.000Z,veh_0,74.26278407688547,821.5573346753303,0.1233563848781365,12.711544659960072,36.674432707859495,0.6822408289623025,1.0,0
2025-01-01T00:06:00.000Z,veh_0,72.53770645782336,816.5287389250996,0.1211382493731545,12.711444659960073,43.028661992933,0.7154283936233015,1.0,0
2025-01-01T00:07:00.000Z,veh_0,74.12929461632753,816.1345922442816,0.1165517321763614,12.711344659960073,36.19213105554856,0.698074958708485,1.0,0
2025-01-01T00:08:00.000Z,veh_0,72.81809572334583,802.945158512111,0.1390825136344327,12.711244659960071,37.80883340549487,0.6748667106753707,1.0,0
2025-01-01T00:09:00.000Z,veh_0,73.30906754112974,807.4799068178619,0.1273496632404339,12.711144659960071,34.496096408075985,0.6901495478336366,1.0,0


In [0]:
def compute_and_store_normalization_stats():
    """
    Compute normalization statistics (mean and std) for each sensor column.
    Store these in the gold layer for consistent use across training runs.
    """
    print("Computing normalization statistics...")
    
    # Read silver layer
    silver_df = spark.table(SILVER_TABLE)
    
    # Identify sensor columns
    label_cols = ["engine_health", "battery_health", "urgency"]
    sensor_cols = [c for c in silver_df.columns if c not in ["vehicle_id", "timestamp"] + label_cols]
    
    # Compute statistics for each sensor
    stats_dict = {}
    
    for col in sensor_cols:
        stats = silver_df.select(
            F_spark.mean(col).alias("mean"),
            F_spark.stddev(col).alias("std")
        ).collect()[0]
        
        mean_val = float(stats["mean"]) if stats["mean"] is not None else 0.0
        std_val = float(stats["std"]) if stats["std"] is not None else 1.0
        
        # Avoid division by zero
        if std_val < 1e-9:
            std_val = 1.0
        
        stats_dict[col] = {"mean": mean_val, "std": std_val}
    
    print("Normalization statistics:")
    for col, stats in stats_dict.items():
        print(f"  {col}: mean={stats['mean']:.4f}, std={stats['std']:.4f}")
    
    # Convert to DataFrame for storage
    stats_records = [
        {"sensor_name": col, "mean": stats["mean"], "std": stats["std"], "computed_at": datetime.now()}
        for col, stats in stats_dict.items()
    ]
    
    stats_df = spark.createDataFrame(stats_records)
    
    # Write to gold layer
    stats_df.write.format("delta").mode("overwrite").saveAsTable(GOLD_STATS_TABLE)
    
    print(f"Normalization statistics saved to {GOLD_STATS_TABLE}")
    
    return stats_dict

# Compute and store stats
norm_stats = compute_and_store_normalization_stats()

Computing normalization statistics...
Normalization statistics:
  engine_temp: mean=70.0553, std=2.6706
  rpm: mean=794.9632, std=54.2376
  vibration: mean=0.0992, std=0.0241
  battery_v: mean=12.5231, std=0.0712
  speed: mean=29.9115, std=6.0570
Normalization statistics saved to project_mayhem.gold.normalization_stats


In [0]:
def create_gold_ml_features():
    """
    Create ML-ready features by:
    1. Normalizing sensor values using stored statistics
    2. Creating sliding windows of size 60 with stride 10
    3. Storing as structured format for efficient loading
    
    Fixed: Properly handles timestamp conversion to avoid Parquet schema issues
    """
    print("Creating gold layer ML features...")
    
    # Read silver layer and normalization stats
    silver_df = spark.table(SILVER_TABLE)
    stats_df = spark.table(GOLD_STATS_TABLE)
    
    # Convert stats to dictionary for easy lookup
    stats_dict = {row["sensor_name"]: {"mean": row["mean"], "std": row["std"]} 
                  for row in stats_df.collect()}
    
    # Apply normalization to sensor columns
    normalized_df = silver_df
    for col, stats in stats_dict.items():
        if col in normalized_df.columns:
            normalized_df = normalized_df.withColumn(
                col,
                (F_spark.col(col) - stats["mean"]) / stats["std"]
            )
    
    print("Normalization applied to sensor columns")
    
    # Convert to pandas for windowing
    # This is where we'll be more careful with timestamp handling
    pdf = normalized_df.toPandas()
    
    window_size = CONFIG["window_size"]
    stride = CONFIG["stride"]
    
    label_cols = ["engine_health", "battery_health", "urgency"]
    sensor_cols = [c for c in pdf.columns if c not in ["vehicle_id", "timestamp"] + label_cols]
    
    print(f"Creating windows with size={window_size}, stride={stride}")
    print(f"Sensor columns: {sensor_cols}")
    
    window_records = []
    window_id = 0
    
    for vehicle_id, group in pdf.groupby("vehicle_id"):
        group = group.sort_values("timestamp").reset_index(drop=True)
        n = len(group)
        
        if n < window_size:
            continue  # Skip vehicles with insufficient data
        
        # Create sliding windows
        for start_idx in range(0, n - window_size + 1, stride):
            end_idx = start_idx + window_size
            window_data = group.iloc[start_idx:end_idx]
            
            # Extract sensor values as arrays (lists that will become arrays in Spark)
            sensor_arrays = {}
            for col in sensor_cols:
                sensor_arrays[f"sensor_{col}"] = window_data[col].tolist()
            
            # Use last timestep's labels as window labels
            last_row = window_data.iloc[-1]
            
            # CRITICAL FIX: Convert timestamps to strings to avoid Parquet schema issues
            # The .isoformat() method converts pandas Timestamps to ISO 8601 format strings
            # which are universally compatible with Spark/Parquet
            window_start = window_data["timestamp"].iloc[0]
            window_end = window_data["timestamp"].iloc[-1]
            
            record = {
                "window_id": window_id,
                "vehicle_id": vehicle_id,
                # Convert timestamps to strings - this is the key fix!
                "window_start_time": window_start.isoformat() if pd.notna(window_start) else None,
                "window_end_time": window_end.isoformat() if pd.notna(window_end) else None,
                **sensor_arrays,  # Unpack all sensor arrays
                "engine_health": float(last_row["engine_health"]),
                "battery_health": float(last_row["battery_health"]),
                "urgency": int(last_row["urgency"]),
                "created_at": datetime.now().isoformat()  # Also convert this to string
            }
            
            window_records.append(record)
            window_id += 1
        
        if window_id % 1000 == 0 and window_id > 0:
            print(f"  Processed {window_id} windows...")
    
    print(f"Created {len(window_records)} total windows")
    
    # Additional safety check: make sure we have windows before trying to write
    if len(window_records) == 0:
        print("ERROR: No windows were created!")
        print("This usually means vehicles don't have enough timesteps.")
        print(f"Window size required: {window_size}")
        print("Check your silver table to ensure vehicles have sufficient data.")
        return None
    
    # Convert to Spark DataFrame
    # Spark will now properly infer the schema because all timestamps are strings
    gold_df = spark.createDataFrame(window_records)
    
    # Let's verify the schema looks good before writing
    print("\nDataFrame schema:")
    gold_df.printSchema()
    
    # Write to gold layer
    gold_df.write.format("delta").mode("overwrite").option("mergeSchema", "true").saveAsTable(GOLD_FEATURES_TABLE)
    
    print(f"\n✓ Gold features saved to {GOLD_FEATURES_TABLE}")
    
    # Show a sample of what we created
    print("\nSample of created features:")
    display(gold_df.select("window_id", "vehicle_id", "window_start_time", "engine_health", "battery_health", "urgency").limit(5))
    
    return gold_df

# Execute the function
gold_features = create_gold_ml_features()

## Model Definition

In [0]:
import math
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)  # (1, max_len, d_model)

    def forward(self, x):
        # x: (B, T, D)
        T = x.size(1)
        return x + self.pe[:, :T, :].to(x.device)

class TemporalSensorEncoder(nn.Module):
    """
    Encodes each sensor's time series into a sensor embedding using a Transformer across time.
    Input x shape: (B, T, F)
    We'll vectorize by reshaping to (B*F, T, 1) and then project to d_model for Transformer.
    """
    def __init__(self, T, d_model=64, nhead=4, num_layers=2, dropout=0.1):
        super().__init__()
        self.T = T
        self.d_model = d_model
        self.input_proj = nn.Linear(1, d_model)  # we'll apply per sensor value
        self.pos_enc = PositionalEncoding(d_model, max_len=T)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout, dim_feedforward=d_model*2)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.pool = nn.AdaptiveAvgPool1d(1)

    def forward(self, x):
        # x: (B, T, F)
        B, T, F = x.shape
        # reshape to (B*F, T, 1)
        x_r = x.permute(0,2,1).contiguous().view(B*F, T, 1)  # (B*F, T, 1)
        x_p = self.input_proj(x_r)  # (B*F, T, d_model)
        x_pe = self.pos_enc(x_p)    # (B*F, T, d_model)
        # transformer expects (T, B', D)
        tf_in = x_pe.permute(1,0,2)
        tf_out = self.transformer(tf_in)  # (T, B*F, D)
        tf_out = tf_out.permute(1,2,0)    # (B*F, D, T)
        # pool over time -> (B*F, D, 1)
        pooled = self.pool(tf_out).squeeze(-1)  # (B*F, D)
        # reshape back to (B, F, D)
        sensor_emb = pooled.view(B, F, self.d_model)
        return sensor_emb  # (B, F, D)

class SensorGraphAttention(nn.Module):
    """
    Simple GAT-style attention across sensors:
    Given sensor embeddings H (B, F, D), compute new H' (B, F, D_out)
    """
    def __init__(self, in_dim, out_dim, dropout=0.1, leaky_relu_neg_slope=0.2):
        super().__init__()
        self.W = nn.Linear(in_dim, out_dim, bias=False)
        self.a = nn.Linear(2*out_dim, 1, bias=False)  # attention mechanism
        self.leakyrelu = nn.LeakyReLU(leaky_relu_neg_slope)
        self.dropout = nn.Dropout(dropout)

    def forward(self, H):
        # H: (B, n_sensors, D)
        B, n_sensors, D = H.shape
        Wh = self.W(H)  # (B, n_sensors, D_out)
        # compute pairwise attention scores efficiently:
        # prepare Wh_i || Wh_j pairs -> shape (B, n_sensors, n_sensors, 2*D_out)
        Wh_i = Wh.unsqueeze(2).expand(B, n_sensors, n_sensors, Wh.shape[-1])
        Wh_j = Wh.unsqueeze(1).expand(B, n_sensors, n_sensors, Wh.shape[-1])
        a_input = torch.cat([Wh_i, Wh_j], dim=-1)  # (B,n_sensors,n_sensors,2*D_out)
        e = self.leakyrelu(self.a(a_input).squeeze(-1))  # (B, n_sensors, n_sensors)
        # mask self-attention optionally
        attention = F.softmax(e, dim=-1)  # softmax over neighbors j
        attention = self.dropout(attention)
        # aggregate
        H_prime = torch.matmul(attention, Wh)  # (B, n_sensors, D_out)
        return H_prime

class TemporalGraphHealthNet(nn.Module):
    """
    Enhanced version that allows configuring the transformer architecture
    """
    def __init__(self, T, F, d_model=64, gat_out=64, hidden=128, 
                 nhead=4, num_transformer_layers=2, dropout=0.1):
        super().__init__()
        
        # Now we pass the transformer config through to the encoder
        self.temporal_encoder = TemporalSensorEncoder(
            T, 
            d_model=d_model, 
            nhead=nhead,  # Pass through the number of attention heads
            num_layers=num_transformer_layers,  # Pass through the depth
            dropout=dropout
        )
        
        self.gat = SensorGraphAttention(in_dim=d_model, out_dim=gat_out, dropout=dropout)
        
        self.shared_fc = nn.Sequential(
            nn.Linear(gat_out, hidden), 
            nn.ReLU(), 
            nn.Dropout(dropout)
        )
        
        self.engine_head = nn.Linear(hidden, 1)
        self.battery_head = nn.Linear(hidden, 1)
        self.urgency_head = nn.Linear(hidden, 3)

    def forward(self, x):
        B, T, F = x.shape
        sensor_emb = self.temporal_encoder(x)
        gat_out = self.gat(sensor_emb)
        pooled = gat_out.mean(dim=1)
        z = self.shared_fc(pooled)
        engine = torch.sigmoid(self.engine_head(z)).squeeze(-1)
        battery = torch.sigmoid(self.battery_head(z)).squeeze(-1)
        urgency_logits = self.urgency_head(z)
        return {
            "engine": engine, 
            "battery": battery, 
            "urgency_logits": urgency_logits
        }

## Utilizing Feature Store

In [0]:
class GoldLayerDataset(Dataset):
    """
    PyTorch Dataset that reads from the gold layer ML features table.
    This is much more efficient than recreating windows on the fly.
    """
    def __init__(self, gold_table_name: str, sensor_cols: List[str]):
        """
        Args:
            gold_table_name: Full name of the gold features table
            sensor_cols: List of sensor column names (in order)
        """
        print(f"Loading dataset from {gold_table_name}...")
        
        # Read the entire gold table into memory
        # For very large datasets, you might want to use Spark's toLocalIterator() instead
        gold_df = spark.table(gold_table_name).toPandas()
        
        self.sensor_cols = sensor_cols
        self.samples = []
        
        # Extract each window as a training sample
        for _, row in gold_df.iterrows():
            # Reconstruct the feature matrix from sensor arrays
            # Shape will be (window_size, num_sensors)
            feature_matrix = np.column_stack([
                np.array(row[f"sensor_{col}"]) for col in sensor_cols
            ]).astype(np.float32)
            
            # Extract labels
            labels_continuous = np.array([
                row["engine_health"],
                row["battery_health"]
            ], dtype=np.float32)
            
            label_categorical = int(row["urgency"])
            
            self.samples.append((feature_matrix, labels_continuous, label_categorical))
        
        print(f"Loaded {len(self.samples)} training samples")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return self.samples[idx]

def collate_fn(batch):
    """Collate function for DataLoader"""
    X = torch.stack([torch.tensor(item[0]) for item in batch])
    y_continuous = torch.stack([torch.tensor(item[1]) for item in batch])
    y_categorical = torch.tensor([item[2] for item in batch]).long()
    return X, y_continuous, y_categorical

In [0]:
# Helper function to ensure directories exist

def ensure_directories_exist():
    """
    Ensure all necessary directories and volumes are properly initialized.
    This needs to be called before any file write operations.
    """
    print("Ensuring directory structure exists...")
    
    # These paths need to exist before we can write files to them
    directories_to_create = [
        MODEL_VOLUME,
        ARTIFACTS_VOLUME
    ]
    
    for directory in directories_to_create:
        try:
            # mkdirs creates the directory if it doesn't exist, does nothing if it does
            dbutils.fs.mkdirs(directory)
            print(f"  ✓ Directory ready: {directory}")
        except Exception as e:
            print(f"  ✗ Error creating directory {directory}: {e}")
            raise
    
    print("Directory structure verified!\n")

    # Ensure that model_filename = "best_model_checkpoint.pth" exists such that
    # model_path = f"{MODEL_VOLUME}/{model_filename}"
    # full_path = f"/dbfs{model_path}"
    

# Run this to initialize directories
ensure_directories_exist()

Ensuring directory structure exists...
  ✓ Directory ready: /Volumes/project_mayhem/source_data/models
  ✓ Directory ready: /Volumes/project_mayhem/gold/artifacts
Directory structure verified!



## Training Loop

In [0]:
# Training Loop with MLflow Tracking

import io  # for the BytesIO buffer

def train_model():
    """
    Main training function with full MLflow tracking.
    Fixed to use Databricks-compatible file saving approach.
    """
    
    # Ensure directories exist
    ensure_directories_exist()
    
    # Start MLflow run
    mlflow.set_experiment(EXPERIMENT_NAME)
    
    with mlflow.start_run(run_name=f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}") as run:
        
        # Log all configuration parameters
        mlflow.log_params(CONFIG)
        
        print("="*60)
        print("Starting training run")
        print(f"Run ID: {run.info.run_id}")
        print("="*60)
        
        # Get sensor column names from normalization stats
        stats_df = spark.table(GOLD_STATS_TABLE)
        sensor_cols = [row["sensor_name"] for row in stats_df.collect()]
        mlflow.log_param("sensor_columns", ",".join(sensor_cols))
        mlflow.log_param("num_sensors", len(sensor_cols))
        
        # Create dataset from gold layer
        dataset = GoldLayerDataset(GOLD_FEATURES_TABLE, sensor_cols)
        
        # Split into train and validation
        n_total = len(dataset)
        n_val = int(n_total * CONFIG["val_split"])
        n_train = n_total - n_val
        
        train_dataset, val_dataset = random_split(
            dataset, 
            [n_train, n_val],
            generator=torch.Generator().manual_seed(CONFIG["random_seed"])
        )
        
        # Create data loaders
        train_loader = DataLoader(
            train_dataset, 
            batch_size=CONFIG["batch_size"],
            shuffle=True,
            collate_fn=collate_fn,
            num_workers=0
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=CONFIG["batch_size"],
            shuffle=False,
            collate_fn=collate_fn,
            num_workers=0
        )
        
        print(f"\nDataset split:")
        print(f"  Training samples: {n_train}")
        print(f"  Validation samples: {n_val}")
        print(f"  Batch size: {CONFIG['batch_size']}")
        print(f"  Training batches per epoch: {len(train_loader)}")
        
        # Initialize model
        device = torch.device(CONFIG["device"])
        model = TemporalGraphHealthNet(
            T=CONFIG["window_size"],
            F=len(sensor_cols),
            d_model=64,
            gat_out=64,
            hidden=128
        ).to(device)
        
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        print(f"\nModel initialized on {device}")
        print(f"  Total parameters: {total_params:,}")
        print(f"  Trainable parameters: {trainable_params:,}")
        
        mlflow.log_param("total_parameters", total_params)
        mlflow.log_param("trainable_parameters", trainable_params)
        
        # Define loss functions and optimizer
        criterion_continuous = nn.MSELoss()
        criterion_categorical = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(
            model.parameters(), 
            lr=CONFIG["learning_rate"], 
            weight_decay=CONFIG["weight_decay"]
        )
        
        # Training loop
        best_val_loss = float('inf')
        best_epoch = 0
        
        print(f"\nStarting training for {CONFIG['epochs']} epochs...")
        print("="*60)
        
        for epoch in range(CONFIG["epochs"]):
            # Training phase
            model.train()
            train_loss = 0.0
            train_engine_loss = 0.0
            train_battery_loss = 0.0
            train_urgency_loss = 0.0
            
            for batch_idx, (X, y_continuous, y_categorical) in enumerate(train_loader):
                X = X.to(device)
                y_continuous = y_continuous.to(device)
                y_categorical = y_categorical.to(device)
                
                # Forward pass
                predictions = model(X)
                
                # Compute individual task losses
                loss_engine = criterion_continuous(predictions["engine"], y_continuous[:, 0])
                loss_battery = criterion_continuous(predictions["battery"], y_continuous[:, 1])
                loss_urgency = criterion_categorical(predictions["urgency_logits"], y_categorical)
                
                # Combined loss with task weighting
                loss = loss_engine + loss_battery + 0.5 * loss_urgency
                
                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                
                # Accumulate losses
                train_loss += loss.item()
                train_engine_loss += loss_engine.item()
                train_battery_loss += loss_battery.item()
                train_urgency_loss += loss_urgency.item()
            
            # Average training losses
            train_loss /= len(train_loader)
            train_engine_loss /= len(train_loader)
            train_battery_loss /= len(train_loader)
            train_urgency_loss /= len(train_loader)
            
            # Validation phase
            model.eval()
            val_loss = 0.0
            val_engine_loss = 0.0
            val_battery_loss = 0.0
            val_urgency_loss = 0.0
            
            with torch.no_grad():
                for X, y_continuous, y_categorical in val_loader:
                    X = X.to(device)
                    y_continuous = y_continuous.to(device)
                    y_categorical = y_categorical.to(device)
                    
                    predictions = model(X)
                    
                    loss_engine = criterion_continuous(predictions["engine"], y_continuous[:, 0])
                    loss_battery = criterion_continuous(predictions["battery"], y_continuous[:, 1])
                    loss_urgency = criterion_categorical(predictions["urgency_logits"], y_categorical)
                    
                    loss = loss_engine + loss_battery + 0.5 * loss_urgency
                    
                    val_loss += loss.item()
                    val_engine_loss += loss_engine.item()
                    val_battery_loss += loss_battery.item()
                    val_urgency_loss += loss_urgency.item()
            
            # Average validation losses
            val_loss /= len(val_loader)
            val_engine_loss /= len(val_loader)
            val_battery_loss /= len(val_loader)
            val_urgency_loss /= len(val_loader)
            
            # Log metrics to MLflow
            mlflow.log_metrics({
                "train_loss": train_loss,
                "train_engine_loss": train_engine_loss,
                "train_battery_loss": train_battery_loss,
                "train_urgency_loss": train_urgency_loss,
                "val_loss": val_loss,
                "val_engine_loss": val_engine_loss,
                "val_battery_loss": val_battery_loss,
                "val_urgency_loss": val_urgency_loss
            }, step=epoch)
            
            # Print progress
            print(f"\nEpoch {epoch+1}/{CONFIG['epochs']}")
            print(f"  Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
            print(f"  Engine: {train_engine_loss:.4f} / {val_engine_loss:.4f}")
            print(f"  Battery: {train_battery_loss:.4f} / {val_battery_loss:.4f}")
            print(f"  Urgency: {train_urgency_loss:.4f} / {val_urgency_loss:.4f}")
            
            # Save best model using Databricks-compatible approach
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_epoch = epoch
                
                # Create checkpoint dictionary
                checkpoint = {
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "epoch": epoch,
                    "val_loss": val_loss,
                    "train_loss": train_loss,
                    "config": CONFIG,
                    "sensor_cols": sensor_cols,
                    "timestamp": datetime.now().isoformat()
                }
                
                # CRITICAL FIX: Save to BytesIO buffer first, then write using dbutils
                # This avoids IPython's file handling restrictions
                buffer = io.BytesIO()
                torch.save(checkpoint, buffer)
                buffer.seek(0)  # Reset buffer position to beginning
                
                # Define the destination path
                model_filename = "best_model_checkpoint.pth"
                model_path = f"{MODEL_VOLUME}/{model_filename}"
                
                # Write the buffer contents using dbutils, which is Databricks-native
                # and doesn't go through IPython's modified file operations
                dbutils.fs.put(model_path, buffer.read().decode('latin1'), overwrite=True)
                
                print(f"  ✓ New best model saved! (val_loss: {val_loss:.4f})")
                print(f"    Location: {model_path}")
                
                # Also log the model to MLflow
                # MLflow handles its own file operations safely
                mlflow.pytorch.log_model(model, "model")
        
        print("\n" + "="*60)
        print("TRAINING COMPLETED")
        print("="*60)
        print(f"Best validation loss: {best_val_loss:.4f} (epoch {best_epoch+1})")
        print(f"Model saved to: {model_path}")
        print(f"MLflow run ID: {run.info.run_id}")
        print("="*60)
        
        return model, best_val_loss, model_path

# Run training
print("Initiating training pipeline...\n")
trained_model, final_loss, saved_model_path = train_model()

Initiating training pipeline...

Ensuring directory structure exists...
  ✓ Directory ready: /Volumes/project_mayhem/source_data/models
  ✓ Directory ready: /Volumes/project_mayhem/gold/artifacts
Directory structure verified!

Starting training run
Run ID: d1bb83884116439fa44be99c55c03674
Loading dataset from project_mayhem.gold.ml_features...
Loaded 27800 training samples

Dataset split:
  Training samples: 23630
  Validation samples: 4170
  Batch size: 64
  Training batches per epoch: 370

Model initialized on cpu
  Total parameters: 80,261
  Trainable parameters: 80,261

Starting training for 12 epochs...





Epoch 1/12
  Train Loss: 0.1207 | Val Loss: 0.0140
  Engine: 0.0137 / 0.0041
  Battery: 0.0319 / 0.0011
  Urgency: 0.1501 / 0.0176
Wrote 1456598 bytes.
  ✓ New best model saved! (val_loss: 0.0140)
    Location: /Volumes/project_mayhem/source_data/models/best_model_checkpoint.pth





Epoch 2/12
  Train Loss: 0.0137 | Val Loss: 0.0096
  Engine: 0.0052 / 0.0033
  Battery: 0.0013 / 0.0010
  Urgency: 0.0145 / 0.0106
Wrote 1456864 bytes.
  ✓ New best model saved! (val_loss: 0.0096)
    Location: /Volumes/project_mayhem/source_data/models/best_model_checkpoint.pth





Epoch 3/12
  Train Loss: 0.0105 | Val Loss: 0.0082
  Engine: 0.0042 / 0.0029
  Battery: 0.0012 / 0.0010
  Urgency: 0.0102 / 0.0086
Wrote 1455926 bytes.
  ✓ New best model saved! (val_loss: 0.0082)
    Location: /Volumes/project_mayhem/source_data/models/best_model_checkpoint.pth





Epoch 4/12
  Train Loss: 0.0093 | Val Loss: 0.0076
  Engine: 0.0036 / 0.0025
  Battery: 0.0011 / 0.0010
  Urgency: 0.0092 / 0.0082
Wrote 1456068 bytes.
  ✓ New best model saved! (val_loss: 0.0076)
    Location: /Volumes/project_mayhem/source_data/models/best_model_checkpoint.pth





Epoch 5/12
  Train Loss: 0.0087 | Val Loss: 0.0074
  Engine: 0.0033 / 0.0025
  Battery: 0.0011 / 0.0009
  Urgency: 0.0086 / 0.0079
Wrote 1456093 bytes.
  ✓ New best model saved! (val_loss: 0.0074)
    Location: /Volumes/project_mayhem/source_data/models/best_model_checkpoint.pth





Epoch 6/12
  Train Loss: 0.0084 | Val Loss: 0.0072
  Engine: 0.0031 / 0.0024
  Battery: 0.0011 / 0.0009
  Urgency: 0.0084 / 0.0076
Wrote 1456157 bytes.
  ✓ New best model saved! (val_loss: 0.0072)
    Location: /Volumes/project_mayhem/source_data/models/best_model_checkpoint.pth





Epoch 7/12
  Train Loss: 0.0083 | Val Loss: 0.0071
  Engine: 0.0030 / 0.0025
  Battery: 0.0011 / 0.0009
  Urgency: 0.0084 / 0.0075
Wrote 1455328 bytes.
  ✓ New best model saved! (val_loss: 0.0071)
    Location: /Volumes/project_mayhem/source_data/models/best_model_checkpoint.pth





Epoch 8/12
  Train Loss: 0.0081 | Val Loss: 0.0070
  Engine: 0.0029 / 0.0024
  Battery: 0.0011 / 0.0009
  Urgency: 0.0082 / 0.0075
Wrote 1455965 bytes.
  ✓ New best model saved! (val_loss: 0.0070)
    Location: /Volumes/project_mayhem/source_data/models/best_model_checkpoint.pth





Epoch 9/12
  Train Loss: 0.0080 | Val Loss: 0.0069
  Engine: 0.0028 / 0.0023
  Battery: 0.0011 / 0.0009
  Urgency: 0.0082 / 0.0073
Wrote 1455419 bytes.
  ✓ New best model saved! (val_loss: 0.0069)
    Location: /Volumes/project_mayhem/source_data/models/best_model_checkpoint.pth





Epoch 10/12
  Train Loss: 0.0079 | Val Loss: 0.0069
  Engine: 0.0027 / 0.0023
  Battery: 0.0011 / 0.0009
  Urgency: 0.0082 / 0.0072
Wrote 1455817 bytes.
  ✓ New best model saved! (val_loss: 0.0069)
    Location: /Volumes/project_mayhem/source_data/models/best_model_checkpoint.pth





Epoch 11/12
  Train Loss: 0.0078 | Val Loss: 0.0068
  Engine: 0.0027 / 0.0024
  Battery: 0.0011 / 0.0009
  Urgency: 0.0081 / 0.0071
Wrote 1455148 bytes.
  ✓ New best model saved! (val_loss: 0.0068)
    Location: /Volumes/project_mayhem/source_data/models/best_model_checkpoint.pth





Epoch 12/12
  Train Loss: 0.0078 | Val Loss: 0.0067
  Engine: 0.0027 / 0.0023
  Battery: 0.0011 / 0.0009
  Urgency: 0.0080 / 0.0069
Wrote 1455119 bytes.
  ✓ New best model saved! (val_loss: 0.0067)
    Location: /Volumes/project_mayhem/source_data/models/best_model_checkpoint.pth





TRAINING COMPLETED
Best validation loss: 0.0067 (epoch 12)
Model saved to: /Volumes/project_mayhem/source_data/models/best_model_checkpoint.pth
MLflow run ID: d1bb83884116439fa44be99c55c03674


## Final Command to Execute It All

In [0]:
def incremental_training_pipeline(new_data_path: str):
    """
    Complete pipeline for training on new data while preserving old knowledge.
    
    Steps:
    1. Load new CSV into bronze layer (append mode)
    2. Reprocess silver layer with combined data
    3. Update gold features with new windows
    4. Train model with experience replay from previous data
    """
    print("="*60)
    print("INCREMENTAL TRAINING PIPELINE")
    print("="*60)
    
    # Step 1: Load new data into bronze
    print("\n[1/4] Loading new data into bronze layer...")
    new_df = spark.read.csv(new_data_path, header=True, inferSchema=True)
    new_df.write.format("delta").mode("append").saveAsTable(BRONZE_TABLE)
    print(f"  ✓ Added {new_df.count()} new rows to bronze layer")
    
    # Step 2: Reprocess silver layer
    print("\n[2/4] Reprocessing silver layer...")
    silver_df = create_silver_layer()
    print("Silver layer updated")
    
    # Step 3: Update gold features
    print("\n[3/4] Creating new gold features...")
    gold_features = create_gold_ml_features()
    print("Gold features updated")
    
    # Step 4: Train with full dataset
    print("\n[4/4] Training model...")
    trained_model, final_loss = train_model()
    
    print()
    print("Incremental Training Complete")
    print(f"Final validation loss: {final_loss:.4f}")
    
    return trained_model

# Note to self: To trigger all this run incremental_training_pipeline("/Volumes/project_mayhem/source_data/raw/new_batch.csv") in next cell.