In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim 
from torch_geometric.data import HeteroData, Data
from torch_geometric.nn import (
    GATv2Conv,        
    Linear,            
    LayerNorm,        
    BatchNorm,        
    HeteroConv        
)
from torch_geometric.loader import DataLoader as GraphDataLoader

from torch_scatter import scatter_mean, scatter_sum, scatter_max, scatter 
# --- Scikit-learn ---
from sklearn.model_selection import KFold, StratifiedKFold   
from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score, roc_curve
from sklearn.preprocessing import (
    RobustScaler,       
    StandardScaler,     
    OneHotEncoder      
)
from sklearn.metrics.pairwise import cosine_similarity 
from sklearn.neighbors import kneighbors_graph      

import numpy as np

import pandas as pd

from tqdm.notebook import tqdm as tqdm_notebook 

import matplotlib.pyplot as plt

from typing import Dict, List, Tuple, Optional, Any, Union

import os
import wandb

In [None]:
class ViewEncoder(nn.Module):
    """
    Graph VAE Encoder using GATv2Conv to map node features and graph structure
    of a specific view to parameters of a latent Gaussian distribution (mu, logvar).
    """
    def __init__(self, in_channels: int, hidden_channels: int, latent_dim: int,
                 heads: int = 4, dropout: float = 0.5, num_gnn_layers: int = 2, edge_dim: int = -1):
        """
        Args:
            in_channels: Dimensionality of input node features for this view.
            hidden_channels: Dimensionality of hidden layers in the GNN.
            latent_dim: Dimensionality of the output latent space (mu and logvar).
            heads: Number of attention heads in GATv2Conv layers.
            dropout: Dropout rate.
            num_gnn_layers: Number of GATv2Conv layers (supports 1 or 2).
            edge_dim: Dimensionality of edge features (-1 if no edge features).
        """
        super().__init__()
        if num_gnn_layers not in [1, 2]:
            raise ValueError("ViewEncoder currently supports 1 or 2 GNN layers.")

        self.num_gnn_layers = num_gnn_layers
        self.dropout_p = dropout
        current_dim = hidden_channels

        self.conv1 = GATv2Conv(in_channels, hidden_channels, heads=heads, concat=True,
                               dropout=dropout, edge_dim=edge_dim, add_self_loops=True)
        self.bn1 = BatchNorm(hidden_channels * heads)

        if num_gnn_layers > 1:
            
            self.conv2 = GATv2Conv(hidden_channels * heads, hidden_channels, heads=heads, concat=True,
                                   dropout=dropout, edge_dim=edge_dim, add_self_loops=True)
            self.bn2 = BatchNorm(hidden_channels * heads)
            current_dim = hidden_channels * heads

        self.fc_mu = Linear(current_dim, latent_dim)
        self.fc_logvar = Linear(current_dim, latent_dim)

        self.reset_parameters()

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.bn1.reset_parameters()
        if self.num_gnn_layers > 1:
            self.conv2.reset_parameters()
            self.bn2.reset_parameters()
        self.fc_mu.reset_parameters()
        self.fc_logvar.reset_parameters()

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor,
                edge_attr: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x: Node feature matrix [num_nodes, in_channels].
            edge_index: Graph connectivity [2, num_edges].
            edge_attr: Edge feature matrix [num_edges, edge_dim] (optional).

        Returns:
            mu: Latent mean [num_nodes, latent_dim].
            logvar: Latent log variance [num_nodes, latent_dim].
        """
        # Layer 1
        x = self.conv1(x, edge_index, edge_attr=edge_attr)
        x = self.bn1(x)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout_p, training=self.training)

        # Layer 2 (if exists)
        if self.num_gnn_layers == 2:
            x = self.conv2(x, edge_index, edge_attr=edge_attr)
            x = self.bn2(x)
            x = F.elu(x)
            x = F.dropout(x, p=self.dropout_p, training=self.training)

        # Output Projections
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        #logvar = torch.tanh(self.fc_logvar(x)) * 5.0
        #logvar = F.hardtanh(self.fc_logvar(x), min_val=-6.0, max_val=2.0)
        return mu, logvar

# --- 2. Structure Decoder (Adjacency Reconstruction) ---
class StructureDecoder(nn.Module):
    """
    Decodes latent embeddings to reconstruct graph adjacency matrix logits
    using inner product.
    """
    def __init__(self, activation: str = 'none'):
        """
        Args:
            activation: Output activation ('sigmoid' or 'none'). 'none' is suitable
                        for BCEWithLogitsLoss.
        """
        super().__init__()
        if activation not in ['sigmoid', 'none']:
             raise ValueError("Activation must be 'sigmoid' or 'none'")
        self.activation = activation

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        """
        Args:
            z: Latent node embeddings [num_nodes, latent_dim].

        Returns:
            adj_rec_logits or adj_rec_probs: Reconstructed adjacency [num_nodes, num_nodes].
        """
        adj_rec_logits = torch.matmul(z, z.t())
        if self.activation == 'sigmoid':
            return torch.sigmoid(adj_rec_logits)
        return adj_rec_logits

# --- 3. Attribute Decoder (Feature Reconstruction) ---
class AttributeDecoder(nn.Module):
    """
    Decodes latent embeddings back to the original node feature space using an MLP.
    Applies Tanh activation to the output to control scale and prevent explosion.
    """
    def __init__(self, latent_dim: int, original_feature_dim: int, hidden_decoder_dim: Optional[int] = None):
        """
        Args:
            latent_dim: Dimensionality of the latent embeddings.
            original_feature_dim: Dimensionality of the original node features to reconstruct.
            hidden_decoder_dim: Dimensionality of the hidden layer in the MLP decoder.
                                Defaults to latent_dim if None.
        """
        super().__init__()
        if hidden_decoder_dim is None:
            hidden_decoder_dim = latent_dim

        self.mlp = nn.Sequential(
            Linear(latent_dim, hidden_decoder_dim),
            nn.ReLU(),
            #nn.ELU(),
            Linear(hidden_decoder_dim, original_feature_dim),
        )
        self.norm_layer = LayerNorm(original_feature_dim)
        self.reset_parameters()

    def reset_parameters(self):
         for layer in self.mlp:
             if hasattr(layer, 'reset_parameters'):
                 layer.reset_parameters()

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        x_hat = self.mlp(z)
        return self.norm_layer(x_hat)

# --- 4. Attention Fusion Layer ---
class AttentionFusionLayer(nn.Module):
    """
    Fuses embeddings from multiple views using an attention mechanism.
    Takes view embeddings stacked along a dimension.
    """
    def __init__(self, embed_dim: int, num_views: int, hidden_dim_attention: int, dropout: float = 0.3):
        """
        Args:
            embed_dim: Dimensionality of the input embeddings from each view.
            num_views: Number of views being fused.
            hidden_dim_attention: Hidden dimension for the attention MLP.
            dropout: Dropout rate applied to the fused embedding.
        """
        super().__init__()
        self.embed_dim = embed_dim
        self.num_views = num_views

        self.attention_mlp = nn.Sequential(
            # Input is concatenation of all view embeddings
            Linear(embed_dim * num_views, hidden_dim_attention),
            nn.Tanh(),
            Linear(hidden_dim_attention, num_views) 
        )

        self.dropout = nn.Dropout(dropout)
        self.final_norm = LayerNorm(embed_dim) 

        self.reset_parameters()

    def reset_parameters(self):
         for layer in self.attention_mlp:
             if hasattr(layer, 'reset_parameters'):
                 layer.reset_parameters()
         self.final_norm.reset_parameters()

    def forward(self, view_embeddings_stacked: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            view_embeddings_stacked: Tensor of view embeddings, shape [batch_size, num_views, embed_dim].

        Returns:
            fused_embedding: Fused representation [batch_size, embed_dim].
            attention_weights: Attention weights applied [batch_size, num_views].
        """
        batch_size = view_embeddings_stacked.shape[0]

        # Reshape for attention MLP: [batch_size, num_views * embed_dim]
        concatenated_embeddings = view_embeddings_stacked.view(batch_size, -1)

        # Get attention scores
        attn_scores = self.attention_mlp(concatenated_embeddings) # [batch_size, num_views]

        # Apply softmax to get attention weights
        attention_weights = F.softmax(attn_scores, dim=1) # [batch_size, num_views]

        # Calculate weighted sum using original embeddings
        # Reshape weights to allow broadcasting: [batch_size, num_views, 1]
        attention_weights_expanded = attention_weights.unsqueeze(-1)

        # Element-wise multiplication and sum over the num_views dimension
        fused_embedding = (view_embeddings_stacked * attention_weights_expanded).sum(dim=1) 

        fused_embedding = self.dropout(fused_embedding)
        fused_embedding = self.final_norm(fused_embedding)

        return fused_embedding, attention_weights


# --- 5. Classifier MLP ---
class ClassifierMLP(nn.Module):
    """
    Simple MLP for binary classification based on the fused embedding.
    """
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int = 1, dropout: float = 0.5):
        """
        Args:
            input_dim: Dimensionality of the fused input embedding.
            hidden_dim: Dimensionality of the hidden layer.
            output_dim: Dimensionality of the output (1 for binary classification logits).
            dropout: Dropout rate.
        """
        super().__init__()
        self.mlp = nn.Sequential(
            Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            Linear(hidden_dim, output_dim) 
        )
        self.reset_parameters()

    def reset_parameters(self):
         for layer in self.mlp:
             if hasattr(layer, 'reset_parameters'):
                 layer.reset_parameters()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Fused input embedding [batch_size, input_dim].

        Returns:
            logits: Classification logits [batch_size, output_dim].
        """
        return self.mlp(x)

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MHA_CLSToken_FusionLayer(nn.Module):
    """
    Fuses view embeddings using a learnable [CLS] token and Multi-Head Self-Attention.
    """
    def __init__(self, embed_dim: int, num_heads: int, 
                 ffn_dim_multiplier: int = 2, dropout: float = 0.1, 
                 output_dim: Optional[int] = None):
        """
        Args:
            embed_dim: Dimensionality of the input view embeddings.
            num_heads: Number of attention heads.
            ffn_dim_multiplier: Multiplier for the feed-forward layer's hidden dim.
            dropout: Dropout rate.
            output_dim: Final dimension of the fused embedding. Defaults to embed_dim.
        """
        super().__init__()
        self.output_dim = output_dim if output_dim is not None else embed_dim

        # 1. The learnable [CLS] token
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))

        # 2. The Multi-Head Attention layer
        self.mha = nn.MultiheadAttention(
            embed_dim=embed_dim, 
            num_heads=num_heads, 
            dropout=dropout, 
            batch_first=True 
        )
        
        # 3. A standard Feed-Forward Network (part of a Transformer block)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * ffn_dim_multiplier),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(embed_dim * ffn_dim_multiplier, embed_dim)
        )

        # 4. Layer Normalization
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

        # 5. Optional final projection layer
        self.final_projection = nn.Linear(embed_dim, self.output_dim) if embed_dim != self.output_dim else nn.Identity()
        
    def forward(self, view_embeddings_stacked: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """
        Args:
            view_embeddings_stacked: Tensor of view embeddings, shape [batch_size, num_views, embed_dim].
        
        Returns:
            fused_embedding: A single fused vector per patient, shape [batch_size, output_dim].
            attention_weights: None, as extracting them is complex and not the primary goal.
        """
        batch_size = view_embeddings_stacked.shape[0]

        # Prepend the CLS token to the sequence of view embeddings
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, view_embeddings_stacked), dim=1) # Shape: [batch_size, num_views + 1, embed_dim]

        # --- First part of Transformer Block: MHA + Residual + Norm ---
        # Self-attention: query, key, and value are all the same
        attn_output, _ = self.mha(x, x, x)
        # Residual connection
        x = x + attn_output
        x = self.norm1(x)

        # --- Second part of Transformer Block: FFN + Residual + Norm ---
        ffn_output = self.ffn(x)
        # Residual connection
        x = x + ffn_output
        x = self.norm2(x)
        
        # The final fused representation is the output of the CLS token (at position 0)
        cls_output = x[:, 0, :] # Shape: [batch_size, embed_dim]

        # Apply final projection
        fused_embedding = self.final_projection(cls_output)

        return fused_embedding, None # Return None for attention weights

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear, LayerNorm # Ensure Linear is imported if not already

class ProjectionHead(nn.Module):
    """
    Projects embeddings (typically mu from VAE) to a new space for contrastive learning.
    """
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, dropout: float = 0.1):
        """
        Args:
            input_dim: Dimensionality of the input embeddings (e.g., d_embed).
            hidden_dim: Dimensionality of the hidden layer.
            output_dim: Dimensionality of the projected embeddings for CL.
            dropout: Dropout rate.
        """
        super().__init__()
        self.net = nn.Sequential(
            Linear(input_dim, hidden_dim),
            LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            Linear(hidden_dim, output_dim)
        )
        self.reset_parameters()

    def reset_parameters(self):
        for layer in self.net:
            if hasattr(layer, 'reset_parameters'):
                layer.reset_parameters()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: Input embeddings [batch_size, input_dim] or [num_nodes, input_dim].

        Returns:
            Projected and L2-normalized embeddings [batch_size, output_dim].
        """
        projected_x = self.net(x)
        return F.normalize(projected_x, p=2, dim=-1)

In [None]:
def calculate_contrastive_loss(
    sampled_zs_per_view_batch: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
    temperature: float
) -> torch.Tensor:
    """
    Calculates cross-view contrastive loss for patients present in multiple views within the batch.
    Args:
        sampled_zs_per_view_batch: Dict where key is view_name, value is a tuple:
                                   (Tensor of z_embeddings for patients in batch having this view,
                                    Tensor of global_indices for these patients).
        temperature: Temperature for InfoNCE.
    Returns:
        Contrastive loss scalar.
    """
    if sampled_zs_per_view_batch:
        first_emb = list(sampled_zs_per_view_batch.values())[0][0]
        if first_emb is not None:
            device = first_emb.device

    total_contrastive_loss = torch.tensor(0.0, device=device)
    num_contrastive_pairs_total = 0

    # Create a list of (global_patient_idx, view_name, embedding_tensor)
    all_embeddings_flat = []
    for view_name, (embeddings, global_indices) in sampled_zs_per_view_batch.items():
        if embeddings is not None and global_indices is not None and embeddings.numel() > 0: # Check if embeddings exist
            for i in range(embeddings.shape[0]):
                all_embeddings_flat.append((global_indices[i].item(), view_name, embeddings[i]))

    if not all_embeddings_flat:
        return total_contrastive_loss

    # Group embeddings by global_patient_idx
    patient_to_view_embeddings = {}
    for global_idx, view_name, emb in all_embeddings_flat:
        if global_idx not in patient_to_view_embeddings:
            patient_to_view_embeddings[global_idx] = []
        patient_to_view_embeddings[global_idx].append(emb)

    # For each patient with embeddings from multiple views
    for global_idx, view_embs_list in patient_to_view_embeddings.items():
        if len(view_embs_list) < 2: # Need at least two views for this patient
            continue

        # Form positive pairs for this patient
        for i in range(len(view_embs_list)):
            for j in range(i + 1, len(view_embs_list)):
                z_i = view_embs_list[i].unsqueeze(0) # Anchor [1, d_embed]
                z_j = view_embs_list[j].unsqueeze(0) # Positive [1, d_embed]

                # Negative samples: all other embeddings in all_embeddings_flat NOT from this patient
                negatives = torch.stack([
                    other_emb for other_global_idx, _, other_emb in all_embeddings_flat
                    if other_global_idx != global_idx
                ])

                if negatives.numel() == 0: # Only one patient in batch, no negatives
                    continue

                # Cosine similarity
                sim_positive = F.cosine_similarity(z_i, z_j, dim=1) / temperature # Shape [1]
                sim_negatives = F.cosine_similarity(z_i.expand(negatives.shape[0], -1), negatives, dim=1) / temperature # Shape [num_negatives]

                # Concatenate positive score with negative scores for logits
                logits = torch.cat([sim_positive, sim_negatives]) # Shape [1 + num_negatives]
                target_index = torch.tensor([0], device=device, dtype=torch.long) 


                total_contrastive_loss += F.cross_entropy(logits.unsqueeze(0), target_index)
                num_contrastive_pairs_total += 1

    return total_contrastive_loss / num_contrastive_pairs_total if num_contrastive_pairs_total > 0 else torch.tensor(0.0, device=device)

In [None]:
def get_view_subgraph_and_features(
    full_data: HeteroData,
    view_name: str,
    batch_patient_global_indices: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
    """
    Extracts features and a local subgraph for a specific view.
    For 'radiology', x_view_subset_batch will be None as features are handled by lesion aggregator.
    It still extracts the patient-patient similarity graph for radiology if it exists.
    """
    device = batch_patient_global_indices.device
    x_view_feature_key = f'x_{view_name}'
    edge_type_sim = ('patient', f'similar_to_{view_name}', 'patient')

    x_view_subset_batch: Optional[torch.Tensor] = None # Initialize
    global_indices_of_subset_in_batch: torch.Tensor = batch_patient_global_indices 

    if view_name == 'radiology':
        mask_feature_key = f'{view_name}_mask'
        if mask_feature_key not in full_data['patient']:
            print(f"Warning: Mask key '{mask_feature_key}' not found for view '{view_name}'. Assuming no patients have this view in batch.")
            return None, torch.empty((2,0), dtype=torch.long, device=device), None, torch.empty(0, dtype=torch.long, device=device)

        view_presence_mask_all_patients = full_data['patient'][mask_feature_key]
        view_presence_in_batch = view_presence_mask_all_patients[batch_patient_global_indices]
        global_indices_of_subset_in_batch = batch_patient_global_indices[view_presence_in_batch]

        if global_indices_of_subset_in_batch.numel() == 0:
            return None, torch.empty((2,0), dtype=torch.long, device=device), None, global_indices_of_subset_in_batch

    elif view_name == 'clinical':
        if x_view_feature_key not in full_data['patient']:
             print(f"Warning: Clinical features '{x_view_feature_key}' not found.")
             return None, torch.empty((2,0), dtype=torch.long, device=device), None, torch.empty(0, dtype=torch.long, device=device)
        global_indices_of_subset_in_batch = batch_patient_global_indices
        x_view_subset_batch = full_data['patient'][x_view_feature_key][global_indices_of_subset_in_batch]

    else:
        mask_feature_key = f'{view_name}_mask'
        if mask_feature_key not in full_data['patient']:
            print(f"Warning: Mask key '{mask_feature_key}' not found for view '{view_name}'.")
            return None, torch.empty((2,0), dtype=torch.long, device=device), None, torch.empty(0, dtype=torch.long, device=device)
        if x_view_feature_key not in full_data['patient']:
             print(f"Warning: Features '{x_view_feature_key}' not found for view '{view_name}'.")
             return None, torch.empty((2,0), dtype=torch.long, device=device), None, torch.empty(0, dtype=torch.long, device=device)

        view_presence_mask_all_patients = full_data['patient'][mask_feature_key]
        view_presence_in_batch = view_presence_mask_all_patients[batch_patient_global_indices]
        global_indices_of_subset_in_batch = batch_patient_global_indices[view_presence_in_batch]

        if global_indices_of_subset_in_batch.numel() == 0:
            return None, torch.empty((2,0), dtype=torch.long, device=device), None, global_indices_of_subset_in_batch
        x_view_subset_batch = full_data['patient'][x_view_feature_key][global_indices_of_subset_in_batch]


    if edge_type_sim not in full_data.edge_types:
        # print(f"Warning: Similarity edge type '{edge_type_sim}' not found for view '{view_name}")
        return x_view_subset_batch, torch.empty((2,0), dtype=torch.long, device=device), None, global_indices_of_subset_in_batch

    view_full_edge_index = full_data[edge_type_sim].edge_index
    view_full_edge_attr = getattr(full_data[edge_type_sim], 'edge_attr', None)

    num_batch_patients_with_view_for_sim_graph = global_indices_of_subset_in_batch.numel()
    if num_batch_patients_with_view_for_sim_graph == 0: #
        return x_view_subset_batch, torch.empty((2,0), dtype=torch.long, device=device), None, global_indices_of_subset_in_batch

    global_to_local_idx_map_sim_graph = {global_idx.item(): local_idx for local_idx, global_idx in enumerate(global_indices_of_subset_in_batch)}

    src_nodes_global = view_full_edge_index[0]
    dst_nodes_global = view_full_edge_index[1]
    mask_src_in_subset = torch.isin(src_nodes_global, global_indices_of_subset_in_batch)
    mask_dst_in_subset = torch.isin(dst_nodes_global, global_indices_of_subset_in_batch)
    edge_selection_mask = mask_src_in_subset & mask_dst_in_subset

    if not edge_selection_mask.any():
        empty_edge_index = torch.empty((2,0), dtype=torch.long, device=device)
        empty_edge_attr = None
        if view_full_edge_attr is not None:
             empty_edge_attr = torch.empty((0, view_full_edge_attr.shape[1]), dtype=view_full_edge_attr.dtype, device=device)
        return x_view_subset_batch, empty_edge_index, empty_edge_attr, global_indices_of_subset_in_batch

    selected_edges_global_src = src_nodes_global[edge_selection_mask]
    selected_edges_global_dst = dst_nodes_global[edge_selection_mask]

    try:
        local_edge_src = torch.tensor([global_to_local_idx_map_sim_graph[idx.item()] for idx in selected_edges_global_src], dtype=torch.long, device=device)
        local_edge_dst = torch.tensor([global_to_local_idx_map_sim_graph[idx.item()] for idx in selected_edges_global_dst], dtype=torch.long, device=device)
        local_edge_index_batch_sim_graph = torch.stack([local_edge_src, local_edge_dst], dim=0)
    except KeyError as e:
         print(f"Error remapping indices for view {view_name} similarity graph. Missing global index in map: {e}. Returning no edges.")
         empty_edge_index = torch.empty((2,0), dtype=torch.long, device=device)
         empty_edge_attr = None
         if view_full_edge_attr is not None:
             empty_edge_attr = torch.empty((0, view_full_edge_attr.shape[1]), dtype=view_full_edge_attr.dtype, device=device)
         return x_view_subset_batch, empty_edge_index, empty_edge_attr, global_indices_of_subset_in_batch
    
    local_edge_attr_batch_sim_graph = None
    if view_full_edge_attr is not None:
        local_edge_attr_batch_sim_graph = view_full_edge_attr[edge_selection_mask]

    return x_view_subset_batch, local_edge_index_batch_sim_graph, local_edge_attr_batch_sim_graph, global_indices_of_subset_in_batch

def get_dense_adj_for_reconstruction(local_edge_index: Optional[torch.Tensor], num_nodes_in_subset: int, device: torch.device) -> torch.Tensor:
    """Creates a dense adjacency matrix from local_edge_index for reconstruction loss."""
    adj = torch.zeros((num_nodes_in_subset, num_nodes_in_subset), device=device)
    if local_edge_index is not None and local_edge_index.numel() > 0:
        adj[local_edge_index[0], local_edge_index[1]] = 1
        #make symmetric
        adj = torch.max(adj, adj.t()) 
    return adj

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter_mean, scatter_add, scatter_max 

class RadiologyLesionAttentionAggregator(nn.Module):
    def __init__(self, lesion_feature_dim: int, patient_embed_dim: int,
                 attention_hidden_dim: Optional[int] = None, dropout: float = 0.1):
        super().__init__()
        self.lesion_feature_dim = lesion_feature_dim
        self.patient_embed_dim = patient_embed_dim
        
        if attention_hidden_dim is None:
            attention_hidden_dim = lesion_feature_dim

        # Attention mechanism: learns to score lesions
        # Takes individual lesion features
        self.attention_mlp = nn.Sequential(
            nn.Linear(lesion_feature_dim, attention_hidden_dim),
            nn.Tanh(),
            nn.Linear(attention_hidden_dim, 1) 
        )

        if lesion_feature_dim != patient_embed_dim:
            self.output_projection = nn.Linear(lesion_feature_dim, patient_embed_dim)
        else:
            self.output_projection = nn.Identity()

        self.dropout = nn.Dropout(dropout)
        self.norm_layer = nn.LayerNorm(patient_embed_dim) # Normalize the final patient embedding

        self.reset_parameters()

    def reset_parameters(self):
        for layer in self.attention_mlp:
            if hasattr(layer, 'reset_parameters'):
                layer.reset_parameters()
        if isinstance(self.output_projection, nn.Linear):
            self.output_projection.reset_parameters()
        self.norm_layer.reset_parameters()

    def forward(self, lesion_x: torch.Tensor, patient_to_lesion_edge_index: torch.Tensor,
                num_patients_in_batch: int) -> torch.Tensor:
        """
        Aggregates lesion features for patients using attention.

        Args:
            lesion_x: Tensor of lesion features [total_lesions_in_batch, lesion_feature_dim].
            patient_to_lesion_edge_index: Edge index [2, num_edges] connecting
                                           batch-local patient indices to batch-local lesion indices.
                                           edge_index[0] = batch_local_patient_idx
                                           edge_index[1] = batch_local_lesion_idx
            num_patients_in_batch: The number of unique patients in this batch for whom
                                   we need to produce aggregated embeddings.

        Returns:
            patient_radiology_embeddings: Tensor [num_patients_in_batch, patient_embed_dim].
                                          Contains aggregated features for patients who have lesions.
                                          For patients with no lesions, their rows will be zeros.
        """
        if lesion_x.numel() == 0 or patient_to_lesion_edge_index.numel() == 0:
            # No lesions in this batch, return zeros for all patients
            return torch.zeros((num_patients_in_batch, self.patient_embed_dim),
                               device=lesion_x.device, dtype=lesion_x.dtype)

        batch_local_patient_indices = patient_to_lesion_edge_index[0]
        batch_local_lesion_indices = patient_to_lesion_edge_index[1] 

        relevant_lesion_features = lesion_x[batch_local_lesion_indices]

        # 1. Calculate attention scores for each lesion
        attn_scores = self.attention_mlp(relevant_lesion_features)  # [num_batch_edges, 1]

        # 2. Apply softmax grouped by patient to get attention weights

        attn_scores_max_per_patient = scatter_max(attn_scores.squeeze(-1), batch_local_patient_indices, dim=0, dim_size=num_patients_in_batch)[0]
        attn_scores_stabilized = attn_scores.squeeze(-1) - attn_scores_max_per_patient[batch_local_patient_indices]
        
        attn_exp = torch.exp(attn_scores_stabilized)
        attn_exp_sum_per_patient = scatter_add(attn_exp, batch_local_patient_indices, dim=0, dim_size=num_patients_in_batch)
        
        attn_exp_sum_per_patient = attn_exp_sum_per_patient.clamp(min=1e-12) 
        
        alpha = attn_exp / attn_exp_sum_per_patient[batch_local_patient_indices] # [num_batch_edges]
        alpha = alpha.unsqueeze(-1) # [num_batch_edges, 1]

        # 3. Calculate weighted sum of lesion features for each patient
        weighted_lesion_features = relevant_lesion_features * alpha # [num_batch_edges, lesion_feature_dim]
        
        # Aggregate weighted features per patient
        aggregated_patient_features = scatter_add(
            weighted_lesion_features, batch_local_patient_indices, dim=0, dim_size=num_patients_in_batch
        ) # [num_patients_in_batch, lesion_feature_dim]

        # 4. Optional output projection and normalization
        projected_features = self.output_projection(aggregated_patient_features)
        projected_features = self.dropout(projected_features)
        normalized_features = self.norm_layer(projected_features)
        
        return normalized_features

In [None]:
class EndToEndMultiViewVAE_CL_AttentionRadiology(nn.Module):
    def __init__(self, 
                 view_configs: Dict[str, Any], 
                 radiology_aggregator_config: Dict[str, Any],
                 projection_head_config: Dict[str, Any],
                 fusion_config: Dict[str, Any], 
                 classifier_config: Dict[str, Any], 
                 d_embed: int,
                 missing_strategy: str ='zero'):
        super().__init__()
        self.views = list(view_configs.keys())
        self.d_embed = d_embed
        self.missing_strategy = missing_strategy

        self.radiology_lesion_aggregator = None
        if 'radiology' in self.views and radiology_aggregator_config:
            self.radiology_lesion_aggregator = RadiologyLesionAttentionAggregator(
                lesion_feature_dim=radiology_aggregator_config['lesion_feature_dim'],
                patient_embed_dim=radiology_aggregator_config['aggregated_output_dim'],
                attention_hidden_dim=radiology_aggregator_config.get('attention_hidden_dim'),
                dropout=radiology_aggregator_config.get('dropout', 0.1)
            )
            if view_configs['radiology']['in_channels'] != radiology_aggregator_config['aggregated_output_dim']:
                raise ValueError("Radiology VAE in_channels must match aggregator output_dim")
        
        # --- VAE Components ---
        self.vae_encoders = nn.ModuleDict()
        self.structure_decoders = nn.ModuleDict()
        self.attribute_decoders = nn.ModuleDict()
        if self.missing_strategy == 'learnable':
             self.missing_embeddings_params = nn.ParameterDict()

        self.projection_heads = nn.ModuleDict()

        for view, config in view_configs.items():
            self.vae_encoders[view] = ViewEncoder(
                config['in_channels'], config['hidden_channels_vae'], d_embed,
                config.get('heads', 4), config.get('dropout', 0.3),
                config.get('num_gnn_layers_vae', 2), config.get('edge_dim', -1)
            )
            self.structure_decoders[view] = StructureDecoder()
            self.attribute_decoders[view] = AttributeDecoder(d_embed, config['in_channels'])

            self.projection_heads[view] = ProjectionHead(
                input_dim=d_embed,
                hidden_dim=projection_head_config.get('hidden_dim', d_embed),
                output_dim=projection_head_config.get('output_dim', d_embed),
                dropout=projection_head_config.get('dropout', 0.1)
            )
            if self.missing_strategy == 'learnable':
                self.missing_embeddings_params[view] = nn.Parameter(torch.randn(1, d_embed))

        self.fusion_layer = MHA_CLSToken_FusionLayer(
                embed_dim=d_embed, # The dimension of your view embeddings
                num_heads=fusion_config.get('num_fusion_heads', 4),
                ffn_dim_multiplier=fusion_config.get('fusion_ffn_multiplier', 2),
                dropout=fusion_config.get('dropout_fusion', 0.1),
                output_dim=fusion_config.get('fused_dim', d_embed)
        )
        
        self.classifier = ClassifierMLP(
                input_dim=fusion_config.get('fused_dim', d_embed),
                hidden_dim=classifier_config['hidden_dim_classifier'],
                output_dim=1, dropout=classifier_config.get('dropout_class', 0.5)
        )

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + eps * std
        else:
            return mu

    def forward(self, full_data: HeteroData, batch_patient_global_indices: torch.Tensor):
        device = batch_patient_global_indices.device

        vae_outputs_for_loss = {view: {} for view in self.views}
        # all_patient_view_zs_for_fusion = {
        #     idx.item(): {view: torch.zeros(self.d_embed, device=device) for view in self.views}
        #     for idx in batch_patient_global_indices
        # }
        all_patient_zs_for_fusion = {}
        
        #mu_projected' embeddings for the contrastive loss
        all_patient_mus_projected_for_cl = {}

        for view in self.views:
            x_patient_level_subset, local_patient_sim_edge_idx, local_patient_sim_edge_attr, global_indices_subset_patients = \
                get_view_subgraph_and_features(full_data, view, batch_patient_global_indices)
            
            if global_indices_subset_patients.numel() == 0:
                vae_outputs_for_loss[view] = { # Ensure structure exists for loss calculation
                    'mu': None, 'logvar': None, 'z_sampled_for_dec': None,
                    'rec_adj_logits': None, 'rec_x': None,
                    'original_x_subset': None, 'original_adj_subset': None
                }
                continue

            num_active_patients_for_view = global_indices_subset_patients.shape[0]
            x_for_vae_encoder, original_x_for_vae_reconstruction = None, None
            if view == 'radiology' and self.radiology_lesion_aggregator:
                                
                patient_lesion_edges_all = full_data['patient', 'has_lesion', 'lesion'].edge_index
                all_lesion_features_all = full_data['lesion'].x

                # Map active global patient indices to their 0-based local indices within this subset
                active_patient_global_to_local_map = {glob_idx.item(): i for i, glob_idx in enumerate(global_indices_subset_patients)}
                
                batch_lesion_src_patient_local_idx_list = []
                batch_lesion_node_global_idx_list = []

                for i_edge in range(patient_lesion_edges_all.shape[1]):
                    src_patient_global = patient_lesion_edges_all[0, i_edge].item()
                    dst_lesion_global = patient_lesion_edges_all[1, i_edge].item()
                    if src_patient_global in active_patient_global_to_local_map:
                        batch_lesion_src_patient_local_idx_list.append(active_patient_global_to_local_map[src_patient_global])
                        batch_lesion_node_global_idx_list.append(dst_lesion_global)
                
                if batch_lesion_node_global_idx_list:
                    batch_lesion_src_patient_local_idx_t = torch.tensor(batch_lesion_src_patient_local_idx_list, dtype=torch.long, device=device)
                    batch_lesion_node_global_idx_t = torch.tensor(batch_lesion_node_global_idx_list, dtype=torch.long, device=device)
                    
                    lesion_features_for_batch_agg = all_lesion_features_all[batch_lesion_node_global_idx_t]
                    unique_lesions_in_batch, inverse_indices = torch.unique(batch_lesion_node_global_idx_t, return_inverse=True)
                    batch_local_lesion_indices_for_agg = torch.arange(lesion_features_for_batch_agg.shape[0], device=device)

                    patient_to_batch_lesion_edge_index = torch.stack([
                        batch_lesion_src_patient_local_idx_t, # patient indices (local to current subset)
                        batch_local_lesion_indices_for_agg    # lesion indices (local to lesions_for_batch_agg)
                    ], dim=0)

                    x_for_vae_encoder = self.radiology_lesion_aggregator(
                        lesion_features_for_batch_agg,
                        patient_to_batch_lesion_edge_index,
                        num_active_patients_for_view 
                    )
                    original_x_for_vae_reconstruction = x_for_vae_encoder
                
                pass # Assume this part is correctly implemented
            elif x_patient_level_subset is not None and x_patient_level_subset.numel() > 0:
                x_for_vae_encoder = x_patient_level_subset
                original_x_for_vae_reconstruction = x_for_vae_encoder

            if x_for_vae_encoder is not None and x_for_vae_encoder.numel() > 0:
                num_nodes_for_vae = x_for_vae_encoder.shape[0]
                # 1. Get mu and logvar from the VAE encoder
                mu, logvar = self.vae_encoders[view](x_for_vae_encoder, local_patient_sim_edge_idx, local_patient_sim_edge_attr)
                # 2. Get z_sampled (noisy version) for reconstruction and fusion
                z_sampled = self.reparameterize(mu, logvar)

                mu_projected = self.projection_heads[view](mu)

                # Store for VAE loss
                vae_outputs_for_loss[view] = {
                    'mu': mu, 'logvar': logvar, 'z_sampled_for_dec': z_sampled,
                    'original_x_subset': original_x_for_vae_reconstruction,
                    'original_adj_subset': get_dense_adj_for_reconstruction(local_patient_sim_edge_idx, num_nodes_for_vae, device),
                    'rec_adj_logits': self.structure_decoders[view](z_sampled),
                    'rec_x': self.attribute_decoders[view](z_sampled)
                }
                
                # 5. Gather embeddings for downstream tasks
                for i, global_idx_tensor in enumerate(global_indices_subset_patients):
                    global_idx_item = global_idx_tensor.item()
                    
                    # Store z_sampled for the FUSION layer
                    if global_idx_item not in all_patient_zs_for_fusion:
                        all_patient_zs_for_fusion[global_idx_item] = {}
                    all_patient_zs_for_fusion[global_idx_item][view] = z_sampled[i]

                    # Store mu_projected for the CONTRASTIVE loss
                    if global_idx_item not in all_patient_mus_projected_for_cl:
                        all_patient_mus_projected_for_cl[global_idx_item] = {}
                    all_patient_mus_projected_for_cl[global_idx_item][view] = mu_projected[i]

        # 1. FUSION: Use the 'z_sampled' embeddings
        batch_fusion_input_list = []
        for global_idx_tensor in batch_patient_global_indices:
            global_idx_item = global_idx_tensor.item()
            patient_embs = []
            patient_z_data = all_patient_zs_for_fusion.get(global_idx_item, {})
            for view in self.views:
                emb = patient_z_data.get(view, torch.zeros(self.d_embed, device=device))
                if view not in patient_z_data and self.missing_strategy == 'learnable':
                    emb = self.missing_embeddings_params[view].squeeze(0)
                patient_embs.append(emb)
            batch_fusion_input_list.append(torch.stack(patient_embs))

        batch_fusion_input_tensor = torch.stack(batch_fusion_input_list)
        z_fused, fusion_attention = self.fusion_layer(batch_fusion_input_tensor)
        logits = self.classifier(z_fused)

        # 2. CONTRASTIVE LEARNING: Use the 'mu_projected' embeddings
        mus_projected_for_cl_formatted = {}
        for patient_idx, view_data in all_patient_mus_projected_for_cl.items():
            for view, emb in view_data.items():
                if view not in mus_projected_for_cl_formatted:
                    mus_projected_for_cl_formatted[view] = {'embeddings': [], 'indices': []}
                mus_projected_for_cl_formatted[view]['embeddings'].append(emb)
                mus_projected_for_cl_formatted[view]['indices'].append(patient_idx)

        final_mus_projected_for_cl = {
            view: (torch.stack(data['embeddings']), torch.tensor(data['indices'], device=device))
            for view, data in mus_projected_for_cl_formatted.items()
        }

        return logits, vae_outputs_for_loss, final_mus_projected_for_cl, fusion_attention

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import roc_curve, auc
from scipy import interpolate

def plot_kfold_roc_curves(roc_data_per_fold: List[Dict], title: str = "ROC 10-CV", save_path: Optional[str] = None):
    """
    Generates a K-fold ROC plot with individual, mean, and merged curves.

    Args:
        roc_data_per_fold (List[Dict]): A list where each element is a dictionary
            from a fold containing {'fpr', 'tpr', 'auc', 'y_true', 'y_pred'}.
        title (str): The title of the plot.
    """
    fig, ax = plt.subplots(figsize=(10, 9))
    
    tprs = []
    aucs = []
    mean_fpr = np.linspace(0, 1, 100)
    
    all_y_true = []
    all_y_pred = []

    # Plot individual fold ROC curves
    for i, data in enumerate(roc_data_per_fold):
        ax.plot(data['fpr'], data['tpr'], lw=1.5, alpha=0.6,
                label=f"Fold {i+1} (AUC = {data['auc']:.2f}), N={len(data['y_true'])} patients")
        
        # For calculating the mean ROC
        interp_tpr = interpolate.interp1d(data['fpr'], data['tpr'], kind='linear', bounds_error=False, fill_value=(0.0, 1.0))(mean_fpr)
        tprs.append(interp_tpr)
        aucs.append(data['auc'])
        

    # Plot chance line
    ax.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', label='Chance', alpha=0.8)

    # Calculate and plot MEAN ROC
    mean_tpr = np.mean(tprs, axis=0)
    mean_tpr[-1] = 1.0
    mean_auc = np.mean(aucs)
    std_auc = np.std(aucs)
    ax.plot(mean_fpr, mean_tpr, color='blue',
            label=f'Mean ROC (AUC = {mean_auc:.2f} $\\pm$ {std_auc:.2f}), N={len(roc_data_per_fold)} folds',
            lw=2.5, alpha=0.9)

    # Plot standard deviation area
    std_tpr = np.std(tprs, axis=0)
    tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
    tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
    ax.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=0.3,
                    label=r'$\pm$ 1 std. dev.')

    # Final plot settings
    ax.set(xlim=[-0.05, 1.05], ylim=[-0.05, 1.05],
           xlabel='False Positive Rate',
           ylabel='True Positive Rate',
           title=title)
    ax.legend(loc="lower right", fontsize=11)
    ax.grid(alpha=0.5)

    if save_path:
        # Ensure the directory exists
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"   Plot saved to: {save_path}")
        
    plt.show()
    plt.close(fig)

In [None]:
# --- Imports ---
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import gc
from typing import Dict, Any, Tuple, Optional, List

# --- Scikit-learn and PyTorch Geometric Imports ---
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score, roc_curve
from torch_geometric.data import HeteroData

# --- Assumed Helper Imports (ensure these are available) ---
# from my_models import EndToEndMultiViewVAE_CL_AttentionRadiology, calculate_contrastive_loss

# --- Helper function for linear annealing ---
def linear_anneal(epoch: int, start_epoch: int, end_epoch: int, start_val: float, end_val: float) -> float:
    """Performs linear annealing."""
    if start_epoch >= end_epoch: return end_val
    if epoch < start_epoch: return start_val
    if epoch >= end_epoch: return end_val
    return start_val + (end_val - start_val) * (epoch - start_epoch) / (end_epoch - start_epoch)


# --- Main Training Function ---
def kfold_train_attention_radiology_cl(
    full_multi_view_data: HeteroData,
    model_config: Dict,
    train_config: Dict
) -> Tuple[Dict[str, float], List[Dict]]:
    """
    Performs K-fold cross-validation.
    - Scheduler/Early Stopping monitor TOTAL VALIDATION LOSS.
    - Best model state is saved based on highest VALIDATION AUC.
    - Returns aggregated metrics and ROC data from the best AUC epoch.
    """
    device = train_config['device']
    full_multi_view_data = full_multi_view_data.to(device)

    # --- Loss and Annealing Setup ---
    criterion_bce_logits = nn.BCEWithLogitsLoss()
    criterion_mse = nn.MSELoss()
    loss_weights_config = train_config['loss_weights']
    anneal_config = train_config.get('annealing', {})
    
    base_w_class = loss_weights_config['class']
    base_w_cross_cl = loss_weights_config.get('cross_cl', 0.0)
    base_w_kl = loss_weights_config['kl']
    w_rec_attr_config = loss_weights_config['rec_attr']
    w_rec_struct_config = loss_weights_config['rec_struct']
    print('base_w_cross_cl: ', base_w_cross_cl)

    kl_params = anneal_config.get('kl', {}); kl_start_w = kl_params.get('start_weight', base_w_kl); kl_end_w = kl_params.get('end_weight', base_w_kl); kl_start_e = kl_params.get('start_epoch', 0); kl_end_e = kl_params.get('end_epoch', 0)
    cl_params = anneal_config.get('cross_cl', {}); cl_start_w = cl_params.get('start_weight', base_w_cross_cl); cl_end_w = cl_params.get('end_weight', base_w_cross_cl); cl_start_e = cl_params.get('start_epoch', 0); cl_end_e = cl_params.get('end_epoch', 0)

    # --- Data Splitting ---
    all_patient_indices_np = np.arange(full_multi_view_data['patient'].num_nodes)
    y_for_stratification = full_multi_view_data['patient']['binary_label'].cpu().numpy()
    kf = KFold(n_splits=train_config['n_splits'], shuffle=True, random_state=train_config.get('random_seed', 40))

    fold_metrics_list = []
    roc_data_per_fold = []
    all_folds_detailed_logs = []

    for fold, (train_global_idx_np, val_global_idx_np) in enumerate(kf.split(all_patient_indices_np)):
        print(f"\n===== Fold {fold+1}/{train_config['n_splits']} =====")
        val_labels_for_check = y_for_stratification[val_global_idx_np]
        unique_labels, counts = np.unique(val_labels_for_check, return_counts=True)
        print(f"  Fold {fold+1} Validation Label Distribution: {dict(zip(unique_labels, counts))}")

        if wandb_params:
            full_wandb_config = {
                **wandb_params['run_config'],
                'fold': fold + 1,
                'train_config': train_config,
                'model_config': model_config 
            }
            
            wandb.init(
                project=wandb_params['project_name'],
                group=wandb_params['group_name'], 
                name=f"fold-{fold+1}",
                config=full_wandb_config,
                reinit=True 
            )

        
        train_fold_global_indices = torch.from_numpy(train_global_idx_np).to(device)
        val_fold_global_indices = torch.from_numpy(val_global_idx_np).to(device)

        model = EndToEndMultiViewVAE_CL_AttentionRadiology(
            view_configs=model_config['view_configs'],
            radiology_aggregator_config=model_config.get('radiology_aggregator_config'),
            projection_head_config=model_config['projection_head_config'],
            fusion_config=model_config['fusion_config'],
            classifier_config=model_config['classifier_config'],
            d_embed=model_config['d_embed'],
            missing_strategy=model_config.get('missing_strategy', 'zero')
        ).to(device)
        
        optimizer = torch.optim.AdamW(model.parameters(), lr=train_config['lr'], weight_decay=train_config['wd'])
        
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 
            mode='min', # We want to MINIMIZE loss
            factor=0.5, 
            patience=train_config.get('patience', 10), 
            verbose=True
        )

        best_val_loss = float('inf')
        best_val_auc_this_fold = -1.0
        epochs_no_improve = 0 # Now based on loss improvement
        best_model_state_fold = None
        best_fold_roc_data = {}
        fold_epoch_logs = []

        for epoch in range(1, train_config['epochs'] + 1):
            # --- Training Phase ---
            model.train()
            current_w_kl = linear_anneal(epoch, kl_start_e, kl_end_e, kl_start_w, kl_end_w)
            current_w_cl = linear_anneal(epoch, cl_start_e, cl_end_e, cl_start_w, cl_end_w)
            
            # (Training loss calculation logic remains identical)
            logits_train, vae_outputs_loss_train, mus_projected_for_cl_train, _ = model(full_multi_view_data, train_fold_global_indices)
            true_labels_train = full_multi_view_data['patient']['binary_label'][train_fold_global_indices]
            raw_loss_class_train = criterion_bce_logits(logits_train.squeeze(), true_labels_train.float())
            raw_loss_cl_train = calculate_contrastive_loss(mus_projected_for_cl_train, train_config['cross_cl_temp'])
            
            total_loss_rec_attr_train, total_loss_rec_struct_train, total_loss_kl_train = 0.0, 0.0, 0.0
            num_active_views_train = 0
            for view_name in model.views:
                vo = vae_outputs_loss_train.get(view_name, {})
                if vo and vo.get('mu') is not None:
                    num_active_views_train += 1
                    w_attr = w_rec_attr_config if isinstance(w_rec_attr_config, float) else w_rec_attr_config.get(view_name, 1.0)
                    total_loss_rec_attr_train += w_attr * criterion_mse(vo['rec_x'], vo['original_x_subset'])
                    w_struct = w_rec_struct_config if isinstance(w_rec_struct_config, float) else w_rec_struct_config.get(view_name, 1.0)
                    total_loss_rec_struct_train += w_struct * criterion_bce_logits(vo['rec_adj_logits'].reshape(-1), vo['original_adj_subset'].reshape(-1))
                    kl_div = -0.5 * torch.sum(1 + vo['logvar'] - vo['mu'].pow(2) - vo['logvar'].exp(), dim=1).mean()
                    total_loss_kl_train += kl_div
            
            avg_loss_rec_attr_train = total_loss_rec_attr_train / num_active_views_train if num_active_views_train > 0 else 0.0
            avg_loss_rec_struct_train = total_loss_rec_struct_train / num_active_views_train if num_active_views_train > 0 else 0.0
            avg_raw_loss_kl_train = total_loss_kl_train / num_active_views_train if num_active_views_train > 0 else 0.0
            
            total_train_loss = (base_w_class * raw_loss_class_train +
                                current_w_cl * raw_loss_cl_train +
                                avg_loss_rec_attr_train + 
                                avg_loss_rec_struct_train +
                                current_w_kl * avg_raw_loss_kl_train)

            epoch_log['train_total_loss'] = total_train_loss.item()
            epoch_log['train_class_loss'] = raw_loss_class_train.item()
            epoch_log['train_cl_loss'] = raw_loss_cl_train.item()
            epoch_log['train_rec_attr_loss'] = avg_loss_rec_attr_train.item()
            epoch_log['train_rec_struct_loss'] = avg_loss_rec_struct_train.item()
            epoch_log['train_kl_loss'] = avg_raw_loss_kl_train.item()
            epoch_log['w_kl'] = current_w_kl
            epoch_log['w_cl'] = current_w_cl

            if not torch.isnan(total_train_loss):
                optimizer.zero_grad()
                total_train_loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), max_norm=train_config.get('grad_clip_norm', 1.0))
                optimizer.step()

            # --- Validation Phase ---
            model.eval()
            total_validation_loss = torch.tensor(float('inf'), device=device)
            current_val_auc = -1.0

            if val_fold_global_indices.numel() > 0:
                with torch.no_grad():
                    val_logits_raw, vae_outputs_loss_val, mus_projected_for_cl_val, _ = model(full_multi_view_data, val_fold_global_indices)
                    val_labels = full_multi_view_data['patient']['binary_label'][val_fold_global_indices]

                    raw_loss_class_val = criterion_bce_logits(val_logits_raw.squeeze(), val_labels.float())
                    raw_loss_cl_val = calculate_contrastive_loss(mus_projected_for_cl_val, train_config['cross_cl_temp'])
                    total_loss_rec_attr_val, total_loss_rec_struct_val, total_loss_kl_val = 0.0, 0.0, 0.0
                    num_active_views_val = 0
                    for view_name in model.views:
                        vo_val = vae_outputs_loss_val.get(view_name, {})
                        if vo_val and vo_val.get('mu') is not None:
                            num_active_views_val += 1
                            w_attr = w_rec_attr_config if isinstance(w_rec_attr_config, float) else w_rec_attr_config.get(view_name, 1.0)
                            total_loss_rec_attr_val += w_attr * criterion_mse(vo_val['rec_x'], vo_val['original_x_subset'])
                            w_struct = w_rec_struct_config if isinstance(w_rec_struct_config, float) else w_rec_struct_config.get(view_name, 1.0)
                            total_loss_rec_struct_val += w_struct * criterion_bce_logits(vo_val['rec_adj_logits'].reshape(-1), vo_val['original_adj_subset'].reshape(-1))
                            kl_div_val = -0.5 * torch.sum(1 + vo_val['logvar'] - vo_val['mu'].pow(2) - vo_val['logvar'].exp(), dim=1).mean()
                            total_loss_kl_val += kl_div_val
                    
                    avg_loss_rec_attr_val = total_loss_rec_attr_val / num_active_views_val if num_active_views_val > 0 else 0.0
                    avg_loss_rec_struct_val = total_loss_rec_struct_val / num_active_views_val if num_active_views_val > 0 else 0.0
                    avg_raw_loss_kl_val = total_loss_kl_val / num_active_views_val if num_active_views_val > 0 else 0.0
                    
                    # Use the final annealed weights for the validation loss to have a stable target
                    total_validation_loss = (base_w_class * raw_loss_class_val +
                                            cl_end_w * raw_loss_cl_val +
                                            avg_loss_rec_attr_val + 
                                            avg_loss_rec_struct_val +
                                            kl_end_w * avg_raw_loss_kl_val)

                    if not torch.isnan(val_logits_raw).any() and len(np.unique(val_labels.cpu().numpy())) > 1:
                        val_probs_np = torch.sigmoid(val_logits_raw.squeeze()).cpu().numpy()
                        current_val_auc = roc_auc_score(val_labels.cpu().numpy(), val_probs_np)

                    val_loss_metrics['val_total_loss'] = total_validation_loss.item()
                    val_loss_metrics['val_class_loss'] = raw_loss_class_val.item()
                    val_loss_metrics['val_cl_loss'] = raw_loss_cl_val.item()
                    val_loss_metrics['val_rec_attr_loss'] = avg_loss_rec_attr_val.item()
                    val_loss_metrics['val_rec_struct_loss'] = avg_loss_rec_struct_val.item()
                    val_loss_metrics['val_kl_loss'] = avg_raw_loss_kl_val.item()
                    val_loss_metrics['val_auc'] = current_val_auc

            epoch_log.update(val_loss_metrics)
            
            if epoch % train_config.get('print_every_k_epochs', 10) == 0:
                print(f"  F{fold+1} Ep{epoch:03d} TLoss:{total_train_loss.item():.4f} | "
                      f"VLoss:{total_validation_loss.item():.4f} (Best VLoss: {best_val_loss:.4f}) | "
                      f"ValAUC:{current_val_auc:.4f} (Best ValAUC: {best_val_auc_this_fold:.4f})")

            # 1. Scheduler and Early Stopping are driven by Validation Loss
            scheduler.step(total_validation_loss)
            if total_validation_loss < best_val_loss:
                best_val_loss = total_validation_loss
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1

            # 2. Best Model State saving is driven by Validation AUC
            if current_val_auc != -1.0 and current_val_auc > best_val_auc_this_fold:
                best_val_auc_this_fold = current_val_auc
                best_model_state_fold = model.state_dict().copy()
                val_labels_np = val_labels.cpu().numpy()
                fpr, tpr, _ = roc_curve(val_labels_np, val_probs_np)
                best_fold_roc_data = {
                    'fpr': fpr, 'tpr': tpr, 'auc': current_val_auc,
                    'y_true': val_labels_np, 'y_pred': val_probs_np
                }

            fold_epoch_logs.append(epoch_log)
            
            if epochs_no_improve >= train_config.get('patience_early_stopping', 20):
                print(f"  Early stopping at epoch {epoch} for fold {fold+1} due to validation loss stagnation.")
                break
        
        # --- End of Fold ---
        if best_fold_roc_data:
            roc_data_per_fold.append(best_fold_roc_data)
            y_true = best_fold_roc_data['y_true']
            y_pred_probs = best_fold_roc_data['y_pred']
            y_pred_binary = (y_pred_probs > 0.5).astype(int)

            fold_results = {
                'auc': best_fold_roc_data['auc'],
                'f1': f1_score(y_true, y_pred_binary, zero_division=0),
                'accuracy': accuracy_score(y_true, y_pred_binary),
                'precision': precision_score(y_true, y_pred_binary, zero_division=0),
                'recall': recall_score(y_true, y_pred_binary, zero_division=0)
            }
            fold_metrics_list.append(fold_results)
            
        else:
           
            fold_metrics_list.append({'auc': np.nan, 'f1': np.nan, 'accuracy': np.nan, 'precision': np.nan, 'recall': np.nan})
        
        del model, optimizer, scheduler, best_model_state_fold
        if torch.cuda.is_available(): torch.cuda.empty_cache(); gc.collect()

    if not fold_metrics_list:
        print("Warning: No metrics were collected during cross-validation.")
        return {}, []

    df_fold_metrics = pd.DataFrame(fold_metrics_list)
    mean_metrics = df_fold_metrics.mean()
    std_metrics = df_fold_metrics.std()
    
    results_summary = {}
    for metric in ['auc', 'f1', 'accuracy', 'precision', 'recall']:
        results_summary[f'mean_{metric}'] = mean_metrics.get(metric, np.nan)
        results_summary[f'std_{metric}'] = std_metrics.get(metric, np.nan)

    print("\n--- Cross-Validation Summary (based on true best AUC epochs) ---")
    for key, value in results_summary.items():
        print(f"  {key}: {value:.4f}")

    return results_summary, df_fold_metrics, roc_data_per_fold

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

data = torch.load('multi_view_pdl1_data_lesions_247_thresh_07_robust.pt')
DIM_CLINICAL = data['patient'].x_clinical.shape[1]
DIM_PATHOLOGY = data['patient'].x_pathology.shape[1] # Dimension of GLCM features
DIM_RADIOLOGY = data['lesion'].x.shape[1] 
print(f"Clinical Dim: {DIM_CLINICAL}, Path Dim: {DIM_PATHOLOGY}, Rad Dim: {DIM_RADIOLOGY}")
# -------------------------------------------

if any(d is None or d <= 0 for d in [DIM_CLINICAL, DIM_PATHOLOGY, DIM_RADIOLOGY]):
     raise ValueError("Please determine and set the actual input dimensions (DIM_...)")

train_config_new = {
    # --- Data and Device ---
    'data_path': 'multi_view_pdl1_data_lesions_247_thresh_07_robust.pt', 
    'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu'),

    # --- KFold and Epochs ---
    'n_splits': 10,      
    'epochs': 100,       
    'patience': 10,
    'patience_early_stopping': 35,
    # --- Optimizer ---
    'lr': 0.0001,        # Learning rate (might need to be smaller for complex models)
    'wd': 1e-5,          # Weight decay for AdamW

    # --- Loss Component Weights ---
    # These are critical and require tuning!
    'loss_weights': {
        'class': 1.0,          # Weight for the main classification loss
        'cross_cl': 0.5,    
        'rec_struct': 0.1,    
        'kl': 0.001,           
    },
    'annealing': {
        'kl': {
            'start_weight': 0.00001,
            'end_weight': 0.001,     # Matches loss_weights['kl']
            'start_epoch': 40,        # Start KL annealing from epoch 1
            'end_epoch': 100          # Reach target KL weight by epoch 50
        },
        'cross_cl': {
            'start_weight': 0.5,
            'end_weight': 0.5,       # Matches loss_weights['cross_cl']
            'start_epoch': 20,       
            'end_epoch': 100
        },
    },
    # --- Contrastive Learning Temperatures ---
    'cross_cl_temp': 0.1,  # Temperature for cross-view contrastive loss
    'intra_cl_temp': 0.1, 
    'grad_clip_norm': 1.0,
    'print_epoch_freq': 1, 
}

model_config_new = {
    'view_configs': {
        'clinical': {'in_channels': DIM_CLINICAL, 'hidden_channels_vae': 64, 'heads': 8, 'dropout': 0.5, 'num_gnn_layers': 2, 'edge_dim': 1},
        'pathology': {'in_channels': DIM_PATHOLOGY, 'hidden_channels_vae': 64, 'heads': 8, 'dropout': 0.5, 'num_gnn_layers': 2, 'edge_dim': 1}, 
        'radiology': {
            'in_channels': 128,  
            'hidden_channels_vae': 64, 
            'heads': 8, 'dropout': 0.5, 'num_gnn_layers': 2, 'edge_dim': 1 # Patient-patient sim graph
        },
    },
    'radiology_aggregator_config': {
        'lesion_feature_dim': DIM_RADIOLOGY, # dim of one scaled lesion feature vector 1671
        'aggregated_output_dim': 128, # Output dim of aggregator, input to VAE['radiology']
        'attention_hidden_dim': 128,   # Hidden dim for attention MLP inside aggregator
        'dropout': 0.5
    },
    'fusion_config': {
        'hidden_dim_attention': 64,  # Hidden dimension for the Attention MLP within the fusion layer
        'fused_dim': 64,             # Dimension AFTER fusion (should match d_embed if attention returns weighted sum)
        'dropout_class': 0.5,          # Dropout for the final classifier MLP
        'num_fusion_heads': 8,
        'fusion_ffn_multiplier' : 2,
        'num_fusion_transformer_layers': 2
    },
    'classifier_config': {            
         'hidden_dim_classifier': 64
    },
    'projection_head_config': {
        'hidden_dim': 64,    # Hidden dimension of the projection MLP
        'output_dim': 64,    # Output dimension for contrastive learning (can be same as d_embed or different)
        'dropout': 0.5
    },
    'd_embed': 64,                   
    'missing_strategy': 'learnable'
}

full_multi_view_data = torch.load(train_config_new['data_path'])
# results = kfold_train_attention_radiology_cl(full_multi_view_data, model_config_new, train_config_new)

Clinical Dim: 64, Path Dim: 137, Rad Dim: 1671


In [None]:
print("Starting K-Fold Cross-Validation...")
results_summary, df_fold_metrics, roc_data_per_fold = kfold_train_attention_radiology_cl(
    full_multi_view_data, 
    model_config_new, 
    train_config_new
)


Starting K-Fold Cross-Validation...

===== Fold 1/10 =====
  Fold 1 Validation Label Distribution: {0: 8, 1: 17}
  F1 Ep010 TLoss:6.0158 | VLoss:5.9845 (Best VLoss: 6.0601) | ValAUC:0.6912 (Best ValAUC: 0.6838)
  F1 Ep020 TLoss:5.6368 | VLoss:5.1109 (Best VLoss: 5.1752) | ValAUC:0.6618 (Best ValAUC: 0.6985)
  F1 Ep030 TLoss:5.3358 | VLoss:4.6901 (Best VLoss: 4.7251) | ValAUC:0.6471 (Best ValAUC: 0.6985)
  F1 Ep040 TLoss:5.1301 | VLoss:4.5106 (Best VLoss: 4.5207) | ValAUC:0.6250 (Best ValAUC: 0.6985)
  F1 Ep050 TLoss:4.9983 | VLoss:4.3873 (Best VLoss: 4.3902) | ValAUC:0.6029 (Best ValAUC: 0.6985)
  F1 Ep060 TLoss:4.9496 | VLoss:4.3280 (Best VLoss: 4.3303) | ValAUC:0.5882 (Best ValAUC: 0.6985)
  F1 Ep070 TLoss:4.9328 | VLoss:4.2138 (Best VLoss: 4.2242) | ValAUC:0.6029 (Best ValAUC: 0.6985)
  F1 Ep080 TLoss:4.8585 | VLoss:4.1262 (Best VLoss: 4.1390) | ValAUC:0.6324 (Best ValAUC: 0.6985)
  F1 Ep090 TLoss:4.8604 | VLoss:4.0900 (Best VLoss: 4.0859) | ValAUC:0.6324 (Best ValAUC: 0.6985)
  F1 