In [1]:
import torch
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from typing import Tuple
import numpy as np
from torch_geometric.data import Data
import random
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from rdkit import RDLogger
from rdkit.Chem import RemoveHs

# Suppress RDKit warnings
RDLogger.DisableLog('rdApp.warning')

In [2]:
class MolecularFeatureExtractor:
    def __init__(self):
        self.atom_list = list(range(1, 119))
        self.chirality_list = [
            Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
            Chem.rdchem.ChiralType.CHI_OTHER
        ]
        self.bond_list = [
            Chem.rdchem.BondType.SINGLE,
            Chem.rdchem.BondType.DOUBLE, 
            Chem.rdchem.BondType.TRIPLE,
            Chem.rdchem.BondType.AROMATIC
        ]
        self.bonddir_list = [
            Chem.rdchem.BondDir.NONE,
            Chem.rdchem.BondDir.ENDUPRIGHT,
            Chem.rdchem.BondDir.ENDDOWNRIGHT
        ]

    def calc_atom_features(self, atom: Chem.Atom) -> Tuple[list, list]:
        """Calculate atom features with better error handling"""
        try:
            # Basic features
            atom_feat = [
                self.atom_list.index(atom.GetAtomicNum()),
                self.chirality_list.index(atom.GetChiralTag())
            ]

            # Physical features with error handling
            phys_feat = []
            
            # Molecular weight contribution
            try:
                contrib_mw = Descriptors.ExactMolWt(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_mw)
            except:
                phys_feat.append(0.0)
                
            # LogP contribution    
            try:
                contrib_logp = Descriptors.MolLogP(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_logp)
            except:
                phys_feat.append(0.0)
                
            # Add other physical properties
            phys_feat.extend([
                atom.GetFormalCharge(),
                int(atom.GetHybridization()),
                int(atom.GetIsAromatic()),
                atom.GetTotalNumHs(),
                atom.GetTotalValence(),
                atom.GetDegree()
            ])
            
            return atom_feat, phys_feat
            
        except Exception as e:
            print(f"Error calculating atom features: {e}")
            return [0, 0], [0.0] * 9

    def get_atom_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract atom features for the whole molecule"""
        atom_feats = []
        phys_feats = []
        
        if mol is None:
            return torch.tensor([[0, 0]], dtype=torch.long), torch.tensor([[0.0] * 9], dtype=torch.float)
            
        for atom in mol.GetAtoms():
            atom_feat, phys_feat = self.calc_atom_features(atom)
            atom_feats.append(atom_feat)
            phys_feats.append(phys_feat)

        x = torch.tensor(atom_feats, dtype=torch.long)
        phys = torch.tensor(phys_feats, dtype=torch.float)
        
        return x, phys
    
    def remove_unbonded_hydrogens(mol):
        params = Chem.RemoveHsParameters()
        params.removeDegreeZero = True
        mol = Chem.RemoveHs(mol, params)
        return mol


    def get_bond_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract bond features with better error handling"""
        if mol is None:
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)
            
        row, col, edge_feat = [], [], []
        
        for bond in mol.GetBonds():
            try:
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                
                # Add edges in both directions
                row += [start, end]
                col += [end, start]
                
                # Bond features
                bond_type = self.bond_list.index(bond.GetBondType())
                bond_dir = self.bonddir_list.index(bond.GetBondDir())
                
                # Calculate additional properties
                feat = [
                    bond_type,
                    bond_dir,
                    int(bond.GetIsConjugated()),
                    int(self._is_rotatable(bond)),
                    self._get_bond_length(mol, start, end)
                ]
                
                edge_feat.extend([feat, feat])
                
            except Exception as e:
                print(f"Error processing bond: {e}")
                continue

        if not row:  # If no valid bonds were processed
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)

        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = torch.tensor(edge_feat, dtype=torch.float)
        
        return edge_index, edge_attr

    def _is_rotatable(self, bond: Chem.Bond) -> bool:
        """Check if bond is rotatable"""
        return (bond.GetBondType() == Chem.rdchem.BondType.SINGLE and 
                not bond.IsInRing() and
                len(bond.GetBeginAtom().GetNeighbors()) > 1 and
                len(bond.GetEndAtom().GetNeighbors()) > 1)

    def _get_bond_length(self, mol: Chem.Mol, start: int, end: int) -> float:
        """Get bond length with error handling"""
        try:
            conf = mol.GetConformer()
            if conf.Is3D():
                return Chem.rdMolTransforms.GetBondLength(conf, start, end)
        except:
            pass
        return 0.0

    def process_molecule(self, smiles: str) -> Data:
        """Process SMILES string to graph data"""
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                print(f"Invalid SMILES: {smiles}")
                return None  # Skip invalid molecules
            mol = RemoveHs(mol)

            # Add explicit hydrogens
            mol = Chem.AddHs(mol, addCoords=True)

            # Sanitize molecule
            Chem.SanitizeMol(mol)

            # Check if the molecule has atoms
            if mol.GetNumAtoms() == 0:
                print("Molecule has no atoms, skipping.")
                return None

            # Generate 3D coordinates
            if not mol.GetNumConformers():
                status = AllChem.EmbedMolecule(mol, AllChem.ETKDG())
                if status != 0:
                    print("Failed to generate 3D conformer")
                    return None  # Skip failed molecules

                # Try MMFF or UFF optimization
                try:
                    AllChem.MMFFOptimizeMolecule(mol)
                except:
                    AllChem.UFFOptimizeMolecule(mol)

            # Extract features
            x_cat, x_phys = self.get_atom_features(mol)
            edge_index, edge_attr = self.get_bond_features(mol)

            return Data(
                x_cat=x_cat, 
                x_phys=x_phys,
                edge_index=edge_index, 
                edge_attr=edge_attr,
                num_nodes=x_cat.size(0)
            )

        except Exception as e:
            print(f"Error processing molecule {smiles}: {e}")
            return None


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, MessagePassing
from typing import Tuple, List, Optional
import copy
from dataclasses import dataclass

class MemoryQueue:
    """Memory queue with temporal decay for contrastive learning"""
    def __init__(self, size: int, dim: int, decay: float = 0.99999):
        self.size = size
        self.dim = dim
        self.decay = decay
        self.ptr = 0
        self.full = False
        
        # Initialize queue
        self.queue = nn.Parameter(F.normalize(torch.randn(size, dim), dim=1), requires_grad=False)
        self.queue_age = nn.Parameter(torch.zeros(size), requires_grad=False)
        
#         self.register_buffer("queue", torch.randn(size, dim))
#         self.register_buffer("queue_age", torch.zeros(size))  # Track age of each entry
        self.queue = F.normalize(self.queue, dim=1)
        
    def update_queue(self, keys: torch.Tensor):
        """Update queue with new keys"""
        batch_size = keys.shape[0]
        
        # Increment age of all entries
        self.queue_age += 1
        
        # Add new keys
        if self.ptr + batch_size <= self.size:
            self.queue[self.ptr:self.ptr + batch_size] = keys
            self.queue_age[self.ptr:self.ptr + batch_size] = 0
        else:
            # Handle overflow
            rem = self.size - self.ptr
            self.queue[self.ptr:] = keys[:rem]
            self.queue[:batch_size-rem] = keys[rem:]
            self.queue_age[self.ptr:] = 0
            self.queue_age[:batch_size-rem] = 0
            self.full = True
            
        self.ptr = (self.ptr + batch_size) % self.size
        
    def get_decay_weights(self) -> torch.Tensor:
        """Get temporal decay weights for queue entries"""
        return self.decay ** self.queue_age
        
    def compute_contrastive_loss(self, query: torch.Tensor, positive_key: torch.Tensor, 
                                temperature: float = 0.07) -> torch.Tensor:
        """Compute contrastive loss with temporal decay"""
        # Normalize embeddings
        query = F.normalize(query, dim=1)
        positive_key = F.normalize(positive_key, dim=1)
        queue = F.normalize(self.queue, dim=1)
        
        # Compute logits
        l_pos = torch.einsum('nc,nc->n', [query, positive_key]).unsqueeze(-1)
        l_neg = torch.einsum('nc,ck->nk', [query, queue.T])
        
        # Apply temporal decay to negative samples
        decay_weights = self.get_decay_weights()
        l_neg = l_neg * decay_weights.unsqueeze(0)
        
        # Temperature scaling
        logits = torch.cat([l_pos, l_neg], dim=1) / temperature
        labels = torch.zeros(logits.shape[0], dtype=torch.long, device=query.device)
        
        return F.cross_entropy(logits, labels)

class GraphGenerator(nn.Module):
    """Generator network with proper feature handling"""
    def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int = 128):
        super().__init__()
        
        # Node feature processing
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Edge feature processing
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Graph convolution layers
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        
        # Importance prediction layers
        self.node_importance = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()
        )
        
        self.edge_importance = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
    def normalize_features(self, x_cat, x_phys):
        """Normalize categorical and physical features separately"""
        # Convert categorical features to one-hot
        x_cat = x_cat.float()
        
        # Normalize physical features
        x_phys = x_phys.float()
        if x_phys.size(0) > 1:  # Only normalize if we have more than one sample
            x_phys = (x_phys - x_phys.mean(0)) / (x_phys.std(0) + 1e-5)
            
        return x_cat, x_phys
        
    def forward(self, data) -> Tuple[torch.Tensor, torch.Tensor]:
        # Normalize features
        x_cat, x_phys = self.normalize_features(data.x_cat, data.x_phys)
        
        # Concatenate features
        x = torch.cat([x_cat, x_phys], dim=-1)
        
        edge_index = data.edge_index
        edge_attr = data.edge_attr.float()  # Ensure float type
        
        # Initial feature encoding
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)
        
        # Graph convolutions
        x = F.relu(self.conv1(x, edge_index))  # Removed edge_attr from GCNConv
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        
        # Predict importance scores
        node_scores = self.node_importance(x)
        
        # Edge scores using both connected nodes
        edge_features = torch.cat([
            x[edge_index[0]], 
            x[edge_index[1]]
        ], dim=-1)
        edge_scores = self.edge_importance(edge_features)
        
        return node_scores, edge_scores

def get_model_config(dataset):
    """Get model configuration based on dataset features"""
    sample_data = dataset[0]
    
    # Calculate input dimensions
    node_dim = sample_data.x_cat.shape[1] + sample_data.x_phys.shape[1]
    edge_dim = sample_data.edge_attr.shape[1]
    
    config = GanClConfig(
        node_dim=node_dim,
        edge_dim=edge_dim,
        hidden_dim=128,
        output_dim=128,
        queue_size=65536,
        momentum=0.999,
        temperature=0.07,
        decay=0.99999,
        dropout_ratio=0.25
    )
    
    return config

class GraphDiscriminator(nn.Module):
    """Discriminator/Encoder network"""
    def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int = 128, output_dim: int = 128):
        super().__init__()
        
        # Feature encoding
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Graph convolution layers
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, output_dim)
        
        # Projection head for contrastive learning
        self.projection = nn.Sequential(
            nn.Linear(output_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def normalize_features(self, x_cat, x_phys):
        """Normalize categorical and physical features separately"""
        # Convert categorical features to one-hot
        x_cat = x_cat.float()
        
        # Normalize physical features
        x_phys = x_phys.float()
        if x_phys.size(0) > 1:  # Only normalize if we have more than one sample
            x_phys = (x_phys - x_phys.mean(0)) / (x_phys.std(0) + 1e-5)
            
        return x_cat, x_phys 
        
    def forward(self, data):
        # Normalize features
        x_cat, x_phys = self.normalize_features(data.x_cat, data.x_phys)
        
        # Concatenate features
        x = torch.cat([x_cat, x_phys], dim=-1)
        
        edge_index = data.edge_index
        edge_attr = data.edge_attr.float()  # Ensure float type
        batch = data.batch
        
        # Initial feature encoding
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)
        
        # Graph convolutions
        x = F.relu(self.conv1(x, edge_index))  # Removed edge_attr from GCNConv
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        
        # Global pooling
        x = global_mean_pool(x, batch)
        
        # Projection
        x = self.projection(x)
        
        return x

@dataclass
class GanClConfig:
    """Configuration for GAN-CL training"""
    node_dim: int
    edge_dim: int
    hidden_dim: int = 128
    output_dim: int = 128
    queue_size: int = 65536
    momentum: float = 0.999
    temperature: float = 0.07
    decay: float = 0.99999
    dropout_ratio: float = 0.25

class MolecularGANCL(nn.Module):
    """Combined GAN and Contrastive Learning framework"""
    def __init__(self, config: GanClConfig):
        super().__init__()
        self.config = config
        
        # Initialize networks
        self.generator = GraphGenerator(
            config.node_dim, 
            config.edge_dim, 
            config.hidden_dim
        )
        
        self.encoder = GraphDiscriminator(
            config.node_dim,
            config.edge_dim,
            config.hidden_dim,
            config.output_dim
        )
        
        # Create momentum encoder
        self.momentum_encoder = copy.deepcopy(self.encoder)
        for param in self.momentum_encoder.parameters():
            param.requires_grad = False
            
        # Initialize memory queue
        self.memory_queue = MemoryQueue(
            config.queue_size,
            config.output_dim,
            config.decay
        )
        
    @torch.no_grad()
    def _momentum_update(self):
        """Update momentum encoder"""
        for param_q, param_k in zip(self.encoder.parameters(), 
                                  self.momentum_encoder.parameters()):
            param_k.data = self.config.momentum * param_k.data + \
                          (1 - self.config.momentum) * param_q.data
                          
    def drop_graph_elements(self, data, node_scores: torch.Tensor, 
                          edge_scores: torch.Tensor) -> Data:
        """Apply dropout to graph based on importance scores"""
        # Select elements to keep based on scores and dropout ratio
#         node_mask = (node_scores < self.config.dropout_ratio).float()
#         edge_mask = (edge_scores < self.config.dropout_ratio).float()

        node_mask = (torch.rand_like(node_scores) > self.config.dropout_ratio).float()
        edge_mask = (torch.rand_like(edge_scores) > self.config.dropout_ratio).float()
        
        # Apply masks
        x_cat_new = data.x_cat * node_mask
        x_phys_new = data.x_phys * node_mask
        edge_attr_new = data.edge_attr * edge_mask
        
        # Create new graph data object
        return Data(
            x_cat=x_cat_new,
            x_phys=x_phys_new,
            edge_index=data.edge_index,
            edge_attr=edge_attr_new,
            batch=data.batch
        )
        
    def forward(self, data) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Forward pass returning losses for both GAN and CL
        Returns:
            contrastive_loss: Loss for contrastive learning
            adversarial_loss: Loss for adversarial training
            similarity_loss: Loss measuring embedding similarity
        """
        # Generate importance scores
        node_scores, edge_scores = self.generator(data)
        
        # Create perturbed graph
        perturbed_data = self.drop_graph_elements(data, node_scores, edge_scores)
        
        # Get embeddings
        query_emb = self.encoder(perturbed_data)
        original_emb = self.encoder(data)
        
        with torch.no_grad():
            self._momentum_update()
            key_emb = self.momentum_encoder(data)
            
        # Compute losses
        contrastive_loss = self.memory_queue.compute_contrastive_loss(
            query_emb, key_emb, self.config.temperature
        )
        
        adversarial_loss = -torch.mean(torch.abs(original_emb - query_emb))
        similarity_loss = F.mse_loss(query_emb, original_emb)
        
        # Update memory queue
        self.memory_queue.update_queue(key_emb.detach())
        
        return contrastive_loss, adversarial_loss, similarity_loss
    
    def get_embeddings(self, data) -> torch.Tensor:
        """Get embeddings for downstream tasks"""
        with torch.no_grad():
            return self.encoder(data)

In [4]:
import torch
import numpy as np
from torch_geometric.data import DataLoader
import os
import json
from tqdm import tqdm
import pickle
from datetime import datetime

def save_embeddings(embeddings, labels, filepath):
    """Save embeddings and corresponding labels"""
    with open(filepath, 'wb') as f:
        pickle.dump({
            'embeddings': embeddings,
            'labels': labels
        }, f)

def save_encoder(encoder, save_path, info=None):
    """Save encoder model for downstream tasks"""
    save_dict = {
        'encoder_state_dict': encoder.state_dict(),
        'model_info': info or {}
    }
    torch.save(save_dict, save_path)

def load_encoder(model_path, device='cpu'):
    """Load saved encoder model"""
    checkpoint = torch.load(model_path, map_location=device)
    encoder = GraphDiscriminator(
        node_dim=checkpoint['model_info'].get('node_dim'),
        edge_dim=checkpoint['model_info'].get('edge_dim'),
        hidden_dim=checkpoint['model_info'].get('hidden_dim', 128),
        output_dim=checkpoint['model_info'].get('output_dim', 128)
    )
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    return encoder        
        
def train_gan_cl(train_loader, config, device='cuda', 
                save_dir='./checkpoints', 
                embedding_dir='./embeddings'):
    """Main training function for GAN-CL with fixed gradient computation"""
    
    # Create directories
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(embedding_dir, exist_ok=True)
    encoder_dir = os.path.join(save_dir, 'encoders')
    os.makedirs(encoder_dir, exist_ok=True)    
    
    # Initialize model
    model = MolecularGANCL(config).to(device)
    
    # Initialize optimizers
    optimizer_encoder = torch.optim.Adam(model.encoder.parameters(), lr=3e-4)
    optimizer_generator = torch.optim.Adam(model.generator.parameters(), lr=1e-4)
    
    # Save initial model info
    model_info = {
        'node_dim': config.node_dim,
        'edge_dim': config.edge_dim,
        'hidden_dim': config.hidden_dim,
        'output_dim': config.output_dim,
        'training_config': config.__dict__
    }
    
    # Training phases as before...
    best_loss = float('inf')
    
    # Training metrics
    metrics = {
        'contrastive_losses': [],
        'adversarial_losses': [],
        'similarity_losses': [],
        'total_losses': []
    }
    
    # Training phases
    print("Phase 1: Pretraining Contrastive Learning...")
    pretrain_epochs = 10
    for epoch in range(pretrain_epochs):
        contrastive_epoch_loss = 0
        
        for batch in tqdm(train_loader, desc=f'Pretrain Epoch {epoch+1}/{pretrain_epochs}'):
            batch = batch.to(device)
            
            # Forward pass (without generator)
            query_emb = model.encoder(batch)
            with torch.no_grad():
                key_emb = model.momentum_encoder(batch)
            
            # Compute contrastive loss
            contrastive_loss = model.memory_queue.compute_contrastive_loss(
                query_emb, key_emb, model.config.temperature
            )
            
            # Update encoder
            optimizer_encoder.zero_grad()
            contrastive_loss.backward()
            optimizer_encoder.step()
            
            # Update momentum encoder
            model._momentum_update()
            
            # Update memory queue
            model.memory_queue.update_queue(key_emb.detach())
            
            contrastive_epoch_loss += contrastive_loss.item()
            
        avg_loss = contrastive_epoch_loss / len(train_loader)
        metrics['contrastive_losses'].append(avg_loss)
        print(f'Pretrain Epoch {epoch+1}, Avg Loss: {avg_loss:.4f}')
        
        # Save pretrained checkpoint
        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_encoder_state_dict': optimizer_encoder.state_dict(),
                'loss': avg_loss,
            }, os.path.join(save_dir, f'pretrain_checkpoint_{epoch+1}.pt'))
    
    print("\nPhase 2: Training GAN-CL...")
    train_epochs = 50
    for epoch in range(train_epochs):
        epoch_losses = {
            'contrastive': 0,
            'adversarial': 0,
            'similarity': 0,
            'total': 0
        }
        
        for batch in tqdm(train_loader, desc=f'Train Epoch {epoch+1}/{train_epochs}'):
            batch = batch.to(device)
            
            # Step 1: Train Encoder
            optimizer_encoder.zero_grad()
            
            # Get importance scores from generator
            with torch.no_grad():
                node_scores, edge_scores = model.generator(batch)
            
            # Create perturbed graph
            perturbed_data = model.drop_graph_elements(batch, node_scores, edge_scores)
            
            # Get embeddings
            query_emb = model.encoder(perturbed_data)
            with torch.no_grad():
                key_emb = model.momentum_encoder(batch)
                original_emb = model.encoder(batch).detach()
            
            # Compute losses for encoder
            contrastive_loss = model.memory_queue.compute_contrastive_loss(
                query_emb, key_emb, model.config.temperature
            )
            similarity_loss = F.mse_loss(query_emb, original_emb)
            
            # Total loss for encoder
            encoder_loss = contrastive_loss + 0.1 * similarity_loss
            
            # Update encoder
            encoder_loss.backward()
            optimizer_encoder.step()
            
            # Update momentum encoder
            model._momentum_update()
            
            # Step 2: Train Generator
            optimizer_generator.zero_grad()
            
            # Get new embeddings for adversarial loss
            node_scores, edge_scores = model.generator(batch)
            perturbed_data = model.drop_graph_elements(batch, node_scores, edge_scores)
            
            with torch.no_grad():
                original_emb = model.encoder(batch)
            perturbed_emb = model.encoder(perturbed_data)
            
            # Compute adversarial loss
            adversarial_loss = -F.mse_loss(perturbed_emb, original_emb)
            
            # Update generator
            adversarial_loss.backward()
            optimizer_generator.step()
            
            # Update memory queue
            model.memory_queue.update_queue(key_emb.detach())
            
            # Update metrics
            epoch_losses['contrastive'] += contrastive_loss.item()
            epoch_losses['adversarial'] += adversarial_loss.item()
            epoch_losses['similarity'] += similarity_loss.item()
            epoch_losses['total'] += encoder_loss.item()
        
        # Average losses
        for k in epoch_losses:
            epoch_losses[k] /= len(train_loader)
            metrics[f'{k}_losses'].append(epoch_losses[k])
        
        print(f'Epoch {epoch+1}, Losses: {epoch_losses}')
        
        # Save checkpoint
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_encoder_state_dict': optimizer_encoder.state_dict(),
                'optimizer_generator_state_dict': optimizer_generator.state_dict(),
                'losses': epoch_losses,
            }, os.path.join(save_dir, f'gan_cl_checkpoint_{epoch+1}.pt'))
            
        # Extract and save embeddings periodically
        if (epoch + 1) % 10 == 0:
            model.eval()
            all_embeddings = []
            all_graphs = []
            
            with torch.no_grad():
                for batch in train_loader:
                    batch = batch.to(device)
                    embeddings = model.get_embeddings(batch)
                    all_embeddings.append(embeddings.cpu())
                    all_graphs.extend([data for data in batch])
            
            all_embeddings = torch.cat(all_embeddings, dim=0)
            
            # Save embeddings
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            save_embeddings(
                all_embeddings.numpy(),
                all_graphs,
                os.path.join(embedding_dir, f'embeddings_epoch_{epoch+1}_{timestamp}.pkl')
            )
            
            model.train()
    
    # Save final metrics
    with open(os.path.join(save_dir, 'training_metrics.json'), 'w') as f:
        json.dump(metrics, f)
    
            # Save encoder periodically
        if (epoch + 1) % 10 == 0:
            epoch_info = {
                **model_info,
                'epoch': epoch + 1,
                'loss': epoch_losses['total']
            }
            save_encoder(
                model.encoder,
                os.path.join(encoder_dir, f'encoder_epoch_{epoch+1}.pt'),
                epoch_info
            )
        
        # Save best encoder based on total loss
        if epoch_losses['total'] < best_loss:
            best_loss = epoch_losses['total']
            save_encoder(
                model.encoder,
                os.path.join(encoder_dir, 'best_encoder.pt'),
                {**model_info, 'epoch': epoch + 1, 'loss': best_loss}
            )
    
    # Save final encoder
    save_encoder(
        model.encoder,
        os.path.join(encoder_dir, 'final_encoder.pt'),
        {**model_info, 'epoch': train_epochs, 'loss': epoch_losses['total']}
    )
    
    return model, metrics


In [5]:
def main():
     # Enable anomaly detection during development
    torch.autograd.set_detect_anomaly(True)
    # Your existing data loading code here
    torch.manual_seed(42)
    np.random.seed(42)
    
    print("Starting data loading...")
    extractor = MolecularFeatureExtractor()
#     smiles_file = "D:\\PhD\\Chapter3\\Unsupervised_GAN_Code\\pubchem-41-clean.txt"
    smiles_file = "D:\\PhD\\Chapter3\\Unsupervised_GAN_Code\\pubchem-10m-clean_test.txt"
    
    dataset = []
    failed_smiles = []
    
    with open(smiles_file, 'r') as f:
        for line in f:
            smiles = line.strip()
            data = extractor.process_molecule(smiles)
            if data is not None:
                dataset.append(data)
            else:
                failed_smiles.append(smiles)
    
    print(f"1. Loaded dataset with {len(dataset)} graphs.")
    print(f"2. Failed SMILES count: {len(failed_smiles)}")
    
    if not dataset:
        print("No valid graphs generated.")
        return None
        
    # Setup training
    batch_size = 32
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    print(f"3. Created DataLoader with {len(train_loader.dataset)} graphs")
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"4. Using device: {device}")
    
    # Get configuration based on dataset
    config = get_model_config(dataset)
    
#     # Configure GAN-CL
#     config = GanClConfig(
#         node_dim=dataset[0].x_cat.shape[1] + dataset[0].x_phys.shape[1],
#         edge_dim=dataset[0].edge_attr.shape[1],
#         hidden_dim=128,
#         output_dim=128,
#         queue_size=65536,
#         momentum=0.999,
#         temperature=0.07,
#         decay=0.99999,
#         dropout_ratio=0.25
#     )
    
    # Train model
    print("5. Starting GAN-CL training...")
    model, metrics = train_gan_cl(
        train_loader, 
        config,
        device=device,
        save_dir='./checkpoints',
        embedding_dir='./embeddings'
    )
    
    print("6. Training completed!")
    
    # Extract embeddings for XAI
    print("7. Extracting final embeddings for XAI...")
    model.eval()
    with torch.no_grad():
        all_embeddings = []
        all_graphs = []
        
        for batch in tqdm(train_loader, desc="Extracting embeddings"):
            batch = batch.to(device)
            embeddings = model.get_embeddings(batch)
            all_embeddings.append(embeddings.cpu())
            all_graphs.extend([data for data in batch])
            
    all_embeddings = torch.cat(all_embeddings, dim=0).numpy()
    
    # Save final embeddings and graphs
    final_embedding_path = './embeddings/final_embeddings.pkl'
    save_embeddings(all_embeddings, all_graphs, final_embedding_path)
    print(f"8. Final embeddings saved to {final_embedding_path}")
    
    # Print encoder locations
    print(f"9. Encoders saved in ./checkpoints/encoders/:")
    print(f"   - Best encoder: best_encoder.pt")
    print(f"   - Final encoder: final_encoder.pt")
    print(f"   - Periodic encoders: encoder_epoch_*.pt")
    
    return model, metrics, all_embeddings, all_graphs

if __name__ == "__main__":
    model, metrics, embeddings, graphs = main()

Starting data loading...
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[16:38:35] UFFTYPER: Unrecognized atom type: Se2+2 (17)


Failed to generate 3D conformer


[16:38:38] UFFTYPER: Unrecognized charge state for atom: 1


Failed to generate 3D conformer
1. Loaded dataset with 397 graphs.
2. Failed SMILES count: 5
3. Created DataLoader with 397 graphs
4. Using device: cpu
5. Starting GAN-CL training...




Phase 1: Pretraining Contrastive Learning...


Pretrain Epoch 1/10: 100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  6.75it/s]


Pretrain Epoch 1, Avg Loss: 4.6552


Pretrain Epoch 2/10: 100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  6.84it/s]


Pretrain Epoch 2, Avg Loss: 6.1092


Pretrain Epoch 3/10: 100%|█████████████████████████████████████████████████████████████| 13/13 [00:02<00:00,  6.34it/s]


Pretrain Epoch 3, Avg Loss: 6.4764


Pretrain Epoch 4/10: 100%|█████████████████████████████████████████████████████████████| 13/13 [00:02<00:00,  6.15it/s]


Pretrain Epoch 4, Avg Loss: 6.7511


Pretrain Epoch 5/10: 100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  6.58it/s]


Pretrain Epoch 5, Avg Loss: 6.9143


Pretrain Epoch 6/10: 100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  6.84it/s]


Pretrain Epoch 6, Avg Loss: 7.0698


Pretrain Epoch 7/10: 100%|█████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  6.83it/s]


Pretrain Epoch 7, Avg Loss: 7.1939


Pretrain Epoch 8/10: 100%|█████████████████████████████████████████████████████████████| 13/13 [00:02<00:00,  6.44it/s]


Pretrain Epoch 8, Avg Loss: 7.2831


Pretrain Epoch 9/10: 100%|█████████████████████████████████████████████████████████████| 13/13 [00:02<00:00,  6.38it/s]


Pretrain Epoch 9, Avg Loss: 7.3691


Pretrain Epoch 10/10: 100%|████████████████████████████████████████████████████████████| 13/13 [00:01<00:00,  6.59it/s]


Pretrain Epoch 10, Avg Loss: 7.4304

Phase 2: Training GAN-CL...


Train Epoch 1/50: 100%|████████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.15it/s]


Epoch 1, Losses: {'contrastive': 7.706097199366643, 'adversarial': -0.0006047741581614201, 'similarity': 0.0005707601309180832, 'total': 7.7061543464660645}


Train Epoch 2/50: 100%|████████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.04it/s]


Epoch 2, Losses: {'contrastive': 7.669712250049297, 'adversarial': -0.0005941907449876174, 'similarity': 0.0006228644230689567, 'total': 7.669774568997896}


Train Epoch 3/50: 100%|████████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.17it/s]


Epoch 3, Losses: {'contrastive': 7.708826578580416, 'adversarial': -0.0005725304609558617, 'similarity': 0.0006167498090340254, 'total': 7.708888273972732}


Train Epoch 4/50: 100%|████████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.05it/s]


Epoch 4, Losses: {'contrastive': 7.7228118456327, 'adversarial': -0.000633134560372967, 'similarity': 0.0006356180170909143, 'total': 7.722875338334304}


Train Epoch 5/50: 100%|████████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.07it/s]


Epoch 5, Losses: {'contrastive': 7.722977014688345, 'adversarial': -0.0006155915504608017, 'similarity': 0.0005506445147777692, 'total': 7.723032107720008}


Train Epoch 6/50: 100%|████████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.21it/s]


Epoch 6, Losses: {'contrastive': 7.73885022676908, 'adversarial': -0.0006341060910087365, 'similarity': 0.0005957282824405971, 'total': 7.738909794734075}


Train Epoch 7/50: 100%|████████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.06it/s]


Epoch 7, Losses: {'contrastive': 7.717137300051176, 'adversarial': -0.0006529632597588575, 'similarity': 0.0006308327831972677, 'total': 7.717200425954966}


Train Epoch 8/50: 100%|████████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.17it/s]


Epoch 8, Losses: {'contrastive': 7.726694914010855, 'adversarial': -0.0005400236628842182, 'similarity': 0.0005688592970657807, 'total': 7.726751804351807}


Train Epoch 9/50: 100%|████████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.06it/s]


Epoch 9, Losses: {'contrastive': 7.728627938490647, 'adversarial': -0.0005072676397573489, 'similarity': 0.0005199697938783524, 'total': 7.7286799137408915}


Train Epoch 10/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.10it/s]


Epoch 10, Losses: {'contrastive': 7.704584048344539, 'adversarial': -0.0005138122181121546, 'similarity': 0.0005341046467387619, 'total': 7.70463752746582}


Train Epoch 11/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.16it/s]


Epoch 11, Losses: {'contrastive': 7.6994566183823805, 'adversarial': -0.0005237226678918188, 'similarity': 0.0005702240898524626, 'total': 7.699513582082895}


Train Epoch 12/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  2.85it/s]


Epoch 12, Losses: {'contrastive': 7.6806116837721605, 'adversarial': -0.00044770564552611456, 'similarity': 0.00043091593901268567, 'total': 7.6806547458355245}


Train Epoch 13/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.14it/s]


Epoch 13, Losses: {'contrastive': 7.676091707669771, 'adversarial': -0.00042436516378074884, 'similarity': 0.0004577109827480924, 'total': 7.676137484036959}


Train Epoch 14/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.03it/s]


Epoch 14, Losses: {'contrastive': 7.664036787473238, 'adversarial': -0.0005220542348419818, 'similarity': 0.0005538632820109621, 'total': 7.6640921372633715}


Train Epoch 15/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.11it/s]


Epoch 15, Losses: {'contrastive': 7.635131689218374, 'adversarial': -0.0005276481477686992, 'similarity': 0.000513652827626524, 'total': 7.635183077592116}


Train Epoch 16/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.13it/s]


Epoch 16, Losses: {'contrastive': 7.614208661592924, 'adversarial': -0.0004610629738845791, 'similarity': 0.0004810637798912537, 'total': 7.61425674878634}


Train Epoch 17/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.11it/s]


Epoch 17, Losses: {'contrastive': 7.605614295372596, 'adversarial': -0.00046748431318869384, 'similarity': 0.0005194004625082016, 'total': 7.605666197263277}


Train Epoch 18/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.15it/s]


Epoch 18, Losses: {'contrastive': 7.584592782534086, 'adversarial': -0.0004899169725831598, 'similarity': 0.0005023234317867229, 'total': 7.584643033834604}


Train Epoch 19/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  2.97it/s]


Epoch 19, Losses: {'contrastive': 7.542650919694167, 'adversarial': -0.0004673386328459646, 'similarity': 0.00044696091656358197, 'total': 7.542695558988131}


Train Epoch 20/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.21it/s]


Epoch 20, Losses: {'contrastive': 7.512714055868296, 'adversarial': -0.00046433610035679664, 'similarity': 0.0005237194373666381, 'total': 7.512766434596135}


Train Epoch 21/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.05it/s]


Epoch 21, Losses: {'contrastive': 7.488494946406438, 'adversarial': -0.000496066332114144, 'similarity': 0.0005022055747059102, 'total': 7.488545197706956}


Train Epoch 22/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.12it/s]


Epoch 22, Losses: {'contrastive': 7.449380251077505, 'adversarial': -0.0004755713251562646, 'similarity': 0.0004981757665518671, 'total': 7.449430062220647}


Train Epoch 23/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.19it/s]


Epoch 23, Losses: {'contrastive': 7.442277614886944, 'adversarial': -0.0004891140896898622, 'similarity': 0.0005128935422712507, 'total': 7.442328893221342}


Train Epoch 24/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.08it/s]


Epoch 24, Losses: {'contrastive': 7.395681784703181, 'adversarial': -0.0004503590283163178, 'similarity': 0.00046446014983722795, 'total': 7.395728257986216}


Train Epoch 25/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.18it/s]


Epoch 25, Losses: {'contrastive': 7.333456442906306, 'adversarial': -0.00045742820321510616, 'similarity': 0.0004798666005416845, 'total': 7.333504456740159}


Train Epoch 26/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.06it/s]


Epoch 26, Losses: {'contrastive': 7.327721118927002, 'adversarial': -0.0005170731903770223, 'similarity': 0.0005093262924884374, 'total': 7.327772067143367}


Train Epoch 27/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.13it/s]


Epoch 27, Losses: {'contrastive': 7.28427714567918, 'adversarial': -0.0005567484544231915, 'similarity': 0.0005061267852747383, 'total': 7.284327763777513}


Train Epoch 28/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.14it/s]


Epoch 28, Losses: {'contrastive': 7.268845411447378, 'adversarial': -0.00047777983351037477, 'similarity': 0.0005082897774767703, 'total': 7.268896212944617}


Train Epoch 29/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.04it/s]


Epoch 29, Losses: {'contrastive': 7.243943727933443, 'adversarial': -0.0004906712856609374, 'similarity': 0.0004894758637349766, 'total': 7.24399262208205}


Train Epoch 30/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.23it/s]


Epoch 30, Losses: {'contrastive': 7.179747141324556, 'adversarial': -0.0005208368887766623, 'similarity': 0.00043367041731611465, 'total': 7.179790533505953}


Train Epoch 31/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.06it/s]


Epoch 31, Losses: {'contrastive': 7.1403242624723, 'adversarial': -0.0005822869626661906, 'similarity': 0.0005131078763112712, 'total': 7.140375577486479}


Train Epoch 32/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.11it/s]


Epoch 32, Losses: {'contrastive': 7.12304427073552, 'adversarial': -0.0005459086006829658, 'similarity': 0.0005307310829476381, 'total': 7.1230973830589885}


Train Epoch 33/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.12it/s]


Epoch 33, Losses: {'contrastive': 7.087850057161772, 'adversarial': -0.0005509648494458256, 'similarity': 0.0005530352668406871, 'total': 7.087905370272123}


Train Epoch 34/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  2.99it/s]


Epoch 34, Losses: {'contrastive': 7.028794948871319, 'adversarial': -0.0005617415836940592, 'similarity': 0.0005307901726784901, 'total': 7.028847987835224}


Train Epoch 35/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.14it/s]


Epoch 35, Losses: {'contrastive': 7.018270089076116, 'adversarial': -0.0005209805542388215, 'similarity': 0.0005638674819447959, 'total': 7.018326502579909}


Train Epoch 36/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.09it/s]


Epoch 36, Losses: {'contrastive': 6.990226672245906, 'adversarial': -0.0005851619554540286, 'similarity': 0.0005700545779501016, 'total': 6.9902837093059835}


Train Epoch 37/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.24it/s]


Epoch 37, Losses: {'contrastive': 6.949682969313401, 'adversarial': -0.0005281559737900702, 'similarity': 0.0005227932482599639, 'total': 6.94973531136146}


Train Epoch 38/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.07it/s]


Epoch 38, Losses: {'contrastive': 6.9080071816077595, 'adversarial': -0.0005638800592770657, 'similarity': 0.0005188588044033027, 'total': 6.90805904681866}


Train Epoch 39/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.12it/s]


Epoch 39, Losses: {'contrastive': 6.909762712625357, 'adversarial': -0.0005682685407989013, 'similarity': 0.0005398884957405523, 'total': 6.9098166685837965}


Train Epoch 40/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.10it/s]


Epoch 40, Losses: {'contrastive': 6.903601756462684, 'adversarial': -0.0005732920052161297, 'similarity': 0.0006206382263021973, 'total': 6.903663818652813}


Train Epoch 41/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.01it/s]


Epoch 41, Losses: {'contrastive': 6.844861617455115, 'adversarial': -0.0006918677513917478, 'similarity': 0.0006565779868441706, 'total': 6.8449272742638225}


Train Epoch 42/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:05<00:00,  2.45it/s]


Epoch 42, Losses: {'contrastive': 6.804459755237286, 'adversarial': -0.0005525188753381371, 'similarity': 0.0005759478618319218, 'total': 6.804517379173865}


Train Epoch 43/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  2.81it/s]


Epoch 43, Losses: {'contrastive': 6.794143493358906, 'adversarial': -0.0006987940645418488, 'similarity': 0.0005975442092256764, 'total': 6.794203208043025}


Train Epoch 44/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.22it/s]


Epoch 44, Losses: {'contrastive': 6.776719570159912, 'adversarial': -0.0006407162290997803, 'similarity': 0.0006057270063767926, 'total': 6.776780201838567}


Train Epoch 45/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  2.88it/s]


Epoch 45, Losses: {'contrastive': 6.712311011094314, 'adversarial': -0.0006414836075586769, 'similarity': 0.000623107471395857, 'total': 6.7123733300429125}


Train Epoch 46/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.09it/s]


Epoch 46, Losses: {'contrastive': 6.677046445699839, 'adversarial': -0.0006836728649572111, 'similarity': 0.0007262945784112582, 'total': 6.677119108346792}


Train Epoch 47/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  2.75it/s]


Epoch 47, Losses: {'contrastive': 6.685994258293738, 'adversarial': -0.0006714242738850701, 'similarity': 0.0007265576233084386, 'total': 6.686066884260911}


Train Epoch 48/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.20it/s]


Epoch 48, Losses: {'contrastive': 6.627720502706675, 'adversarial': -0.0007385410628138253, 'similarity': 0.0006017488843868845, 'total': 6.627780657548171}


Train Epoch 49/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:04<00:00,  3.00it/s]


Epoch 49, Losses: {'contrastive': 6.609088017390325, 'adversarial': -0.0007057215082638252, 'similarity': 0.0006788085223748707, 'total': 6.609155985025259}


Train Epoch 50/50: 100%|███████████████████████████████████████████████████████████████| 13/13 [00:05<00:00,  2.44it/s]


Epoch 50, Losses: {'contrastive': 6.580638115222637, 'adversarial': -0.0007157067325002012, 'similarity': 0.0006788143520959868, 'total': 6.580705972818228}
6. Training completed!
7. Extracting final embeddings for XAI...


Extracting embeddings: 100%|███████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 98.27it/s]

8. Final embeddings saved to ./embeddings/final_embeddings.pkl
9. Encoders saved in ./checkpoints/encoders/:
   - Best encoder: best_encoder.pt
   - Final encoder: final_encoder.pt
   - Periodic encoders: encoder_epoch_*.pt



