In [1]:
import torch
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from typing import Tuple
import numpy as np
from torch_geometric.data import Data
import random
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from rdkit import RDLogger
from rdkit.Chem import RemoveHs
from datetime import datetime

# Suppress RDKit warnings
RDLogger.DisableLog('rdApp.warning')
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')



In [2]:
class MolecularFeatureExtractor:
    def __init__(self):
        self.atom_list = list(range(1, 119))
        self.chirality_list = [
            Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
            Chem.rdchem.ChiralType.CHI_OTHER
        ]
        self.bond_list = [
            Chem.rdchem.BondType.SINGLE,
            Chem.rdchem.BondType.DOUBLE, 
            Chem.rdchem.BondType.TRIPLE,
            Chem.rdchem.BondType.AROMATIC
        ]
        self.bonddir_list = [
            Chem.rdchem.BondDir.NONE,
            Chem.rdchem.BondDir.ENDUPRIGHT,
            Chem.rdchem.BondDir.ENDDOWNRIGHT
        ]

    def calc_atom_features(self, atom: Chem.Atom) -> Tuple[list, list]:
        """Calculate atom features with better error handling"""
        try:
            # Basic features
            atom_feat = [
                self.atom_list.index(atom.GetAtomicNum()),
                self.chirality_list.index(atom.GetChiralTag())
            ]

            # Physical features with error handling
            phys_feat = []
            
            # Molecular weight contribution
            try:
                contrib_mw = Descriptors.ExactMolWt(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_mw)
            except:
                phys_feat.append(0.0)
                
            # LogP contribution    
            try:
                contrib_logp = Descriptors.MolLogP(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_logp)
            except:
                phys_feat.append(0.0)
                
            # Add other physical properties
            phys_feat.extend([
                atom.GetFormalCharge(),
                int(atom.GetHybridization()),
                int(atom.GetIsAromatic()),
                atom.GetTotalNumHs(),
                atom.GetTotalValence(),
                atom.GetDegree()
            ])
            
            return atom_feat, phys_feat
            
        except Exception as e:
            print(f"Error calculating atom features: {e}")
            return [0, 0], [0.0] * 9

    def get_atom_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract atom features for the whole molecule"""
        atom_feats = []
        phys_feats = []
        
        if mol is None:
            return torch.tensor([[0, 0]], dtype=torch.long), torch.tensor([[0.0] * 9], dtype=torch.float)
            
        for atom in mol.GetAtoms():
            atom_feat, phys_feat = self.calc_atom_features(atom)
            atom_feats.append(atom_feat)
            phys_feats.append(phys_feat)

        x = torch.tensor(atom_feats, dtype=torch.long)
        phys = torch.tensor(phys_feats, dtype=torch.float)
        
        return x, phys
    
    def remove_unbonded_hydrogens(mol):
        params = Chem.RemoveHsParameters()
        params.removeDegreeZero = True
        mol = Chem.RemoveHs(mol, params)
        return mol


    def get_bond_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract bond features with better error handling"""
        if mol is None:
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)
            
        row, col, edge_feat = [], [], []
        
        for bond in mol.GetBonds():
            try:
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                
                # Add edges in both directions
                row += [start, end]
                col += [end, start]
                
                # Bond features
                bond_type = self.bond_list.index(bond.GetBondType())
                bond_dir = self.bonddir_list.index(bond.GetBondDir())
                
                # Calculate additional properties
                feat = [
                    bond_type,
                    bond_dir,
                    int(bond.GetIsConjugated()),
                    int(self._is_rotatable(bond)),
                    self._get_bond_length(mol, start, end)
                ]
                
                edge_feat.extend([feat, feat])
                
            except Exception as e:
                print(f"Error processing bond: {e}")
                continue

        if not row:  # If no valid bonds were processed
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)

        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = torch.tensor(edge_feat, dtype=torch.float)
        
        return edge_index, edge_attr

    def _is_rotatable(self, bond: Chem.Bond) -> bool:
        """Check if bond is rotatable"""
        return (bond.GetBondType() == Chem.rdchem.BondType.SINGLE and 
                not bond.IsInRing() and
                len(bond.GetBeginAtom().GetNeighbors()) > 1 and
                len(bond.GetEndAtom().GetNeighbors()) > 1)

    def _get_bond_length(self, mol: Chem.Mol, start: int, end: int) -> float:
        """Get bond length with error handling"""
        try:
            conf = mol.GetConformer()
            if conf.Is3D():
                return Chem.rdMolTransforms.GetBondLength(conf, start, end)
        except:
            pass
        return 0.0

    def process_molecule(self, smiles: str) -> Data:
        """Process SMILES string to graph data"""
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                print(f"Invalid SMILES: {smiles}")
                return None  # Skip invalid molecules
            mol = RemoveHs(mol)

            # Add explicit hydrogens
            mol = Chem.AddHs(mol, addCoords=True)

            # Sanitize molecule
            Chem.SanitizeMol(mol)

            # Check if the molecule has atoms
            if mol.GetNumAtoms() == 0:
                print("Molecule has no atoms, skipping.")
                return None

            # Generate 3D coordinates
            if not mol.GetNumConformers():
                status = AllChem.EmbedMolecule(mol, AllChem.ETKDG())
                if status != 0:
                    print("Failed to generate 3D conformer")
                    return None  # Skip failed molecules

                # Try MMFF or UFF optimization
                try:
                    AllChem.MMFFOptimizeMolecule(mol)
                except:
                    AllChem.UFFOptimizeMolecule(mol)

            # Extract features
            x_cat, x_phys = self.get_atom_features(mol)
            edge_index, edge_attr = self.get_bond_features(mol)

            # Save all necessary information
            return Data(
                x_cat=x_cat, 
                x_phys=x_phys,
                edge_index=edge_index, 
                edge_attr=edge_attr,
                num_nodes=x_cat.size(0),
                smiles=smiles,  # Store original SMILES
                mol=mol  # Store RDKit mol object for later use
            )

        except Exception as e:
            print(f"Error processing molecule {smiles}: {e}")
            return None


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, MessagePassing
from typing import Tuple, List, Optional
import copy
from dataclasses import dataclass

class MemoryQueue:
    """Memory queue with temporal decay for contrastive learning"""
    def __init__(self, size: int, dim: int, decay: float = 0.99999):
        self.size = size
        self.dim = dim
        self.decay = decay
        self.ptr = 0
        self.full = False
        
        # Initialize queue
        self.queue = nn.Parameter(F.normalize(torch.randn(size, dim), dim=1), requires_grad=False)
        self.queue_age = nn.Parameter(torch.zeros(size), requires_grad=False)
        
#         self.register_buffer("queue", torch.randn(size, dim))
#         self.register_buffer("queue_age", torch.zeros(size))  # Track age of each entry
        self.queue = F.normalize(self.queue, dim=1)
        
    def update_queue(self, keys: torch.Tensor):
        """Update queue with new keys"""
        batch_size = keys.shape[0]
        
        # Increment age of all entries
        self.queue_age += 1
        
        # Add new keys
        if self.ptr + batch_size <= self.size:
            self.queue[self.ptr:self.ptr + batch_size] = keys
            self.queue_age[self.ptr:self.ptr + batch_size] = 0
        else:
            # Handle overflow
            rem = self.size - self.ptr
            self.queue[self.ptr:] = keys[:rem]
            self.queue[:batch_size-rem] = keys[rem:]
            self.queue_age[self.ptr:] = 0
            self.queue_age[:batch_size-rem] = 0
            self.full = True
            
        self.ptr = (self.ptr + batch_size) % self.size
        
    def get_decay_weights(self) -> torch.Tensor:
        """Get temporal decay weights for queue entries"""
        return self.decay ** self.queue_age
        
    def compute_contrastive_loss(self, query: torch.Tensor, positive_key: torch.Tensor, 
                                temperature: float = 0.07) -> torch.Tensor:
        """Compute contrastive loss with temporal decay"""
        # Normalize embeddings
        query = F.normalize(query, dim=1)
        positive_key = F.normalize(positive_key, dim=1)
        queue = F.normalize(self.queue, dim=1)
        
        # Compute logits
        l_pos = torch.einsum('nc,nc->n', [query, positive_key]).unsqueeze(-1)
        l_neg = torch.einsum('nc,ck->nk', [query, queue.T])
        
        # Apply temporal decay to negative samples
        decay_weights = self.get_decay_weights()
        l_neg = l_neg * decay_weights.unsqueeze(0)
        
        # Temperature scaling
        logits = torch.cat([l_pos, l_neg], dim=1) / temperature
        labels = torch.zeros(logits.shape[0], dtype=torch.long, device=query.device)
        
        return F.cross_entropy(logits, labels)

class GraphGenerator(nn.Module):
    """Generator network with proper feature handling"""
    def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int = 128):
        super().__init__()
        
        # Node feature processing
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Edge feature processing
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Graph convolution layers
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        
        # Importance prediction layers
        self.node_importance = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()
        )
        
        self.edge_importance = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
    def normalize_features(self, x_cat, x_phys):
        """Normalize categorical and physical features separately"""
        # Convert categorical features to one-hot
        x_cat = x_cat.float()
        
        # Normalize physical features
        x_phys = x_phys.float()
        if x_phys.size(0) > 1:  # Only normalize if we have more than one sample
            x_phys = (x_phys - x_phys.mean(0)) / (x_phys.std(0) + 1e-5)
            
        return x_cat, x_phys
        
    def forward(self, data) -> Tuple[torch.Tensor, torch.Tensor]:
        # Normalize features
        x_cat, x_phys = self.normalize_features(data.x_cat, data.x_phys)
        
        # Concatenate features
        x = torch.cat([x_cat, x_phys], dim=-1)
        
        edge_index = data.edge_index
        edge_attr = data.edge_attr.float()  # Ensure float type
        
        # Initial feature encoding
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)
        
        # Graph convolutions
        x = F.relu(self.conv1(x, edge_index))  # Removed edge_attr from GCNConv
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        
        # Predict importance scores
        node_scores = self.node_importance(x)
        
        # Edge scores using both connected nodes
        edge_features = torch.cat([
            x[edge_index[0]], 
            x[edge_index[1]]
        ], dim=-1)
        edge_scores = self.edge_importance(edge_features)
        
        return node_scores, edge_scores

def get_model_config(dataset):
    """Get model configuration based on dataset features"""
    sample_data = dataset[0]
    
    # Calculate input dimensions
    node_dim = sample_data.x_cat.shape[1] + sample_data.x_phys.shape[1]
    edge_dim = sample_data.edge_attr.shape[1]
    
    config = GanClConfig(
        node_dim=node_dim,
        edge_dim=edge_dim,
        hidden_dim=128,
        output_dim=128,
        queue_size=65536,
        momentum=0.999,
        temperature=0.07,
        decay=0.99999,
        dropout_ratio=0.25
    )
    
    return config

class GraphDiscriminator(nn.Module):
    """Discriminator/Encoder network"""
    def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int = 128, output_dim: int = 128):
        super().__init__()
        
        # Feature encoding
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Graph convolution layers
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, output_dim)
        
        # Projection head for contrastive learning
        self.projection = nn.Sequential(
            nn.Linear(output_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def normalize_features(self, x_cat, x_phys):
        """Normalize categorical and physical features separately"""
        # Convert categorical features to one-hot
        x_cat = x_cat.float()
        
        # Normalize physical features
        x_phys = x_phys.float()
        if x_phys.size(0) > 1:  # Only normalize if we have more than one sample
            x_phys = (x_phys - x_phys.mean(0)) / (x_phys.std(0) + 1e-5)
            
        return x_cat, x_phys 
        
    def forward(self, data):
        # Normalize features
        x_cat, x_phys = self.normalize_features(data.x_cat, data.x_phys)
        
        # Concatenate features
        x = torch.cat([x_cat, x_phys], dim=-1)
        
        edge_index = data.edge_index
        edge_attr = data.edge_attr.float()  # Ensure float type
        batch = data.batch
        
        # Initial feature encoding
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)
        
        # Graph convolutions
        x = F.relu(self.conv1(x, edge_index))  # Removed edge_attr from GCNConv
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        
        # Global pooling
        x = global_mean_pool(x, batch)
        
        # Projection
        x = self.projection(x)
        
        return x

@dataclass
class GanClConfig:
    """Configuration for GAN-CL training"""
    node_dim: int
    edge_dim: int
    hidden_dim: int = 128
    output_dim: int = 128
    queue_size: int = 65536
    momentum: float = 0.999
    temperature: float = 0.07
    decay: float = 0.99999
    dropout_ratio: float = 0.25

class MolecularGANCL(nn.Module):
    """Combined GAN and Contrastive Learning framework"""
    def __init__(self, config: GanClConfig):
        super().__init__()
        self.config = config
        
        # Add weight initialization
        def init_weights(m):
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                m.bias.data.fill_(0.01)
        
        # Initialize networks
        self.generator = GraphGenerator(
            config.node_dim, 
            config.edge_dim, 
            config.hidden_dim * 2
        )
        
        self.encoder = GraphDiscriminator(
            config.node_dim,
            config.edge_dim,
            config.hidden_dim,
            config.output_dim
        )
        self.encoder.apply(init_weights)
        
        # Modified loss weights
        self.contrastive_weight = 1.0
        self.adversarial_weight = 0.1  # Increased from 0.05
        self.similarity_weight = 0.01  # Decreased from 0.1
        
        # Temperature annealing
        self.initial_temperature = 0.1
        self.min_temperature = 0.05        
        
        # Create momentum encoder
        self.momentum_encoder = copy.deepcopy(self.encoder)
        for param in self.momentum_encoder.parameters():
            param.requires_grad = False
            
        # Initialize memory queue
        self.memory_queue = MemoryQueue(
            config.queue_size,
            config.output_dim,
            config.decay
        )
        
    @torch.no_grad()
    def _momentum_update(self):
        """Update momentum encoder"""
        for param_q, param_k in zip(self.encoder.parameters(), 
                                  self.momentum_encoder.parameters()):
            param_k.data = self.config.momentum * param_k.data + \
                          (1 - self.config.momentum) * param_q.data
                          
    def drop_graph_elements(self, data, node_scores: torch.Tensor, 
                          edge_scores: torch.Tensor) -> Data:
        """Apply dropout to graph based on importance scores"""
        # Select elements to keep based on scores and dropout ratio
#         node_mask = (node_scores < self.config.dropout_ratio).float()
#         edge_mask = (edge_scores < self.config.dropout_ratio).float()

        node_mask = (torch.rand_like(node_scores) > self.config.dropout_ratio).float()
        edge_mask = (torch.rand_like(edge_scores) > self.config.dropout_ratio).float()
        
        # Apply masks
        x_cat_new = data.x_cat * node_mask
        x_phys_new = data.x_phys * node_mask
        edge_attr_new = data.edge_attr * edge_mask
        
        # Create new graph data object
        return Data(
            x_cat=x_cat_new,
            x_phys=x_phys_new,
            edge_index=data.edge_index,
            edge_attr=edge_attr_new,
            batch=data.batch
        )
        
    def get_temperature(self, epoch, total_epochs):
        """Anneal temperature during training"""
        progress = epoch / total_epochs
        return max(self.initial_temperature * (1 - progress), self.min_temperature)
    
    def forward(self, data, epoch=0, total_epochs=50):
        # Get current temperature
        temperature = self.get_temperature(epoch, total_epochs)
        
        # Get importance scores from generator
        node_scores, edge_scores = self.generator(data)
        
        # Create perturbed graph
        perturbed_data = self.drop_graph_elements(data, node_scores, edge_scores)
        
        # Get embeddings
        query_emb = self.encoder(perturbed_data)
        with torch.no_grad():
            key_emb = self.momentum_encoder(data)
            original_emb = self.encoder(data).detach()
        
        # Compute losses with modified weights
        contrastive_loss = self.memory_queue.compute_contrastive_loss(
            query_emb, key_emb, temperature
        ) * self.contrastive_weight
        
        adversarial_loss = -F.mse_loss(query_emb, original_emb) * self.adversarial_weight
        similarity_loss = F.mse_loss(query_emb, original_emb) * self.similarity_weight
        
        return contrastive_loss, adversarial_loss, similarity_loss
    
    def get_embeddings(self, data) -> torch.Tensor:
        """Get embeddings for downstream tasks"""
        with torch.no_grad():
            return self.encoder(data)

In [4]:
import torch
import numpy as np
from torch_geometric.data import DataLoader
import os
import json
from tqdm import tqdm
import pickle
from datetime import datetime

def calculate_molecular_features(mol):
    """Calculate comprehensive molecular features"""
    if mol is None:
        return None
        
    features = {
        # Graph-level properties
        'MW': Descriptors.ExactMolWt(mol),
        'LogP': Descriptors.MolLogP(mol),
        'TPSA': Descriptors.TPSA(mol),
        
        # Ring information
        'num_rings': Chem.rdMolDescriptors.CalcNumRings(mol),
        'num_aromatic_rings': Chem.rdMolDescriptors.CalcNumAromaticRings(mol),
        'num_aliphatic_rings': Chem.rdMolDescriptors.CalcNumAliphaticRings(mol),
        
        # Atom counts
        'num_atoms': mol.GetNumAtoms(),
        'num_heavy_atoms': mol.GetNumHeavyAtoms(),
        'num_rotatable_bonds': Chem.rdMolDescriptors.CalcNumRotatableBonds(mol),
        
        # Element counts
        'num_C': sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() == 6),
        'num_N': sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() == 7),
        'num_O': sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() == 8),
        'num_F': sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() == 9),
        'num_P': sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() == 15),
        'num_S': sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() == 16),
        'num_Cl': sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() == 17),
        'num_Br': sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() == 35),
        'num_I': sum(1 for atom in mol.GetAtoms() if atom.GetAtomicNum() == 53),
        
        # Functional groups using SMARTS patterns
        'has_alcohol': len(mol.GetSubstructMatches(Chem.MolFromSmarts('[OH]'))),
        'has_amine': len(mol.GetSubstructMatches(Chem.MolFromSmarts('[NH2]'))),
        'has_carboxyl': len(mol.GetSubstructMatches(Chem.MolFromSmarts('[CX3](=O)[OX2H1]'))),
        'has_carbonyl': len(mol.GetSubstructMatches(Chem.MolFromSmarts('[CX3]=O'))),
        'has_ether': len(mol.GetSubstructMatches(Chem.MolFromSmarts('[OR]'))),
        'has_ester': len(mol.GetSubstructMatches(Chem.MolFromSmarts('[#6][CX3](=O)[OX2H0][#6]'))),
        'has_amide': len(mol.GetSubstructMatches(Chem.MolFromSmarts('[NX3][CX3](=[OX1])'))),
        'has_halogen': len(mol.GetSubstructMatches(Chem.MolFromSmarts('[F,Cl,Br,I]')))
    }
    
    return features

def save_embeddings(embeddings, graphs, filepath, smiles_list=None):
    """Save embeddings with comprehensive molecular details"""
    # SMARTS patterns for substructure matching
    patterns = {
        # Ring systems
        'aromatic_ring': '[a;r6]1:a:a:a:a:a:1',
        'heterocycle': '[!#6;!#1;R]',  # Any non-carbon, non-hydrogen ring atom
        'spiro': '[D4R]',
        'bridged': '[R2]([R2])([R2])[R2]',
        'macrocycle': '[r{8,}]',
        
        # Functional groups
        'alcohol': '[OH]',
        'phenol': '[OH]c1ccccc1',
        'amine': '[NH2]',
        'carboxyl': '[CX3](=O)[OX2H1]',
        'carbonyl': '[CX3]=O',
        'ether': '[OR]',
        'ester': '[#6][CX3](=O)[OX2H0][#6]',
        'amide': '[NX3][CX3](=[OX1])',
        'sulfonamide': '[#16X4]([NX3])(=[OX1])(=[OX1])',
        'halogen': '[F,Cl,Br,I]'
    }
    
    molecular_details = []
    print("\nProcessing molecular details...")
    
    for graph in tqdm(graphs):
        try:
            # Convert graph to SMILES if not provided
            smiles = graph.smiles if hasattr(graph, 'smiles') else ''
            mol = Chem.MolFromSmiles(smiles) if smiles else None
            
            if mol is not None:
                # Basic molecular properties
                props = {
                    'smiles': smiles,
                    'basic_properties': {
                        'MW': Descriptors.ExactMolWt(mol),
                        'LogP': Descriptors.MolLogP(mol),
                        'TPSA': Descriptors.TPSA(mol),
                        'num_atoms': mol.GetNumAtoms(),
                        'num_bonds': mol.GetNumBonds(),
                        'num_rotatable_bonds': Descriptors.NumRotatableBonds(mol),
                        'num_h_acceptors': Descriptors.NumHAcceptors(mol),
                        'num_h_donors': Descriptors.NumHDonors(mol)
                    },
                    
                    # Ring information
                    'ring_info': {
                        'total_rings': Chem.rdMolDescriptors.CalcNumRings(mol),
                        'aromatic_rings': Chem.rdMolDescriptors.CalcNumAromaticRings(mol),
                        'aliphatic_rings': Chem.rdMolDescriptors.CalcNumAliphaticRings(mol),
                        'spiro_atoms': len(Chem.GetSpiroAtoms(mol)),
                        'bridgeheads': len(Chem.FindMolChiralCenters(mol, includeUnassigned=True))
                    },
                    
                    # Ring sizes
                    'ring_sizes': {},
                    
                    # Element counts
                    'element_counts': {},
                    
                    # Structural features
                    'structural_features': {},
                    
                    # Functional groups
                    'functional_groups': {}
                }
                
                # Get ring sizes
                sssr = Chem.GetSymmSSSR(mol)
                for ring in sssr:
                    size = len(ring)
                    props['ring_sizes'][str(size)] = props['ring_sizes'].get(str(size), 0) + 1
                
                # Count elements
                for atom in mol.GetAtoms():
                    symbol = atom.GetSymbol()
                    props['element_counts'][symbol] = props['element_counts'].get(symbol, 0) + 1
                
                # Match patterns
                for name, smarts in patterns.items():
                    pattern = Chem.MolFromSmarts(smarts)
                    if pattern:
                        matches = len(mol.GetSubstructMatches(pattern))
                        if name in ['aromatic_ring', 'heterocycle', 'spiro', 'bridged', 'macrocycle']:
                            props['structural_features'][name] = matches
                        else:
                            props['functional_groups'][name] = matches
                
                molecular_details.append(props)
            else:
                molecular_details.append(None)
                
        except Exception as e:
            print(f"Error processing molecule: {e}")
            molecular_details.append(None)
    
    # Save all data
    save_data = {
        'embeddings': embeddings,
        'molecular_details': molecular_details,
        'original_graphs': graphs
    }
    
    with open(filepath, 'wb') as f:
        pickle.dump(save_data, f)
    
    print(f"\nSaved {len(embeddings)} embeddings with molecular details")
    
    # Print summary
    valid_mols = sum(1 for x in molecular_details if x is not None)
    print(f"Successfully processed {valid_mols} molecules")
    
    return molecular_details  # Return for potential validation

def save_encoder(encoder, save_path, info=None):
    """Save encoder model for downstream tasks"""
    save_dict = {
        'encoder_state_dict': encoder.state_dict(),
        'model_info': info or {}
    }
    torch.save(save_dict, save_path)

def load_encoder(model_path, device='cpu'):
    """Load saved encoder model"""
    checkpoint = torch.load(model_path, map_location=device)
    encoder = GraphDiscriminator(
        node_dim=checkpoint['model_info'].get('node_dim'),
        edge_dim=checkpoint['model_info'].get('edge_dim'),
        hidden_dim=checkpoint['model_info'].get('hidden_dim', 128),
        output_dim=checkpoint['model_info'].get('output_dim', 128)
    )
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    return encoder        
        
def train_gan_cl(train_loader, config, device='cuda', 
                save_dir='./checkpoints', 
                embedding_dir='./embeddings'):
    """Main training function for GAN-CL with fixed gradient computation"""
    
    # Create directories
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(embedding_dir, exist_ok=True)
    encoder_dir = os.path.join(save_dir, 'encoders')
    os.makedirs(encoder_dir, exist_ok=True)    
    
    # Initialize model
    model = MolecularGANCL(config).to(device)
    
    # Initialize optimizers
    optimizer_encoder = torch.optim.Adam(model.encoder.parameters(), lr=3e-4)
    optimizer_generator = torch.optim.Adam(model.generator.parameters(), lr=1e-4)
    
    # Save initial model info
    model_info = {
        'node_dim': config.node_dim,
        'edge_dim': config.edge_dim,
        'hidden_dim': config.hidden_dim,
        'output_dim': config.output_dim,
        'training_config': config.__dict__
    }
    
    # Training phases as before...
    best_loss = float('inf')
    
    # Training metrics
    metrics = {
        'contrastive_losses': [],
        'adversarial_losses': [],
        'similarity_losses': [],
        'total_losses': []
    }
    
    # Training phases
    print("Phase 1: Pretraining Contrastive Learning...")
    pretrain_epochs = 10
    for epoch in range(pretrain_epochs):
        contrastive_epoch_loss = 0
        
        for batch in tqdm(train_loader, desc=f'Pretrain Epoch {epoch+1}/{pretrain_epochs}'):
            batch = batch.to(device)
            
            # Forward pass (without generator)
            query_emb = model.encoder(batch)
            with torch.no_grad():
                key_emb = model.momentum_encoder(batch)
            
            # Compute contrastive loss
            contrastive_loss = model.memory_queue.compute_contrastive_loss(
                query_emb, key_emb, model.config.temperature
            )
            
            # Update encoder
            optimizer_encoder.zero_grad()
            contrastive_loss.backward()
            optimizer_encoder.step()
            
            # Update momentum encoder
            model._momentum_update()
            
            # Update memory queue
            model.memory_queue.update_queue(key_emb.detach())
            
            contrastive_epoch_loss += contrastive_loss.item()
            
        avg_loss = contrastive_epoch_loss / len(train_loader)
        metrics['contrastive_losses'].append(avg_loss)
        print(f'Pretrain Epoch {epoch+1}, Avg Loss: {avg_loss:.4f}')
        
        # Save pretrained checkpoint
        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_encoder_state_dict': optimizer_encoder.state_dict(),
                'loss': avg_loss,
            }, os.path.join(save_dir, f'pretrain_checkpoint_{epoch+1}.pt'))
    
    print("\nPhase 2: Training GAN-CL...")
#     train_epochs = 50
    train_epochs = 10
    for epoch in range(train_epochs):
        epoch_losses = {
            'contrastive': 0,
            'adversarial': 0,
            'similarity': 0,
            'total': 0
        }
        
        for batch in tqdm(train_loader, desc=f'Train Epoch {epoch+1}/{train_epochs}'):
            batch = batch.to(device)
            
            # Step 1: Train Encoder
            optimizer_encoder.zero_grad()
            
            # Get importance scores from generator
            with torch.no_grad():
                node_scores, edge_scores = model.generator(batch)
            
            # Create perturbed graph
            perturbed_data = model.drop_graph_elements(batch, node_scores, edge_scores)
            
            # Get embeddings
            query_emb = model.encoder(perturbed_data)
            with torch.no_grad():
                key_emb = model.momentum_encoder(batch)
                original_emb = model.encoder(batch).detach()
            
            # Compute losses for encoder
            contrastive_loss = model.memory_queue.compute_contrastive_loss(
                query_emb, key_emb, model.config.temperature
            )
            similarity_loss = F.mse_loss(query_emb, original_emb)
            
            # Total loss for encoder
            encoder_loss = contrastive_loss + 0.1 * similarity_loss
            
            # Update encoder
            encoder_loss.backward()
            optimizer_encoder.step()
            
            # Update momentum encoder
            model._momentum_update()
            
            # Step 2: Train Generator
            optimizer_generator.zero_grad()
            
            # Get new embeddings for adversarial loss
            node_scores, edge_scores = model.generator(batch)
            perturbed_data = model.drop_graph_elements(batch, node_scores, edge_scores)
            
            with torch.no_grad():
                original_emb = model.encoder(batch)
            perturbed_emb = model.encoder(perturbed_data)
            
            # Compute adversarial loss
            adversarial_loss = -F.mse_loss(perturbed_emb, original_emb)
            
            # Update generator
            adversarial_loss.backward()
            optimizer_generator.step()
            
            # Update memory queue
            model.memory_queue.update_queue(key_emb.detach())
            
            # Update metrics
            epoch_losses['contrastive'] += contrastive_loss.item()
            epoch_losses['adversarial'] += adversarial_loss.item()
            epoch_losses['similarity'] += similarity_loss.item()
            epoch_losses['total'] += encoder_loss.item()
        
        # Average losses
        for k in epoch_losses:
            epoch_losses[k] /= len(train_loader)
            metrics[f'{k}_losses'].append(epoch_losses[k])
        
        print(f'Epoch {epoch+1}, Losses: {epoch_losses}')
        
        # Save checkpoint
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_encoder_state_dict': optimizer_encoder.state_dict(),
                'optimizer_generator_state_dict': optimizer_generator.state_dict(),
                'losses': epoch_losses,
            }, os.path.join(save_dir, f'gan_cl_checkpoint_{epoch+1}.pt'))
            

        # In train_gan_cl function, in the embedding saving section:
        if (epoch + 1) % 10 == 0:
            model.eval()
            all_embeddings = []
            all_graphs = []

            with torch.no_grad():
                for batch in train_loader:
                    batch = batch.to(device)
                    embeddings = model.get_embeddings(batch)
                    all_embeddings.append(embeddings.cpu())
                    all_graphs.extend([data for data in batch])

            all_embeddings = torch.cat(all_embeddings, dim=0)

            # Save embeddings
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            save_path = os.path.join(
                embedding_dir, 
                f'embeddings_epoch_{epoch+1}_{timestamp}.pkl'
            )
            save_embeddings(
                embeddings=all_embeddings.numpy(),
                graphs=all_graphs,
                filepath=save_path
            )

            model.train()
    
    # Save final metrics
    with open(os.path.join(save_dir, 'training_metrics.json'), 'w') as f:
        json.dump(metrics, f)
    
            # Save encoder periodically
        if (epoch + 1) % 10 == 0:
            epoch_info = {
                **model_info,
                'epoch': epoch + 1,
                'loss': epoch_losses['total']
            }
            save_encoder(
                model.encoder,
                os.path.join(encoder_dir, f'encoder_epoch_{epoch+1}.pt'),
                epoch_info
            )
        
        # Save best encoder based on total loss
        if epoch_losses['total'] < best_loss:
            best_loss = epoch_losses['total']
            save_encoder(
                model.encoder,
                os.path.join(encoder_dir, f'best_encoder_{timestamp}.pt'),
                {**model_info, 'epoch': epoch + 1, 'loss': best_loss}
            )
    
    # Save final encoder
    save_encoder(
        model.encoder,
        os.path.join(encoder_dir, f'final_encoder_{timestamp}.pt'),
        {**model_info, 'epoch': train_epochs, 'loss': epoch_losses['total']}
    )
    
    return model, metrics


In [5]:
def main():
     # Enable anomaly detection during development
    torch.autograd.set_detect_anomaly(True)
    # Your existing data loading code here
    torch.manual_seed(42)
    np.random.seed(42)
    
    
    dataset = []
    failed_smiles = []
    
    print("Starting data loading...")
    extractor = MolecularFeatureExtractor()
    smiles_file = "D:\\PhD\\Chapter3\\Unsupervised_GAN_Code\\pubchem-41-clean.txt"
    
    # Create dataset with molecules
    original_molecules = []  # Store original RDKit molecules
    with open(smiles_file, 'r') as f:
        for line in f:
            smiles = line.strip()
            mol = Chem.MolFromSmiles(smiles)
            if mol is not None:
                data = extractor.process_molecule(smiles)
                if data is not None:
                    dataset.append(data)
                    original_molecules.append({
                        'smiles': smiles,
                        'mol': mol,
                        'features': {
                            'MW': Descriptors.ExactMolWt(mol),
                            'LogP': Descriptors.MolLogP(mol),
                            'TPSA': Descriptors.TPSA(mol),
                            'num_rings': Chem.rdMolDescriptors.CalcNumRings(mol),
                            'aromatic_rings': Chem.rdMolDescriptors.CalcNumAromaticRings(mol),
                            'aliphatic_rings': Chem.rdMolDescriptors.CalcNumAliphaticRings(mol)
                        }
                    })
            else:
                failed_smiles.append(smiles)
    
    # Save original molecule information
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    with open(f'./original_molecules_{timestamp}.pkl', 'wb') as f:
        pickle.dump(original_molecules, f)
        
    print(f"1. Loaded dataset with {len(dataset)} graphs.")
    print(f"2. Failed SMILES count: {len(failed_smiles)}")
    
    if not dataset:
        print("No valid graphs generated.")
        return None
        
    # Setup training
    batch_size = 32
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    print(f"3. Created DataLoader with {len(train_loader.dataset)} graphs")
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"4. Using device: {device}")
    
    # Get configuration based on dataset
    config = get_model_config(dataset)
    
    # Train model
    print("5. Starting GAN-CL training...")
    model, metrics = train_gan_cl(
        train_loader, 
        config,
        device=device,
        save_dir='./checkpoints',
        embedding_dir='./embeddings'
    )
    
    print("6. Training completed!")
    
    # Extract embeddings for XAI
    print("7. Extracting final embeddings for XAI...")
    model.eval()
    with torch.no_grad():
        all_embeddings = []
        all_graphs = []

        for batch in tqdm(train_loader, desc="Extracting embeddings"):
            batch = batch.to(device)
            embeddings = model.get_embeddings(batch)
            all_embeddings.append(embeddings.cpu())
            all_graphs.extend([data for data in batch])

    all_embeddings = torch.cat(all_embeddings, dim=0).numpy()

    # Save final embeddings
    final_embedding_path = f'./embeddings/final_embeddings_{timestamp}.pkl'
    save_embeddings(
        embeddings=all_embeddings,
        graphs=all_graphs,
        filepath=final_embedding_path
    )
    print(f"8. Final embeddings saved to {final_embedding_path}")
    
    return model, metrics, all_embeddings, all_graphs

if __name__ == "__main__":
    model, metrics, embeddings, graphs = main()

Starting data loading...




1. Loaded dataset with 41 graphs.
2. Failed SMILES count: 0
3. Created DataLoader with 41 graphs
4. Using device: cpu
5. Starting GAN-CL training...
Phase 1: Pretraining Contrastive Learning...


Pretrain Epoch 1/10: 100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  6.87it/s]


Pretrain Epoch 1, Avg Loss: 1.7384


Pretrain Epoch 2/10: 100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  7.97it/s]


Pretrain Epoch 2, Avg Loss: 3.8909


Pretrain Epoch 3/10: 100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  7.68it/s]


Pretrain Epoch 3, Avg Loss: 4.3685


Pretrain Epoch 4/10: 100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.34it/s]


Pretrain Epoch 4, Avg Loss: 4.7127


Pretrain Epoch 5/10: 100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  8.22it/s]


Pretrain Epoch 5, Avg Loss: 4.7612


Pretrain Epoch 6/10: 100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  8.12it/s]


Pretrain Epoch 6, Avg Loss: 4.9642


Pretrain Epoch 7/10: 100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  6.53it/s]


Pretrain Epoch 7, Avg Loss: 5.1284


Pretrain Epoch 8/10: 100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  7.45it/s]


Pretrain Epoch 8, Avg Loss: 5.2073


Pretrain Epoch 9/10: 100%|██████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  8.74it/s]


Pretrain Epoch 9, Avg Loss: 5.3486


Pretrain Epoch 10/10: 100%|█████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  6.90it/s]


Pretrain Epoch 10, Avg Loss: 5.5249

Phase 2: Training GAN-CL...


Train Epoch 1/10: 100%|█████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.12it/s]


Epoch 1, Losses: {'contrastive': 5.605298280715942, 'adversarial': -0.0003780678234761581, 'similarity': 0.0003618255868786946, 'total': 5.605334281921387}


Train Epoch 2/10: 100%|█████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.26it/s]


Epoch 2, Losses: {'contrastive': 5.601036787033081, 'adversarial': -0.0004964996187482029, 'similarity': 0.0003700220986502245, 'total': 5.601073741912842}


Train Epoch 3/10: 100%|█████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.34it/s]


Epoch 3, Losses: {'contrastive': 5.684419393539429, 'adversarial': -0.0005045205180067569, 'similarity': 0.0005598131101578474, 'total': 5.684475421905518}


Train Epoch 4/10: 100%|█████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.35it/s]


Epoch 4, Losses: {'contrastive': 5.761997938156128, 'adversarial': -0.0005649148661177605, 'similarity': 0.00047480646753683686, 'total': 5.762045383453369}


Train Epoch 5/10: 100%|█████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.46it/s]


Epoch 5, Losses: {'contrastive': 5.782212495803833, 'adversarial': -0.0003543853235896677, 'similarity': 0.00037507347587961704, 'total': 5.782249927520752}


Train Epoch 6/10: 100%|█████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.29it/s]


Epoch 6, Losses: {'contrastive': 5.933075428009033, 'adversarial': -0.0004684303712565452, 'similarity': 0.00047023536171764135, 'total': 5.933122396469116}


Train Epoch 7/10: 100%|█████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.46it/s]


Epoch 7, Losses: {'contrastive': 5.924046516418457, 'adversarial': -0.000440751071437262, 'similarity': 0.0005972441722406074, 'total': 5.9241063594818115}


Train Epoch 8/10: 100%|█████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.50it/s]


Epoch 8, Losses: {'contrastive': 6.0389556884765625, 'adversarial': -0.00047894091403577477, 'similarity': 0.00042598682921379805, 'total': 6.038998365402222}


Train Epoch 9/10: 100%|█████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.94it/s]


Epoch 9, Losses: {'contrastive': 6.027733564376831, 'adversarial': -0.00043135814485140145, 'similarity': 0.0006277782958932221, 'total': 6.027796268463135}


Train Epoch 10/10: 100%|████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.70it/s]


Epoch 10, Losses: {'contrastive': 6.116659164428711, 'adversarial': -0.0005571308720391244, 'similarity': 0.000566645700018853, 'total': 6.116715908050537}

Processing molecular details...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 18/18 [00:00<?, ?it/s]



Saved 41 embeddings with molecular details
Successfully processed 0 molecules
6. Training completed!
7. Extracting final embeddings for XAI...


Extracting embeddings: 100%|███████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 114.49it/s]



Processing molecular details...


100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 18/18 [00:00<?, ?it/s]


Saved 41 embeddings with molecular details
Successfully processed 0 molecules
8. Final embeddings saved to ./embeddings/final_embeddings_20250224_124549.pkl



