In [86]:
import ast
import h5py
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.data import DataLoader, Data
from torch_geometric.nn import GCNConv, MessagePassing

In [87]:
def convert_adjacency_to_tensors(edge_list):
    return torch.tensor(edge_list, dtype=torch.long).t().contiguous()    

In [88]:
filepath = "/Users/zacharykaras/Desktop/brain_data_all_labels.h5"
with h5py.File(filepath, 'r') as f:
  # print(f.keys())
  data = np.array(f['brains'])
  data[np.isnan(data)] = 0
  inputs = torch.Tensor(np.array(f['brains']))
  inputs = torch.nan_to_num(inputs, nan=0.0, posinf=0, neginf=0)
  labels = torch.Tensor(np.array(f['labels']))

features = pd.read_csv("../feature_embeddings.csv")

In [89]:

node_features = torch.Tensor(data.T)
edge_list = features['edgeList_embeddings'].apply(lambda x: ast.literal_eval(x))
max_length = max(len(lst) for lst in edge_list)

padded_edge_embeddings = []
for lst in edge_list:
    extended = [[0,0]] * (max_length - len(lst))
    if len(lst) != 19:
        lst.extend(extended)
    padded_edge_embeddings.append(lst)


# # # edge_embeddings = edge_list.apply(lambda x: convert_adjacency_to_tensors(x))
padded_edge_embeddings = torch.Tensor(padded_edge_embeddings).long() - 1
# edge_list = torch.Tensor(edge_list).long()
# valid_edges_mask = (padded_edge_embeddings[0] != 0) & (padded_edge_embeddings[1] != 0)
# edge_index = padded_edge_embeddings[:, valid_edges_mask]

# print(edge_index.shape)  # Should be [2, num_valid_edges]
# print(edge_index)

network_embeddings = features['network_embeddings'].apply(lambda x: ast.literal_eval(x))
# network_embeddings = network_embeddings.apply(lambda x: torch.Tensor(x))
network_embeddings = torch.Tensor(network_embeddings)
# # Combine node features with network embeddings
# node_features_combined = torch.cat([node_features, network_embeddings], dim=1)  # Shape: [400, 20008]

# # Create a PyG Data object
# graph_data = Data(x=node_features_combined, edge_index=edge_index, edge_attr=edge_features)

In [90]:
print(type(node_features), node_features.shape, type(node_features[0]))
print(type(padded_edge_embeddings), padded_edge_embeddings.shape, type(padded_edge_embeddings[0]))
print(type(network_embeddings), network_embeddings.shape, type(network_embeddings[0]))
print(type(labels), labels.shape)

<class 'torch.Tensor'> torch.Size([400, 20821]) <class 'torch.Tensor'>
<class 'torch.Tensor'> torch.Size([400, 19, 2]) <class 'torch.Tensor'>
<class 'torch.Tensor'> torch.Size([400, 8]) <class 'torch.Tensor'>
<class 'torch.Tensor'> torch.Size([20821])


In [91]:
# Assuming `brain_activation_values`, `network_info`, `adjacency_lists`, and `labels` are your data tensors
# brain_activation_values = torch.randn(400, 20821)  # Example data
# network_info = torch.randn(400, 8)  # Example data
# adjacency_lists = torch.randint(0, 400, (400, 19, 2))  # Example data
# labels = torch.randint(0, 2, (20821,))  # Example labels

# Prepare the dataset
data_list = []
for i in range(20821):
    x = node_features[:, i].unsqueeze(0)
    # edge_index = padded_edge_embeddings
    edge_index = padded_edge_embeddings.view(-1, 2).t().contiguous().long()
    edge_index = edge_index[:, edge_index[0] != -1] 
    # edge_index = edge_list.view(-1, 2).t().contiguous().long()
    network_features = network_embeddings
    y = labels[i].long()
    data = Data(x=x, edge_index=edge_index, network_features=network_features, y=y)
    data_list.append(data)

# Create a DataLoader
loader = DataLoader(data_list, batch_size=1, shuffle=True)

# # Instantiate the model, optimizer, and loss function
# model = GNN(num_node_features=1, num_classes=2)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
# criterion = torch.nn.CrossEntropyLoss()



In [94]:
class BrainGNN(nn.Module):
    def __init__(self, input_dim, network_dim, edge_dim, hidden_dim, output_dim):
        super(BrainGNN, self).__init__()
        
        # Layers for node, network, and edge embeddings
        self.node_mlp = nn.Linear(input_dim, hidden_dim)
        self.network_mlp = nn.Linear(network_dim, hidden_dim)
        self.edge_mlp = nn.Linear(edge_dim, hidden_dim)
        
        # Graph message-passing layer
        self.message_layer = nn.Linear(2 * hidden_dim, hidden_dim)
        
        # Output layer
        self.output_mlp = nn.Linear(hidden_dim, output_dim)

    # def forward(self, node_features, network_features, edge_features, edge_index):
    def forward(self, node_features, network_features, edge_index, return_embeddings=False):
        # Step 1: Transform input features
        h_node = self.node_mlp(node_features)  # (N, hidden_dim)
        h_network = self.network_mlp(network_features)  # (1, hidden_dim) or (N, hidden_dim)
        # h_edge = self.edge_mlp(edge_features)  # (E, hidden_dim)
        
        # Step 2: Fuse global edge features
        # h_edge_global = torch.mean(h_edge, dim=0, keepdim=True)  # Aggregate into global feature
        # h_fused = h_node + h_network + h_edge_global  # Example fusion (addition)
        h_fused = h_node + h_network  # Example fusion (addition)
        # print(h_fused.shape, h_fused)
        # Step 3: Message passing
        # valid_edges = edge_index 
        row, col = edge_index
        messages = h_fused[col]  # Gather features for neighbors
        # print(messages.shape, messages)
        # messages = torch.cat([h_fused[row], messages, h_edge[row]], dim=-1)  # Include edge info
        messages = torch.cat([h_fused[row], messages], dim=-1)  # Include edge info
        h_messages = self.message_layer(messages)  # Aggregate messages

        # Step 4: Update node embeddings
        h_updated = h_node + h_messages

        # Step 5: Compute output
        edge_outputs = self.output_mlp(h_updated)  # Final output
        # print(edge_outputs.mean(dim=0))
        graph_output = edge_outputs.mean(dim=0)  # Aggregate edge outputs to two scalars
        # print(output.shape, output)
        if return_embeddings:
            return h_updated
        return graph_output

In [99]:
# Example Dataset and DataLoader (Assume you already have these)
# - dataset: a collection of graphs where each graph contains `node_features`, `network_features`, `edge_features`, `edge_index`, and `targets`.
# - DataLoader: batches these graphs for training.

# Parameters
input_dim = 400        # Size of node features
network_dim = 8        # Size of network features
edge_dim = 19          # Size of edge embeddings
hidden_dim = 512       # Size of hidden layer embeddings
output_dim = 180         # Regression output or number of classes
num_epochs = 50        # Number of training epochs
batch_size = 1        # Batch size for DataLoader
learning_rate = 1e-3   # Learning rate

# Model, Loss, Optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BrainGNN(input_dim, network_dim, edge_dim, hidden_dim, output_dim).to(device)
criterion = nn.CrossEntropyLoss(label_smoothing=0.0)  # Use nn.CrossEntropyLoss() for classification tasks
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# # DataLoader (example)
# train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# val_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)  # Validation set

# Training Loop
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    # batch_num = 1
    for batch in loader:
        # Move data to device
        # print(batch)
        node_features = batch.x.to(device)       # (N, input_dim)
        network_features = batch.network_features.to(device) # (1, network_dim) or (N, network_dim)
        # edge_features = batch.edge_features.to(device)       # (E, edge_dim)
        edge_index = batch.edge_index.to(device)             # (2, E)
        target = batch.y.to(device).long()-1           # (N, output_dim)

        # Forward pass
        optimizer.zero_grad()
        prediction = model(node_features, network_features, edge_index)
        # print(target, prediction.shape)
        # prediction = torch.argmax(output, dim=0)

        # print(prediction, target)

        # print(prediction.shape, target.shape)
        # print(prediction, target)
        # Calculate loss
        # print(output.unsqueeze(0), target.unsqueeze(0))
        # loss = criterion(output.unsqueeze(0), target.unsqueeze(0))
        loss = criterion(prediction.unsqueeze(0), target)
        epoch_loss += loss.item()
        predicted_class = torch.argmax(prediction, dim=0)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        # print(f"Batch {batch_num}, Loss: {loss}, Output: {predicted_class}, Label: {target}")
        # batch_num += 1
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}")
    
    # # Optional: Validation step
    # model.eval()
    # val_loss = 0.0
    # with torch.no_grad():
    #     for batch in val_loader:
    #         node_features = batch.node_features.to(device)
    #         network_features = batch.network_features.to(device)
    #         edge_features = batch.edge_features.to(device)
    #         edge_index = batch.edge_index.to(device)
    #         targets = batch.targets.to(device)
    #         predictions = model(node_features, network_features, edge_features, edge_index)
    #         val_loss += criterion(predictions, targets).item()
    # print(f"Validation Loss: {val_loss:.4f}")

# Save the model
# torch.save(model.state_dict(), "brain_gnn_model.pth")

KeyboardInterrupt: 

In [19]:
with torch.no_grad():
    h_node = model.node_mlp(node_features)  # Node embeddings
    h_network = model.network_mlp(network_features)
    h_fused = h_node + h_network

    row, col = edge_index
    messages = torch.cat([h_fused[row], h_fused[col]], dim=-1)
    h_messages = model.message_layer(messages)
    embeddings = h_node + h_messages  # Final node embeddings

In [22]:
h_fused

tensor([[-0.1179, -0.5957, -0.5841,  ..., -0.3794,  0.3671,  0.0172],
        [-0.1179, -0.5957, -0.5841,  ..., -0.3794,  0.3671,  0.0172],
        [-0.1179, -0.5957, -0.5841,  ..., -0.3794,  0.3671,  0.0172],
        ...,
        [ 0.3338, -0.3025, -0.5232,  ..., -0.0349, -0.1378,  0.5898],
        [ 0.3338, -0.3025, -0.5232,  ..., -0.0349, -0.1378,  0.5898],
        [ 0.3338, -0.3025, -0.5232,  ..., -0.0349, -0.1378,  0.5898]])