# Project Newton-GraphMamba Training

This notebook trains the Graph-Mamba model on Kaggle's free GPU tier.

## Setup Instructions
1. Create a new Kaggle Notebook
2. Enable GPU (Settings → Accelerator → GPU T4 x2)
3. Run all cells sequentially
4. Download the trained model from Output section


## Step 1: Magic Install Block (Run FIRST!)

This forces Mamba to install correctly on T4 GPUs.


In [None]:
# Run this cell FIRST in Kaggle
# This installs PyTorch with CUDA 11.8 support
!pip install torch==2.1.2+cu118 torchvision==0.16.2+cu118 --extra-index-url https://download.pytorch.org/whl/cu118
!pip install ninja packaging
!pip install causal-conv1d>=1.1.0
!pip install mamba-ssm --no-build-isolation

# Additional dependencies
!pip install torch-geometric osmnx networkx pandas numpy tqdm


## Step 2: Import Libraries


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
import osmnx as ox
import networkx as nx
import numpy as np
import pandas as pd
from tqdm import tqdm
import json
from pathlib import Path

# Import our model
from newton_graphmamba import NewtonGraphMamba
from graph_utils import networkx_to_pyg

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")


## Step 3: Download City Graph Data


In [None]:
# Download city graph using OSMnx
place_name = "San Francisco, California, USA"

print(f"Downloading graph for {place_name}...")
G = ox.graph_from_place(place_name, network_type='drive')

print(f"Graph nodes: {G.number_of_nodes()}")
print(f"Graph edges: {G.number_of_edges()}")

# Convert to PyTorch Geometric format
pyg_data = networkx_to_pyg(G)
print(f"PyG Data: {pyg_data}")

# Save graph data for later use
torch.save(pyg_data, 'city_graph.pt')
print("Graph saved to city_graph.pt")


## Step 4: Initialize Model


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Model hyperparameters
num_nodes = pyg_data.num_nodes
node_features = pyg_data.x.shape[1] if pyg_data.x is not None else 6
d_model = 256
n_layers = 4
d_state = 16
d_conv = 4
expand = 2

# Initialize model
model = NewtonGraphMamba(
    num_nodes=num_nodes,
    node_features=node_features,
    d_model=d_model,
    n_layers=n_layers,
    d_state=d_state,
    d_conv=d_conv,
    expand=expand
).to(device)

print(f"Model initialized with {sum(p.numel() for p in model.parameters())} parameters")
print(f"Model device: {next(model.parameters()).device}")


## Step 5: Prepare Training Data

**Note:** Replace this with your actual trajectory data loader. This is a placeholder.


In [None]:
def generate_synthetic_routes(num_routes=1000, seq_len=100, num_nodes=None):
    """Generate synthetic route data for training.
    
    Replace this with your actual trajectory data loader.
    """
    routes = []
    vehicle_ids = []
    
    for _ in range(num_routes):
        # Random walk on graph
        route = torch.randint(0, num_nodes, (seq_len,))
        routes.append(route)
        vehicle_ids.append(torch.randint(0, 100, (1,)).item())
    
    return routes, vehicle_ids

# Generate training data
train_routes, train_vehicle_ids = generate_synthetic_routes(
    num_routes=1000,
    seq_len=100,
    num_nodes=num_nodes
)

print(f"Generated {len(train_routes)} training routes")
print(f"Route length: {len(train_routes[0])}")


## Step 6: Training Loop


In [None]:
# Training hyperparameters
batch_size = 32
learning_rate = 1e-4
num_epochs = 10

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)

# Loss function
criterion = nn.CrossEntropyLoss()

# Training loop
model.train()
losses = []

for epoch in range(num_epochs):
    epoch_losses = []
    
    # Batch training
    for i in tqdm(range(0, len(train_routes), batch_size), desc=f"Epoch {epoch+1}/{num_epochs}"):
        batch_routes = train_routes[i:i+batch_size]
        batch_vehicle_ids = train_vehicle_ids[i:i+batch_size]
        
        # Pad sequences to same length
        max_len = max(len(r) for r in batch_routes)
        padded_routes = []
        for route in batch_routes:
            padded = F.pad(route, (0, max_len - len(route)), value=0)
            padded_routes.append(padded)
        
        routes_tensor = torch.stack(padded_routes).to(device)
        vehicle_ids_tensor = torch.tensor(batch_vehicle_ids, dtype=torch.long).to(device)
        
        # Forward pass
        optimizer.zero_grad()
        
        # Use routes as both graph_input and sequence_input for demo
        # In real implementation, graph_input would be node features
        logits = model.forward_decoder(
            routes_tensor,
            vehicle_id=vehicle_ids_tensor,
            graph_memory=None  # Will be computed from graph
        )
        
        # Create targets (next node prediction)
        # Shift routes by 1 for next-node prediction
        targets = routes_tensor[:, 1:].contiguous().view(-1)
        logits_flat = logits[:, :-1].contiguous().view(-1, num_nodes)
        
        loss = criterion(logits_flat, targets)
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        epoch_losses.append(loss.item())
    
    avg_loss = np.mean(epoch_losses)
    losses.append(avg_loss)
    print(f"Epoch {epoch+1}/{num_epochs} - Average Loss: {avg_loss:.4f}")

print("\nTraining completed!")


In [None]:
# Save model weights
model_path = "newton_mamba_v1.pt"
torch.save({
    'model_state_dict': model.state_dict(),
    'model_config': {
        'num_nodes': num_nodes,
        'node_features': node_features,
        'd_model': d_model,
        'n_layers': n_layers,
        'd_state': d_state,
        'd_conv': d_conv,
        'expand': expand
    },
    'training_losses': losses
}, model_path)

print(f"Model saved to {model_path}")
print(f"\nDownload this file from the Output section of Kaggle!")
