# üö¶ Reshaping Traffic: Spatio-Temporal Traffic Forecasting & Control

**Complete Project Consolidation**

This notebook contains the entire traffic forecasting and routing system:
- **Model Architecture**: Graph Attention Networks + Mamba (SSM) for spatio-temporal forecasting
- **Training Pipeline**: Complete training loop with early stopping, checkpoints, and logging
- **Evaluation**: Model evaluation with uncertainty quantification
- **Traffic Routing**: Closed-loop traffic control system (Model-2)
- **Simulation**: Traffic flow simulator for closed-loop evaluation

---

## Project Structure

1. **Imports & Dependencies**
2. **Utility Functions** (metrics, seed, early stopping, checkpoint, logger, uncertainty)
3. **Model Components** (GAT, Mamba, ST-Mamba blocks, multi-scale temporal, dynamic graph, uncertainty head)
4. **Main Model** (NewtonGraphMamba)
5. **Dataset** (TrafficDataset)
6. **Training Pipeline**
7. **Evaluation Pipeline**
8. **Traffic Routing Controller** (Model-2)
9. **Traffic Simulator**
10. **Example Usage & Demo**

## 1. Imports & Dependencies

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import random
import logging
import csv
import os
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.cuda.amp import autocast, GradScaler

# NetworkX for routing
try:
    import networkx as nx
except ImportError:
    print("Warning: networkx not installed. Traffic routing features will be limited.")
    nx = None

# Mamba SSM with fallback
try:
    from mamba_ssm import Mamba
    MAMBA_AVAILABLE = True
except ImportError:
    print("Warning: mamba_ssm not found. Using CPU placeholder.")
    MAMBA_AVAILABLE = False
    
    class Mamba(nn.Module):
        def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
            super().__init__()
            self.in_proj = nn.Linear(d_model, d_model * 2)
            self.out_proj = nn.Linear(d_model, d_model)

        def forward(self, x):
            return self.out_proj(F.silu(self.in_proj(x))[:, :, :x.shape[-1]])

print(f"‚úÖ PyTorch version: {torch.__version__}")
print(f"‚úÖ Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
print(f"‚úÖ Mamba SSM: {'Available' if MAMBA_AVAILABLE else 'Using fallback'}")

## 2. Utility Functions

In [None]:
# =======================
# Metrics
# =======================
def masked_mae(pred, true, eps=1e-5):
    """Masked Mean Absolute Error"""
    mask = (true != 0).float()
    loss = torch.abs(pred - true)
    return (loss * mask).sum() / (mask.sum() + eps)

def masked_rmse(pred, true, eps=1e-5):
    """Masked Root Mean Squared Error"""
    mask = (true != 0).float()
    loss = (pred - true) ** 2
    return torch.sqrt((loss * mask).sum() / (mask.sum() + eps))

def masked_mape(pred, true, eps=1e-5):
    """Masked Mean Absolute Percentage Error"""
    mask = (true != 0).float()
    loss = torch.abs((pred - true) / (true + eps))
    return (loss * mask).sum() / (mask.sum() + eps)

# =======================
# Seed for Reproducibility
# =======================
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# =======================
# Early Stopping
# =======================
class EarlyStopping:
    def __init__(self, patience=10):
        self.patience = patience
        self.best = float("inf")
        self.counter = 0

    def load(self, best, counter):
        self.best = best
        self.counter = counter

    def step(self, loss):
        if loss < self.best:
            self.best = loss
            self.counter = 0
            return False
        else:
            self.counter += 1
            return self.counter >= self.patience

# =======================
# Checkpoint Management
# =======================
def save_checkpoint(state, path="checkpoint.pt"):
    torch.save(state, path)

def load_checkpoint(model, optimizer, scheduler, path="checkpoint.pt"):
    if not os.path.exists(path):
        return 0, float("inf"), 0
    ckpt = torch.load(path, map_location="cpu")
    model.load_state_dict(ckpt["model"])
    optimizer.load_state_dict(ckpt["optimizer"])
    scheduler.load_state_dict(ckpt["scheduler"])
    return ckpt["epoch"], ckpt["best_val"], ckpt["early_stop_counter"]

# =======================
# CSV Logger
# =======================
class CSVLogger:
    def __init__(self, path="training_metrics.csv"):
        self.path = Path(path)
        if not self.path.exists():
            with open(self.path, "w", newline="") as f:
                writer = csv.writer(f)
                writer.writerow(["epoch", "train_loss", "val_mae"])

    def log(self, epoch, train_loss, val_mae):
        with open(self.path, "a", newline="") as f:
            writer = csv.writer(f)
            writer.writerow([epoch, train_loss, val_mae])

# =======================
# Uncertainty Quantification (implementing mc dropout)
# =======================
@torch.no_grad()
def mc_dropout_predict(model, X, adj, runs=20):
    """Monte-Carlo Dropout Inference for uncertainty estimation"""
    model.train()  # IMPORTANT: enable dropout
    preds = []
    for _ in range(runs):
        preds.append(model(X, adj))
    preds = torch.stack(preds, dim=0)
    mean = preds.mean(dim=0)
    std = preds.std(dim=0)
    model.eval()
    return mean, std

print("‚úÖ Utility functions loaded")

## 3. Model Components

In [None]:
# =======================
# GAT Layer (Graph Attention Network)
# =======================
class GATLayer(nn.Module):
    """Multi-head Graph Attention Network layer for spatial modeling"""
    def __init__(self, in_features, out_features, num_heads=4, dropout=0.1):
        super().__init__()
        assert out_features % num_heads == 0
        
        self.num_heads = num_heads
        self.head_dim = out_features // num_heads
        
        self.Wq = nn.Linear(in_features, out_features, bias=False)
        self.Wk = nn.Linear(in_features, out_features, bias=False)
        self.Wv = nn.Linear(in_features, out_features, bias=False)
        
        self.out_proj = nn.Linear(out_features, out_features)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, h, adj):
        B, N, _ = h.shape
        
        # Multi-head attention
        q = self.Wq(h).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.Wk(h).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.Wv(h).view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Attention scores
        scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim)
        
        # Mask disconnected nodes
        mask = (adj == 0).view(1, 1, N, N)
        scores = scores.masked_fill(mask, -1e9)
        
        attn = self.dropout(F.softmax(scores, dim=-1))
        out = torch.matmul(attn, v)
        
        out = out.transpose(1, 2).reshape(B, N, -1)
        return h + self.out_proj(out)  # Residual connection

In [None]:
# =======================
# Multi-Scale Temporal Encoder
# =======================
class TemporalBlock(nn.Module):
    """Temporal modeling block using Mamba"""
    def __init__(self, dim):
        super().__init__()
        self.mamba = Mamba(d_model=dim)
    
    def forward(self, x):
        # x: [B, T, N, F]
        B, T, N, F = x.shape
        x = x.view(B * N, T, F)
        out = self.mamba(x)
        return out[:, -1].view(B, N, F)  # Take last timestep

class MultiScaleTemporal(nn.Module):
    """Multi-scale temporal encoder with attention fusion"""
    def __init__(self, dim):
        super().__init__()
        self.short = TemporalBlock(dim)  # 5‚Äì15 min
        self.mid = TemporalBlock(dim)    # 30‚Äì60 min
        self.long = TemporalBlock(dim)   # 1‚Äì2 hour
        self.attn = nn.Linear(dim, 1)
    
    def forward(self, x):
        # x: [B, T>=48, N, F]
        short = x[:, -12:]      # Last 12 timesteps
        mid = x[:, -24::2]      # Last 24 timesteps, every 2nd
        long = x[:, -48::4]     # Last 48 timesteps, every 4th
        
        f_s = self.short(short)
        f_m = self.mid(mid)
        f_l = self.long(long)
        
        # Attention-weighted fusion
        scores = torch.stack([
            self.attn(f_s),
            self.attn(f_m),
            self.attn(f_l)
        ], dim=0)
        
        weights = torch.softmax(scores, dim=0)
        return weights[0] * f_s + weights[1] * f_m + weights[2] * f_l

In [None]:
# =======================
# Dynamic Graph Learner
# =======================
class DynamicGraphLearner(nn.Module):
    """Learns dynamic graph structure adapting to traffic conditions"""
    def __init__(self, num_nodes, dim):
        super().__init__()
        self.node_emb = nn.Parameter(torch.randn(num_nodes, dim))
        self.proj = nn.Linear(dim, dim)
    
    def forward(self, h, A_static):
        """
        h: [B, N, F] - node features
        A_static: [N, N] - static adjacency matrix
        """
        q = self.proj(h)  # Query from current state
        k = self.node_emb.unsqueeze(0)  # Key from learned embeddings
        
        # Compute dynamic adjacency
        A_dyn = torch.softmax(
            torch.matmul(q, k.transpose(-1, -2)), dim=-1
        )
        
        # Combine static and dynamic (weighted fusion)
        return 0.7 * A_static + 0.3 * A_dyn

In [None]:
# =======================
# Uncertainty Head (Probabilistic Output)
# =======================
class ProbabilisticHead(nn.Module):
    """Probabilistic prediction head with dropout for uncertainty"""
    def __init__(self, dim, horizon=12):
        super().__init__()
        self.fc = nn.Linear(dim, horizon)  # Output H predictions per node
        self.dropout = nn.Dropout(0.3)
    
    def forward(self, x):
        # x: [B, N, F] -> [B, N, H]
        return self.fc(self.dropout(x))

In [None]:
# =======================
# ST-Mamba Block (Spatio-Temporal)
# =======================
class STMambaBlock(nn.Module):
    """Spatio-temporal block combining GAT and Mamba"""
    def __init__(self, d_model, num_heads=4, dropout=0.1):
        super().__init__()
        
        self.gat = GATLayer(d_model, d_model, num_heads, dropout)
        self.norm_s = nn.LayerNorm(d_model)
        
        self.mamba = Mamba(d_model)
        self.norm_t = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, adj):
        """
        x: [B, T, N, F]
        adj: [N, N]
        """
        B, T, N, F = x.shape
        
        # Spatial modeling (across nodes)
        x_s = x.view(B * T, N, F)
        x_s = self.norm_s(x_s + self.dropout(self.gat(x_s, adj)))
        x = x_s.view(B, T, N, F)
        
        # Temporal modeling (across time)
        x_t = x.permute(0, 2, 1, 3).reshape(B * N, T, F)
        x_t = self.norm_t(x_t + self.dropout(self.mamba(x_t)))
        x = x_t.view(B, N, T, F).permute(0, 2, 1, 3)
        
        return x

print("‚úÖ Model components loaded")

## 4. Main Model: NewtonGraphMamba

In [None]:
class NewtonGraphMamba(nn.Module):
    """
    Complete spatio-temporal traffic forecasting model
    
    Architecture:
    - Input projection
    - Multi-scale temporal encoder
    - Dynamic graph learner
    - ST-Mamba blocks (spatial + temporal)
    - Probabilistic prediction head
    """
    def __init__(
        self,
        in_features=5,
        d_model=64,
        num_nodes=100,
        num_layers=4,
        num_heads=4,
        prediction_horizon=12
    ):
        super().__init__()
        
        self.num_nodes = num_nodes
        self.horizon = prediction_horizon
        
        # Input projection
        self.input_proj = nn.Linear(in_features, d_model)
        
        # Multi-scale temporal encoder
        self.multi_scale_temporal = MultiScaleTemporal(d_model)
        
        # Dynamic graph learner
        self.graph_learner = DynamicGraphLearner(num_nodes, d_model)
        
        # ST-Mamba blocks
        self.layers = nn.ModuleList([
            STMambaBlock(d_model, num_heads)
            for _ in range(num_layers)
        ])
        
        # Probabilistic prediction head
        self.head = ProbabilisticHead(d_model, prediction_horizon)
    
    def forward(self, x, adj_static):
        """
        x: [B, T>=48, N, F] - input traffic data
        adj_static: [N, N] - static adjacency matrix
        Returns: [B, N, H] - traffic predictions for H time steps
        """
        # Input projection
        x = self.input_proj(x)
        
        # Multi-scale temporal fusion
        h = self.multi_scale_temporal(x)  # [B, N, F]
        
        # Dynamic graph learning
        adj = self.graph_learner(h, adj_static)
        
        # Expand temporal dimension for ST blocks
        h = h.unsqueeze(1).repeat(1, self.horizon, 1, 1)
        
        # ST-Mamba blocks
        for layer in self.layers:
            h = layer(h, adj)
        
        # Prediction head - take last timestep and predict all H horizons
        # h[:, -1] is [B, N, F], head outputs [B, N, H]
        out = self.head(h[:, -1])  # [B, N, H]
        
        return out  # [B, N, H]

print("‚úÖ Main model loaded")

## 5. Dataset: TrafficDataset

In [None]:
class TrafficDataset(Dataset):
    """
    Traffic dataset with sliding window approach
    
    Args:
        data: np.ndarray [T, N, F] - time series data
        history_len: history window size (default: 12)
        horizon: prediction horizon (default: 12)
        mean: normalization mean (optional)
        std: normalization std (optional)
    """
    def __init__(self, data, history_len=12, horizon=12, mean=None, std=None):
        self.history_len = history_len
        self.horizon = horizon
        
        # Compute normalization statistics
        if mean is None:
            self.mean = data.mean()
            self.std = data.std() + 1e-6
        else:
            self.mean = mean
            self.std = std
        
        # Normalize data
        self.data = (data - self.mean) / self.std
        
        # Create sliding windows
        self.X, self.Y = self._create_windows()
    
    def _create_windows(self):
        """Create input-output pairs using sliding window"""
        X, Y = [], []
        T = self.data.shape[0]
        
        for t in range(T - self.history_len - self.horizon):
            X.append(self.data[t:t+self.history_len])
            Y.append(self.data[t+self.history_len:t+self.history_len+self.horizon, :, 0])
            # Target = first feature (traffic speed)
        
        return torch.tensor(X, dtype=torch.float32), torch.tensor(Y, dtype=torch.float32)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

print("‚úÖ Dataset class loaded")

## 6. Training Pipeline

In [None]:
def auto_accumulation(device_mem_gb):
    """Auto-configure batch size and gradient accumulation based on GPU memory"""
    if device_mem_gb <= 4:
        return 4, 8  # effective batch = 32
    elif device_mem_gb <= 6:
        return 8, 4
    else:
        return 8, 8

def train_model(
    data_path="data/metr_la/metr_la.npz",
    adj_path="data/metr_la/adj.npy",
    max_epochs=500,
    device_mem_gb=4,
    seed=42
):
    """Complete training pipeline"""
    
    # Setup
    set_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Logging
    logging.basicConfig(
        filename="training.log",
        level=logging.INFO,
        format="%(asctime)s | %(message)s"
    )
    logger = CSVLogger()
    
    # Load data
    data = np.load(data_path)["data"]
    adj = torch.tensor(np.load(adj_path)).float().to(device)
    
    # Train/val split (70/10)
    T = len(data)
    train_data = data[:int(0.7 * T)]
    val_data = data[int(0.7 * T):int(0.8 * T)]
    
    # Create datasets
    train_ds = TrafficDataset(train_data)
    val_ds = TrafficDataset(val_data, mean=train_ds.mean, std=train_ds.std)
    
    # Data loaders
    batch_size, grad_accum_steps = auto_accumulation(device_mem_gb)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, pin_memory=True)
    
    # Model
    model = NewtonGraphMamba(
        in_features=data.shape[-1],
        num_nodes=data.shape[1]
    ).to(device)
    
    # Optimizer & scheduler
    optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
    scheduler = ReduceLROnPlateau(optimizer, patience=5)
    scaler = GradScaler(enabled=torch.cuda.is_available())
    
    # Early stopping & checkpoint
    early_stop = EarlyStopping(patience=10)
    start_epoch, best_val, counter = load_checkpoint(model, optimizer, scheduler)
    early_stop.load(best_val, counter)
    
    print(f"‚ñ∂ Resuming from epoch {start_epoch}")
    print(f"‚ñ∂ Batch size: {batch_size}, Gradient accumulation: {grad_accum_steps}")
    
    # Training loop
    for epoch in range(start_epoch, max_epochs):
        model.train()
        optimizer.zero_grad(set_to_none=True)
        running_loss = 0.0
        
        # Training
        for step, (X, Y) in enumerate(train_loader):
            X = X.to(device, non_blocking=True)
            Y = Y.to(device, non_blocking=True).permute(0, 2, 1)
            
            with autocast(enabled=torch.cuda.is_available()):
                pred = model(X, adj)
                loss = masked_mae(pred, Y) / grad_accum_steps
            
            scaler.scale(loss).backward()
            running_loss += loss.item()
            
            if (step + 1) % grad_accum_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad(set_to_none=True)
        
        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for X, Y in val_loader:
                X = X.to(device, non_blocking=True)
                Y = Y.to(device, non_blocking=True).permute(0, 2, 1)
                val_loss += masked_mae(model(X, adj), Y).item()
        
        val_loss /= len(val_loader)
        scheduler.step(val_loss)
        logger.log(epoch, running_loss, val_loss)
        
        # Save best model
        if val_loss < best_val:
            best_val = val_loss
            torch.save(model.state_dict(), "best_model.pt")
        
        # Save checkpoint
        save_checkpoint({
            "epoch": epoch + 1,
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
            "best_val": best_val,
            "early_stop_counter": early_stop.counter
        })
        
        # Logging
        logging.info(f"Epoch {epoch} | Train Loss {running_loss:.4f} | Val MAE {val_loss:.4f}")
        print(f"Epoch {epoch:03d} | Train Loss {running_loss:.4f} | Val MAE {val_loss:.4f}")
        
        # Early stopping
        if early_stop.step(val_loss):
            print("üõë Early stopping triggered")
            break
    
    print("‚úÖ Training complete!")
    return model, best_val

print("‚úÖ Training pipeline loaded")

In [None]:
def evaluate_model(
    model,
    data_path="data/metr_la/metr_la.npz",
    adj_path="data/metr_la/adj.npy",
    model_path="best_model.pt",
    mc_runs=20
):
    """Evaluate model with uncertainty quantification"""
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load model weights
    if os.path.exists(model_path):
        model.load_state_dict(torch.load(model_path, map_location=device))
        print(f"‚úÖ Loaded model from {model_path}")
    
    # Load data
    data = np.load(data_path)["data"]
    adj = torch.tensor(np.load(adj_path)).float().to(device)
    
    # Test split (last 20%)
    T = len(data)
    train_data = data[:int(0.7 * T)]
    test_data = data[int(0.8 * T):]
    
    # Create test dataset
    train_ds = TrafficDataset(train_data)
    test_ds = TrafficDataset(test_data, mean=train_ds.mean, std=train_ds.std)
    test_loader = DataLoader(test_ds, batch_size=32, pin_memory=True)
    
    model.eval()
    
    # Metrics
    mae = rmse = mape = 0.0
    uncertainty = 0.0
    
    with torch.no_grad():
        for X, Y in test_loader:
            X = X.to(device)
            Y = Y.to(device).permute(0, 2, 1)
            
            # Monte-Carlo dropout for uncertainty
            mean, std = mc_dropout_predict(model, X, adj, runs=mc_runs)
            
            # Compute metrics
            mae += masked_mae(mean, Y).item()
            rmse += masked_rmse(mean, Y).item()
            mape += masked_mape(mean, Y).item()
            uncertainty += std.mean().item()
    
    mae /= len(test_loader)
    rmse /= len(test_loader)
    mape /= len(test_loader)
    uncertainty /= len(test_loader)
    
    print("=" * 50)
    print("Evaluation Results:")
    print(f"  MAE:  {mae:.4f}")
    print(f"  RMSE: {rmse:.4f}")
    print(f"  MAPE: {mape:.4f}")
    print(f"  Avg Predictive Uncertainty: {uncertainty:.4f}")
    print("=" * 50)
    
    return {
        "mae": mae,
        "rmse": rmse,
        "mape": mape,
        "uncertainty": uncertainty
    }

print("‚úÖ Evaluation pipeline loaded")

## 8. Traffic Routing Controller (Model-2)

In [None]:
# Traffic Routing & Flow Control Engine
# Consumes future traffic predictions and produces stable, congestion-aware route decisions

def aggregate_congestion(pred_traffic, mode="risk"):
    """
    Convert future trajectory into a single congestion risk score
    
    Args:
        pred_traffic: np.ndarray [N, H] - predicted traffic for H time steps
        mode: aggregation mode ("mean" or "risk")
    Returns:
        np.ndarray [N] - congestion risk score per node
    """
    if mode == "mean":
        return pred_traffic.mean(axis=1)
    elif mode == "risk":
        # Penalize spikes more than average
        return 0.6 * pred_traffic.mean(axis=1) + 0.4 * pred_traffic.max(axis=1)
    else:
        raise ValueError(f"Unknown aggregation mode: {mode}")

def edge_cost(distance, congestion, alpha=0.7):
    """Compute edge cost based on distance and congestion"""
    return distance * (1.0 + alpha * congestion)

def build_weighted_graph(adj, dist, node_congestion):
    """
    Build NetworkX graph with congestion-aware edge weights
    
    Args:
        adj: [N, N] adjacency matrix (0/1)
        dist: [N, N] distance matrix
        node_congestion: [N] congestion score per node
    Returns:
        NetworkX Graph with weighted edges
    """
    if nx is None:
        raise ImportError("networkx is required for traffic routing")
    
    G = nx.Graph()
    N = adj.shape[0]
    
    for i in range(N):
        for j in range(N):
            if adj[i, j] == 1:
                G.add_edge(
                    i, j,
                    weight=edge_cost(dist[i, j], node_congestion[j]),
                    distance=dist[i, j]
                )
    return G

def compute_routes(G, src, dst):
    """Compute optimized and shortest routes"""
    if nx is None:
        raise ImportError("networkx is required for traffic routing")
    
    optimized_route = nx.shortest_path(G, src, dst, weight="weight")
    shortest_route = nx.shortest_path(G, src, dst, weight="distance")
    return optimized_route, shortest_route

def assign_route(p_optimized=0.75):
    """Probabilistic assignment to prevent route collapse"""
    return "optimized" if random.random() < p_optimized else "shortest"

class RouteStabilityGuard:
    """Prevents sudden routing policy shifts that cause oscillations"""
    def __init__(self, initial_split=0.75, max_change=0.15):
        self.current_split = initial_split
        self.max_change = max_change
    
    def smooth(self, target_split):
        delta = target_split - self.current_split
        delta = np.clip(delta, -self.max_change, self.max_change)
        self.current_split += delta
        return self.current_split

class TrafficRouter:
    """
    Model-2 Controller: Consumes Model-1 predictions and returns user routes
    
    Design goals:
    - Minimal and deterministic
    - Industry-grade
    - Congestion-aware routing
    - Stability guards to avoid oscillations
    """
    def __init__(self, adj, dist, alpha=0.7):
        self.adj = adj
        self.dist = dist
        self.alpha = alpha
        self.guard = RouteStabilityGuard()
    
    def route(self, pred_traffic, src, dst):
        """
        Compute route based on traffic predictions
        
        Args:
            pred_traffic: np.ndarray [N, H] - Model-1 output (predicted traffic)
            src, dst: source & destination nodes
        Returns:
            dict with route information
        """
        # Aggregate future congestion
        node_cong = aggregate_congestion(pred_traffic)
        
        # Build congestion-aware graph
        G = build_weighted_graph(self.adj, self.dist, node_cong)
        
        # Compute dual routes
        opt_route, short_route = compute_routes(G, src, dst)
        
        # Stable flow split
        split = self.guard.smooth(0.75)
        
        # Assign route
        choice = assign_route(split)
        
        return {
            "chosen_route": opt_route if choice == "optimized" else short_route,
            "optimized_route": opt_route,
            "shortest_route": short_route,
            "split_ratio": split,
            "policy": choice
        }

print("‚úÖ Traffic routing controller loaded")

## 9. Traffic Simulator

In [None]:
class TrafficSimulator:
    """
    Closed-loop traffic flow simulator
    
    Simulates how routing decisions modify traffic state.
    This is a SYSTEM simulator, not ML.
    """
    def __init__(self, adj, decay=0.85):
        """
        Args:
            adj: adjacency matrix [N, N]
            decay: how fast congestion dissipates (0-1)
        """
        self.adj = adj
        self.decay = decay
    
    def apply_routes(self, traffic, route, load=1.0):
        """
        Apply route to traffic state
        
        Args:
            traffic: [N] current congestion
            route: list of nodes (route path)
            load: traffic volume
        Returns:
            Updated traffic state
        """
        traffic = traffic.copy()
        for node in route:
            traffic[node] += load
        return traffic
    
    def step(self, traffic, routes):
        """
        Simulate one time step
        
        Args:
            traffic: [N] current congestion
            routes: list of (route, load) tuples
        Returns:
            Updated traffic state after one time step
        """
        # Decay old congestion
        traffic = traffic * self.decay
        
        # Apply routes
        for route, load in routes:
            traffic = self.apply_routes(traffic, route, load)
        
        return traffic

print("‚úÖ Traffic simulator loaded")

## 10. Example Usage & Demo

In [None]:
# =======================
# Example: Quick Model Test
# =======================
def test_model_forward():
    """Test model forward pass"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Create dummy data
    B, T, N, F = 2, 48, 20, 5
    x = torch.randn(B, T, N, F).to(device)
    adj = torch.eye(N).to(device)
    
    # Create model
    model = NewtonGraphMamba(
        in_features=F,
        d_model=32,
        num_nodes=N,
        num_layers=2,
        prediction_horizon=12
    ).to(device)
    
    # Forward pass
    y = model(x, adj)
    
    print(f"‚úÖ Forward pass successful!")
    print(f"   Input shape:  {x.shape}")
    print(f"   Output shape: {y.shape}")
    print(f"   Expected:     [B={B}, N={N}, H=12]")
    
    return model, y

# Uncomment to test:
# test_model_forward()

In [None]:
# =======================
# Example: Training (commented out)
# =======================
# Uncomment and run to train the model:
#
# model, best_val = train_model(
#     data_path="data/metr_la/metr_la.npz",
#     adj_path="data/metr_la/adj.npy",
#     max_epochs=100,
#     device_mem_gb=4,
#     seed=42
# )

In [None]:
# =======================
# Example: Evaluation (commented out)
# =======================
# Uncomment and run to evaluate the model:
#
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# data = np.load("data/metr_la/metr_la.npz")["data"]
# model = NewtonGraphMamba(
#     in_features=data.shape[-1],
#     num_nodes=data.shape[1]
# ).to(device)
#
# results = evaluate_model(
#     model,
#     data_path="data/metr_la/metr_la.npz",
#     adj_path="data/metr_la/adj.npy",
#     model_path="best_model.pt",
#     mc_runs=20
# )

In [None]:
# =======================
# Example: Closed-Loop Traffic Control
# =======================
def demo_closed_loop_control():
    """
    Demo of closed-loop traffic control system:
    1. Model-1 predicts future traffic
    2. Model-2 (Router) computes routes based on predictions
    3. Simulator updates traffic state
    4. Loop continues
    """
    if nx is None:
        print("‚ö†Ô∏è  networkx not available. Skipping closed-loop demo.")
        return
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Setup (dummy data for demo)
    N = 20
    H = 12  # prediction horizon
    
    # Create dummy adjacency and distance matrices
    adj = np.random.randint(0, 2, (N, N))
    adj = adj | adj.T  # Make symmetric
    np.fill_diagonal(adj, 0)  # No self-loops
    
    dist = np.random.rand(N, N) * 10
    dist = (dist + dist.T) / 2  # Symmetric
    
    # Initialize model-1 (forecaster)
    model = NewtonGraphMamba(
        in_features=5,
        num_nodes=N,
        d_model=32,
        num_layers=2
    ).to(device)
    
    # Initialize model-2 (router)
    router = TrafficRouter(adj, dist, alpha=0.7)
    
    # Initialize simulator
    simulator = TrafficSimulator(adj, decay=0.85)
    
    # Initial traffic state
    current_traffic = np.random.rand(N)
    
    print("üö¶ Closed-Loop Traffic Control Demo")
    print("=" * 50)
    
    # Simulate a few time steps
    for step in range(5):
        print(f"\n--- Time Step {step + 1} ---")
        
        # Step 1: Prepare input for Model-1
        # (In real scenario, this would be historical traffic data)
        B, T = 1, 48
        historical_data = torch.randn(B, T, N, 5).to(device)
        adj_tensor = torch.tensor(adj).float().to(device)
        
        # Step 2: Model-1 predicts future traffic
        model.eval()
        with torch.no_grad():
            pred_traffic = model(historical_data, adj_tensor)
            pred_traffic = pred_traffic[0].cpu().numpy()  # [N, H]
        
        print(f"Predicted traffic shape: {pred_traffic.shape}")
        print(f"Avg predicted congestion: {pred_traffic.mean():.3f}")
        
        # Step 3: Model-2 computes routes
        src, dst = 0, N - 1
        route_info = router.route(pred_traffic, src, dst)
        
        print(f"Route from {src} to {dst}:")
        print(f"  Chosen: {route_info['chosen_route'][:5]}..." if len(route_info['chosen_route']) > 5 else f"  Chosen: {route_info['chosen_route']}")
        print(f"  Policy: {route_info['policy']}")
        print(f"  Split ratio: {route_info['split_ratio']:.2f}")
        
        # Step 4: Simulator updates traffic
        routes_to_apply = [(route_info['chosen_route'], 1.0)]
        current_traffic = simulator.step(current_traffic, routes_to_apply)
        
        print(f"Current traffic state: avg={current_traffic.mean():.3f}, max={current_traffic.max():.3f}")
    
    print("\n‚úÖ Closed-loop demo complete!")
    print("=" * 50)

# Uncomment to run demo:
# demo_closed_loop_control()

---

## üìù Notes

- **Data Paths**: Update `data_path` and `adj_path` in training/evaluation functions to match your data location
- **GPU Memory**: Adjust `device_mem_gb` parameter based on your GPU capacity
- **Model Checkpoints**: Training automatically saves checkpoints and best model
- **NetworkX**: Required for traffic routing features (`pip install networkx`)
- **Mamba SSM**: Optional but recommended for better performance (`pip install mamba-ssm`)

## üéØ Quick Start

1. **Load all cells** to import all functions and classes
2. **Test model** by running the `test_model_forward()` function
3. **Train model** by uncommenting and running the training cell (update data paths first)
4. **Evaluate model** by uncommenting and running the evaluation cell
5. **Demo closed-loop** by uncommenting and running the closed-loop control demo

---

**Project consolidated and ready to use!** üöÄ