In [None]:
!pip -q install torch-geometric
!pip -q install nilearn

In [None]:
import os
import random
import nibabel as nib
import torch
from nilearn import plotting

def set_seed(seed):
    random.seed(seed)                     # Python
    np.random.seed(seed)                  # NumPy
    torch.manual_seed(seed)               # PyTorch CPU
    torch.cuda.manual_seed(seed)          # PyTorch GPU
    torch.cuda.manual_seed_all(seed)      # All GPUs (if using DataParallel or DDP)

    # Ensure deterministic behavior
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
set_seed(42)



from torch_geometric.data import Dataset
from torch_geometric.loader import DataLoader

import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool, SAGEConv, HeteroConv, Linear


import warnings
import numpy as np
from glob import glob

warnings.filterwarnings("ignore", category=FutureWarning, message="You are using `torch.load`")


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

In [None]:
class GraphDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.graph_paths = []
        self.labels = []

        class_map = {"M": 0, "F":1}

        for class_name, label in class_map.items():
            class_path = os.path.join(root_dir, class_name)
            for graph_file in glob(os.path.join(class_path, "*.pt")):
                self.graph_paths.append(graph_file)
                self.labels.append(label)



    def __len__(self):
        return len(self.graph_paths)

    def __getitem__(self, idx):
        path = self.graph_paths[idx]
        data = torch.load(path, weights_only=False)

        data.y = torch.tensor([self.labels[idx]], dtype=torch.float)

        desired_dim = 114
        current_dim = data['roi'].x.shape[1]
        if current_dim < desired_dim:
            pad_size = desired_dim - current_dim
            pad = torch.zeros((data['roi'].x.size(0), pad_size), dtype=data['roi'].x.dtype)
            data['roi'].x = torch.cat([data['roi'].x, pad], dim=1)

        if self.transform:
            data = self.transform(data)

        return data

In [None]:
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset

graph_path = "/content/drive/MyDrive/HCP Data/gender_graphs"

dataset = GraphDataset(graph_path)

In [None]:
class multihead(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.heads = 2

        self.cluster_encoder = torch.nn.Sequential(
            Linear(2, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(hidden_channels)
        )
        self.roi_encoder = torch.nn.Sequential(
            Linear(114, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(hidden_channels)
        )

        # Heterogeneous SAGE
        self.sage = HeteroConv({
            ('cluster', 'intersects', 'roi'): SAGEConv((-1, -1), hidden_channels),
            ('roi', 'intersects_rev', 'cluster'): SAGEConv((-1, -1), hidden_channels),
        }, aggr='mean')

        # Heterogeneous GAT
        self.gat = HeteroConv({
            ('cluster', 'intersects', 'roi'): GATConv((-1, -1), hidden_channels, heads=self.heads, concat=True, add_self_loops=False),
            ('roi', 'intersects_rev', 'cluster'): GATConv((-1, -1), hidden_channels, heads=self.heads, concat=True, add_self_loops=False),
        }, aggr='mean')

        self.classifier = torch.nn.Linear(hidden_channels * 2 * self.heads, 1)

    def forward(self, data, return_attention=False):
        # Encode input features
        x_dict = {
            'cluster': self.cluster_encoder(data['cluster'].x),
            'roi': self.roi_encoder(data['roi'].x),
        }

        # convolutions
        x_dict = self.sage(x_dict, data.edge_index_dict)
        x_dict = {k: F.relu(v) for k, v in x_dict.items()}

        if return_attention:
            att_dict = {}
            x_gat = {}
            for edge_type, conv in self.gat.convs.items():
                edge_index = data.edge_index_dict[edge_type]
                out, (edge_index_used, attn_weights) = conv(
                    (x_dict[edge_type[0]], x_dict[edge_type[2]]),
                    edge_index,
                    return_attention_weights=True
                )
                x_gat[edge_type[2]] = out  # aggregate to target node
                att_dict[edge_type] = (edge_index_used, attn_weights)
            x_dict = x_gat
            x_dict = {k: F.relu(v) for k, v in x_dict.items()}
        else:
            x_dict = self.gat(x_dict, data.edge_index_dict)
            x_dict = {k: F.relu(v) for k, v in x_dict.items()}
            att_dict = None

        # Global pooling
        cluster_pool = global_mean_pool(x_dict['cluster'], data['cluster'].batch)
        roi_pool = global_mean_pool(x_dict['roi'], data['roi'].batch)

        # Classification
        x = torch.cat([cluster_pool, roi_pool], dim=1)

        if return_attention:
            return self.classifier(x), att_dict
        return self.classifier(x)

model = multihead(64)
model.load_state_dict(torch.load("/content/drive/MyDrive/HCP Data/Models/multihead/seed_4/best_model.pt", weights_only=True, map_location=torch.device(device)))
model.eval()

In [None]:
from collections import defaultdict

roi_attention_sum = defaultdict(lambda: defaultdict(float))
roi_attention_count = defaultdict(lambda: defaultdict(int))

test_loader = DataLoader(dataset, batch_size=1, shuffle=False)



for data in test_loader:
    data = data.to(device)
    _, att_dict = model(data, return_attention=True)
    edge_index, attn_weights = att_dict[('cluster', 'intersects', 'roi')]  # Shape: [E, heads]


    for tgt_idx, att_vec in zip(edge_index[1].tolist(), attn_weights.tolist()):
        for head_idx, score in enumerate(att_vec):
            roi_attention_sum[tgt_idx][head_idx] += score
            roi_attention_count[tgt_idx][head_idx] += 1

avg_attention = {
    roi: {
        head: roi_attention_sum[roi][head] / roi_attention_count[roi][head]
        for head in roi_attention_sum[roi]
    }
    for roi in roi_attention_sum
}

parcellation_nii = nib.load("aparc+aseg.nii.gz")
parcellation_data = parcellation_nii.get_fdata()

roi_labels = np.unique(parcellation_data)
roi_labels = roi_labels[roi_labels != 0]

roi_label_to_index = {label: idx for idx, label in enumerate(roi_labels)}
index_to_label = {v: k for k, v in roi_label_to_index.items()}


headwise_label_attention = defaultdict(dict)  # head -> label -> score

for roi_idx, head_scores in avg_attention.items():
    if roi_idx not in index_to_label:
        continue  # skip unknown ROI indices
    label = index_to_label[roi_idx]
    for head_idx, score in head_scores.items():
        headwise_label_attention[head_idx][label] = score


# Create NIfTI volume for each head
output_paths = []
for head_idx, label_score_map in headwise_label_attention.items():
    attention_volume = np.zeros_like(parcellation_data)

    for label, score in label_score_map.items():
        attention_volume[parcellation_data == label] = score

    out_img = nib.Nifti1Image(attention_volume, parcellation_nii.affine)
    out_path = f"gat_attention_head{head_idx}.nii.gz"
    nib.save(out_img, out_path)
    output_paths.append(out_path)

print("Saved attention maps to:")
for path in output_paths:
    print(path)


In [None]:
# print top 5 regions per head
N = 5
for head_idx in headwise_label_attention:
    print(f"\nTop {N} ROIs for Head {head_idx}:")
    top_items = sorted(headwise_label_attention[head_idx].items(), key=lambda x: x[1], reverse=True)[:N]
    for label, score in top_items:
        print(f"  Label {label}: Score = {score:.4f}")


In [None]:
#heatmap visualizing top regions

for path in output_paths:
    display = plotting.plot_glass_brain(
        path,
        threshold=0.03,
        display_mode='lyrz',
        colorbar=True,
        plot_abs=False
    )
    display.savefig(path.replace(".nii.gz", "_glass.png"))