In [7]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score
from torch.utils.data import DataLoader, TensorDataset
import pytorch_lightning as pl
import os
import pickle
from typing import List, Tuple, Dict
import logging
from tqdm import tqdm
from torch_geometric.nn import GCNConv
import torch_geometric
from joblib import Parallel, delayed


In [8]:
device = torch.device("cpu")

In [9]:
## Loading node embedding data
node_embeddings = torch.load('./Graph Outputs/node_embeddings_initial.pt')
node_embeddings = node_embeddings.to(device)

  node_embeddings = torch.load('./Graph Outputs/node_embeddings_initial.pt')


In [10]:
## Load Patient Subgraph Data
with open('./Graph Outputs/train_pg_subgraph.pkl', 'rb') as f:
    train_subgraphs = pickle.load(f)
with open('./Graph Outputs/val_pg_subgraph.pkl', 'rb') as f:
    val_subgraphs = pickle.load(f)

In [11]:
for subgraph in train_subgraphs:
    global_node_ids = [int(key) for key in subgraph.node_mapping.keys()]
    subgraph.x = node_embeddings[global_node_ids]
    del subgraph.node_mapping  # Remove node_mapping from Data object

for subgraph in val_subgraphs:
    global_node_ids = [int(key) for key in subgraph.node_mapping.keys()]
    subgraph.x = node_embeddings[global_node_ids]
    del subgraph.node_mapping

In [21]:
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch_geometric.nn import GCNConv, global_mean_pool

class PopulationLevelGraph(pl.LightningModule):
    def __init__(self, embedding_dim, num_patients, node_embeddings):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.num_patients = num_patients
        self.node_embeddings = node_embeddings

        # GNN layers to process patient-specific subgraphs
        self.conv1 = GCNConv(embedding_dim, 128)
        self.conv2 = GCNConv(128, embedding_dim)

        # MLP to learn patient representations
        self.mlp = nn.Sequential(
            nn.Linear(embedding_dim, 128),
            nn.ReLU(),
            nn.Linear(128, embedding_dim),
        )

    def forward(self, data):
        # Ensure tensors are on the correct device
        device = self.node_embeddings.device
        x = data.x.to(device)
        edge_index = data.edge_index.to(device)

        # Apply GNN layers
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)

        # Aggregate node features to create graph-level representation
        patient_rep = global_mean_pool(x, data.batch)
        return self.mlp(patient_rep)

    def compute_population_graph(self, patient_representations):
        # Compute pairwise similarities between patient representations
        latent_adj_matrix = torch.sigmoid(-torch.cdist(patient_representations, patient_representations, p=2))
        return latent_adj_matrix

    def compute_similarity_loss(self, patient_representations):
        # Encourage adjacency matrix to reflect similarity
        latent_adj_matrix = self.compute_population_graph(patient_representations)
        similarity_matrix = torch.mm(patient_representations, patient_representations.t())
        return F.mse_loss(latent_adj_matrix, torch.sigmoid(similarity_matrix))

    def compute_sparsity_loss(self, latent_adj_matrix):
        # Penalize non-zero elements to encourage sparsity
        return latent_adj_matrix.sum() * 1e-3

    def training_step(self, batch, batch_idx):
        # Move batch to the correct device
        batch = batch.to(self.node_embeddings.device)

        # Generate patient representations
        patient_representations = self.forward(batch)

        # Compute adjacency matrix
        latent_adj_matrix = self.compute_population_graph(patient_representations)

        # Compute losses
        similarity_loss = self.compute_similarity_loss(patient_representations)
        sparsity_loss = self.compute_sparsity_loss(latent_adj_matrix)

        # Total loss
        total_loss = similarity_loss + sparsity_loss
        self.log("train_loss", total_loss)
        return total_loss

    def validation_step(self, batch, batch_idx):
        # Move batch to the correct device
        batch = batch.to(self.node_embeddings.device)
    
        # Generate patient representations
        patient_representations = self.forward(batch)
    
        # Compute adjacency matrix
        latent_adj_matrix = self.compute_population_graph(patient_representations)
    
        # Compute losses
        similarity_loss = self.compute_similarity_loss(patient_representations)
        sparsity_loss = self.compute_sparsity_loss(latent_adj_matrix)
    
        # Total loss
        val_loss = similarity_loss + sparsity_loss
        self.log("val_loss", val_loss, prog_bar=True)  # Log val_loss for monitoring
        return val_loss    
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)




In [22]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(
    train_subgraphs,
    batch_size=128,
    shuffle=True,
)

val_loader = DataLoader(
    val_subgraphs,
    batch_size=128,
    shuffle=False,
)

model = PopulationLevelGraph(
    embedding_dim=node_embeddings.size(1),
    num_patients=len(train_subgraphs),
    node_embeddings=node_embeddings
)

trainer = pl.Trainer(max_epochs=10, accelerator="cpu", devices='auto')
trainer.fit(model, train_loader, val_loader)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name  | Type       | Params | Mode 
---------------------------------------------
0 | conv1 | GCNConv    | 32.9 K | train
1 | conv2 | GCNConv    | 33.0 K | train
2 | mlp   | Sequential | 65.9 K | train
---------------------------------------------
131 K     Trainable params
0         Non-trainable params
131 K     Total params
0.527     Total estimated model params size (MB)
10        Modules in train mode
0         Modules in eval mode


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


In [14]:
# # Initialize model
# num_patients = len(train_subgraphs)  # Each subgraph corresponds to a patient
# embedding_dim = node_embeddings.size(1)
# model = PopulationLevelGraph(embedding_dim, num_patients)

# # Define trainer
# from pytorch_lightning import Trainer
# trainer = Trainer(max_epochs=50, accelerator="gpu", devices="auto")

# # Train the model
# trainer.fit(model, train_loader, val_loader)
