In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import os
import json
import pickle
import copy
import pandas as pd
from tqdm import tqdm
from datetime import datetime
from typing import Tuple, List, Dict, Any, Optional
from collections import defaultdict
from dataclasses import dataclass

from rdkit import Chem, RDLogger
from rdkit.Chem import (
    AllChem, Descriptors, MolSurf, Fragments, Lipinski, RemoveHs
)
from rdkit.Chem.rdMolDescriptors import (
    CalcNumRings, CalcNumAromaticRings, CalcNumHeterocycles, 
    CalcNumAliphaticRings, CalcTPSA
)

from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool, MessagePassing

# Suppress RDKit warnings
RDLogger.DisableLog('rdApp.warning')

# Working timestamp
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')


import numpy as np
import pandas as pd
import torch
from rdkit import Chem
from rdkit.Chem import Descriptors, Fragments, Lipinski
from rdkit.Chem.rdMolDescriptors import CalcNumRings, CalcNumAromaticRings, CalcNumHeterocycles
from rdkit.Chem.rdMolDescriptors import CalcNumAliphaticRings, CalcTPSA
from typing import List, Dict, Any, Tuple
import os
import json
import pickle
from tqdm import tqdm
from datetime import datetime
from collections import defaultdict





In [2]:
class MolecularFeatureExtractor:
    def __init__(self):
        self.atom_list = list(range(1, 119))
        self.chirality_list = [
            Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
            Chem.rdchem.ChiralType.CHI_OTHER
        ]
        self.bond_list = [
            Chem.rdchem.BondType.SINGLE,
            Chem.rdchem.BondType.DOUBLE, 
            Chem.rdchem.BondType.TRIPLE,
            Chem.rdchem.BondType.AROMATIC
        ]
        self.bonddir_list = [
            Chem.rdchem.BondDir.NONE,
            Chem.rdchem.BondDir.ENDUPRIGHT,
            Chem.rdchem.BondDir.ENDDOWNRIGHT
        ]

    def calc_atom_features(self, atom: Chem.Atom) -> Tuple[list, list]:
        """Calculate atom features with better error handling"""
        try:
            # Basic features
            atom_feat = [
                self.atom_list.index(atom.GetAtomicNum()),
                self.chirality_list.index(atom.GetChiralTag())
            ]

            # Physical features with error handling
            phys_feat = []
            
            # Molecular weight contribution
            try:
                contrib_mw = Descriptors.ExactMolWt(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_mw)
            except:
                phys_feat.append(0.0)
                
            # LogP contribution    
            try:
                contrib_logp = Descriptors.MolLogP(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_logp)
            except:
                phys_feat.append(0.0)
                
            # Add other physical properties
            phys_feat.extend([
                atom.GetFormalCharge(),
                int(atom.GetHybridization()),
                int(atom.GetIsAromatic()),
                atom.GetTotalNumHs(),
                atom.GetTotalValence(),
                atom.GetDegree()
            ])
            
            return atom_feat, phys_feat
            
        except Exception as e:
            print(f"Error calculating atom features: {e}")
            return [0, 0], [0.0] * 9

    def get_atom_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract atom features for the whole molecule"""
        atom_feats = []
        phys_feats = []
        
        if mol is None:
            return torch.tensor([[0, 0]], dtype=torch.long), torch.tensor([[0.0] * 9], dtype=torch.float)
            
        for atom in mol.GetAtoms():
            atom_feat, phys_feat = self.calc_atom_features(atom)
            atom_feats.append(atom_feat)
            phys_feats.append(phys_feat)

        x = torch.tensor(atom_feats, dtype=torch.long)
        phys = torch.tensor(phys_feats, dtype=torch.float)
        
        return x, phys
    
    def remove_unbonded_hydrogens(mol):
        params = Chem.RemoveHsParameters()
        params.removeDegreeZero = True
        mol = Chem.RemoveHs(mol, params)
        return mol


    def get_bond_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract bond features with better error handling"""
        if mol is None:
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)
            
        row, col, edge_feat = [], [], []
        
        for bond in mol.GetBonds():
            try:
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                
                # Add edges in both directions
                row += [start, end]
                col += [end, start]
                
                # Bond features
                bond_type = self.bond_list.index(bond.GetBondType())
                bond_dir = self.bonddir_list.index(bond.GetBondDir())
                
                # Calculate additional properties
                feat = [
                    bond_type,
                    bond_dir,
                    int(bond.GetIsConjugated()),
                    int(self._is_rotatable(bond)),
                    self._get_bond_length(mol, start, end)
                ]
                
                edge_feat.extend([feat, feat])
                
            except Exception as e:
                print(f"Error processing bond: {e}")
                continue

        if not row:  # If no valid bonds were processed
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)

        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = torch.tensor(edge_feat, dtype=torch.float)
        
        return edge_index, edge_attr

    def _is_rotatable(self, bond: Chem.Bond) -> bool:
        """Check if bond is rotatable"""
        return (bond.GetBondType() == Chem.rdchem.BondType.SINGLE and 
                not bond.IsInRing() and
                len(bond.GetBeginAtom().GetNeighbors()) > 1 and
                len(bond.GetEndAtom().GetNeighbors()) > 1)

    def _get_bond_length(self, mol: Chem.Mol, start: int, end: int) -> float:
        """Get bond length with error handling"""
        try:
            conf = mol.GetConformer()
            if conf.Is3D():
                return Chem.rdMolTransforms.GetBondLength(conf, start, end)
        except:
            pass
        return 0.0

    def process_molecule(self, smiles: str) -> Data:
        """Process SMILES string to graph data"""
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                print(f"Invalid SMILES: {smiles}")
                return None  # Skip invalid molecules
            mol = RemoveHs(mol)

            # Add explicit hydrogens
            mol = Chem.AddHs(mol, addCoords=True)

            # Sanitize molecule
            Chem.SanitizeMol(mol)

            # Check if the molecule has atoms
            if mol.GetNumAtoms() == 0:
                print("Molecule has no atoms, skipping.")
                return None

            # Generate 3D coordinates
            if not mol.GetNumConformers():
                status = AllChem.EmbedMolecule(mol, AllChem.ETKDG())
                if status != 0:
                    print("Failed to generate 3D conformer")
                    return None  # Skip failed molecules

                # Try MMFF or UFF optimization
                try:
                    AllChem.MMFFOptimizeMolecule(mol)
                except:
                    AllChem.UFFOptimizeMolecule(mol)

            # Extract features
            x_cat, x_phys = self.get_atom_features(mol)
            edge_index, edge_attr = self.get_bond_features(mol)

            return Data(
                x_cat=x_cat, 
                x_phys=x_phys,
                edge_index=edge_index, 
                edge_attr=edge_attr,
                num_nodes=x_cat.size(0)
            )

        except Exception as e:
            print(f"Error processing molecule {smiles}: {e}")
            return None


In [3]:
class MemoryQueue:
    """Memory queue with temporal decay for contrastive learning"""
    def __init__(self, size: int, dim: int, decay: float = 0.99999):
        self.size = size
        self.dim = dim
        self.decay = decay
        self.ptr = 0
        self.full = False
        
        # Initialize queue
        self.queue = nn.Parameter(F.normalize(torch.randn(size, dim), dim=1), requires_grad=False)
        self.queue_age = nn.Parameter(torch.zeros(size), requires_grad=False)
        
#         self.register_buffer("queue", torch.randn(size, dim))
#         self.register_buffer("queue_age", torch.zeros(size))  # Track age of each entry
        self.queue = F.normalize(self.queue, dim=1)
        
    def update_queue(self, keys: torch.Tensor):
        """Update queue with new keys"""
        batch_size = keys.shape[0]
        
        # Increment age of all entries
        self.queue_age += 1
        
        # Add new keys
        if self.ptr + batch_size <= self.size:
            self.queue[self.ptr:self.ptr + batch_size] = keys
            self.queue_age[self.ptr:self.ptr + batch_size] = 0
        else:
            # Handle overflow
            rem = self.size - self.ptr
            self.queue[self.ptr:] = keys[:rem]
            self.queue[:batch_size-rem] = keys[rem:]
            self.queue_age[self.ptr:] = 0
            self.queue_age[:batch_size-rem] = 0
            self.full = True
            
        self.ptr = (self.ptr + batch_size) % self.size
        
    def get_decay_weights(self) -> torch.Tensor:
        """Get temporal decay weights for queue entries"""
        return self.decay ** self.queue_age
        
    def compute_contrastive_loss(self, query: torch.Tensor, positive_key: torch.Tensor, 
                                temperature: float = 0.07) -> torch.Tensor:
        """Compute contrastive loss with temporal decay"""
        # Normalize embeddings
        query = F.normalize(query, dim=1)
        positive_key = F.normalize(positive_key, dim=1)
        queue = F.normalize(self.queue, dim=1)
        
        # Compute logits
        l_pos = torch.einsum('nc,nc->n', [query, positive_key]).unsqueeze(-1)
        l_neg = torch.einsum('nc,ck->nk', [query, queue.T])
        
        # Apply temporal decay to negative samples
        decay_weights = self.get_decay_weights()
        l_neg = l_neg * decay_weights.unsqueeze(0)
        
        # Temperature scaling
        logits = torch.cat([l_pos, l_neg], dim=1) / temperature
        labels = torch.zeros(logits.shape[0], dtype=torch.long, device=query.device)
        
        return F.cross_entropy(logits, labels)

class GraphGenerator(nn.Module):
    """Generator network with proper feature handling"""
    def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int = 128):
        super().__init__()
        
        # Node feature processing
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Edge feature processing
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Graph convolution layers
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, hidden_dim)
        
        # Importance prediction layers
        self.node_importance = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()
        )
        
        self.edge_importance = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
    def normalize_features(self, x_cat, x_phys):
        """Normalize categorical and physical features separately"""
        # Convert categorical features to one-hot
        x_cat = x_cat.float()
        
        # Normalize physical features
        x_phys = x_phys.float()
        if x_phys.size(0) > 1:  # Only normalize if we have more than one sample
            x_phys = (x_phys - x_phys.mean(0)) / (x_phys.std(0) + 1e-5)
            
        return x_cat, x_phys
        
    def forward(self, data) -> Tuple[torch.Tensor, torch.Tensor]:
        # Normalize features
        x_cat, x_phys = self.normalize_features(data.x_cat, data.x_phys)
        
        # Concatenate features
        x = torch.cat([x_cat, x_phys], dim=-1)
        
        edge_index = data.edge_index
        edge_attr = data.edge_attr.float()  # Ensure float type
        
        # Initial feature encoding
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)
        
        # Graph convolutions
        x = F.relu(self.conv1(x, edge_index))  # Removed edge_attr from GCNConv
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        
        # Predict importance scores
        node_scores = self.node_importance(x)
        
        # Edge scores using both connected nodes
        edge_features = torch.cat([
            x[edge_index[0]], 
            x[edge_index[1]]
        ], dim=-1)
        edge_scores = self.edge_importance(edge_features)
        
        return node_scores, edge_scores

def get_model_config(dataset):
    """Get model configuration based on dataset features"""
    sample_data = dataset[0]
    
    # Calculate input dimensions
    node_dim = sample_data.x_cat.shape[1] + sample_data.x_phys.shape[1]
    edge_dim = sample_data.edge_attr.shape[1]
    
    config = GanClConfig(
        node_dim=node_dim,
        edge_dim=edge_dim,
        hidden_dim=128,
        output_dim=128,
        queue_size=65536,
        momentum=0.999,
        temperature=0.07,
        decay=0.99999,
        dropout_ratio=0.25
    )
    
    return config

class GraphDiscriminator(nn.Module):
    """Discriminator/Encoder network"""
    def __init__(self, node_dim: int, edge_dim: int, hidden_dim: int = 128, output_dim: int = 128):
        super().__init__()
        
        # Feature encoding
        self.node_encoder = nn.Sequential(
            nn.Linear(node_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        self.edge_encoder = nn.Sequential(
            nn.Linear(edge_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        
        # Graph convolution layers
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv3 = GCNConv(hidden_dim, output_dim)
        
        # Projection head for contrastive learning
        self.projection = nn.Sequential(
            nn.Linear(output_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        
    def normalize_features(self, x_cat, x_phys):
        """Normalize categorical and physical features separately"""
        # Convert categorical features to one-hot
        x_cat = x_cat.float()
        
        # Normalize physical features
        x_phys = x_phys.float()
        if x_phys.size(0) > 1:  # Only normalize if we have more than one sample
            x_phys = (x_phys - x_phys.mean(0)) / (x_phys.std(0) + 1e-5)
            
        return x_cat, x_phys 
        
    def forward(self, data):
        # Normalize features
        x_cat, x_phys = self.normalize_features(data.x_cat, data.x_phys)
        
        # Concatenate features
        x = torch.cat([x_cat, x_phys], dim=-1)
        
        edge_index = data.edge_index
        edge_attr = data.edge_attr.float()  # Ensure float type
        batch = data.batch
        
        # Initial feature encoding
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)
        
        # Graph convolutions
        x = F.relu(self.conv1(x, edge_index))  # Removed edge_attr from GCNConv
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        
        # Global pooling
        x = global_mean_pool(x, batch)
        
        # Projection
        x = self.projection(x)
        
        return x

@dataclass
class GanClConfig:
    """Configuration for GAN-CL training"""
    node_dim: int
    edge_dim: int
    hidden_dim: int = 128
    output_dim: int = 128
    queue_size: int = 65536
    momentum: float = 0.999
    temperature: float = 0.07
    decay: float = 0.99999
    dropout_ratio: float = 0.25

class MolecularGANCL(nn.Module):
    """Combined GAN and Contrastive Learning framework"""
    def __init__(self, config: GanClConfig):
        super().__init__()
        self.config = config
        
        # Add weight initialization
        def init_weights(m):
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                m.bias.data.fill_(0.01)
        
        # Initialize networks
        self.generator = GraphGenerator(
            config.node_dim, 
            config.edge_dim, 
            config.hidden_dim * 2
        )
        
        self.encoder = GraphDiscriminator(
            config.node_dim,
            config.edge_dim,
            config.hidden_dim,
            config.output_dim
        )
        self.encoder.apply(init_weights)
        
        # Modified loss weights
        self.contrastive_weight = 1.0
        self.adversarial_weight = 0.1  # Increased from 0.05
        self.similarity_weight = 0.01  # Decreased from 0.1
        
        # Temperature annealing
        self.initial_temperature = 0.1
        self.min_temperature = 0.05        
        
        # Create momentum encoder
        self.momentum_encoder = copy.deepcopy(self.encoder)
        for param in self.momentum_encoder.parameters():
            param.requires_grad = False
            
        # Initialize memory queue
        self.memory_queue = MemoryQueue(
            config.queue_size,
            config.output_dim,
            config.decay
        )
        
    @torch.no_grad()
    def _momentum_update(self):
        """Update momentum encoder"""
        for param_q, param_k in zip(self.encoder.parameters(), 
                                  self.momentum_encoder.parameters()):
            param_k.data = self.config.momentum * param_k.data + \
                          (1 - self.config.momentum) * param_q.data
                          
    def drop_graph_elements(self, data, node_scores: torch.Tensor, 
                          edge_scores: torch.Tensor) -> Data:
        """Apply dropout to graph based on importance scores"""
        # Select elements to keep based on scores and dropout ratio
#         node_mask = (node_scores < self.config.dropout_ratio).float()
#         edge_mask = (edge_scores < self.config.dropout_ratio).float()

        node_mask = (torch.rand_like(node_scores) > self.config.dropout_ratio).float()
        edge_mask = (torch.rand_like(edge_scores) > self.config.dropout_ratio).float()
        
        # Apply masks
        x_cat_new = data.x_cat * node_mask
        x_phys_new = data.x_phys * node_mask
        edge_attr_new = data.edge_attr * edge_mask
        
        # Create new graph data object
        return Data(
            x_cat=x_cat_new,
            x_phys=x_phys_new,
            edge_index=data.edge_index,
            edge_attr=edge_attr_new,
            batch=data.batch
        )
        
    def get_temperature(self, epoch, total_epochs):
        """Anneal temperature during training"""
        progress = epoch / total_epochs
        return max(self.initial_temperature * (1 - progress), self.min_temperature)
    
    def forward(self, data, epoch=0, total_epochs=50):
        # Get current temperature
        temperature = self.get_temperature(epoch, total_epochs)
        
        # Get importance scores from generator
        node_scores, edge_scores = self.generator(data)
        
        # Create perturbed graph
        perturbed_data = self.drop_graph_elements(data, node_scores, edge_scores)
        
        # Get embeddings
        query_emb = self.encoder(perturbed_data)
        with torch.no_grad():
            key_emb = self.momentum_encoder(data)
            original_emb = self.encoder(data).detach()
        
        # Compute losses with modified weights
        contrastive_loss = self.memory_queue.compute_contrastive_loss(
            query_emb, key_emb, temperature
        ) * self.contrastive_weight
        
        adversarial_loss = -F.mse_loss(query_emb, original_emb) * self.adversarial_weight
        similarity_loss = F.mse_loss(query_emb, original_emb) * self.similarity_weight
        
        return contrastive_loss, adversarial_loss, similarity_loss
    
    def get_embeddings(self, data) -> torch.Tensor:
        """Get embeddings for downstream tasks"""
        with torch.no_grad():
            return self.encoder(data)

In [4]:
import numpy as np
import pandas as pd
import torch
from rdkit import Chem
from tqdm import tqdm
import os
from datetime import datetime
import json


class SMILESTracker:
    """A simplified tracker that stores original SMILES strings during training"""
    
    def __init__(self, output_dir='./embeddings'):
        """Initialize the tracker"""
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        self.timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
        self.original_smiles = {}  # Maps dataset index to SMILES
        self.embeddings = []
        self.batch_indices = []
    
    def store_dataset_smiles(self, train_loader):
        """Extract and store all SMILES strings from the dataset
        
        This should be called once at the start of training
        """
        print("Storing original SMILES strings from dataset...")
        dataset_idx = 0
        
        for batch in train_loader:
            for data in batch:
                if hasattr(data, 'smiles'):
                    self.original_smiles[dataset_idx] = data.smiles
                dataset_idx += 1
        
        print(f"Stored {len(self.original_smiles)} SMILES strings from dataset")
        
        # Save the SMILES as a reference
        smiles_file = os.path.join(self.output_dir, f"dataset_smiles_{self.timestamp}.txt")
        with open(smiles_file, 'w') as f:
            for idx, smiles in sorted(self.original_smiles.items()):
                f.write(f"{idx},{smiles}\n")
        
        print(f"Saved original SMILES to {smiles_file}")
    
    def add_batch(self, batch, embeddings):
        """Store embeddings and their batch indices
        
        Args:
            batch: Batch data with batch.batch containing the batch indices
            embeddings: Embeddings tensor
        """
        # Convert embeddings to numpy
        embeddings_np = embeddings.detach().cpu().numpy()
        
        # Extract batch indices
        if hasattr(batch, 'batch'):
            # For batched graph data
            indices = batch.batch.cpu().numpy()
        else:
            # Fallback: just use sequential indices
            indices = np.arange(len(embeddings_np))
        
        # Verify dimensions match
        if len(indices) != len(embeddings_np):
            print(f"Warning: Indices count ({len(indices)}) doesn't match embeddings count ({len(embeddings_np)})")
            # Use the smaller size
            min_size = min(len(indices), len(embeddings_np))
            indices = indices[:min_size]
            embeddings_np = embeddings_np[:min_size]
        
        # Store
        self.embeddings.append(embeddings_np)
        self.batch_indices.append(indices)
    
    def save_embeddings(self, prefix="embeddings"):
        """Save embeddings with their original SMILES strings
        
        Args:
            prefix: Filename prefix
            
        Returns:
            Path to saved file
        """
        if not self.embeddings:
            print("Warning: No embeddings to save")
            return None
        
        # Concatenate all embeddings and indices
        all_embeddings = np.vstack(self.embeddings)
        all_indices = np.concatenate(self.batch_indices)
        
        # Map indices back to SMILES
        all_smiles = []
        for idx in all_indices:
            if idx in self.original_smiles:
                all_smiles.append(self.original_smiles[idx])
            else:
                all_smiles.append(f"unknown_{idx}")
        
        # Create filename
        filename = f"{prefix}_{self.timestamp}.npz"
        filepath = os.path.join(self.output_dir, filename)
        
        # Save as npz file
        np.savez(filepath, embeddings=all_embeddings, smiles=all_smiles)
        
        print(f"Saved {len(all_smiles)} embeddings with SMILES to {filepath}")
        return filepath
    
    def reset_embeddings(self):
        """Clear current embeddings (keeping original SMILES)"""
        self.embeddings = []
        self.batch_indices = []


import numpy as np
import pandas as pd
import torch
from rdkit import Chem
from rdkit.Chem import Descriptors, Lipinski
from rdkit.Chem.rdMolDescriptors import CalcTPSA
from tqdm import tqdm
import os
from datetime import datetime
import json
from torch_geometric.data import Batch, Data


def process_and_save_dataset(train_loader, output_dir='./analysis', prefix='before_training'):
    """Extract and analyze SMILES from a dataloader directly
    
    Args:
        train_loader: PyTorch Geometric DataLoader
        output_dir: Directory to save analysis results
        prefix: Prefix for output files
    """
    # Extract SMILES from the dataset
    all_smiles = []
    for batch in train_loader:
        for i in range(len(batch)):
            # Extract individual data items from the batch
            if hasattr(batch, 'smiles'):
                # If the batch has a smiles attribute (list)
                all_smiles.append(batch.smiles[i])
            elif hasattr(batch[i], 'smiles'):
                # If individual items have smiles attributes
                all_smiles.append(batch[i].smiles)
    
    print(f"Extracted {len(all_smiles)} SMILES strings from the dataset")
    
    # Save SMILES for reference
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    smiles_path = os.path.join(output_dir, f"{prefix}_smiles_{timestamp}.txt")
    os.makedirs(os.path.dirname(smiles_path), exist_ok=True)
    
    with open(smiles_path, 'w') as f:
        for smiles in all_smiles:
            f.write(f"{smiles}\n")
    
    print(f"Saved SMILES to {smiles_path}")
    
    # Analyze SMILES
    analyze_smiles_list(all_smiles, output_dir=output_dir, prefix=prefix)
    
    return all_smiles


# Utility function to ensure Data objects have SMILES attributes
def ensure_smiles_in_batch(batch):
    """Ensure that SMILES strings are available in a batch
    
    This modifies the batch in-place to make sure SMILES strings are preserved
    for later tracking.
    
    Args:
        batch: PyTorch Geometric batch
        
    Returns:
        Modified batch with smiles attribute
    """
    if not hasattr(batch, 'smiles'):
        # Check if individual data items have smiles
        smiles_list = []
        for i in range(len(batch)):
            if hasattr(batch[i], 'smiles'):
                smiles_list.append(batch[i].smiles)
            else:
                smiles_list.append(f"unknown_{i}")
        
        # Add smiles list as an attribute of the batch
        batch.smiles = smiles_list
    
    return batch


def modified_train_gan_cl(train_loader, config, device='cuda', 
                         save_dir='./checkpoints', 
                         embedding_dir='./embeddings'):
    """Modified training function with direct SMILES tracking"""
    
    # Create directories
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(embedding_dir, exist_ok=True)
    os.makedirs('./analysis', exist_ok=True)
    
    # Process and analyze original dataset
    print("Analyzing dataset before training...")
    original_smiles = process_and_save_dataset(train_loader, output_dir='./analysis', prefix='before_training')
    
    # Initialize model
    model = MolecularGANCL(config).to(device)
    
    # Initialize optimizers
    optimizer_encoder = torch.optim.Adam(model.encoder.parameters(), lr=3e-4)
    optimizer_generator = torch.optim.Adam(model.generator.parameters(), lr=1e-4)
    
    # Training metrics
    metrics = {
        'contrastive_losses': [],
        'adversarial_losses': [],
        'similarity_losses': [],
        'total_losses': []
    }
    
    # Training phases
    print("Phase 1: Pretraining Contrastive Learning...")
    pretrain_epochs = 10
    for epoch in range(pretrain_epochs):
        contrastive_epoch_loss = 0
        
        for batch in tqdm(train_loader, desc=f'Pretrain Epoch {epoch+1}/{pretrain_epochs}'):
            batch = batch.to(device)
            
            # Forward pass (without generator)
            query_emb = model.encoder(batch)
            with torch.no_grad():
                key_emb = model.momentum_encoder(batch)
            
            # Compute contrastive loss
            contrastive_loss = model.memory_queue.compute_contrastive_loss(
                query_emb, key_emb, model.config.temperature
            )
            
            # Update encoder
            optimizer_encoder.zero_grad()
            contrastive_loss.backward()
            optimizer_encoder.step()
            
            # Update momentum encoder
            model._momentum_update()
            
            # Update memory queue
            model.memory_queue.update_queue(key_emb.detach())
            
            contrastive_epoch_loss += contrastive_loss.item()
            
        avg_loss = contrastive_epoch_loss / len(train_loader)
        metrics['contrastive_losses'].append(avg_loss)
        print(f'Pretrain Epoch {epoch+1}, Avg Loss: {avg_loss:.4f}')
        
        # Save pretrained checkpoint
        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_encoder_state_dict': optimizer_encoder.state_dict(),
                'loss': avg_loss,
            }, os.path.join(save_dir, f'pretrain_checkpoint_{epoch+1}.pt'))
    
    print("\nPhase 2: Training GAN-CL...")
    train_epochs = 50
    best_loss = float('inf')
    
    for epoch in range(train_epochs):
        epoch_losses = {
            'contrastive': 0,
            'adversarial': 0,
            'similarity': 0,
            'total': 0
        }
        
        for batch in tqdm(train_loader, desc=f'Train Epoch {epoch+1}/{train_epochs}'):
            batch = batch.to(device)
            
            # Step 1: Train Encoder
            optimizer_encoder.zero_grad()
            
            # Get importance scores from generator
            with torch.no_grad():
                node_scores, edge_scores = model.generator(batch)
            
            # Create perturbed graph
            perturbed_data = model.drop_graph_elements(batch, node_scores, edge_scores)
            
            # Get embeddings
            query_emb = model.encoder(perturbed_data)
            with torch.no_grad():
                key_emb = model.momentum_encoder(batch)
                original_emb = model.encoder(batch).detach()
            
            # Compute losses for encoder
            contrastive_loss = model.memory_queue.compute_contrastive_loss(
                query_emb, key_emb, model.config.temperature
            )
            similarity_loss = torch.nn.functional.mse_loss(query_emb, original_emb)
            
            # Total loss for encoder
            encoder_loss = contrastive_loss + 0.1 * similarity_loss
            
            # Update encoder
            encoder_loss.backward()
            optimizer_encoder.step()
            
            # Update momentum encoder
            model._momentum_update()
            
            # Step 2: Train Generator
            optimizer_generator.zero_grad()
            
            # Get new embeddings for adversarial loss
            node_scores, edge_scores = model.generator(batch)
            perturbed_data = model.drop_graph_elements(batch, node_scores, edge_scores)
            
            with torch.no_grad():
                original_emb = model.encoder(batch)
            perturbed_emb = model.encoder(perturbed_data)
            
            # Compute adversarial loss
            adversarial_loss = -torch.nn.functional.mse_loss(perturbed_emb, original_emb)
            
            # Update generator
            adversarial_loss.backward()
            optimizer_generator.step()
            
            # Update memory queue
            model.memory_queue.update_queue(key_emb.detach())
            
            # Update metrics
            epoch_losses['contrastive'] += contrastive_loss.item()
            epoch_losses['adversarial'] += adversarial_loss.item()
            epoch_losses['similarity'] += similarity_loss.item()
            epoch_losses['total'] += encoder_loss.item()
        
        # Average losses
        for k in epoch_losses:
            epoch_losses[k] /= len(train_loader)
            metrics[f'{k}_losses'].append(epoch_losses[k])
        
        print(f'Epoch {epoch+1}, Losses: {epoch_losses}')
        
        # Save checkpoint
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_encoder_state_dict': optimizer_encoder.state_dict(),
                'optimizer_generator_state_dict': optimizer_generator.state_dict(),
                'losses': epoch_losses,
            }, os.path.join(save_dir, f'gan_cl_checkpoint_{epoch+1}.pt'))
    
    # Extract and save final embeddings
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    all_final_embeddings = []
    all_final_smiles = []
    
    # Extract embeddings after training
    model.eval()
    print("Extracting final embeddings...")
    
    with torch.no_grad():
        for batch in tqdm(train_loader, desc="Final embeddings"):
            # Process batch and add SMILES
            batch = batch.to(device)
            
            # Extract SMILES from batch
            batch_smiles = []
            for i in range(len(batch)):
                if hasattr(batch[i], 'smiles'):
                    batch_smiles.append(batch[i].smiles)
                else:
                    batch_smiles.append(f"unknown_{i}")
            
            # Get embeddings
            final_embeddings = model.get_embeddings(batch)
            
            # Store embeddings and SMILES
            all_final_embeddings.append(final_embeddings.cpu().numpy())
            all_final_smiles.extend(batch_smiles)
    
    # Concatenate embeddings
    if all_final_embeddings:
        all_final_embeddings = np.vstack(all_final_embeddings)
        
        # Save embeddings with SMILES
        final_embeddings_path = os.path.join(embedding_dir, f"after_training_{timestamp}.npz")
        np.savez(final_embeddings_path, embeddings=all_final_embeddings, smiles=all_final_smiles)
        print(f"Saved {len(all_final_smiles)} embeddings with SMILES to {final_embeddings_path}")
        
        # Analyze final embeddings
        print("Analyzing final embeddings...")
        analyze_smiles_list(all_final_smiles, output_dir='./analysis', prefix="after_training")
    else:
        print("Warning: No final embeddings to save")
        all_final_embeddings = np.array([])
        all_final_smiles = []
    
    # Save final metrics
    with open(os.path.join(save_dir, 'training_metrics.json'), 'w') as f:
        json.dump(metrics, f)
    
    return model, metrics, all_final_embeddings, all_final_smiles


def analyze_smiles_list(smiles_list, output_dir='./analysis', prefix='analysis'):
    """Analyze a list of SMILES strings for molecular properties
    
    Args:
        smiles_list: List of SMILES strings
        output_dir: Directory to save analysis results
        prefix: Prefix for output files
        
    Returns:
        Tuple of DataFrames with properties, features, functional groups
    """
    # Make sure directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"Analyzing {len(smiles_list)} molecules...")
    
    # Prepare data storage
    properties_list = []
    features_list = []
    func_groups_list = []
    valid_smiles = []
    
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    
    # Process each SMILES
    for smiles in tqdm(smiles_list):
        # Skip if not a valid SMILES string (e.g., "unknown_0")
        if not isinstance(smiles, str) or smiles.startswith("unknown_"):
            print(f"Skipping invalid SMILES placeholder: {smiles}")
            continue
            
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                print(f"Warning: Invalid SMILES: {smiles}")
                continue
                
            valid_smiles.append(smiles)
            
            # Extract basic properties
            prop = {
                'MW': Descriptors.ExactMolWt(mol),
                'LogP': Descriptors.MolLogP(mol),
                'TPSA': CalcTPSA(mol),
                'NumHAcceptors': Lipinski.NumHAcceptors(mol), 
                'NumHDonors': Lipinski.NumHDonors(mol),
                'NumRotatableBonds': Descriptors.NumRotatableBonds(mol),
                'NumAtoms': mol.GetNumAtoms(),
                'NumHeavyAtoms': mol.GetNumHeavyAtoms(),
                'NumBonds': mol.GetNumBonds(),
                'NumRings': Chem.GetSSSR(mol)
            }
            properties_list.append(prop)
            
            # Extract structural features
            ri = mol.GetRingInfo()
            rings = ri.AtomRings()
            
            num_aromatic = sum(1 for atom in mol.GetAtoms() if atom.GetIsAromatic())
            
            feat = {
                'Aromatic': 1 if num_aromatic > 0 else 0,
                'Heterocycles': 1 if any(any(mol.GetAtomWithIdx(idx).GetAtomicNum() != 6 for idx in ring) for ring in rings) else 0,
                'FusedRings': 0,
                'SpiroRings': 0,
                'BridgedRings': 0,
                'Macrocycles': 0,
                'LinearChain': 1 if mol.GetNumBonds() == mol.GetNumAtoms() - 1 and len(rings) == 0 else 0,
                'Branched': 1 if sum(1 for atom in mol.GetAtoms() if atom.GetDegree() > 2) > 0 else 0
            }
            
            # Check for fused rings
            if len(rings) >= 2:
                for i in range(len(rings)):
                    for j in range(i+1, len(rings)):
                        if len(set(rings[i]).intersection(set(rings[j]))) > 1:
                            feat['FusedRings'] = 1
                            break
                    if feat['FusedRings'] == 1:
                        break
            
            # Check for spiro rings
            if len(rings) >= 2:
                for i in range(len(rings)):
                    for j in range(i+1, len(rings)):
                        if len(set(rings[i]).intersection(set(rings[j]))) == 1:
                            feat['SpiroRings'] = 1
                            break
                    if feat['SpiroRings'] == 1:
                        break
            
            # Check for bridged rings
            bridged_patt = Chem.MolFromSmarts('[D4R]')
            if bridged_patt and mol.HasSubstructMatch(bridged_patt):
                feat['BridgedRings'] = 1
                
            # Check for macrocycles
            for ring in rings:
                if len(ring) >= 8:
                    feat['Macrocycles'] = 1
                    break
                    
            features_list.append(feat)
            
            # Extract functional groups
            fg = {}
            
            # Alcohols (explicit check for -OH group)
            alcohol_count = 0
            for atom in mol.GetAtoms():
                if atom.GetAtomicNum() == 8:  # Oxygen
                    if atom.GetTotalNumHs() > 0:  # Has hydrogen
                        # Check if connected to carbon
                        for neighbor in atom.GetNeighbors():
                            if neighbor.GetAtomicNum() == 6:  # Carbon
                                alcohol_count += 1
                                break
            fg['Alcohol'] = alcohol_count
            
            # Check amines (N with hydrogens)
            amine_count = 0
            for atom in mol.GetAtoms():
                if atom.GetAtomicNum() == 7:  # Nitrogen
                    if atom.GetTotalNumHs() > 0:  # Has hydrogen
                        amine_count += 1
            fg['Amine'] = amine_count
            
            # Other functional groups via SMARTS patterns
            patterns = {
                'Carboxyl': 'C(=O)[OH]',
                'Carbonyl': 'C=O',
                'Ether': 'COC',
                'Ester': 'C(=O)OC',
                'Amide': 'C(=O)N'
            }
            
            for name, smarts in patterns.items():
                patt = Chem.MolFromSmarts(smarts)
                if patt:
                    fg[name] = len(mol.GetSubstructMatches(patt))
                else:
                    fg[name] = 0
            
            # Count halogens
            fg['Halogen'] = sum(1 for atom in mol.GetAtoms() 
                            if atom.GetAtomicNum() in [9, 17, 35, 53])  # F, Cl, Br, I
            
            func_groups_list.append(fg)
            
        except Exception as e:
            print(f"Error processing SMILES {smiles}: {e}")
    
    # If no valid SMILES were found, create empty DataFrames
    if not valid_smiles:
        print("Warning: No valid SMILES found for analysis")
        # Create empty DataFrames
        props_df = pd.DataFrame()
        features_df = pd.DataFrame()
        func_groups_df = pd.DataFrame()
    else:
        # Create DataFrames
        props_df = pd.DataFrame(properties_list, index=valid_smiles)
        features_df = pd.DataFrame(features_list, index=valid_smiles)
        func_groups_df = pd.DataFrame(func_groups_list, index=valid_smiles)
    
    # Save to CSV
    output_prefix = os.path.join(output_dir, f"{prefix}_{timestamp}")
    
    props_df.to_csv(f"{output_prefix}_properties.csv")
    features_df.to_csv(f"{output_prefix}_features.csv")
    func_groups_df.to_csv(f"{output_prefix}_functional_groups.csv")
    
    # Save the valid SMILES list for reference
    with open(f"{output_prefix}_valid_smiles.txt", 'w') as f:
        for smiles in valid_smiles:
            f.write(f"{smiles}\n")
    
    print(f"Analysis saved to {output_prefix}_*.csv")
    print(f"Found {len(valid_smiles)} valid molecules out of {len(smiles_list)} SMILES")
    
    return props_df, features_df, func_groups_df

In [5]:
def save_embeddings(embeddings, labels, filepath):
    """Save embeddings and corresponding labels"""
    with open(filepath, 'wb') as f:
        pickle.dump({
            'embeddings': embeddings,
            'labels': labels
        }, f)

def save_encoder(encoder, save_path, info=None):
    """Save encoder model for downstream tasks"""
    save_dict = {
        'encoder_state_dict': encoder.state_dict(),
        'model_info': info or {}
    }
    torch.save(save_dict, save_path)

def load_encoder(model_path, device='cpu'):
    """Load saved encoder model"""
    checkpoint = torch.load(model_path, map_location=device)
    encoder = GraphDiscriminator(
        node_dim=checkpoint['model_info'].get('node_dim'),
        edge_dim=checkpoint['model_info'].get('edge_dim'),
        hidden_dim=checkpoint['model_info'].get('hidden_dim', 128),
        output_dim=checkpoint['model_info'].get('output_dim', 128)
    )
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    return encoder        
        
def train_gan_cl(train_loader, config, device='cuda', 
                save_dir='./checkpoints', 
                embedding_dir='./embeddings'):
    """Main training function for GAN-CL with fixed gradient computation"""
    
    # Create directories
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(embedding_dir, exist_ok=True)
    encoder_dir = os.path.join(save_dir, 'encoders')
    os.makedirs(encoder_dir, exist_ok=True)    
    
    # Initialize model
    model = MolecularGANCL(config).to(device)
    
    # Initialize optimizers
    optimizer_encoder = torch.optim.Adam(model.encoder.parameters(), lr=3e-4)
    optimizer_generator = torch.optim.Adam(model.generator.parameters(), lr=1e-4)
    
    # Save initial model info
    model_info = {
        'node_dim': config.node_dim,
        'edge_dim': config.edge_dim,
        'hidden_dim': config.hidden_dim,
        'output_dim': config.output_dim,
        'training_config': config.__dict__
    }
    
    # Training phases as before...
    best_loss = float('inf')
    
    # Training metrics
    metrics = {
        'contrastive_losses': [],
        'adversarial_losses': [],
        'similarity_losses': [],
        'total_losses': []
    }
    
    # Training phases
    print("Phase 1: Pretraining Contrastive Learning...")
    pretrain_epochs = 10
    for epoch in range(pretrain_epochs):
        contrastive_epoch_loss = 0
        
        for batch in tqdm(train_loader, desc=f'Pretrain Epoch {epoch+1}/{pretrain_epochs}'):
            batch = batch.to(device)
            
            # Forward pass (without generator)
            query_emb = model.encoder(batch)
            with torch.no_grad():
                key_emb = model.momentum_encoder(batch)
            
            # Compute contrastive loss
            contrastive_loss = model.memory_queue.compute_contrastive_loss(
                query_emb, key_emb, model.config.temperature
            )
            
            # Update encoder
            optimizer_encoder.zero_grad()
            contrastive_loss.backward()
            optimizer_encoder.step()
            
            # Update momentum encoder
            model._momentum_update()
            
            # Update memory queue
            model.memory_queue.update_queue(key_emb.detach())
            
            contrastive_epoch_loss += contrastive_loss.item()
            
        avg_loss = contrastive_epoch_loss / len(train_loader)
        metrics['contrastive_losses'].append(avg_loss)
        print(f'Pretrain Epoch {epoch+1}, Avg Loss: {avg_loss:.4f}')
        
        # Save pretrained checkpoint
        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_encoder_state_dict': optimizer_encoder.state_dict(),
                'loss': avg_loss,
            }, os.path.join(save_dir, f'pretrain_checkpoint_{epoch+1}.pt'))
    
    print("\nPhase 2: Training GAN-CL...")
    train_epochs = 50
#     train_epochs = 10
    for epoch in range(train_epochs):
        epoch_losses = {
            'contrastive': 0,
            'adversarial': 0,
            'similarity': 0,
            'total': 0
        }
        
        for batch in tqdm(train_loader, desc=f'Train Epoch {epoch+1}/{train_epochs}'):
            batch = batch.to(device)
            
            # Step 1: Train Encoder
            optimizer_encoder.zero_grad()
            
            # Get importance scores from generator
            with torch.no_grad():
                node_scores, edge_scores = model.generator(batch)
            
            # Create perturbed graph
            perturbed_data = model.drop_graph_elements(batch, node_scores, edge_scores)
            
            # Get embeddings
            query_emb = model.encoder(perturbed_data)
            with torch.no_grad():
                key_emb = model.momentum_encoder(batch)
                original_emb = model.encoder(batch).detach()
            
            # Compute losses for encoder
            contrastive_loss = model.memory_queue.compute_contrastive_loss(
                query_emb, key_emb, model.config.temperature
            )
            similarity_loss = F.mse_loss(query_emb, original_emb)
            
            # Total loss for encoder
            encoder_loss = contrastive_loss + 0.1 * similarity_loss
            
            # Update encoder
            encoder_loss.backward()
            optimizer_encoder.step()
            
            # Update momentum encoder
            model._momentum_update()
            
            # Step 2: Train Generator
            optimizer_generator.zero_grad()
            
            # Get new embeddings for adversarial loss
            node_scores, edge_scores = model.generator(batch)
            perturbed_data = model.drop_graph_elements(batch, node_scores, edge_scores)
            
            with torch.no_grad():
                original_emb = model.encoder(batch)
            perturbed_emb = model.encoder(perturbed_data)
            
            # Compute adversarial loss
            adversarial_loss = -F.mse_loss(perturbed_emb, original_emb)
            
            # Update generator
            adversarial_loss.backward()
            optimizer_generator.step()
            
            # Update memory queue
            model.memory_queue.update_queue(key_emb.detach())
            
            # Update metrics
            epoch_losses['contrastive'] += contrastive_loss.item()
            epoch_losses['adversarial'] += adversarial_loss.item()
            epoch_losses['similarity'] += similarity_loss.item()
            epoch_losses['total'] += encoder_loss.item()
        
        # Average losses
        for k in epoch_losses:
            epoch_losses[k] /= len(train_loader)
            metrics[f'{k}_losses'].append(epoch_losses[k])
        
        print(f'Epoch {epoch+1}, Losses: {epoch_losses}')
        
        # Save checkpoint
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_encoder_state_dict': optimizer_encoder.state_dict(),
                'optimizer_generator_state_dict': optimizer_generator.state_dict(),
                'losses': epoch_losses,
            }, os.path.join(save_dir, f'gan_cl_checkpoint_{epoch+1}.pt'))
            
        # Extract and save embeddings periodically
        if (epoch + 1) % 10 == 0:
            model.eval()
            all_embeddings = []
            all_graphs = []
            
            with torch.no_grad():
                for batch in train_loader:
                    batch = batch.to(device)
                    embeddings = model.get_embeddings(batch)
                    all_embeddings.append(embeddings.cpu())
                    all_graphs.extend([data for data in batch])
            
            all_embeddings = torch.cat(all_embeddings, dim=0)
            
            # Save embeddings
#             timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            save_embeddings(
                all_embeddings.numpy(),
                all_graphs,
                os.path.join(embedding_dir, f'embeddings_epoch_{epoch+1}_{timestamp}.pkl')
            )
            
            model.train()
    
    # Save final metrics
    with open(os.path.join(save_dir, 'training_metrics.json'), 'w') as f:
        json.dump(metrics, f)
    
            # Save encoder periodically
        if (epoch + 1) % 10 == 0:
            epoch_info = {
                **model_info,
                'epoch': epoch + 1,
                'loss': epoch_losses['total']
            }
            save_encoder(
                model.encoder,
                os.path.join(encoder_dir, f'encoder_epoch_{epoch+1}.pt'),
                epoch_info
            )
        
        # Save best encoder based on total loss
        if epoch_losses['total'] < best_loss:
            best_loss = epoch_losses['total']
            save_encoder(
                model.encoder,
                os.path.join(encoder_dir, f'best_encoder_{timestamp}.pt'),
                {**model_info, 'epoch': epoch + 1, 'loss': best_loss}
            )
    
    # Save final encoder
    save_encoder(
        model.encoder,
        os.path.join(encoder_dir, f'final_encoder_{timestamp}.pt'),
        {**model_info, 'epoch': train_epochs, 'loss': epoch_losses['total']}
    )
    
    return model, metrics


In [6]:
def main():
    """Main function"""
    # Enable anomaly detection during development
    torch.autograd.set_detect_anomaly(True)
    # Your existing data loading code here
    torch.manual_seed(42)
    np.random.seed(42)
    
    print("Starting data loading...")
    extractor = MolecularFeatureExtractor()
    smiles_file = "D:\\PhD\\Chapter3\\Unsupervised_GAN_Code\\pubchem-41-clean.txt"
    
    dataset = []
    failed_smiles = []
    
    with open(smiles_file, 'r') as f:
        for i, line in enumerate(f):
            smiles = line.strip()
            data = extractor.process_molecule(smiles)
            if data is not None:
                # Store original SMILES in the data object
                data.smiles = smiles
                dataset.append(data)
            else:
                failed_smiles.append(smiles)
            
            # Limit dataset size for testing
            if i >= 10000:  # Adjust as needed
                break
    
    print(f"1. Loaded dataset with {len(dataset)} graphs.")
    print(f"2. Failed SMILES count: {len(failed_smiles)}")
    
    if not dataset:
        print("No valid graphs generated.")
        return None
        
    # Setup training
    batch_size = 32
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    print(f"3. Created DataLoader with {len(train_loader.dataset)} graphs")
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"4. Using device: {device}")
    
    # Get configuration based on dataset
    config = get_model_config(dataset)
    
    # Train model with direct SMILES tracking
    print("5. Starting GAN-CL training with SMILES tracking...")
    model, metrics, final_embeddings, final_smiles = modified_train_gan_cl(
        train_loader, 
        config,
        device=device,
        save_dir='./checkpoints',
        embedding_dir='./embeddings'
    )
    
    print("6. Training completed!")
    print(f"7. Final results: {len(final_smiles)} embeddings with associated SMILES")
    
    # At this point, all analysis should be complete
    print("8. Analysis complete. The data can now be used for visualization and comparison.")
    
    return model, metrics, final_embeddings, final_smiles


if __name__ == "__main__":
    model, metrics, embeddings, smiles = main()

Starting data loading...




1. Loaded dataset with 41 graphs.
2. Failed SMILES count: 0
3. Created DataLoader with 41 graphs
4. Using device: cpu
5. Starting GAN-CL training with SMILES tracking...
Analyzing dataset before training...
Extracted 41 SMILES strings from the dataset
Saved SMILES to ./analysis\before_training_smiles_20250301_140313.txt
Analyzing 41 molecules...


100%|█████████████████████████████████████████████████████████████████████████████████| 41/41 [00:00<00:00, 898.48it/s]


Analysis saved to ./analysis\before_training_20250301_140313_*.csv
Found 41 valid molecules out of 41 SMILES
Phase 1: Pretraining Contrastive Learning...


Pretrain Epoch 1/10: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  7.46it/s]


Pretrain Epoch 1, Avg Loss: 1.7414


Pretrain Epoch 2/10: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 10.02it/s]


Pretrain Epoch 2, Avg Loss: 3.8775


Pretrain Epoch 3/10: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  9.28it/s]


Pretrain Epoch 3, Avg Loss: 4.3861


Pretrain Epoch 4/10: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  6.26it/s]


Pretrain Epoch 4, Avg Loss: 4.6080


Pretrain Epoch 5/10: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.99it/s]


Pretrain Epoch 5, Avg Loss: 4.8150


Pretrain Epoch 6/10: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  8.12it/s]


Pretrain Epoch 6, Avg Loss: 5.0336


Pretrain Epoch 7/10: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  9.15it/s]


Pretrain Epoch 7, Avg Loss: 5.0830


Pretrain Epoch 8/10: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  7.36it/s]


Pretrain Epoch 8, Avg Loss: 5.2246


Pretrain Epoch 9/10: 100%|███████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  7.62it/s]


Pretrain Epoch 9, Avg Loss: 5.3733


Pretrain Epoch 10/10: 100%|██████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  7.97it/s]


Pretrain Epoch 10, Avg Loss: 5.3945

Phase 2: Training GAN-CL...


Train Epoch 1/50: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.28it/s]


Epoch 1, Losses: {'contrastive': 5.613857984542847, 'adversarial': -0.00035622032009996474, 'similarity': 0.0003208323469152674, 'total': 5.613890171051025}


Train Epoch 2/50: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.42it/s]


Epoch 2, Losses: {'contrastive': 5.602139234542847, 'adversarial': -0.00041644911107141525, 'similarity': 0.00019903276552213356, 'total': 5.602159261703491}


Train Epoch 3/50: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.68it/s]


Epoch 3, Losses: {'contrastive': 5.776403903961182, 'adversarial': -0.0002487349775037728, 'similarity': 0.0004396298172650859, 'total': 5.776448011398315}


Train Epoch 4/50: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.87it/s]


Epoch 4, Losses: {'contrastive': 5.786062479019165, 'adversarial': -0.0003670359874377027, 'similarity': 0.00023863442038418725, 'total': 5.786086320877075}


Train Epoch 5/50: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.89it/s]


Epoch 5, Losses: {'contrastive': 5.810200214385986, 'adversarial': -0.0002592162272776477, 'similarity': 0.00024219768965849653, 'total': 5.810224533081055}


Train Epoch 6/50: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.77it/s]


Epoch 6, Losses: {'contrastive': 5.822868347167969, 'adversarial': -0.00019123798119835556, 'similarity': 0.0002630089147714898, 'total': 5.82289457321167}


Train Epoch 7/50: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.89it/s]


Epoch 7, Losses: {'contrastive': 5.9388792514801025, 'adversarial': -0.00025760434073163196, 'similarity': 0.00025341274886159226, 'total': 5.938904523849487}


Train Epoch 8/50: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.88it/s]


Epoch 8, Losses: {'contrastive': 5.943818092346191, 'adversarial': -0.00032640770950820297, 'similarity': 0.0002051167539320886, 'total': 5.943838596343994}


Train Epoch 9/50: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.99it/s]


Epoch 9, Losses: {'contrastive': 6.04681134223938, 'adversarial': -0.0002637593570398167, 'similarity': 0.0003642529190983623, 'total': 6.046847820281982}


Train Epoch 10/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.96it/s]


Epoch 10, Losses: {'contrastive': 5.995693922042847, 'adversarial': -0.0003292931796750054, 'similarity': 0.00029201857978478074, 'total': 5.995723247528076}


Train Epoch 11/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.54it/s]


Epoch 11, Losses: {'contrastive': 6.0029308795928955, 'adversarial': -0.00021977775759296492, 'similarity': 0.0002688825625227764, 'total': 6.002957582473755}


Train Epoch 12/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.89it/s]


Epoch 12, Losses: {'contrastive': 6.124785423278809, 'adversarial': -0.0003300555399619043, 'similarity': 0.0003512067924020812, 'total': 6.124820709228516}


Train Epoch 13/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.42it/s]


Epoch 13, Losses: {'contrastive': 6.159255504608154, 'adversarial': -0.00018644799274625257, 'similarity': 0.0002371559021412395, 'total': 6.159279108047485}


Train Epoch 14/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.02it/s]


Epoch 14, Losses: {'contrastive': 6.194928169250488, 'adversarial': -0.0002618615690153092, 'similarity': 0.00026680598239181563, 'total': 6.194954872131348}


Train Epoch 15/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.71it/s]


Epoch 15, Losses: {'contrastive': 6.206578016281128, 'adversarial': -0.00022813717077951878, 'similarity': 0.0002982755977427587, 'total': 6.206607818603516}


Train Epoch 16/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.60it/s]


Epoch 16, Losses: {'contrastive': 6.226555347442627, 'adversarial': -0.00027981864695902914, 'similarity': 0.00028140388894826174, 'total': 6.226583480834961}


Train Epoch 17/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.89it/s]


Epoch 17, Losses: {'contrastive': 6.169885873794556, 'adversarial': -0.00015823738067410886, 'similarity': 0.00021405993902590126, 'total': 6.169907331466675}


Train Epoch 18/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.88it/s]


Epoch 18, Losses: {'contrastive': 6.2140398025512695, 'adversarial': -0.00021041298168711364, 'similarity': 0.00040543244540458545, 'total': 6.214080333709717}


Train Epoch 19/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.08it/s]


Epoch 19, Losses: {'contrastive': 6.357293605804443, 'adversarial': -0.0003338971146149561, 'similarity': 0.0002608720023999922, 'total': 6.357319593429565}


Train Epoch 20/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.77it/s]


Epoch 20, Losses: {'contrastive': 6.403562545776367, 'adversarial': -0.0003783313150051981, 'similarity': 0.00029251378146000206, 'total': 6.403591632843018}


Train Epoch 21/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.81it/s]


Epoch 21, Losses: {'contrastive': 6.357892990112305, 'adversarial': -0.0003155117155984044, 'similarity': 0.000340967089869082, 'total': 6.357927083969116}


Train Epoch 22/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.91it/s]


Epoch 22, Losses: {'contrastive': 6.441501617431641, 'adversarial': -0.00022769557835999876, 'similarity': 0.0003027850325452164, 'total': 6.4415318965911865}


Train Epoch 23/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.68it/s]


Epoch 23, Losses: {'contrastive': 6.385432243347168, 'adversarial': -0.00029272359097376466, 'similarity': 0.00025911908596754074, 'total': 6.38545823097229}


Train Epoch 24/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.03it/s]


Epoch 24, Losses: {'contrastive': 6.517337799072266, 'adversarial': -0.00024220493651228026, 'similarity': 0.000358695222530514, 'total': 6.517373561859131}


Train Epoch 25/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.92it/s]


Epoch 25, Losses: {'contrastive': 6.560001373291016, 'adversarial': -0.0003033413377124816, 'similarity': 0.0003411675716051832, 'total': 6.560035467147827}


Train Epoch 26/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.11it/s]


Epoch 26, Losses: {'contrastive': 6.443934917449951, 'adversarial': -0.00028237666992936283, 'similarity': 0.00022344150784192607, 'total': 6.443957090377808}


Train Epoch 27/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.00it/s]


Epoch 27, Losses: {'contrastive': 6.515333890914917, 'adversarial': -0.0002538797562010586, 'similarity': 0.0002691253539524041, 'total': 6.515360593795776}


Train Epoch 28/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.80it/s]


Epoch 28, Losses: {'contrastive': 6.400022983551025, 'adversarial': -0.0002919069229392335, 'similarity': 0.0002425112688797526, 'total': 6.400047302246094}


Train Epoch 29/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.29it/s]


Epoch 29, Losses: {'contrastive': 6.483670473098755, 'adversarial': -0.00027928894269280136, 'similarity': 0.0002263242204207927, 'total': 6.4836931228637695}


Train Epoch 30/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.69it/s]


Epoch 30, Losses: {'contrastive': 6.493528842926025, 'adversarial': -0.00021912228839937598, 'similarity': 0.00026391082792542875, 'total': 6.493555307388306}


Train Epoch 31/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.64it/s]


Epoch 31, Losses: {'contrastive': 6.619554042816162, 'adversarial': -0.0002499188922229223, 'similarity': 0.0003219181817257777, 'total': 6.619586229324341}


Train Epoch 32/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.04it/s]


Epoch 32, Losses: {'contrastive': 6.604148864746094, 'adversarial': -0.00023477452486986294, 'similarity': 0.00029323616763576865, 'total': 6.604178190231323}


Train Epoch 33/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.18it/s]


Epoch 33, Losses: {'contrastive': 6.554488182067871, 'adversarial': -0.0002808406570693478, 'similarity': 0.0002567828050814569, 'total': 6.554513931274414}


Train Epoch 34/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.79it/s]


Epoch 34, Losses: {'contrastive': 6.640108346939087, 'adversarial': -0.0002570294527686201, 'similarity': 0.00034352824150118977, 'total': 6.6401426792144775}


Train Epoch 35/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.16it/s]


Epoch 35, Losses: {'contrastive': 6.737628698348999, 'adversarial': -0.000251511788519565, 'similarity': 0.000392052810639143, 'total': 6.737668037414551}


Train Epoch 36/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.83it/s]


Epoch 36, Losses: {'contrastive': 6.689039468765259, 'adversarial': -0.00032414545421488583, 'similarity': 0.0003797968529397622, 'total': 6.689077377319336}


Train Epoch 37/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.67it/s]


Epoch 37, Losses: {'contrastive': 6.685063123703003, 'adversarial': -0.00031769835914019495, 'similarity': 0.00027533697721082717, 'total': 6.6850905418396}


Train Epoch 38/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.84it/s]


Epoch 38, Losses: {'contrastive': 6.646392583847046, 'adversarial': -0.00029649617499671876, 'similarity': 0.0003797268000198528, 'total': 6.646430730819702}


Train Epoch 39/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.90it/s]


Epoch 39, Losses: {'contrastive': 6.721464157104492, 'adversarial': -0.00025778677809285, 'similarity': 0.0003120310138911009, 'total': 6.7214953899383545}


Train Epoch 40/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.05it/s]


Epoch 40, Losses: {'contrastive': 6.763017177581787, 'adversarial': -0.00026221504231216386, 'similarity': 0.0003566572268027812, 'total': 6.763052701950073}


Train Epoch 41/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.02it/s]


Epoch 41, Losses: {'contrastive': 6.728412866592407, 'adversarial': -0.00040262505353894085, 'similarity': 0.00045368142309598625, 'total': 6.7284581661224365}


Train Epoch 42/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.98it/s]


Epoch 42, Losses: {'contrastive': 6.72764253616333, 'adversarial': -0.0002239174209535122, 'similarity': 0.00024661546194693074, 'total': 6.7276670932769775}


Train Epoch 43/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.82it/s]


Epoch 43, Losses: {'contrastive': 6.7495832443237305, 'adversarial': -0.00024061810108833015, 'similarity': 0.00031625023984815925, 'total': 6.749614715576172}


Train Epoch 44/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.05it/s]


Epoch 44, Losses: {'contrastive': 6.719606161117554, 'adversarial': -0.0002649614034453407, 'similarity': 0.00024376938381465152, 'total': 6.719630718231201}


Train Epoch 45/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.97it/s]


Epoch 45, Losses: {'contrastive': 6.876022577285767, 'adversarial': -0.00033359423105139285, 'similarity': 0.0002836329076671973, 'total': 6.87605094909668}


Train Epoch 46/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.64it/s]


Epoch 46, Losses: {'contrastive': 6.817465305328369, 'adversarial': -0.0003284262929810211, 'similarity': 0.00038551961188204587, 'total': 6.817503929138184}


Train Epoch 47/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  4.22it/s]


Epoch 47, Losses: {'contrastive': 6.7435901165008545, 'adversarial': -0.00023320678155869246, 'similarity': 0.0003012162778759375, 'total': 6.743620157241821}


Train Epoch 48/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.78it/s]


Epoch 48, Losses: {'contrastive': 6.757992267608643, 'adversarial': -0.0003213839663658291, 'similarity': 0.00033199235622305423, 'total': 6.758025407791138}


Train Epoch 49/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.52it/s]


Epoch 49, Losses: {'contrastive': 6.683620929718018, 'adversarial': -0.0003573875583242625, 'similarity': 0.00019149315630784258, 'total': 6.683640003204346}


Train Epoch 50/50: 100%|█████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.89it/s]


Epoch 50, Losses: {'contrastive': 6.729233026504517, 'adversarial': -0.00024506183399353176, 'similarity': 0.00034168182173743844, 'total': 6.729267120361328}
Extracting final embeddings...


Final embeddings: 100%|██████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 83.70it/s]


Saved 41 embeddings with SMILES to ./embeddings\after_training_20250301_140344.npz
Analyzing final embeddings...
Analyzing 41 molecules...


100%|█████████████████████████████████████████████████████████████████████████████████| 41/41 [00:00<00:00, 771.78it/s]

Analysis saved to ./analysis\after_training_20250301_140344_*.csv
Found 41 valid molecules out of 41 SMILES
6. Training completed!
7. Final results: 41 embeddings with associated SMILES
8. Analysis complete. The data can now be used for visualization and comparison.



