In [None]:
import ast
import h5py
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.data import DataLoader, Data
from torch_geometric.nn import GCNConv, MessagePassing

In [None]:
def convert_adjacency_to_tensors(edge_list):
    return torch.tensor(edge_list, dtype=torch.long).t().contiguous()   

In [None]:
filepath = "brain_data_all_labels.h5"
with h5py.File(filepath, 'r') as f:
  data = np.array(f['brains'])
  data[np.isnan(data)] = 0
  inputs = torch.Tensor(np.array(f['brains']))
  inputs = torch.nan_to_num(inputs, nan=0.0, posinf=0, neginf=0)
  labels = torch.Tensor(np.array(f['labels']))

features = pd.read_csv("feature_embeddings.csv")

In [None]:
node_features = torch.Tensor(data.T)
edge_list = features['edgeList_embeddings'].apply(lambda x: ast.literal_eval(x))
max_length = max(len(lst) for lst in edge_list)

padded_edge_embeddings = []
for lst in edge_list:
    extended = [[0,0]] * (max_length - len(lst))
    if len(lst) != 19:
        lst.extend(extended)
    padded_edge_embeddings.append(lst)


# edge_embeddings = edge_list.apply(lambda x: convert_adjacency_to_tensors(x))
padded_edge_embeddings = torch.Tensor(padded_edge_embeddings).long() - 1

network_embeddings = features['network_embeddings'].apply(lambda x: ast.literal_eval(x))
network_embeddings = torch.Tensor(network_embeddings)


In [None]:
# Prepare the dataset
data_list = []
for i in range(20821):
    x = node_features[:, i].unsqueeze(0)

    edge_index = padded_edge_embeddings.view(-1, 2).t().contiguous().long()
    edge_index = edge_index[:, edge_index[0] != -1]

    network_features = network_embeddings
    y = labels[i].long()
    data = Data(x=x, edge_index=edge_index, network_features=network_features, y=y)
    data_list.append(data)

# Create a DataLoader
loader = DataLoader(data_list, batch_size=1, shuffle=True)

<class 'torch.Tensor'> torch.Size([400, 20821]) <class 'torch.Tensor'>
<class 'torch.Tensor'> torch.Size([400, 19, 2]) <class 'torch.Tensor'>
<class 'torch.Tensor'> torch.Size([400, 8]) <class 'torch.Tensor'>
<class 'torch.Tensor'> torch.Size([20821])


In [None]:
class BrainGNN(nn.Module):
    def __init__(self, input_dim, network_dim, edge_dim, hidden_dim1, hidden_dim2, hidden_dim3, output_dim):
        super(BrainGNN, self).__init__()

        # Helper function to create MLP with ReLU activations
        def create_mlp(layers_dims, dropout=0.5):
            layers = []
            for i in range(len(layers_dims) - 1):
                layers.append(nn.Linear(layers_dims[i], layers_dims[i+1]))
                if i < len(layers_dims) - 2:  # No ReLU for the last layer
                    layers.append(nn.ReLU())
                    layers.append(nn.Dropout(p=dropout))
            return nn.Sequential(*layers)

        # Layers for node, network, and edge embeddings
        self.node_mlp = create_mlp([input_dim, hidden_dim1, hidden_dim2, hidden_dim3])
        self.network_mlp = create_mlp([network_dim, hidden_dim1, hidden_dim2, hidden_dim3])
        self.edge_mlp = create_mlp([edge_dim, hidden_dim1, hidden_dim2, hidden_dim3])

        # Graph message-passing layer
        self.message_layer = nn.Sequential(
            nn.Linear(2 * hidden_dim3, hidden_dim3),
            nn.ReLU(),
            nn.Dropout(p=0.5)
        )

        # Output layer
        self.output_mlp = nn.Linear(hidden_dim3, output_dim)
        
        # Learnable weights for combining features
        self.alpha = nn.Parameter(torch.ones(1))
        self.beta = nn.Parameter(torch.ones(1))

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0.01)

    def forward(self, node_features, network_features, edge_index, return_embeddings=False):
        # Step 1: Transform input features
        h_node = self.node_mlp(node_features)
        h_network = self.network_mlp(network_features)

        # Step 2: Fuse global edge features (assuming the addition as a fusion mechanism)
        h_fused = h_node + h_network

        # Step 3: Message passing
        row, col = edge_index
        messages = h_fused[col]  # Gather features for neighbors
        messages = torch.cat([h_fused[row], messages], dim=-1)  # Include edge info
        h_messages = self.message_layer(messages)  # Aggregate messages

        # Step 4: Update node embeddings
        h_updated = self.alpha * h_node + self.beta * h_messages

        # Step 5: Compute output
        edge_outputs = self.output_mlp(h_updated)  # Final output
        graph_output = edge_outputs.mean(dim=0)  # Aggregate edge outputs to two scalars

        if return_embeddings:
            return h_updated
        return graph_output




In [None]:
# Parameters
input_dim = 400        # Size of node features
network_dim = 8        # Size of network features
edge_dim = 19          # Size of edge embeddings
hidden_dim1 = 256       # Size of hidden layer embeddings
hidden_dim2 = 512
hidden_dim3 = 128
output_dim = 180         # Regression output or number of classes
num_epochs = 3       # Number of training epochs
batch_size = 1        # Batch size for DataLoader
learning_rate = 1e-3   # Learning rate

# Model, Loss, Optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BrainGNN(input_dim, network_dim, edge_dim, hidden_dim1, hidden_dim2, hidden_dim3, output_dim).to(device)
criterion = nn.CrossEntropyLoss(label_smoothing=0.0)  # Use nn.CrossEntropyLoss() for classification tasks
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
print(model)

In [None]:
collected_loss = []
# Training Loop
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    # batch_num = 1
    for batch in loader:
        # Move data to device

        node_features = batch.x.to(device)       # (N, input_dim)
        network_features = batch.network_features.to(device) # (1, network_dim) or (N, network_dim)
        edge_index = batch.edge_index.to(device)             # (2, E)
        target = batch.y.to(device).long()-1           # (N, output_dim)

        # Forward pass
        optimizer.zero_grad()
        prediction = model(node_features, network_features, edge_index)

        loss = criterion(prediction.unsqueeze(0), target)
        collected_loss.append(loss.item())
        epoch_loss += loss.item()
        predicted_class = torch.argmax(prediction, dim=0)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}")

KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), "brain_gnn_model_3epochs_weighted_features.pth")

In [None]:
import pickle

with open("collected_loss_3epochs_weighted_features.pkl", "wb") as f:
    pickle.dump(collected_loss, f)

tensor([[-0.1179, -0.5957, -0.5841,  ..., -0.3794,  0.3671,  0.0172],
        [-0.1179, -0.5957, -0.5841,  ..., -0.3794,  0.3671,  0.0172],
        [-0.1179, -0.5957, -0.5841,  ..., -0.3794,  0.3671,  0.0172],
        ...,
        [ 0.3338, -0.3025, -0.5232,  ..., -0.0349, -0.1378,  0.5898],
        [ 0.3338, -0.3025, -0.5232,  ..., -0.0349, -0.1378,  0.5898],
        [ 0.3338, -0.3025, -0.5232,  ..., -0.0349, -0.1378,  0.5898]])

In [None]:
import os
# Create a directory to save embeddings
os.makedirs('embeddings_3e_weighted', exist_ok=True)

# Disable gradient calculation
with torch.no_grad():
    # embeddings = []
    for i, batch in enumerate(loader):
        node_features = batch.x.to(device)       # (N, input_dim)
        network_features = batch.network_features.to(device) # (1, network_dim) or (N, network_dim)
        edge_index = batch.edge_index.to(device)             # (2, E)
        target = batch.y.to(device).long()-1

        # Forward pass to get embeddings
        h_updated = model(node_features, network_features, edge_index, return_embeddings=True)
        # print(h_updated)
        embeddings_cpu = h_updated.detach().cpu()
        # embeddings.append(embeddings_cpu)

        # Save each embedding to a separate file
        embedding_path = os.path.join('embeddings_3e', f'embedding_number{i}_class{int(target[0])}.npy')
        np.save(embedding_path, h_updated.cpu().numpy())
        print(f"Embedding {i} saved to {embedding_path}")
        # if i == 10:
        #     break

# all_embeddings = torch.cat(embeddings, dim=0)
print("All embeddings have been saved.")