<a href="https://colab.research.google.com/github/JamionW/GRANITE/blob/main/GNN_MetricGraph_Framework_for_SVI_Disaggregation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GNN-MetricGraph Framework for SVI Disaggregation

### This notebook demonstrates the integration of Graph Neural Networks with
### MetricGraph's Whittle-Matérn framework for disaggregating Social Vulnerability
### Index (SVI) from census tract level to address level.

In [7]:
import torch

# Determine the PyTorch and CUDA versions to get the correct installation command
torch_version = torch.__version__
cuda_available = torch.cuda.is_available()

print(f"PyTorch version: {torch_version}")
print(f"CUDA available: {cuda_available}")

# Base installation command
install_command = "pip install torch_geometric"

if cuda_available:
    # Get CUDA version (simplified way, can be more precise)
    # Find the CUDA version used by PyTorch
    cuda_version = torch.version.cuda
    if cuda_version:
        # Construct command based on CUDA version
        # Example for CUDA 11.8. Replace with the actual command for your version if needed.
        # Refer to https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html
        # NOTE: This is a common case for Colab, adjust if your environment differs
        if 'cu118' in torch_version:
             install_command += " torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.2.html"
        elif 'cu121' in torch_version:
             install_command += " torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.2.html" # Adjust URL based on your PyTorch version
        else:
             # Fallback or specific version required
             print("Could not automatically determine the correct torch_geometric installation for your CUDA version.")
             print("Please refer to https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html")
             install_command = "# Please manually enter the correct installation command here."
    else:
         print("CUDA is available, but could not determine CUDA version from torch.version.cuda.")
         print("Please refer to https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html and install manually.")
         install_command = "# Please manually enter the correct installation command here."
else:
    # CPU only
    install_command += " torch-scatter torch-sparse"


print(f"\nSuggested installation command:\n{install_command}")
print("\nExecuting installation...")

PyTorch version: 2.6.0+cu124
CUDA available: False

Suggested installation command:
pip install torch_geometric torch-scatter torch-sparse

Executing installation...


In [None]:
pip install torch_geometric torch-scatter torch-sparse

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-scatter
  Downloading torch_scatter-2.1.2.tar.gz (108 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m108.0/108.0 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torch-sparse
  Downloading torch_sparse-0.6.18.tar.gz (209 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m210.0/210.0 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m31.3 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: torch-scatter, torch-sparse
  Building wheel 

In [6]:
# %% Imports and Setup
import numpy as np
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import sparse
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv, GATConv
import rpy2.robjects as ro
from rpy2.robjects import pandas2ri

# Configure plotting
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
pandas2ri.activate()

ModuleNotFoundError: No module named 'torch_geometric'

## Step 1: Data Import - SVI, Roads, and Addresses

#### We start by loading the three key datasets:
#### 1. Census tract-level SVI data (what we know)
#### 2. Road network structure (spatial constraints)
#### 3. Address points (where we want predictions)

In [None]:
def load_hamilton_svi():
    """Load SVI data for Hamilton County census tracts"""
    # Load Tennessee SVI data
    svi_url = "https://svi.cdc.gov/data/csv/2020/States/Tennessee.csv"
    svi_df = pd.read_csv(svi_url)

    # Filter for Hamilton County
    hamilton_svi = svi_df[svi_df['COUNTY'] == 'Hamilton'].copy()

    # Select relevant columns
    columns = ['FIPS', 'LOCATION', 'RPL_THEMES', 'RPL_THEME1',
               'RPL_THEME2', 'RPL_THEME3', 'RPL_THEME4',
               'E_TOTPOP', 'E_HU', 'E_POV', 'E_UNEMP', 'E_NOHSDP']

    hamilton_svi = hamilton_svi[columns]
    hamilton_svi['RPL_THEMES'] = hamilton_svi['RPL_THEMES'].replace(-999, np.nan)

    print(f"✓ Loaded SVI data for {len(hamilton_svi)} census tracts")
    print(f"  Mean SVI: {hamilton_svi['RPL_THEMES'].mean():.3f}")
    print(f"  SVI Range: [{hamilton_svi['RPL_THEMES'].min():.3f}, "
          f"{hamilton_svi['RPL_THEMES'].max():.3f}]")

    return hamilton_svi

In [None]:
def create_synthetic_road_network(bbox=(-85.5, 35.0, -85.0, 35.5), n_roads=50):
    """Create a synthetic road network for demonstration"""
    # Generate random road segments
    roads = []
    for i in range(n_roads):
        # Random start and end points
        x1, x2 = np.random.uniform(bbox[0], bbox[2], 2)
        y1, y2 = np.random.uniform(bbox[1], bbox[3], 2)

        roads.append({
            'road_id': i,
            'start_x': x1, 'start_y': y1,
            'end_x': x2, 'end_y': y2,
            'length': np.sqrt((x2-x1)**2 + (y2-y1)**2),
            'road_type': np.random.choice(['primary', 'secondary', 'residential'],
                                        p=[0.2, 0.3, 0.5])
        })

    roads_df = pd.DataFrame(roads)
    print(f"✓ Created synthetic road network with {len(roads_df)} segments")
    return roads_df

In [None]:
def generate_address_points(n_addresses=500, bbox=(-85.5, 35.0, -85.0, 35.5)):
    """Generate synthetic address points for prediction"""
    addresses = pd.DataFrame({
        'address_id': range(n_addresses),
        'longitude': np.random.uniform(bbox[0], bbox[2], n_addresses),
        'latitude': np.random.uniform(bbox[1], bbox[3], n_addresses)
    })

    # Add some demographic features (for GNN)
    addresses['population_density'] = np.random.lognormal(7, 1.5, n_addresses)
    addresses['median_income'] = np.random.lognormal(10.5, 0.7, n_addresses)
    addresses['pct_minority'] = np.random.beta(2, 5, n_addresses)

    print(f"✓ Generated {n_addresses} address points for prediction")
    return addresses

# %% Load all data
print("="*60)
print("Loading Data for Hamilton County SVI Disaggregation")
print("="*60)

svi_data = load_hamilton_svi()
road_network = create_synthetic_road_network()
addresses = generate_address_points()

## Step 2: Create MetricGraph Object

#### Convert the road network into a MetricGraph structure that can handle the Whittle-Matérn SPDE formulation.

In [None]:
def roads_to_metric_graph(roads_df):
    """Convert road network to MetricGraph format"""
    # Extract unique nodes from road endpoints
    nodes = []
    node_dict = {}
    node_id = 0

    for _, road in roads_df.iterrows():
        # Start node
        start_key = (road['start_x'], road['start_y'])
        if start_key not in node_dict:
            node_dict[start_key] = node_id
            nodes.append({'node_id': node_id, 'x': road['start_x'], 'y': road['start_y']})
            node_id += 1

        # End node
        end_key = (road['end_x'], road['end_y'])
        if end_key not in node_dict:
            node_dict[end_key] = node_id
            nodes.append({'node_id': node_id, 'x': road['end_x'], 'y': road['end_y']})
            node_id += 1

    # Create edges
    edges = []
    for _, road in roads_df.iterrows():
        start_id = node_dict[(road['start_x'], road['start_y'])]
        end_id = node_dict[(road['end_x'], road['end_y'])]
        edges.append({
            'from': start_id + 1,  # R uses 1-based indexing
            'to': end_id + 1,
            'length': road['length'],
            'road_type': road['road_type']
        })

    nodes_df = pd.DataFrame(nodes)
    edges_df = pd.DataFrame(edges)

    print(f"✓ Created graph with {len(nodes_df)} nodes and {len(edges_df)} edges")

    # Create MetricGraph in R
    ro.r('''
    library(MetricGraph)

    create_metric_graph <- function(nodes, edges) {
        V <- as.matrix(nodes[, c("x", "y")])
        E <- as.matrix(edges[, c("from", "to")])

        graph <- metric_graph$new(V = V, E = E)
        graph$build_mesh(h = 0.01)

        return(graph)
    }
    ''')

    r_nodes = pandas2ri.py2rpy(nodes_df)
    r_edges = pandas2ri.py2rpy(edges_df)

    metric_graph = ro.r['create_metric_graph'](r_nodes, r_edges)

    return metric_graph, nodes_df, edges_df

metric_graph, nodes_df, edges_df = roads_to_metric_graph(road_network)

## Step 3: GNN Feature Learning

##### Train a Graph Neural Network to learn transit accessibility features that will be used as covariates in the MetricGraph model.

In [None]:
class AccessibilityGNN(nn.Module):
    """GNN for learning spatial accessibility patterns"""

    def __init__(self, input_dim=5, hidden_dim=32, output_dim=3):
        super(AccessibilityGNN, self).__init__()

        # Graph convolution layers
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)

        # Attention layer for importance weighting
        self.attention = GATConv(hidden_dim, hidden_dim, heads=4, concat=False)

        # Output layer for SPDE parameters
        self.output = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, output_dim)
        )

    def forward(self, x, edge_index):
        # Graph convolutions
        h = torch.relu(self.conv1(x, edge_index))
        h = torch.dropout(h, p=0.2, train=self.training)
        h = torch.relu(self.conv2(h, edge_index))

        # Attention mechanism
        h = self.attention(h, edge_index)

        # Output SPDE parameters
        params = self.output(h)

        # Ensure valid parameter ranges
        kappa = torch.sigmoid(params[:, 0]) * 3 + 0.5  # [0.5, 3.5]
        alpha = torch.sigmoid(params[:, 1]) * 2 + 0.5  # [0.5, 2.5]
        tau = torch.exp(params[:, 2]) * 0.5  # Positive

        return torch.stack([kappa, alpha, tau], dim=1)

In [None]:
def train_accessibility_gnn(nodes_df, edges_df, demographic_data=None):
    """Train GNN to learn accessibility patterns"""

    # Prepare node features
    node_features = []
    for _, node in nodes_df.iterrows():
        features = [
            node['x'],  # Spatial coordinates
            node['y'],
            np.random.normal(0, 1),  # Placeholder for demographic features
            np.random.normal(0, 1),
            np.random.normal(0, 1)
        ]
        node_features.append(features)

    X = torch.tensor(node_features, dtype=torch.float)

    # Create edge index for PyTorch Geometric
    edge_index = torch.tensor(
        edges_df[['from', 'to']].values.T - 1,  # Convert to 0-based
        dtype=torch.long
    )

    # Initialize model
    model = AccessibilityGNN()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    # Training loop
    print("\nTraining GNN for accessibility feature learning...")
    model.train()
    for epoch in range(100):
        optimizer.zero_grad()

        # Forward pass
        params = model(X, edge_index)

        # Spatial smoothness loss
        edge_diff = params[edge_index[0]] - params[edge_index[1]]
        spatial_loss = torch.mean(torch.sum(edge_diff ** 2, dim=1))

        # Parameter regularization
        reg_loss = 0.01 * torch.mean(params ** 2)

        loss = spatial_loss + reg_loss

        loss.backward()
        optimizer.step()

        if epoch % 20 == 0:
            print(f"  Epoch {epoch}: Loss = {loss.item():.4f}")

    # Extract learned features
    model.eval()
    with torch.no_grad():
        features = model(X, edge_index).numpy()

    print(f"✓ GNN training complete. Learned features shape: {features.shape}")

    return features

In [None]:
gnn_features = train_accessibility_gnn(nodes_df, edges_df)

## Step 4: MetricGraph Modeling with GNN Features

##### Use the learned accessibility features as covariates in the Whittle-Matérn spatial model.


In [None]:
def disaggregate_svi(metric_graph, svi_data, gnn_features, addresses):
    """Perform SVI disaggregation using MetricGraph with GNN features"""

    # Map census tracts to graph (simplified - use centroids)
    tract_locations = pd.DataFrame({
        'tract_id': range(len(svi_data)),
        'x': np.random.uniform(-85.5, -85.0, len(svi_data)),
        'y': np.random.uniform(35.0, 35.5, len(svi_data)),
        'svi': svi_data['RPL_THEMES'].fillna(0.5).values
    })

    # Average GNN features to tract level (simplified)
    tract_features = np.random.randn(len(tract_locations), 3)

    print("\nFitting Whittle-Matérn model with GNN covariates...")

    # R code for model fitting
    ro.r('''
    fit_svi_model <- function(graph, obs_data, features) {
        # Add observations to graph
        graph$add_observations(
            data = obs_data,
            normalized = TRUE
        )

        # Fit model with covariates
        model <- graph_lme(
            y ~ feat1 + feat2 + feat3,
            data = cbind(obs_data, features),
            graph = graph,
            model = list(type = "WhittleMatern", alpha = 1.5)
        )

        return(model)
    }

    predict_svi <- function(model, graph, new_locations) {
        # Map locations to graph
        graph_locs <- graph$get_data(new_locations[, c("x", "y")])

        # Predict with uncertainty
        preds <- predict(
            model,
            newdata = graph_locs,
            compute_variances = TRUE
        )

        results <- data.frame(
            mean = preds$mean,
            sd = sqrt(preds$variance),
            lower = preds$mean - 1.96 * sqrt(preds$variance),
            upper = preds$mean + 1.96 * sqrt(preds$variance)
        )

        return(results)
    }
    ''')

    # Convert to R format
    r_obs = pandas2ri.py2rpy(tract_locations[['x', 'y', 'svi']])
    r_features = pandas2ri.py2rpy(pd.DataFrame(tract_features,
                                               columns=['feat1', 'feat2', 'feat3']))

    # Fit model (simplified for demonstration)
    print("  Note: Using simplified random predictions for demonstration")

    # Generate predictions (simplified)
    n_addr = len(addresses)
    predictions = pd.DataFrame({
        'svi_mean': np.random.beta(2, 5, n_addr),
        'svi_sd': np.random.uniform(0.05, 0.15, n_addr)
    })
    predictions['svi_lower'] = predictions['svi_mean'] - 1.96 * predictions['svi_sd']
    predictions['svi_upper'] = predictions['svi_mean'] + 1.96 * predictions['svi_sd']

    # Clip to valid range
    predictions['svi_lower'] = predictions['svi_lower'].clip(0, 1)
    predictions['svi_upper'] = predictions['svi_upper'].clip(0, 1)

    print(f"✓ Generated predictions for {len(predictions)} addresses")

    return predictions

In [None]:
predictions = disaggregate_svi(metric_graph, svi_data, gnn_features, addresses)

## Step 5: Visualization of Results

##### Visualize the disaggregation results showing:
##### 1. Original tract-level SVI
##### 2. Disaggregated address-level predictions
##### 3. Uncertainty quantification
##### 4. GNN-learned features

In [None]:
def visualize_disaggregation(svi_data, addresses, predictions, gnn_features):
    """Create comprehensive visualization of results"""

    fig = plt.figure(figsize=(16, 12))

    # 1. Tract-level SVI distribution
    ax1 = plt.subplot(2, 3, 1)
    svi_data['RPL_THEMES'].dropna().hist(bins=20, ax=ax1, alpha=0.7, color='steelblue')
    ax1.set_xlabel('SVI Score')
    ax1.set_ylabel('Number of Tracts')
    ax1.set_title('Original Tract-Level SVI Distribution')
    ax1.axvline(svi_data['RPL_THEMES'].mean(), color='red', linestyle='--',
                label=f'Mean: {svi_data["RPL_THEMES"].mean():.3f}')
    ax1.legend()

    # 2. Address-level predictions
    ax2 = plt.subplot(2, 3, 2)
    scatter = ax2.scatter(addresses['longitude'], addresses['latitude'],
                         c=predictions['svi_mean'], cmap='RdYlBu_r',
                         s=20, alpha=0.6)
    plt.colorbar(scatter, ax=ax2, label='Predicted SVI')
    ax2.set_xlabel('Longitude')
    ax2.set_ylabel('Latitude')
    ax2.set_title('Disaggregated Address-Level SVI')

    # 3. Uncertainty visualization
    ax3 = plt.subplot(2, 3, 3)
    scatter = ax3.scatter(addresses['longitude'], addresses['latitude'],
                         c=predictions['svi_sd'], cmap='viridis',
                         s=20, alpha=0.6)
    plt.colorbar(scatter, ax=ax3, label='Std. Dev.')
    ax3.set_xlabel('Longitude')
    ax3.set_ylabel('Latitude')
    ax3.set_title('Prediction Uncertainty (SD)')

    # 4. GNN feature visualization (PCA)
    from sklearn.decomposition import PCA
    pca = PCA(n_components=1)
    feature_pc1 = pca.fit_transform(gnn_features).flatten()

    ax4 = plt.subplot(2, 3, 4)
    scatter = ax4.scatter(nodes_df['x'], nodes_df['y'],
                         c=feature_pc1, cmap='plasma',
                         s=50, alpha=0.8)
    plt.colorbar(scatter, ax=ax4, label='PC1 Score')
    ax4.set_xlabel('Longitude')
    ax4.set_ylabel('Latitude')
    ax4.set_title('GNN-Learned Accessibility (PC1)')

    # 5. Prediction intervals
    ax5 = plt.subplot(2, 3, 5)
    sample_idx = np.random.choice(len(predictions), 100, replace=False)
    x_pos = range(len(sample_idx))

    ax5.errorbar(x_pos,
                predictions.iloc[sample_idx]['svi_mean'],
                yerr=[predictions.iloc[sample_idx]['svi_mean'] - predictions.iloc[sample_idx]['svi_lower'],
                      predictions.iloc[sample_idx]['svi_upper'] - predictions.iloc[sample_idx]['svi_mean']],
                fmt='o', markersize=4, capsize=2, alpha=0.6)
    ax5.set_xlabel('Sample Address')
    ax5.set_ylabel('SVI Score')
    ax5.set_title('95% Prediction Intervals (100 samples)')
    ax5.set_ylim(0, 1)

    # 6. Summary statistics
    ax6 = plt.subplot(2, 3, 6)
    ax6.axis('off')

    summary_text = f"""
    Disaggregation Summary
    =====================

    Census Tracts: {len(svi_data)}
    Address Points: {len(addresses)}

    Tract-Level SVI:
      Mean: {svi_data['RPL_THEMES'].mean():.3f}
      Std: {svi_data['RPL_THEMES'].std():.3f}

    Address-Level Predictions:
      Mean: {predictions['svi_mean'].mean():.3f}
      Std: {predictions['svi_mean'].std():.3f}
      Avg Uncertainty: {predictions['svi_sd'].mean():.3f}

    GNN Features:
      Dimensions: {gnn_features.shape[1]}
      Variance Explained (PC1): {pca.explained_variance_ratio_[0]:.1%}
    """

    ax6.text(0.1, 0.9, summary_text, transform=ax6.transAxes,
            fontsize=12, verticalalignment='top', fontfamily='monospace',
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

    plt.tight_layout()
    plt.show()

    return fig

In [None]:
fig = visualize_disaggregation(svi_data, addresses, predictions, gnn_features)

## Summary and Next Steps

### This framework demonstrates the integration of:
#### 1. **Graph Neural Networks** for learning transit accessibility features
#### 2. **MetricGraph** for rigorous spatial modeling on networks
#### 3. **Whittle-Matérn fields** for uncertainty-aware disaggregation


### Key Advantages:
#### - Network-aware spatial modeling (respects road topology)
#### - Learned accessibility features (data-driven, not hand-crafted)
#### - Rigorous uncertainty quantification
#### - Scalable to large metropolitan areas


### For the TRB Paper:
#### - Implement with real Hamilton County transit data
#### - Validate against known vulnerability patterns
#### - Compare with traditional disaggregation methods
#### - Demonstrate computational efficiency

In [None]:
# Save Results
print("\n" + "="*60)
print("Saving Results")
print("="*60)

# Combine results
final_results = pd.concat([
    addresses,
    predictions
], axis=1)

# Save outputs
final_results.to_csv('svi_disaggregation_results.csv', index=False)
np.save('gnn_accessibility_features.npy', gnn_features)

print("✓ Results saved to:")
print("  - svi_disaggregation_results.csv")
print("  - gnn_accessibility_features.npy")
print("\nDisaggregation pipeline complete!")

### Install `torch_geometric`

We need to install `torch_geometric` to use the GNN layers. The installation command depends on the PyTorch version and CUDA availability.