In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import os
import json
import pickle
import copy
import pandas as pd
from tqdm import tqdm
from datetime import datetime
from typing import Tuple, List, Dict, Any, Optional
from collections import defaultdict
from dataclasses import dataclass

from rdkit import Chem, RDLogger
from rdkit.Chem import (
    AllChem, Descriptors, MolSurf, Fragments, Lipinski, RemoveHs
)
from rdkit.Chem.rdMolDescriptors import (
    CalcNumRings, CalcNumAromaticRings, CalcNumHeterocycles, 
    CalcNumAliphaticRings, CalcTPSA
)

from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool, MessagePassing

# Suppress RDKit warnings
RDLogger.DisableLog('rdApp.warning')

# Working timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')


import numpy as np
import pandas as pd
import torch
from rdkit import Chem
from rdkit.Chem import Descriptors, Fragments, Lipinski
from rdkit.Chem.rdMolDescriptors import CalcNumRings, CalcNumAromaticRings, CalcNumHeterocycles
from rdkit.Chem.rdMolDescriptors import CalcNumAliphaticRings, CalcTPSA
from typing import List, Dict, Any, Tuple
import os
import json
import pickle
from tqdm import tqdm
from datetime import datetime
from collections import defaultdict





In [2]:
class MolecularFeatureExtractor:
    def __init__(self):
        self.atom_list = list(range(1, 119))
        self.chirality_list = [
            Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
            Chem.rdchem.ChiralType.CHI_OTHER
        ]
        self.bond_list = [
            Chem.rdchem.BondType.SINGLE,
            Chem.rdchem.BondType.DOUBLE, 
            Chem.rdchem.BondType.TRIPLE,
            Chem.rdchem.BondType.AROMATIC
        ]
        self.bonddir_list = [
            Chem.rdchem.BondDir.NONE,
            Chem.rdchem.BondDir.ENDUPRIGHT,
            Chem.rdchem.BondDir.ENDDOWNRIGHT
        ]

    def calc_atom_features(self, atom: Chem.Atom) -> Tuple[list, list]:
        """Calculate atom features with better error handling"""
        try:
            # Basic features
            atom_feat = [
                self.atom_list.index(atom.GetAtomicNum()),
                self.chirality_list.index(atom.GetChiralTag())
            ]

            # Physical features with error handling
            phys_feat = []
            
            # Molecular weight contribution
            try:
                contrib_mw = Descriptors.ExactMolWt(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_mw)
            except:
                phys_feat.append(0.0)
                
            # LogP contribution    
            try:
                contrib_logp = Descriptors.MolLogP(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_logp)
            except:
                phys_feat.append(0.0)
                
            # Add other physical properties
            phys_feat.extend([
                atom.GetFormalCharge(),
                int(atom.GetHybridization()),
                int(atom.GetIsAromatic()),
                atom.GetTotalNumHs(),
                atom.GetTotalValence(),
                atom.GetDegree()
            ])
            
            return atom_feat, phys_feat
            
        except Exception as e:
            print(f"Error calculating atom features: {e}")
            return [0, 0], [0.0] * 9

    def get_atom_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract atom features for the whole molecule"""
        atom_feats = []
        phys_feats = []
        
        if mol is None:
            return torch.tensor([[0, 0]], dtype=torch.long), torch.tensor([[0.0] * 9], dtype=torch.float)
            
        for atom in mol.GetAtoms():
            atom_feat, phys_feat = self.calc_atom_features(atom)
            atom_feats.append(atom_feat)
            phys_feats.append(phys_feat)

        x = torch.tensor(atom_feats, dtype=torch.long)
        phys = torch.tensor(phys_feats, dtype=torch.float)
        
        return x, phys
    
    def remove_unbonded_hydrogens(mol):
        params = Chem.RemoveHsParameters()
        params.removeDegreeZero = True
        mol = Chem.RemoveHs(mol, params)
        return mol


    def get_bond_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract bond features with better error handling"""
        if mol is None:
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)
            
        row, col, edge_feat = [], [], []
        
        for bond in mol.GetBonds():
            try:
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                
                # Add edges in both directions
                row += [start, end]
                col += [end, start]
                
                # Bond features
                bond_type = self.bond_list.index(bond.GetBondType())
                bond_dir = self.bonddir_list.index(bond.GetBondDir())
                
                # Calculate additional properties
                feat = [
                    bond_type,
                    bond_dir,
                    int(bond.GetIsConjugated()),
                    int(self._is_rotatable(bond)),
                    self._get_bond_length(mol, start, end)
                ]
                
                edge_feat.extend([feat, feat])
                
            except Exception as e:
                print(f"Error processing bond: {e}")
                continue

        if not row:  # If no valid bonds were processed
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)

        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = torch.tensor(edge_feat, dtype=torch.float)
        
        return edge_index, edge_attr

    def _is_rotatable(self, bond: Chem.Bond) -> bool:
        """Check if bond is rotatable"""
        return (bond.GetBondType() == Chem.rdchem.BondType.SINGLE and 
                not bond.IsInRing() and
                len(bond.GetBeginAtom().GetNeighbors()) > 1 and
                len(bond.GetEndAtom().GetNeighbors()) > 1)

    def _get_bond_length(self, mol: Chem.Mol, start: int, end: int) -> float:
        """Get bond length with error handling"""
        try:
            conf = mol.GetConformer()
            if conf.Is3D():
                return Chem.rdMolTransforms.GetBondLength(conf, start, end)
        except:
            pass
        return 0.0

    def process_molecule(self, smiles: str) -> Data:
        """Process SMILES string to graph data"""
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                print(f"Invalid SMILES: {smiles}")
                return None  # Skip invalid molecules
            mol = RemoveHs(mol)

            # Add explicit hydrogens
            mol = Chem.AddHs(mol, addCoords=True)

            # Sanitize molecule
            Chem.SanitizeMol(mol)

            # Check if the molecule has atoms
            if mol.GetNumAtoms() == 0:
                print("Molecule has no atoms, skipping.")
                return None

            # Generate 3D coordinates
            if not mol.GetNumConformers():
                status = AllChem.EmbedMolecule(mol, AllChem.ETKDG())
                if status != 0:
                    print("Failed to generate 3D conformer")
                    return None  # Skip failed molecules

                # Try MMFF or UFF optimization
                try:
                    AllChem.MMFFOptimizeMolecule(mol)
                except:
                    AllChem.UFFOptimizeMolecule(mol)

            # Extract features
            x_cat, x_phys = self.get_atom_features(mol)
            edge_index, edge_attr = self.get_bond_features(mol)

            return Data(
                x_cat=x_cat, 
                x_phys=x_phys,
                edge_index=edge_index, 
                edge_attr=edge_attr,
                num_nodes=x_cat.size(0)
            )

        except Exception as e:
            print(f"Error processing molecule {smiles}: {e}")
            return None


In [3]:
class MemoryQueue:
    """Memory queue with temporal decay for contrastive learning"""
    def __init__(self, size: int, dim: int, decay: float = 0.99999):
        self.size = size
        self.dim = dim
        self.decay = decay
        self.ptr = 0
        self.full = False
        
        # Initialize queue
        self.queue = nn.Parameter(F.normalize(torch.randn(size, dim), dim=1), requires_grad=False)
        self.queue_age = nn.Parameter(torch.zeros(size), requires_grad=False)
        
#         self.register_buffer("queue", torch.randn(size, dim))
#         self.register_buffer("queue_age", torch.zeros(size))  # Track age of each entry
        self.queue = F.normalize(self.queue, dim=1)
        
    def update_queue(self, keys: torch.Tensor):
        """Update queue with new keys"""
        batch_size = keys.shape[0]
        
        # Increment age of all entries
        self.queue_age += 1
        
        # Add new keys
        if self.ptr + batch_size <= self.size:
            self.queue[self.ptr:self.ptr + batch_size] = keys
            self.queue_age[self.ptr:self.ptr + batch_size] = 0
        else:
            # Handle overflow
            rem = self.size - self.ptr
            self.queue[self.ptr:] = keys[:rem]
            self.queue[:batch_size-rem] = keys[rem:]
            self.queue_age[self.ptr:] = 0
            self.queue_age[:batch_size-rem] = 0
            self.full = True
            
        self.ptr = (self.ptr + batch_size) % self.size
        
    def get_decay_weights(self) -> torch.Tensor:
        """Get temporal decay weights for queue entries"""
        return self.decay ** self.queue_age
        
    def compute_contrastive_loss(self, query: torch.Tensor, positive_key: torch.Tensor, 
                                temperature: float = 0.07) -> torch.Tensor:
        """Compute contrastive loss with temporal decay"""
        # Normalize embeddings
        query = F.normalize(query, dim=1)
        positive_key = F.normalize(positive_key, dim=1)
        queue = F.normalize(self.queue, dim=1)
        
        # Compute logits
        l_pos = torch.einsum('nc,nc->n', [query, positive_key]).unsqueeze(-1)
        l_neg = torch.einsum('nc,ck->nk', [query, queue.T])
        
        # Apply temporal decay to negative samples
        decay_weights = self.get_decay_weights()
        l_neg = l_neg * decay_weights.unsqueeze(0)
        
        # Temperature scaling
        logits = torch.cat([l_pos, l_neg], dim=1) / temperature
        labels = torch.zeros(logits.shape[0], dtype=torch.long, device=query.device)
        
        return F.cross_entropy(logits, labels)

class GraphGenerator(nn.Module):
    """Generator network with proper feature handling"""
    def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int = 128):
        super().__init__()
        
        # Node feature processing
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Edge feature processing
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Graph convolution layers
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        
        # Importance prediction layers
        self.node_importance = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()
        )
        
        self.edge_importance = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
    def normalize_features(self, x_cat, x_phys):
        """Normalize categorical and physical features separately"""
        # Convert categorical features to one-hot
        x_cat = x_cat.float()
        
        # Normalize physical features
        x_phys = x_phys.float()
        if x_phys.size(0) > 1:  # Only normalize if we have more than one sample
            x_phys = (x_phys - x_phys.mean(0)) / (x_phys.std(0) + 1e-5)
            
        return x_cat, x_phys
        
    def forward(self, data) -> Tuple[torch.Tensor, torch.Tensor]:
        # Normalize features
        x_cat, x_phys = self.normalize_features(data.x_cat, data.x_phys)
        
        # Concatenate features
        x = torch.cat([x_cat, x_phys], dim=-1)
        
        edge_index = data.edge_index
        edge_attr = data.edge_attr.float()  # Ensure float type
        
        # Initial feature encoding
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)
        
        # Graph convolutions
        x = F.relu(self.conv1(x, edge_index))  # Removed edge_attr from GCNConv
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        
        # Predict importance scores
        node_scores = self.node_importance(x)
        
        # Edge scores using both connected nodes
        edge_features = torch.cat([
            x[edge_index[0]], 
            x[edge_index[1]]
        ], dim=-1)
        edge_scores = self.edge_importance(edge_features)
        
        return node_scores, edge_scores

def get_model_config(dataset):
    """Get model configuration based on dataset features"""
    sample_data = dataset[0]
    
    # Calculate input dimensions
    node_dim = sample_data.x_cat.shape[1] + sample_data.x_phys.shape[1]
    edge_dim = sample_data.edge_attr.shape[1]
    
    config = GanClConfig(
        node_dim=node_dim,
        edge_dim=edge_dim,
        hidden_dim=128,
        output_dim=128,
        queue_size=65536,
        momentum=0.999,
        temperature=0.07,
        decay=0.99999,
        dropout_ratio=0.25
    )
    
    return config

class GraphDiscriminator(nn.Module):
    """Discriminator/Encoder network"""
    def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int = 128, output_dim: int = 128):
        super().__init__()
        
        # Feature encoding
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Graph convolution layers
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, output_dim)
        
        # Projection head for contrastive learning
        self.projection = nn.Sequential(
            nn.Linear(output_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def normalize_features(self, x_cat, x_phys):
        """Normalize categorical and physical features separately"""
        # Convert categorical features to one-hot
        x_cat = x_cat.float()
        
        # Normalize physical features
        x_phys = x_phys.float()
        if x_phys.size(0) > 1:  # Only normalize if we have more than one sample
            x_phys = (x_phys - x_phys.mean(0)) / (x_phys.std(0) + 1e-5)
            
        return x_cat, x_phys 
        
    def forward(self, data):
        # Normalize features
        x_cat, x_phys = self.normalize_features(data.x_cat, data.x_phys)
        
        # Concatenate features
        x = torch.cat([x_cat, x_phys], dim=-1)
        
        edge_index = data.edge_index
        edge_attr = data.edge_attr.float()  # Ensure float type
        batch = data.batch
        
        # Initial feature encoding
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)
        
        # Graph convolutions
        x = F.relu(self.conv1(x, edge_index))  # Removed edge_attr from GCNConv
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        
        # Global pooling
        x = global_mean_pool(x, batch)
        
        # Projection
        x = self.projection(x)
        
        return x

@dataclass
class GanClConfig:
    """Configuration for GAN-CL training"""
    node_dim: int
    edge_dim: int
    hidden_dim: int = 128
    output_dim: int = 128
    queue_size: int = 65536
    momentum: float = 0.999
    temperature: float = 0.07
    decay: float = 0.99999
    dropout_ratio: float = 0.25

class MolecularGANCL(nn.Module):
    """Combined GAN and Contrastive Learning framework"""
    def __init__(self, config: GanClConfig):
        super().__init__()
        self.config = config
        
        # Add weight initialization
        def init_weights(m):
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                m.bias.data.fill_(0.01)
        
        # Initialize networks
        self.generator = GraphGenerator(
            config.node_dim, 
            config.edge_dim, 
            config.hidden_dim * 2
        )
        
        self.encoder = GraphDiscriminator(
            config.node_dim,
            config.edge_dim,
            config.hidden_dim,
            config.output_dim
        )
        self.encoder.apply(init_weights)
        
        # Modified loss weights
        self.contrastive_weight = 1.0
        self.adversarial_weight = 0.1  # Increased from 0.05
        self.similarity_weight = 0.01  # Decreased from 0.1
        
        # Temperature annealing
        self.initial_temperature = 0.1
        self.min_temperature = 0.05        
        
        # Create momentum encoder
        self.momentum_encoder = copy.deepcopy(self.encoder)
        for param in self.momentum_encoder.parameters():
            param.requires_grad = False
            
        # Initialize memory queue
        self.memory_queue = MemoryQueue(
            config.queue_size,
            config.output_dim,
            config.decay
        )
        
    @torch.no_grad()
    def _momentum_update(self):
        """Update momentum encoder"""
        for param_q, param_k in zip(self.encoder.parameters(), 
                                  self.momentum_encoder.parameters()):
            param_k.data = self.config.momentum * param_k.data + \
                          (1 - self.config.momentum) * param_q.data
                          
    def drop_graph_elements(self, data, node_scores: torch.Tensor, 
                          edge_scores: torch.Tensor) -> Data:
        """Apply dropout to graph based on importance scores"""
        # Select elements to keep based on scores and dropout ratio
#         node_mask = (node_scores < self.config.dropout_ratio).float()
#         edge_mask = (edge_scores < self.config.dropout_ratio).float()

        node_mask = (torch.rand_like(node_scores) > self.config.dropout_ratio).float()
        edge_mask = (torch.rand_like(edge_scores) > self.config.dropout_ratio).float()
        
        # Apply masks
        x_cat_new = data.x_cat * node_mask
        x_phys_new = data.x_phys * node_mask
        edge_attr_new = data.edge_attr * edge_mask
        
        # Create new graph data object
        return Data(
            x_cat=x_cat_new,
            x_phys=x_phys_new,
            edge_index=data.edge_index,
            edge_attr=edge_attr_new,
            batch=data.batch
        )
        
    def get_temperature(self, epoch, total_epochs):
        """Anneal temperature during training"""
        progress = epoch / total_epochs
        return max(self.initial_temperature * (1 - progress), self.min_temperature)
    
    def forward(self, data, epoch=0, total_epochs=50):
        # Get current temperature
        temperature = self.get_temperature(epoch, total_epochs)
        
        # Get importance scores from generator
        node_scores, edge_scores = self.generator(data)
        
        # Create perturbed graph
        perturbed_data = self.drop_graph_elements(data, node_scores, edge_scores)
        
        # Get embeddings
        query_emb = self.encoder(perturbed_data)
        with torch.no_grad():
            key_emb = self.momentum_encoder(data)
            original_emb = self.encoder(data).detach()
        
        # Compute losses with modified weights
        contrastive_loss = self.memory_queue.compute_contrastive_loss(
            query_emb, key_emb, temperature
        ) * self.contrastive_weight
        
        adversarial_loss = -F.mse_loss(query_emb, original_emb) * self.adversarial_weight
        similarity_loss = F.mse_loss(query_emb, original_emb) * self.similarity_weight
        
        return contrastive_loss, adversarial_loss, similarity_loss
    
    def get_embeddings(self, data) -> torch.Tensor:
        """Get embeddings for downstream tasks"""
        with torch.no_grad():
            return self.encoder(data)

In [4]:
class SMILESEmbeddingTracker:
    """Track the association between SMILES and embeddings during training"""
    
    def __init__(self, output_dir='./embeddings'):
        """Initialize the tracker
        
        Args:
            output_dir: Directory to save embeddings and associated data
        """
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        self.timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        self.smiles_to_data = {}
        
    def add_batch(self, batch, embeddings, smiles_list=None):
        """Add a batch of data to the tracker
        
        Args:
            batch: Batch of data from the dataloader
            embeddings: Corresponding embeddings
            smiles_list: List of SMILES strings (if available separately)
        """
        if smiles_list is None:
            # Extract SMILES from batch data
            smiles_list = []
            for data in batch:
                if hasattr(data, 'smiles'):
                    smiles_list.append(data.smiles)
                else:
                    # If no SMILES available, use a placeholder
                    smiles_list.append(f"mol_{len(self.smiles_to_data)}")
        
        # Map embeddings to SMILES
        for smiles, emb in zip(smiles_list, embeddings):
            self.smiles_to_data[smiles] = emb.detach().cpu().numpy()
    
    def save_embeddings(self, epoch=None, is_final=False):
        """Save current embeddings with SMILES mapping
        
        Args:
            epoch: Current training epoch (if applicable)
            is_final: Whether these are final embeddings
        
        Returns:
            Path to saved file
        """
        if not self.smiles_to_data:
            print("Warning: No embeddings to save")
            return None
        
        # Prepare data for saving
        smiles_list = list(self.smiles_to_data.keys())
        embeddings_array = np.stack([self.smiles_to_data[s] for s in smiles_list])
        
        # Create filename
        if is_final:
            filename = f"final_embeddings_{self.timestamp}.npz"
        elif epoch is not None:
            filename = f"embeddings_epoch_{epoch}_{self.timestamp}.npz"
        else:
            filename = f"embeddings_{self.timestamp}.npz"
        
        filepath = os.path.join(self.output_dir, filename)
        
        # Save as npz file
        np.savez(filepath, embeddings=embeddings_array, smiles=smiles_list)
        
        print(f"Saved {len(smiles_list)} embeddings to {filepath}")
        return filepath
    
    def reset(self):
        """Clear current data (e.g., between epochs)"""
        self.smiles_to_data = {}

In [5]:
class MolecularAnalyzer:
    """Analyze molecules and their properties for bias detection"""
    
    def __init__(self, output_dir='./analysis'):
        """Initialize the analyzer
        
        Args:
            output_dir: Directory to save analysis results
        """
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        self.timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    
    def extract_properties(self, mol):
        """Extract basic molecular properties"""
        if mol is None:
            return {
                'MW': 0.0,
                'LogP': 0.0,
                'TPSA': 0.0,
                'NumHAcceptors': 0, 
                'NumHDonors': 0,
                'NumRotatableBonds': 0,
                'NumAtoms': 0,
                'NumHeavyAtoms': 0,
                'NumBonds': 0
            }
            
        return {
            'MW': Descriptors.ExactMolWt(mol),
            'LogP': Descriptors.MolLogP(mol),
            'TPSA': CalcTPSA(mol),
            'NumHAcceptors': Lipinski.NumHAcceptors(mol), 
            'NumHDonors': Lipinski.NumHDonors(mol),
            'NumRotatableBonds': Descriptors.NumRotatableBonds(mol),
            'NumAtoms': mol.GetNumAtoms(),
            'NumHeavyAtoms': mol.GetNumHeavyAtoms(),
            'NumBonds': mol.GetNumBonds()
        }
    
    def extract_features(self, mol):
        """Extract structural features"""
        if mol is None:
            return {
                'Aromatic': 0,
                'Heterocycles': 0,
                'FusedRings': 0,
                'SpiroRings': 0,
                'BridgedRings': 0,
                'Macrocycles': 0,
                'LinearChain': 0,
                'Branched': 0
            }
        
        # Get ring information
        ri = mol.GetRingInfo()
        rings = ri.AtomRings()
        
        # Count different ring types
        num_aromatic = CalcNumAromaticRings(mol)
        num_heterocycles = CalcNumHeterocycles(mol)
        
        # Define feature presence (1 if present, 0 if not)
        features = {
            'Aromatic': 1 if num_aromatic > 0 else 0,
            'Heterocycles': 1 if num_heterocycles > 0 else 0,
            'FusedRings': 0,
            'SpiroRings': 0,
            'BridgedRings': 0,
            'Macrocycles': 0,
            'LinearChain': 0,
            'Branched': 0
        }
        
        # Check for fused rings
        if len(rings) >= 2:
            # Check for fused rings (rings sharing atoms)
            for i in range(len(rings)):
                for j in range(i+1, len(rings)):
                    if set(rings[i]).intersection(set(rings[j])):
                        features['FusedRings'] = 1
                        break
                if features['FusedRings'] == 1:
                    break
        
        # Check for spiro rings (rings sharing exactly one atom)
        for i in range(len(rings)):
            for j in range(i+1, len(rings)):
                if len(set(rings[i]).intersection(set(rings[j]))) == 1:
                    features['SpiroRings'] = 1
                    break
            if features['SpiroRings'] == 1:
                break
        
        # Check for bridged rings (complex structures)
        if mol.GetSubstructMatches(Chem.MolFromSmarts('[D4R]')):
            features['BridgedRings'] = 1
            
        # Check for macrocycles (rings with >= 8 atoms)
        for ring in rings:
            if len(ring) >= 8:
                features['Macrocycles'] = 1
                break
        
        # Check for linear chains (molecules with no branches)
        if mol.GetNumBonds() == mol.GetNumAtoms() - 1 and mol.GetNumRings() == 0:
            features['LinearChain'] = 1
        
        # Check for branched structures
        branch_count = sum(1 for atom in mol.GetAtoms() if atom.GetDegree() > 2)
        if branch_count > 0:
            features['Branched'] = 1
            
        return features
    
    def extract_functional_groups(self, mol):
        """Extract functional group presence"""
        if mol is None:
            return {
                'Alcohol': 0,
                'Amine': 0,
                'Carboxyl': 0,
                'Carbonyl': 0,
                'Ether': 0,
                'Ester': 0,
                'Amide': 0,
                'Halogen': 0
            }
        
        # Count alcohols
        alcohol_count = 0
        for atom in mol.GetAtoms():
            if atom.GetAtomicNum() == 8:  # Oxygen
                if atom.GetTotalNumHs() > 0:  # Has at least one hydrogen
                    # Check if connected to carbon
                    for neighbor in atom.GetNeighbors():
                        if neighbor.GetAtomicNum() == 6:  # Carbon
                            alcohol_count += 1
                            break
        
        # Count amines using SMARTS patterns
        amine_patt = Chem.MolFromSmarts('[NX3;H2,H1,H0]')
        amine_count = len(mol.GetSubstructMatches(amine_patt)) if amine_patt else 0
        
        # Count carboxylic acids
        carboxylic_patt = Chem.MolFromSmarts('C(=O)[OH]')
        carboxyl_count = len(mol.GetSubstructMatches(carboxylic_patt)) if carboxylic_patt else 0
        
        # Count carbonyls (ketones and aldehydes)
        ketone_patt = Chem.MolFromSmarts('[#6]C(=O)[#6]')
        ketone_count = len(mol.GetSubstructMatches(ketone_patt)) if ketone_patt else 0
        
        aldehyde_patt = Chem.MolFromSmarts('[#6]C(=O)[H]')
        aldehyde_count = len(mol.GetSubstructMatches(aldehyde_patt)) if aldehyde_patt else 0
        
        # Count ethers
        ether_patt = Chem.MolFromSmarts('[#6]-[O]-[#6]')
        ether_count = len(mol.GetSubstructMatches(ether_patt)) if ether_patt else 0
        
        # Count esters
        ester_patt = Chem.MolFromSmarts('[#6]C(=O)O[#6]')
        ester_count = len(mol.GetSubstructMatches(ester_patt)) if ester_patt else 0
        
        # Count amides
        amide_patt = Chem.MolFromSmarts('C(=O)N')
        amide_count = len(mol.GetSubstructMatches(amide_patt)) if amide_patt else 0
        
        # Count halogens
        halogen_count = sum(1 for atom in mol.GetAtoms() 
                          if atom.GetAtomicNum() in [9, 17, 35, 53])  # F, Cl, Br, I
        
        return {
            'Alcohol': alcohol_count,
            'Amine': amine_count,
            'Carboxyl': carboxyl_count,
            'Carbonyl': ketone_count + aldehyde_count,
            'Ether': ether_count,
            'Ester': ester_count,
            'Amide': amide_count,
            'Halogen': halogen_count
        }
    
    def extract_ring_info(self, mol):
        """Extract ring information"""
        if mol is None:
            return {
                'Spiro': {},
                'Bridged': {},
                'Fused': {},
                'Single': {}
            }
        
        # Get ring information
        ri = mol.GetRingInfo()
        rings = ri.AtomRings()
        
        # Initialize counts
        ring_info = {
            'Spiro': defaultdict(int),
            'Bridged': defaultdict(int),
            'Fused': defaultdict(int),
            'Single': defaultdict(int)
        }
        
        # Process each ring
        processed_rings = set()
        
        # First identify spiro rings (rings sharing exactly one atom)
        for i in range(len(rings)):
            for j in range(i+1, len(rings)):
                if len(set(rings[i]).intersection(set(rings[j]))) == 1:
                    ring_info['Spiro'][len(rings[i])] += 1
                    ring_info['Spiro'][len(rings[j])] += 1
                    processed_rings.add(i)
                    processed_rings.add(j)
        
        # Identify fused rings (rings sharing more than one atom)
        for i in range(len(rings)):
            for j in range(i+1, len(rings)):
                if i in processed_rings or j in processed_rings:
                    continue
                    
                shared = len(set(rings[i]).intersection(set(rings[j])))
                if shared > 1:
                    ring_info['Fused'][len(rings[i])] += 1
                    ring_info['Fused'][len(rings[j])] += 1
                    processed_rings.add(i)
                    processed_rings.add(j)
        
        # Try to identify bridged rings
        bridged_patt = Chem.MolFromSmarts('[D4R]')
        if mol.HasSubstructMatch(bridged_patt):
            for i in range(len(rings)):
                if i in processed_rings:
                    continue
                    
                if len(rings[i]) >= 6:
                    ring_info['Bridged'][len(rings[i])] += 1
                    processed_rings.add(i)
        
        # Remaining rings are single
        for i in range(len(rings)):
            if i not in processed_rings:
                ring_info['Single'][len(rings[i])] += 1
                
        # Convert defaultdicts to regular dicts for JSON serialization
        return {k: dict(v) for k, v in ring_info.items()}
    
    def analyze_smiles_list(self, smiles_list, prefix="molecules"):
        """Analyze a list of SMILES strings and extract properties
        
        Args:
            smiles_list: List of SMILES strings
            prefix: Prefix for output files
            
        Returns:
            Tuple of DataFrames with properties, features, etc.
        """
        print(f"Analyzing {len(smiles_list)} molecules...")
        
        properties_list = []
        features_list = []
        func_groups_list = []
        ring_info_list = []
        valid_smiles = []
        
        for smiles in tqdm(smiles_list):
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                print(f"Warning: Could not parse SMILES: {smiles}")
                continue
                
            valid_smiles.append(smiles)
                
            # Extract and store properties
            props = self.extract_properties(mol)
            properties_list.append(props)
            
            # Extract and store features
            features = self.extract_features(mol)
            features_list.append(features)
            
            # Extract and store functional groups
            func_groups = self.extract_functional_groups(mol)
            func_groups_list.append(func_groups)
            
            # Extract and store ring information
            ring_info = self.extract_ring_info(mol)
            ring_info_list.append(ring_info)
        
        # Create DataFrames with SMILES as index
        props_df = pd.DataFrame(properties_list, index=valid_smiles)
        features_df = pd.DataFrame(features_list, index=valid_smiles)
        func_groups_df = pd.DataFrame(func_groups_list, index=valid_smiles)
        
        # Ring info requires special handling due to nested structure
        ring_df = pd.DataFrame(index=valid_smiles)
        
        # Flatten the ring info for DataFrame representation
        for smiles, ring_data in zip(valid_smiles, ring_info_list):
            for ring_type, size_dict in ring_data.items():
                for size, count in size_dict.items():
                    ring_df.at[smiles, f"{ring_type}_Size{size}"] = count
        
        # Fill NaN values with 0
        ring_df = ring_df.fillna(0)
        
        # Save to CSV
        output_prefix = os.path.join(self.output_dir, f"{prefix}_{self.timestamp}")
        
        props_df.to_csv(f"{output_prefix}_properties.csv")
        features_df.to_csv(f"{output_prefix}_features.csv")
        func_groups_df.to_csv(f"{output_prefix}_functional_groups.csv")
        ring_df.to_csv(f"{output_prefix}_ring_info.csv")
        
        print(f"Analysis saved to {output_prefix}_*.csv")
        
        return props_df, features_df, func_groups_df, ring_df, valid_smiles
    
    def analyze_embeddings_file(self, embeddings_file, output_prefix=None):
        """Analyze embeddings from a file with SMILES mapping
        
        Args:
            embeddings_file: Path to .npz file with embeddings and SMILES
            output_prefix: Prefix for output files (default: based on input filename)
            
        Returns:
            Tuple of DataFrames with properties, features, etc.
        """
        # Load the embeddings file
        try:
            data = np.load(embeddings_file, allow_pickle=True)
            embeddings = data['embeddings']
            smiles_list = data['smiles']
            
            print(f"Loaded {len(smiles_list)} SMILES with embeddings of shape {embeddings.shape}")
            
            # Generate output prefix if not provided
            if output_prefix is None:
                base_name = os.path.splitext(os.path.basename(embeddings_file))[0]
                output_prefix = os.path.join(self.output_dir, base_name)
            
            # Analyze the SMILES
            return self.analyze_smiles_list(smiles_list, prefix=output_prefix)
            
        except Exception as e:
            print(f"Error analyzing embeddings file: {e}")
            return None, None, None, None, None

In [6]:
class MolecularBiasAnalyzer:
    """Molecular bias analysis for comparing raw molecules and embeddings"""
    
    def __init__(self, output_dir: str = './bias_analysis'):
        """Initialize the bias analyzer
        
        Args:
            output_dir: Directory to save analysis results
        """
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        self.timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        
        # Setup SMILES to property mapping for quick lookup
        self.smiles_to_props = {}
        self.smiles_to_features = {}
        self.smiles_to_funcs = {}
        self.smiles_to_rings = {}
    
    def extract_properties(self, mol: Chem.Mol) -> Dict[str, float]:
        """Extract basic molecular properties
        
        Args:
            mol: RDKit molecule object
            
        Returns:
            Dictionary of properties
        """
        if mol is None:
            return {
                'MW': 0.0,
                'LogP': 0.0,
                'TPSA': 0.0,
                'NumHAcceptors': 0, 
                'NumHDonors': 0,
                'NumRotatableBonds': 0,
                'NumAtoms': 0,
                'NumHeavyAtoms': 0,
                'NumBonds': 0
            }
            
        return {
            'MW': Descriptors.ExactMolWt(mol),
            'LogP': Descriptors.MolLogP(mol),
            'TPSA': CalcTPSA(mol),
            'NumHAcceptors': Lipinski.NumHAcceptors(mol), 
            'NumHDonors': Lipinski.NumHDonors(mol),
            'NumRotatableBonds': Descriptors.NumRotatableBonds(mol),
            'NumAtoms': mol.GetNumAtoms(),
            'NumHeavyAtoms': mol.GetNumHeavyAtoms(),
            'NumBonds': mol.GetNumBonds()
        }
    
    def extract_features(self, mol: Chem.Mol) -> Dict[str, int]:
        """Extract structural features
        
        Args:
            mol: RDKit molecule object
            
        Returns:
            Dictionary of feature presence (0/1)
        """
        if mol is None:
            return {
                'Aromatic': 0,
                'Heterocycles': 0,
                'FusedRings': 0,
                'SpiroRings': 0,
                'BridgedRings': 0,
                'Macrocycles': 0,
                'LinearChain': 0,
                'Branched': 0
            }
        
        # Get ring information
        ri = mol.GetRingInfo()
        rings = ri.AtomRings()
        
        # Count different ring types
        num_aromatic = CalcNumAromaticRings(mol)
        num_heterocycles = CalcNumHeterocycles(mol)
        
        # Define feature presence (1 if present, 0 if not)
        features = {
            'Aromatic': 1 if num_aromatic > 0 else 0,
            'Heterocycles': 1 if num_heterocycles > 0 else 0,
            'FusedRings': 0,  # Will be set below if detected
            'SpiroRings': 0,  # Will be set below if detected
            'BridgedRings': 0,  # Will be set below if detected
            'Macrocycles': 0,  # Will be set below if detected
            'LinearChain': 0,  # Will be set below if detected
            'Branched': 0      # Will be set below if detected
        }
        
        # Check for fused rings
        if len(rings) >= 2:
            # Check for fused rings (rings sharing atoms)
            for i in range(len(rings)):
                for j in range(i+1, len(rings)):
                    if set(rings[i]).intersection(set(rings[j])):
                        features['FusedRings'] = 1
                        break
                if features['FusedRings'] == 1:
                    break
        
        # Check for spiro rings (rings sharing exactly one atom)
        for i in range(len(rings)):
            for j in range(i+1, len(rings)):
                if len(set(rings[i]).intersection(set(rings[j]))) == 1:
                    features['SpiroRings'] = 1
                    break
            if features['SpiroRings'] == 1:
                break
        
        # Check for bridged rings (complex structures)
        # This is a simplification - detailed analysis would require more complex algorithms
        if Fragments.fr_bicyclic(mol) > 0:
            features['BridgedRings'] = 1
            
        # Check for macrocycles (rings with >= 8 atoms)
        for ring in rings:
            if len(ring) >= 8:
                features['Macrocycles'] = 1
                break
        
        # Check for linear chains (molecules with no branches)
        if mol.GetNumBonds() == mol.GetNumAtoms() - 1 and mol.GetNumRings() == 0:
            features['LinearChain'] = 1
        
        # Check for branched structures
        # Simplification: if multiple atoms have degree > 2, it's branched
        branch_count = sum(1 for atom in mol.GetAtoms() if atom.GetDegree() > 2)
        if branch_count > 0:
            features['Branched'] = 1
            
        return features
    
    def extract_functional_groups(self, mol: Chem.Mol) -> Dict[str, int]:
        """Extract functional group presence
        
        Args:
            mol: RDKit molecule object
            
        Returns:
            Dictionary of functional group counts
        """
        if mol is None:
            return {
                'Alcohol': 0,
                'Amine': 0,
                'Carboxyl': 0,
                'Carbonyl': 0,
                'Ether': 0,
                'Ester': 0,
                'Amide': 0,
                'Halogen': 0
            }
        
        # Count alcohols
        alcohol_count = 0
        for atom in mol.GetAtoms():
            if atom.GetAtomicNum() == 8:  # Oxygen
                if atom.GetTotalNumHs() > 0:  # Has at least one hydrogen
                    # Check if connected to carbon
                    for neighbor in atom.GetNeighbors():
                        if neighbor.GetAtomicNum() == 6:  # Carbon
                            alcohol_count += 1
                            break
        
        # Count various functional groups using RDKit's built-in fragment counts
        try:
            amine_count = Fragments.fr_NH2(mol) + Fragments.fr_NH1(mol) + Fragments.fr_NH0(mol)
        except:
            # Fallback if fragment functions are not available
            amine_count = sum(1 for atom in mol.GetAtoms() 
                             if atom.GetAtomicNum() == 7 and atom.GetTotalNumHs() > 0)
        
        try:
            carboxyl_count = Fragments.fr_COO(mol) + Fragments.fr_COOH(mol)
        except:
            # Fallback using SMARTS patterns
            carboxyl_patt = Chem.MolFromSmarts('C(=O)[OH]')
            carboxyl_count = len(mol.GetSubstructMatches(carboxyl_patt)) if carboxyl_patt else 0
            
            ester_patt = Chem.MolFromSmarts('C(=O)O[#6]')
            carboxyl_count += len(mol.GetSubstructMatches(ester_patt)) if ester_patt else 0
        
        try:
            carbonyl_count = Fragments.fr_ketone(mol) + Fragments.fr_aldehyde(mol)
        except:
            # Fallback using SMARTS patterns
            ketone_patt = Chem.MolFromSmarts('[#6]C(=O)[#6]')
            carbonyl_count = len(mol.GetSubstructMatches(ketone_patt)) if ketone_patt else 0
            
            aldehyde_patt = Chem.MolFromSmarts('[#6]C(=O)[H]')
            carbonyl_count += len(mol.GetSubstructMatches(aldehyde_patt)) if aldehyde_patt else 0
        
        # Count ethers
        try:
            ether_count = len(mol.GetSubstructMatches(Chem.MolFromSmarts('[#6]-[O]-[#6]')))
        except:
            ether_count = 0
        
        # Count esters
        try:
            ester_count = len(mol.GetSubstructMatches(Chem.MolFromSmarts('[#6]-C(=O)-O-[#6]')))
        except:
            ester_count = 0
        
        # Count amides
        try:
            amide_count = len(mol.GetSubstructMatches(Chem.MolFromSmarts('C(=O)-N')))
        except:
            amide_count = 0
        
        # Count halogens
        halogen_count = sum(1 for atom in mol.GetAtoms() 
                           if atom.GetAtomicNum() in [9, 17, 35, 53])  # F, Cl, Br, I
        
        return {
            'Alcohol': alcohol_count,
            'Amine': amine_count,
            'Carboxyl': carboxyl_count,
            'Carbonyl': carbonyl_count,
            'Ether': ether_count,
            'Ester': ester_count,
            'Amide': amide_count,
            'Halogen': halogen_count
        }
    
    def extract_ring_info(self, mol: Chem.Mol) -> Dict[str, Dict[int, int]]:
        """Extract ring information
        
        Args:
            mol: RDKit molecule object
            
        Returns:
            Dictionary of ring types and their size distributions
        """
        if mol is None:
            return {
                'Spiro': {},
                'Bridged': {},
                'Fused': {},
                'Single': {}
            }
        
        # Get ring information
        ri = mol.GetRingInfo()
        rings = ri.AtomRings()
        
        # Initialize counts
        ring_info = {
            'Spiro': defaultdict(int),
            'Bridged': defaultdict(int),
            'Fused': defaultdict(int),
            'Single': defaultdict(int)
        }
        
        # Process each ring
        processed_rings = set()
        
        # First identify spiro rings (rings sharing exactly one atom)
        for i in range(len(rings)):
            for j in range(i+1, len(rings)):
                if len(set(rings[i]).intersection(set(rings[j]))) == 1:
                    ring_info['Spiro'][len(rings[i])] += 1
                    ring_info['Spiro'][len(rings[j])] += 1
                    processed_rings.add(i)
                    processed_rings.add(j)
        
        # Identify fused rings (rings sharing more than one atom)
        for i in range(len(rings)):
            for j in range(i+1, len(rings)):
                if i in processed_rings or j in processed_rings:
                    continue
                    
                shared = len(set(rings[i]).intersection(set(rings[j])))
                if shared > 1:
                    ring_info['Fused'][len(rings[i])] += 1
                    ring_info['Fused'][len(rings[j])] += 1
                    processed_rings.add(i)
                    processed_rings.add(j)
        
        # Identify bridged rings (simplification)
        if Fragments.fr_bicyclic(mol) > 0:
            for i in range(len(rings)):
                if i in processed_rings:
                    continue
                    
                # This is a simplification - true bridged detection is complex
                if len(rings[i]) >= 6:
                    ring_info['Bridged'][len(rings[i])] += 1
                    processed_rings.add(i)
        
        # Remaining rings are single
        for i in range(len(rings)):
            if i not in processed_rings:
                ring_info['Single'][len(rings[i])] += 1
                
        # Convert defaultdicts to regular dicts for JSON serialization
        return {k: dict(v) for k, v in ring_info.items()}
    
    def analyze_raw_molecules(self, smiles_list: List[str]) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
        """Analyze raw molecules from SMILES strings
        
        Args:
            smiles_list: List of SMILES strings
            
        Returns:
            DataFrames with properties, features, functional groups, and ring info
        """
        print("Analyzing raw molecules...")
        
        properties_list = []
        features_list = []
        func_groups_list = []
        ring_info_list = []
        valid_smiles = []
        
        for smiles in tqdm(smiles_list):
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                continue
                
            valid_smiles.append(smiles)
                
            # Extract and store properties
            props = self.extract_properties(mol)
            properties_list.append(props)
            self.smiles_to_props[smiles] = props
            
            # Extract and store features
            features = self.extract_features(mol)
            features_list.append(features)
            self.smiles_to_features[smiles] = features
            
            # Extract and store functional groups
            func_groups = self.extract_functional_groups(mol)
            func_groups_list.append(func_groups)
            self.smiles_to_funcs[smiles] = func_groups
            
            # Extract and store ring information
            ring_info = self.extract_ring_info(mol)
            ring_info_list.append(ring_info)
            self.smiles_to_rings[smiles] = ring_info
        
        # Create DataFrames with SMILES as index
        props_df = pd.DataFrame(properties_list, index=valid_smiles)
        features_df = pd.DataFrame(features_list, index=valid_smiles)
        func_groups_df = pd.DataFrame(func_groups_list, index=valid_smiles)
        
        # Ring info requires special handling due to nested structure
        ring_df = pd.DataFrame(index=valid_smiles)
        
        # Flatten the ring info for DataFrame representation
        for smiles, ring_data in zip(valid_smiles, ring_info_list):
            for ring_type, size_dict in ring_data.items():
                for size, count in size_dict.items():
                    ring_df.at[smiles, f"{ring_type}_Size{size}"] = count
        
        # Fill NaN values with 0
        ring_df = ring_df.fillna(0)
        
        return props_df, features_df, func_groups_df, ring_df
    
    def save_raw_analysis(self, props_df, features_df, func_groups_df, ring_df, prefix="raw"):
        """Save raw molecule analysis to CSV files"""
        output_prefix = os.path.join(self.output_dir, f"{prefix}_{self.timestamp}")
        
        props_df.to_csv(f"{output_prefix}_properties.csv")
        features_df.to_csv(f"{output_prefix}_features.csv")
        func_groups_df.to_csv(f"{output_prefix}_functional_groups.csv")
        ring_df.to_csv(f"{output_prefix}_ring_info.csv")
        
        print(f"Analysis saved to {output_prefix}_*.csv")
        
        # Also save the lookup mappings for later use
        with open(f"{output_prefix}_lookups.pkl", 'wb') as f:
            pickle.dump({
                'smiles_to_props': self.smiles_to_props,
                'smiles_to_features': self.smiles_to_features,
                'smiles_to_funcs': self.smiles_to_funcs,
                'smiles_to_rings': self.smiles_to_rings
            }, f)
        
        return output_prefix

In [7]:
import numpy as np
import pandas as pd
import torch
from rdkit import Chem
from rdkit.Chem import Descriptors, Fragments, Lipinski
from rdkit.Chem.rdMolDescriptors import CalcNumRings, CalcNumAromaticRings, CalcNumHeterocycles
from rdkit.Chem.rdMolDescriptors import CalcNumAliphaticRings, CalcTPSA
from typing import List, Dict, Any, Tuple
import os
import json
import pickle
from tqdm import tqdm
from datetime import datetime
from collections import defaultdict


class DirectSMILESTracker:
    """A simpler tracker that directly stores SMILES with their embeddings"""
    
    def __init__(self, output_dir='./embeddings'):
        """Initialize the tracker
        
        Args:
            output_dir: Directory to save embeddings and associated data
        """
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        self.timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        # Simple list storage instead of dict
        self.embeddings = []
        self.smiles = []
        
    def add_batch(self, batch_data, batch_embeddings):
        """Add batch data to tracker, ensuring SMILES are properly associated with embeddings
        
        Args:
            batch_data: Batch of molecular graph data
            batch_embeddings: Tensor of embeddings
        """
        # Convert embeddings to numpy
        embeddings_np = batch_embeddings.detach().cpu().numpy()
        
        # Extract SMILES from each data object
        batch_smiles = []
        for data in batch_data:
            if hasattr(data, 'smiles'):
                batch_smiles.append(data.smiles)
            else:
                # If no SMILES available, use a placeholder
                batch_smiles.append(f"unknown_molecule_{len(self.smiles)}")
        
        # Verify dimensions match
        if len(batch_smiles) != len(embeddings_np):
            print(f"Warning: SMILES count ({len(batch_smiles)}) doesn't match embeddings count ({len(embeddings_np)})")
            # Use the smaller size
            min_size = min(len(batch_smiles), len(embeddings_np))
            batch_smiles = batch_smiles[:min_size]
            embeddings_np = embeddings_np[:min_size]
        
        # Append to lists
        self.embeddings.extend(embeddings_np)
        self.smiles.extend(batch_smiles)
        
        print(f"Added {len(batch_smiles)} SMILES-embedding pairs. Total now: {len(self.smiles)}")
    
    def save_embeddings(self, prefix="embeddings"):
        """Save current embeddings with SMILES mapping
        
        Args:
            prefix: Prefix for the filename
            
        Returns:
            Path to saved file
        """
        if not self.smiles or not self.embeddings:
            print("Warning: No embeddings to save")
            return None
        
        # Convert to numpy arrays
        embeddings_array = np.array(self.embeddings)
        smiles_array = np.array(self.smiles)
        
        # Create filename
        filename = f"{prefix}_{self.timestamp}.npz"
        filepath = os.path.join(self.output_dir, filename)
        
        # Save as npz file
        np.savez(filepath, embeddings=embeddings_array, smiles=smiles_array)
        
        print(f"Saved {len(self.smiles)} embeddings to {filepath}")
        return filepath
    
    def reset(self):
        """Clear current data"""
        self.embeddings = []
        self.smiles = []


def modified_train_gan_cl(train_loader, config, device='cuda', 
                         save_dir='./checkpoints', 
                         embedding_dir='./embeddings'):
    """Modified training function with SMILES-embedding tracking"""
    
    # Create directories
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(embedding_dir, exist_ok=True)
    
    # Initialize tracker
    tracker = DirectSMILESTracker(output_dir=embedding_dir)
    
    # Extract all SMILES from the dataset for the "before" analysis
    all_smiles = []
    for batch in train_loader:
        for data in batch:
            if hasattr(data, 'smiles'):
                all_smiles.append(data.smiles)
    
    print(f"Found {len(all_smiles)} molecules for 'before training' analysis")
    
    # Save the raw SMILES list for before-training analysis
    smiles_path = os.path.join(embedding_dir, f"before_training_smiles_{tracker.timestamp}.txt")
    with open(smiles_path, 'w') as f:
        for smiles in all_smiles:
            f.write(f"{smiles}\n")
    
    print(f"Saved 'before training' SMILES to {smiles_path}")
    
    # Initialize model
    model = MolecularGANCL(config).to(device)
    
    # Initialize optimizers
    optimizer_encoder = torch.optim.Adam(model.encoder.parameters(), lr=3e-4)
    optimizer_generator = torch.optim.Adam(model.generator.parameters(), lr=1e-4)
    
    # Training metrics
    metrics = {
        'contrastive_losses': [],
        'adversarial_losses': [],
        'similarity_losses': [],
        'total_losses': []
    }
    
    # Get initial embeddings
    model.eval()
    print("Extracting initial embeddings...")
    with torch.no_grad():
        for batch in tqdm(train_loader, desc="Initial embeddings"):
            batch = batch.to(device)
            initial_embeddings = model.encoder(batch)
            tracker.add_batch(batch, initial_embeddings)
    
    # Save initial embeddings
    initial_embeddings_path = tracker.save_embeddings(prefix="before_training")
    tracker.reset()
    model.train()
    
    # Training phases
    print("Phase 1: Pretraining Contrastive Learning...")
    pretrain_epochs = 10
    for epoch in range(pretrain_epochs):
        contrastive_epoch_loss = 0
        
        for batch in tqdm(train_loader, desc=f'Pretrain Epoch {epoch+1}/{pretrain_epochs}'):
            batch = batch.to(device)
            
            # Forward pass (without generator)
            query_emb = model.encoder(batch)
            with torch.no_grad():
                key_emb = model.momentum_encoder(batch)
            
            # Compute contrastive loss
            contrastive_loss = model.memory_queue.compute_contrastive_loss(
                query_emb, key_emb, model.config.temperature
            )
            
            # Update encoder
            optimizer_encoder.zero_grad()
            contrastive_loss.backward()
            optimizer_encoder.step()
            
            # Update momentum encoder
            model._momentum_update()
            
            # Update memory queue
            model.memory_queue.update_queue(key_emb.detach())
            
            contrastive_epoch_loss += contrastive_loss.item()
            
        avg_loss = contrastive_epoch_loss / len(train_loader)
        metrics['contrastive_losses'].append(avg_loss)
        print(f'Pretrain Epoch {epoch+1}, Avg Loss: {avg_loss:.4f}')
        
        # Save pretrained checkpoint
        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_encoder_state_dict': optimizer_encoder.state_dict(),
                'loss': avg_loss,
            }, os.path.join(save_dir, f'pretrain_checkpoint_{epoch+1}.pt'))
    
    print("\nPhase 2: Training GAN-CL...")
    train_epochs = 50
    best_loss = float('inf')
    
    for epoch in range(train_epochs):
        epoch_losses = {
            'contrastive': 0,
            'adversarial': 0,
            'similarity': 0,
            'total': 0
        }
        
        for batch in tqdm(train_loader, desc=f'Train Epoch {epoch+1}/{train_epochs}'):
            batch = batch.to(device)
            
            # Step 1: Train Encoder
            optimizer_encoder.zero_grad()
            
            # Get importance scores from generator
            with torch.no_grad():
                node_scores, edge_scores = model.generator(batch)
            
            # Create perturbed graph
            perturbed_data = model.drop_graph_elements(batch, node_scores, edge_scores)
            
            # Get embeddings
            query_emb = model.encoder(perturbed_data)
            with torch.no_grad():
                key_emb = model.momentum_encoder(batch)
                original_emb = model.encoder(batch).detach()
            
            # Compute losses for encoder
            contrastive_loss = model.memory_queue.compute_contrastive_loss(
                query_emb, key_emb, model.config.temperature
            )
            similarity_loss = torch.nn.functional.mse_loss(query_emb, original_emb)
            
            # Total loss for encoder
            encoder_loss = contrastive_loss + 0.1 * similarity_loss
            
            # Update encoder
            encoder_loss.backward()
            optimizer_encoder.step()
            
            # Update momentum encoder
            model._momentum_update()
            
            # Step 2: Train Generator
            optimizer_generator.zero_grad()
            
            # Get new embeddings for adversarial loss
            node_scores, edge_scores = model.generator(batch)
            perturbed_data = model.drop_graph_elements(batch, node_scores, edge_scores)
            
            with torch.no_grad():
                original_emb = model.encoder(batch)
            perturbed_emb = model.encoder(perturbed_data)
            
            # Compute adversarial loss
            adversarial_loss = -torch.nn.functional.mse_loss(perturbed_emb, original_emb)
            
            # Update generator
            adversarial_loss.backward()
            optimizer_generator.step()
            
            # Update memory queue
            model.memory_queue.update_queue(key_emb.detach())
            
            # Update metrics
            epoch_losses['contrastive'] += contrastive_loss.item()
            epoch_losses['adversarial'] += adversarial_loss.item()
            epoch_losses['similarity'] += similarity_loss.item()
            epoch_losses['total'] += encoder_loss.item()
        
        # Average losses
        for k in epoch_losses:
            epoch_losses[k] /= len(train_loader)
            metrics[f'{k}_losses'].append(epoch_losses[k])
        
        print(f'Epoch {epoch+1}, Losses: {epoch_losses}')
        
        # Save checkpoint
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_encoder_state_dict': optimizer_encoder.state_dict(),
                'optimizer_generator_state_dict': optimizer_generator.state_dict(),
                'losses': epoch_losses,
            }, os.path.join(save_dir, f'gan_cl_checkpoint_{epoch+1}.pt'))
    
    # Extract final embeddings after training
    model.eval()
    tracker.reset()  # Clear previous embeddings
    
    print("Extracting final embeddings...")
    with torch.no_grad():
        for batch in tqdm(train_loader, desc="Extracting final embeddings"):
            batch = batch.to(device)
            final_embeddings = model.get_embeddings(batch)
            tracker.add_batch(batch, final_embeddings)
    
    # Save final embeddings
    final_embeddings_path = tracker.save_embeddings(prefix="after_training")
    
    # Save final metrics
    with open(os.path.join(save_dir, 'training_metrics.json'), 'w') as f:
        json.dump(metrics, f)
    
    # Make sure we have a valid path to return
    if final_embeddings_path is None:
        print("WARNING: Failed to save final embeddings. Using empty arrays.")
        return model, metrics, np.array([]), np.array([])
    
    # Load the final embeddings to return
    try:
        data = np.load(final_embeddings_path)
        final_embeddings = data['embeddings']
        final_smiles = data['smiles']
        print(f"Loaded {len(final_smiles)} final embeddings from {final_embeddings_path}")
    except Exception as e:
        print(f"Error loading final embeddings: {e}")
        print("Using empty arrays for return values")
        final_embeddings = np.array([])
        final_smiles = np.array([])
    
    return model, metrics, final_embeddings, final_smiles


def analyze_smiles_file(smiles_file, output_dir='./analysis', prefix='analysis'):
    """Analyze a file of SMILES strings for molecular properties
    
    Args:
        smiles_file: Path to text file with SMILES strings (one per line)
        output_dir: Directory to save analysis results
        prefix: Prefix for output files
    """
    # Make sure directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    # Load SMILES
    smiles_list = []
    with open(smiles_file, 'r') as f:
        for line in f:
            smiles = line.strip()
            if smiles:
                smiles_list.append(smiles)
    
    print(f"Loaded {len(smiles_list)} SMILES strings for analysis")
    
    # Analyze properties
    analyze_smiles_list(smiles_list, output_dir, prefix)


def analyze_embedding_file(embedding_file, output_dir='./analysis', prefix=None):
    """Analyze an embedding file with SMILES mapping
    
    Args:
        embedding_file: Path to .npz file with embeddings and SMILES
        output_dir: Directory to save analysis results
        prefix: Prefix for output files (default: derived from embedding filename)
    """
    # Make sure directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    # Determine prefix if not provided
    if prefix is None:
        base_name = os.path.splitext(os.path.basename(embedding_file))[0]
        prefix = base_name
    
    # Load embeddings and SMILES
    try:
        data = np.load(embedding_file)
        embeddings = data['embeddings']
        smiles_list = data['smiles']
        
        print(f"Loaded {len(smiles_list)} SMILES strings from embedding file")
        
        # Analyze properties
        analyze_smiles_list(smiles_list, output_dir, prefix)
        
    except Exception as e:
        print(f"Error loading embedding file: {e}")


def analyze_smiles_list(smiles_list, output_dir='./analysis', prefix='analysis'):
    """Analyze a list of SMILES strings for molecular properties
    
    Args:
        smiles_list: List of SMILES strings
        output_dir: Directory to save analysis results
        prefix: Prefix for output files
    """
    from rdkit import Chem
    from rdkit.Chem import Descriptors, Lipinski
    from rdkit.Chem.rdMolDescriptors import CalcTPSA
    from collections import defaultdict
    import pandas as pd
    
    print(f"Analyzing {len(smiles_list)} molecules...")
    
    # Prepare data storage
    properties = []
    features = []
    func_groups = []
    ring_info = []
    valid_smiles = []
    
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    
    # Process each SMILES
    for smiles in tqdm(smiles_list):
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            print(f"Warning: Invalid SMILES: {smiles}")
            continue
            
        valid_smiles.append(smiles)
        
        # Extract basic properties
        prop = {
            'MW': Descriptors.ExactMolWt(mol),
            'LogP': Descriptors.MolLogP(mol),
            'TPSA': CalcTPSA(mol),
            'NumHAcceptors': Lipinski.NumHAcceptors(mol), 
            'NumHDonors': Lipinski.NumHDonors(mol),
            'NumRotatableBonds': Descriptors.NumRotatableBonds(mol),
            'NumAtoms': mol.GetNumAtoms(),
            'NumHeavyAtoms': mol.GetNumHeavyAtoms(),
            'NumBonds': mol.GetNumBonds(),
            'NumRings': Chem.GetSSSR(mol)
        }
        properties.append(prop)
        
        # Extract structural features
        ri = mol.GetRingInfo()
        rings = ri.AtomRings()
        
        num_aromatic = sum(1 for atom in mol.GetAtoms() if atom.GetIsAromatic())
        
        feat = {
            'Aromatic': 1 if num_aromatic > 0 else 0,
            'Heterocycles': 1 if any(any(mol.GetAtomWithIdx(idx).GetAtomicNum() != 6 for idx in ring) for ring in rings) else 0,
            'FusedRings': 0,
            'SpiroRings': 0,
            'BridgedRings': 0,
            'Macrocycles': 0,
            'LinearChain': 1 if mol.GetNumBonds() == mol.GetNumAtoms() - 1 and len(rings) == 0 else 0,
            'Branched': 1 if sum(1 for atom in mol.GetAtoms() if atom.GetDegree() > 2) > 0 else 0
        }
        
        # Check for fused rings
        if len(rings) >= 2:
            for i in range(len(rings)):
                for j in range(i+1, len(rings)):
                    if len(set(rings[i]).intersection(set(rings[j]))) > 1:
                        feat['FusedRings'] = 1
                        break
                if feat['FusedRings'] == 1:
                    break
        
        # Check for spiro rings
        if len(rings) >= 2:
            for i in range(len(rings)):
                for j in range(i+1, len(rings)):
                    if len(set(rings[i]).intersection(set(rings[j]))) == 1:
                        feat['SpiroRings'] = 1
                        break
                if feat['SpiroRings'] == 1:
                    break
        
        # Check for bridged rings
        bridged_patt = Chem.MolFromSmarts('[D4R]')
        if bridged_patt and mol.HasSubstructMatch(bridged_patt):
            feat['BridgedRings'] = 1
            
        # Check for macrocycles
        for ring in rings:
            if len(ring) >= 8:
                feat['Macrocycles'] = 1
                break
                
        features.append(feat)
        
        # Extract functional groups
        fg = {
            'Alcohol': len(mol.GetSubstructMatches(Chem.MolFromSmarts('CO[H]'))) if Chem.MolFromSmarts('CO[H]') else 0,
            'Amine': len(mol.GetSubstructMatches(Chem.MolFromSmarts('[NX3]'))) if Chem.MolFromSmarts('[NX3]') else 0,
            'Carboxyl': len(mol.GetSubstructMatches(Chem.MolFromSmarts('C(=O)[OH]'))) if Chem.MolFromSmarts('C(=O)[OH]') else 0,
            'Carbonyl': len(mol.GetSubstructMatches(Chem.MolFromSmarts('C=O'))) if Chem.MolFromSmarts('C=O') else 0,
            'Ether': len(mol.GetSubstructMatches(Chem.MolFromSmarts('COC'))) if Chem.MolFromSmarts('COC') else 0,
            'Ester': len(mol.GetSubstructMatches(Chem.MolFromSmarts('C(=O)OC'))) if Chem.MolFromSmarts('C(=O)OC') else 0,
            'Amide': len(mol.GetSubstructMatches(Chem.MolFromSmarts('C(=O)N'))) if Chem.MolFromSmarts('C(=O)N') else 0,
            'Halogen': sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() in [9, 17, 35, 53])
        }
        func_groups.append(fg)
        
        # Extract ring information
        rings_data = {
            'Spiro': defaultdict(int),
            'Bridged': defaultdict(int),
            'Fused': defaultdict(int),
            'Single': defaultdict(int)
        }
        
        # Process ring information
        processed = set()
        
        # First identify spiro rings
        for i in range(len(rings)):
            for j in range(i+1, len(rings)):
                if len(set(rings[i]).intersection(set(rings[j]))) == 1:
                    rings_data['Spiro'][len(rings[i])] += 1
                    rings_data['Spiro'][len(rings[j])] += 1
                    processed.add(i)
                    processed.add(j)
        
        # Identify fused rings
        for i in range(len(rings)):
            for j in range(i+1, len(rings)):
                if i in processed or j in processed:
                    continue
                    
                shared = len(set(rings[i]).intersection(set(rings[j])))
                if shared > 1:
                    rings_data['Fused'][len(rings[i])] += 1
                    rings_data['Fused'][len(rings[j])] += 1
                    processed.add(i)
                    processed.add(j)
        
        # Try to identify bridged rings
        if mol.HasSubstructMatch(Chem.MolFromSmarts('[D4R]')):
            for i in range(len(rings)):
                if i in processed:
                    continue
                    
                if len(rings[i]) >= 6:
                    rings_data['Bridged'][len(rings[i])] += 1
                    processed.add(i)
        
        # Remaining rings are single
        for i in range(len(rings)):
            if i not in processed:
                rings_data['Single'][len(rings[i])] += 1
        
        # Convert defaultdicts to regular dicts
        rings_data = {k: dict(v) for k, v in rings_data.items()}
        ring_info.append(rings_data)
    
    # Create DataFrames
    props_df = pd.DataFrame(properties, index=valid_smiles)
    features_df = pd.DataFrame(features, index=valid_smiles)
    func_groups_df = pd.DataFrame(func_groups, index=valid_smiles)
    
    # Ring info requires special handling
    ring_df = pd.DataFrame(index=valid_smiles)
    
    # Flatten ring info
    for smiles, rings_data in zip(valid_smiles, ring_info):
        for ring_type, size_dict in rings_data.items():
            for size, count in size_dict.items():
                ring_df.at[smiles, f"{ring_type}_Size{size}"] = count
    
    # Fill NaN values with 0
    ring_df = ring_df.fillna(0)
    
    # Save to CSV
    output_prefix = os.path.join(output_dir, f"{prefix}_{timestamp}")
    
    props_df.to_csv(f"{output_prefix}_properties.csv")
    features_df.to_csv(f"{output_prefix}_features.csv")
    func_groups_df.to_csv(f"{output_prefix}_functional_groups.csv")
    ring_df.to_csv(f"{output_prefix}_ring_info.csv")
    
    print(f"Analysis complete. Results saved to {output_prefix}_*.csv")
    
    return props_df, features_df, func_groups_df, ring_df

In [8]:
# Function to modify the original train_gan_cl function
def train_gan_cl_with_bias_analysis(train_loader, config, device='cuda', 
                                   save_dir='./checkpoints', 
                                   embedding_dir='./embeddings'):
    """Main training function for GAN-CL with bias analysis"""
    
    # Create directories
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(embedding_dir, exist_ok=True)
    encoder_dir = os.path.join(save_dir, 'encoders')
    os.makedirs(encoder_dir, exist_ok=True)
    
    # Create bias analyzer
    analyzer = MolecularBiasAnalyzer(output_dir='./bias_analysis')
    
    # Extract and analyze SMILES from the dataset
    smiles_list = []
    for batch in train_loader:
        for data in batch:
            if hasattr(data, 'smiles'):
                smiles_list.append(data.smiles)
    
    print(f"Found {len(smiles_list)} SMILES strings for analysis")
    
    # Analyze raw molecules (before training)
    props_df, features_df, func_groups_df, ring_df = analyzer.analyze_raw_molecules(smiles_list)
    before_prefix = analyzer.save_raw_analysis(props_df, features_df, func_groups_df, ring_df, "before_training")
    
    # Initialize model
    model = MolecularGANCL(config).to(device)
    
    # Initialize optimizers
    optimizer_encoder = torch.optim.Adam(model.encoder.parameters(), lr=3e-4)
    optimizer_generator = torch.optim.Adam(model.generator.parameters(), lr=1e-4)
    
    # Save initial model info
    model_info = {
        'node_dim': config.node_dim,
        'edge_dim': config.edge_dim,
        'hidden_dim': config.hidden_dim,
        'output_dim': config.output_dim,
        'training_config': config.__dict__
    }
    
    # Training phases as before...
    best_loss = float('inf')
    
    # Training metrics
    metrics = {
        'contrastive_losses': [],
        'adversarial_losses': [],
        'similarity_losses': [],
        'total_losses': []
    }
    
    # Training phases
    print("Phase 1: Pretraining Contrastive Learning...")
    pretrain_epochs = 10
    for epoch in range(pretrain_epochs):
        contrastive_epoch_loss = 0
        
        for batch in tqdm(train_loader, desc=f'Pretrain Epoch {epoch+1}/{pretrain_epochs}'):
            batch = batch.to(device)
            
            # Forward pass (without generator)
            query_emb = model.encoder(batch)
            with torch.no_grad():
                key_emb = model.momentum_encoder(batch)
            
            # Compute contrastive loss
            contrastive_loss = model.memory_queue.compute_contrastive_loss(
                query_emb, key_emb, model.config.temperature
            )
            
            # Update encoder
            optimizer_encoder.zero_grad()
            contrastive_loss.backward()
            optimizer_encoder.step()
            
            # Update momentum encoder
            model._momentum_update()
            
            # Update memory queue
            model.memory_queue.update_queue(key_emb.detach())
            
            contrastive_epoch_loss += contrastive_loss.item()
            
        avg_loss = contrastive_epoch_loss / len(train_loader)
        metrics['contrastive_losses'].append(avg_loss)
        print(f'Pretrain Epoch {epoch+1}, Avg Loss: {avg_loss:.4f}')
        
        # Save pretrained checkpoint
        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_encoder_state_dict': optimizer_encoder.state_dict(),
                'loss': avg_loss,
            }, os.path.join(save_dir, f'pretrain_checkpoint_{epoch+1}.pt'))
    
    print("\nPhase 2: Training GAN-CL...")
    train_epochs = 50
    for epoch in range(train_epochs):
        epoch_losses = {
            'contrastive': 0,
            'adversarial': 0,
            'similarity': 0,
            'total': 0
        }
        
        for batch in tqdm(train_loader, desc=f'Train Epoch {epoch+1}/{train_epochs}'):
            batch = batch.to(device)
            
            # Step 1: Train Encoder
            optimizer_encoder.zero_grad()
            
            # Get importance scores from generator
            with torch.no_grad():
                node_scores, edge_scores = model.generator(batch)
            
            # Create perturbed graph
            perturbed_data = model.drop_graph_elements(batch, node_scores, edge_scores)
            
            # Get embeddings
            query_emb = model.encoder(perturbed_data)
            with torch.no_grad():
                key_emb = model.momentum_encoder(batch)
                original_emb = model.encoder(batch).detach()
            
            # Compute losses for encoder
            contrastive_loss = model.memory_queue.compute_contrastive_loss(
                query_emb, key_emb, model.config.temperature
            )
            similarity_loss = torch.nn.functional.mse_loss(query_emb, original_emb)
            
            # Total loss for encoder
            encoder_loss = contrastive_loss + 0.1 * similarity_loss
            
            # Update encoder
            encoder_loss.backward()
            optimizer_encoder.step()
            
            # Update momentum encoder
            model._momentum_update()
            
            # Step 2: Train Generator
            optimizer_generator.zero_grad()
            
            # Get new embeddings for adversarial loss
            node_scores, edge_scores = model.generator(batch)
            perturbed_data = model.drop_graph_elements(batch, node_scores, edge_scores)
            
            with torch.no_grad():
                original_emb = model.encoder(batch)
            perturbed_emb = model.encoder(perturbed_data)
            
            # Compute adversarial loss
            adversarial_loss = -torch.nn.functional.mse_loss(perturbed_emb, original_emb)
            
            # Update generator
            adversarial_loss.backward()
            optimizer_generator.step()
            
            # Update memory queue
            model.memory_queue.update_queue(key_emb.detach())
            
            # Update metrics
            epoch_losses['contrastive'] += contrastive_loss.item()
            epoch_losses['adversarial'] += adversarial_loss.item()
            epoch_losses['similarity'] += similarity_loss.item()
            epoch_losses['total'] += encoder_loss.item()
        
        # Average losses
        for k in epoch_losses:
            epoch_losses[k] /= len(train_loader)
            metrics[f'{k}_losses'].append(epoch_losses[k])
        
        print(f'Epoch {epoch+1}, Losses: {epoch_losses}')
        
        # Save checkpoint
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_encoder_state_dict': optimizer_encoder.state_dict(),
                'optimizer_generator_state_dict': optimizer_generator.state_dict(),
                'losses': epoch_losses,
            }, os.path.join(save_dir, f'gan_cl_checkpoint_{epoch+1}.pt'))
            
        # Extract and save embeddings periodically
        if (epoch + 1) % 10 == 0 or epoch == train_epochs - 1:
            model.eval()
            all_embeddings = []
            all_smiles = []
            
            with torch.no_grad():
                for batch in train_loader:
                    batch = batch.to(device)
                    embeddings = model.get_embeddings(batch)
                    all_embeddings.append(embeddings.cpu().numpy())
                    
                    # Extract SMILES for each data point
                    for data in batch:
                        if hasattr(data, 'smiles'):
                            all_smiles.append(data.smiles)
                        
            # Convert to numpy arrays
            all_embeddings = np.vstack(all_embeddings)
            
            # Save embeddings with SMILES mapping
            embedding_file = os.path.join(embedding_dir, f'embeddings_epoch_{epoch+1}.npz')
            np.savez(embedding_file, embeddings=all_embeddings, smiles=all_smiles)
            
            print(f"Saved embeddings for epoch {epoch+1} with {len(all_smiles)} SMILES mappings")
            
            model.train()
            
        # Save encoder periodically
        if (epoch + 1) % 10 == 0:
            epoch_info = {
                **model_info,
                'epoch': epoch + 1,
                'loss': epoch_losses['total']
            }
            save_encoder(
                model.encoder,
                os.path.join(encoder_dir, f'encoder_epoch_{epoch+1}.pt'),
                epoch_info
            )
        
        # Save best encoder based on total loss
        if epoch_losses['total'] < best_loss:
            best_loss = epoch_losses['total']
            save_encoder(
                model.encoder,
                os.path.join(encoder_dir, 'best_encoder.pt'),
                {**model_info, 'epoch': epoch + 1, 'loss': best_loss}
            )
    
    # Save final encoder
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    save_encoder(
        model.encoder,
        os.path.join(encoder_dir, f'final_encoder_{timestamp}.pt'),
        {**model_info, 'epoch': train_epochs, 'loss': epoch_losses['total']}
    )
    
    # Extract and analyze final embeddings
    model.eval()
    final_embeddings = []
    final_smiles = []
    
    with torch.no_grad():
        for batch in tqdm(train_loader, desc="Extracting final embeddings"):
            batch = batch.to(device)
            embeddings = model.get_embeddings(batch)
            final_embeddings.append(embeddings.cpu().numpy())
            
            # Extract SMILES for each data point
            for data in batch:
                if hasattr(data, 'smiles'):
                    final_smiles.append(data.smiles)
                    
    # Convert to numpy arrays
    final_embeddings = np.vstack(final_embeddings)
    
    # Save final embeddings with SMILES mapping
    final_embedding_file = os.path.join(embedding_dir, f'final_embeddings_{timestamp}.npz')
    np.savez(final_embedding_file, embeddings=final_embeddings, smiles=final_smiles)
    
    print(f"Saved final embeddings with {len(final_smiles)} SMILES mappings")
    
    # Save final metrics
    with open(os.path.join(save_dir, 'training_metrics.json'), 'w') as f:
        json.dump(metrics, f)
    
    return model, metrics, final_embeddings, final_smiles


def save_embeddings_with_properties(embeddings, smiles_list, analyzer, output_dir, prefix="after_training"):
    """Save embeddings with molecular properties for further analysis
    
    Args:
        embeddings: Numpy array of embeddings
        smiles_list: List of SMILES strings corresponding to embeddings
        analyzer: MolecularBiasAnalyzer instance with property data
        output_dir: Directory to save results
        prefix: Prefix for output filenames
    """
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    
    # Extract properties for each SMILES
    properties_list = []
    features_list = []
    func_groups_list = []
    ring_info_list = []
    valid_indices = []
    valid_smiles = []
    
    for i, smiles in enumerate(smiles_list):
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            continue
            
        valid_indices.append(i)
        valid_smiles.append(smiles)
        
        # Extract properties
        properties_list.append(analyzer.extract_properties(mol))
        features_list.append(analyzer.extract_features(mol))
        func_groups_list.append(analyzer.extract_functional_groups(mol))
        ring_info_list.append(analyzer.extract_ring_info(mol))
    
    # Filter embeddings to match valid SMILES
    valid_embeddings = embeddings[valid_indices]
    
    # Create DataFrames
    props_df = pd.DataFrame(properties_list, index=valid_smiles)
    features_df = pd.DataFrame(features_list, index=valid_smiles)
    func_groups_df = pd.DataFrame(func_groups_list, index=valid_smiles)
    
    # Ring info requires special handling due to nested structure
    ring_df = pd.DataFrame(index=valid_smiles)
    
    # Flatten the ring info for DataFrame representation
    for smiles, ring_data in zip(valid_smiles, ring_info_list):
        for ring_type, size_dict in ring_data.items():
            for size, count in size_dict.items():
                ring_df.at[smiles, f"{ring_type}_Size{size}"] = count
    
    # Fill NaN values with 0
    ring_df = ring_df.fillna(0)
    
    # Save all data
    output_prefix = os.path.join(output_dir, f"{prefix}_{timestamp}")
    
    # Save DataFrames
    props_df.to_csv(f"{output_prefix}_properties.csv")
    features_df.to_csv(f"{output_prefix}_features.csv")
    func_groups_df.to_csv(f"{output_prefix}_functional_groups.csv")
    ring_df.to_csv(f"{output_prefix}_ring_info.csv")
    
    # Save embeddings with SMILES mapping
    np.savez(f"{output_prefix}_embeddings.npz", 
             embeddings=valid_embeddings, 
             smiles=valid_smiles)
    
    print(f"Saved {prefix} analysis to {output_prefix}_*.csv/npz")
    
    return output_prefix

In [9]:
def save_embeddings(embeddings, labels, filepath):
    """Save embeddings and corresponding labels"""
    with open(filepath, 'wb') as f:
        pickle.dump({
            'embeddings': embeddings,
            'labels': labels
        }, f)

def save_encoder(encoder, save_path, info=None):
    """Save encoder model for downstream tasks"""
    save_dict = {
        'encoder_state_dict': encoder.state_dict(),
        'model_info': info or {}
    }
    torch.save(save_dict, save_path)

def load_encoder(model_path, device='cpu'):
    """Load saved encoder model"""
    checkpoint = torch.load(model_path, map_location=device)
    encoder = GraphDiscriminator(
        node_dim=checkpoint['model_info'].get('node_dim'),
        edge_dim=checkpoint['model_info'].get('edge_dim'),
        hidden_dim=checkpoint['model_info'].get('hidden_dim', 128),
        output_dim=checkpoint['model_info'].get('output_dim', 128)
    )
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    return encoder        
        
def train_gan_cl(train_loader, config, device='cuda', 
                save_dir='./checkpoints', 
                embedding_dir='./embeddings'):
    """Main training function for GAN-CL with fixed gradient computation"""
    
    # Create directories
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(embedding_dir, exist_ok=True)
    encoder_dir = os.path.join(save_dir, 'encoders')
    os.makedirs(encoder_dir, exist_ok=True)    
    
    # Initialize model
    model = MolecularGANCL(config).to(device)
    
    # Initialize optimizers
    optimizer_encoder = torch.optim.Adam(model.encoder.parameters(), lr=3e-4)
    optimizer_generator = torch.optim.Adam(model.generator.parameters(), lr=1e-4)
    
    # Save initial model info
    model_info = {
        'node_dim': config.node_dim,
        'edge_dim': config.edge_dim,
        'hidden_dim': config.hidden_dim,
        'output_dim': config.output_dim,
        'training_config': config.__dict__
    }
    
    # Training phases as before...
    best_loss = float('inf')
    
    # Training metrics
    metrics = {
        'contrastive_losses': [],
        'adversarial_losses': [],
        'similarity_losses': [],
        'total_losses': []
    }
    
    # Training phases
    print("Phase 1: Pretraining Contrastive Learning...")
    pretrain_epochs = 10
    for epoch in range(pretrain_epochs):
        contrastive_epoch_loss = 0
        
        for batch in tqdm(train_loader, desc=f'Pretrain Epoch {epoch+1}/{pretrain_epochs}'):
            batch = batch.to(device)
            
            # Forward pass (without generator)
            query_emb = model.encoder(batch)
            with torch.no_grad():
                key_emb = model.momentum_encoder(batch)
            
            # Compute contrastive loss
            contrastive_loss = model.memory_queue.compute_contrastive_loss(
                query_emb, key_emb, model.config.temperature
            )
            
            # Update encoder
            optimizer_encoder.zero_grad()
            contrastive_loss.backward()
            optimizer_encoder.step()
            
            # Update momentum encoder
            model._momentum_update()
            
            # Update memory queue
            model.memory_queue.update_queue(key_emb.detach())
            
            contrastive_epoch_loss += contrastive_loss.item()
            
        avg_loss = contrastive_epoch_loss / len(train_loader)
        metrics['contrastive_losses'].append(avg_loss)
        print(f'Pretrain Epoch {epoch+1}, Avg Loss: {avg_loss:.4f}')
        
        # Save pretrained checkpoint
        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_encoder_state_dict': optimizer_encoder.state_dict(),
                'loss': avg_loss,
            }, os.path.join(save_dir, f'pretrain_checkpoint_{epoch+1}.pt'))
    
    print("\nPhase 2: Training GAN-CL...")
    train_epochs = 50
#     train_epochs = 10
    for epoch in range(train_epochs):
        epoch_losses = {
            'contrastive': 0,
            'adversarial': 0,
            'similarity': 0,
            'total': 0
        }
        
        for batch in tqdm(train_loader, desc=f'Train Epoch {epoch+1}/{train_epochs}'):
            batch = batch.to(device)
            
            # Step 1: Train Encoder
            optimizer_encoder.zero_grad()
            
            # Get importance scores from generator
            with torch.no_grad():
                node_scores, edge_scores = model.generator(batch)
            
            # Create perturbed graph
            perturbed_data = model.drop_graph_elements(batch, node_scores, edge_scores)
            
            # Get embeddings
            query_emb = model.encoder(perturbed_data)
            with torch.no_grad():
                key_emb = model.momentum_encoder(batch)
                original_emb = model.encoder(batch).detach()
            
            # Compute losses for encoder
            contrastive_loss = model.memory_queue.compute_contrastive_loss(
                query_emb, key_emb, model.config.temperature
            )
            similarity_loss = F.mse_loss(query_emb, original_emb)
            
            # Total loss for encoder
            encoder_loss = contrastive_loss + 0.1 * similarity_loss
            
            # Update encoder
            encoder_loss.backward()
            optimizer_encoder.step()
            
            # Update momentum encoder
            model._momentum_update()
            
            # Step 2: Train Generator
            optimizer_generator.zero_grad()
            
            # Get new embeddings for adversarial loss
            node_scores, edge_scores = model.generator(batch)
            perturbed_data = model.drop_graph_elements(batch, node_scores, edge_scores)
            
            with torch.no_grad():
                original_emb = model.encoder(batch)
            perturbed_emb = model.encoder(perturbed_data)
            
            # Compute adversarial loss
            adversarial_loss = -F.mse_loss(perturbed_emb, original_emb)
            
            # Update generator
            adversarial_loss.backward()
            optimizer_generator.step()
            
            # Update memory queue
            model.memory_queue.update_queue(key_emb.detach())
            
            # Update metrics
            epoch_losses['contrastive'] += contrastive_loss.item()
            epoch_losses['adversarial'] += adversarial_loss.item()
            epoch_losses['similarity'] += similarity_loss.item()
            epoch_losses['total'] += encoder_loss.item()
        
        # Average losses
        for k in epoch_losses:
            epoch_losses[k] /= len(train_loader)
            metrics[f'{k}_losses'].append(epoch_losses[k])
        
        print(f'Epoch {epoch+1}, Losses: {epoch_losses}')
        
        # Save checkpoint
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_encoder_state_dict': optimizer_encoder.state_dict(),
                'optimizer_generator_state_dict': optimizer_generator.state_dict(),
                'losses': epoch_losses,
            }, os.path.join(save_dir, f'gan_cl_checkpoint_{epoch+1}.pt'))
            
        # Extract and save embeddings periodically
        if (epoch + 1) % 10 == 0:
            model.eval()
            all_embeddings = []
            all_graphs = []
            
            with torch.no_grad():
                for batch in train_loader:
                    batch = batch.to(device)
                    embeddings = model.get_embeddings(batch)
                    all_embeddings.append(embeddings.cpu())
                    all_graphs.extend([data for data in batch])
            
            all_embeddings = torch.cat(all_embeddings, dim=0)
            
            # Save embeddings
#             timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            save_embeddings(
                all_embeddings.numpy(),
                all_graphs,
                os.path.join(embedding_dir, f'embeddings_epoch_{epoch+1}_{timestamp}.pkl')
            )
            
            model.train()
    
    # Save final metrics
    with open(os.path.join(save_dir, 'training_metrics.json'), 'w') as f:
        json.dump(metrics, f)
    
            # Save encoder periodically
        if (epoch + 1) % 10 == 0:
            epoch_info = {
                **model_info,
                'epoch': epoch + 1,
                'loss': epoch_losses['total']
            }
            save_encoder(
                model.encoder,
                os.path.join(encoder_dir, f'encoder_epoch_{epoch+1}.pt'),
                epoch_info
            )
        
        # Save best encoder based on total loss
        if epoch_losses['total'] < best_loss:
            best_loss = epoch_losses['total']
            save_encoder(
                model.encoder,
                os.path.join(encoder_dir, f'best_encoder_{timestamp}.pt'),
                {**model_info, 'epoch': epoch + 1, 'loss': best_loss}
            )
    
    # Save final encoder
    save_encoder(
        model.encoder,
        os.path.join(encoder_dir, f'final_encoder_{timestamp}.pt'),
        {**model_info, 'epoch': train_epochs, 'loss': epoch_losses['total']}
    )
    
    return model, metrics


In [10]:
import torch
import numpy as np
from torch_geometric.data import DataLoader
import os
import json
import pickle
from tqdm import tqdm
from datetime import datetime
from rdkit import Chem
import pandas as pd


def main():
    """Main function with integrated bias analysis"""
    # Enable anomaly detection during development
    torch.autograd.set_detect_anomaly(True)
    # Your existing data loading code here
    torch.manual_seed(42)
    np.random.seed(42)
    
    print("Starting data loading...")
    extractor = MolecularFeatureExtractor()
    smiles_file = "D:\\PhD\\Chapter3\\Unsupervised_GAN_Code\\pubchem-41-clean.txt"
    
    dataset = []
    failed_smiles = []
    
    with open(smiles_file, 'r') as f:
        for i, line in enumerate(f):
            smiles = line.strip()
            data = extractor.process_molecule(smiles)
            if data is not None:
                # Store original SMILES in the data object
                data.smiles = smiles
                dataset.append(data)
            else:
                failed_smiles.append(smiles)
            
            # Limit dataset size for testing
            if i >= 10000:  # Adjust as needed
                break
    
    print(f"1. Loaded dataset with {len(dataset)} graphs.")
    print(f"2. Failed SMILES count: {len(failed_smiles)}")
    
    if not dataset:
        print("No valid graphs generated.")
        return None
        
    # Setup training
    batch_size = 32
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    print(f"3. Created DataLoader with {len(train_loader.dataset)} graphs")
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"4. Using device: {device}")
    
    # Get configuration based on dataset
    config = get_model_config(dataset)
    
    # Train model with bias analysis
    print("5. Starting GAN-CL training with bias analysis...")
    model, metrics, final_embeddings, final_smiles = modified_train_gan_cl(
        train_loader, 
        config,
        device=device,
        save_dir='./checkpoints',
        embedding_dir='./embeddings'
    )
    
    print("6. Training completed!")
    
    # At this point, we've already saved and analyzed the embeddings during training
    # But let's do one final sanity check to make sure the analysis was completed
    
    # Check if after_training_*.csv files exist
    analysis_dir = './analysis'
    found_analysis = False
    
    for filename in os.listdir(analysis_dir):
        if filename.startswith('after_training_') and filename.endswith('.csv'):
            found_analysis = True
            break
    
    if not found_analysis:
        print("7. Warning: After-training analysis files not found!")
        print("   Running final analysis on embeddings...")
        
        # Instead of importing, use the analyze_smiles_list function directly
        # This function should be defined in your script or properly imported
        
        # Create a mapping from embeddings to SMILES
        valid_smiles = []
        for smiles in final_smiles:
            if Chem.MolFromSmiles(smiles) is not None:
                valid_smiles.append(smiles)
        
        # Use analyze_smiles_list function (this should be defined elsewhere in your code)
        props_df, features_df, func_groups_df, ring_df = analyze_smiles_list(
            valid_smiles, output_dir=analysis_dir, prefix="after_training")
        
        print(f"8. Final analysis completed with {len(valid_smiles)} valid molecules")
    else:
        print("7. After-training analysis found!")
        
    print("8. Analysis complete. The data can now be used for visualization and comparison.")
    
    return model, metrics, final_embeddings, final_smiles


if __name__ == "__main__":
    model, metrics, embeddings, smiles = main()

Starting data loading...
1. Loaded dataset with 41 graphs.
2. Failed SMILES count: 0
3. Created DataLoader with 41 graphs
4. Using device: cpu
5. Starting GAN-CL training with bias analysis...
Found 0 molecules for 'before training' analysis
Saved 'before training' SMILES to ./embeddings\before_training_smiles_20250301_125314.txt




Extracting initial embeddings...


Initial embeddings: 100%|████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 80.60it/s]


Added 8 SMILES-embedding pairs. Total now: 8
Added 8 SMILES-embedding pairs. Total now: 16
Saved 16 embeddings to ./embeddings\before_training_20250301_125314.npz
Phase 1: Pretraining Contrastive Learning...


Pretrain Epoch 1/10: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  6.77it/s]


Pretrain Epoch 1, Avg Loss: 1.7559


Pretrain Epoch 2/10: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  7.45it/s]


Pretrain Epoch 2, Avg Loss: 3.8909


Pretrain Epoch 3/10: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.21it/s]


Pretrain Epoch 3, Avg Loss: 4.3830


Pretrain Epoch 4/10: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.82it/s]


Pretrain Epoch 4, Avg Loss: 4.6739


Pretrain Epoch 5/10: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.01it/s]


Pretrain Epoch 5, Avg Loss: 4.9190


Pretrain Epoch 6/10: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.63it/s]


Pretrain Epoch 6, Avg Loss: 4.9916


Pretrain Epoch 7/10: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.89it/s]


Pretrain Epoch 7, Avg Loss: 5.1455


Pretrain Epoch 8/10: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.63it/s]


Pretrain Epoch 8, Avg Loss: 5.3087


Pretrain Epoch 9/10: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.94it/s]


Pretrain Epoch 9, Avg Loss: 5.3186


Pretrain Epoch 10/10: 100%|██████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.91it/s]


Pretrain Epoch 10, Avg Loss: 5.4348

Phase 2: Training GAN-CL...


Train Epoch 1/50: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.63it/s]


Epoch 1, Losses: {'contrastive': 5.613492250442505, 'adversarial': -0.0003769537579501048, 'similarity': 0.0002952148352051154, 'total': 5.613521575927734}


Train Epoch 2/50: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.61it/s]


Epoch 2, Losses: {'contrastive': 5.631539344787598, 'adversarial': -0.0003998180909547955, 'similarity': 0.0002768647391349077, 'total': 5.631567001342773}


Train Epoch 3/50: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.01it/s]


Epoch 3, Losses: {'contrastive': 5.779375076293945, 'adversarial': -0.00025428216758882627, 'similarity': 0.000361294747563079, 'total': 5.77941107749939}


Train Epoch 4/50: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.04it/s]


Epoch 4, Losses: {'contrastive': 5.815361976623535, 'adversarial': -0.00029112392803654075, 'similarity': 0.00028955365996807814, 'total': 5.8153910636901855}


Train Epoch 5/50: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.11it/s]


Epoch 5, Losses: {'contrastive': 5.826724290847778, 'adversarial': -0.00026467108546057716, 'similarity': 0.0002450275205774233, 'total': 5.826748847961426}


Train Epoch 6/50: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.13it/s]


Epoch 6, Losses: {'contrastive': 5.831117630004883, 'adversarial': -0.0002166351696359925, 'similarity': 0.00024627868697280064, 'total': 5.83114218711853}


Train Epoch 7/50: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.22it/s]


Epoch 7, Losses: {'contrastive': 5.909837961196899, 'adversarial': -0.0002601938249426894, 'similarity': 0.0002730983978835866, 'total': 5.909865379333496}


Train Epoch 8/50: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.70it/s]


Epoch 8, Losses: {'contrastive': 5.97320556640625, 'adversarial': -0.00037819512363057584, 'similarity': 0.0003055665365536697, 'total': 5.973236083984375}


Train Epoch 9/50: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.11it/s]


Epoch 9, Losses: {'contrastive': 5.921234846115112, 'adversarial': -0.00029328963137231767, 'similarity': 0.0002580540021881461, 'total': 5.921260595321655}


Train Epoch 10/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.91it/s]


Epoch 10, Losses: {'contrastive': 6.1182098388671875, 'adversarial': -0.00022079909103922546, 'similarity': 0.0003254311886848882, 'total': 6.118242263793945}


Train Epoch 11/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.23it/s]


Epoch 11, Losses: {'contrastive': 6.118496656417847, 'adversarial': -0.00025310095224995166, 'similarity': 0.0002672597474884242, 'total': 6.118523597717285}


Train Epoch 12/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.80it/s]


Epoch 12, Losses: {'contrastive': 6.098068952560425, 'adversarial': -0.0003603752702474594, 'similarity': 0.00024339219089597464, 'total': 6.098093271255493}


Train Epoch 13/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.01it/s]


Epoch 13, Losses: {'contrastive': 6.130922079086304, 'adversarial': -0.00024502797896275297, 'similarity': 0.0002553040030761622, 'total': 6.130947589874268}


Train Epoch 14/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.15it/s]


Epoch 14, Losses: {'contrastive': 6.2653303146362305, 'adversarial': -0.00019840085587929934, 'similarity': 0.0004251596692483872, 'total': 6.2653727531433105}


Train Epoch 15/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.94it/s]


Epoch 15, Losses: {'contrastive': 6.124017715454102, 'adversarial': -0.00034639908699318767, 'similarity': 0.0003271595051046461, 'total': 6.1240503787994385}


Train Epoch 16/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.07it/s]


Epoch 16, Losses: {'contrastive': 6.2178051471710205, 'adversarial': -0.00029189563065301627, 'similarity': 0.00031289718754123896, 'total': 6.217836380004883}


Train Epoch 17/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.93it/s]


Epoch 17, Losses: {'contrastive': 6.227486848831177, 'adversarial': -0.00025084860681090504, 'similarity': 0.00021638851467287168, 'total': 6.227508306503296}


Train Epoch 18/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.36it/s]


Epoch 18, Losses: {'contrastive': 6.391809940338135, 'adversarial': -0.00021781565010314807, 'similarity': 0.0004804972995771095, 'total': 6.391858100891113}


Train Epoch 19/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.33it/s]


Epoch 19, Losses: {'contrastive': 6.235967636108398, 'adversarial': -0.00036357255885377526, 'similarity': 0.0003705114941112697, 'total': 6.236004829406738}


Train Epoch 20/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.43it/s]


Epoch 20, Losses: {'contrastive': 6.293101787567139, 'adversarial': -0.00033551343949511647, 'similarity': 0.00035090722667519003, 'total': 6.293137073516846}


Train Epoch 21/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.55it/s]


Epoch 21, Losses: {'contrastive': 6.286855936050415, 'adversarial': -0.0002657333607203327, 'similarity': 0.0002863065165001899, 'total': 6.286884784698486}


Train Epoch 22/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.29it/s]


Epoch 22, Losses: {'contrastive': 6.349819898605347, 'adversarial': -0.0004775566776515916, 'similarity': 0.0003069901140406728, 'total': 6.349850654602051}


Train Epoch 23/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.25it/s]


Epoch 23, Losses: {'contrastive': 6.398885726928711, 'adversarial': -0.00032853049924597144, 'similarity': 0.0003799564437940717, 'total': 6.398923635482788}


Train Epoch 24/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.55it/s]


Epoch 24, Losses: {'contrastive': 6.432774543762207, 'adversarial': -0.00025558078777976334, 'similarity': 0.0002489579128450714, 'total': 6.432799339294434}


Train Epoch 25/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.87it/s]


Epoch 25, Losses: {'contrastive': 6.300262928009033, 'adversarial': -0.0003665765543701127, 'similarity': 0.0002543073205742985, 'total': 6.300288438796997}


Train Epoch 26/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.33it/s]


Epoch 26, Losses: {'contrastive': 6.5520758628845215, 'adversarial': -0.00030530168442055583, 'similarity': 0.0005046597216278315, 'total': 6.552126407623291}


Train Epoch 27/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.52it/s]


Epoch 27, Losses: {'contrastive': 6.427116394042969, 'adversarial': -0.0004642554195015691, 'similarity': 0.00019306931790197268, 'total': 6.427135705947876}


Train Epoch 28/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.68it/s]


Epoch 28, Losses: {'contrastive': 6.414320230484009, 'adversarial': -0.00037512776907533407, 'similarity': 0.00026436924235895276, 'total': 6.414346694946289}


Train Epoch 29/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.77it/s]


Epoch 29, Losses: {'contrastive': 6.5350518226623535, 'adversarial': -0.00039180500607471913, 'similarity': 0.0004203114513074979, 'total': 6.5350940227508545}


Train Epoch 30/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.73it/s]


Epoch 30, Losses: {'contrastive': 6.466502666473389, 'adversarial': -0.00041668719495646656, 'similarity': 0.0002736851602094248, 'total': 6.466530084609985}


Train Epoch 31/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.08it/s]


Epoch 31, Losses: {'contrastive': 6.5103089809417725, 'adversarial': -0.00042712807771749794, 'similarity': 0.00031958760519046336, 'total': 6.510340929031372}


Train Epoch 32/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.87it/s]


Epoch 32, Losses: {'contrastive': 6.581608057022095, 'adversarial': -0.0003468715585768223, 'similarity': 0.00029273383552208543, 'total': 6.581637382507324}


Train Epoch 33/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.29it/s]


Epoch 33, Losses: {'contrastive': 6.565536022186279, 'adversarial': -0.00033151390380226076, 'similarity': 0.00025446587096666917, 'total': 6.565561532974243}


Train Epoch 34/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.23it/s]


Epoch 34, Losses: {'contrastive': 6.587994813919067, 'adversarial': -0.0003595510861487128, 'similarity': 0.00031529198167845607, 'total': 6.588026285171509}


Train Epoch 35/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.95it/s]


Epoch 35, Losses: {'contrastive': 6.626386880874634, 'adversarial': -0.0002857803483493626, 'similarity': 0.0003517637378536165, 'total': 6.626421928405762}


Train Epoch 36/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.85it/s]


Epoch 36, Losses: {'contrastive': 6.597581624984741, 'adversarial': -0.00043703374103643, 'similarity': 0.00037367276672739536, 'total': 6.59761905670166}


Train Epoch 37/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.29it/s]


Epoch 37, Losses: {'contrastive': 6.6728434562683105, 'adversarial': -0.00043058054870925844, 'similarity': 0.00034692489134613425, 'total': 6.672878265380859}


Train Epoch 38/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.40it/s]


Epoch 38, Losses: {'contrastive': 6.734581470489502, 'adversarial': -0.00025015894789248705, 'similarity': 0.0004612239863490686, 'total': 6.7346274852752686}


Train Epoch 39/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.28it/s]


Epoch 39, Losses: {'contrastive': 6.659841537475586, 'adversarial': -0.00036192568950355053, 'similarity': 0.00043350951455067843, 'total': 6.659884929656982}


Train Epoch 40/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.20it/s]


Epoch 40, Losses: {'contrastive': 6.779620409011841, 'adversarial': -0.00033001833799062297, 'similarity': 0.00037916119617875665, 'total': 6.779658317565918}


Train Epoch 41/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.36it/s]


Epoch 41, Losses: {'contrastive': 6.7663490772247314, 'adversarial': -0.00030055329261813313, 'similarity': 0.00032083211408462375, 'total': 6.76638126373291}


Train Epoch 42/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.94it/s]


Epoch 42, Losses: {'contrastive': 6.7481818199157715, 'adversarial': -0.0003252384049119428, 'similarity': 0.00043949067185167223, 'total': 6.748225927352905}


Train Epoch 43/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.18it/s]


Epoch 43, Losses: {'contrastive': 6.762255668640137, 'adversarial': -0.0003380371490493417, 'similarity': 0.0003589333064155653, 'total': 6.762291669845581}


Train Epoch 44/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.91it/s]


Epoch 44, Losses: {'contrastive': 6.6950461864471436, 'adversarial': -0.00032445545366499573, 'similarity': 0.0002454178174957633, 'total': 6.695070743560791}


Train Epoch 45/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.10it/s]


Epoch 45, Losses: {'contrastive': 6.740059852600098, 'adversarial': -0.00034892596886493266, 'similarity': 0.0002761537762125954, 'total': 6.740087509155273}


Train Epoch 46/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.17it/s]


Epoch 46, Losses: {'contrastive': 6.79088830947876, 'adversarial': -0.00028544574888655916, 'similarity': 0.00029969183378852904, 'total': 6.7909181118011475}


Train Epoch 47/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.44it/s]


Epoch 47, Losses: {'contrastive': 6.67271876335144, 'adversarial': -0.00028900242614327, 'similarity': 0.0003392087819520384, 'total': 6.672752618789673}


Train Epoch 48/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.04it/s]


Epoch 48, Losses: {'contrastive': 6.8084094524383545, 'adversarial': -0.00033863364660646766, 'similarity': 0.00043579822522588074, 'total': 6.80845308303833}


Train Epoch 49/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.03it/s]


Epoch 49, Losses: {'contrastive': 6.768752574920654, 'adversarial': -0.0003772854106500745, 'similarity': 0.0004153896734351292, 'total': 6.768794059753418}


Train Epoch 50/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.19it/s]


Epoch 50, Losses: {'contrastive': 6.825029611587524, 'adversarial': -0.00033936907129827887, 'similarity': 0.0004176431684754789, 'total': 6.825071334838867}
Extracting final embeddings...


Extracting final embeddings: 100%|██████████████████████████████████████████████████████| 2/2 [00:00<00:00, 153.14it/s]

Added 8 SMILES-embedding pairs. Total now: 8
Added 8 SMILES-embedding pairs. Total now: 16
Saved 16 embeddings to ./embeddings\after_training_20250301_125314.npz
Loaded 16 final embeddings from ./embeddings\after_training_20250301_125314.npz
6. Training completed!
7. After-training analysis found!
8. Analysis complete. The data can now be used for visualization and comparison.



