In [1]:
import torch
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from typing import Tuple
import numpy as np
from torch_geometric.data import Data
import random
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from rdkit import RDLogger
from rdkit.Chem import RemoveHs
from datetime import datetime

import os
import json
from tqdm import tqdm
import pickle

# Suppress RDKit warnings
RDLogger.DisableLog('rdApp.warning')
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')



In [2]:
class MolecularFeatureExtractor:
    def __init__(self):
        self.atom_list = list(range(1, 119))
        self.chirality_list = [
            Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
            Chem.rdchem.ChiralType.CHI_OTHER
        ]
        self.bond_list = [
            Chem.rdchem.BondType.SINGLE,
            Chem.rdchem.BondType.DOUBLE, 
            Chem.rdchem.BondType.TRIPLE,
            Chem.rdchem.BondType.AROMATIC
        ]
        self.bonddir_list = [
            Chem.rdchem.BondDir.NONE,
            Chem.rdchem.BondDir.ENDUPRIGHT,
            Chem.rdchem.BondDir.ENDDOWNRIGHT
        ]

    def calc_atom_features(self, atom: Chem.Atom) -> Tuple[list, list]:
        """Calculate atom features with better error handling"""
        try:
            # Basic features
            atom_feat = [
                self.atom_list.index(atom.GetAtomicNum()),
                self.chirality_list.index(atom.GetChiralTag())
            ]

            # Physical features with error handling
            phys_feat = []
            
            # Molecular weight contribution
            try:
                contrib_mw = Descriptors.ExactMolWt(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_mw)
            except:
                phys_feat.append(0.0)
                
            # LogP contribution    
            try:
                contrib_logp = Descriptors.MolLogP(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_logp)
            except:
                phys_feat.append(0.0)
                
            # Add other physical properties
            phys_feat.extend([
                atom.GetFormalCharge(),
                int(atom.GetHybridization()),
                int(atom.GetIsAromatic()),
                atom.GetTotalNumHs(),
                atom.GetTotalValence(),
                atom.GetDegree()
            ])
            
            return atom_feat, phys_feat
            
        except Exception as e:
            print(f"Error calculating atom features: {e}")
            return [0, 0], [0.0] * 9

    def get_atom_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract atom features for the whole molecule"""
        atom_feats = []
        phys_feats = []
        
        if mol is None:
            return torch.tensor([[0, 0]], dtype=torch.long), torch.tensor([[0.0] * 9], dtype=torch.float)
            
        for atom in mol.GetAtoms():
            atom_feat, phys_feat = self.calc_atom_features(atom)
            atom_feats.append(atom_feat)
            phys_feats.append(phys_feat)

        x = torch.tensor(atom_feats, dtype=torch.long)
        phys = torch.tensor(phys_feats, dtype=torch.float)
        
        return x, phys
    
    def remove_unbonded_hydrogens(mol):
        params = Chem.RemoveHsParameters()
        params.removeDegreeZero = True
        mol = Chem.RemoveHs(mol, params)
        return mol


    def get_bond_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract bond features with better error handling"""
        if mol is None:
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)
            
        row, col, edge_feat = [], [], []
        
        for bond in mol.GetBonds():
            try:
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                
                # Add edges in both directions
                row += [start, end]
                col += [end, start]
                
                # Bond features
                bond_type = self.bond_list.index(bond.GetBondType())
                bond_dir = self.bonddir_list.index(bond.GetBondDir())
                
                # Calculate additional properties
                feat = [
                    bond_type,
                    bond_dir,
                    int(bond.GetIsConjugated()),
                    int(self._is_rotatable(bond)),
                    self._get_bond_length(mol, start, end)
                ]
                
                edge_feat.extend([feat, feat])
                
            except Exception as e:
                print(f"Error processing bond: {e}")
                continue

        if not row:  # If no valid bonds were processed
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)

        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = torch.tensor(edge_feat, dtype=torch.float)
        
        return edge_index, edge_attr

    def _is_rotatable(self, bond: Chem.Bond) -> bool:
        """Check if bond is rotatable"""
        return (bond.GetBondType() == Chem.rdchem.BondType.SINGLE and 
                not bond.IsInRing() and
                len(bond.GetBeginAtom().GetNeighbors()) > 1 and
                len(bond.GetEndAtom().GetNeighbors()) > 1)

    def _get_bond_length(self, mol: Chem.Mol, start: int, end: int) -> float:
        """Get bond length with error handling"""
        try:
            conf = mol.GetConformer()
            if conf.Is3D():
                return Chem.rdMolTransforms.GetBondLength(conf, start, end)
        except:
            pass
        return 0.0

    def process_molecule(self, smiles: str) -> Data:
        """Process SMILES string to graph data"""
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                print(f"Invalid SMILES: {smiles}")
                return None  # Skip invalid molecules
            mol = RemoveHs(mol)

            # Add explicit hydrogens
            mol = Chem.AddHs(mol, addCoords=True)

            # Sanitize molecule
            Chem.SanitizeMol(mol)

            # Check if the molecule has atoms
            if mol.GetNumAtoms() == 0:
                print("Molecule has no atoms, skipping.")
                return None

            # Generate 3D coordinates
            if not mol.GetNumConformers():
                status = AllChem.EmbedMolecule(mol, AllChem.ETKDG())
                if status != 0:
                    print("Failed to generate 3D conformer")
                    return None  # Skip failed molecules

                # Try MMFF or UFF optimization
                try:
                    AllChem.MMFFOptimizeMolecule(mol)
                except:
                    AllChem.UFFOptimizeMolecule(mol)

            # Extract features
            x_cat, x_phys = self.get_atom_features(mol)
            edge_index, edge_attr = self.get_bond_features(mol)

            return Data(
                x_cat=x_cat, 
                x_phys=x_phys,
                edge_index=edge_index, 
                edge_attr=edge_attr,
                num_nodes=x_cat.size(0)
            )

        except Exception as e:
            print(f"Error processing molecule {smiles}: {e}")
            return None


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, MessagePassing
from typing import Tuple, List, Optional
import copy
from dataclasses import dataclass

class MemoryQueue:
    """Memory queue with temporal decay for contrastive learning"""
    def __init__(self, size: int, dim: int, decay: float = 0.99999):
        self.size = size
        self.dim = dim
        self.decay = decay
        self.ptr = 0
        self.full = False
        
        # Initialize queue
        self.queue = nn.Parameter(F.normalize(torch.randn(size, dim), dim=1), requires_grad=False)
        self.queue_age = nn.Parameter(torch.zeros(size), requires_grad=False)
        
#         self.register_buffer("queue", torch.randn(size, dim))
#         self.register_buffer("queue_age", torch.zeros(size))  # Track age of each entry
        self.queue = F.normalize(self.queue, dim=1)
        
    def update_queue(self, keys: torch.Tensor):
        """Update queue with new keys"""
        batch_size = keys.shape[0]
        
        # Increment age of all entries
        self.queue_age += 1
        
        # Add new keys
        if self.ptr + batch_size <= self.size:
            self.queue[self.ptr:self.ptr + batch_size] = keys
            self.queue_age[self.ptr:self.ptr + batch_size] = 0
        else:
            # Handle overflow
            rem = self.size - self.ptr
            self.queue[self.ptr:] = keys[:rem]
            self.queue[:batch_size-rem] = keys[rem:]
            self.queue_age[self.ptr:] = 0
            self.queue_age[:batch_size-rem] = 0
            self.full = True
            
        self.ptr = (self.ptr + batch_size) % self.size
        
    def get_decay_weights(self) -> torch.Tensor:
        """Get temporal decay weights for queue entries"""
        return self.decay ** self.queue_age
        
    def compute_contrastive_loss(self, query: torch.Tensor, positive_key: torch.Tensor, 
                                temperature: float = 0.07) -> torch.Tensor:
        """Compute contrastive loss with temporal decay"""
        # Normalize embeddings
        query = F.normalize(query, dim=1)
        positive_key = F.normalize(positive_key, dim=1)
        queue = F.normalize(self.queue, dim=1)
        
        # Compute logits
        l_pos = torch.einsum('nc,nc->n', [query, positive_key]).unsqueeze(-1)
        l_neg = torch.einsum('nc,ck->nk', [query, queue.T])
        
        # Apply temporal decay to negative samples
        decay_weights = self.get_decay_weights()
        l_neg = l_neg * decay_weights.unsqueeze(0)
        
        # Temperature scaling
        logits = torch.cat([l_pos, l_neg], dim=1) / temperature
        labels = torch.zeros(logits.shape[0], dtype=torch.long, device=query.device)
        
        return F.cross_entropy(logits, labels)

class GraphGenerator(nn.Module):
    """Generator network with proper feature handling"""
    def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int = 128):
        super().__init__()
        
        # Node feature processing
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Edge feature processing
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Graph convolution layers
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        
        # Importance prediction layers
        self.node_importance = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()
        )
        
        self.edge_importance = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
    def normalize_features(self, x_cat, x_phys):
        """Normalize categorical and physical features separately"""
        # Convert categorical features to one-hot
        x_cat = x_cat.float()
        
        # Normalize physical features
        x_phys = x_phys.float()
        if x_phys.size(0) > 1:  # Only normalize if we have more than one sample
            x_phys = (x_phys - x_phys.mean(0)) / (x_phys.std(0) + 1e-5)
            
        return x_cat, x_phys
        
    def forward(self, data) -> Tuple[torch.Tensor, torch.Tensor]:
        # Normalize features
        x_cat, x_phys = self.normalize_features(data.x_cat, data.x_phys)
        
        # Concatenate features
        x = torch.cat([x_cat, x_phys], dim=-1)
        
        edge_index = data.edge_index
        edge_attr = data.edge_attr.float()  # Ensure float type
        
        # Initial feature encoding
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)
        
        # Graph convolutions
        x = F.relu(self.conv1(x, edge_index))  # Removed edge_attr from GCNConv
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        
        # Predict importance scores
        node_scores = self.node_importance(x)
        
        # Edge scores using both connected nodes
        edge_features = torch.cat([
            x[edge_index[0]], 
            x[edge_index[1]]
        ], dim=-1)
        edge_scores = self.edge_importance(edge_features)
        
        return node_scores, edge_scores

def get_model_config(dataset):
    """Get model configuration based on dataset features"""
    sample_data = dataset[0]
    
    # Calculate input dimensions
    node_dim = sample_data.x_cat.shape[1] + sample_data.x_phys.shape[1]
    edge_dim = sample_data.edge_attr.shape[1]
    
    config = GanClConfig(
        node_dim=node_dim,
        edge_dim=edge_dim,
        hidden_dim=128,
        output_dim=128,
        queue_size=65536,
        momentum=0.999,
        temperature=0.07,
        decay=0.99999,
        dropout_ratio=0.25
    )
    
    return config

class GraphDiscriminator(nn.Module):
    """Discriminator/Encoder network"""
    def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int = 128, output_dim: int = 128):
        super().__init__()
        
        # Feature encoding
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Graph convolution layers
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, output_dim)
        
        # Projection head for contrastive learning
        self.projection = nn.Sequential(
            nn.Linear(output_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def normalize_features(self, x_cat, x_phys):
        """Normalize categorical and physical features separately"""
        # Convert categorical features to one-hot
        x_cat = x_cat.float()
        
        # Normalize physical features
        x_phys = x_phys.float()
        if x_phys.size(0) > 1:  # Only normalize if we have more than one sample
            x_phys = (x_phys - x_phys.mean(0)) / (x_phys.std(0) + 1e-5)
            
        return x_cat, x_phys 
        
    def forward(self, data):
        # Normalize features
        x_cat, x_phys = self.normalize_features(data.x_cat, data.x_phys)
        
        # Concatenate features
        x = torch.cat([x_cat, x_phys], dim=-1)
        
        edge_index = data.edge_index
        edge_attr = data.edge_attr.float()  # Ensure float type
        batch = data.batch
        
        # Initial feature encoding
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)
        
        # Graph convolutions
        x = F.relu(self.conv1(x, edge_index))  # Removed edge_attr from GCNConv
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        
        # Global pooling
        x = global_mean_pool(x, batch)
        
        # Projection
        x = self.projection(x)
        
        return x

@dataclass
class GanClConfig:
    """Configuration for GAN-CL training"""
    node_dim: int
    edge_dim: int
    hidden_dim: int = 128
    output_dim: int = 128
    queue_size: int = 65536
    momentum: float = 0.999
    temperature: float = 0.07
    decay: float = 0.99999
    dropout_ratio: float = 0.25

class MolecularGANCL(nn.Module):
    """Combined GAN and Contrastive Learning framework"""
    def __init__(self, config: GanClConfig):
        super().__init__()
        self.config = config
        
        # Add weight initialization
        def init_weights(m):
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                m.bias.data.fill_(0.01)
        
        # Initialize networks
        self.generator = GraphGenerator(
            config.node_dim, 
            config.edge_dim, 
            config.hidden_dim * 2
        )
        
        self.encoder = GraphDiscriminator(
            config.node_dim,
            config.edge_dim,
            config.hidden_dim,
            config.output_dim
        )
        self.encoder.apply(init_weights)
        
        # Modified loss weights
        self.contrastive_weight = 1.0
        self.adversarial_weight = 0.1  # Increased from 0.05
        self.similarity_weight = 0.01  # Decreased from 0.1
        
        # Temperature annealing
        self.initial_temperature = 0.1
        self.min_temperature = 0.05        
        
        # Create momentum encoder
        self.momentum_encoder = copy.deepcopy(self.encoder)
        for param in self.momentum_encoder.parameters():
            param.requires_grad = False
            
        # Initialize memory queue
        self.memory_queue = MemoryQueue(
            config.queue_size,
            config.output_dim,
            config.decay
        )
        
    @torch.no_grad()
    def _momentum_update(self):
        """Update momentum encoder"""
        for param_q, param_k in zip(self.encoder.parameters(), 
                                  self.momentum_encoder.parameters()):
            param_k.data = self.config.momentum * param_k.data + \
                          (1 - self.config.momentum) * param_q.data
                          
    def drop_graph_elements(self, data, node_scores: torch.Tensor, 
                          edge_scores: torch.Tensor) -> Data:
        """Apply dropout to graph based on importance scores"""
        # Select elements to keep based on scores and dropout ratio
#         node_mask = (node_scores < self.config.dropout_ratio).float()
#         edge_mask = (edge_scores < self.config.dropout_ratio).float()

        node_mask = (torch.rand_like(node_scores) > self.config.dropout_ratio).float()
        edge_mask = (torch.rand_like(edge_scores) > self.config.dropout_ratio).float()
        
        # Apply masks
        x_cat_new = data.x_cat * node_mask
        x_phys_new = data.x_phys * node_mask
        edge_attr_new = data.edge_attr * edge_mask
        
        # Create new graph data object
        return Data(
            x_cat=x_cat_new,
            x_phys=x_phys_new,
            edge_index=data.edge_index,
            edge_attr=edge_attr_new,
            batch=data.batch
        )
        
    def get_temperature(self, epoch, total_epochs):
        """Anneal temperature during training"""
        progress = epoch / total_epochs
        return max(self.initial_temperature * (1 - progress), self.min_temperature)
    
    def forward(self, data, epoch=0, total_epochs=50):
        # Get current temperature
        temperature = self.get_temperature(epoch, total_epochs)
        
        # Get importance scores from generator
        node_scores, edge_scores = self.generator(data)
        
        # Create perturbed graph
        perturbed_data = self.drop_graph_elements(data, node_scores, edge_scores)
        
        # Get embeddings
        query_emb = self.encoder(perturbed_data)
        with torch.no_grad():
            key_emb = self.momentum_encoder(data)
            original_emb = self.encoder(data).detach()
        
        # Compute losses with modified weights
        contrastive_loss = self.memory_queue.compute_contrastive_loss(
            query_emb, key_emb, temperature
        ) * self.contrastive_weight
        
        adversarial_loss = -F.mse_loss(query_emb, original_emb) * self.adversarial_weight
        similarity_loss = F.mse_loss(query_emb, original_emb) * self.similarity_weight
        
        return contrastive_loss, adversarial_loss, similarity_loss
    
    def get_embeddings(self, data) -> torch.Tensor:
        """Get embeddings for downstream tasks"""
        with torch.no_grad():
            return self.encoder(data)

In [4]:

def save_embeddings(embeddings, labels, filepath):
    """Save embeddings and corresponding labels"""
    with open(filepath, 'wb') as f:
        pickle.dump({
            'embeddings': embeddings,
            'labels': labels
        }, f)

def save_encoder(encoder, save_path, info=None):
    """Save encoder model for downstream tasks"""
    save_dict = {
        'encoder_state_dict': encoder.state_dict(),
        'model_info': info or {}
    }
    torch.save(save_dict, save_path)

def load_encoder(model_path, device='cpu'):
    """Load saved encoder model"""
    checkpoint = torch.load(model_path, map_location=device)
    encoder = GraphDiscriminator(
        node_dim=checkpoint['model_info'].get('node_dim'),
        edge_dim=checkpoint['model_info'].get('edge_dim'),
        hidden_dim=checkpoint['model_info'].get('hidden_dim', 128),
        output_dim=checkpoint['model_info'].get('output_dim', 128)
    )
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    return encoder        

def save_embeddings_with_smiles(embeddings, smiles_list, filepath):
    """Save embeddings and corresponding SMILES
    
    Args:
        embeddings: Numpy array of embeddings
        smiles_list: List of SMILES strings corresponding to embeddings
        filepath: Path to save the data
    """
    import numpy as np
    
    # Save as npz format
    np.savez(filepath, embeddings=embeddings, smiles=smiles_list)
    print(f"Saved {len(smiles_list)} embeddings to {filepath}")

def extract_and_save_embeddings(model, train_loader, device, filepath):
    """Extract embeddings using a model and save them with SMILES
    
    Args:
        model: GAN-CL model with encoder
        train_loader: DataLoader with molecular data
        device: Device to run extraction on
        filepath: Path to save the embeddings
        
    Returns:
        embeddings: Numpy array of extracted embeddings
        smiles_list: List of SMILES strings
    """
    import torch
    import numpy as np
    from tqdm import tqdm
    
    model.eval()
    all_embeddings = []
    all_smiles = []
    
    with torch.no_grad():
        for batch in tqdm(train_loader, desc="Extracting embeddings"):
            batch = batch.to(device)
            embeddings = model.get_embeddings(batch)
            all_embeddings.append(embeddings.cpu().numpy())
            
            # Extract SMILES from batch
            for data in batch:
                if hasattr(data, 'smiles'):
                    all_smiles.append(data.smiles)
                else:
                    all_smiles.append("unknown")
    
    # Concatenate embeddings
    embeddings = np.vstack(all_embeddings)
    
    # Save embeddings with SMILES
    save_embeddings_with_smiles(embeddings, all_smiles, filepath)
    
    return embeddings, all_smiles

def train_gan_cl_with_embeddings(train_loader, config, device='cuda', 
                                save_dir='./checkpoints', 
                                embedding_dir='./embeddings'):
    """Modified training function that saves before and after embeddings
    
    Args:
        train_loader: DataLoader with molecular data
        config: GAN-CL configuration
        device: Device to run training on
        save_dir: Directory to save checkpoints
        embedding_dir: Directory to save embeddings
        
    Returns:
        model: Trained GAN-CL model
        metrics: Training metrics
        before_embeddings: Embeddings before training
        after_embeddings: Embeddings after training
    """
    import os
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import json
    from tqdm import tqdm
    from datetime import datetime
    
    # Create directories
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(embedding_dir, exist_ok=True)
    encoder_dir = os.path.join(save_dir, 'encoders')
    os.makedirs(encoder_dir, exist_ok=True)
    
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    
    # Initialize model
    model = MolecularGANCL(config).to(device)
    
    # Extract SMILES from dataloader
    all_smiles = []
    for batch in train_loader:
        for data in batch:
            if hasattr(data, 'smiles'):
                all_smiles.append(data.smiles)
    
    print(f"Found {len(all_smiles)} molecules with SMILES")
    
    # Extract and save embeddings BEFORE training
    print("Extracting embeddings BEFORE training...")
    before_emb_path = os.path.join(embedding_dir, f"before_training_{timestamp}.npz")
    before_embeddings, before_smiles = extract_and_save_embeddings(
        model, train_loader, device, before_emb_path
    )
    
    # Initialize optimizers
    optimizer_encoder = torch.optim.Adam(model.encoder.parameters(), lr=3e-4)
    optimizer_generator = torch.optim.Adam(model.generator.parameters(), lr=1e-4)
    
    # Save initial model info
    model_info = {
        'node_dim': config.node_dim,
        'edge_dim': config.edge_dim,
        'hidden_dim': config.hidden_dim,
        'output_dim': config.output_dim,
        'training_config': config.__dict__
    }
    
    # Training phases as before...
    best_loss = float('inf')
    
    # Training metrics
    metrics = {
        'contrastive_losses': [],
        'adversarial_losses': [],
        'similarity_losses': [],
        'total_losses': []
    }
    
    # Training phases
    print("Phase 1: Pretraining Contrastive Learning...")
    pretrain_epochs = 10
    for epoch in range(pretrain_epochs):
        contrastive_epoch_loss = 0
        
        for batch in tqdm(train_loader, desc=f'Pretrain Epoch {epoch+1}/{pretrain_epochs}'):
            batch = batch.to(device)
            
            # Forward pass (without generator)
            query_emb = model.encoder(batch)
            with torch.no_grad():
                key_emb = model.momentum_encoder(batch)
            
            # Compute contrastive loss
            contrastive_loss = model.memory_queue.compute_contrastive_loss(
                query_emb, key_emb, model.config.temperature
            )
            
            # Update encoder
            optimizer_encoder.zero_grad()
            contrastive_loss.backward()
            optimizer_encoder.step()
            
            # Update momentum encoder
            model._momentum_update()
            
            # Update memory queue
            model.memory_queue.update_queue(key_emb.detach())
            
            contrastive_epoch_loss += contrastive_loss.item()
            
        avg_loss = contrastive_epoch_loss / len(train_loader)
        metrics['contrastive_losses'].append(avg_loss)
        print(f'Pretrain Epoch {epoch+1}, Avg Loss: {avg_loss:.4f}')
        
        # Save pretrained checkpoint
        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_encoder_state_dict': optimizer_encoder.state_dict(),
                'loss': avg_loss,
            }, os.path.join(save_dir, f'pretrain_checkpoint_{epoch+1}.pt'))
    
    print("\nPhase 2: Training GAN-CL...")
    train_epochs = 50
    for epoch in range(train_epochs):
        epoch_losses = {
            'contrastive': 0,
            'adversarial': 0,
            'similarity': 0,
            'total': 0
        }
        
        for batch in tqdm(train_loader, desc=f'Train Epoch {epoch+1}/{train_epochs}'):
            batch = batch.to(device)
            
            # Step 1: Train Encoder
            optimizer_encoder.zero_grad()
            
            # Get importance scores from generator
            with torch.no_grad():
                node_scores, edge_scores = model.generator(batch)
            
            # Create perturbed graph
            perturbed_data = model.drop_graph_elements(batch, node_scores, edge_scores)
            
            # Get embeddings
            query_emb = model.encoder(perturbed_data)
            with torch.no_grad():
                key_emb = model.momentum_encoder(batch)
                original_emb = model.encoder(batch).detach()
            
            # Compute losses for encoder
            contrastive_loss = model.memory_queue.compute_contrastive_loss(
                query_emb, key_emb, model.config.temperature
            )
            similarity_loss = F.mse_loss(query_emb, original_emb)
            
            # Total loss for encoder
            encoder_loss = contrastive_loss + 0.1 * similarity_loss
            
            # Update encoder
            encoder_loss.backward()
            optimizer_encoder.step()
            
            # Update momentum encoder
            model._momentum_update()
            
            # Step 2: Train Generator
            optimizer_generator.zero_grad()
            
            # Get new embeddings for adversarial loss
            node_scores, edge_scores = model.generator(batch)
            perturbed_data = model.drop_graph_elements(batch, node_scores, edge_scores)
            
            with torch.no_grad():
                original_emb = model.encoder(batch)
            perturbed_emb = model.encoder(perturbed_data)
            
            # Compute adversarial loss
            adversarial_loss = -F.mse_loss(perturbed_emb, original_emb)
            
            # Update generator
            adversarial_loss.backward()
            optimizer_generator.step()
            
            # Update memory queue
            model.memory_queue.update_queue(key_emb.detach())
            
            # Update metrics
            epoch_losses['contrastive'] += contrastive_loss.item()
            epoch_losses['adversarial'] += adversarial_loss.item()
            epoch_losses['similarity'] += similarity_loss.item()
            epoch_losses['total'] += encoder_loss.item()
        
        # Average losses
        for k in epoch_losses:
            epoch_losses[k] /= len(train_loader)
            metrics[f'{k}_losses'].append(epoch_losses[k])
        
        print(f'Epoch {epoch+1}, Losses: {epoch_losses}')
        
        # Save checkpoint
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_encoder_state_dict': optimizer_encoder.state_dict(),
                'optimizer_generator_state_dict': optimizer_generator.state_dict(),
                'losses': epoch_losses,
            }, os.path.join(save_dir, f'gan_cl_checkpoint_{epoch+1}.pt'))
            
        # Extract and save embeddings periodically
        if (epoch + 1) % 10 == 0:
            epoch_emb_path = os.path.join(embedding_dir, f"epoch_{epoch+1}_{timestamp}.npz")
            extract_and_save_embeddings(model, train_loader, device, epoch_emb_path)
    
    # Save final metrics
    with open(os.path.join(save_dir, 'training_metrics.json'), 'w') as f:
        json.dump(metrics, f)
    
    # Save encoder periodically
    if (epoch + 1) % 10 == 0:
        epoch_info = {
            **model_info,
            'epoch': epoch + 1,
            'loss': epoch_losses['total']
        }
        save_encoder(
            model.encoder,
            os.path.join(encoder_dir, f'encoder_epoch_{epoch+1}.pt'),
            epoch_info
        )
    
    # Save best encoder based on total loss
    if epoch_losses['total'] < best_loss:
        best_loss = epoch_losses['total']
        save_encoder(
            model.encoder,
            os.path.join(encoder_dir, f'best_encoder_{timestamp}.pt'),
            {**model_info, 'epoch': epoch + 1, 'loss': best_loss}
        )
    
    # Save final encoder
    save_encoder(
        model.encoder,
        os.path.join(encoder_dir, f'final_encoder_{timestamp}.pt'),
        {**model_info, 'epoch': train_epochs, 'loss': epoch_losses['total']}
    )
    
    # Extract and save embeddings AFTER training
    print("Extracting embeddings AFTER training...")
    after_emb_path = os.path.join(embedding_dir, f"after_training_{timestamp}.npz")
    after_embeddings, after_smiles = extract_and_save_embeddings(
        model, train_loader, device, after_emb_path
    )
    
    return model, metrics, before_embeddings, after_embeddings


In [5]:
def main():
    # Enable anomaly detection during development
    torch.autograd.set_detect_anomaly(True)
    # Your existing data loading code here
    torch.manual_seed(42)
    np.random.seed(42)
    
    print("Starting data loading...")
    extractor = MolecularFeatureExtractor()
    smiles_file = "D:\\PhD\\Chapter3\\Unsupervised_GAN_Code\\pubchem-41-clean.txt"
    
    dataset = []
    failed_smiles = []
    
    with open(smiles_file, 'r') as f:
        for i, line in enumerate(f):
            smiles = line.strip()
            data = extractor.process_molecule(smiles)
            if data is not None:
                # Store original SMILES in the data object
                data.smiles = smiles
                dataset.append(data)
            else:
                failed_smiles.append(smiles)
            
            # Limit dataset size for testing
            if i >= 10000:  # Adjust as needed
                break
    
    print(f"1. Loaded dataset with {len(dataset)} graphs.")
    print(f"2. Failed SMILES count: {len(failed_smiles)}")
    
    if not dataset:
        print("No valid graphs generated.")
        return None
        
    # Setup training
    batch_size = 32
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    print(f"3. Created DataLoader with {len(train_loader.dataset)} graphs")
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"4. Using device: {device}")
    
    # Get configuration based on dataset
    config = get_model_config(dataset)
    
    # Train model with embedding extraction
    print("5. Starting GAN-CL training with embedding extraction...")
    model, metrics, before_embeddings, after_embeddings = train_gan_cl_with_embeddings(
        train_loader, 
        config,
        device=device,
        save_dir='./checkpoints',
        embedding_dir='./embeddings'
    )
    
    print("6. Training completed!")
    print("7. Embeddings saved and ready for analysis")
    
    return model, metrics, before_embeddings, after_embeddings

if __name__ == "__main__":
    model, metrics, before_embeddings, after_embeddings = main()

Starting data loading...
1. Loaded dataset with 41 graphs.
2. Failed SMILES count: 0
3. Created DataLoader with 41 graphs
4. Using device: cpu
5. Starting GAN-CL training with embedding extraction...
Found 0 molecules with SMILES
Extracting embeddings BEFORE training...


Extracting embeddings: 100%|█████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 95.72it/s]


Saved 16 embeddings to ./embeddings\before_training_20250301_193202.npz
Phase 1: Pretraining Contrastive Learning...


Pretrain Epoch 1/10: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  7.32it/s]


Pretrain Epoch 1, Avg Loss: 1.7194


Pretrain Epoch 2/10: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  8.41it/s]


Pretrain Epoch 2, Avg Loss: 3.8957


Pretrain Epoch 3/10: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  8.82it/s]


Pretrain Epoch 3, Avg Loss: 4.3224


Pretrain Epoch 4/10: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  9.27it/s]


Pretrain Epoch 4, Avg Loss: 4.6567


Pretrain Epoch 5/10: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 10.01it/s]


Pretrain Epoch 5, Avg Loss: 4.8997


Pretrain Epoch 6/10: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  9.66it/s]


Pretrain Epoch 6, Avg Loss: 4.9795


Pretrain Epoch 7/10: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  8.28it/s]


Pretrain Epoch 7, Avg Loss: 5.1467


Pretrain Epoch 8/10: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  8.53it/s]


Pretrain Epoch 8, Avg Loss: 5.3348


Pretrain Epoch 9/10: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  9.52it/s]


Pretrain Epoch 9, Avg Loss: 5.3925


Pretrain Epoch 10/10: 100%|██████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  8.93it/s]


Pretrain Epoch 10, Avg Loss: 5.4766

Phase 2: Training GAN-CL...


Train Epoch 1/50: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.48it/s]


Epoch 1, Losses: {'contrastive': 5.596378564834595, 'adversarial': -0.00047190464101731777, 'similarity': 0.00038258724089246243, 'total': 5.59641695022583}


Train Epoch 2/50: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.11it/s]


Epoch 2, Losses: {'contrastive': 5.658815622329712, 'adversarial': -0.0004642332933144644, 'similarity': 0.00036507094046100974, 'total': 5.6588521003723145}


Train Epoch 3/50: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.80it/s]


Epoch 3, Losses: {'contrastive': 5.806594610214233, 'adversarial': -0.0003904362383764237, 'similarity': 0.0005367717531044036, 'total': 5.806648254394531}


Train Epoch 4/50: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.58it/s]


Epoch 4, Losses: {'contrastive': 5.83513617515564, 'adversarial': -0.0005119767738506198, 'similarity': 0.0004924715176457539, 'total': 5.835185289382935}


Train Epoch 5/50: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.91it/s]


Epoch 5, Losses: {'contrastive': 5.860702037811279, 'adversarial': -0.0004348681395640597, 'similarity': 0.0004323863686295226, 'total': 5.860745191574097}


Train Epoch 6/50: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.68it/s]


Epoch 6, Losses: {'contrastive': 5.844198226928711, 'adversarial': -0.0003904422919731587, 'similarity': 0.00045226978545542806, 'total': 5.84424352645874}


Train Epoch 7/50: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.70it/s]


Epoch 7, Losses: {'contrastive': 5.951338768005371, 'adversarial': -0.00046108711103443056, 'similarity': 0.00044237282418180257, 'total': 5.951383113861084}


Train Epoch 8/50: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.63it/s]


Epoch 8, Losses: {'contrastive': 5.9330055713653564, 'adversarial': -0.0006327366572804749, 'similarity': 0.00040989211993291974, 'total': 5.933046579360962}


Train Epoch 9/50: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.79it/s]


Epoch 9, Losses: {'contrastive': 5.950452566146851, 'adversarial': -0.0005251338734524325, 'similarity': 0.0005411187885329127, 'total': 5.950506687164307}


Train Epoch 10/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.43it/s]


Epoch 10, Losses: {'contrastive': 6.017263889312744, 'adversarial': -0.00046169135021045804, 'similarity': 0.0005470262403832749, 'total': 6.0173187255859375}


Extracting embeddings: 100%|████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 111.51it/s]


Saved 16 embeddings to ./embeddings\epoch_10_20250301_193202.npz


Train Epoch 11/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.62it/s]


Epoch 11, Losses: {'contrastive': 6.051062822341919, 'adversarial': -0.000581209606025368, 'similarity': 0.0005201846070121974, 'total': 6.051114797592163}


Train Epoch 12/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.67it/s]


Epoch 12, Losses: {'contrastive': 6.075205087661743, 'adversarial': -0.000592513766605407, 'similarity': 0.0006574607396032661, 'total': 6.075270891189575}


Train Epoch 13/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.84it/s]


Epoch 13, Losses: {'contrastive': 6.193239688873291, 'adversarial': -0.0005709250108338892, 'similarity': 0.0005657148722093552, 'total': 6.193296194076538}


Train Epoch 14/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.39it/s]


Epoch 14, Losses: {'contrastive': 6.199094772338867, 'adversarial': -0.00078483548713848, 'similarity': 0.0006220081704668701, 'total': 6.199156999588013}


Train Epoch 15/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.23it/s]


Epoch 15, Losses: {'contrastive': 6.176629543304443, 'adversarial': -0.0006407227192539722, 'similarity': 0.00045280103222467005, 'total': 6.1766746044158936}


Train Epoch 16/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.55it/s]


Epoch 16, Losses: {'contrastive': 6.243843078613281, 'adversarial': -0.0006534174899570644, 'similarity': 0.0007087438716553152, 'total': 6.2439141273498535}


Train Epoch 17/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.31it/s]


Epoch 17, Losses: {'contrastive': 6.240437030792236, 'adversarial': -0.0004926197871100157, 'similarity': 0.0007130775193218142, 'total': 6.240508317947388}


Train Epoch 18/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.71it/s]


Epoch 18, Losses: {'contrastive': 6.289941072463989, 'adversarial': -0.0005826707347296178, 'similarity': 0.0010103120584972203, 'total': 6.290042161941528}


Train Epoch 19/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.64it/s]


Epoch 19, Losses: {'contrastive': 6.360161542892456, 'adversarial': -0.0007306703773792833, 'similarity': 0.0006183400982990861, 'total': 6.3602235317230225}


Train Epoch 20/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.84it/s]


Epoch 20, Losses: {'contrastive': 6.288281202316284, 'adversarial': -0.0006001486908644438, 'similarity': 0.0005023618577979505, 'total': 6.288331508636475}


Extracting embeddings: 100%|████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 116.50it/s]


Saved 16 embeddings to ./embeddings\epoch_20_20250301_193202.npz


Train Epoch 21/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.58it/s]


Epoch 21, Losses: {'contrastive': 6.4971232414245605, 'adversarial': -0.0005636138084810227, 'similarity': 0.0006563501083292067, 'total': 6.4971888065338135}


Train Epoch 22/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.73it/s]


Epoch 22, Losses: {'contrastive': 6.38019871711731, 'adversarial': -0.0006153088470455259, 'similarity': 0.0005920703988522291, 'total': 6.380257844924927}


Train Epoch 23/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.29it/s]


Epoch 23, Losses: {'contrastive': 6.41270637512207, 'adversarial': -0.000652632734272629, 'similarity': 0.0005175917176529765, 'total': 6.412758111953735}


Train Epoch 24/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.81it/s]


Epoch 24, Losses: {'contrastive': 6.481417655944824, 'adversarial': -0.0006924055924173445, 'similarity': 0.0005414673214545473, 'total': 6.48147177696228}


Train Epoch 25/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.82it/s]


Epoch 25, Losses: {'contrastive': 6.577829360961914, 'adversarial': -0.000531294965185225, 'similarity': 0.0005799898353870958, 'total': 6.577887296676636}


Train Epoch 26/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.54it/s]


Epoch 26, Losses: {'contrastive': 6.514339208602905, 'adversarial': -0.0006074373086448759, 'similarity': 0.0005319837073329836, 'total': 6.514392375946045}


Train Epoch 27/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.83it/s]


Epoch 27, Losses: {'contrastive': 6.478627920150757, 'adversarial': -0.000607440946623683, 'similarity': 0.0005073939828434959, 'total': 6.4786787033081055}


Train Epoch 28/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.46it/s]


Epoch 28, Losses: {'contrastive': 6.5912299156188965, 'adversarial': -0.0006039880681782961, 'similarity': 0.0004394151910673827, 'total': 6.591273784637451}


Train Epoch 29/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.74it/s]


Epoch 29, Losses: {'contrastive': 6.658490180969238, 'adversarial': -0.0004665793967433274, 'similarity': 0.0007512524316553026, 'total': 6.658565521240234}


Train Epoch 30/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.39it/s]


Epoch 30, Losses: {'contrastive': 6.5169079303741455, 'adversarial': -0.0006816058303229511, 'similarity': 0.0005706476658815518, 'total': 6.516964912414551}


Extracting embeddings: 100%|████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 115.85it/s]


Saved 16 embeddings to ./embeddings\epoch_30_20250301_193202.npz


Train Epoch 31/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.96it/s]


Epoch 31, Losses: {'contrastive': 6.654305934906006, 'adversarial': -0.0005562263249885291, 'similarity': 0.0007086952100507915, 'total': 6.654376745223999}


Train Epoch 32/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.85it/s]


Epoch 32, Losses: {'contrastive': 6.563824892044067, 'adversarial': -0.0005418880900833756, 'similarity': 0.0004958454665029421, 'total': 6.5638744831085205}


Train Epoch 33/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.69it/s]


Epoch 33, Losses: {'contrastive': 6.637869358062744, 'adversarial': -0.0005862440739292651, 'similarity': 0.0006415430689230561, 'total': 6.6379334926605225}


Train Epoch 34/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.75it/s]


Epoch 34, Losses: {'contrastive': 6.571667194366455, 'adversarial': -0.0005910538020543754, 'similarity': 0.0004835508589167148, 'total': 6.571715593338013}


Train Epoch 35/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.76it/s]


Epoch 35, Losses: {'contrastive': 6.728488445281982, 'adversarial': -0.0006582586793228984, 'similarity': 0.0005158727435627952, 'total': 6.728539943695068}


Train Epoch 36/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.79it/s]


Epoch 36, Losses: {'contrastive': 6.735221862792969, 'adversarial': -0.0005385930417105556, 'similarity': 0.0006359810940921307, 'total': 6.735285520553589}


Train Epoch 37/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.57it/s]


Epoch 37, Losses: {'contrastive': 6.712361812591553, 'adversarial': -0.0005380775837693363, 'similarity': 0.0005040030519012362, 'total': 6.712412118911743}


Train Epoch 38/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.07it/s]


Epoch 38, Losses: {'contrastive': 6.78844690322876, 'adversarial': -0.0004915576719213277, 'similarity': 0.0008539023110643029, 'total': 6.788532257080078}


Train Epoch 39/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.61it/s]


Epoch 39, Losses: {'contrastive': 6.724242448806763, 'adversarial': -0.0005567369807977229, 'similarity': 0.0008958661928772926, 'total': 6.724332094192505}


Train Epoch 40/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.74it/s]


Epoch 40, Losses: {'contrastive': 6.812379598617554, 'adversarial': -0.0005975289386697114, 'similarity': 0.00046634182217530906, 'total': 6.812426328659058}


Extracting embeddings: 100%|████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 121.20it/s]


Saved 16 embeddings to ./embeddings\epoch_40_20250301_193202.npz


Train Epoch 41/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.65it/s]


Epoch 41, Losses: {'contrastive': 6.721327781677246, 'adversarial': -0.0005296022573020309, 'similarity': 0.0005246388027444482, 'total': 6.721380233764648}


Train Epoch 42/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.90it/s]


Epoch 42, Losses: {'contrastive': 6.893293142318726, 'adversarial': -0.0005508774484042078, 'similarity': 0.0005812654562760144, 'total': 6.893351316452026}


Train Epoch 43/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.73it/s]


Epoch 43, Losses: {'contrastive': 6.861454248428345, 'adversarial': -0.0005620890879072249, 'similarity': 0.0005763945227954537, 'total': 6.861511707305908}


Train Epoch 44/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.71it/s]


Epoch 44, Losses: {'contrastive': 6.811563968658447, 'adversarial': -0.0005953053187113255, 'similarity': 0.0006016761180944741, 'total': 6.811624050140381}


Train Epoch 45/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.61it/s]


Epoch 45, Losses: {'contrastive': 6.849266290664673, 'adversarial': -0.0005060021940153092, 'similarity': 0.0007777492864988744, 'total': 6.849344253540039}


Train Epoch 46/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.91it/s]


Epoch 46, Losses: {'contrastive': 6.788744688034058, 'adversarial': -0.0004619746468961239, 'similarity': 0.0004679817648138851, 'total': 6.7887914180755615}


Train Epoch 47/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.77it/s]


Epoch 47, Losses: {'contrastive': 6.830874919891357, 'adversarial': -0.00039488876063842326, 'similarity': 0.0004902023501927033, 'total': 6.830924034118652}


Train Epoch 48/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.82it/s]


Epoch 48, Losses: {'contrastive': 6.832772731781006, 'adversarial': -0.00047386328515131027, 'similarity': 0.0004629102040780708, 'total': 6.832818984985352}


Train Epoch 49/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.81it/s]


Epoch 49, Losses: {'contrastive': 6.809685707092285, 'adversarial': -0.000608099828241393, 'similarity': 0.0005322796496329829, 'total': 6.809738874435425}


Train Epoch 50/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.79it/s]


Epoch 50, Losses: {'contrastive': 6.788842678070068, 'adversarial': -0.00044866233656648546, 'similarity': 0.0004706869076471776, 'total': 6.788889646530151}


Extracting embeddings: 100%|████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 161.93it/s]


Saved 16 embeddings to ./embeddings\epoch_50_20250301_193202.npz
Extracting embeddings AFTER training...


Extracting embeddings: 100%|████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 102.59it/s]

Saved 16 embeddings to ./embeddings\after_training_20250301_193202.npz
6. Training completed!
7. Embeddings saved and ready for analysis



