In [15]:
import pandas as pd
import os
import json
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import torch 

import pickle
from kg import KnowledgeGraph


In [16]:
def load_data(save_dir, max_patients):
    graphs = []
    files = [f for f in os.listdir(save_dir) if f.endswith(".pt")]

    # Optionally, sort files to ensure consistent order
    files.sort()

    for file in files[:max_patients]:  # Limit to max_patients
        graphs.append(torch.load(os.path.join(save_dir, file)))
    
    return graphs

save_dir = "./Data/saved_graphs_3/train"
train_pg_subgraph = load_data(save_dir, max_patients=None)

save_dir = "./Data/saved_graphs_3/val"
val_pg_subgraph = load_data(save_dir, max_patients=7)

save_dir = "./Data/saved_graphs_3/test"
test_pg_subgraph = load_data(save_dir, max_patients=3)


  graphs.append(torch.load(os.path.join(save_dir, file)))


In [17]:
def identify_overlapping_training_data(train_data, val_data, test_data):
    """
    Identify training samples with gene IDs that overlap with validation and test data,
    and assign them to a new variable.
    
    Args:
        train_data: List of training samples
        val_data: List of validation samples
        test_data: List of test samples
        
    Returns:
        List of training samples that have overlapping gene IDs with val/test data
    """
    # Extract gene IDs from validation and test data
    val_test_gene_ids = set()
    
    # Process validation data
    for sample in val_data:
        if hasattr(sample, 'true_gene_ids'):
            gene_ids = sample.true_gene_ids
            if isinstance(gene_ids, list):
                val_test_gene_ids.update(gene_ids)
            else:
                val_test_gene_ids.add(gene_ids)
    
    # Process test data
    for sample in test_data:
        if hasattr(sample, 'true_gene_ids'):
            gene_ids = sample.true_gene_ids
            if isinstance(gene_ids, list):
                val_test_gene_ids.update(gene_ids)
            else:
                val_test_gene_ids.add(gene_ids)
    
    print(f"Found {len(val_test_gene_ids)} unique gene IDs in validation and test data")
    
    # Find training samples with overlapping gene IDs
    training_pg_subgraph_test = []
    overlapping_gene_map = {}  # To track which genes are found in which training samples
    
    for sample in train_data:
        if hasattr(sample, 'true_gene_ids'):
            gene_ids = sample.true_gene_ids
            found_overlap = False
            
            if isinstance(gene_ids, list):
                # Check if any gene ID in the list is in val_test_gene_ids
                for gene_id in gene_ids:
                    if gene_id in val_test_gene_ids:
                        found_overlap = True
                        overlapping_gene_map[gene_id] = sample
            else:
                # Check if the single gene ID is in val_test_gene_ids
                if gene_ids in val_test_gene_ids:
                    found_overlap = True
                    overlapping_gene_map[gene_ids] = sample
            
            if found_overlap:
                training_pg_subgraph_test.append(sample)
    
    print(f"Identified {len(training_pg_subgraph_test)} training samples with overlapping gene IDs")
    print(f"Found matches for {len(overlapping_gene_map)}/{len(val_test_gene_ids)} validation/test gene IDs")
    
    return training_pg_subgraph_test, overlapping_gene_map

# Identify and assign overlapping training data to training_pg_subgraph_test
training_pg_subgraph_test, gene_to_sample_map = identify_overlapping_training_data(
    train_pg_subgraph, val_pg_subgraph, test_pg_subgraph
)

# Display information abou

Found 10 unique gene IDs in validation and test data
Identified 206 training samples with overlapping gene IDs
Found matches for 10/10 validation/test gene IDs


In [18]:
def extract_overlapping_samples(train_data, val_data, test_data):
    """
    Extract training samples with gene IDs that overlap with validation and test data.
    
    Args:
        train_data: List of training samples
        val_data: List of validation samples
        test_data: List of test samples
        
    Returns:
        List of training samples with overlapping gene IDs
    """
    # Extract gene IDs from validation and test data
    val_test_gene_ids = set()
    
    # Process validation and test data to get all unique gene IDs
    for dataset in [val_data, test_data]:
        for sample in dataset:
            if hasattr(sample, 'true_gene_ids'):
                gene_ids = sample.true_gene_ids
                if isinstance(gene_ids, list):
                    val_test_gene_ids.update(gene_ids)
                else:
                    val_test_gene_ids.add(gene_ids)
    
    # Find and collect training samples with overlapping gene IDs
    overlapping_samples = []
    
    for sample in train_data:
        if hasattr(sample, 'true_gene_ids'):
            gene_ids = sample.true_gene_ids
            found_overlap = False
            
            if isinstance(gene_ids, list):
                # Check if any gene ID in the list is in val_test_gene_ids
                for gene_id in gene_ids:
                    if gene_id in val_test_gene_ids:
                        found_overlap = True
                        break
            else:
                # Check if the single gene ID is in val_test_gene_ids
                if gene_ids in val_test_gene_ids:
                    found_overlap = True
            
            if found_overlap:
                overlapping_samples.append(sample)
    
    return overlapping_samples

# Extract the 520 overlapping training samples
training_pg_subgraph_test = extract_overlapping_samples(
    train_pg_subgraph, val_pg_subgraph, test_pg_subgraph
)

# Verify the extraction
print(f"Successfully extracted {len(training_pg_subgraph_test)} training samples with overlapping gene IDs")

# Optional: Examine a few samples to confirm
if len(training_pg_subgraph_test) > 0:
    sample = training_pg_subgraph_test[0]
    gene_ids = sample.true_gene_ids
    print(f"First sample gene IDs: {gene_ids if not isinstance(gene_ids, list) else gene_ids[:3]}")

Successfully extracted 206 training samples with overlapping gene IDs
First sample gene IDs: [6607]


In [19]:
## To get the number of true gene and set true gene as data.y label
all_true_gene_ids = []

for patient in training_pg_subgraph_test:
    all_true_gene_ids.extend(patient.true_gene_ids)

for patient in val_pg_subgraph:
    all_true_gene_ids.extend(patient.true_gene_ids)
    
for patient in test_pg_subgraph:
    all_true_gene_ids.extend(patient.true_gene_ids)

## Get the unique true gene ids
unique_true_gene_ids = set(all_true_gene_ids)
print("the number of unique true gene ids is:",len(unique_true_gene_ids))

## Mapping all unique true gene ids to a index from 0 to the number of unique true gene ids
gene_id_mapping = {gene_id: idx for idx, gene_id in enumerate(unique_true_gene_ids)}
print(gene_id_mapping)

## Add the true gene ids back to the graph
for patient in training_pg_subgraph_test:
    patient.y = torch.tensor([gene_id_mapping[gene_id] for gene_id in patient.true_gene_ids], dtype=torch.long)

for patient in val_pg_subgraph:
    patient.y = torch.tensor([gene_id_mapping[gene_id] for gene_id in patient.true_gene_ids], dtype=torch.long)

for patient in test_pg_subgraph:
    patient.y = torch.tensor([gene_id_mapping[gene_id] for gene_id in patient.true_gene_ids], dtype=torch.long)

the number of unique true gene ids is: 10
{5281: 0, 642: 1, 2889: 2, 362: 3, 12783: 4, 6607: 5, 5199: 6, 10802: 7, 8566: 8, 6137: 9}


In [20]:
## Preprocess the trainign data, extract only x, y, edge_index

from torch_geometric.data import Data

def preprocess_graph_data(dataset):
    processed_graphs = []
   
    for data in dataset:
        
        new_data = Data(
            edge_index=data.edge_index,
            y=data.y,
            x=data.x,
            original_ids = data.original_ids,
            edge_attr=data.edge_attr
        )
        processed_graphs.append(new_data)
    
    return processed_graphs

train_data = preprocess_graph_data(training_pg_subgraph_test)
val_data = preprocess_graph_data(val_pg_subgraph)
test_data = preprocess_graph_data(test_pg_subgraph)

In [21]:
## Define collate function for handling batched data
def optimized_collate_fn(batch):

    batch_size = len(batch)
    cumsum_nodes = 0
    
    # Adjust edge indices to account for the node offset in each batch
    adjusted_edge_indices = []
    for data in batch:
        edge_index = data.edge_index + cumsum_nodes
        adjusted_edge_indices.append(edge_index)
        cumsum_nodes += data.num_nodes

    # Concatenate with adjusted indices
    x = torch.cat([data.x for data in batch], dim=0)
    y = torch.cat([data.y for data in batch], dim=0)
    edge_index = torch.cat(adjusted_edge_indices, dim=1)
    edge_attr = torch.cat([data.edge_attr for data in batch], dim=0) if batch[0].edge_attr is not None else None
    batch_tensor = torch.cat([torch.full((data.num_nodes,), i, dtype=torch.long) for i, data in enumerate(batch)])
    
    ## Additional attributes
    original_ids = torch.cat([torch.tensor(data.original_ids, dtype=torch.long) if isinstance(data.original_ids, list) else data.original_ids for data in batch if data.original_ids is not None])
    
    return Data(
        x = x,
        y = y,
        edge_index=edge_index,
        edge_attr=edge_attr,
        batch=batch_tensor,
        original_ids=original_ids,
        batch_size=batch_size,
      
    )
## torch.dataloader doesn't consider custom data types
from torch.utils.data import DataLoader

train_loader = DataLoader(train_data, batch_size=512, shuffle=True, collate_fn=optimized_collate_fn)
val_loader = DataLoader(val_data, batch_size=512, collate_fn=optimized_collate_fn)
test_loader = DataLoader(test_data, batch_size=128, collate_fn=optimized_collate_fn)


In [22]:
import torch.nn as nn

class GlobalNodeEmbedding(nn.Module):
    def __init__(self, num_global_nodes, embedding_dim):
        super().__init__()
        self.embedding = nn.Embedding(num_global_nodes, embedding_dim)
   
    def forward(self, node_ids):
        if isinstance(node_ids, list):
            node_ids = torch.tensor(node_ids, dtype=torch.long)
        node_ids = node_ids.view(-1)  # Flatten any multi-dimensional input

        embeddings = self.embedding(node_ids)
        return embeddings


In [23]:
class SoftHistogram(nn.Module):
    def __init__(self, bins, min, max, sigma):
        super().__init__()
        self.bins = bins
        self.min = min
        self.max = max
        self.sigma = sigma
        self.delta = (max - min) / bins
        self.centers = torch.linspace(min + self.delta/2, max - self.delta/2, bins)
        
    def forward(self, x):
        x = x.view(-1, 1)
        centers = self.centers.to(x.device).view(1, -1)
        # Compute Gaussian smoothed histogram
        x = torch.exp(-(x - centers)**2 / (2 * self.sigma**2))
        # Normalize each data point's contribution
        x = x / x.sum(dim=1, keepdim=True)
        # Sum over all data points
        histogram = x.sum(dim=0)
        return histogram

In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import (
    GCNConv, GraphConv, SAGEConv, GIN, 
    global_mean_pool, global_add_pool
)
import numpy as np

## Node-Level Module (F1)
class F1NodeLevelModule(nn.Module):
    def __init__(self, input_dim, hidden_dim, embedding_dim, conv_type='GCN',
                 dropout=0.5, pooling="mean", num_layers=2):
        super().__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.embedding_dim = embedding_dim
        self.dropout = dropout

        # Choose GNN Layer Type
        if conv_type == "GCN":
            Conv = GCNConv
        elif conv_type == "Graph":
            Conv = GraphConv
        elif conv_type == "SAGE":
            Conv = SAGEConv
        elif conv_type == "GIN":
            Conv = lambda in_dim, out_dim: GIN(
                nn=nn.Sequential(
                    nn.Linear(in_dim, out_dim),
                    nn.ReLU(),
                    nn.Linear(out_dim, out_dim)
                )
            )
        else:
            raise ValueError(f"Unknown conv_type: {conv_type}")

        # Define GNN layers
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()

        # First Layer
        self.convs.append(Conv(input_dim, hidden_dim))
        self.bns.append(nn.LayerNorm(hidden_dim))  # Using LayerNorm for stability

        # Hidden Layers (Adding Residual Connections)
        for _ in range(num_layers):
            self.convs.append(Conv(hidden_dim, hidden_dim))
            self.bns.append(nn.LayerNorm(hidden_dim))

        # Last Layer
        self.convs.append(Conv(hidden_dim, embedding_dim))
        self.bns.append(nn.LayerNorm(embedding_dim))  # Ensure correct dimension

        self.dropout_layer = nn.Dropout(dropout)
        self.pooling = global_mean_pool if pooling == "mean" else global_add_pool


    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        for conv, bn in zip(self.convs[:-1], self.bns[:-1]):
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)
            x = self.dropout_layer(x)

        x = self.convs[-1](x, edge_index)
        x = self.bns[-1](x)
        x = F.leaky_relu(x, 0.2)
        x = self.dropout_layer(x)

        node_embeddings = x
        graph_embeddings = self.pooling(node_embeddings, batch)

        return node_embeddings, graph_embeddings


## Population-Level Module (F2)
class F2PopulationLevelGraph(nn.Module):
    def __init__(self, embedding_dim, latent_dim, temperature=0.5, threshold=0.1):
        super().__init__()
        
        # Simplify the transformation network
        self.latent_transform = nn.Sequential(
            nn.Linear(embedding_dim, latent_dim),
            nn.LayerNorm(latent_dim),
            nn.LeakyReLU(0.2)
        )

        # Use direct parameters like the tutor
        self.temp = nn.Parameter(torch.tensor(temperature, dtype=torch.float32))
        self.theta = nn.Parameter(torch.tensor(threshold, dtype=torch.float32))
        self.mu = nn.Parameter(torch.tensor(2.0, dtype=torch.float32))
        self.sigma = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))
        
    def forward(self, graph_embeddings):
        # Transform to latent space with the neural network
        latent_space = self.latent_transform(graph_embeddings)
        
        # Compute the pairwise differences exactly as tutor does
        diff = latent_space.unsqueeze(1) - latent_space.unsqueeze(0)
        # Compute the squared norm
        diff = torch.pow(diff, 2).sum(2)
        mask_diff = diff != 0.0
        dist = - torch.sqrt(diff + torch.finfo(torch.float32).eps)
        dist = dist * mask_diff
        
        # Apply temperature and threshold like tutor
        prob_matrix = self.temp * dist + self.theta
        
        # Add eye and sigmoid
        adj = prob_matrix + torch.eye(prob_matrix.shape[0]).to(prob_matrix.device)
        adjacency_matrix = torch.sigmoid(adj)
        
        # Extract edges for the GNN
        edge_indices = torch.nonzero(adjacency_matrix > 0.1, as_tuple=False)
        edge_index = edge_indices.t()
        edge_weight = adjacency_matrix[edge_indices[:, 0], edge_indices[:, 1]]
        
        # Calculate KL loss using tutor's approach
        n_nodes = adjacency_matrix.shape[0]
        softhist = SoftHistogram(bins=n_nodes, min=0.5, max=n_nodes + 0.5, sigma=0.6)
        kl_loss = self._compute_kl_loss(adjacency_matrix, n_nodes, softhist)
        
        return adjacency_matrix, edge_index, edge_weight, kl_loss
    
    def _compute_kl_loss(self, adj, batch_size, softhist):
        # Create binary adjacency matrix
        binarized_adj = torch.zeros(adj.shape).to(adj.device)
        binarized_adj[adj > 0.5] = 1
        
        # Get distribution and degrees
        dist, deg = self._compute_distr(adj * binarized_adj, softhist)
        
        # Get target distribution
        target_dist = self._compute_target_distribution(batch_size)
        
        # Calculate KL divergence
        kl_loss = self._kl_div(dist, target_dist)
        return kl_loss
    
    def _kl_div(self, p, q):
        return torch.sum(p * torch.log(p / (q + 1e-8) + 1e-8))
    
    def _compute_distr(self, adj, softhist):
        deg = adj.sum(-1)
        distr = softhist(deg)
        return distr / torch.sum(distr), deg
    
    def _compute_target_distribution(self, batch_size):
        """Compute Gaussian target distribution like tutor"""
        device = self.mu.device
        target_distribution = torch.zeros(batch_size).to(device)
        
        # Use all bins
        indices = torch.arange(batch_size, device=device)
        
        # Create Gaussian distribution 
        target_distribution = torch.exp(
            -((self.mu - indices) ** 2) / (self.sigma ** 2)
        )
        
        # Normalize
        return target_distribution / target_distribution.sum()

## Classifier Module (F3)
class F3Classifier(nn.Module):
    def __init__(self, input_dim_h, gnn_hidden_dim, num_classes, conv_type="GCN", gnn_layers=2, dropout=0.3):
        super().__init__()
        
        self.input_dim_h = input_dim_h
        self.gnn_hidden_dim = gnn_hidden_dim
        self.num_classes = num_classes
        
        # Choose GNN Layer
        if conv_type == "GCN":
            Conv = GCNConv
        else:
            Conv = GraphConv
        
        # Complex input transformation with multiple pathways
        self.input_transform = nn.Sequential(
            # Branch 1: Direct mapping
            nn.Linear(input_dim_h, gnn_hidden_dim),
            nn.BatchNorm1d(gnn_hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        self.deep_transform = nn.Sequential(
            nn.Linear(input_dim_h, gnn_hidden_dim),
            nn.LayerNorm(gnn_hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout/2),
            nn.Linear(gnn_hidden_dim, gnn_hidden_dim),
            nn.LayerNorm(gnn_hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout/2)
        )
        
        # Feature attention layer
        self.feature_attention = nn.Sequential(
            nn.Linear(gnn_hidden_dim, gnn_hidden_dim // 4),
            nn.ReLU(),
            nn.Linear(gnn_hidden_dim // 4, gnn_hidden_dim),
            nn.Sigmoid()
        )
        
        # Integration layer
        self.integration = nn.Sequential(
            nn.Linear(gnn_hidden_dim * 2, gnn_hidden_dim),
            nn.LayerNorm(gnn_hidden_dim),
            nn.ReLU()
        )
        
        # Simpler GNN stack
        self.gnn_layers = nn.ModuleList([Conv(gnn_hidden_dim, gnn_hidden_dim) for _ in range(gnn_layers)])
        
        # In F3Classifier
        self.classifier = nn.Sequential(
            nn.Linear(gnn_hidden_dim, gnn_hidden_dim*2),
            nn.BatchNorm1d(gnn_hidden_dim*2),
            nn.LeakyReLU(negative_slope=0.1),
            nn.Dropout(dropout),
            nn.Linear(gnn_hidden_dim*2, gnn_hidden_dim),
            nn.BatchNorm1d(gnn_hidden_dim),
            nn.LeakyReLU(negative_slope=0.1), 
            nn.Linear(gnn_hidden_dim, num_classes)
        )

    def forward(self, h, edge_index, batch, Ap=None, edge_weight=None, gene_ids=None):
        # First transform input
        h = F.relu(self.input_transform(h))
        
        # Process through GNN layers
        for gnn in self.gnn_layers:
            if edge_weight is not None and isinstance(gnn, (GCNConv, GraphConv)):
                h = gnn(h, edge_index, edge_weight)
            else:
                h = gnn(h, edge_index)
            h = F.relu(h)
        
        # Simple pooling
        graph_embeddings = global_mean_pool(h, batch)
        
        # Final classification
        logits = self.classifier(graph_embeddings)
        
        return logits

In [25]:
class GiG(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.node_level_module = F1NodeLevelModule(
            input_dim=config["input_dim"],
            hidden_dim=config["hidden_dim"],
            embedding_dim=config["embedding_dim"],
            conv_type=config["conv_type"],
            dropout=config["dropout"]
        )
        
        self.population_level_module = F2PopulationLevelGraph(
            embedding_dim=config["embedding_dim"],
            latent_dim=config["latent_dim"]
        )
        
        self.classifier = F3Classifier(
            input_dim_h=config["embedding_dim"],  # Use embedding_dim from F1
            gnn_hidden_dim=config["gnn_hidden_dim"],
            num_classes=config["num_classes"],
            conv_type=config["conv_type"],
            gnn_layers=config["gnn_layers"],
            dropout=config["dropout"]
        )
    def forward(self, data):
        # Process inputs through node-level module
        node_embeddings, graph_embeddings = self.node_level_module(data)
        
        # Process through population-level module
        adjacency_matrix, edge_index, edge_weight, kl_loss = self.population_level_module(graph_embeddings)
        
        # Classifier takes embeddings and edge structure
        logits = self.classifier(
            node_embeddings,
            edge_index, 
            data.batch,
            adjacency_matrix,
            edge_weight
        )
        
        return logits, adjacency_matrix, kl_loss


In [26]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchmetrics
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR

# Define Loss Dictionary
losses = nn.ModuleDict({
    'BCEWithLogitsLoss': nn.BCEWithLogitsLoss(),
    'CrossEntropyLoss': nn.CrossEntropyLoss(),
    'MultiTaskBCE': nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([10]))
})

class GiGTrainer(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.save_hyperparameters(config)
        self.automatic_optimization = False  # Manual optimization

        # Initialize the model
        self.model = GiG(config)

        # Check and define GlobalNodeEmbedding properly
        if "GlobalNodeEmbedding" in globals():
            self.global_node_embedding = GlobalNodeEmbedding(
                num_global_nodes=105220, embedding_dim=config["input_dim"]
            )
        else:
            self.global_node_embedding = nn.Embedding(
                num_embeddings=105220, embedding_dim=config["input_dim"]
            )

        # Set loss function
        self.initial_loss = losses[config["loss"]]
        self.alpha = config["alpha"]

        # Store embeddings for debugging
        self.node_embeddings = None

    def forward(self, data):
        return self.model(data)

    def _shared_step(self, data, addition):
        """Common logic for train, validation, and test steps with NaN protection."""
        # Ensure input embeddings are used correctly
        data.x = self.global_node_embedding(data.original_ids.long().to(self.device))
        
        # Forward pass
        logits, adj_matrix, kl_loss = self.model(data)
        
        # Prepare labels
        labels = data.y.view(-1).long().to(self.device)
        
        # Compute classification loss
        classification_loss = self.initial_loss(logits, data.y.view(-1).long())

        
        # Handle NaN classification loss
        if torch.isnan(classification_loss):
            print(f"Warning: NaN in classification loss detected in {addition} step")
            classification_loss = torch.tensor(5.0, device=self.device)  # Reasonable default
        
        # Apply very small KL loss weight initially
        kl_weight = self.alpha 
    
        # Total loss
        total_loss = classification_loss + kl_weight * kl_loss
        
        # Compute metrics safely
        try:
            acc = torchmetrics.functional.accuracy(
                logits.argmax(dim=-1), labels, task="multiclass", num_classes=logits.shape[1]
            )
            f1 = torchmetrics.functional.f1_score(
                logits.argmax(dim=-1), labels, task="multiclass", num_classes=logits.shape[1]
            )
        except Exception as e:
            print(f"Error computing metrics: {e}")
            acc = torch.tensor(0.0)
            f1 = torch.tensor(0.0)
                
        # Log metrics
        metrics = {
            f"{addition}_acc": acc,
            f"{addition}_f1": f1,
            f"{addition}_loss": total_loss,
            f"{addition}_classification_loss": classification_loss,
            f"{addition}_kl_loss": kl_loss,
        }
        
        return metrics, total_loss

    def training_step(self, batch, batch_idx):
        """Training step with gradient clipping and NaN detection."""
        # Get optimizers
        main_optimizer, lgl_optimizer = self.optimizers()
        
        # Zero gradients
        main_optimizer.zero_grad()
        lgl_optimizer.zero_grad()
        
        # Forward and loss calculation
        metrics, loss = self._shared_step(batch, "train")
    
        # Check if loss is valid
        if not torch.isfinite(loss):
            print(f"Warning: Non-finite loss detected: {loss}")
            # Return a default loss to continue training
            placeholder_loss = torch.tensor(5.0, device=self.device, requires_grad=True)
            # Use log_dict with on_step=True, on_epoch=True
            self.log_dict(metrics, prog_bar=True, batch_size=len(batch.y), 
                        on_step=True, on_epoch=True)
            return placeholder_loss
        
        # Backward pass
        self.manual_backward(loss)
        
        # Check for NaN/Inf gradients before optimizer step
        valid_gradients = True
        for name, param in self.named_parameters():
            if param.grad is not None:
                if not torch.isfinite(param.grad).all():
                    print(f"Warning: Non-finite gradients in {name}")
                    valid_gradients = False
                    param.grad = torch.zeros_like(param.grad)  # Reset problematic gradients
        
        # Clip gradients to prevent explosion
        if valid_gradients:
            torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
        
        # Update parameters
        main_optimizer.step()
        lgl_optimizer.step()
        
        # Log metrics - correctly specify on_step and on_epoch
        self.log_dict(metrics, prog_bar=True, batch_size=len(batch.y),
                    on_step=True, on_epoch=True)
        
        return loss

    def validation_step(self, batch, batch_idx):
        metrics, _ = self._shared_step(batch, "val")
        
        with torch.no_grad():  # Ensure we don't compute gradients
            node_embeddings, _ = self.model.node_level_module(batch)
            node_embeddings = node_embeddings.detach().cpu()
            
            # Store embeddings for later analysis
            if self.node_embeddings is None:
                self.node_embeddings = node_embeddings
            else:
                self.node_embeddings = torch.cat([self.node_embeddings, node_embeddings], dim=0)
    
        # Add on_step and on_epoch parameters
        self.log_dict(metrics, prog_bar=True, batch_size=len(batch.y),
                    on_step=False, on_epoch=True)
        return metrics

    def test_step(self, batch, batch_idx):
        metrics, _ = self._shared_step(batch, "test")
        # Add on_step and on_epoch parameters
        self.log_dict(metrics, batch_size=len(batch.y),
                    on_step=False, on_epoch=True)
        return metrics
    
    def configure_optimizers(self):
        """Set up optimizers and learning rate schedulers."""
        # Main parameters excluding population graph learnable scalars
        main_params = [
            param for name_, param in self.model.population_level_module.named_parameters()
            if name_ not in ["log_temperature", "log_threshold", "mu", "sigma"]
        ]
        main_params.extend(self.model.node_level_module.parameters())
        main_params.extend(self.model.classifier.parameters())

        # Define optimizers with weight decay
        main_optimizer = torch.optim.Adam(
            main_params, 
            lr=self.config["lr"],
            weight_decay=self.config.get("weight_decay", 1e-5)
        )
        
        # Updated parameter names for log-space parameters
        # In configure_optimizers method
        lgl_optimizer = torch.optim.Adam([
            self.model.population_level_module.temp,      # Changed from log_temperature
            self.model.population_level_module.theta,     # Changed from log_threshold
            self.model.population_level_module.mu,
            self.model.population_level_module.sigma
        ], lr=self.config["lr_theta_temp"])

        # Define learning rate scheduler
        if self.config["scheduler"] == "ReduceLROnPlateau":
            scheduler_dict = {
                "scheduler": ReduceLROnPlateau(
                    main_optimizer, mode="min", patience=5,
                    threshold=0.001, verbose=True
                ),
                "interval": "epoch",
                "monitor": "val_loss",
                "frequency": 1
            }
        else:
            scheduler_dict = {
                "scheduler": CosineAnnealingLR(main_optimizer, T_max=10),
                "interval": "epoch",
                "monitor": "val_loss",
                "frequency": 1
            }

        return [main_optimizer, lgl_optimizer], [scheduler_dict]

In [27]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import (
    ModelCheckpoint,
    EarlyStopping,
    LearningRateMonitor
)
import wandb
from pytorch_lightning.loggers import WandbLogger

# # Setup callbacks
# callbacks = [
#     ModelCheckpoint(
#         monitor='val_loss',
#         dirpath='checkpoints',
#         filename='gig-{epoch:02d}-{val_loss:.2f}',
#         save_top_k=3,
#         mode='min',
#     ),
#     EarlyStopping(
#         monitor='val_loss',
#         min_delta=0.1,
#         patience=10,
#         mode='min',
#         verbose=True
#     ),
#     LearningRateMonitor(logging_interval='epoch')
# ]

# Setup logger
wandb_logger = WandbLogger(project="gig-model")

config = {
    "input_dim": 64,
    "hidden_dim": 256,     
    "embedding_dim": 256,   
    "latent_dim": 16,      
    "gnn_hidden_dim": 32,  
    "num_classes": len(unique_true_gene_ids),
    "gene_embedding_dim": 16,
    
    "conv_type": "GCN",
    "gnn_layers": 5,
    "dropout": 0.3,        
    "lr": 0.0001,            
    "optimizer_lr": 0.001,
    "lr_theta_temp": 0.001,  
    "alpha": 0.01,     
   
    "loss": "CrossEntropyLoss",
    "scheduler": "ReduceLROnPlateau",
    "weight_decay": 1e-3   # Added weight decay
}


model = GiGTrainer(config)

trainer = Trainer(
    max_epochs=300,
    accelerator='gpu',
    devices='auto',
    # callbacks=callbacks,
    logger=wandb_logger,
    deterministic=False,
    benchmark=True
)

trainer.fit(model, train_loader, val_loader)
trainer.test(model, test_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/kai/anaconda3/lib/python3.11/site-packages/pytorch_lightning/loggers/wandb.py:397: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/home/kai/anaconda3/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory ./gig-model/iyp2qqm8/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                  | Type                | Params | Mode 
----------------------------------------------------------------------
0 | model                 | GiG                 | 250 K  | train
1 | global_node_embedding | GlobalNodeEmbedding | 6.7 M  | train
2 | initial_loss          | CrossEntropyLoss    | 0      | train
---------------------------------------------------

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=300` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_acc': 0.3333333432674408,
  'test_f1': 0.3333333432674408,
  'test_loss': 2.114711284637451,
  'test_classification_loss': 2.094419240951538,
  'test_kl_loss': 2.029202938079834}]