In [1]:
import torch
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from typing import Tuple
import numpy as np
from torch_geometric.data import Data
import random
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from rdkit import RDLogger
from rdkit.Chem import RemoveHs
from datetime import datetime

# Suppress RDKit warnings
RDLogger.DisableLog('rdApp.warning')
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')



In [2]:
class MolecularFeatureExtractor:
    def __init__(self):
        self.atom_list = list(range(1, 119))
        self.chirality_list = [
            Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
            Chem.rdchem.ChiralType.CHI_OTHER
        ]
        self.bond_list = [
            Chem.rdchem.BondType.SINGLE,
            Chem.rdchem.BondType.DOUBLE, 
            Chem.rdchem.BondType.TRIPLE,
            Chem.rdchem.BondType.AROMATIC
        ]
        self.bonddir_list = [
            Chem.rdchem.BondDir.NONE,
            Chem.rdchem.BondDir.ENDUPRIGHT,
            Chem.rdchem.BondDir.ENDDOWNRIGHT
        ]

    def calc_atom_features(self, atom: Chem.Atom) -> Tuple[list, list]:
        """Calculate atom features with better error handling"""
        try:
            # Basic features
            atom_feat = [
                self.atom_list.index(atom.GetAtomicNum()),
                self.chirality_list.index(atom.GetChiralTag())
            ]

            # Physical features with error handling
            phys_feat = []
            
            # Molecular weight contribution
            try:
                contrib_mw = Descriptors.ExactMolWt(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_mw)
            except:
                phys_feat.append(0.0)
                
            # LogP contribution    
            try:
                contrib_logp = Descriptors.MolLogP(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_logp)
            except:
                phys_feat.append(0.0)
                
            # Add other physical properties
            phys_feat.extend([
                atom.GetFormalCharge(),
                int(atom.GetHybridization()),
                int(atom.GetIsAromatic()),
                atom.GetTotalNumHs(),
                atom.GetTotalValence(),
                atom.GetDegree()
            ])
            
            return atom_feat, phys_feat
            
        except Exception as e:
            print(f"Error calculating atom features: {e}")
            return [0, 0], [0.0] * 9

    def get_atom_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract atom features for the whole molecule"""
        atom_feats = []
        phys_feats = []
        
        if mol is None:
            return torch.tensor([[0, 0]], dtype=torch.long), torch.tensor([[0.0] * 9], dtype=torch.float)
            
        for atom in mol.GetAtoms():
            atom_feat, phys_feat = self.calc_atom_features(atom)
            atom_feats.append(atom_feat)
            phys_feats.append(phys_feat)

        x = torch.tensor(atom_feats, dtype=torch.long)
        phys = torch.tensor(phys_feats, dtype=torch.float)
        
        return x, phys
    
    def remove_unbonded_hydrogens(mol):
        params = Chem.RemoveHsParameters()
        params.removeDegreeZero = True
        mol = Chem.RemoveHs(mol, params)
        return mol


    def get_bond_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract bond features with better error handling"""
        if mol is None:
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)
            
        row, col, edge_feat = [], [], []
        
        for bond in mol.GetBonds():
            try:
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                
                # Add edges in both directions
                row += [start, end]
                col += [end, start]
                
                # Bond features
                bond_type = self.bond_list.index(bond.GetBondType())
                bond_dir = self.bonddir_list.index(bond.GetBondDir())
                
                # Calculate additional properties
                feat = [
                    bond_type,
                    bond_dir,
                    int(bond.GetIsConjugated()),
                    int(self._is_rotatable(bond)),
                    self._get_bond_length(mol, start, end)
                ]
                
                edge_feat.extend([feat, feat])
                
            except Exception as e:
                print(f"Error processing bond: {e}")
                continue

        if not row:  # If no valid bonds were processed
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)

        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = torch.tensor(edge_feat, dtype=torch.float)
        
        return edge_index, edge_attr

    def _is_rotatable(self, bond: Chem.Bond) -> bool:
        """Check if bond is rotatable"""
        return (bond.GetBondType() == Chem.rdchem.BondType.SINGLE and 
                not bond.IsInRing() and
                len(bond.GetBeginAtom().GetNeighbors()) > 1 and
                len(bond.GetEndAtom().GetNeighbors()) > 1)

    def _get_bond_length(self, mol: Chem.Mol, start: int, end: int) -> float:
        """Get bond length with error handling"""
        try:
            conf = mol.GetConformer()
            if conf.Is3D():
                return Chem.rdMolTransforms.GetBondLength(conf, start, end)
        except:
            pass
        return 0.0

    def process_molecule(self, smiles: str) -> Data:
        """Process SMILES string to graph data"""
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                print(f"Invalid SMILES: {smiles}")
                return None  # Skip invalid molecules
            mol = RemoveHs(mol)

            # Add explicit hydrogens
            mol = Chem.AddHs(mol, addCoords=True)

            # Sanitize molecule
            Chem.SanitizeMol(mol)

            # Check if the molecule has atoms
            if mol.GetNumAtoms() == 0:
                print("Molecule has no atoms, skipping.")
                return None

            # Generate 3D coordinates
            if not mol.GetNumConformers():
                status = AllChem.EmbedMolecule(mol, AllChem.ETKDG())
                if status != 0:
                    print("Failed to generate 3D conformer")
                    return None  # Skip failed molecules

                # Try MMFF or UFF optimization
                try:
                    AllChem.MMFFOptimizeMolecule(mol)
                except:
                    AllChem.UFFOptimizeMolecule(mol)

            # Extract features
            x_cat, x_phys = self.get_atom_features(mol)
            edge_index, edge_attr = self.get_bond_features(mol)

            return Data(
                x_cat=x_cat, 
                x_phys=x_phys,
                edge_index=edge_index, 
                edge_attr=edge_attr,
                num_nodes=x_cat.size(0)
            )

        except Exception as e:
            print(f"Error processing molecule {smiles}: {e}")
            return None


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, MessagePassing
from typing import Tuple, List, Optional
import copy
from dataclasses import dataclass

class MemoryQueue:
    """Memory queue with temporal decay for contrastive learning"""
    def __init__(self, size: int, dim: int, decay: float = 0.99999):
        self.size = size
        self.dim = dim
        self.decay = decay
        self.ptr = 0
        self.full = False
        
        # Initialize queue
        self.queue = nn.Parameter(F.normalize(torch.randn(size, dim), dim=1), requires_grad=False)
        self.queue_age = nn.Parameter(torch.zeros(size), requires_grad=False)
        
#         self.register_buffer("queue", torch.randn(size, dim))
#         self.register_buffer("queue_age", torch.zeros(size))  # Track age of each entry
        self.queue = F.normalize(self.queue, dim=1)
        
    def update_queue(self, keys: torch.Tensor):
        """Update queue with new keys"""
        batch_size = keys.shape[0]
        
        # Increment age of all entries
        self.queue_age += 1
        
        # Add new keys
        if self.ptr + batch_size <= self.size:
            self.queue[self.ptr:self.ptr + batch_size] = keys
            self.queue_age[self.ptr:self.ptr + batch_size] = 0
        else:
            # Handle overflow
            rem = self.size - self.ptr
            self.queue[self.ptr:] = keys[:rem]
            self.queue[:batch_size-rem] = keys[rem:]
            self.queue_age[self.ptr:] = 0
            self.queue_age[:batch_size-rem] = 0
            self.full = True
            
        self.ptr = (self.ptr + batch_size) % self.size
        
    def get_decay_weights(self) -> torch.Tensor:
        """Get temporal decay weights for queue entries"""
        return self.decay ** self.queue_age
        
    def compute_contrastive_loss(self, query: torch.Tensor, positive_key: torch.Tensor, 
                                temperature: float = 0.07) -> torch.Tensor:
        """Compute contrastive loss with temporal decay"""
        # Normalize embeddings
        query = F.normalize(query, dim=1)
        positive_key = F.normalize(positive_key, dim=1)
        queue = F.normalize(self.queue, dim=1)
        
        # Compute logits
        l_pos = torch.einsum('nc,nc->n', [query, positive_key]).unsqueeze(-1)
        l_neg = torch.einsum('nc,ck->nk', [query, queue.T])
        
        # Apply temporal decay to negative samples
        decay_weights = self.get_decay_weights()
        l_neg = l_neg * decay_weights.unsqueeze(0)
        
        # Temperature scaling
        logits = torch.cat([l_pos, l_neg], dim=1) / temperature
        labels = torch.zeros(logits.shape[0], dtype=torch.long, device=query.device)
        
        return F.cross_entropy(logits, labels)

class GraphGenerator(nn.Module):
    """Generator network with proper feature handling"""
    def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int = 128):
        super().__init__()
        
        # Node feature processing
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Edge feature processing
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Graph convolution layers
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        
        # Importance prediction layers
        self.node_importance = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()
        )
        
        self.edge_importance = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
    def normalize_features(self, x_cat, x_phys):
        """Normalize categorical and physical features separately"""
        # Convert categorical features to one-hot
        x_cat = x_cat.float()
        
        # Normalize physical features
        x_phys = x_phys.float()
        if x_phys.size(0) > 1:  # Only normalize if we have more than one sample
            x_phys = (x_phys - x_phys.mean(0)) / (x_phys.std(0) + 1e-5)
            
        return x_cat, x_phys
        
    def forward(self, data) -> Tuple[torch.Tensor, torch.Tensor]:
        # Normalize features
        x_cat, x_phys = self.normalize_features(data.x_cat, data.x_phys)
        
        # Concatenate features
        x = torch.cat([x_cat, x_phys], dim=-1)
        
        edge_index = data.edge_index
        edge_attr = data.edge_attr.float()  # Ensure float type
        
        # Initial feature encoding
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)
        
        # Graph convolutions
        x = F.relu(self.conv1(x, edge_index))  # Removed edge_attr from GCNConv
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        
        # Predict importance scores
        node_scores = self.node_importance(x)
        
        # Edge scores using both connected nodes
        edge_features = torch.cat([
            x[edge_index[0]], 
            x[edge_index[1]]
        ], dim=-1)
        edge_scores = self.edge_importance(edge_features)
        
        return node_scores, edge_scores

def get_model_config(dataset):
    """Get model configuration based on dataset features"""
    sample_data = dataset[0]
    
    # Calculate input dimensions
    node_dim = sample_data.x_cat.shape[1] + sample_data.x_phys.shape[1]
    edge_dim = sample_data.edge_attr.shape[1]
    
    config = GanClConfig(
        node_dim=node_dim,
        edge_dim=edge_dim,
        hidden_dim=128,
        output_dim=128,
        queue_size=65536,
        momentum=0.999,
        temperature=0.07,
        decay=0.99999,
        dropout_ratio=0.25
    )
    
    return config

class GraphDiscriminator(nn.Module):
    """Discriminator/Encoder network"""
    def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int = 128, output_dim: int = 128):
        super().__init__()
        
        # Feature encoding
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Graph convolution layers
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, output_dim)
        
        # Projection head for contrastive learning
        self.projection = nn.Sequential(
            nn.Linear(output_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def normalize_features(self, x_cat, x_phys):
        """Normalize categorical and physical features separately"""
        # Convert categorical features to one-hot
        x_cat = x_cat.float()
        
        # Normalize physical features
        x_phys = x_phys.float()
        if x_phys.size(0) > 1:  # Only normalize if we have more than one sample
            x_phys = (x_phys - x_phys.mean(0)) / (x_phys.std(0) + 1e-5)
            
        return x_cat, x_phys 
        
    def forward(self, data):
        # Normalize features
        x_cat, x_phys = self.normalize_features(data.x_cat, data.x_phys)
        
        # Concatenate features
        x = torch.cat([x_cat, x_phys], dim=-1)
        
        edge_index = data.edge_index
        edge_attr = data.edge_attr.float()  # Ensure float type
        batch = data.batch
        
        # Initial feature encoding
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)
        
        # Graph convolutions
        x = F.relu(self.conv1(x, edge_index))  # Removed edge_attr from GCNConv
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        
        # Global pooling
        x = global_mean_pool(x, batch)
        
        # Projection
        x = self.projection(x)
        
        return x

@dataclass
class GanClConfig:
    """Configuration for GAN-CL training"""
    node_dim: int
    edge_dim: int
    hidden_dim: int = 128
    output_dim: int = 128
    queue_size: int = 65536
    momentum: float = 0.999
    temperature: float = 0.07
    decay: float = 0.99999
    dropout_ratio: float = 0.25

class MolecularGANCL(nn.Module):
    """Combined GAN and Contrastive Learning framework"""
    def __init__(self, config: GanClConfig):
        super().__init__()
        self.config = config
        
        # Add weight initialization
        def init_weights(m):
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                m.bias.data.fill_(0.01)
        
        # Initialize networks
        self.generator = GraphGenerator(
            config.node_dim, 
            config.edge_dim, 
            config.hidden_dim * 2
        )
        
        self.encoder = GraphDiscriminator(
            config.node_dim,
            config.edge_dim,
            config.hidden_dim,
            config.output_dim
        )
        self.encoder.apply(init_weights)
        
        # Modified loss weights
        self.contrastive_weight = 1.0
        self.adversarial_weight = 0.1  # Increased from 0.05
        self.similarity_weight = 0.01  # Decreased from 0.1
        
        # Temperature annealing
        self.initial_temperature = 0.1
        self.min_temperature = 0.05        
        
        # Create momentum encoder
        self.momentum_encoder = copy.deepcopy(self.encoder)
        for param in self.momentum_encoder.parameters():
            param.requires_grad = False
            
        # Initialize memory queue
        self.memory_queue = MemoryQueue(
            config.queue_size,
            config.output_dim,
            config.decay
        )
        
    @torch.no_grad()
    def _momentum_update(self):
        """Update momentum encoder"""
        for param_q, param_k in zip(self.encoder.parameters(), 
                                  self.momentum_encoder.parameters()):
            param_k.data = self.config.momentum * param_k.data + \
                          (1 - self.config.momentum) * param_q.data
                          
    def drop_graph_elements(self, data, node_scores: torch.Tensor, 
                          edge_scores: torch.Tensor) -> Data:
        """Apply dropout to graph based on importance scores"""
        # Select elements to keep based on scores and dropout ratio
#         node_mask = (node_scores < self.config.dropout_ratio).float()
#         edge_mask = (edge_scores < self.config.dropout_ratio).float()

        node_mask = (torch.rand_like(node_scores) > self.config.dropout_ratio).float()
        edge_mask = (torch.rand_like(edge_scores) > self.config.dropout_ratio).float()
        
        # Apply masks
        x_cat_new = data.x_cat * node_mask
        x_phys_new = data.x_phys * node_mask
        edge_attr_new = data.edge_attr * edge_mask
        
        # Create new graph data object
        return Data(
            x_cat=x_cat_new,
            x_phys=x_phys_new,
            edge_index=data.edge_index,
            edge_attr=edge_attr_new,
            batch=data.batch
        )
        
    def get_temperature(self, epoch, total_epochs):
        """Anneal temperature during training"""
        progress = epoch / total_epochs
        return max(self.initial_temperature * (1 - progress), self.min_temperature)
    
    def forward(self, data, epoch=0, total_epochs=50):
        # Get current temperature
        temperature = self.get_temperature(epoch, total_epochs)
        
        # Get importance scores from generator
        node_scores, edge_scores = self.generator(data)
        
        # Create perturbed graph
        perturbed_data = self.drop_graph_elements(data, node_scores, edge_scores)
        
        # Get embeddings
        query_emb = self.encoder(perturbed_data)
        with torch.no_grad():
            key_emb = self.momentum_encoder(data)
            original_emb = self.encoder(data).detach()
        
        # Compute losses with modified weights
        contrastive_loss = self.memory_queue.compute_contrastive_loss(
            query_emb, key_emb, temperature
        ) * self.contrastive_weight
        
        adversarial_loss = -F.mse_loss(query_emb, original_emb) * self.adversarial_weight
        similarity_loss = F.mse_loss(query_emb, original_emb) * self.similarity_weight
        
        return contrastive_loss, adversarial_loss, similarity_loss
    
    def get_embeddings(self, data) -> torch.Tensor:
        """Get embeddings for downstream tasks"""
        with torch.no_grad():
            return self.encoder(data)

In [None]:
import pickle
import numpy as np
from rdkit import Chem
from rdkit.Chem import Descriptors, Lipinski, MolSurf

def extract_molecular_properties(mol):
    """Extract basic molecular properties"""
    if mol is None:
        return {}
        
    properties = {
        # Basic properties
        'MW': Descriptors.MolWt(mol),
        'LogP': Descriptors.MolLogP(mol),
        'TPSA': Descriptors.TPSA(mol),
        'HBA': Lipinski.NumHAcceptors(mol),
        'HBD': Lipinski.NumHDonors(mol),
        'RotBonds': Descriptors.NumRotatableBonds(mol),
        'HeavyAtoms': mol.GetNumHeavyAtoms(),
        'Rings': Chem.GetSSSR(mol),
        'NumRings': Chem.rdMolDescriptors.CalcNumRings(mol),
        'NumAromaticRings': Chem.rdMolDescriptors.CalcNumAromaticRings(mol),
        'NumAliphaticRings': Chem.rdMolDescriptors.CalcNumAliphaticRings(mol),
        'FractionCSP3': Chem.rdMolDescriptors.CalcFractionCSP3(mol),
    }
    
    return properties

def identify_structural_features(mol):
    """Identify structural feature groups"""
    if mol is None:
        return {}
        
    features = {
        # Ring systems
        'HasAromatic': mol.GetNumAromaticRings() > 0,
        'HasHeterocycle': any(atom.GetAtomicNum() != 6 for ring in mol.GetSSSR() for atom in ring),
        
        # Advanced ring features
        'HasFusedRings': Chem.rdMolDescriptors.CalcNumSpiroAtoms(mol) > 0,
        'HasSpiroRings': Chem.rdMolDescriptors.CalcNumBridgeheadAtoms(mol) > 0,
        'HasBridgedRings': Chem.rdMolDescriptors.CalcNumBridgeheadAtoms(mol) > 0,
        'HasMacrocycle': any(len(ring) >= 8 for ring in mol.GetSSSR()),
        
        # Chain features
        'HasLongChain': max([len(chain) for chain in Chem.rdMolDescriptors.GetMolFrags(mol) if all(mol.GetAtomWithIdx(idx).IsInRing() == False for idx in chain)], default=0) >= 6,
        'IsBranched': Chem.rdMolDescriptors.CalcNumBranches(mol) > 0,
    }
    
    return features

def identify_functional_groups(mol):
    """Identify functional groups"""
    if mol is None:
        return {}
    
    # SMARTS patterns for functional groups
    patterns = {
        'Alcohol': '[OX2H]',
        'Amine': '[NX3;H2,H1,H0;!$(NC=O)]',
        'Carboxyl': '[CX3](=O)[OX2H1]',
        'Carbonyl': '[CX3]=O',
        'Ether': '[OD2](C)C',
        'Ester': '[#6][CX3](=O)[OX2H0][#6]',
        'Amide': '[NX3][CX3](=[OX1])',
        'Halogen': '[F,Cl,Br,I]',
        'Nitrile': '[CX2]#[NX1]',
        'Nitro': '[NX3](=O)=O',
        'Sulfide': '[#16X2]',
        'Sulfoxide': '[#16X3](=[OX1])',
        'Sulfone': '[#16X4](=[OX1])2',
    }
    
    functional_groups = {}
    for name, smarts in patterns.items():
        pattern = Chem.MolFromSmarts(smarts)
        functional_groups[f'Has{name}'] = len(mol.GetSubstructMatches(pattern)) > 0
    
    return functional_groups

def analyze_rings(mol):
    """Analyze ring structures in detail"""
    if mol is None:
        return {}
    
    ring_info = {}
    
    # Ring counts by size
    ring_sizes = [3, 4, 5, 6, 7, 8]
    for size in ring_sizes:
        ring_info[f'RingSize{size}'] = len([ring for ring in mol.GetSSSR() if len(ring) == size])
    
    # Ring types
    ring_info['NumSpiroRings'] = Chem.rdMolDescriptors.CalcNumSpiroAtoms(mol)
    ring_info['NumBridgedRings'] = Chem.rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
    
    # Count fused ring systems (this is more complex and requires custom implementation)
    # This is a simplification
    ring_info['NumFusedRingSystems'] = max(0, Chem.rdMolDescriptors.CalcNumRings(mol) - ring_info['NumSpiroRings'] - ring_info['NumBridgedRings'] - Chem.rdMolDescriptors.CalcNumRings(mol, countFused=False))
    
    return ring_info

def extract_all_features(smiles):
    """Extract all features from a SMILES string"""
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    
    # Combined feature dictionary
    features = {
        'SMILES': smiles,
        'Mol': mol,
    }
    
    # Add all property categories
    features.update(extract_molecular_properties(mol))
    features.update(identify_structural_features(mol))
    features.update(identify_functional_groups(mol))
    features.update(analyze_rings(mol))
    
    return features

def save_enhanced_embeddings(embeddings, graphs, filepath, stage="after_training"):
    """
    Save embeddings with enhanced molecular information
    
    Parameters:
    - embeddings: numpy array of embeddings
    - graphs: list of graph objects from the model
    - filepath: where to save the data
    - stage: "before_training" or "after_training"
    """
    # Extract SMILES from graphs 
    # Note: This assumes that the SMILES string can be extracted from the graph objects
    # You might need to adjust this based on your data structure
    molecules_data = []
    
    for i, graph in enumerate(graphs):
        # Extract SMILES from graph (you'll need to implement this based on your data structure)
        # smiles = extract_smiles_from_graph(graph)
        smiles = "C"  # Placeholder - replace with actual extraction
        
        # Create data entry
        mol_data = {
            'embedding': embeddings[i],
            'graph': graph,  # Original graph for reference
            'embedding_dim': embeddings[i].shape[0],
            'stage': stage,
            # Add placeholder for features to be filled later when analyzing
            'features': {} 
        }
        
        molecules_data.append(mol_data)
    
    # Save the enhanced data
    with open(filepath, 'wb') as f:
        pickle.dump(molecules_data, f)
    
    return molecules_data

In [4]:
import torch
import numpy as np
from torch_geometric.data import DataLoader
import os
import json
from tqdm import tqdm
import pickle
from datetime import datetime

def save_embeddings(embeddings, labels, filepath):
    """Save embeddings and corresponding labels"""
    with open(filepath, 'wb') as f:
        pickle.dump({
            'embeddings': embeddings,
            'labels': labels
        }, f)

def save_encoder(encoder, save_path, info=None):
    """Save encoder model for downstream tasks"""
    save_dict = {
        'encoder_state_dict': encoder.state_dict(),
        'model_info': info or {}
    }
    torch.save(save_dict, save_path)

def load_encoder(model_path, device='cpu'):
    """Load saved encoder model"""
    checkpoint = torch.load(model_path, map_location=device)
    encoder = GraphDiscriminator(
        node_dim=checkpoint['model_info'].get('node_dim'),
        edge_dim=checkpoint['model_info'].get('edge_dim'),
        hidden_dim=checkpoint['model_info'].get('hidden_dim', 128),
        output_dim=checkpoint['model_info'].get('output_dim', 128)
    )
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    return encoder        
        
def train_gan_cl(train_loader, config, device='cuda', 
                save_dir='./checkpoints', 
                embedding_dir='./embeddings'):
    """Main training function for GAN-CL with fixed gradient computation"""
    
    # Create directories
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(embedding_dir, exist_ok=True)
    encoder_dir = os.path.join(save_dir, 'encoders')
    os.makedirs(encoder_dir, exist_ok=True)    
    
    # Initialize model
    model = MolecularGANCL(config).to(device)
    
    # Initialize optimizers
    optimizer_encoder = torch.optim.Adam(model.encoder.parameters(), lr=3e-4)
    optimizer_generator = torch.optim.Adam(model.generator.parameters(), lr=1e-4)
    
    # Save initial model info
    model_info = {
        'node_dim': config.node_dim,
        'edge_dim': config.edge_dim,
        'hidden_dim': config.hidden_dim,
        'output_dim': config.output_dim,
        'training_config': config.__dict__
    }
    
    # Training phases as before...
    best_loss = float('inf')
    
    # Training metrics
    metrics = {
        'contrastive_losses': [],
        'adversarial_losses': [],
        'similarity_losses': [],
        'total_losses': []
    }
    
    # Training phases
    print("Phase 1: Pretraining Contrastive Learning...")
    pretrain_epochs = 10
    for epoch in range(pretrain_epochs):
        contrastive_epoch_loss = 0
        
        for batch in tqdm(train_loader, desc=f'Pretrain Epoch {epoch+1}/{pretrain_epochs}'):
            batch = batch.to(device)
            
            # Forward pass (without generator)
            query_emb = model.encoder(batch)
            with torch.no_grad():
                key_emb = model.momentum_encoder(batch)
            
            # Compute contrastive loss
            contrastive_loss = model.memory_queue.compute_contrastive_loss(
                query_emb, key_emb, model.config.temperature
            )
            
            # Update encoder
            optimizer_encoder.zero_grad()
            contrastive_loss.backward()
            optimizer_encoder.step()
            
            # Update momentum encoder
            model._momentum_update()
            
            # Update memory queue
            model.memory_queue.update_queue(key_emb.detach())
            
            contrastive_epoch_loss += contrastive_loss.item()
            
        avg_loss = contrastive_epoch_loss / len(train_loader)
        metrics['contrastive_losses'].append(avg_loss)
        print(f'Pretrain Epoch {epoch+1}, Avg Loss: {avg_loss:.4f}')
        
        # Save pretrained checkpoint
        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_encoder_state_dict': optimizer_encoder.state_dict(),
                'loss': avg_loss,
            }, os.path.join(save_dir, f'pretrain_checkpoint_{epoch+1}.pt'))
    
    print("\nPhase 2: Training GAN-CL...")
    train_epochs = 50
#     train_epochs = 10
    for epoch in range(train_epochs):
        epoch_losses = {
            'contrastive': 0,
            'adversarial': 0,
            'similarity': 0,
            'total': 0
        }
        
        for batch in tqdm(train_loader, desc=f'Train Epoch {epoch+1}/{train_epochs}'):
            batch = batch.to(device)
            
            # Step 1: Train Encoder
            optimizer_encoder.zero_grad()
            
            # Get importance scores from generator
            with torch.no_grad():
                node_scores, edge_scores = model.generator(batch)
            
            # Create perturbed graph
            perturbed_data = model.drop_graph_elements(batch, node_scores, edge_scores)
            
            # Get embeddings
            query_emb = model.encoder(perturbed_data)
            with torch.no_grad():
                key_emb = model.momentum_encoder(batch)
                original_emb = model.encoder(batch).detach()
            
            # Compute losses for encoder
            contrastive_loss = model.memory_queue.compute_contrastive_loss(
                query_emb, key_emb, model.config.temperature
            )
            similarity_loss = F.mse_loss(query_emb, original_emb)
            
            # Total loss for encoder
            encoder_loss = contrastive_loss + 0.1 * similarity_loss
            
            # Update encoder
            encoder_loss.backward()
            optimizer_encoder.step()
            
            # Update momentum encoder
            model._momentum_update()
            
            # Step 2: Train Generator
            optimizer_generator.zero_grad()
            
            # Get new embeddings for adversarial loss
            node_scores, edge_scores = model.generator(batch)
            perturbed_data = model.drop_graph_elements(batch, node_scores, edge_scores)
            
            with torch.no_grad():
                original_emb = model.encoder(batch)
            perturbed_emb = model.encoder(perturbed_data)
            
            # Compute adversarial loss
            adversarial_loss = -F.mse_loss(perturbed_emb, original_emb)
            
            # Update generator
            adversarial_loss.backward()
            optimizer_generator.step()
            
            # Update memory queue
            model.memory_queue.update_queue(key_emb.detach())
            
            # Update metrics
            epoch_losses['contrastive'] += contrastive_loss.item()
            epoch_losses['adversarial'] += adversarial_loss.item()
            epoch_losses['similarity'] += similarity_loss.item()
            epoch_losses['total'] += encoder_loss.item()
        
        # Average losses
        for k in epoch_losses:
            epoch_losses[k] /= len(train_loader)
            metrics[f'{k}_losses'].append(epoch_losses[k])
        
        print(f'Epoch {epoch+1}, Losses: {epoch_losses}')
        
        # Save checkpoint
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_encoder_state_dict': optimizer_encoder.state_dict(),
                'optimizer_generator_state_dict': optimizer_generator.state_dict(),
                'losses': epoch_losses,
            }, os.path.join(save_dir, f'gan_cl_checkpoint_{epoch+1}.pt'))
            
        # Extract and save embeddings periodically
        if (epoch + 1) % 10 == 0:
            model.eval()
            all_embeddings = []
            all_graphs = []
            
            with torch.no_grad():
                for batch in train_loader:
                    batch = batch.to(device)
                    embeddings = model.get_embeddings(batch)
                    all_embeddings.append(embeddings.cpu())
                    all_graphs.extend([data for data in batch])
            
            all_embeddings = torch.cat(all_embeddings, dim=0)
            
            # Save embeddings
#             timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            save_embeddings(
                all_embeddings.numpy(),
                all_graphs,
                os.path.join(embedding_dir, f'embeddings_epoch_{epoch+1}_{timestamp}.pkl')
            )
            
            model.train()
    
    # Save final metrics
    with open(os.path.join(save_dir, 'training_metrics.json'), 'w') as f:
        json.dump(metrics, f)
    
            # Save encoder periodically
        if (epoch + 1) % 10 == 0:
            epoch_info = {
                **model_info,
                'epoch': epoch + 1,
                'loss': epoch_losses['total']
            }
            save_encoder(
                model.encoder,
                os.path.join(encoder_dir, f'encoder_epoch_{epoch+1}.pt'),
                epoch_info
            )
        
        # Save best encoder based on total loss
        if epoch_losses['total'] < best_loss:
            best_loss = epoch_losses['total']
            save_encoder(
                model.encoder,
                os.path.join(encoder_dir, f'best_encoder_{timestamp}.pt'),
                {**model_info, 'epoch': epoch + 1, 'loss': best_loss}
            )
    
    # Save final encoder
    save_encoder(
        model.encoder,
        os.path.join(encoder_dir, f'final_encoder_{timestamp}.pt'),
        {**model_info, 'epoch': train_epochs, 'loss': epoch_losses['total']}
    )
    
    return model, metrics


In [5]:
def main():
     # Enable anomaly detection during development
    torch.autograd.set_detect_anomaly(True)
    # Your existing data loading code here
    torch.manual_seed(42)
    np.random.seed(42)
    
    print("Starting data loading...")
    extractor = MolecularFeatureExtractor()
#     smiles_file = "D:\\PhD\\Chapter3\\Unsupervised_GAN_Code\\pubchem-41-clean.txt"
    smiles_file = "D:\\PhD\\Chapter3\\Unsupervised_GAN_Code\\pubchem-10m-clean_test50k.txt"
    
    dataset = []
    failed_smiles = []
    
    with open(smiles_file, 'r') as f:
        for line in f:
            smiles = line.strip()
            data = extractor.process_molecule(smiles)
            if data is not None:
                dataset.append(data)
            else:
                failed_smiles.append(smiles)
    
    print(f"1. Loaded dataset with {len(dataset)} graphs.")
    print(f"2. Failed SMILES count: {len(failed_smiles)}")
    
    if not dataset:
        print("No valid graphs generated.")
        return None
        
    # Setup training
    batch_size = 32
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    print(f"3. Created DataLoader with {len(train_loader.dataset)} graphs")
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"4. Using device: {device}")
    
    # Get configuration based on dataset
    config = get_model_config(dataset)
    
#     # Configure GAN-CL
#     config = GanClConfig(
#         node_dim=dataset[0].x_cat.shape[1] + dataset[0].x_phys.shape[1],
#         edge_dim=dataset[0].edge_attr.shape[1],
#         hidden_dim=128,
#         output_dim=128,
#         queue_size=65536,
#         momentum=0.999,
#         temperature=0.07,
#         decay=0.99999,
#         dropout_ratio=0.25
#     )
    
    # Train model
    print("5. Starting GAN-CL training...")
    model, metrics = train_gan_cl(
        train_loader, 
        config,
        device=device,
        save_dir='./checkpoints',
        embedding_dir='./embeddings'
    )
    
    print("6. Training completed!")
    
    # Extract embeddings for XAI
    print("7. Extracting final embeddings for XAI...")
    model.eval()
    with torch.no_grad():
        all_embeddings = []
        all_graphs = []
        
        for batch in tqdm(train_loader, desc="Extracting embeddings"):
            batch = batch.to(device)
            embeddings = model.get_embeddings(batch)
            all_embeddings.append(embeddings.cpu())
            all_graphs.extend([data for data in batch])
            
    all_embeddings = torch.cat(all_embeddings, dim=0).numpy()
    
    # Save final embeddings and graphs
#     final_embedding_path = './embeddings/final_embeddings.pkl'
    final_embedding_path = f'./embeddings/final_embeddings_{timestamp}.pkl'
    save_embeddings(all_embeddings, all_graphs, final_embedding_path)
    print(f"8. Final embeddings saved to {final_embedding_path}")
    
    # Print encoder locations
    print(f"9. Encoders saved in ./checkpoints/encoders/:")
    print(f"   - Best encoder: best_encoder.pt")
    print(f"   - Final encoder: final_encoder.pt")
    print(f"   - Periodic encoders: encoder_epoch_*.pt")
    
    return model, metrics, all_embeddings, all_graphs

if __name__ == "__main__":
    model, metrics, embeddings, graphs = main()

Starting data loading...
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[23:40:14] UFFTYPER: Unrecognized atom type: Se2+2 (17)


Failed to generate 3D conformer


[23:40:18] UFFTYPER: Unrecognized charge state for atom: 1


Failed to generate 3D conformer


[23:40:26] UFFTYPER: Unrecognized charge state for atom: 4


Failed to generate 3D conformer


[23:40:56] UFFTYPER: Unrecognized atom type: S_5+4 (11)


Failed to generate 3D conformer


[23:41:17] UFFTYPER: Unrecognized atom type: Se2+2 (14)
[23:41:17] UFFTYPER: Unrecognized charge state for atom: 20
[23:41:17] UFFTYPER: Unrecognized charge state for atom: 40


Failed to generate 3D conformer


[23:41:18] UFFTYPER: Unrecognized charge state for atom: 5
[23:41:23] UFFTYPER: Unrecognized charge state for atom: 9
[23:41:27] UFFTYPER: Unrecognized charge state for atom: 2


Failed to generate 3D conformer


[23:41:40] UFFTYPER: Unrecognized charge state for atom: 8
[23:42:17] UFFTYPER: Unrecognized charge state for atom: 6


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[23:43:52] UFFTYPER: Unrecognized charge state for atom: 2


Failed to generate 3D conformer


[23:44:10] UFFTYPER: Unrecognized charge state for atom: 3
[23:44:10] UFFTYPER: Unrecognized charge state for atom: 7


Failed to generate 3D conformer
Failed to generate 3D conformer


[23:45:04] UFFTYPER: Unrecognized atom type: S_6+6 (17)
[23:45:12] UFFTYPER: Unrecognized atom type: Se2+2 (7)
[23:45:12] UFFTYPER: Unrecognized atom type: Se2+2 (7)


Failed to generate 3D conformer


[23:45:28] UFFTYPER: Unrecognized charge state for atom: 5
[23:45:42] UFFTYPER: Unrecognized charge state for atom: 13


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[23:46:20] UFFTYPER: Unrecognized charge state for atom: 8
[23:46:33] UFFTYPER: Unrecognized atom type: Se2+2 (16)


Failed to generate 3D conformer


[23:46:52] UFFTYPER: Unrecognized charge state for atom: 4


Failed to generate 3D conformer
Failed to generate 3D conformer


[23:47:23] UFFTYPER: Unrecognized charge state for atom: 1


Failed to generate 3D conformer


[23:48:35] UFFTYPER: Unrecognized charge state for atom: 17
[23:48:35] UFFTYPER: Unrecognized charge state for atom: 19
[23:48:38] UFFTYPER: Unrecognized charge state for atom: 8


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[23:50:52] UFFTYPER: Unrecognized charge state for atom: 8
[23:50:54] UFFTYPER: Unrecognized charge state for atom: 1
[23:51:08] UFFTYPER: Unrecognized atom type: Se2+2 (9)


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[23:52:15] UFFTYPER: Unrecognized atom type: S_5+4 (1)
[23:52:25] UFFTYPER: Unrecognized atom type: S_5+4 (10)


Failed to generate 3D conformer


[23:52:45] UFFTYPER: Unrecognized hybridization for atom: 1
[23:52:45] UFFTYPER: Unrecognized atom type: Pt+2 (1)


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[23:56:19] UFFTYPER: Unrecognized charge state for atom: 3
[23:56:20] UFFTYPER: Unrecognized charge state for atom: 15
[23:56:25] UFFTYPER: Unrecognized atom type: S_5+4 (15)


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:00:37] UFFTYPER: Unrecognized atom type: S_5+4 (20)


Failed to generate 3D conformer
Failed to generate 3D conformer


[00:00:50] UFFTYPER: Unrecognized atom type: S_5+4 (8)
[00:00:50] UFFTYPER: Unrecognized atom type: S_5+4 (18)
[00:00:52] UFFTYPER: Unrecognized charge state for atom: 3
[00:00:52] UFFTYPER: Unrecognized charge state for atom: 5
[00:00:52] UFFTYPER: Unrecognized charge state for atom: 6
[00:00:52] UFFTYPER: Unrecognized charge state for atom: 7
[00:00:52] UFFTYPER: Unrecognized atom type: Se2+2 (7)


Failed to generate 3D conformer
Failed to generate 3D conformer


[00:01:04] UFFTYPER: Unrecognized charge state for atom: 30


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:01:52] UFFTYPER: Unrecognized atom type: S_6+6 (1)
[00:01:57] UFFTYPER: Unrecognized charge state for atom: 3


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:02:28] UFFTYPER: Unrecognized charge state for atom: 1


Failed to generate 3D conformer


[00:02:55] UFFTYPER: Unrecognized charge state for atom: 8


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:03:39] UFFTYPER: Unrecognized charge state for atom: 11
[00:03:43] UFFTYPER: Unrecognized charge state for atom: 7


Failed to generate 3D conformer
Failed to generate 3D conformer


[00:04:28] UFFTYPER: Unrecognized atom type: S_5+4 (17)
[00:04:29] UFFTYPER: Unrecognized charge state for atom: 17


Failed to generate 3D conformer


[00:04:40] UFFTYPER: Unrecognized charge state for atom: 19
[00:04:45] UFFTYPER: Unrecognized charge state for atom: 6


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:05:49] UFFTYPER: Unrecognized charge state for atom: 10


Failed to generate 3D conformer


[00:05:54] UFFTYPER: Unrecognized charge state for atom: 15
[00:06:04] UFFTYPER: Unrecognized charge state for atom: 7


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:07:05] UFFTYPER: Unrecognized atom type: S_5+4 (36)


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:07:55] UFFTYPER: Unrecognized charge state for atom: 8


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:08:18] UFFTYPER: Unrecognized charge state for atom: 25


Failed to generate 3D conformer


[00:09:13] UFFTYPER: Unrecognized charge state for atom: 6
[00:09:13] UFFTYPER: Unrecognized charge state for atom: 8
[00:09:18] UFFTYPER: Unrecognized charge state for atom: 3
[00:09:20] UFFTYPER: Unrecognized charge state for atom: 16


Failed to generate 3D conformer
Failed to generate 3D conformer


[00:09:28] UFFTYPER: Unrecognized atom type: S_5+4 (8)


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:10:10] UFFTYPER: Unrecognized charge state for atom: 2


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:11:16] UFFTYPER: Unrecognized atom type: Ge2 (0)


Failed to generate 3D conformer


[00:11:18] UFFTYPER: Unrecognized atom type: Se2+2 (33)
[00:11:20] UFFTYPER: Unrecognized charge state for atom: 40
[00:11:20] UFFTYPER: Unrecognized charge state for atom: 47
[00:11:46] UFFTYPER: Unrecognized charge state for atom: 12
[00:11:46] UFFTYPER: Unrecognized charge state for atom: 44


Failed to generate 3D conformer
Failed to generate 3D conformer


[00:11:53] UFFTYPER: Unrecognized atom type: S_5+4 (29)


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:13:13] UFFTYPER: Unrecognized atom type: S_6+6 (1)
[00:13:20] UFFTYPER: Unrecognized atom type: S_5+4 (1)
[00:13:35] UFFTYPER: Unrecognized charge state for atom: 5
[00:13:44] UFFTYPER: Unrecognized charge state for atom: 2


Failed to generate 3D conformer
Failed to generate 3D conformer


[00:14:14] UFFTYPER: Unrecognized charge state for atom: 1
[00:14:14] UFFTYPER: Unrecognized atom type: Se2+2 (16)


Failed to generate 3D conformer


[00:14:30] UFFTYPER: Unrecognized atom type: S_5+4 (20)
[00:14:33] UFFTYPER: Unrecognized charge state for atom: 4
[00:14:33] UFFTYPER: Unrecognized atom type: Se2+2 (4)
[00:14:33] UFFTYPER: Unrecognized charge state for atom: 5
[00:14:33] UFFTYPER: Unrecognized atom type: Se5+2 (5)
[00:14:33] UFFTYPER: Unrecognized charge state for atom: 12
[00:14:33] UFFTYPER: Unrecognized atom type: Se2+2 (12)


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:15:16] UFFTYPER: Unrecognized charge state for atom: 6


Failed to generate 3D conformer


[00:15:55] UFFTYPER: Unrecognized atom type: S_5+4 (12)
[00:15:58] UFFTYPER: Unrecognized charge state for atom: 10
[00:16:00] UFFTYPER: Unrecognized charge state for atom: 9


Failed to generate 3D conformer
Failed to generate 3D conformer


[00:16:43] UFFTYPER: Unrecognized charge state for atom: 9
[00:16:43] UFFTYPER: Unrecognized charge state for atom: 10


Failed to generate 3D conformer
Failed to generate 3D conformer


[00:17:14] UFFTYPER: Unrecognized charge state for atom: 2


Failed to generate 3D conformer
Failed to generate 3D conformer


[00:17:18] UFFTYPER: Unrecognized atom type: S_5+4 (23)
[00:17:21] UFFTYPER: Unrecognized atom type: Se2+2 (7)


Failed to generate 3D conformer
Failed to generate 3D conformer


[00:17:40] UFFTYPER: Unrecognized charge state for atom: 5


Failed to generate 3D conformer
Failed to generate 3D conformer


[00:18:02] UFFTYPER: Unrecognized atom type: S_5+4 (35)


Failed to generate 3D conformer


[00:18:14] UFFTYPER: Unrecognized atom type: Se2+2 (5)


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:19:21] UFFTYPER: Unrecognized charge state for atom: 3


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:21:02] UFFTYPER: Unrecognized charge state for atom: 24


Failed to generate 3D conformer


[00:21:47] UFFTYPER: Unrecognized charge state for atom: 1
[00:21:47] UFFTYPER: Unrecognized charge state for atom: 2
[00:21:47] UFFTYPER: Unrecognized charge state for atom: 4
[00:21:47] UFFTYPER: Unrecognized charge state for atom: 6
[00:21:47] UFFTYPER: Unrecognized charge state for atom: 8


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:22:43] UFFTYPER: Unrecognized charge state for atom: 25
[00:24:14] UFFTYPER: Unrecognized atom type: S_5+4 (11)


Failed to generate 3D conformer
Failed to generate 3D conformer


[00:25:18] UFFTYPER: Unrecognized charge state for atom: 5
[00:25:18] UFFTYPER: Unrecognized atom type: Ga2+3 (5)


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:25:37] UFFTYPER: Unrecognized charge state for atom: 3


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:25:55] UFFTYPER: Unrecognized charge state for atom: 7
[00:25:55] UFFTYPER: Unrecognized atom type: Se2+2 (7)


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:27:10] UFFTYPER: Unrecognized charge state for atom: 9
[00:27:27] UFFTYPER: Unrecognized charge state for atom: 9
[00:27:27] UFFTYPER: Unrecognized charge state for atom: 22


Failed to generate 3D conformer


[00:27:50] UFFTYPER: Unrecognized charge state for atom: 16
[00:27:50] UFFTYPER: Unrecognized atom type: Se2+2 (16)


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:28:05] UFFTYPER: Unrecognized charge state for atom: 8
[00:28:07] UFFTYPER: Unrecognized charge state for atom: 7
[00:28:07] UFFTYPER: Unrecognized charge state for atom: 19


Failed to generate 3D conformer


[00:28:21] UFFTYPER: Unrecognized atom type: S_5+4 (23)


Failed to generate 3D conformer


[00:28:52] UFFTYPER: Unrecognized charge state for atom: 22


Failed to generate 3D conformer


[00:29:00] UFFTYPER: Unrecognized atom type: S_5+4 (2)


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:29:43] UFFTYPER: Unrecognized charge state for atom: 5


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:30:19] UFFTYPER: Unrecognized charge state for atom: 5
[00:30:21] UFFTYPER: Unrecognized charge state for atom: 1


Failed to generate 3D conformer


[00:30:26] UFFTYPER: Unrecognized charge state for atom: 2


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:31:16] UFFTYPER: Unrecognized charge state for atom: 8


Failed to generate 3D conformer


[00:31:34] UFFTYPER: Unrecognized atom type: S_5+4 (2)
[00:31:42] UFFTYPER: Unrecognized atom type: S_6+6 (30)


Failed to generate 3D conformer


[00:32:46] UFFTYPER: Unrecognized charge state for atom: 3


Failed to generate 3D conformer


[00:33:13] UFFTYPER: Unrecognized charge state for atom: 4
[00:33:13] UFFTYPER: Unrecognized charge state for atom: 16


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:34:05] UFFTYPER: Unrecognized charge state for atom: 6


Failed to generate 3D conformer


[00:34:15] UFFTYPER: Unrecognized charge state for atom: 8
[00:34:15] UFFTYPER: Unrecognized charge state for atom: 10
[00:34:33] UFFTYPER: Unrecognized atom type: S_5+4 (15)
[00:34:33] UFFTYPER: Unrecognized charge state for atom: 6


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:36:13] UFFTYPER: Unrecognized charge state for atom: 5


Failed to generate 3D conformer


[00:36:20] UFFTYPER: Unrecognized charge state for atom: 1


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:37:30] UFFTYPER: Unrecognized charge state for atom: 9
[00:37:30] UFFTYPER: Unrecognized charge state for atom: 19
[00:37:31] UFFTYPER: Unrecognized charge state for atom: 11


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:38:11] UFFTYPER: Unrecognized charge state for atom: 0


Failed to generate 3D conformer
Failed to generate 3D conformer


[00:38:17] UFFTYPER: Unrecognized charge state for atom: 9


Failed to generate 3D conformer
Failed to generate 3D conformer


[00:38:37] UFFTYPER: Unrecognized charge state for atom: 1


Failed to generate 3D conformer
Failed to generate 3D conformer


[00:39:07] UFFTYPER: Unrecognized atom type: S_5+4 (21)
[00:39:08] UFFTYPER: Unrecognized charge state for atom: 0


Failed to generate 3D conformer
Failed to generate 3D conformer


[00:39:34] UFFTYPER: Unrecognized atom type: Se2+2 (25)
[00:39:34] UFFTYPER: Unrecognized atom type: Se2+2 (25)
[00:39:55] UFFTYPER: Unrecognized charge state for atom: 14


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:43:49] UFFTYPER: Unrecognized atom type: B_1 (0)
[00:43:51] UFFTYPER: Unrecognized charge state for atom: 20


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:47:09] UFFTYPER: Unrecognized charge state for atom: 22


Failed to generate 3D conformer


[00:47:24] UFFTYPER: Unrecognized charge state for atom: 19


Failed to generate 3D conformer


[00:47:26] UFFTYPER: Unrecognized charge state for atom: 13


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:47:46] UFFTYPER: Unrecognized charge state for atom: 11
[00:47:46] UFFTYPER: Unrecognized charge state for atom: 21
[00:47:50] UFFTYPER: Unrecognized charge state for atom: 14


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:48:38] UFFTYPER: Unrecognized charge state for atom: 19


Failed to generate 3D conformer


[00:49:03] UFFTYPER: Unrecognized charge state for atom: 12


Failed to generate 3D conformer
Failed to generate 3D conformer


[00:49:33] UFFTYPER: Unrecognized atom type: S_5+4 (1)


Failed to generate 3D conformer


[00:50:01] UFFTYPER: Unrecognized charge state for atom: 13
[00:50:01] UFFTYPER: Unrecognized charge state for atom: 14


Failed to generate 3D conformer


[00:50:12] UFFTYPER: Unrecognized charge state for atom: 15


Failed to generate 3D conformer


[00:50:26] UFFTYPER: Unrecognized charge state for atom: 3


Failed to generate 3D conformer
Failed to generate 3D conformer


[00:50:54] UFFTYPER: Unrecognized charge state for atom: 14
[00:50:54] UFFTYPER: Unrecognized charge state for atom: 24
[00:51:00] UFFTYPER: Unrecognized charge state for atom: 22


Failed to generate 3D conformer


[00:51:09] UFFTYPER: Unrecognized atom type: Se2+2 (18)
[00:51:20] UFFTYPER: Unrecognized charge state for atom: 3


Failed to generate 3D conformer


[00:51:35] UFFTYPER: Unrecognized charge state for atom: 2


Failed to generate 3D conformer


[00:51:45] UFFTYPER: Unrecognized atom type: Pt5+2 (8)
[00:51:51] UFFTYPER: Unrecognized charge state for atom: 14
[00:51:51] UFFTYPER: Unrecognized charge state for atom: 15
[00:51:54] UFFTYPER: Unrecognized charge state for atom: 13
[00:52:01] UFFTYPER: Unrecognized charge state for atom: 31


Failed to generate 3D conformer


[00:52:35] UFFTYPER: Unrecognized atom type: B_1 (7)


Failed to generate 3D conformer
Failed to generate 3D conformer


[00:53:07] UFFTYPER: Unrecognized charge state for atom: 6


Failed to generate 3D conformer


[00:53:48] UFFTYPER: Unrecognized charge state for atom: 5
[00:53:48] UFFTYPER: Unrecognized charge state for atom: 6
[00:53:48] UFFTYPER: Unrecognized charge state for atom: 21
[00:53:48] UFFTYPER: Unrecognized charge state for atom: 22


Failed to generate 3D conformer


[00:54:01] UFFTYPER: Unrecognized charge state for atom: 5
[00:54:01] UFFTYPER: Unrecognized charge state for atom: 10
[00:54:01] UFFTYPER: Unrecognized charge state for atom: 13
[00:54:01] UFFTYPER: Unrecognized charge state for atom: 21
[00:54:08] UFFTYPER: Unrecognized charge state for atom: 3


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:54:26] UFFTYPER: Unrecognized charge state for atom: 1


Failed to generate 3D conformer
Failed to generate 3D conformer


[00:54:45] UFFTYPER: Unrecognized charge state for atom: 10
[00:54:45] UFFTYPER: Unrecognized atom type: Pb3+3 (10)
[00:54:57] UFFTYPER: Unrecognized atom type: S_5+4 (16)
[00:54:59] UFFTYPER: Unrecognized charge state for atom: 7
[00:54:59] UFFTYPER: Unrecognized charge state for atom: 10


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:55:27] UFFTYPER: Unrecognized charge state for atom: 21
[00:55:38] UFFTYPER: Unrecognized charge state for atom: 19
[00:55:39] UFFTYPER: Unrecognized charge state for atom: 8
[00:55:50] UFFTYPER: Unrecognized charge state for atom: 2
[00:55:50] UFFTYPER: Unrecognized atom type: Pb3+3 (2)
[00:55:55] UFFTYPER: Unrecognized charge state for atom: 5


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:56:44] UFFTYPER: Unrecognized charge state for atom: 29
[00:56:46] UFFTYPER: Unrecognized charge state for atom: 5


Failed to generate 3D conformer


[00:57:10] UFFTYPER: Unrecognized atom type: B_1 (1)


Failed to generate 3D conformer


[00:57:11] Can't kekulize mol.  Unkekulized atoms: 7 8 10 12
[00:57:11] Can't kekulize mol.  Unkekulized atoms: 7 8 10 12


Error processing molecule CCN(CC)P(=NS1=NSN=S=N1)(c1ccccc1)N(C1CCCCC1)C1CCCCC1: Can't kekulize mol.  Unkekulized atoms: 7 8 10 12
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[00:58:06] UFFTYPER: Unrecognized charge state for atom: 21


Failed to generate 3D conformer
Failed to generate 3D conformer


[00:58:29] UFFTYPER: Unrecognized atom type: S_5+4 (24)
[00:58:34] UFFTYPER: Unrecognized charge state for atom: 1


Failed to generate 3D conformer


[00:59:12] UFFTYPER: Unrecognized charge state for atom: 9
[00:59:22] UFFTYPER: Unrecognized charge state for atom: 4


Failed to generate 3D conformer


[00:59:44] UFFTYPER: Unrecognized charge state for atom: 5


Failed to generate 3D conformer
Failed to generate 3D conformer


[00:59:56] UFFTYPER: Unrecognized charge state for atom: 15
[01:00:04] UFFTYPER: Unrecognized charge state for atom: 29


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[01:00:18] UFFTYPER: Unrecognized charge state for atom: 3
[01:00:31] UFFTYPER: Unrecognized charge state for atom: 5


Failed to generate 3D conformer


[01:00:36] UFFTYPER: Unrecognized charge state for atom: 90


Failed to generate 3D conformer
Failed to generate 3D conformer


[01:01:12] UFFTYPER: Unrecognized charge state for atom: 9


Failed to generate 3D conformer


[01:01:50] UFFTYPER: Unrecognized atom type: S_5+4 (6)


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[01:02:13] UFFTYPER: Unrecognized charge state for atom: 10
[01:02:13] UFFTYPER: Unrecognized charge state for atom: 13
[01:02:13] UFFTYPER: Unrecognized charge state for atom: 32


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[01:02:48] UFFTYPER: Unrecognized charge state for atom: 4


Failed to generate 3D conformer
Failed to generate 3D conformer


[01:02:58] UFFTYPER: Unrecognized charge state for atom: 5


Failed to generate 3D conformer


[01:03:17] UFFTYPER: Unrecognized charge state for atom: 8


Failed to generate 3D conformer


[01:03:24] UFFTYPER: Unrecognized charge state for atom: 35


Failed to generate 3D conformer


[01:03:37] UFFTYPER: Unrecognized charge state for atom: 6
[01:03:37] UFFTYPER: Unrecognized charge state for atom: 29


Failed to generate 3D conformer


[01:04:00] UFFTYPER: Unrecognized charge state for atom: 5
[01:04:00] UFFTYPER: Unrecognized charge state for atom: 8


Failed to generate 3D conformer
Failed to generate 3D conformer


[01:04:43] UFFTYPER: Unrecognized atom type: Zr1 (14)


Failed to generate 3D conformer


[01:05:00] UFFTYPER: Unrecognized atom type: S_5+4 (1)


Failed to generate 3D conformer


[01:05:10] UFFTYPER: Unrecognized charge state for atom: 20
[01:05:14] UFFTYPER: Unrecognized charge state for atom: 4


Failed to generate 3D conformer
1. Loaded dataset with 49693 graphs.
2. Failed SMILES count: 308
3. Created DataLoader with 49693 graphs
4. Using device: cpu
5. Starting GAN-CL training...




Phase 1: Pretraining Contrastive Learning...


Pretrain Epoch 1/10: 100%|█████████████████████████████████████████████████████████| 1553/1553 [03:46<00:00,  6.85it/s]


Pretrain Epoch 1, Avg Loss: 6.3369


Pretrain Epoch 2/10: 100%|█████████████████████████████████████████████████████████| 1553/1553 [03:45<00:00,  6.88it/s]


Pretrain Epoch 2, Avg Loss: 3.1049


Pretrain Epoch 3/10: 100%|█████████████████████████████████████████████████████████| 1553/1553 [03:45<00:00,  6.88it/s]


Pretrain Epoch 3, Avg Loss: 1.8073


Pretrain Epoch 4/10: 100%|█████████████████████████████████████████████████████████| 1553/1553 [03:45<00:00,  6.89it/s]


Pretrain Epoch 4, Avg Loss: 1.2444


Pretrain Epoch 5/10: 100%|█████████████████████████████████████████████████████████| 1553/1553 [03:45<00:00,  6.89it/s]


Pretrain Epoch 5, Avg Loss: 0.9977


Pretrain Epoch 6/10: 100%|█████████████████████████████████████████████████████████| 1553/1553 [03:46<00:00,  6.87it/s]


Pretrain Epoch 6, Avg Loss: 0.8098


Pretrain Epoch 7/10: 100%|█████████████████████████████████████████████████████████| 1553/1553 [03:45<00:00,  6.88it/s]


Pretrain Epoch 7, Avg Loss: 0.6895


Pretrain Epoch 8/10: 100%|█████████████████████████████████████████████████████████| 1553/1553 [03:49<00:00,  6.77it/s]


Pretrain Epoch 8, Avg Loss: 0.5692


Pretrain Epoch 9/10: 100%|█████████████████████████████████████████████████████████| 1553/1553 [03:46<00:00,  6.86it/s]


Pretrain Epoch 9, Avg Loss: 0.5115


Pretrain Epoch 10/10: 100%|████████████████████████████████████████████████████████| 1553/1553 [03:48<00:00,  6.80it/s]


Pretrain Epoch 10, Avg Loss: 0.4494

Phase 2: Training GAN-CL...


Train Epoch 1/50: 100%|████████████████████████████████████████████████████████████| 1553/1553 [08:56<00:00,  2.89it/s]


Epoch 1, Losses: {'contrastive': 3.640679861604208, 'adversarial': -0.0008367752162981011, 'similarity': 0.0008184608752206292, 'total': 3.640761707340297}


Train Epoch 2/50: 100%|████████████████████████████████████████████████████████████| 1553/1553 [08:57<00:00,  2.89it/s]


Epoch 2, Losses: {'contrastive': 3.738550600434455, 'adversarial': -0.0008148053263580032, 'similarity': 0.0008145790753351777, 'total': 3.7386320632884216}


Train Epoch 3/50: 100%|████████████████████████████████████████████████████████████| 1553/1553 [09:10<00:00,  2.82it/s]


Epoch 3, Losses: {'contrastive': 3.7529484018232466, 'adversarial': -0.0006950421262217934, 'similarity': 0.0007159144669752904, 'total': 3.7530199935646573}


Train Epoch 4/50: 100%|████████████████████████████████████████████████████████████| 1553/1553 [08:59<00:00,  2.88it/s]


Epoch 4, Losses: {'contrastive': 3.6035216914709354, 'adversarial': -0.0006949069019920071, 'similarity': 0.0006813468442748842, 'total': 3.6035898279852048}


Train Epoch 5/50: 100%|████████████████████████████████████████████████████████████| 1553/1553 [09:02<00:00,  2.86it/s]


Epoch 5, Losses: {'contrastive': 3.545553888878512, 'adversarial': -0.0006607495961689522, 'similarity': 0.0006608095673008547, 'total': 3.54561996820736}


Train Epoch 6/50: 100%|████████████████████████████████████████████████████████████| 1553/1553 [09:11<00:00,  2.82it/s]


Epoch 6, Losses: {'contrastive': 3.37273174965681, 'adversarial': -0.0006863637086341645, 'similarity': 0.0006939695749579673, 'total': 3.3728011457365863}


Train Epoch 7/50: 100%|████████████████████████████████████████████████████████████| 1553/1553 [09:11<00:00,  2.82it/s]


Epoch 7, Losses: {'contrastive': 3.3163442039059885, 'adversarial': -0.0007273135721587252, 'similarity': 0.0007442864174106029, 'total': 3.316418635177367}


Train Epoch 8/50: 100%|████████████████████████████████████████████████████████████| 1553/1553 [09:09<00:00,  2.82it/s]


Epoch 8, Losses: {'contrastive': 3.2504490022112766, 'adversarial': -0.0008125077857062009, 'similarity': 0.0008044872803332287, 'total': 3.2505294524387476}


Train Epoch 9/50: 100%|████████████████████████████████████████████████████████████| 1553/1553 [09:12<00:00,  2.81it/s]


Epoch 9, Losses: {'contrastive': 3.242375027342758, 'adversarial': -0.0009067275500336106, 'similarity': 0.0008983867477674484, 'total': 3.242464866165338}


Train Epoch 10/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [09:15<00:00,  2.80it/s]


Epoch 10, Losses: {'contrastive': 3.2373495491105206, 'adversarial': -0.0009785970131067844, 'similarity': 0.0010016392688557785, 'total': 3.2374497131616624}


Train Epoch 11/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [09:15<00:00,  2.79it/s]


Epoch 11, Losses: {'contrastive': 3.2442655125818787, 'adversarial': -0.0010553830072669947, 'similarity': 0.0010603354642685222, 'total': 3.2443715462128884}


Train Epoch 12/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [09:18<00:00,  2.78it/s]


Epoch 12, Losses: {'contrastive': 3.2874330521704533, 'adversarial': -0.001178071790219495, 'similarity': 0.0011982103699970576, 'total': 3.287552874317034}


Train Epoch 13/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [09:22<00:00,  2.76it/s]


Epoch 13, Losses: {'contrastive': 3.2203921870270316, 'adversarial': -0.001339205050805292, 'similarity': 0.0013362189899294662, 'total': 3.22052580904976}


Train Epoch 14/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [09:11<00:00,  2.82it/s]


Epoch 14, Losses: {'contrastive': 3.2215679520111427, 'adversarial': -0.0015179311886762015, 'similarity': 0.0015088506705992412, 'total': 3.2217188342801233}


Train Epoch 15/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [08:28<00:00,  3.06it/s]


Epoch 15, Losses: {'contrastive': 3.311419454760039, 'adversarial': -0.0015777612077051274, 'similarity': 0.0015577306083858695, 'total': 3.3115752262218643}


Train Epoch 16/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [08:28<00:00,  3.05it/s]


Epoch 16, Losses: {'contrastive': 3.184621890974444, 'adversarial': -0.0015423765550463485, 'similarity': 0.001538973058473753, 'total': 3.1847757911651424}


Train Epoch 17/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [08:27<00:00,  3.06it/s]


Epoch 17, Losses: {'contrastive': 3.0912682390642874, 'adversarial': -0.0016904489163618776, 'similarity': 0.0016844800274779556, 'total': 3.091436684707327}


Train Epoch 18/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [08:31<00:00,  3.03it/s]


Epoch 18, Losses: {'contrastive': 3.023737264232181, 'adversarial': -0.0016941353018291856, 'similarity': 0.0016960145909212054, 'total': 3.0239068657370898}


Train Epoch 19/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [08:32<00:00,  3.03it/s]


Epoch 19, Losses: {'contrastive': 2.961476913745067, 'adversarial': -0.0018558541265363084, 'similarity': 0.0018727858951141374, 'total': 2.9616641936882497}


Train Epoch 20/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [08:36<00:00,  3.00it/s]


Epoch 20, Losses: {'contrastive': 2.979823577503811, 'adversarial': -0.0019339073150426494, 'similarity': 0.0019299125864667636, 'total': 2.980016572047109}


Train Epoch 21/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [08:36<00:00,  3.00it/s]


Epoch 21, Losses: {'contrastive': 2.953134519625539, 'adversarial': -0.00195318744583829, 'similarity': 0.001992800674045606, 'total': 2.9533337967210023}


Train Epoch 22/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [08:41<00:00,  2.98it/s]


Epoch 22, Losses: {'contrastive': 2.853040630174159, 'adversarial': -0.002128224081677041, 'similarity': 0.0021252262070823006, 'total': 2.8532531530873975}


Train Epoch 23/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [08:43<00:00,  2.96it/s]


Epoch 23, Losses: {'contrastive': 2.8182093405523996, 'adversarial': -0.002252003161883712, 'similarity': 0.002274825120151657, 'total': 2.818436825344506}


Train Epoch 24/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [08:41<00:00,  2.98it/s]


Epoch 24, Losses: {'contrastive': 2.724774585335161, 'adversarial': -0.0023859420230405866, 'similarity': 0.0023751100185684048, 'total': 2.725012097831549}


Train Epoch 25/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [08:45<00:00,  2.95it/s]


Epoch 25, Losses: {'contrastive': 2.789448172986239, 'adversarial': -0.002622662416160951, 'similarity': 0.002589060964526972, 'total': 2.7897070792898853}


Train Epoch 26/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [08:48<00:00,  2.94it/s]


Epoch 26, Losses: {'contrastive': 2.7743315864054527, 'adversarial': -0.0027093371519271983, 'similarity': 0.0027139629776327507, 'total': 2.7746029823577105}


Train Epoch 27/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [08:51<00:00,  2.92it/s]


Epoch 27, Losses: {'contrastive': 2.723916147803769, 'adversarial': -0.002837898086379261, 'similarity': 0.002822837118882175, 'total': 2.7241984324845205}


Train Epoch 28/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [08:53<00:00,  2.91it/s]


Epoch 28, Losses: {'contrastive': 2.763966690780114, 'adversarial': -0.0029202186357390665, 'similarity': 0.0029172276881183406, 'total': 2.7642584118010993}


Train Epoch 29/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [08:54<00:00,  2.90it/s]


Epoch 29, Losses: {'contrastive': 2.7498059774780765, 'adversarial': -0.0030204057651502963, 'similarity': 0.0030015206101026053, 'total': 2.750106128967737}


Train Epoch 30/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [08:54<00:00,  2.91it/s]


Epoch 30, Losses: {'contrastive': 2.753765388680364, 'adversarial': -0.003218449057449862, 'similarity': 0.0032047854235853954, 'total': 2.7540858680021354}


Train Epoch 31/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [08:54<00:00,  2.91it/s]


Epoch 31, Losses: {'contrastive': 2.7876128889404415, 'adversarial': -0.0033902865422995136, 'similarity': 0.003466512262086574, 'total': 2.7879595382706395}


Train Epoch 32/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [08:55<00:00,  2.90it/s]


Epoch 32, Losses: {'contrastive': 2.7714747966525635, 'adversarial': -0.0038786886109390606, 'similarity': 0.003885616339066684, 'total': 2.771863358372346}


Train Epoch 33/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [08:59<00:00,  2.88it/s]


Epoch 33, Losses: {'contrastive': 2.8386682032464168, 'adversarial': -0.003979010908508821, 'similarity': 0.004035081704646704, 'total': 2.8390717134887145}


Train Epoch 34/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [08:57<00:00,  2.89it/s]


Epoch 34, Losses: {'contrastive': 2.8185987543921582, 'adversarial': -0.0039425612842881215, 'similarity': 0.0038984868249988466, 'total': 2.818988603925367}


Train Epoch 35/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [09:00<00:00,  2.88it/s]


Epoch 35, Losses: {'contrastive': 2.7742703262637374, 'adversarial': -0.004166036423119426, 'similarity': 0.0040658947734938215, 'total': 2.7746769167574006}


Train Epoch 36/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [09:08<00:00,  2.83it/s]


Epoch 36, Losses: {'contrastive': 2.754845130696883, 'adversarial': -0.0045138838609796875, 'similarity': 0.004530115545105432, 'total': 2.7552981411957846}


Train Epoch 37/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [09:08<00:00,  2.83it/s]


Epoch 37, Losses: {'contrastive': 2.773259278043807, 'adversarial': -0.00485453854572251, 'similarity': 0.004831294026897461, 'total': 2.773742407349257}


Train Epoch 38/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [09:08<00:00,  2.83it/s]


Epoch 38, Losses: {'contrastive': 2.7711171882811163, 'adversarial': -0.0052400891510475875, 'similarity': 0.005342843167541377, 'total': 2.7716514725725037}


Train Epoch 39/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [09:11<00:00,  2.82it/s]


Epoch 39, Losses: {'contrastive': 2.7520868608587725, 'adversarial': -0.005840647435391469, 'similarity': 0.005764251746535925, 'total': 2.7526632822949275}


Train Epoch 40/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [09:12<00:00,  2.81it/s]


Epoch 40, Losses: {'contrastive': 2.780091933565453, 'adversarial': -0.006097130433051185, 'similarity': 0.0061794547478057306, 'total': 2.78070987621892}


Train Epoch 41/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [18:44<00:00,  1.38it/s]


Epoch 41, Losses: {'contrastive': 2.7754576353587725, 'adversarial': -0.006372250932438511, 'similarity': 0.006432862532266919, 'total': 2.7761009231507203}


Train Epoch 42/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [28:15<00:00,  1.09s/it]


Epoch 42, Losses: {'contrastive': 2.7824320410883816, 'adversarial': -0.007187329295960286, 'similarity': 0.007107279118161267, 'total': 2.783142772706493}


Train Epoch 43/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [37:16<00:00,  1.44s/it]


Epoch 43, Losses: {'contrastive': 2.855771768269659, 'adversarial': -0.007698649682970544, 'similarity': 0.007571251177470663, 'total': 2.8565288943317118}


Train Epoch 44/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [11:51<00:00,  2.18it/s]


Epoch 44, Losses: {'contrastive': 2.8261750318584946, 'adversarial': -0.007666976607290108, 'similarity': 0.007675767805743951, 'total': 2.826942607731644}


Train Epoch 45/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [12:08<00:00,  2.13it/s]


Epoch 45, Losses: {'contrastive': 2.835218757640448, 'adversarial': -0.008011934788824471, 'similarity': 0.008225697316847027, 'total': 2.8360413307538743}


Train Epoch 46/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [12:17<00:00,  2.11it/s]


Epoch 46, Losses: {'contrastive': 2.7790552093686403, 'adversarial': -0.00809892841224383, 'similarity': 0.008062502257441964, 'total': 2.779861456029122}


Train Epoch 47/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [12:10<00:00,  2.13it/s]


Epoch 47, Losses: {'contrastive': 2.7469651615397055, 'adversarial': -0.008854261347610222, 'similarity': 0.008628800554238187, 'total': 2.7478280415937197}


Train Epoch 48/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [12:22<00:00,  2.09it/s]


Epoch 48, Losses: {'contrastive': 2.744010975187391, 'adversarial': -0.008798275725509437, 'similarity': 0.008961790064991442, 'total': 2.7449071543952073}


Train Epoch 49/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [12:19<00:00,  2.10it/s]


Epoch 49, Losses: {'contrastive': 2.727302418627128, 'adversarial': -0.009186991676028532, 'similarity': 0.009185440081225064, 'total': 2.7282209624341434}


Train Epoch 50/50: 100%|███████████████████████████████████████████████████████████| 1553/1553 [11:51<00:00,  2.18it/s]


Epoch 50, Losses: {'contrastive': 2.7366841472661347, 'adversarial': -0.009587227682302978, 'similarity': 0.009615160049588754, 'total': 2.7376456644484404}
6. Training completed!
7. Extracting final embeddings for XAI...


Extracting embeddings: 100%|███████████████████████████████████████████████████████| 1553/1553 [00:17<00:00, 87.81it/s]


8. Final embeddings saved to ./embeddings/final_embeddings_20250226_233946.pkl
9. Encoders saved in ./checkpoints/encoders/:
   - Best encoder: best_encoder.pt
   - Final encoder: final_encoder.pt
   - Periodic encoders: encoder_epoch_*.pt
