In [1]:
import torch
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from typing import Tuple
import numpy as np
from torch_geometric.data import Data
import random
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from rdkit import RDLogger
from rdkit.Chem import RemoveHs
from datetime import datetime

import torch
import numpy as np
from torch_geometric.data import DataLoader
import os
import json
from tqdm import tqdm
import pickle
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, MessagePassing
from typing import Tuple, List, Optional
import copy
from dataclasses import dataclass

# Suppress RDKit warnings
RDLogger.DisableLog('rdApp.warning')
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')



In [2]:
class MolecularFeatureExtractor:
    def __init__(self):
        self.atom_list = list(range(1, 119))
        self.chirality_list = [
            Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
            Chem.rdchem.ChiralType.CHI_OTHER
        ]
        self.bond_list = [
            Chem.rdchem.BondType.SINGLE,
            Chem.rdchem.BondType.DOUBLE, 
            Chem.rdchem.BondType.TRIPLE,
            Chem.rdchem.BondType.AROMATIC
        ]
        self.bonddir_list = [
            Chem.rdchem.BondDir.NONE,
            Chem.rdchem.BondDir.ENDUPRIGHT,
            Chem.rdchem.BondDir.ENDDOWNRIGHT
        ]

    def calc_atom_features(self, atom: Chem.Atom) -> Tuple[list, list]:
        """Calculate atom features with better error handling"""
        try:
            # Basic features
            atom_feat = [
                self.atom_list.index(atom.GetAtomicNum()),
                self.chirality_list.index(atom.GetChiralTag())
            ]

            # Physical features with error handling
            phys_feat = []
            
            # Molecular weight contribution
            try:
                contrib_mw = Descriptors.ExactMolWt(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_mw)
            except:
                phys_feat.append(0.0)
                
            # LogP contribution    
            try:
                contrib_logp = Descriptors.MolLogP(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_logp)
            except:
                phys_feat.append(0.0)
                
            # Add other physical properties
            phys_feat.extend([
                atom.GetFormalCharge(),
                int(atom.GetHybridization()),
                int(atom.GetIsAromatic()),
                atom.GetTotalNumHs(),
                atom.GetTotalValence(),
                atom.GetDegree()
            ])
            
            return atom_feat, phys_feat
            
        except Exception as e:
            print(f"Error calculating atom features: {e}")
            return [0, 0], [0.0] * 9

    def get_atom_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract atom features for the whole molecule"""
        atom_feats = []
        phys_feats = []
        
        if mol is None:
            return torch.tensor([[0, 0]], dtype=torch.long), torch.tensor([[0.0] * 9], dtype=torch.float)
            
        for atom in mol.GetAtoms():
            atom_feat, phys_feat = self.calc_atom_features(atom)
            atom_feats.append(atom_feat)
            phys_feats.append(phys_feat)

        x = torch.tensor(atom_feats, dtype=torch.long)
        phys = torch.tensor(phys_feats, dtype=torch.float)
        
        return x, phys
    
    def remove_unbonded_hydrogens(mol):
        params = Chem.RemoveHsParameters()
        params.removeDegreeZero = True
        mol = Chem.RemoveHs(mol, params)
        return mol


    def get_bond_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract bond features with better error handling"""
        if mol is None:
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)
            
        row, col, edge_feat = [], [], []
        
        for bond in mol.GetBonds():
            try:
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                
                # Add edges in both directions
                row += [start, end]
                col += [end, start]
                
                # Bond features
                bond_type = self.bond_list.index(bond.GetBondType())
                bond_dir = self.bonddir_list.index(bond.GetBondDir())
                
                # Calculate additional properties
                feat = [
                    bond_type,
                    bond_dir,
                    int(bond.GetIsConjugated()),
                    int(self._is_rotatable(bond)),
                    self._get_bond_length(mol, start, end)
                ]
                
                edge_feat.extend([feat, feat])
                
            except Exception as e:
                print(f"Error processing bond: {e}")
                continue

        if not row:  # If no valid bonds were processed
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)

        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = torch.tensor(edge_feat, dtype=torch.float)
        
        return edge_index, edge_attr

    def _is_rotatable(self, bond: Chem.Bond) -> bool:
        """Check if bond is rotatable"""
        return (bond.GetBondType() == Chem.rdchem.BondType.SINGLE and 
                not bond.IsInRing() and
                len(bond.GetBeginAtom().GetNeighbors()) > 1 and
                len(bond.GetEndAtom().GetNeighbors()) > 1)

    def _get_bond_length(self, mol: Chem.Mol, start: int, end: int) -> float:
        """Get bond length with error handling"""
        try:
            conf = mol.GetConformer()
            if conf.Is3D():
                return Chem.rdMolTransforms.GetBondLength(conf, start, end)
        except:
            pass
        return 0.0

    def process_molecule(self, smiles: str) -> Data:
        """Process SMILES string to graph data"""
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                print(f"Invalid SMILES: {smiles}")
                return None  # Skip invalid molecules
            mol = RemoveHs(mol)

            # Add explicit hydrogens
            mol = Chem.AddHs(mol, addCoords=True)

            # Sanitize molecule
            Chem.SanitizeMol(mol)

            # Check if the molecule has atoms
            if mol.GetNumAtoms() == 0:
                print("Molecule has no atoms, skipping.")
                return None

            # Generate 3D coordinates
            if not mol.GetNumConformers():
                status = AllChem.EmbedMolecule(mol, AllChem.ETKDG())
                if status != 0:
                    print("Failed to generate 3D conformer")
                    return None  # Skip failed molecules

                # Try MMFF or UFF optimization
                try:
                    AllChem.MMFFOptimizeMolecule(mol)
                except:
                    AllChem.UFFOptimizeMolecule(mol)

            # Extract features
            x_cat, x_phys = self.get_atom_features(mol)
            edge_index, edge_attr = self.get_bond_features(mol)

            # Create data object with SMILES
            data = Data(
                x_cat=x_cat, 
                x_phys=x_phys,
                edge_index=edge_index, 
                edge_attr=edge_attr,
                num_nodes=x_cat.size(0)
            )

            # Store the SMILES as an attribute
            data.smiles = smiles

            return data

        except Exception as e:
            print(f"Error processing molecule {smiles}: {e}")
            return None

In [3]:
class MemoryQueue:
    """Memory queue with temporal decay for contrastive learning"""
    def __init__(self, size: int, dim: int, decay: float = 0.99999):
        self.size = size
        self.dim = dim
        self.decay = decay
        self.ptr = 0
        self.full = False
        
        # Initialize queue
        self.queue = nn.Parameter(F.normalize(torch.randn(size, dim), dim=1), requires_grad=False)
        self.queue_age = nn.Parameter(torch.zeros(size), requires_grad=False)
        
#         self.register_buffer("queue", torch.randn(size, dim))
#         self.register_buffer("queue_age", torch.zeros(size))  # Track age of each entry
        self.queue = F.normalize(self.queue, dim=1)
        
    def update_queue(self, keys: torch.Tensor):
        """Update queue with new keys"""
        batch_size = keys.shape[0]
        
        # Increment age of all entries
        self.queue_age += 1
        
        # Add new keys
        if self.ptr + batch_size <= self.size:
            self.queue[self.ptr:self.ptr + batch_size] = keys
            self.queue_age[self.ptr:self.ptr + batch_size] = 0
        else:
            # Handle overflow
            rem = self.size - self.ptr
            self.queue[self.ptr:] = keys[:rem]
            self.queue[:batch_size-rem] = keys[rem:]
            self.queue_age[self.ptr:] = 0
            self.queue_age[:batch_size-rem] = 0
            self.full = True
            
        self.ptr = (self.ptr + batch_size) % self.size
        
    def get_decay_weights(self) -> torch.Tensor:
        """Get temporal decay weights for queue entries"""
        return self.decay ** self.queue_age
        
    def compute_contrastive_loss(self, query: torch.Tensor, positive_key: torch.Tensor, 
                                temperature: float = 0.07) -> torch.Tensor:
        """Compute contrastive loss with temporal decay"""
        # Normalize embeddings
        query = F.normalize(query, dim=1)
        positive_key = F.normalize(positive_key, dim=1)
        queue = F.normalize(self.queue, dim=1)
        
        # Compute logits
        l_pos = torch.einsum('nc,nc->n', [query, positive_key]).unsqueeze(-1)
        l_neg = torch.einsum('nc,ck->nk', [query, queue.T])
        
        # Apply temporal decay to negative samples
        decay_weights = self.get_decay_weights()
        l_neg = l_neg * decay_weights.unsqueeze(0)
        
        # Temperature scaling
        logits = torch.cat([l_pos, l_neg], dim=1) / temperature
        labels = torch.zeros(logits.shape[0], dtype=torch.long, device=query.device)
        
        return F.cross_entropy(logits, labels)

class GraphGenerator(nn.Module):
    """Generator network with proper feature handling"""
    def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int = 128):
        super().__init__()
        
        # Node feature processing
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Edge feature processing
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Graph convolution layers
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        
        # Importance prediction layers
        self.node_importance = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()
        )
        
        self.edge_importance = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
    def normalize_features(self, x_cat, x_phys):
        """Normalize categorical and physical features separately"""
        # Convert categorical features to one-hot
        x_cat = x_cat.float()
        
        # Normalize physical features
        x_phys = x_phys.float()
        if x_phys.size(0) > 1:  # Only normalize if we have more than one sample
            x_phys = (x_phys - x_phys.mean(0)) / (x_phys.std(0) + 1e-5)
            
        return x_cat, x_phys
        
    def forward(self, data) -> Tuple[torch.Tensor, torch.Tensor]:
        # Normalize features
        x_cat, x_phys = self.normalize_features(data.x_cat, data.x_phys)
        
        # Concatenate features
        x = torch.cat([x_cat, x_phys], dim=-1)
        
        edge_index = data.edge_index
        edge_attr = data.edge_attr.float()  # Ensure float type
        
        # Initial feature encoding
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)
        
        # Graph convolutions
        x = F.relu(self.conv1(x, edge_index))  # Removed edge_attr from GCNConv
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        
        # Predict importance scores
        node_scores = self.node_importance(x)
        
        # Edge scores using both connected nodes
        edge_features = torch.cat([
            x[edge_index[0]], 
            x[edge_index[1]]
        ], dim=-1)
        edge_scores = self.edge_importance(edge_features)
        
        return node_scores, edge_scores

def get_model_config(dataset):
    """Get model configuration based on dataset features"""
    sample_data = dataset[0]
    
    # Calculate input dimensions
    node_dim = sample_data.x_cat.shape[1] + sample_data.x_phys.shape[1]
    edge_dim = sample_data.edge_attr.shape[1]
    
    config = GanClConfig(
        node_dim=node_dim,
        edge_dim=edge_dim,
        hidden_dim=128,
        output_dim=128,
        queue_size=65536,
        momentum=0.999,
        temperature=0.07,
        decay=0.99999,
        dropout_ratio=0.25
    )
    
    return config

class GraphDiscriminator(nn.Module):
    """Discriminator/Encoder network"""
    def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int = 128, output_dim: int = 128):
        super().__init__()
        
        # Feature encoding
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Graph convolution layers
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, output_dim)
        
        # Projection head for contrastive learning
        self.projection = nn.Sequential(
            nn.Linear(output_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def normalize_features(self, x_cat, x_phys):
        """Normalize categorical and physical features separately"""
        # Convert categorical features to one-hot
        x_cat = x_cat.float()
        
        # Normalize physical features
        x_phys = x_phys.float()
        if x_phys.size(0) > 1:  # Only normalize if we have more than one sample
            x_phys = (x_phys - x_phys.mean(0)) / (x_phys.std(0) + 1e-5)
            
        return x_cat, x_phys 
        
    def forward(self, data):
        # Normalize features
        x_cat, x_phys = self.normalize_features(data.x_cat, data.x_phys)
        
        # Concatenate features
        x = torch.cat([x_cat, x_phys], dim=-1)
        
        edge_index = data.edge_index
        edge_attr = data.edge_attr.float()  # Ensure float type
        batch = data.batch
        
        # Initial feature encoding
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)
        
        # Graph convolutions
        x = F.relu(self.conv1(x, edge_index))  # Removed edge_attr from GCNConv
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        
        # Global pooling
        x = global_mean_pool(x, batch)
        
        # Projection
        x = self.projection(x)
        
        return x

@dataclass
class GanClConfig:
    """Configuration for GAN-CL training"""
    node_dim: int
    edge_dim: int
    hidden_dim: int = 128
    output_dim: int = 128
    queue_size: int = 65536
    momentum: float = 0.999
    temperature: float = 0.07
    decay: float = 0.99999
    dropout_ratio: float = 0.25

class MolecularGANCL(nn.Module):
    """Combined GAN and Contrastive Learning framework"""
    def __init__(self, config: GanClConfig):
        super().__init__()
        self.config = config
        
        # Add weight initialization
        def init_weights(m):
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                m.bias.data.fill_(0.01)
        
        # Initialize networks
        self.generator = GraphGenerator(
            config.node_dim, 
            config.edge_dim, 
            config.hidden_dim * 2
        )
        
        self.encoder = GraphDiscriminator(
            config.node_dim,
            config.edge_dim,
            config.hidden_dim,
            config.output_dim
        )
        self.encoder.apply(init_weights)
        
        # Modified loss weights
        self.contrastive_weight = 1.0
        self.adversarial_weight = 0.1  # Increased from 0.05
        self.similarity_weight = 0.01  # Decreased from 0.1
        
        # Temperature annealing
        self.initial_temperature = 0.1
        self.min_temperature = 0.05        
        
        # Create momentum encoder
        self.momentum_encoder = copy.deepcopy(self.encoder)
        for param in self.momentum_encoder.parameters():
            param.requires_grad = False
            
        # Initialize memory queue
        self.memory_queue = MemoryQueue(
            config.queue_size,
            config.output_dim,
            config.decay
        )
        
    @torch.no_grad()
    def _momentum_update(self):
        """Update momentum encoder"""
        for param_q, param_k in zip(self.encoder.parameters(), 
                                  self.momentum_encoder.parameters()):
            param_k.data = self.config.momentum * param_k.data + \
                          (1 - self.config.momentum) * param_q.data
                          
    def drop_graph_elements(self, data, node_scores: torch.Tensor, 
                          edge_scores: torch.Tensor) -> Data:
        """Apply dropout to graph based on importance scores"""
        # Select elements to keep based on scores and dropout ratio
#         node_mask = (node_scores < self.config.dropout_ratio).float()
#         edge_mask = (edge_scores < self.config.dropout_ratio).float()

        node_mask = (torch.rand_like(node_scores) > self.config.dropout_ratio).float()
        edge_mask = (torch.rand_like(edge_scores) > self.config.dropout_ratio).float()
        
        # Apply masks
        x_cat_new = data.x_cat * node_mask
        x_phys_new = data.x_phys * node_mask
        edge_attr_new = data.edge_attr * edge_mask
        
        # Create new graph data object
        return Data(
            x_cat=x_cat_new,
            x_phys=x_phys_new,
            edge_index=data.edge_index,
            edge_attr=edge_attr_new,
            batch=data.batch
        )
        
    def get_temperature(self, epoch, total_epochs):
        """Anneal temperature during training"""
        progress = epoch / total_epochs
        return max(self.initial_temperature * (1 - progress), self.min_temperature)
    
    def forward(self, data, epoch=0, total_epochs=50):
        # Get current temperature
        temperature = self.get_temperature(epoch, total_epochs)
        
        # Get importance scores from generator
        node_scores, edge_scores = self.generator(data)
        
        # Create perturbed graph
        perturbed_data = self.drop_graph_elements(data, node_scores, edge_scores)
        
        # Get embeddings
        query_emb = self.encoder(perturbed_data)
        with torch.no_grad():
            key_emb = self.momentum_encoder(data)
            original_emb = self.encoder(data).detach()
        
        # Compute losses with modified weights
        contrastive_loss = self.memory_queue.compute_contrastive_loss(
            query_emb, key_emb, temperature
        ) * self.contrastive_weight
        
        adversarial_loss = -F.mse_loss(query_emb, original_emb) * self.adversarial_weight
        similarity_loss = F.mse_loss(query_emb, original_emb) * self.similarity_weight
        
        return contrastive_loss, adversarial_loss, similarity_loss
    
    def get_embeddings(self, data) -> torch.Tensor:
        """Get embeddings for downstream tasks"""
        with torch.no_grad():
            return self.encoder(data)

In [4]:
def extract_molecule_metadata(dataset):
    """Extract metadata from PyG graph data without relying on SMILES strings"""
    from rdkit import Chem
    from rdkit.Chem import Descriptors
    from tqdm import tqdm
    import numpy as np
    import torch
    import networkx as nx
    from collections import defaultdict
    
    metadata = []
    
    for i, data in enumerate(tqdm(dataset, desc="Extracting molecule metadata")):
        # Set graph ID
        mol_id = f"molecule_{i}"
        
        # Initialize empty dictionaries for metadata
        properties = {}
        features = {}
        functional_groups = {}
        ring_info = {"ring_counts": {}, "ring_sizes": {}}
        
        # Extract basic graph properties directly from the PyG data
        if hasattr(data, 'num_nodes') and hasattr(data, 'edge_index'):
            try:
                # Convert to networkx graph for analysis
                G = to_networkx(data)
                
                # Calculate graph-level properties
                num_edges = data.edge_index.size(1) // 2  # Count unique edges
                properties = {
                    "num_nodes": data.num_nodes,
                    "num_edges": num_edges,
                    "avg_node_degree": 2 * num_edges / data.num_nodes if data.num_nodes > 0 else 0
                }
                
                # Calculate average path length if graph is connected
                if nx.is_connected(G):
                    try:
                        properties["avg_path_length"] = nx.average_shortest_path_length(G)
                    except:
                        properties["avg_path_length"] = 0.0
                else:
                    properties["avg_path_length"] = 0.0
                
                # Add more sophisticated graph properties
                try:
                    properties["clustering_coefficient"] = nx.average_clustering(G)
                except:
                    properties["clustering_coefficient"] = 0.0
                
                try:
                    properties["graph_diameter"] = nx.diameter(G) if nx.is_connected(G) else 0
                except:
                    properties["graph_diameter"] = 0

                try:
                    properties["assortativity"] = nx.degree_assortativity_coefficient(G)
                except:
                    properties["assortativity"] = 0.0
                
                # Graph features
                features = {
                    "is_connected": nx.is_connected(G),
                    "num_connected_components": nx.number_connected_components(G),
                    "has_cycles": not nx.is_tree(G),
                    "max_degree": max(dict(G.degree()).values()) if G.number_of_nodes() > 0 else 0,
                    "density": nx.density(G),
                    "is_bipartite": nx.is_bipartite(G) if G.number_of_nodes() > 0 else False
                }
                
                # Get centrality measures
                if G.number_of_nodes() > 0:
                    try:
                        degree_centrality = nx.degree_centrality(G)
                        features["max_centrality"] = max(degree_centrality.values()) if degree_centrality else 0
                        features["avg_centrality"] = sum(degree_centrality.values()) / len(degree_centrality) if degree_centrality else 0
                    except:
                        features["max_centrality"] = 0
                        features["avg_centrality"] = 0
                else:
                    features["max_centrality"] = 0
                    features["avg_centrality"] = 0
                
                # Analyze node features if available
                if hasattr(data, 'x_cat') and hasattr(data, 'x_phys'):
                    # Atomic element distribution (from x_cat)
                    atom_types = {}
                    if data.x_cat.size(1) > 0:
                        for i in range(data.num_nodes):
                            atom_type = int(data.x_cat[i, 0].item())
                            atom_types[atom_type] = atom_types.get(atom_type, 0) + 1
                    
                    features["atom_type_distribution"] = atom_types
                    
                    # Physical property statistics (from x_phys)
                    if data.x_phys.size(1) > 0:
                        phys_means = data.x_phys.mean(dim=0).tolist() 
                        phys_stds = data.x_phys.std(dim=0).tolist()
                        
                        # Map indices to meaningful property names for the first few common properties
                        phys_prop_names = ['contrib_mw', 'contrib_logp', 'formal_charge', 
                                        'hybridization', 'is_aromatic', 'num_h', 'valence', 'degree']
                        
                        for idx, name in enumerate(phys_prop_names):
                            if idx < len(phys_means):
                                properties[f"avg_{name}"] = phys_means[idx]
                                properties[f"std_{name}"] = phys_stds[idx]
                
                # Cycle analysis
                cycles = list(nx.cycle_basis(G))
                cycle_count = len(cycles)
                ring_info["ring_counts"]["total"] = cycle_count
                
                # Count rings by size
                ring_sizes = defaultdict(int)
                for cycle in cycles:
                    size = len(cycle)
                    ring_sizes[str(size)] = ring_sizes.get(str(size), 0) + 1
                
                # Ensure we have entries for common ring sizes
                for size in range(3, 11):
                    if str(size) not in ring_sizes:
                        ring_sizes[str(size)] = 0
                
                ring_info["ring_sizes"] = dict(ring_sizes)
                
                # Estimate ring types
                ring_info["ring_counts"]["single"] = 0
                ring_info["ring_counts"]["fused"] = 0
                
                # Identify single vs fused rings by checking for shared nodes
                if cycles:
                    # Build a mapping of nodes to cycles they belong to
                    node_to_cycles = defaultdict(list)
                    for cycle_idx, cycle in enumerate(cycles):
                        for node in cycle:
                            node_to_cycles[node].append(cycle_idx)
                    
                    # Count single rings (no shared nodes with other rings)
                    shared_cycles = set()
                    for node, cycle_list in node_to_cycles.items():
                        if len(cycle_list) > 1:
                            for c in cycle_list:
                                shared_cycles.add(c)
                    
                    ring_info["ring_counts"]["single"] = cycle_count - len(shared_cycles)
                    ring_info["ring_counts"]["fused"] = len(shared_cycles)
                
                # Edge feature analysis if available
                if hasattr(data, 'edge_attr') and data.edge_attr.size(0) > 0:
                    # Analyze bond types (assuming first dimension is bond type)
                    bond_types = {}
                    for i in range(data.edge_attr.size(0)):
                        if data.edge_attr.size(1) > 0:
                            bond_type = int(data.edge_attr[i, 0].item())
                            bond_types[bond_type] = bond_types.get(bond_type, 0) + 1
                    
                    # Divide by 2 since each bond is counted twice in undirected graph
                    for bt in bond_types:
                        bond_types[bt] = bond_types[bt] // 2
                    
                    functional_groups["bond_types"] = bond_types
                    
                    # Count functional group proxies based on patterns in the graph
                    # This is just an estimate since we don't have chemical information
                    conjugated_bonds = 0
                    for i in range(data.edge_attr.size(0)):
                        if data.edge_attr.size(1) > 1 and data.edge_attr[i, 2].item() > 0:  # IsConjugated flag
                            conjugated_bonds += 1
                    
                    functional_groups["conjugated_bonds"] = conjugated_bonds // 2
                    
            except Exception as e:
                # If any error occurs during analysis, use minimal information
                print(f"Error analyzing graph {i}: {e}")
        
        metadata.append({
            "graph_id": mol_id,
            "properties": properties,
            "features": features,
            "functional_groups": functional_groups,
            "ring_info": ring_info
        })
    
    return metadata

def to_networkx(data):
    """Convert PyG data to networkx graph for analysis"""
    import networkx as nx
    
    G = nx.Graph()
    
    # Add nodes
    for i in range(data.num_nodes):
        G.add_node(i)
    
    # Add edges (removing duplicates and self-loops)
    edge_index = data.edge_index.cpu().numpy()
    edges = set()
    for i in range(edge_index.shape[1]):
        u, v = edge_index[0, i], edge_index[1, i]
        if u != v and (u, v) not in edges and (v, u) not in edges:
            G.add_edge(u, v)
            edges.add((u, v))
    
    return G

In [5]:
def save_embedding_file(embeddings, molecule_indices, training_info, model_config, filepath):
    """Save embeddings with training metadata"""
    data = {
        "embeddings": embeddings,
        "molecule_indices": molecule_indices,
        "training_info": training_info,
        "model_config": {k: v for k, v in model_config.__dict__.items() 
                         if not k.startswith('_') and not callable(v)}
    }
    
    with open(filepath, 'wb') as f:
        pickle.dump(data, f)
        
def extract_negative_pair_distances(model, train_loader, device, epoch_num, num_samples=1000):
    """Extract distances between negative pairs at a specific epoch"""
    model.eval()
    negative_distances = []
    
    with torch.no_grad():
        for batch in tqdm(train_loader, desc=f"Extracting negative pairs at epoch {epoch_num}"):
            batch = batch.to(device)
            
            # Get embeddings from encoder and momentum encoder
            query_emb = model.encoder(batch)
            key_emb = model.momentum_encoder(batch)
            
            # Normalize embeddings
            query_emb = F.normalize(query_emb, dim=1)
            key_emb = F.normalize(key_emb, dim=1)
            
            # Compute distances between all pairs
            similarity_matrix = torch.mm(query_emb, key_emb.T)
            
            # Get off-diagonal elements (negative pairs)
            mask = ~torch.eye(similarity_matrix.shape[0], dtype=torch.bool, device=device)
            negative_similarities = similarity_matrix[mask]
            
            # Convert to distances (1 - similarity)
            distances = 1 - negative_similarities
            
            negative_distances.append(distances.cpu().numpy())
            
            # If we have enough samples, break
            if len(negative_distances) * batch.batch[-1].item() > num_samples:
                break
    
    # Concatenate and sample if necessary
    negative_distances = np.concatenate(negative_distances)
    if len(negative_distances) > num_samples:
        indices = np.random.choice(len(negative_distances), num_samples, replace=False)
        negative_distances = negative_distances[indices]
    
    return negative_distances
        
def save_embeddings_with_molecules(embeddings, dataset, filepath):
    """Save embeddings with corresponding molecule information and graph-level properties"""
    # Create a list to store molecule data
    molecule_data = []
    
    # Extract important info from each molecule in the dataset
    for data in dataset:
        # Create a dictionary with basic graph properties
        mol_info = {
            "num_nodes": data.num_nodes,
            "edge_index": data.edge_index.tolist() if hasattr(data, 'edge_index') else None,
            "x_cat": data.x_cat.tolist() if hasattr(data, 'x_cat') else None,
            "x_phys": data.x_phys.tolist() if hasattr(data, 'x_phys') else None,
            "edge_attr": data.edge_attr.tolist() if hasattr(data, 'edge_attr') else None,
            "smiles": data.smiles if hasattr(data, 'smiles') else ""  # This is missing in your code
        }
        
        # Calculate additional graph properties if possible
        try:
            if hasattr(data, 'edge_index') and hasattr(data, 'num_nodes'):
                # Graph density
                num_edges = len(data.edge_index[0]) // 2  # Undirected edges counted once
                max_edges = data.num_nodes * (data.num_nodes - 1) // 2
                density = num_edges / max_edges if max_edges > 0 else 0
                mol_info["graph_density"] = density
                
                # Average degree
                avg_degree = num_edges * 2 / data.num_nodes if data.num_nodes > 0 else 0
                mol_info["avg_degree"] = avg_degree
                
                # Count atom types if available
                if hasattr(data, 'x_cat') and data.x_cat is not None:
                    atom_types = {}
                    for atom in data.x_cat:
                        atom_type = int(atom[0])
                        atom_types[atom_type] = atom_types.get(atom_type, 0) + 1
                    mol_info["atom_type_counts"] = atom_types
                
                # Count bond types if available
                if hasattr(data, 'edge_attr') and data.edge_attr is not None:
                    bond_types = {}
                    for bond in data.edge_attr:
                        bond_type = int(bond[0])
                        bond_types[bond_type] = bond_types.get(bond_type, 0) + 1
                    mol_info["bond_type_counts"] = bond_types
        except:
            # If calculation fails, continue without these properties
            pass
            
        molecule_data.append(mol_info)
    
    # Save both embeddings and molecule data
    with open(filepath, 'wb') as f:
        pickle.dump({
            'embeddings': embeddings,
            'molecule_data': molecule_data,
            'graph_properties': True,  # Flag to indicate enhanced properties are stored
            'smiles_list': [data.smiles for data in dataset if hasattr(data, 'smiles')] # This is missing
        }, f)
    
    print(f"Saved embeddings and molecule data with graph properties to {filepath}")
        

def save_embeddings(embeddings, labels, filepath):
    """Save embeddings and corresponding labels"""
    with open(filepath, 'wb') as f:
        pickle.dump({
            'embeddings': embeddings,
            'labels': labels
        }, f)

def save_encoder(encoder, save_path, info=None):
    """Save encoder model for downstream tasks"""
    save_dict = {
        'encoder_state_dict': encoder.state_dict(),
        'model_info': info or {}
    }
    torch.save(save_dict, save_path)

def load_encoder(model_path, device='cpu'):
    """Load saved encoder model"""
    checkpoint = torch.load(model_path, map_location=device)
    encoder = GraphDiscriminator(
        node_dim=checkpoint['model_info'].get('node_dim'),
        edge_dim=checkpoint['model_info'].get('edge_dim'),
        hidden_dim=checkpoint['model_info'].get('hidden_dim', 128),
        output_dim=checkpoint['model_info'].get('output_dim', 128)
    )
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    return encoder     

def extract_embeddings_at_epoch(model, train_loader, device, epoch_num):
    """Extract embeddings at a specific epoch for visualization/analysis"""
    model.eval()
    epoch_embeddings = []
    molecule_indices = []
    
    with torch.no_grad():
        for i, batch in enumerate(tqdm(train_loader, desc=f"Extracting embeddings at epoch {epoch_num}")):
            batch = batch.to(device)
            # Use the momentum encoder for consistency in the momentum-based approach
            embeddings = model.momentum_encoder(batch)
            epoch_embeddings.append(embeddings.cpu().numpy())
            
            # Track molecule indices 
            if hasattr(batch, 'batch'):
                batch_size = batch.batch[-1].item() + 1
                indices = [i * batch_size + j for j in range(batch_size)]
                molecule_indices.extend(indices)
            else:
                molecule_indices.append(i)
    
    # Concatenate all embeddings
    epoch_embeddings = np.vstack(epoch_embeddings)
    
    return epoch_embeddings, molecule_indices
        
def train_gan_cl(train_loader, config, dataset, device='cuda', 
                save_dir='./checkpoints', 
                embedding_dir='./embeddings'):
    """Main training function for GAN-CL with embedding storage for bias analysis"""
    
    # Create directories
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(embedding_dir, exist_ok=True)
    metadata_dir = os.path.join(embedding_dir, 'metadata')
    os.makedirs(metadata_dir, exist_ok=True)
    encoder_dir = os.path.join(save_dir, 'encoders')
    os.makedirs(encoder_dir, exist_ok=True)
    
    # Extract and save molecule metadata (once, before training)
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    metadata = extract_molecule_metadata(dataset)
    with open(os.path.join(metadata_dir, f'molecule_metadata_{timestamp}.pkl'), 'wb') as f:
        pickle.dump(metadata, f)
    
    # Save molecule indices for consistent order
    molecule_indices = list(range(len(metadata)))
    
    # Initialize model
    model = MolecularGANCL(config).to(device)
    
    # Get pre-training embeddings before any training
    print("Extracting pre-training embeddings...")
    model.eval()
    pre_training_embeddings = []
    with torch.no_grad():
        for batch in tqdm(train_loader, desc="Pre-training embeddings"):
            batch = batch.to(device)
            embeddings = model.get_embeddings(batch)
            pre_training_embeddings.append(embeddings.cpu())
    
    pre_training_embeddings = torch.cat(pre_training_embeddings, dim=0).numpy()
    
    # Save pre-training embeddings
    pre_training_info = {
        "stage": "pre",
        "epoch": 0,
        "timestamp": timestamp,
        "loss_values": {"contrastive": 0, "adversarial": 0, "similarity": 0, "total": 0}
    }
    
    save_embedding_file(
        pre_training_embeddings, 
        molecule_indices,
        pre_training_info, 
        config, 
        os.path.join(embedding_dir, f'pre_training_embeddings_{timestamp}.pkl')
    )
    
    model.train()
    
    # Rest of your training code remains the same...
    # Initialize optimizers
    optimizer_encoder = torch.optim.Adam(model.encoder.parameters(), lr=3e-4)
    optimizer_generator = torch.optim.Adam(model.generator.parameters(), lr=1e-4)
    
    # Save initial model info
    model_info = {
        'node_dim': config.node_dim,
        'edge_dim': config.edge_dim,
        'hidden_dim': config.hidden_dim,
        'output_dim': config.output_dim,
        'training_config': config.__dict__
    }
    
    # Training phases as before...
    best_loss = float('inf')
    
    # Training metrics
    metrics = {
        'contrastive_losses': [],
        'adversarial_losses': [],
        'similarity_losses': [],
        'total_losses': []
    }
    
    # Training phases
    print("Phase 1: Pretraining Contrastive Learning...")
    pretrain_epochs = 10
    for epoch in range(pretrain_epochs):
        contrastive_epoch_loss = 0
        
        for batch in tqdm(train_loader, desc=f'Pretrain Epoch {epoch+1}/{pretrain_epochs}'):
            batch = batch.to(device)
            
            # Forward pass (without generator)
            query_emb = model.encoder(batch)
            with torch.no_grad():
                key_emb = model.momentum_encoder(batch)
            
            # Compute contrastive loss
            contrastive_loss = model.memory_queue.compute_contrastive_loss(
                query_emb, key_emb, model.config.temperature
            )
            
            # Update encoder
            optimizer_encoder.zero_grad()
            contrastive_loss.backward()
            optimizer_encoder.step()
            
            # Update momentum encoder
            model._momentum_update()
            
            # Update memory queue
            model.memory_queue.update_queue(key_emb.detach())
            
            contrastive_epoch_loss += contrastive_loss.item()
            
        avg_loss = contrastive_epoch_loss / len(train_loader)
        metrics['contrastive_losses'].append(avg_loss)
        print(f'Pretrain Epoch {epoch+1}, Avg Loss: {avg_loss:.4f}')
        
        # Save embeddings at key epochs for temporal analysis
        if (epoch + 1) % 2 == 0 or epoch == 0:
            epoch_embeddings, epoch_molecule_indices = extract_embeddings_at_epoch(model, train_loader, device, epoch+1)

            # Save epoch embeddings
            epoch_info = {
                "stage": f"pretrain_epoch_{epoch+1}",
                "epoch": epoch + 1,
                "timestamp": datetime.now().strftime('%Y%m%d_%H%M%S'),
                "loss_values": {"contrastive": avg_loss}
            }

            save_embedding_file(
                epoch_embeddings,
                epoch_molecule_indices,
                epoch_info,
                config,
                os.path.join(embedding_dir, f'pretrain_epoch_{epoch+1}_embeddings_{timestamp}.pkl')
            )

            # Add negative pair extraction here:
            negative_distances = extract_negative_pair_distances(model, train_loader, device, epoch+1)

            # Save to file
            negative_pair_file = os.path.join(embedding_dir, f'pretrain_negative_pairs_epoch_{epoch+1}_{timestamp}.pkl')
            with open(negative_pair_file, 'wb') as f:
                pickle.dump({
                    'epoch': epoch + 1,
                    'negative_distances': negative_distances,
                    'timestamp': datetime.now().strftime('%Y%m%d_%H%M%S')
                }, f)

            print(f"Saved negative pair analysis data to {negative_pair_file}")      
        
        # Save pretrained checkpoint
        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_encoder_state_dict': optimizer_encoder.state_dict(),
                'loss': avg_loss,
            }, os.path.join(save_dir, f'pretrain_checkpoint_{epoch+1}.pt'))
    
    print("\nPhase 2: Training GAN-CL...")
    train_epochs = 50
#     train_epochs = 10
    for epoch in range(train_epochs):
        epoch_losses = {
            'contrastive': 0,
            'adversarial': 0,
            'similarity': 0,
            'total': 0
        }

        for batch in tqdm(train_loader, desc=f'Train Epoch {epoch+1}/{train_epochs}'):
            batch = batch.to(device)

            # Step 1: Train Encoder
            optimizer_encoder.zero_grad()

            # Get importance scores from generator
            with torch.no_grad():
                node_scores, edge_scores = model.generator(batch)

            # Create perturbed graph
            perturbed_data = model.drop_graph_elements(batch, node_scores, edge_scores)

            # Get embeddings
            query_emb = model.encoder(perturbed_data)
            with torch.no_grad():
                key_emb = model.momentum_encoder(batch)
                original_emb = model.encoder(batch).detach()

            # Compute losses for encoder
            contrastive_loss = model.memory_queue.compute_contrastive_loss(
                query_emb, key_emb, model.config.temperature
            )
            similarity_loss = F.mse_loss(query_emb, original_emb)

            # Total loss for encoder
            encoder_loss = contrastive_loss + 0.1 * similarity_loss

            # Update encoder
            encoder_loss.backward()
            optimizer_encoder.step()

            # Update momentum encoder
            model._momentum_update()

            # Step 2: Train Generator
            optimizer_generator.zero_grad()

            # Get new embeddings for adversarial loss
            node_scores, edge_scores = model.generator(batch)
            perturbed_data = model.drop_graph_elements(batch, node_scores, edge_scores)

            with torch.no_grad():
                original_emb = model.encoder(batch)
            perturbed_emb = model.encoder(perturbed_data)

            # Compute adversarial loss
            adversarial_loss = -F.mse_loss(perturbed_emb, original_emb)

            # Update generator
            adversarial_loss.backward()
            optimizer_generator.step()

            # Update memory queue
            model.memory_queue.update_queue(key_emb.detach())

            # Update metrics
            epoch_losses['contrastive'] += contrastive_loss.item()
            epoch_losses['adversarial'] += adversarial_loss.item()
            epoch_losses['similarity'] += similarity_loss.item()
            epoch_losses['total'] += encoder_loss.item()

        # Average losses
        for k in epoch_losses:
            epoch_losses[k] /= len(train_loader)
            metrics[f'{k}_losses'].append(epoch_losses[k])

        print(f'Epoch {epoch+1}, Losses: {epoch_losses}')

        if (epoch + 1) % 10 == 0 or epoch == 0:
            epoch_embeddings, epoch_molecule_indices = extract_embeddings_at_epoch(model, train_loader, device, epoch+1)

            # Save checkpoint embeddings with training info
            checkpoint_info = {
                "stage": f"epoch_{epoch+1}",
                "epoch": epoch + 1,
                "timestamp": datetime.now().strftime('%Y%m%d_%H%M%S'),
                "loss_values": epoch_losses
            }

            save_embedding_file(
                epoch_embeddings,
                epoch_molecule_indices,
                checkpoint_info,
                config,
                os.path.join(embedding_dir, f'epoch_{epoch+1}_embeddings_{timestamp}.pkl')
            )    
            
            # Add negative pair extraction here:
            negative_distances = extract_negative_pair_distances(model, train_loader, device, epoch+1)

            # Save to file
            negative_pair_file = os.path.join(embedding_dir, f'negative_pairs_epoch_{epoch+1}_{timestamp}.pkl')
            with open(negative_pair_file, 'wb') as f:
                pickle.dump({
                    'epoch': epoch + 1,
                    'negative_distances': negative_distances,
                    'timestamp': datetime.now().strftime('%Y%m%d_%H%M%S')
                }, f)

            print(f"Saved negative pair analysis data to {negative_pair_file}")            
        
        # Save checkpoint
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_encoder_state_dict': optimizer_encoder.state_dict(),
                'optimizer_generator_state_dict': optimizer_generator.state_dict(),
                'losses': epoch_losses,
            }, os.path.join(save_dir, f'gan_cl_checkpoint_{epoch+1}.pt'))            
    
#     # Extract and save embeddings periodically
#     if (epoch + 1) % 10 == 0:
#         model.eval()
#         checkpoint_embeddings = []
        
#         with torch.no_grad():
#             for batch in train_loader:
#                 batch = batch.to(device)
#                 embeddings = model.get_embeddings(batch)
#                 checkpoint_embeddings.append(embeddings.cpu())
        
#         checkpoint_embeddings = torch.cat(checkpoint_embeddings, dim=0).numpy()
        
#         # Save checkpoint embeddings with training info
#         checkpoint_info = {
#             "stage": f"epoch_{epoch+1}",
#             "epoch": epoch + 1,
#             "timestamp": datetime.now().strftime('%Y%m%d_%H%M%S'),
#             "loss_values": epoch_losses
#         }
        
#         save_embedding_file(
#             checkpoint_embeddings,
#             molecule_indices,
#             checkpoint_info,
#             config,
#             os.path.join(embedding_dir, f'epoch_{epoch+1}_embeddings_{timestamp}.pkl')
#         )
        
#         model.train()
    
    # Extract and save post-training embeddings at the end
    print("Extracting post-training embeddings...")
    model.eval()
    post_training_embeddings = []
    
    with torch.no_grad():
        for batch in tqdm(train_loader, desc="Post-training embeddings"):
            batch = batch.to(device)
            embeddings = model.get_embeddings(batch)
            post_training_embeddings.append(embeddings.cpu())
        
    post_training_embeddings = torch.cat(post_training_embeddings, dim=0).numpy()
    
    # Save post-training embeddings
    post_training_info = {
        "stage": "post",
        "epoch": train_epochs,
        "timestamp": datetime.now().strftime('%Y%m%d_%H%M%S'),
        "loss_values": epoch_losses
    }
    
    save_embedding_file(
        post_training_embeddings,
        molecule_indices,
        post_training_info,
        config,
        os.path.join(embedding_dir, f'post_training_embeddings_{timestamp}.pkl')
    )
    
    return model, metrics

In [6]:
def main():
     # Enable anomaly detection during development
    torch.autograd.set_detect_anomaly(True)
    # Your existing data loading code here
    torch.manual_seed(42)
    np.random.seed(42)
    
    print("Starting data loading...")
    extractor = MolecularFeatureExtractor()
#     smiles_file = "D:\\PhD\\Chapter3\\Unsupervised_GAN_Code\\pubchem-41-clean.txt"
    smiles_file = "D:\\PhD\\Chapter3\\Unsupervised_GAN_Code\\pubchem-10m-clean_test10k.txt"
#     smiles_file = "D:\\PhD\\Chapter3\\Unsupervised_GAN_Code\\pubchem-10m-clean_test50k.txt"
    
    dataset = []
    failed_smiles = []
    
    smiles_list = []
    
    with open(smiles_file, 'r') as f:
        for line in f:
            smiles = line.strip()
            smiles_list.append(smiles)
            data = extractor.process_molecule(smiles)
            if data is not None:
                # Store SMILES as an attribute for later use
                data.smiles = smiles
                dataset.append(data)
            else:
                failed_smiles.append(smiles)
    
    print(f"1. Loaded dataset with {len(dataset)} graphs.")
    print(f"2. Failed SMILES count: {len(failed_smiles)}")
    
    if not dataset:
        print("No valid graphs generated.")
        return None
        
    # Setup training
    batch_size = 32
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    print(f"3. Created DataLoader with {len(train_loader.dataset)} graphs")
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"4. Using device: {device}")
    
    # Get configuration based on dataset
    config = get_model_config(dataset)   
   
    # Train model
    print("5. Starting GAN-CL training...")
    model, metrics = train_gan_cl(
        train_loader, 
        config,
        dataset,  # Pass the dataset for metadata extraction
        device=device,
        save_dir='./checkpoints',
        embedding_dir='./embeddings'
    )    
    
    
    print("6. Training completed!")
    
    # Extract embeddings for XAI
    print("7. Extracting final embeddings for XAI...")
    model.eval()
    with torch.no_grad():
        all_embeddings = []
        all_graphs = []
        
        for batch in tqdm(train_loader, desc="Extracting embeddings"):
            batch = batch.to(device)
            embeddings = model.get_embeddings(batch)
            all_embeddings.append(embeddings.cpu())
            all_graphs.extend([data for data in batch])
            
    all_embeddings = torch.cat(all_embeddings, dim=0).numpy()
    
    # Save final embeddings and graphs
#     final_embedding_path = './embeddings/final_embeddings.pkl'
#     final_embedding_path = f'./embeddings/final_embeddings_{timestamp}.pkl'
#     save_embeddings(all_embeddings, all_graphs, final_embedding_path)
#     print(f"8. Final embeddings saved to {final_embedding_path}")
    
    # Update your final embedding saving code
    final_embedding_path = f'./embeddings/final_embeddings_molecules_ME_{timestamp}.pkl'
    save_embeddings_with_molecules(all_embeddings, dataset, final_embedding_path)
    
    # Create a dictionary with embeddings and molecule data that includes SMILES
    embedding_data = {
        'embeddings': embeddings,
        'molecule_data': [{
            'smiles': getattr(data, 'smiles', ''),  # Extract SMILES if available
            'num_nodes': data.num_nodes,
            'edge_index': data.edge_index.tolist() if hasattr(data, 'edge_index') else None,
            'x_cat': data.x_cat.tolist() if hasattr(data, 'x_cat') else None,
            'x_phys': data.x_phys.tolist() if hasattr(data, 'x_phys') else None,
            'edge_attr': data.edge_attr.tolist() if hasattr(data, 'edge_attr') else None
        } for data in dataset],
        'smiles_list': [getattr(data, 'smiles', '') for data in dataset]  # Explicit SMILES list
    }

    with open(final_embedding_path, 'wb') as f:
        pickle.dump(embedding_data, f)
    
    print(f"8. Final embeddings saved to {final_embedding_path}")    
    
    
    # Print encoder locations
    print(f"9. Encoders saved in ./checkpoints/encoders/:")
    print(f"   - Best encoder: best_encoder.pt")
    print(f"   - Final encoder: final_encoder.pt")
    print(f"   - Periodic encoders: encoder_epoch_*.pt")
    
    return model, metrics, all_embeddings, all_graphs

if __name__ == "__main__":
    model, metrics, embeddings, graphs = main()

Starting data loading...
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[19:16:17] UFFTYPER: Unrecognized atom type: Se2+2 (17)


Failed to generate 3D conformer


[19:16:21] UFFTYPER: Unrecognized charge state for atom: 1


Failed to generate 3D conformer


[19:16:30] UFFTYPER: Unrecognized charge state for atom: 4


Failed to generate 3D conformer


[19:17:00] UFFTYPER: Unrecognized atom type: S_5+4 (11)


Failed to generate 3D conformer


[19:17:21] UFFTYPER: Unrecognized atom type: Se2+2 (14)
[19:17:21] UFFTYPER: Unrecognized charge state for atom: 20
[19:17:21] UFFTYPER: Unrecognized charge state for atom: 40


Failed to generate 3D conformer


[19:17:21] UFFTYPER: Unrecognized charge state for atom: 5
[19:17:26] UFFTYPER: Unrecognized charge state for atom: 9
[19:17:31] UFFTYPER: Unrecognized charge state for atom: 2


Failed to generate 3D conformer


[19:17:44] UFFTYPER: Unrecognized charge state for atom: 8
[19:18:20] UFFTYPER: Unrecognized charge state for atom: 6


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[19:20:06] UFFTYPER: Unrecognized charge state for atom: 2


Failed to generate 3D conformer


[19:20:24] UFFTYPER: Unrecognized charge state for atom: 3
[19:20:24] UFFTYPER: Unrecognized charge state for atom: 7


Failed to generate 3D conformer
Failed to generate 3D conformer


[19:21:17] UFFTYPER: Unrecognized atom type: S_6+6 (17)
[19:21:26] UFFTYPER: Unrecognized atom type: Se2+2 (7)
[19:21:26] UFFTYPER: Unrecognized atom type: Se2+2 (7)


Failed to generate 3D conformer


[19:21:45] UFFTYPER: Unrecognized charge state for atom: 5
[19:21:59] UFFTYPER: Unrecognized charge state for atom: 13


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[19:22:41] UFFTYPER: Unrecognized charge state for atom: 8
[19:22:55] UFFTYPER: Unrecognized atom type: Se2+2 (16)


Failed to generate 3D conformer


[19:23:16] UFFTYPER: Unrecognized charge state for atom: 4


Failed to generate 3D conformer
Failed to generate 3D conformer


[19:23:47] UFFTYPER: Unrecognized charge state for atom: 1


Failed to generate 3D conformer


[19:25:10] UFFTYPER: Unrecognized charge state for atom: 17
[19:25:10] UFFTYPER: Unrecognized charge state for atom: 19
[19:25:13] UFFTYPER: Unrecognized charge state for atom: 8


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[19:27:29] UFFTYPER: Unrecognized charge state for atom: 8
[19:27:31] UFFTYPER: Unrecognized charge state for atom: 1
[19:27:43] UFFTYPER: Unrecognized atom type: Se2+2 (9)


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[19:28:53] UFFTYPER: Unrecognized atom type: S_5+4 (1)
[19:29:03] UFFTYPER: Unrecognized atom type: S_5+4 (10)


Failed to generate 3D conformer


[19:29:21] UFFTYPER: Unrecognized hybridization for atom: 1
[19:29:21] UFFTYPER: Unrecognized atom type: Pt+2 (1)


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer


[19:33:05] UFFTYPER: Unrecognized charge state for atom: 3
[19:33:07] UFFTYPER: Unrecognized charge state for atom: 15
[19:33:11] UFFTYPER: Unrecognized atom type: S_5+4 (15)


Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
Failed to generate 3D conformer
1. Loaded dataset with 9937 graphs.
2. Failed SMILES count: 63
3. Created DataLoader with 9937 graphs
4. Using device: cpu
5. Starting GAN-CL training...


Extracting molecule metadata: 100%|███████████████████████████████████████████████| 9937/9937 [01:11<00:00, 139.29it/s]


Extracting pre-training embeddings...


Pre-training embeddings: 100%|███████████████████████████████████████████████████████| 311/311 [00:03<00:00, 82.04it/s]


Phase 1: Pretraining Contrastive Learning...


Pretrain Epoch 1/10: 100%|███████████████████████████████████████████████████████████| 311/311 [00:54<00:00,  5.70it/s]


Pretrain Epoch 1, Avg Loss: 7.0362


Extracting embeddings at epoch 1: 100%|██████████████████████████████████████████████| 311/311 [00:05<00:00, 59.79it/s]
Extracting negative pairs at epoch 1:  10%|████▍                                      | 32/311 [00:01<00:10, 27.09it/s]


Saved negative pair analysis data to ./embeddings\pretrain_negative_pairs_epoch_1_20250331_193711.pkl


Pretrain Epoch 2/10: 100%|███████████████████████████████████████████████████████████| 311/311 [00:51<00:00,  6.03it/s]


Pretrain Epoch 2, Avg Loss: 7.3130


Extracting embeddings at epoch 2: 100%|█████████████████████████████████████████████| 311/311 [00:02<00:00, 108.18it/s]
Extracting negative pairs at epoch 2:  10%|████▍                                      | 32/311 [00:00<00:04, 58.16it/s]


Saved negative pair analysis data to ./embeddings\pretrain_negative_pairs_epoch_2_20250331_193711.pkl


Pretrain Epoch 3/10: 100%|███████████████████████████████████████████████████████████| 311/311 [00:46<00:00,  6.72it/s]


Pretrain Epoch 3, Avg Loss: 6.6586


Pretrain Epoch 4/10: 100%|███████████████████████████████████████████████████████████| 311/311 [00:53<00:00,  5.82it/s]


Pretrain Epoch 4, Avg Loss: 5.7612


Extracting embeddings at epoch 4: 100%|██████████████████████████████████████████████| 311/311 [00:03<00:00, 94.60it/s]
Extracting negative pairs at epoch 4:  10%|████▍                                      | 32/311 [00:00<00:04, 56.81it/s]


Saved negative pair analysis data to ./embeddings\pretrain_negative_pairs_epoch_4_20250331_193711.pkl


Pretrain Epoch 5/10: 100%|███████████████████████████████████████████████████████████| 311/311 [00:52<00:00,  5.94it/s]


Pretrain Epoch 5, Avg Loss: 4.8598


Pretrain Epoch 6/10: 100%|███████████████████████████████████████████████████████████| 311/311 [00:52<00:00,  5.90it/s]


Pretrain Epoch 6, Avg Loss: 4.1123


Extracting embeddings at epoch 6: 100%|██████████████████████████████████████████████| 311/311 [00:03<00:00, 78.58it/s]
Extracting negative pairs at epoch 6:  10%|████▍                                      | 32/311 [00:00<00:04, 61.43it/s]


Saved negative pair analysis data to ./embeddings\pretrain_negative_pairs_epoch_6_20250331_193711.pkl


Pretrain Epoch 7/10: 100%|███████████████████████████████████████████████████████████| 311/311 [00:51<00:00,  6.03it/s]


Pretrain Epoch 7, Avg Loss: 3.4505


Pretrain Epoch 8/10: 100%|███████████████████████████████████████████████████████████| 311/311 [00:50<00:00,  6.17it/s]


Pretrain Epoch 8, Avg Loss: 3.0032


Extracting embeddings at epoch 8: 100%|██████████████████████████████████████████████| 311/311 [00:03<00:00, 94.07it/s]
Extracting negative pairs at epoch 8:  10%|████▍                                      | 32/311 [00:00<00:05, 55.79it/s]


Saved negative pair analysis data to ./embeddings\pretrain_negative_pairs_epoch_8_20250331_193711.pkl


Pretrain Epoch 9/10: 100%|███████████████████████████████████████████████████████████| 311/311 [00:51<00:00,  6.07it/s]


Pretrain Epoch 9, Avg Loss: 2.7002


Pretrain Epoch 10/10: 100%|██████████████████████████████████████████████████████████| 311/311 [00:53<00:00,  5.85it/s]


Pretrain Epoch 10, Avg Loss: 2.4607


Extracting embeddings at epoch 10: 100%|█████████████████████████████████████████████| 311/311 [00:03<00:00, 87.43it/s]
Extracting negative pairs at epoch 10:  10%|████▎                                     | 32/311 [00:00<00:06, 46.26it/s]


Saved negative pair analysis data to ./embeddings\pretrain_negative_pairs_epoch_10_20250331_193711.pkl

Phase 2: Training GAN-CL...


Train Epoch 1/50: 100%|██████████████████████████████████████████████████████████████| 311/311 [02:06<00:00,  2.45it/s]


Epoch 1, Losses: {'contrastive': 4.650007546139683, 'adversarial': -0.0006456751320428751, 'similarity': 0.0007096580100557213, 'total': 4.6500785128479984}


Extracting embeddings at epoch 1: 100%|██████████████████████████████████████████████| 311/311 [00:04<00:00, 67.54it/s]
Extracting negative pairs at epoch 1:  10%|████▍                                      | 32/311 [00:00<00:07, 39.52it/s]


Saved negative pair analysis data to ./embeddings\negative_pairs_epoch_1_20250331_193711.pkl


Train Epoch 2/50: 100%|██████████████████████████████████████████████████████████████| 311/311 [02:33<00:00,  2.03it/s]


Epoch 2, Losses: {'contrastive': 4.356454311076468, 'adversarial': -0.0007530057715204128, 'similarity': 0.000662174768157459, 'total': 4.356520538544731}


Train Epoch 3/50: 100%|██████████████████████████████████████████████████████████████| 311/311 [03:15<00:00,  1.59it/s]


Epoch 3, Losses: {'contrastive': 4.613133618302667, 'adversarial': -0.0006624814531866198, 'similarity': 0.0007266633356120318, 'total': 4.613206276172994}


Train Epoch 4/50: 100%|██████████████████████████████████████████████████████████████| 311/311 [04:49<00:00,  1.08it/s]


Epoch 4, Losses: {'contrastive': 4.632433862931475, 'adversarial': -0.0006414713104812494, 'similarity': 0.0006437572241473308, 'total': 4.632498221382067}


Train Epoch 5/50: 100%|██████████████████████████████████████████████████████████████| 311/311 [05:36<00:00,  1.08s/it]


Epoch 5, Losses: {'contrastive': 4.598774166352496, 'adversarial': -0.0005984925545177084, 'similarity': 0.0005941969322690507, 'total': 4.598833585475418}


Train Epoch 6/50: 100%|██████████████████████████████████████████████████████████████| 311/311 [06:06<00:00,  1.18s/it]


Epoch 6, Losses: {'contrastive': 4.47531638114782, 'adversarial': -0.000522431707959712, 'similarity': 0.0005265018124870021, 'total': 4.475369037155937}


Train Epoch 7/50: 100%|██████████████████████████████████████████████████████████████| 311/311 [05:25<00:00,  1.05s/it]


Epoch 7, Losses: {'contrastive': 4.426307822347071, 'adversarial': -0.00050665133707959, 'similarity': 0.0005595845621755987, 'total': 4.426363784017287}


Train Epoch 8/50: 100%|██████████████████████████████████████████████████████████████| 311/311 [02:02<00:00,  2.54it/s]


Epoch 8, Losses: {'contrastive': 4.285507311774987, 'adversarial': -0.0005298203441884871, 'similarity': 0.0004854243354635295, 'total': 4.285555858704055}


Train Epoch 9/50: 100%|██████████████████████████████████████████████████████████████| 311/311 [01:56<00:00,  2.66it/s]


Epoch 9, Losses: {'contrastive': 4.160951454156464, 'adversarial': -0.00046670709967229913, 'similarity': 0.0004536035560315529, 'total': 4.160996812716174}


Train Epoch 10/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [01:58<00:00,  2.62it/s]


Epoch 10, Losses: {'contrastive': 4.1450638226757475, 'adversarial': -0.0004581726386198337, 'similarity': 0.0005261295686934107, 'total': 4.145116446485857}


Extracting embeddings at epoch 10: 100%|█████████████████████████████████████████████| 311/311 [00:03<00:00, 96.93it/s]
Extracting negative pairs at epoch 10:  10%|████▎                                     | 32/311 [00:00<00:04, 57.15it/s]


Saved negative pair analysis data to ./embeddings\negative_pairs_epoch_10_20250331_193711.pkl


Train Epoch 11/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [01:57<00:00,  2.66it/s]


Epoch 11, Losses: {'contrastive': 4.134959253850857, 'adversarial': -0.0004531207774807714, 'similarity': 0.0004994010913894749, 'total': 4.135009183761008}


Train Epoch 12/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [02:01<00:00,  2.57it/s]


Epoch 12, Losses: {'contrastive': 4.176114017940411, 'adversarial': -0.0004754435370969691, 'similarity': 0.0004916270252263287, 'total': 4.176163182764575}


Train Epoch 13/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [02:02<00:00,  2.54it/s]


Epoch 13, Losses: {'contrastive': 4.101280538215515, 'adversarial': -0.000569565801685107, 'similarity': 0.0005719165970538399, 'total': 4.101337731842826}


Train Epoch 14/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [02:02<00:00,  2.54it/s]


Epoch 14, Losses: {'contrastive': 4.027249680454708, 'adversarial': -0.00048413297146112565, 'similarity': 0.00048244724746784526, 'total': 4.027297923802563}


Train Epoch 15/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [02:02<00:00,  2.55it/s]


Epoch 15, Losses: {'contrastive': 3.983715280075932, 'adversarial': -0.0005413853508903917, 'similarity': 0.00048125083775370356, 'total': 3.983763398464853}


Train Epoch 16/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [01:59<00:00,  2.60it/s]


Epoch 16, Losses: {'contrastive': 3.995094426958507, 'adversarial': -0.0004970230622751596, 'similarity': 0.000549451202746343, 'total': 3.9951493751581078}


Train Epoch 17/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [02:01<00:00,  2.57it/s]


Epoch 17, Losses: {'contrastive': 3.976814408777611, 'adversarial': -0.00047715088345823135, 'similarity': 0.00054916031332986, 'total': 3.97686933857835}


Train Epoch 18/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [02:09<00:00,  2.40it/s]


Epoch 18, Losses: {'contrastive': 4.002392411615304, 'adversarial': -0.0005299180493034255, 'similarity': 0.000540731310615334, 'total': 4.002446476669557}


Train Epoch 19/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [01:58<00:00,  2.63it/s]


Epoch 19, Losses: {'contrastive': 4.017126626140435, 'adversarial': -0.0004877619279871943, 'similarity': 0.0004780608147174611, 'total': 4.017174429448855}


Train Epoch 20/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [01:57<00:00,  2.66it/s]


Epoch 20, Losses: {'contrastive': 4.027648484975195, 'adversarial': -0.0005008852043057346, 'similarity': 0.0005528638080410792, 'total': 4.027703768187397}


Extracting embeddings at epoch 20: 100%|█████████████████████████████████████████████| 311/311 [00:03<00:00, 96.60it/s]
Extracting negative pairs at epoch 20:  10%|████▎                                     | 32/311 [00:00<00:05, 53.84it/s]


Saved negative pair analysis data to ./embeddings\negative_pairs_epoch_20_20250331_193711.pkl


Train Epoch 21/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [06:54<00:00,  1.33s/it]


Epoch 21, Losses: {'contrastive': 4.0083982254531225, 'adversarial': -0.0004983052254431369, 'similarity': 0.0004959777937123966, 'total': 4.00844783108334}


Train Epoch 22/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [02:48<00:00,  1.85it/s]


Epoch 22, Losses: {'contrastive': 3.9837480825626583, 'adversarial': -0.0004975314406862249, 'similarity': 0.0005473980195917656, 'total': 3.9838028161088754}


Train Epoch 23/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [02:03<00:00,  2.52it/s]


Epoch 23, Losses: {'contrastive': 3.9427213806814705, 'adversarial': -0.00045734769091986335, 'similarity': 0.0005088170217926587, 'total': 3.9427722535332683}


Train Epoch 24/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [02:22<00:00,  2.19it/s]


Epoch 24, Losses: {'contrastive': 3.8889980063177765, 'adversarial': -0.0004569157911988628, 'similarity': 0.0004630021944079081, 'total': 3.889044303219418}


Train Epoch 25/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [02:22<00:00,  2.18it/s]


Epoch 25, Losses: {'contrastive': 3.813388456102353, 'adversarial': -0.0004601707330895917, 'similarity': 0.0004613022847771884, 'total': 3.813434593163886}


Train Epoch 26/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [02:11<00:00,  2.36it/s]


Epoch 26, Losses: {'contrastive': 3.804657489923802, 'adversarial': -0.00045800589884472623, 'similarity': 0.00046401393938198734, 'total': 3.8047038949187546}


Train Epoch 27/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [02:04<00:00,  2.50it/s]


Epoch 27, Losses: {'contrastive': 3.721645618558313, 'adversarial': -0.0004444861158704643, 'similarity': 0.0004358659771651365, 'total': 3.721689204311064}


Train Epoch 28/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [02:15<00:00,  2.29it/s]


Epoch 28, Losses: {'contrastive': 3.687793789185895, 'adversarial': -0.0004629484705149325, 'similarity': 0.0004716604817778447, 'total': 3.6878409462343074}


Train Epoch 29/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [02:22<00:00,  2.19it/s]


Epoch 29, Losses: {'contrastive': 3.6240547140311583, 'adversarial': -0.0004826019085918245, 'similarity': 0.00047545919870101275, 'total': 3.6241022559224216}


Train Epoch 30/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [02:39<00:00,  1.95it/s]


Epoch 30, Losses: {'contrastive': 3.596701659573619, 'adversarial': -0.0005277804421130687, 'similarity': 0.00048209648473424856, 'total': 3.596749872256705}


Extracting embeddings at epoch 30: 100%|█████████████████████████████████████████████| 311/311 [00:04<00:00, 69.26it/s]
Extracting negative pairs at epoch 30:  10%|████▎                                     | 32/311 [00:00<00:07, 37.34it/s]


Saved negative pair analysis data to ./embeddings\negative_pairs_epoch_30_20250331_193711.pkl


Train Epoch 31/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [02:31<00:00,  2.05it/s]


Epoch 31, Losses: {'contrastive': 3.5314064531847595, 'adversarial': -0.0004865092748867414, 'similarity': 0.00047441064318345553, 'total': 3.531453895415524}


Train Epoch 32/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [05:08<00:00,  1.01it/s]


Epoch 32, Losses: {'contrastive': 3.493494133090666, 'adversarial': -0.00047735118041776264, 'similarity': 0.0005388090579203028, 'total': 3.4935480137729953}


Train Epoch 33/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [09:01<00:00,  1.74s/it]


Epoch 33, Losses: {'contrastive': 3.421871364691633, 'adversarial': -0.000525836481329449, 'similarity': 0.0005132169546963725, 'total': 3.421922686782297}


Train Epoch 34/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [04:22<00:00,  1.19it/s]


Epoch 34, Losses: {'contrastive': 3.4515383059572176, 'adversarial': -0.0005597971160589176, 'similarity': 0.0005266584120571039, 'total': 3.451590972698003}


Train Epoch 35/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [07:09<00:00,  1.38s/it]


Epoch 35, Losses: {'contrastive': 3.4293617527584552, 'adversarial': -0.0005403732826859897, 'similarity': 0.0005542252584655508, 'total': 3.4294171793284525}


Train Epoch 36/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [04:59<00:00,  1.04it/s]


Epoch 36, Losses: {'contrastive': 3.456750283287269, 'adversarial': -0.000572314433408362, 'similarity': 0.0005501765210160513, 'total': 3.456805305082315}


Train Epoch 37/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [02:35<00:00,  2.00it/s]


Epoch 37, Losses: {'contrastive': 3.4941857999544053, 'adversarial': -0.0005396291111233581, 'similarity': 0.0005420381426953462, 'total': 3.494239997633784}


Train Epoch 38/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [03:28<00:00,  1.49it/s]


Epoch 38, Losses: {'contrastive': 3.5461561396190975, 'adversarial': -0.0005320109133316666, 'similarity': 0.0005594873521065225, 'total': 3.5462120867235485}


Train Epoch 39/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [08:53<00:00,  1.71s/it]


Epoch 39, Losses: {'contrastive': 3.5364475553058736, 'adversarial': -0.0005654316817386333, 'similarity': 0.0005660411168842118, 'total': 3.536504158253072}


Train Epoch 40/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [06:30<00:00,  1.26s/it]


Epoch 40, Losses: {'contrastive': 3.500587451112999, 'adversarial': -0.0006272010603459953, 'similarity': 0.0005855924147405809, 'total': 3.5006460108557698}


Extracting embeddings at epoch 40: 100%|█████████████████████████████████████████████| 311/311 [00:18<00:00, 16.97it/s]
Extracting negative pairs at epoch 40:  10%|████▎                                     | 32/311 [00:03<00:31,  8.94it/s]


Saved negative pair analysis data to ./embeddings\negative_pairs_epoch_40_20250331_193711.pkl


Train Epoch 41/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [02:59<00:00,  1.74it/s]


Epoch 41, Losses: {'contrastive': 3.4838734999346963, 'adversarial': -0.000624191456072338, 'similarity': 0.0005902654563048345, 'total': 3.4839325265485757}


Train Epoch 42/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [02:38<00:00,  1.97it/s]


Epoch 42, Losses: {'contrastive': 3.470278547891083, 'adversarial': -0.0006347023576896314, 'similarity': 0.000615024755853363, 'total': 3.4703400526016086}


Train Epoch 43/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [02:37<00:00,  1.97it/s]


Epoch 43, Losses: {'contrastive': 3.514371895713438, 'adversarial': -0.0006242347676254557, 'similarity': 0.0006617204901175364, 'total': 3.514438069518356}


Train Epoch 44/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [02:37<00:00,  1.98it/s]


Epoch 44, Losses: {'contrastive': 3.4682931800363916, 'adversarial': -0.0006079022074305419, 'similarity': 0.0006135741312923718, 'total': 3.4683545379393355}


Train Epoch 45/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [02:36<00:00,  1.99it/s]


Epoch 45, Losses: {'contrastive': 3.4584362345876416, 'adversarial': -0.0006512198789970548, 'similarity': 0.0005930227152082375, 'total': 3.4584955348845847}


Train Epoch 46/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [02:36<00:00,  1.99it/s]


Epoch 46, Losses: {'contrastive': 3.487029386103345, 'adversarial': -0.0006287259415055586, 'similarity': 0.0006666966572469789, 'total': 3.48709605744414}


Train Epoch 47/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [02:37<00:00,  1.97it/s]


Epoch 47, Losses: {'contrastive': 3.471534126441195, 'adversarial': -0.0006321531115361815, 'similarity': 0.0006401651450754758, 'total': 3.4715981483459473}


Train Epoch 48/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [02:41<00:00,  1.93it/s]


Epoch 48, Losses: {'contrastive': 3.4073390516054207, 'adversarial': -0.0006600061508825414, 'similarity': 0.0006619525241959014, 'total': 3.4074052476422962}


Train Epoch 49/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [02:57<00:00,  1.75it/s]


Epoch 49, Losses: {'contrastive': 3.4287216395043867, 'adversarial': -0.0006740878495594361, 'similarity': 0.0006929966992619252, 'total': 3.428790948015317}


Train Epoch 50/50: 100%|█████████████████████████████████████████████████████████████| 311/311 [05:12<00:00,  1.00s/it]


Epoch 50, Losses: {'contrastive': 3.3692102225263785, 'adversarial': -0.0007152919725094206, 'similarity': 0.0007215246756207737, 'total': 3.36928238056097}


Extracting embeddings at epoch 50: 100%|█████████████████████████████████████████████| 311/311 [00:15<00:00, 19.94it/s]
Extracting negative pairs at epoch 50:  10%|████▎                                     | 32/311 [00:01<00:10, 27.00it/s]


Saved negative pair analysis data to ./embeddings\negative_pairs_epoch_50_20250331_193711.pkl
Extracting post-training embeddings...


Post-training embeddings: 100%|██████████████████████████████████████████████████████| 311/311 [00:10<00:00, 30.48it/s]


6. Training completed!
7. Extracting final embeddings for XAI...


Extracting embeddings: 100%|█████████████████████████████████████████████████████████| 311/311 [00:24<00:00, 12.76it/s]


Saved embeddings and molecule data with graph properties to ./embeddings/final_embeddings_molecules_ME_20250331_191547.pkl
8. Final embeddings saved to ./embeddings/final_embeddings_molecules_ME_20250331_191547.pkl
9. Encoders saved in ./checkpoints/encoders/:
   - Best encoder: best_encoder.pt
   - Final encoder: final_encoder.pt
   - Periodic encoders: encoder_epoch_*.pt
