In [1]:
import torch

!pip uninstall torch-scatter torch-sparse torch-geometric torch-cluster  --y
!pip install torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install torch-cluster -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install git+https://github.com/pyg-team/pytorch_geometric.git

[0mLooking in links: https://data.pyg.org/whl/torch-2.5.1+cu121.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcu121/torch_scatter-2.1.2%2Bpt25cu121-cp310-cp310-linux_x86_64.whl (10.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.9/10.9 MB[0m [31m100.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-scatter
Successfully installed torch-scatter-2.1.2+pt25cu121
Looking in links: https://data.pyg.org/whl/torch-2.5.1+cu121.html
Collecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcu121/torch_sparse-0.6.18%2Bpt25cu121-cp310-cp310-linux_x86_64.whl (5.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.1/5.1 MB[0m [31m51.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch-sparse
Successfully installed torch-sparse-0.6.18+pt25cu121
Looking in links: https://data.pyg.org/whl/torch-2.5.1+cu121.html
Collecting torch-cluster
  Downloa

In [2]:
import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, random_split, ConcatDataset
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.utils import add_self_loops

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from scipy.stats import ttest_ind
import statsmodels.stats.multitest as smm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


# **Data paths**

In [None]:
MCI_FOLDER_ALL = "/kaggle/input/ADNI/MCI" 
CN_FOLDER_ALL  = "/kaggle/input/ADNI/CN" 
MCI_FOLDER_TRAIN = "/kaggle/input/ADNI/MCI/train"
MCI_FOLDER_TEST = "/kaggle/input/ADNI/MCI/test"
CN_FOLDER_TRAIN = "/kaggle/input/ADNI/CN/train"
CN_FOLDER_TEST = "/kaggle/input/ADNI/CN/test"
AAL_90_PATH = "/kaggle/input/aal-template/AAL90_region_info.xls"

# **Utility functions**

In [None]:
def get_all_txt_files(folder_path):
    return sorted(glob.glob(os.path.join(folder_path, "**", "*.txt"), recursive=True))

def load_timeseries(file_path):
    return np.loadtxt(file_path)

def clean_timeseries(ts):
    """Add a tiny random noise if a column's standard deviation is near zero"""
    eps = 1e-6
    for col_idx in range(ts.shape[1]):
        std_ = np.std(ts[:, col_idx])
        if std_ < eps:
            ts[:, col_idx] += np.random.normal(loc=0, scale=1e-4, size=(ts.shape[0],))
    return ts

def compute_corr(timeseries):
    """ Calculates the correlation matrix, clamping values to [-1,1]"""
    corr = np.corrcoef(timeseries.T)
    corr = np.nan_to_num(corr, nan=0.0, posinf=0.0, neginf=-1.0)
    corr = np.clip(corr, -1.0, 1.0)
    return corr

def get_subject_correlation_matrices(folder_path):
    txt_files = get_all_txt_files(folder_path)
    mats = []
    for f in txt_files:
        ts = load_timeseries(f)
        if ts.ndim != 2 or ts.shape[1] != 90:
            print(f"Skipping {f}, shape={ts.shape}")
            continue
        c = compute_corr(ts)
        mats.append(c)
    return np.array(mats)

# **Statistical test and building the mask**


* We collect correlation matrices for both classes.
* For each pair (i,j) of regions, we run a Welch’s t-test comparing the classes.
* Benjamini–Hochberg (FDR) controls multiple tests.
* The mask records which edges are significant


In [None]:
mci_corr_mats = get_subject_correlation_matrices(MCI_FOLDER_ALL)
cn_corr_mats  = get_subject_correlation_matrices(CN_FOLDER_ALL)

df = pd.read_excel(AAL_90_PATH, header=0)
region_dict = dict(zip(df['Labels'], df['Regions']))
n_regions = 90

assert mci_corr_mats.shape[1] == n_regions and mci_corr_mats.shape[2] == n_regions, "MCI shape not 90x90"
assert cn_corr_mats.shape[1] == n_regions and cn_corr_mats.shape[2] == n_regions, "CN shape not 90x90"

results_list = []
for i in range(n_regions):
    for j in range(i+1, n_regions):
        mci_vals = mci_corr_mats[:, i, j]
        cn_vals  = cn_corr_mats[:, i, j]
        t_stat, p_val = ttest_ind(mci_vals, cn_vals, equal_var=False, nan_policy='omit')
        results_list.append({
            'i': i,
            'j': j,
            'region_i': region_dict.get(i+1, f"R{i+1}"),
            'region_j': region_dict.get(j+1, f"R{j+1}"),
            't_stat': t_stat,
            'p_value': p_val
        })

edges_df = pd.DataFrame(results_list)
pvals = edges_df['p_value'].values
reject, pvals_corrected, _, _ = smm.multipletests(pvals, alpha=0.05, method='fdr_bh')

edges_df['p_value_corrected'] = pvals_corrected
edges_df['reject_null'] = reject

sig_edges = edges_df[edges_df['reject_null']].copy()
sig_edges.sort_values('p_value_corrected', inplace=True)

mask = np.zeros((n_regions, n_regions), dtype=np.float32)
for row in sig_edges.itertuples():
    i, j = row.i, row.j
    mask[i, j] = 1
    mask[j, i] = 1

# **Dataset for graph construction**


For each subject file:
1. We load time-series and compute correlation.
2. We build node features ([mean, std] per region).
3. We apply the group-level mask to keep only significant edges.
4. We create a PyG Data object.



In [6]:
class BrainGraphDataset(Dataset):
    def __init__(self, folder_path, label, mask):
        super().__init__()
        self.files = get_all_txt_files(folder_path)
        self.label = label
        self.mask = mask

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_path = self.files[idx]
        ts = load_timeseries(file_path)
        corr = compute_corr(ts)
        node_feats = []
        for r in range(n_regions):
            arr = ts[:, r]
            mean_ = np.mean(arr)
            std_  = np.std(arr)
            if std_ < 1e-9:
                std_ = 1e-9
            node_feats.append([mean_, std_])
        x = torch.tensor(node_feats, dtype=torch.float)

        corr_masked = corr * self.mask
        corr_masked[corr_masked<0] = 0

        edge_index_list = []
        edge_attr_list = []
        for i in range(n_regions):
            for j in range(n_regions):
                if i != j and self.mask[i, j] == 1:
                    val = corr_masked[i, j]
                    edge_index_list.append([i, j])
                    edge_attr_list.append(val)

        if len(edge_index_list)==0:
            e_idx = torch.empty((2,0), dtype=torch.long)
            e_attr= torch.empty((0,), dtype=torch.float)
        else:
            e_idx = torch.tensor(edge_index_list, dtype=torch.long).t().contiguous()
            e_attr= torch.tensor(edge_attr_list, dtype=torch.float)

        e_idx, e_attr = add_self_loops(e_idx, e_attr, fill_value=0.0, num_nodes=n_regions)

        y = torch.tensor([self.label], dtype=torch.long)
        data = Data(x=x, edge_index=e_idx, edge_attr=e_attr, y=y)
        return data

# GNN

We use a simple two-layer GCN. After the second layer, we pool node embeddings and classify them.

In [7]:
class GNNClassifier(nn.Module):
    def __init__(self, in_channels=2, hidden_dim=64, num_classes=2):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_dim, normalize=False)
        self.conv2 = GCNConv(hidden_dim, hidden_dim, normalize=False)
        self.fc    = nn.Linear(hidden_dim, num_classes)

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        x = self.conv1(x, edge_index, edge_weight=edge_attr)
        x = F.relu(x)
        x = self.conv2(x, edge_index, edge_weight=edge_attr)
        x = F.relu(x)
        x = global_mean_pool(x, batch)
        return self.fc(x)


# Train and test datasets

In [None]:
mci_train_dataset = BrainGraphDataset(MCI_FOLDER_TRAIN, label=1, mask=mask)
cn_train_dataset  = BrainGraphDataset(CN_FOLDER_TRAIN, label=0, mask=mask)
train_dataset     = ConcatDataset([mci_train_dataset, cn_train_dataset])

mci_test_dataset = BrainGraphDataset(MCI_FOLDER_TEST, label=1, mask=mask)
cn_test_dataset  = BrainGraphDataset(CN_FOLDER_TEST, label=0, mask=mask)
test_dataset     = ConcatDataset([mci_test_dataset, cn_test_dataset])

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader  = DataLoader(test_dataset,  batch_size=4, shuffle=False)

# Training

In [None]:
model = GNNClassifier(in_channels=2, hidden_dim=64, num_classes=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

EPOCHS = 150

train_losses = []
test_losses  = []
train_accs   = []
test_accs    = []

for epoch in range(1, EPOCHS+1):
    model.train()
    running_loss = 0.0
    correct = 0
    total   = 0
    for batch_data in train_loader:
        batch_data = batch_data.to(device)
        optimizer.zero_grad()
        out = model(batch_data)
        loss = criterion(out, batch_data.y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        running_loss += loss.item() * batch_data.num_graphs
        preds = out.argmax(dim=1)
        correct += (preds == batch_data.y).sum().item()
        total   += batch_data.num_graphs

    epoch_train_loss = running_loss / len(train_dataset)
    epoch_train_acc  = correct / total

    model.eval()
    test_loss_sum = 0.0
    test_correct  = 0
    test_total    = 0
    with torch.no_grad():
        for batch_data in test_loader:
            batch_data = batch_data.to(device)
            out = model(batch_data)
            loss = criterion(out, batch_data.y)
            test_loss_sum += loss.item() * batch_data.num_graphs
            preds = out.argmax(dim=1)
            test_correct += (preds == batch_data.y).sum().item()
            test_total   += batch_data.num_graphs

    epoch_test_loss = test_loss_sum / len(test_dataset)
    epoch_test_acc  = test_correct / test_total

    train_losses.append(epoch_train_loss)
    test_losses.append(epoch_test_loss)
    train_accs.append(epoch_train_acc)
    test_accs.append(epoch_test_acc)

    if epoch % 10 == 0:
        print(f"Epoch {epoch:03d}/{EPOCHS} | "
              f"Train Loss={epoch_train_loss:.4f}, Acc={epoch_train_acc:.4f} | "
              f"Test Loss={epoch_test_loss:.4f}, Acc={epoch_test_acc:.4f}")

# Plotting loss and accuracy

# Final evaluation

In [11]:
model.eval()
all_preds = []
all_labels= []
with torch.no_grad():
    for data_batch in test_loader:
        data_batch = data_batch.to(device)
        out = model(data_batch)
        preds = out.argmax(dim=1).cpu().numpy()
        labels= data_batch.y.cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels)

acc = accuracy_score(all_labels, all_preds)
prec= precision_score(all_labels, all_preds, average='binary')
rec = recall_score(all_labels, all_preds, average='binary')
f1  = f1_score(all_labels, all_preds, average='binary')

print("\nFinal Test Metrics:")
print(f"Accuracy= {acc:.4f}")
print(f"Precision={prec:.4f}")
print(f"Recall=   {rec:.4f}")
print(f"F1-score= {f1:.4f}")

model_save_path = "gnn_classifier_mci_cn.pth"
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")


Final Test Metrics:
Accuracy= 0.5857
Precision=0.6000
Recall=   0.5143
F1-score= 0.5538
Model saved to gnn_classifier_mci_cn.pth


**Feature importance**

In [None]:
from torch_geometric.explain import Explainer, GNNExplainer
from torch_geometric.explain.config import ModelConfig
from torch_geometric.data import Data

class WrappedModel(nn.Module):
    def __init__(self, base):
        super().__init__()
        self.base = base          

    def forward(self, x, edge_index, edge_weight, batch=None):
        if batch is None:
            batch = torch.zeros(x.size(0), dtype=torch.long,
                                device=x.device)
        g = Data(x=x,
                 edge_index=edge_index,
                 edge_attr=edge_weight,
                 batch=batch)
        return self.base(g)

wrapped = WrappedModel(model).to(device)
wrapped.eval()

test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

model_cfg = ModelConfig(
    mode        = 'binary_classification',   
    task_level  = 'graph',
    return_type = 'raw'                      
)

explainer = Explainer(
    model            = wrapped,
    algorithm        = GNNExplainer(epochs=200),
    model_config     = model_cfg,
    explanation_type = 'model',
    node_mask_type   = None,
    edge_mask_type   = 'object'
)

edge_imp_accum = torch.zeros((n_regions, n_regions), device=device)
count = 0

for data in test_loader:
    data = data.to(device)
    if model(data).argmax(1).item() != data.y.item():
        continue           

    exp  = explainer(
        x           = data.x,
        edge_index  = data.edge_index,
        edge_weight = data.edge_attr
    )
    mask = exp.edge_mask              
    ei   = data.edge_index

    for k, score in enumerate(mask):
        u, v = ei[0, k].item(), ei[1, k].item()
        edge_imp_accum[u, v] += score
        edge_imp_accum[v, u] += score
    count += 1

edge_imp_avg = (edge_imp_accum / max(count, 1)).cpu().numpy()

u, v   = np.triu_indices(n_regions, k=1)
scores = edge_imp_avg[u, v]
best   = np.argsort(scores)[-10:][::-1]

rows = []
for rank, idx in enumerate(best, 1):
    i, j = u[idx], v[idx]
    rows.append({
        'rank'      : rank,
        'region_i'  : region_dict[i+1],
        'region_j'  : region_dict[j+1],
        'importance': float(scores[idx])
    })

df_imp = pd.DataFrame(rows)
print("\nTop-10 Discriminative Edges (MCI vs CN)")
print(df_imp[['rank', 'region_i', 'region_j', 'importance']])