In [5]:
import os
import pandas as pd
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
from tqdm import tqdm
import numpy as np
from nilearn.connectome import ConnectivityMeasure
import nibabel as nib
import networkx as nx


In [8]:
# Directory containing the pconn files
directory = "/home/tico/Desktop/master_classes/project/BSNIP/pconn"
pconn_files = [f for f in os.listdir(directory) if f.endswith('.pconn.nii')]
std = 2
behavior_path = '/home/tico/Desktop/master_classes/project/behavior/'
behavior_files = os.listdir(behavior_path)

# Load behavior data
behavior_source = pd.read_csv(os.path.join(behavior_path, behavior_files[0]), sep='\t')
for behavior_file in behavior_files[1:]:
    curr_behavior_source = pd.read_csv(os.path.join(behavior_path, behavior_file), sep='\t')
    behavior_source = pd.concat([behavior_source, curr_behavior_source], axis=0)
behavior_source = behavior_source[["session_id", "Group"]]

# Encode labels
label_encoder = LabelEncoder()
behavior_source['Group'] = label_encoder.fit_transform(behavior_source['Group'])

def load_fc_matrix(file_path):
    """ Load functional connectivity matrix from a .pconn.nii file. """
    img = nib.load(file_path)
    fc_matrix = img.get_fdata()
    return fc_matrix

def create_threshold_graph(fc_matrix, std_multiplier=2):
    """
    Create a graph from a functional connectivity matrix by adding edges where the 
    absolute connection strength is above a threshold defined as a multiple of the
    standard deviation of the absolute values in the connectivity matrix.
    """
    n = fc_matrix.shape[0]  # Number of nodes
    G = nx.Graph()
    
    # Calculate the threshold as std_multiplier times the standard deviation of the absolute values
    threshold = std_multiplier * np.std(np.abs(fc_matrix))
    
    # Add nodes
    for i in range(n):
        G.add_node(i, feature=torch.tensor(fc_matrix[i], dtype=torch.float))
    
    # Add edges based on the threshold
    for i in range(n):
        for j in range(i+1, n):  # Avoid self-loops and duplicate edges
            if np.abs(fc_matrix[i, j]) > threshold:
                G.add_edge(i, j, weight=fc_matrix[i, j])
    
    # Extract edge_index and node features from the graph
    edge_index = torch.tensor(list(G.edges)).t().contiguous()
    x = torch.stack([G.nodes[i]['feature'] for i in range(n)], dim=0)
    
    return edge_index, x

# Prepare a list to store the results
data_list = []

for file_name in tqdm(pconn_files, desc="Processing .pconn.nii files"):
    fc_file_path = os.path.join(directory, file_name)
    session_id = file_name[:-len('.pconn.nii')]
    label = behavior_source.loc[behavior_source['session_id'] == session_id, 'Group'].values[0]
    fc_matrix = load_fc_matrix(fc_file_path)
    edge_index, x = create_threshold_graph(fc_matrix, std)
    data = Data(x=x, edge_index=edge_index, y=torch.tensor([label], dtype=torch.long))
    data_list.append(data)

# Split data into training and testing sets
train_data, test_data = train_test_split(data_list, test_size=0.2, random_state=42)

# Create DataLoader for training and testing
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

Processing .pconn.nii files:   2%|▏         | 12/638 [00:05<05:16,  1.98it/s]

In [None]:
class GNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.fc1 = torch.nn.Linear(hidden_dim, output_dim)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = global_mean_pool(x, batch)
        x = self.fc1(x)
        return F.log_softmax(x, dim=1)

# Check input dimension
input_dim = data_list[0].x.size(1)

# Initialize the GNN model, optimizer, and loss function
model = GNN(input_dim=input_dim, hidden_dim=64, output_dim=len(label_encoder.classes_))
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

# Training function
def train(model, loader, optimizer, criterion):
    model.train()
    for data in loader:
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()

# Testing function
def test(model, loader):
    model.eval()
    correct = 0
    for data in loader:
        out = model(data.x, data.edge_index, data.batch)
        pred = out.argmax(dim=1)
        correct += (pred == data.y).sum().item()
    return correct / len(loader.dataset)

# Training loop
for epoch in range(20):
    train(model, train_loader, optimizer, criterion)
    train_acc = test(model, train_loader)
    test_acc = test(model, test_loader)
    print(f'Epoch {epoch+1}, Train Accuracy: {train_acc:.4f}, Test Accuracy: {test_acc:.4f}')

# Evaluate on the test set
test_accuracy = test(model, test_loader)
print(f'Test Accuracy: {test_accuracy:.4f}')