In [124]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import f1_score
from torch.utils.data import DataLoader, TensorDataset
import pytorch_lightning as pl
import os
import pickle
from typing import List, Tuple, Dict
import logging
from tqdm import tqdm
from torch_geometric.nn import GCNConv
import torch_geometric
from joblib import Parallel, delayed


In [125]:
## Define the model
class PopulationGraphModel(pl.LightningModule):
    def __init__(self, num_nodes, embedding_dim=256):  # Fixed parameter definition
        super().__init__()
        
        # Save parameters
        self.num_nodes = num_nodes
        self.embedding_dim = embedding_dim
        
        # Node embedding layer
        self.node_embedding = nn.Embedding(num_nodes, embedding_dim)
        
        # GNN layers
        self.conv1 = GCNConv(embedding_dim, 512)
        self.conv2 = GCNConv(512, embedding_dim)
        
        ## Set up model parameters
        self.encoder = nn.Sequential(
            nn.Linear(embedding_dim * 2, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.3)
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.3),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        
        # Metrics
        self.train_losses = []
        self.val_losses = []
        self.train_f1s = []
        self.val_f1s = []
        
    def forward(self, x1: torch.Tensor, x2: torch.Tensor, edge_index: torch.Tensor = None) -> torch.Tensor:
        
        if edge_index is not None:
            # Apply GNN layers if we have graph structure
            x = self.node_embedding(torch.arange(self.hparams.num_nodes).to(x1.device))
            x = self.conv1(x, edge_index).relu()
            x = self.conv2(x, edge_index)
            
            # Get relevant node embeddings
            x1_emb = x[x1]
            x2_emb = x[x2]
        else:
            # Direct embedding lookup if no graph structure
            x1_emb = self.node_embedding(x1)
            x2_emb = self.node_embedding(x2)
        
        # Combine embeddings
        combined = torch.cat([x1_emb, x2_emb], dim=1)
        features = self.encoder(combined)
        return self.classifier(features)
    
    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], 
                     batch_idx: int) -> torch.Tensor:
        x1, x2, y = batch
        y_hat = self(x1, x2)
        loss = nn.BCELoss()(y_hat, y)
        
        # Calculate F1 score
        with torch.no_grad():
            predictions = (y_hat > 0.5).float()
            f1 = f1_score(y.cpu().numpy(), predictions.cpu().numpy())
        
        # Log metrics
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_f1', f1, prog_bar=True)
        self.train_losses.append(loss.item())
        self.train_f1s.append(f1)
        
        return loss
    
    def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], 
                       batch_idx: int) -> Dict:
        x1, x2, y = batch
        y_hat = self(x1, x2)
        loss = nn.BCELoss()(y_hat, y)
        
        # Calculate F1 score
        predictions = (y_hat > 0.5).float()
        f1 = f1_score(y.cpu().numpy(), predictions.cpu().numpy())
        
        # Log metrics
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_f1', f1, prog_bar=True)
        self.val_losses.append(loss.item())
        self.val_f1s.append(f1)
        
        return {'val_loss': loss, 'val_f1': f1}
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001, weight_decay=1e-5)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.5,
            patience=10,
            verbose=True
        )
        
        return {
            'optimizer': optimizer,
            'lr_scheduler': scheduler,
            'monitor': 'val_loss'
        }

In [126]:
def prepare_data(patient_pheno_lists: List, 
                val_split: float = 0.2) -> Tuple[TensorDataset, TensorDataset]:
    """
    Prepare training and validation datasets
    """
    n_patients = len(patient_pheno_lists)
    
    # Create pairs and labels
    pairs = []
    labels = []
    for i in range(n_patients):
        for j in range(i+1, min(i+100, n_patients)):
            pairs.append((i, j))
            shared = len(set(patient_pheno_lists[i]).intersection(
                set(patient_pheno_lists[j]))) > 0
            labels.append(float(shared))
    
    pairs = torch.tensor(pairs)
    labels = torch.tensor(labels)
    
    # Split train/val
    n_samples = len(pairs)
    n_val = int(n_samples * val_split)
    indices = torch.randperm(n_samples)
    
    train_indices = indices[n_val:]
    val_indices = indices[:n_val]
    
    # Create datasets
    train_data = TensorDataset(
        pairs[train_indices, 0],
        pairs[train_indices, 1],
        labels[train_indices].unsqueeze(1)
    )
    
    val_data = TensorDataset(
        pairs[val_indices, 0],
        pairs[val_indices, 1],
        labels[val_indices].unsqueeze(1)
    )
    
    logger.info(f"Created {len(train_data)} training samples and {len(val_data)} validation samples")
    return train_data, val_data


In [127]:
def plot_metrics(trainer: pl.Trainer, save_dir: str = './outputs'):
    """
    Plot training metrics
    """
    plt.figure(figsize=(15, 5))
    
    # Plot losses
    plt.subplot(121)
    plt.plot(trainer.model.train_losses, label='Train Loss')
    plt.plot(trainer.model.val_losses, label='Val Loss')
    plt.title('Loss Over Time')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    # Plot F1 scores
    plt.subplot(122)
    plt.plot(trainer.model.train_f1s, label='Train F1')
    plt.plot(trainer.model.val_f1s, label='Val F1')
    plt.title('F1 Score Over Time')
    plt.xlabel('Epoch')
    plt.ylabel('F1 Score')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'training_metrics.png'))
    plt.close()

In [128]:
def train_model(num_nodes: int,
               patient_pheno_lists: List,
               save_dir: str = './outputs',
               max_epochs: int = 100) -> Tuple[PopulationGraphModel, pl.Trainer]:
    """
    Train the population graph model
    """
    os.makedirs(save_dir, exist_ok=True)
    
    # Prepare data
    train_data, val_data = prepare_data(patient_pheno_lists)
    
    # Create dataloaders
    train_loader = DataLoader(train_data, batch_size=4096, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_data, batch_size=4096, num_workers=4)
    
    # Initialize model
    model = PopulationGraphModel(num_nodes=num_nodes)
    
    # Initialize trainer
    trainer = pl.Trainer(
        max_epochs=max_epochs,
        accelerator='gpu',
        callbacks=[
            pl.callbacks.EarlyStopping(
                monitor='val_loss',
                patience=20,
                mode='min'
            ),
            pl.callbacks.ModelCheckpoint(
                dirpath=save_dir,
                filename='best_model',
                monitor='val_loss',
                mode='min'
            )
        ],
        logger=True
    )
    
    # Train model
    trainer.fit(model, train_loader, val_loader)
    
    # Plot and save metrics
    plot_metrics(trainer, save_dir)
    
    return model, trainer

In [129]:
def process_batch(i, batch_start, batch_end, model, device):
    """Process a batch of adjacency matrix rows in parallel."""
    pairs_i = torch.tensor([i] * (batch_end - batch_start)).to(device)
    pairs_j = torch.arange(batch_start, batch_end).to(device)
    preds = model(pairs_i, pairs_j).cpu()  # Move results back to CPU to save memory
    results = []
    for idx, k in enumerate(range(batch_start, batch_end)):
        results.append((i, k, preds[idx].item()))
    return results

In [130]:
def generate_adjacency_matrix_parallel(model, patient_pheno_lists, batch_size=1000, num_jobs=-1, save_path=None):
    """
    Generate adjacency matrix for a population graph using parallel processing.

    Args:
        model: Trained PyTorch model for inference.
        patient_pheno_lists: List of patient phenotypes.
        batch_size: Number of samples to process in a batch.
        num_jobs: Number of parallel jobs for processing (-1 uses all available cores).
        save_path: Path to save the adjacency matrix incrementally (optional).

    Returns:
        adj_matrix: Generated adjacency matrix as a torch tensor.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device).eval()  # Move model to device and set to eval mode

    n_patients = len(patient_pheno_lists)
    adj_matrix = torch.zeros((n_patients, n_patients))

    with torch.no_grad():
        for i in tqdm(range(n_patients), desc="Processing patients"):
            # Use joblib to parallelize the batch processing
            results = Parallel(n_jobs=num_jobs)(
                delayed(process_batch)(i, j, min(j + batch_size, n_patients), model, device)
                for j in range(i + 1, n_patients, batch_size)
            )

            # Update adjacency matrix from results
            for batch_results in results:
                for i, k, value in batch_results:
                    adj_matrix[i, k] = value
                    adj_matrix[k, i] = value

            # Save progress incrementally if a save_path is provided
            if save_path and i % 100 == 0:
                torch.save(adj_matrix, save_path)

    # Save the final matrix if save_path is provided
    if save_path:
        torch.save(adj_matrix, save_path)

    return adj_matrix

In [131]:
def generate_adjacency_matrix(model: PopulationGraphModel,
                            patient_pheno_lists: List) -> torch.Tensor:
    """
    Generate adjacency matrix for population graph
    """
    model.eval()
    with torch.no_grad():
        n_patients = len(patient_pheno_lists)
        adj_matrix = torch.zeros((n_patients, n_patients))
        
        # Process in batches for memory efficiency
        batch_size = 100000
        for i in tqdm(range(n_patients)):
            for j in range(i+1, n_patients, batch_size):
                batch_end = min(j + batch_size, n_patients)
                
                # Create pairs for this batch
                pairs_i = torch.tensor(list(range(i, i+1))).repeat(batch_end - j)
                pairs_j = torch.tensor(list(range(j, batch_end)))
                
                # Get predictions
                preds = model(pairs_i, pairs_j)
                
                # Fill adjacency matrix
                for idx, k in enumerate(range(j, batch_end)):
                    adj_matrix[i,k] = preds[idx]
                    adj_matrix[k,i] = preds[idx]
    
    return adj_matrix

In [132]:
def main():
    # Set random seeds
    torch.manual_seed(42)
    np.random.seed(42)
    
    # Create output directory
    output_dir = './Output'
    os.makedirs(output_dir, exist_ok=True)
    
    # Load preprocessed data
    with open(f'{output_dir}/train_patients_phenotypes_list.pkl', 'rb') as f:
        train_patients_phenotypes_list = pickle.load(f)
        
    with open(f'{output_dir}/val_patients_phenotypes_list.pkl', 'rb') as f:
        val_patients_phenotypes_list = pickle.load(f)
    
    # Get max node index instead of counting unique phenotypes
    max_node_idx = max(
        max(max(phenos) if phenos else 0 for phenos in train_patients_phenotypes_list),
        max(max(phenos) if phenos else 0 for phenos in val_patients_phenotypes_list)
    )
    num_nodes = max_node_idx + 1  # Add 1 because indices are 0-based
    
    print(f"Max node index: {max_node_idx}")
    print(f"Number of nodes: {num_nodes}")
    
    # Train model with correct number of nodes
    model, trainer = train_model(num_nodes, train_patients_phenotypes_list)
    
    # Save model
    torch.save(model.state_dict(), os.path.join(output_dir, 'final_model.pt'))
    
    # Generate and save adjacency matrices for train and val sets
    train_adj_matrix = generate_adjacency_matrix(model, train_patients_phenotypes_list)
    val_adj_matrix = generate_adjacency_matrix(model, val_patients_phenotypes_list)
    
    torch.save(train_adj_matrix, os.path.join(output_dir, 'train_adjacency_matrix.pt'))
    torch.save(val_adj_matrix, os.path.join(output_dir, 'val_adjacency_matrix.pt'))

if __name__ == "__main__":
    main()

Max node index: 70272
Number of nodes: 70273


INFO:__main__:Created 2864981 training samples and 716245 validation samples
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
/home/kai/anaconda3/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /kai/Kai_Backup/Study/GiG in rare diease detection/outputs exists and is not empty.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name           | Type       | Params | Mode 
------------------------------------------------------
0 | node_embedding | Embedding  | 18.0 M | train
1 | conv1          | GCNConv    | 131 K  | train
2 | conv2          | GCNConv    | 131 K  | train
3 | encoder        | Sequential | 395 K  | train
4 | classifier     | Sequential | 33.3 K

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

100%|██████████| 36224/36224 [3:15:13<00:00,  3.09it/s]  
100%|██████████| 6400/6400 [06:08<00:00, 17.38it/s] 


In [3]:
import torch 
file_path = './Output/train_adjacency_matrix.pt'
data = torch.load(file_path)

print(data)

  data = torch.load(file_path)


tensor([[0.0000e+00, 2.4411e-01, 3.8471e-01,  ..., 8.0281e-01, 2.0837e-04,
         1.1505e-04],
        [2.4411e-01, 0.0000e+00, 1.3909e-02,  ..., 3.6500e-01, 9.9997e-01,
         9.9706e-01],
        [3.8471e-01, 1.3909e-02, 0.0000e+00,  ..., 1.1961e-03, 1.4693e-04,
         1.4688e-04],
        ...,
        [8.0281e-01, 3.6500e-01, 1.1961e-03,  ..., 0.0000e+00, 1.8188e-03,
         3.3910e-03],
        [2.0837e-04, 9.9997e-01, 1.4693e-04,  ..., 1.8188e-03, 0.0000e+00,
         9.7018e-01],
        [1.1505e-04, 9.9706e-01, 1.4688e-04,  ..., 3.3910e-03, 9.7018e-01,
         0.0000e+00]])
