## 1. Load Required Libraries

In [25]:
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import networkx as nx
import matplotlib.pyplot as plt
from nilearn import datasets, plotting
from nilearn.maskers import NiftiMapsMasker
from nilearn.connectome import ConnectivityMeasure
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data


## 2. Load fMRI Dataset

In [26]:
# Fetch the dataset (all subjects)
development_dataset = datasets.fetch_development_fmri(n_subjects=None)
print("fMRI dataset loaded successfully.")


[get_dataset_dir] Dataset found in /Users/anushamourshed/nilearn_data/development_fmri
[get_dataset_dir] Dataset found in /Users/anushamourshed/nilearn_data/development_fmri/development_fmri
[get_dataset_dir] Dataset found in /Users/anushamourshed/nilearn_data/development_fmri/development_fmri
fMRI dataset loaded successfully.


## 3. Load the Brain Atlas (MSDL)

In [27]:
# Fetch the MSDL atlas
msdl_data = datasets.fetch_atlas_msdl()

# Print basic info
msdl_coords = msdl_data.region_coords
n_regions = len(msdl_coords)

print(f"MSDL has {n_regions} ROIs, part of the following networks:\n{msdl_data.networks}")


[get_dataset_dir] Dataset found in /Users/anushamourshed/nilearn_data/msdl_atlas
MSDL has 39 ROIs, part of the following networks:
['Aud', 'Aud', 'Striate', 'DMN', 'DMN', 'DMN', 'DMN', 'Occ post', 'Motor', 'R V Att', 'R V Att', 'R V Att', 'R V Att', 'Basal', 'L V Att', 'L V Att', 'L V Att', 'D Att', 'D Att', 'Vis Sec', 'Vis Sec', 'Vis Sec', 'Salience', 'Salience', 'Salience', 'Temporal', 'Temporal', 'Language', 'Language', 'Language', 'Language', 'Language', 'Cereb', 'Dors PCC', 'Cing-Ins', 'Cing-Ins', 'Cing-Ins', 'Ant IPS', 'Ant IPS']


## 4. Initialize the Masker

In [28]:
# Set cache directory for Nilearn
cache_dir = os.path.expanduser("~/nilearn_cache")

# Create the masker object
masker = NiftiMapsMasker(
    msdl_data.maps,
    resampling_target="data",
    t_r=2,
    detrend=True,
    low_pass=0.1,
    high_pass=0.01,
    memory=cache_dir,  # Cache results
    memory_level=1,
    standardize="zscore_sample",
    standardize_confounds=True,
).fit()
print("Masker initialized successfully.")


Masker initialized successfully.


## 5. Extract Time Series and Save Data

In [29]:
# Check if saved time series exists
if os.path.exists("fmri_time_series.npz"):
    print("Loading saved time series...")
    data = np.load("fmri_time_series.npz", allow_pickle=True)
    pooled_subjects = data["pooled_subjects"]
    children = data["children"]
    groups = data["groups"]
else:
    print("Extracting time series (this will be saved for future use)...")

    children = []
    pooled_subjects = []
    groups = []  # Store 'child' or 'adult' labels

    for func_file, confound_file, phenotype in zip(
        development_dataset.func,
        development_dataset.confounds,
        development_dataset.phenotypic["Child_Adult"],
    ):
        time_series = masker.transform(func_file, confounds=confound_file)
        
        pooled_subjects.append(time_series)  # Store all subjects

        if phenotype == "child":
            children.append(time_series)  # Store only children

        groups.append(phenotype)  # Store class labels

    # Save extracted time series and labels
    np.savez("fmri_time_series.npz", pooled_subjects=pooled_subjects, children=children, groups=groups)
    print("Time series and labels saved successfully.")

print(f"Total subjects: {len(pooled_subjects)}")
print(f"Total children: {len(children)}")


Loading saved time series...
Total subjects: 155
Total children: 122


## 6. Compute Correlation Matrices

In [30]:
# Compute correlation matrices
from nilearn.connectome import ConnectivityMeasure

kinds = ["correlation", "partial correlation", "tangent"]
connectivity_matrices = {}

for kind in kinds:
    print(f"Computing {kind} matrices...")
    
    # Compute correlation matrices
    correlation_measure = ConnectivityMeasure(kind=kind, standardize="zscore_sample")
    connectivity_matrices[kind] = correlation_measure.fit_transform(pooled_subjects)
    
    # **Fix: Replace NaNs with 0 to prevent training issues**
    connectivity_matrices[kind] = np.nan_to_num(connectivity_matrices[kind], nan=0.0)

# Print shapes to verify correctness
for kind in kinds:
    print(f"{kind.capitalize()} matrix shape: {connectivity_matrices[kind].shape}")


Computing correlation matrices...
Computing partial correlation matrices...
Computing tangent matrices...
Correlation matrix shape: (155, 39, 39)
Partial correlation matrix shape: (155, 39, 39)
Tangent matrix shape: (155, 39, 39)


## 7. Convert Correlation Matrices to Graph Data


In [43]:
import networkx as nx
import torch
from torch_geometric.data import Data

thresholds = [0.05, 0.1, 0.3]  # Lower thresholds to avoid empty graphs

# Convert labels: 'child' -> 1, 'adult' -> 0
labels = np.array([1 if group == "child" else 0 for group in groups])

def create_graph_data(correlation_matrices, labels, threshold):
    """
    Converts correlation matrices into PyTorch Geometric graph data.
    - Removes NaNs
    - Applies threshold
    - Ensures graphs have valid edges/nodes
    """
    graph_data_list = []
    
    for i, matrix in enumerate(correlation_matrices):
        # **Replace NaNs in adjacency matrix**
        matrix = np.nan_to_num(matrix, nan=0.0)
        
        # **Apply threshold to filter connections**
        adj_matrix = np.where(np.abs(matrix) > threshold, matrix, 0)
        
        # **Ensure matrix is not fully zero before converting to a graph**
        if np.count_nonzero(adj_matrix) == 0:
            print(f"⚠ Warning: Graph {i} for threshold {threshold} is completely empty. Skipping...")
            continue  

        # Convert adjacency matrix to NetworkX graph
        G = nx.from_numpy_array(adj_matrix)
        
        # **Skip empty graphs (no edges)**
        if len(G.edges) == 0:  
            print(f"⚠ Warning: Graph {i} for threshold {threshold} has no edges. Skipping...")
            continue  

        edge_index = torch.tensor(list(G.edges), dtype=torch.long).t().contiguous()
        edge_weight = torch.tensor([matrix[u, v] for u, v in G.edges], dtype=torch.float)

        # **Ensure node features are not empty**
        if len(G.nodes) == 0:
            print(f"⚠ Warning: Graph {i} has no nodes! Skipping...")
            continue  

        # Node features: Degree centrality
        node_features = torch.tensor([[d] for _, d in G.degree()], dtype=torch.float)

        # **Ensure labels are valid integers**
        graph_data = Data(
            x=node_features, 
            edge_index=edge_index, 
            edge_attr=edge_weight, 
            y=torch.tensor(labels[i], dtype=torch.long)  # 0 for adult, 1 for child
        )
        graph_data_list.append(graph_data)
    
    return graph_data_list

# **Generate graphs for selected thresholds**
graph_datasets = {
    kind: {threshold: create_graph_data(connectivity_matrices[kind], labels, threshold) 
           for threshold in thresholds} 
    for kind in ["correlation", "tangent", "partial correlation"]
}

print("✅ Graph data successfully generated for selected thresholds and correlation methods.")


✅ Graph data successfully generated for selected thresholds and correlation methods.


## 8. Splitting into Training and Test Data

In [44]:
from torch_geometric.loader import DataLoader

# Split datasets and create DataLoaders
train_loaders = {}
test_loaders = {}

for kind in kinds:  # Loop over correlation methods
    train_loaders[kind] = {}
    test_loaders[kind] = {}

    for threshold in thresholds:  # Loop over thresholds
        train_graphs, test_graphs = train_test_split(graph_datasets[kind][threshold], test_size=0.2, random_state=42)
        
        train_loaders[kind][threshold] = DataLoader(train_graphs, batch_size=8, shuffle=True)
        test_loaders[kind][threshold] = DataLoader(test_graphs, batch_size=8, shuffle=False)

print("✅ Data loaders created for all thresholds and correlation methods.")


✅ Data loaders created for all thresholds and correlation methods.


## 9. Train GNN Models

In [45]:
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool

class GNNClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GNNClassifier, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        
        # Ensure edge_attr is not None
        x = self.conv1(x, edge_index, edge_weight=edge_attr if edge_attr is not None else None)
        x = F.relu(x)
        x = self.conv2(x, edge_index, edge_weight=edge_attr if edge_attr is not None else None)

        # Ensure batch exists for pooling
        batch = data.batch if hasattr(data, 'batch') else torch.zeros(x.size(0), dtype=torch.long, device=x.device)
        x = global_mean_pool(x, batch)  # Pool over nodes

        x = self.fc(x)
        return F.log_softmax(x, dim=1)

# Define model parameters
input_dim = 1  # Placeholder, update based on actual node features
hidden_dim = 16
output_dim = 2  # Binary classification (child vs. adult)


In [46]:
# Device setup for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dictionary to store models for all combinations
models = {}

for kind in kinds:  # Loop over correlation methods
    models[kind] = {}
    for threshold in thresholds:  # Loop over thresholds
        print(f"🚀 Initializing GNN model for {kind} with threshold {threshold}...")
        
        # Create and move model to device
        model = GNNClassifier(input_dim, hidden_dim, output_dim).to(device)
        models[kind][threshold] = model


🚀 Initializing GNN model for correlation with threshold 0.05...
🚀 Initializing GNN model for correlation with threshold 0.1...
🚀 Initializing GNN model for correlation with threshold 0.3...
🚀 Initializing GNN model for partial correlation with threshold 0.05...
🚀 Initializing GNN model for partial correlation with threshold 0.1...
🚀 Initializing GNN model for partial correlation with threshold 0.3...
🚀 Initializing GNN model for tangent with threshold 0.05...
🚀 Initializing GNN model for tangent with threshold 0.1...
🚀 Initializing GNN model for tangent with threshold 0.3...


In [47]:
import torch.optim as optim

def train_model(model, train_loader, epochs=50, lr=0.001):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.NLLLoss()  # Negative log likelihood loss for classification
    model.train()
    
    for epoch in range(epochs):
        total_loss = 0
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()

            # Forward pass
            output = model(batch)
            loss = criterion(output, batch.y)
            
            # Backpropagation
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss: {total_loss / len(train_loader):.4f}")
    
    print("✅ Training complete!\n")



In [48]:
epochs = 50  # Adjust as needed

for kind in kinds:  # Loop over correlation methods
    for threshold in thresholds:  # Loop over thresholds
        print(f"\n🚀 Training GNN on {kind} graphs with threshold {threshold}...\n")

        # Get model & train loader
        model = models[kind][threshold]
        train_loader = train_loaders[kind][threshold]

        # Train the model
        train_model(model, train_loader, epochs=epochs)



🚀 Training GNN on correlation graphs with threshold 0.05...

Epoch 0, Loss: nan
Epoch 10, Loss: nan
Epoch 20, Loss: nan
Epoch 30, Loss: nan
Epoch 40, Loss: nan
✅ Training complete!


🚀 Training GNN on correlation graphs with threshold 0.1...

Epoch 0, Loss: nan
Epoch 10, Loss: nan
Epoch 20, Loss: nan
Epoch 30, Loss: nan
Epoch 40, Loss: nan
✅ Training complete!


🚀 Training GNN on correlation graphs with threshold 0.3...

Epoch 0, Loss: nan
Epoch 10, Loss: nan
Epoch 20, Loss: nan
Epoch 30, Loss: nan
Epoch 40, Loss: nan
✅ Training complete!


🚀 Training GNN on partial correlation graphs with threshold 0.05...

Epoch 0, Loss: nan
Epoch 10, Loss: nan
Epoch 20, Loss: nan
Epoch 30, Loss: nan
Epoch 40, Loss: nan
✅ Training complete!


🚀 Training GNN on partial correlation graphs with threshold 0.1...

Epoch 0, Loss: 0.5125
Epoch 10, Loss: 0.4761
Epoch 20, Loss: 0.4786
Epoch 30, Loss: 0.4814
Epoch 40, Loss: 0.4814
✅ Training complete!


🚀 Training GNN on partial correlation graphs with thresh