In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
import numpy as np
from torch_geometric.nn import global_mean_pool
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import add_self_loops

# Hyperparameters
batch_size = 1
hidden_layer_size = 512
epochs = 200
learning_rate = 0.01
node_feature_dim = 116
edge_feature_dim = 3
num_classes = 4


In [2]:
ppmi = torch.load('../data/ppmi.pth')
connectivities = ppmi['matrix'].numpy()
labels = ppmi['label']
print(f'matrices: {connectivities.shape}, labels: {labels.shape}')
print(f'Connectivity matrices shape: {connectivities.shape}, Labels shape: {labels.shape}')
print(f'Connectivity matrices dtype: {connectivities.dtype}, Labels dtype: {labels.dtype}')
print(f'Connectivity matrices min: {connectivities.min()}, Labels min: {labels.min()}')
print(f'Connectivity matrices max: {connectivities.max()}, Labels max: {labels.max()}')

matrices: (209, 116, 116), labels: torch.Size([209])
Connectivity matrices shape: (209, 116, 116), Labels shape: torch.Size([209])
Connectivity matrices dtype: float64, Labels dtype: torch.int64
Connectivity matrices min: -0.7707168807098154, Labels min: 0
Connectivity matrices max: 1.0, Labels max: 3


In [3]:
class ConnectivityDataset(Dataset):
    def __init__(self, connectivities, labels):
        """
        Args:
            connectivities (numpy.ndarray): Array of connectivity matrices of shape [n_samples, 116, 116].
            labels (torch.Tensor): Tensor of labels of shape [n_samples,].
        """
        self.connectivities = connectivities
        self.labels = labels
        self.left_indices, self.right_indices = self.get_hemisphere_indices()

    def get_hemisphere_indices(self):
        left_indices = [i for i in range(node_feature_dim) if i % 2 == 0]
        right_indices = [i for i in range(node_feature_dim) if i % 2 != 0]
        return left_indices, right_indices

    def __len__(self):
        return len(self.connectivities)

    def __getitem__(self, idx):
        connectivity = self.connectivities[idx]
        label = self.labels[idx]

        inter_matrix = self.extract_interhemispherical_matrix(connectivity)
        intra_asym_matrix = self.extract_intrahemispherical_asymmetry_matrix(connectivity)
        homotopic_matrix = self.extract_homotopic_matrix(connectivity)
        combined_matrix = self.combine_feature_matrices(inter_matrix, intra_asym_matrix, homotopic_matrix)

        edge_index = []
        edge_attr = []
        for j in range(node_feature_dim):
            for k in range(node_feature_dim):
                if np.any(combined_matrix[j, k] != 0):
                    edge_index.append([j, k])
                    edge_attr.append(combined_matrix[j, k])

        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(np.array(edge_attr), dtype=torch.float)
        x = torch.tensor(np.eye(node_feature_dim), dtype=torch.float)  # Node features as identity matrix
        y = label.clone().detach()  # Correctly define the label tensor as a single long tensor

        data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
        return data

    def extract_interhemispherical_matrix(self, connectivity):
        interhemispherical_matrix = np.zeros((58, 58))
        for i, li in enumerate(self.left_indices):
            for j, ri in enumerate(self.right_indices):
                interhemispherical_matrix[i, j] = connectivity[li, ri]
        return interhemispherical_matrix

    def extract_intrahemispherical_asymmetry_matrix(self, connectivity):
        intra_asymmetry_matrix = np.zeros((node_feature_dim, node_feature_dim))
        left_hemisphere = connectivity[np.ix_(self.left_indices, self.left_indices)]
        right_hemisphere = connectivity[np.ix_(self.right_indices, self.right_indices)]
        intra_asymmetry_matrix[np.ix_(self.left_indices, self.left_indices)] = np.abs(left_hemisphere - left_hemisphere.T)
        intra_asymmetry_matrix[np.ix_(self.right_indices, self.right_indices)] = np.abs(right_hemisphere - right_hemisphere.T)
        return intra_asymmetry_matrix

    def extract_homotopic_matrix(self, connectivity):
        homotopic_matrix = np.zeros((58, 58))
        for i, li in enumerate(self.left_indices):
            for j, ri in enumerate(self.right_indices):
                homotopic_matrix[i, j] = connectivity[li, ri]
        return homotopic_matrix

    def combine_feature_matrices(self, inter_matrix, intra_asym_matrix, homotopic_matrix):
        combined_matrix = np.zeros((node_feature_dim, node_feature_dim, edge_feature_dim))
        for i, li in enumerate(self.left_indices):
            for j, ri in enumerate(self.right_indices):
                combined_matrix[li, ri, 0] = inter_matrix[i, j]
                combined_matrix[li, ri, 2] = homotopic_matrix[i, j]
        combined_matrix[:, :, 1] = intra_asym_matrix
        return combined_matrix


In [4]:
class EdgeEnhancedGCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(EdgeEnhancedGCNConv, self).__init__(aggr='add')  # "Add" aggregation.
        self.lin = torch.nn.Linear(in_channels + edge_feature_dim, out_channels)  # Include edge features.

    def forward(self, x, edge_index, edge_attr):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
        # edge_attr has shape [E, edge_feature_dim]

        # Step 1: Add self-loops to the adjacency matrix.
        num_nodes = x.size(0)
        edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
        
        # Step 2: Add corresponding self-loop edge attributes.
        self_loop_attr = torch.zeros((num_nodes, edge_feature_dim), device=edge_attr.device)
        edge_attr = torch.cat([edge_attr, self_loop_attr], dim=0)

        # Step 3: Start propagating messages.
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_j, edge_attr):
        # x_j has shape [E, out_channels]
        # edge_attr has shape [E, edge_feature_dim]
        return torch.cat([x_j, edge_attr], dim=-1)

    def update(self, aggr_out):
        # aggr_out has shape [N, out_channels]
        return self.lin(aggr_out)  # Linearly transform the aggregated messages


class GCN(torch.nn.Module):
    def __init__(self, node_in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = EdgeEnhancedGCNConv(node_in_channels, hidden_channels)
        self.conv2 = EdgeEnhancedGCNConv(hidden_channels, hidden_channels)
        self.lin = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        x = self.conv1(x, edge_index, edge_attr)
        x = F.relu(x)
        x = self.conv2(x, edge_index, edge_attr)
        x = global_mean_pool(x, batch)  # Global pooling for graph-level classification
        x = self.lin(x)
        return F.log_softmax(x, dim=1)


In [5]:
def train_gcn(train_loader, val_loader=None, epochs=200, lr=0.01):
    model = GCN(node_in_channels=node_feature_dim, hidden_channels=hidden_layer_size, out_channels=num_classes)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        model.train()
        for data in train_loader:
            optimizer.zero_grad()
            out = model(data)
            loss = F.nll_loss(out, data.y)
            loss.backward()
            optimizer.step()

        print(f'Epoch {epoch+1}, Loss: {loss.item()}')
        
        # Validation step
        if val_loader:
            model.eval()
            correct = 0
            for data in val_loader:
                out = model(data)
                pred = out.argmax(dim=1)
                correct += (pred == data.y).sum().item()
            accuracy = correct / len(val_loader.dataset)
            print(f'Validation Accuracy: {accuracy:.4f}')


dataset = ConnectivityDataset(connectivities, labels)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Optionally create a validation set
val_size = int(0.1 * len(dataset))
train_size = len(dataset) - val_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

# Train the GCN model
train_gcn(train_loader, val_loader, epochs=epochs, lr=learning_rate)


Epoch 1, Loss: 0.3125596046447754
Validation Accuracy: 0.6500
Epoch 2, Loss: 4.732076168060303
Validation Accuracy: 0.5000
Epoch 3, Loss: 6.851142406463623
Validation Accuracy: 0.6500


KeyboardInterrupt: 

: 