In [1]:
import torch
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from typing import Tuple
import numpy as np
from torch_geometric.data import Data
import random
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from rdkit import RDLogger
from rdkit.Chem import RemoveHs
from datetime import datetime

import torch
import numpy as np
from torch_geometric.data import DataLoader
import os
import json
from tqdm import tqdm
import pickle
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, MessagePassing
from typing import Tuple, List, Optional
import copy
from dataclasses import dataclass

import numpy as np
import torch
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers, StereoEnumerationOptions
from tqdm import tqdm
import random
import copy
from torch_geometric.data import Data, Batch, DataLoader
import os
import pickle
from datetime import datetime

# Suppress RDKit warnings
RDLogger.DisableLog('rdApp.warning')
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')



In [2]:
class MolecularFeatureExtractor:
    def __init__(self):
        self.atom_list = list(range(1, 119))
        self.chirality_list = [
            Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
            Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
            Chem.rdchem.ChiralType.CHI_OTHER
        ]
        self.bond_list = [
            Chem.rdchem.BondType.SINGLE,
            Chem.rdchem.BondType.DOUBLE, 
            Chem.rdchem.BondType.TRIPLE,
            Chem.rdchem.BondType.AROMATIC
        ]
        self.bonddir_list = [
            Chem.rdchem.BondDir.NONE,
            Chem.rdchem.BondDir.ENDUPRIGHT,
            Chem.rdchem.BondDir.ENDDOWNRIGHT
        ]

    def calc_atom_features(self, atom: Chem.Atom) -> Tuple[list, list]:
        """Calculate atom features with better error handling"""
        try:
            # Basic features
            atom_feat = [
                self.atom_list.index(atom.GetAtomicNum()),
                self.chirality_list.index(atom.GetChiralTag())
            ]

            # Physical features with error handling
            phys_feat = []
            
            # Molecular weight contribution
            try:
                contrib_mw = Descriptors.ExactMolWt(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_mw)
            except:
                phys_feat.append(0.0)
                
            # LogP contribution    
            try:
                contrib_logp = Descriptors.MolLogP(Chem.MolFromSmiles(f'[{atom.GetSymbol()}]'))
                phys_feat.append(contrib_logp)
            except:
                phys_feat.append(0.0)
                
            # Add other physical properties
            phys_feat.extend([
                atom.GetFormalCharge(),
                int(atom.GetHybridization()),
                int(atom.GetIsAromatic()),
                atom.GetTotalNumHs(),
                atom.GetTotalValence(),
                atom.GetDegree()
            ])
            
            return atom_feat, phys_feat
            
        except Exception as e:
            print(f"Error calculating atom features: {e}")
            return [0, 0], [0.0] * 9

    def get_atom_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract atom features for the whole molecule"""
        atom_feats = []
        phys_feats = []
        
        if mol is None:
            return torch.tensor([[0, 0]], dtype=torch.long), torch.tensor([[0.0] * 9], dtype=torch.float)
            
        for atom in mol.GetAtoms():
            atom_feat, phys_feat = self.calc_atom_features(atom)
            atom_feats.append(atom_feat)
            phys_feats.append(phys_feat)

        x = torch.tensor(atom_feats, dtype=torch.long)
        phys = torch.tensor(phys_feats, dtype=torch.float)
        
        return x, phys
    
    def remove_unbonded_hydrogens(mol):
        params = Chem.RemoveHsParameters()
        params.removeDegreeZero = True
        mol = Chem.RemoveHs(mol, params)
        return mol


    def get_bond_features(self, mol: Chem.Mol) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract bond features with better error handling"""
        if mol is None:
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)
            
        row, col, edge_feat = [], [], []
        
        for bond in mol.GetBonds():
            try:
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                
                # Add edges in both directions
                row += [start, end]
                col += [end, start]
                
                # Bond features
                bond_type = self.bond_list.index(bond.GetBondType())
                bond_dir = self.bonddir_list.index(bond.GetBondDir())
                
                # Calculate additional properties
                feat = [
                    bond_type,
                    bond_dir,
                    int(bond.GetIsConjugated()),
                    int(self._is_rotatable(bond)),
                    self._get_bond_length(mol, start, end)
                ]
                
                edge_feat.extend([feat, feat])
                
            except Exception as e:
                print(f"Error processing bond: {e}")
                continue

        if not row:  # If no valid bonds were processed
            return torch.tensor([[0], [0]], dtype=torch.long), torch.tensor([[0.0] * 5], dtype=torch.float)

        edge_index = torch.tensor([row, col], dtype=torch.long)
        edge_attr = torch.tensor(edge_feat, dtype=torch.float)
        
        return edge_index, edge_attr

    def _is_rotatable(self, bond: Chem.Bond) -> bool:
        """Check if bond is rotatable"""
        return (bond.GetBondType() == Chem.rdchem.BondType.SINGLE and 
                not bond.IsInRing() and
                len(bond.GetBeginAtom().GetNeighbors()) > 1 and
                len(bond.GetEndAtom().GetNeighbors()) > 1)

    def _get_bond_length(self, mol: Chem.Mol, start: int, end: int) -> float:
        """Get bond length with error handling"""
        try:
            conf = mol.GetConformer()
            if conf.Is3D():
                return Chem.rdMolTransforms.GetBondLength(conf, start, end)
        except:
            pass
        return 0.0

    def process_molecule(self, smiles: str) -> Data:
        """Process SMILES string to graph data"""
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                print(f"Invalid SMILES: {smiles}")
                return None  # Skip invalid molecules
            mol = RemoveHs(mol)

            # Add explicit hydrogens
            mol = Chem.AddHs(mol, addCoords=True)

            # Sanitize molecule
            Chem.SanitizeMol(mol)

            # Check if the molecule has atoms
            if mol.GetNumAtoms() == 0:
                print("Molecule has no atoms, skipping.")
                return None

            # Generate 3D coordinates
            if not mol.GetNumConformers():
                status = AllChem.EmbedMolecule(mol, AllChem.ETKDG())
                if status != 0:
                    print("Failed to generate 3D conformer")
                    return None  # Skip failed molecules

                # Try MMFF or UFF optimization
                try:
                    AllChem.MMFFOptimizeMolecule(mol)
                except:
                    AllChem.UFFOptimizeMolecule(mol)

            # Extract features
            x_cat, x_phys = self.get_atom_features(mol)
            edge_index, edge_attr = self.get_bond_features(mol)

            # Create data object with SMILES
            data = Data(
                x_cat=x_cat, 
                x_phys=x_phys,
                edge_index=edge_index, 
                edge_attr=edge_attr,
                num_nodes=x_cat.size(0)
            )

            # Store the SMILES as an attribute
            data.smiles = smiles

            return data

        except Exception as e:
            print(f"Error processing molecule {smiles}: {e}")
            return None

In [3]:
import numpy as np
import torch
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers, StereoEnumerationOptions
from tqdm import tqdm
import random
import copy
from torch_geometric.data import Data, Batch, DataLoader
import os
import pickle
from rdkit.Chem import MolStandardize
from collections import defaultdict

class MolecularAugmenter:
    """Class for applying chemical augmentations to molecules"""
    
    def __init__(self, smiles_list, feature_extractor=None, seed=42):
        """
        Args:
            smiles_list: List of SMILES strings to augment
            feature_extractor: The feature extractor used to convert SMILES to graph data
            seed: Random seed for reproducibility
        """
        self.smiles_list = smiles_list
        self.feature_extractor = feature_extractor
        random.seed(seed)
        np.random.seed(seed)
        
        # Don't use MolStandardize since it's not available in this version
        # We'll implement our own simple standardization function
            
    def standardize_mol(self, mol):
        """Standardize a molecule to ensure consistency"""
        if mol is None:
            return None
        try:
            # Basic standardization
            mol = Chem.RemoveHs(mol)
            Chem.SanitizeMol(mol)
            return mol
        except:
            return None
            
    def add_remove_hydrogen(self, mol):
        """Add or remove hydrogen atoms from the molecule"""
        if mol is None:
            return None
        
        try:
            # 50% chance to add hydrogens, 50% chance to remove them
            if random.random() < 0.5:
                # Add hydrogens
                mol = Chem.AddHs(mol)
            else:
                # Remove hydrogens
                mol = Chem.RemoveHs(mol)
                
            return mol
        except:
            return mol
            
    def change_bond_order(self, mol):
        """Change the order of a random bond (single↔double, triple↔double)"""
        if mol is None:
            return None
            
        try:
            # Make a copy to avoid modifying the original
            mol_copy = Chem.Mol(mol)
            
            # Get all bonds
            bonds = list(mol_copy.GetBonds())
            if not bonds:
                return mol_copy
                
            # Select a random bond
            bond = random.choice(bonds)
            bond_idx = bond.GetIdx()
            bond_type = bond.GetBondType()
            
            # Change bond type
            edit_mol = Chem.EditableMol(mol_copy)
            
            if bond_type == Chem.BondType.SINGLE:
                edit_mol.ReplaceBond(bond_idx, Chem.BondType.DOUBLE)
            elif bond_type == Chem.BondType.DOUBLE:
                # 50% chance to go to single, 50% to triple
                new_type = random.choice([Chem.BondType.SINGLE, Chem.BondType.TRIPLE])
                edit_mol.ReplaceBond(bond_idx, new_type)
            elif bond_type == Chem.BondType.TRIPLE:
                edit_mol.ReplaceBond(bond_idx, Chem.BondType.DOUBLE)
            
            mod_mol = edit_mol.GetMol()
            
            # Check if the molecule is valid
            try:
                Chem.SanitizeMol(mod_mol)
                return mod_mol
            except:
                return mol_copy
        except:
            return mol
            
    def enumerate_stereoisomer(self, mol):
        """Generate a random stereoisomer of the molecule"""
        if mol is None:
            return None
            
        try:
            # Set up stereoisomer options
            opts = StereoEnumerationOptions(tryEmbedding=True, maxIsomers=5)
            
            # Get all possible stereoisomers
            isomers = list(EnumerateStereoisomers(mol, options=opts))
            
            if isomers:
                # Return a random stereoisomer
                return random.choice(isomers)
            else:
                return mol
        except:
            return mol
            
    def mutate_functional_group(self, mol):
        """Replace a functional group with another similar one"""
        if mol is None:
            return None
            
        # Define some functional group transformations
        transformations = [
            # Alcohol to ether
            ('[OH]', '[OC]'),
            # Carboxylic acid to ester
            ('C(=O)[OH]', 'C(=O)OC'),
            # Amide to ester
            ('C(=O)[NH2]', 'C(=O)OC'),
            # Amine to amide
            ('[NH2]', '[NH]C(=O)C'),
            # Ketone to aldehyde
            ('C(=O)C', 'C(=O)[H]'),
            # Ether to alcohol
            ('COC', 'CO'),
            # Nitro to amine
            ('[N+](=O)[O-]', '[NH2]')
        ]
        
        try:
            # Choose a random transformation
            old_pattern, new_pattern = random.choice(transformations)
            
            # Apply the transformation using SMILES replacement
            smiles = Chem.MolToSmiles(mol)
            old_pattern_mol = Chem.MolFromSmarts(old_pattern)
            
            if mol.HasSubstructMatch(old_pattern_mol):
                # Create the replacement
                modified_smiles = smiles.replace(old_pattern, new_pattern, 1)
                new_mol = Chem.MolFromSmiles(modified_smiles)
                
                if new_mol is not None:
                    return new_mol
            
            return mol
        except:
            return mol
            
    def change_atom(self, mol):
        """Substitute an atom with another from the same group"""
        if mol is None:
            return None
            
        # Define atom replacements (elements from same group tend to have similar properties)
        replacements = {
            'C': ['Si'],  # Carbon -> Silicon
            'N': ['P'],   # Nitrogen -> Phosphorus
            'O': ['S'],   # Oxygen -> Sulfur
            'F': ['Cl', 'Br'],  # Halogen replacements
            'Cl': ['F', 'Br'],
            'Br': ['F', 'Cl'],
            'S': ['O', 'Se']   # Sulfur -> Oxygen or Selenium
        }
        
        try:
            # Make an editable copy
            edit_mol = Chem.EditableMol(Chem.Mol(mol))
            atoms = list(mol.GetAtoms())
            
            if not atoms:
                return mol
                
            # Select a random atom
            atom = random.choice(atoms)
            symbol = atom.GetSymbol()
            
            # Check if we have a replacement for this atom
            if symbol in replacements and random.random() < 0.7:  # 70% chance to replace
                new_symbol = random.choice(replacements[symbol])
                new_atomic_num = Chem.GetPeriodicTable().GetAtomicNumber(new_symbol)
                
                # Replace the atom
                edit_mol.ReplaceAtom(atom.GetIdx(), Chem.Atom(new_atomic_num))
                
                new_mol = edit_mol.GetMol()
                
                # Check if valid
                try:
                    Chem.SanitizeMol(new_mol)
                    return new_mol
                except:
                    return mol
            
            return mol
        except:
            return mol
    
    def add_ring_substituent(self, mol):
        """Add a small substituent to a ring"""
        if mol is None:
            return None
            
        # Common substituents
        substituents = [
            ('c1ccccc1', 'c1ccccc1[C]'),  # Add methyl to benzene
            ('c1ccccc1', 'c1ccccc1[O]'),  # Add hydroxyl to benzene
            ('c1ccccc1', 'c1ccccc1[F]'),  # Add fluorine to benzene
            ('c1ccccc1', 'c1ccccc1[Cl]'),  # Add chlorine to benzene
            ('c1ccccc1', 'c1ccccc1[N]')   # Add amino to benzene
        ]
        
        try:
            # Choose a random substituent
            old_pattern, new_pattern = random.choice(substituents)
            
            # Apply the transformation using SMILES replacement
            smiles = Chem.MolToSmiles(mol)
            old_pattern_mol = Chem.MolFromSmarts(old_pattern)
            
            if mol.HasSubstructMatch(old_pattern_mol):
                # Create the replacement
                modified_smiles = smiles.replace(old_pattern, new_pattern, 1)
                new_mol = Chem.MolFromSmiles(modified_smiles)
                
                if new_mol is not None:
                    return new_mol
            
            return mol
        except:
            return mol
    
    def apply_random_augmentation(self, mol):
        """Apply a random augmentation from the available methods"""
        if mol is None:
            return None
            
        # Define augmentation methods with their probabilities
        augmentations = [
            (self.add_remove_hydrogen, 0.1),
            (self.change_bond_order, 0.2),
            (self.enumerate_stereoisomer, 0.15),
            (self.mutate_functional_group, 0.25),
            (self.change_atom, 0.2),
            (self.add_ring_substituent, 0.1)
        ]
        
        # Normalize probabilities
        total_prob = sum(prob for _, prob in augmentations)
        normalized_probs = [prob / total_prob for _, prob in augmentations]
        
        # Select augmentation based on probability
        aug_func = np.random.choice([aug for aug, _ in augmentations], p=normalized_probs)
        
        # Apply the selected augmentation
        augmented_mol = aug_func(mol)
        
        # Standardize the result
        return self.standardize_mol(augmented_mol)
    
    def generate_augmented_dataset(self, num_augmentations=1):  
        """Generate augmented dataset from original SMILES"""
        augmented_data_list = []

        for i, smiles in enumerate(tqdm(self.smiles_list, desc="Generating augmentations")):
            # Convert to molecule
            mol = Chem.MolFromSmiles(smiles)

            if mol is None:
                continue

            # Add the original molecule
            if self.feature_extractor is not None:
                original_data = self.feature_extractor.process_molecule(smiles)
                if original_data is not None:
                    # Store original SMILES as attribute
                    original_data.smiles = smiles
                    # Add the original_idx attribute to ALL objects for consistency
                    original_data.original_idx = i
                    augmented_data_list.append(original_data)

            # Generate augmentations
            augmented_mol = self.apply_random_augmentation(mol)

            if augmented_mol is not None:
                # Convert back to SMILES
                aug_smiles = Chem.MolToSmiles(augmented_mol)

                # Convert to graph data
                if self.feature_extractor is not None:
                    aug_data = self.feature_extractor.process_molecule(aug_smiles)
                    if aug_data is not None:
                        # Store SMILES as attribute
                        aug_data.smiles = aug_smiles
                        # Store the original molecule index for reference
                        aug_data.original_idx = i
                        augmented_data_list.append(aug_data)

        return augmented_data_list

def train_encoder_with_manual_augmentations(dataset, feature_extractor, output_dir='./embeddings', 
                                           batch_size=32, epochs=50, lr=1e-4, hidden_dim=128, 
                                           output_dim=128, device='cuda'):
    """Train encoder using manual augmentations"""
    from torch_geometric.nn import GCNConv, global_mean_pool
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    from torch_geometric.data import Batch
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Extract SMILES strings for augmentation
    smiles_list = []
    for i, data in enumerate(dataset):
        mol_id = f"molecule_{i}"
        # If you have SMILES in the dataset, extract them
        if hasattr(data, 'smiles'):
            smiles_list.append(data.smiles)
        else:
            # Otherwise, try to convert PyG data to SMILES using RDKit
            smiles = convert_pyg_to_smiles(data)
            smiles_list.append(smiles)
    
    # Create augmenter and generate augmented dataset
    augmenter = MolecularAugmenter(smiles_list, feature_extractor)
    augmented_dataset = augmenter.generate_augmented_dataset(num_augmentations=1)
    
    print(f"Generated {len(augmented_dataset)} augmented molecules (including originals)")
    
    # Define custom collate function to handle additional attributes
    def custom_collate(data_list):
        keys_to_exclude = ['smiles', 'original_idx']  # Add any custom attributes here
        
        # Store the excluded attributes
        excluded_data = {key: [getattr(data, key, None) for data in data_list] for key in keys_to_exclude}
        
        # Create batch without excluded keys
        batch = Batch.from_data_list(data_list)
        
        # Add back the excluded attributes as lists
        for key, values in excluded_data.items():
            setattr(batch, key, values)
            
        return batch
    
    # Define the encoder model
    class GraphEncoder(nn.Module):
        def __init__(self, node_dim, edge_dim, hidden_dim=128, output_dim=128):
            super().__init__()
            
            # Feature encoding layers
            self.node_encoder = nn.Sequential(
                nn.Linear(node_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            )
            
            self.edge_encoder = nn.Sequential(
                nn.Linear(edge_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, hidden_dim)
            )
            
            # Graph convolution layers
            self.conv1 = GCNConv(hidden_dim, hidden_dim)
            self.conv2 = GCNConv(hidden_dim, hidden_dim)
            self.conv3 = GCNConv(hidden_dim, output_dim)
            
            # Projection head for contrastive learning
            self.projection = nn.Sequential(
                nn.Linear(output_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, output_dim)
            )
        
        def normalize_features(self, x_cat, x_phys):
            """Normalize categorical and physical features separately"""
            # Convert categorical features to one-hot
            x_cat = x_cat.float()
            
            # Normalize physical features
            x_phys = x_phys.float()
            if x_phys.size(0) > 1:  # Only normalize if we have more than one sample
                x_phys = (x_phys - x_phys.mean(0)) / (x_phys.std(0) + 1e-5)
                
            return x_cat, x_phys
        
        def forward(self, data):
            # Normalize features
            x_cat, x_phys = self.normalize_features(data.x_cat, data.x_phys)
            
            # Concatenate features
            x = torch.cat([x_cat, x_phys], dim=-1)
            
            edge_index = data.edge_index
            edge_attr = data.edge_attr.float()
            batch = data.batch
            
            # Initial feature encoding
            x = self.node_encoder(x)
            edge_attr = self.edge_encoder(edge_attr)
            
            # Graph convolutions
            x = F.relu(self.conv1(x, edge_index))
            x = F.relu(self.conv2(x, edge_index))
            x = self.conv3(x, edge_index)
            
            # Global pooling
            x = global_mean_pool(x, batch)
            
            # Projection
            x = self.projection(x)
            
            return x
    
    # Create data loader with custom collate function
    train_loader = DataLoader(augmented_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)
    
    # Calculate input dimensions
    sample_data = augmented_dataset[0]
    node_dim = sample_data.x_cat.shape[1] + sample_data.x_phys.shape[1]
    edge_dim = sample_data.edge_attr.shape[1]
    
    # Initialize model
    model = GraphEncoder(node_dim, edge_dim, hidden_dim, output_dim).to(device)
    
    # Define optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # Define contrastive loss function
    def contrastive_loss(query, key, temperature=0.07):
        # Normalize embeddings
        query = F.normalize(query, dim=1)
        key = F.normalize(key, dim=1)
        
        # Calculate similarity
        logits = torch.mm(query, key.T) / temperature
        
        # Set labels to be the positive samples (diagonal elements)
        labels = torch.arange(logits.shape[0], device=device)
        
        return F.cross_entropy(logits, labels)
    
    # Training loop
    print("Training encoder with manual augmentations...")
    model.train()
    
    for epoch in range(epochs):
        epoch_loss = 0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            batch = batch.to(device)
            
            # Forward pass
            embeddings = model(batch)
            
            # Create positive pairs by splitting the batch
            batch_size = embeddings.size(0)
            half_size = batch_size // 2
            
            if half_size > 0:
                # Split embeddings
                query_emb = embeddings[:half_size]
                key_emb = embeddings[half_size:2*half_size]
                
                # Calculate loss for this mini-batch
                loss = contrastive_loss(query_emb, key_emb)
                
                # Backward pass and optimization
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(train_loader)
        print(f'Epoch {epoch+1}, Avg Loss: {avg_loss:.4f}')
        
        # Save checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
            }, os.path.join(output_dir, f'manual_encoder_checkpoint_{epoch+1}.pt'))
    
    # Extract embeddings for the original dataset
    model.eval()
    
    # Create a custom collate function for the original dataset too
    def original_collate(data_list):
        keys_to_exclude = ['smiles']  # Only exclude 'smiles' for original data
        
        # Store the excluded attributes
        excluded_data = {key: [getattr(data, key, None) for data in data_list] for key in keys_to_exclude}
        
        # Create batch without excluded keys
        batch = Batch.from_data_list(data_list)
        
        # Add back the excluded attributes as lists
        for key, values in excluded_data.items():
            setattr(batch, key, values)
            
        return batch
    
    original_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=original_collate)
    all_embeddings = []
    
    with torch.no_grad():
        for batch in tqdm(original_loader, desc="Extracting embeddings"):
            batch = batch.to(device)
            embeddings = model(batch)
            all_embeddings.append(embeddings.cpu().numpy())
    
    all_embeddings = np.vstack(all_embeddings)
    
    # Save the embeddings
    embeddings_data = {
        'embeddings': all_embeddings,
        'molecule_data': [{
            'smiles': getattr(data, 'smiles', ''),
            'num_nodes': data.num_nodes,
            'edge_index': data.edge_index.tolist() if hasattr(data, 'edge_index') else None,
            'x_cat': data.x_cat.tolist() if hasattr(data, 'x_cat') else None,
            'x_phys': data.x_phys.tolist() if hasattr(data, 'x_phys') else None,
            'edge_attr': data.edge_attr.tolist() if hasattr(data, 'edge_attr') else None
        } for data in dataset]
    }
    
    with open(os.path.join(output_dir, 'manual_embeddings.pkl'), 'wb') as f:
        pickle.dump(embeddings_data, f)
    
    print(f"Manual embeddings saved to {os.path.join(output_dir, 'manual_embeddings.pkl')}")
    
    return model, all_embeddings

def convert_pyg_to_smiles(pyg_data):
    """Attempt to convert PyG data to SMILES string using RDKit"""
    try:
        # Extract node features and edge information
        x_cat = pyg_data.x_cat.numpy() if hasattr(pyg_data, 'x_cat') else None
        edge_index = pyg_data.edge_index.numpy() if hasattr(pyg_data, 'edge_index') else None
        
        if x_cat is None or edge_index is None:
            return ''
        
        # Create empty RDKit molecule
        mol = Chem.RWMol()
        
        # Add atoms (assuming first column of x_cat is atomic number)
        for atom_feat in x_cat:
            atom_num = int(atom_feat[0])
            atom = Chem.Atom(atom_num)
            mol.AddAtom(atom)
        
        # Add bonds (assuming edges represent single bonds for simplicity)
        for i in range(edge_index.shape[1]):
            src, dst = edge_index[0, i], edge_index[1, i]
            # Avoid adding duplicate bonds (only add if src < dst)
            if src < dst:
                mol.AddBond(int(src), int(dst), Chem.BondType.SINGLE)
        
        # Convert to molecule and get SMILES
        mol = mol.GetMol()
        return Chem.MolToSmiles(mol)
    except Exception as e:
        print(f"Error converting PyG data to SMILES: {e}")
        return ''



In [4]:
def extract_molecule_metadata(dataset):
    """Extract metadata from PyG graph data without relying on SMILES strings"""
    from rdkit import Chem
    from rdkit.Chem import Descriptors
    from tqdm import tqdm
    import numpy as np
    import torch
    import networkx as nx
    from collections import defaultdict
    
    metadata = []
    
    for i, data in enumerate(tqdm(dataset, desc="Extracting molecule metadata")):
        # Set graph ID
        mol_id = f"molecule_{i}"
        
        # Initialize empty dictionaries for metadata
        properties = {}
        features = {}
        functional_groups = {}
        ring_info = {"ring_counts": {}, "ring_sizes": {}}
        
        # Extract basic graph properties directly from the PyG data
        if hasattr(data, 'num_nodes') and hasattr(data, 'edge_index'):
            try:
                # Convert to networkx graph for analysis
                G = to_networkx(data)
                
                # Calculate graph-level properties
                num_edges = data.edge_index.size(1) // 2  # Count unique edges
                properties = {
                    "num_nodes": data.num_nodes,
                    "num_edges": num_edges,
                    "avg_node_degree": 2 * num_edges / data.num_nodes if data.num_nodes > 0 else 0
                }
                
                # Calculate average path length if graph is connected
                if nx.is_connected(G):
                    try:
                        properties["avg_path_length"] = nx.average_shortest_path_length(G)
                    except:
                        properties["avg_path_length"] = 0.0
                else:
                    properties["avg_path_length"] = 0.0
                
                # Add more sophisticated graph properties
                try:
                    properties["clustering_coefficient"] = nx.average_clustering(G)
                except:
                    properties["clustering_coefficient"] = 0.0
                
                try:
                    properties["graph_diameter"] = nx.diameter(G) if nx.is_connected(G) else 0
                except:
                    properties["graph_diameter"] = 0

                try:
                    properties["assortativity"] = nx.degree_assortativity_coefficient(G)
                except:
                    properties["assortativity"] = 0.0
                
                # Graph features
                features = {
                    "is_connected": nx.is_connected(G),
                    "num_connected_components": nx.number_connected_components(G),
                    "has_cycles": not nx.is_tree(G),
                    "max_degree": max(dict(G.degree()).values()) if G.number_of_nodes() > 0 else 0,
                    "density": nx.density(G),
                    "is_bipartite": nx.is_bipartite(G) if G.number_of_nodes() > 0 else False
                }
                
                # Get centrality measures
                if G.number_of_nodes() > 0:
                    try:
                        degree_centrality = nx.degree_centrality(G)
                        features["max_centrality"] = max(degree_centrality.values()) if degree_centrality else 0
                        features["avg_centrality"] = sum(degree_centrality.values()) / len(degree_centrality) if degree_centrality else 0
                    except:
                        features["max_centrality"] = 0
                        features["avg_centrality"] = 0
                else:
                    features["max_centrality"] = 0
                    features["avg_centrality"] = 0
                
                # Analyze node features if available
                if hasattr(data, 'x_cat') and hasattr(data, 'x_phys'):
                    # Atomic element distribution (from x_cat)
                    atom_types = {}
                    if data.x_cat.size(1) > 0:
                        for i in range(data.num_nodes):
                            atom_type = int(data.x_cat[i, 0].item())
                            atom_types[atom_type] = atom_types.get(atom_type, 0) + 1
                    
                    features["atom_type_distribution"] = atom_types
                    
                    # Physical property statistics (from x_phys)
                    if data.x_phys.size(1) > 0:
                        phys_means = data.x_phys.mean(dim=0).tolist() 
                        phys_stds = data.x_phys.std(dim=0).tolist()
                        
                        # Map indices to meaningful property names for the first few common properties
                        phys_prop_names = ['contrib_mw', 'contrib_logp', 'formal_charge', 
                                        'hybridization', 'is_aromatic', 'num_h', 'valence', 'degree']
                        
                        for idx, name in enumerate(phys_prop_names):
                            if idx < len(phys_means):
                                properties[f"avg_{name}"] = phys_means[idx]
                                properties[f"std_{name}"] = phys_stds[idx]
                
                # Cycle analysis
                cycles = list(nx.cycle_basis(G))
                cycle_count = len(cycles)
                ring_info["ring_counts"]["total"] = cycle_count
                
                # Count rings by size
                ring_sizes = defaultdict(int)
                for cycle in cycles:
                    size = len(cycle)
                    ring_sizes[str(size)] = ring_sizes.get(str(size), 0) + 1
                
                # Ensure we have entries for common ring sizes
                for size in range(3, 11):
                    if str(size) not in ring_sizes:
                        ring_sizes[str(size)] = 0
                
                ring_info["ring_sizes"] = dict(ring_sizes)
                
                # Estimate ring types
                ring_info["ring_counts"]["single"] = 0
                ring_info["ring_counts"]["fused"] = 0
                
                # Identify single vs fused rings by checking for shared nodes
                if cycles:
                    # Build a mapping of nodes to cycles they belong to
                    node_to_cycles = defaultdict(list)
                    for cycle_idx, cycle in enumerate(cycles):
                        for node in cycle:
                            node_to_cycles[node].append(cycle_idx)
                    
                    # Count single rings (no shared nodes with other rings)
                    shared_cycles = set()
                    for node, cycle_list in node_to_cycles.items():
                        if len(cycle_list) > 1:
                            for c in cycle_list:
                                shared_cycles.add(c)
                    
                    ring_info["ring_counts"]["single"] = cycle_count - len(shared_cycles)
                    ring_info["ring_counts"]["fused"] = len(shared_cycles)
                
                # Edge feature analysis if available
                if hasattr(data, 'edge_attr') and data.edge_attr.size(0) > 0:
                    # Analyze bond types (assuming first dimension is bond type)
                    bond_types = {}
                    for i in range(data.edge_attr.size(0)):
                        if data.edge_attr.size(1) > 0:
                            bond_type = int(data.edge_attr[i, 0].item())
                            bond_types[bond_type] = bond_types.get(bond_type, 0) + 1
                    
                    # Divide by 2 since each bond is counted twice in undirected graph
                    for bt in bond_types:
                        bond_types[bt] = bond_types[bt] // 2
                    
                    functional_groups["bond_types"] = bond_types
                    
                    # Count functional group proxies based on patterns in the graph
                    # This is just an estimate since we don't have chemical information
                    conjugated_bonds = 0
                    for i in range(data.edge_attr.size(0)):
                        if data.edge_attr.size(1) > 1 and data.edge_attr[i, 2].item() > 0:  # IsConjugated flag
                            conjugated_bonds += 1
                    
                    functional_groups["conjugated_bonds"] = conjugated_bonds // 2
                    
            except Exception as e:
                # If any error occurs during analysis, use minimal information
                print(f"Error analyzing graph {i}: {e}")
        
        metadata.append({
            "graph_id": mol_id,
            "properties": properties,
            "features": features,
            "functional_groups": functional_groups,
            "ring_info": ring_info
        })
    
    return metadata

def to_networkx(data):
    """Convert PyG data to networkx graph for analysis"""
    import networkx as nx
    
    G = nx.Graph()
    
    # Add nodes
    for i in range(data.num_nodes):
        G.add_node(i)
    
    # Add edges (removing duplicates and self-loops)
    edge_index = data.edge_index.cpu().numpy()
    edges = set()
    for i in range(edge_index.shape[1]):
        u, v = edge_index[0, i], edge_index[1, i]
        if u != v and (u, v) not in edges and (v, u) not in edges:
            G.add_edge(u, v)
            edges.add((u, v))
    
    return G

In [5]:
def save_embedding_file(embeddings, molecule_indices, training_info, model_config, filepath):
    """Save embeddings with training metadata"""
    data = {
        "embeddings": embeddings,
        "molecule_indices": molecule_indices,
        "training_info": training_info,
        "model_config": {k: v for k, v in model_config.__dict__.items() 
                         if not k.startswith('_') and not callable(v)}
    }
    
    with open(filepath, 'wb') as f:
        pickle.dump(data, f)
        
def save_embeddings_with_molecules(embeddings, dataset, filepath):
    """Save embeddings with corresponding molecule information and graph-level properties"""
    # Create a list to store molecule data
    molecule_data = []
    
    # Extract important info from each molecule in the dataset
    for data in dataset:
        # Create a dictionary with basic graph properties
        mol_info = {
            "num_nodes": data.num_nodes,
            "edge_index": data.edge_index.tolist() if hasattr(data, 'edge_index') else None,
            "x_cat": data.x_cat.tolist() if hasattr(data, 'x_cat') else None,
            "x_phys": data.x_phys.tolist() if hasattr(data, 'x_phys') else None,
            "edge_attr": data.edge_attr.tolist() if hasattr(data, 'edge_attr') else None,
            "smiles": data.smiles if hasattr(data, 'smiles') else ""  # This is missing in your code
        }
        
        # Calculate additional graph properties if possible
        try:
            if hasattr(data, 'edge_index') and hasattr(data, 'num_nodes'):
                # Graph density
                num_edges = len(data.edge_index[0]) // 2  # Undirected edges counted once
                max_edges = data.num_nodes * (data.num_nodes - 1) // 2
                density = num_edges / max_edges if max_edges > 0 else 0
                mol_info["graph_density"] = density
                
                # Average degree
                avg_degree = num_edges * 2 / data.num_nodes if data.num_nodes > 0 else 0
                mol_info["avg_degree"] = avg_degree
                
                # Count atom types if available
                if hasattr(data, 'x_cat') and data.x_cat is not None:
                    atom_types = {}
                    for atom in data.x_cat:
                        atom_type = int(atom[0])
                        atom_types[atom_type] = atom_types.get(atom_type, 0) + 1
                    mol_info["atom_type_counts"] = atom_types
                
                # Count bond types if available
                if hasattr(data, 'edge_attr') and data.edge_attr is not None:
                    bond_types = {}
                    for bond in data.edge_attr:
                        bond_type = int(bond[0])
                        bond_types[bond_type] = bond_types.get(bond_type, 0) + 1
                    mol_info["bond_type_counts"] = bond_types
        except:
            # If calculation fails, continue without these properties
            pass
            
        molecule_data.append(mol_info)
    
    # Save both embeddings and molecule data
    with open(filepath, 'wb') as f:
        pickle.dump({
            'embeddings': embeddings,
            'molecule_data': molecule_data,
            'graph_properties': True,  # Flag to indicate enhanced properties are stored
            'smiles_list': [data.smiles for data in dataset if hasattr(data, 'smiles')] # This is missing
        }, f)
    
    print(f"Saved embeddings and molecule data with graph properties to {filepath}")
        

def save_embeddings(embeddings, labels, filepath):
    """Save embeddings and corresponding labels"""
    with open(filepath, 'wb') as f:
        pickle.dump({
            'embeddings': embeddings,
            'labels': labels
        }, f)

def save_encoder(encoder, save_path, info=None):
    """Save encoder model for downstream tasks"""
    save_dict = {
        'encoder_state_dict': encoder.state_dict(),
        'model_info': info or {}
    }
    torch.save(save_dict, save_path)

def load_encoder(model_path, device='cpu'):
    """Load saved encoder model"""
    checkpoint = torch.load(model_path, map_location=device)
    encoder = GraphDiscriminator(
        node_dim=checkpoint['model_info'].get('node_dim'),
        edge_dim=checkpoint['model_info'].get('edge_dim'),
        hidden_dim=checkpoint['model_info'].get('hidden_dim', 128),
        output_dim=checkpoint['model_info'].get('output_dim', 128)
    )
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    return encoder        
        


In [6]:
def main():
    # Enable anomaly detection during development
    torch.autograd.set_detect_anomaly(True)
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    
    # Create timestamp for file naming
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    
    print("Starting data loading...")
    extractor = MolecularFeatureExtractor()
#     smiles_file = "D:\\PhD\\Chapter3\\Unsupervised_GAN_Code\\pubchem-10m-clean_test10k.txt"
    smiles_file = "D:\\PhD\\Chapter3\\Unsupervised_GAN_Code\\pubchem-41-clean.txt"
    
    dataset = []
    failed_smiles = []
    smiles_list = []
    
    with open(smiles_file, 'r') as f:
        for line in f:
            smiles = line.strip()
            smiles_list.append(smiles)
            data = extractor.process_molecule(smiles)
            if data is not None:
                # Store SMILES as an attribute for later use
                data.smiles = smiles
                dataset.append(data)
            else:
                failed_smiles.append(smiles)
    
    print(f"1. Loaded dataset with {len(dataset)} graphs.")
    print(f"2. Failed SMILES count: {len(failed_smiles)}")
    
    if not dataset:
        print("No valid graphs generated.")
        return None
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"3. Using device: {device}")
    
    # Create output directories
    save_dir = './checkpoints_manual'
    embedding_dir = './embeddings'
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(embedding_dir, exist_ok=True)
    
    # Initialize the augmenter
    print("4. Creating molecular augmenter...")
    augmenter = MolecularAugmenter(smiles_list, extractor)
    
    # Generate augmented dataset
    print("5. Generating augmented molecules...")
    augmented_dataset = augmenter.generate_augmented_dataset(num_augmentations=1)
    print(f"   - Generated {len(augmented_dataset)} molecules (including originals)")
    
    # Setup training parameters
    batch_size = 32
    hidden_dim = 128
    output_dim = 128
    epochs = 50
    lr = 1e-4
    
    # Train the encoder with manual augmentations
    print("6. Training encoder with manual augmentations...")
    model, embeddings = train_encoder_with_manual_augmentations(
        dataset,
        extractor,
        output_dir=embedding_dir,
        batch_size=batch_size,
        epochs=epochs,
        lr=lr,
        hidden_dim=hidden_dim,
        output_dim=output_dim,
        device=device
    )
    
    print("7. Training completed!")
    
    # Save additional metadata
    metadata_dir = os.path.join(embedding_dir, 'metadata')
    os.makedirs(metadata_dir, exist_ok=True)
    
    # Extract and save molecule metadata
    print("8. Extracting molecule metadata...")
    metadata = extract_molecule_metadata(dataset)
    with open(os.path.join(metadata_dir, f'molecule_metadata_manual_{timestamp}.pkl'), 'wb') as f:
        pickle.dump(metadata, f)
    
    # Save the final model
    print("9. Saving final model...")
    torch.save({
        'model_state_dict': model.state_dict(),
        'config': {
            'node_dim': model.node_encoder[0].in_features,
            'edge_dim': model.edge_encoder[0].in_features,
            'hidden_dim': hidden_dim,
            'output_dim': output_dim
        }
    }, os.path.join(save_dir, 'final_manual_encoder.pt'))
    
    # Save embeddings with molecules
    final_embedding_path = f'./embeddings/final_embeddings_molecules_{timestamp}.pkl'

    # Create a dictionary with embeddings and molecule data that includes SMILES
    embedding_data = {
        'embeddings': embeddings,
        'molecule_data': [{
            'smiles': getattr(data, 'smiles', ''),  # Extract SMILES if available
            'num_nodes': data.num_nodes,
            'edge_index': data.edge_index.tolist() if hasattr(data, 'edge_index') else None,
            'x_cat': data.x_cat.tolist() if hasattr(data, 'x_cat') else None,
            'x_phys': data.x_phys.tolist() if hasattr(data, 'x_phys') else None,
            'edge_attr': data.edge_attr.tolist() if hasattr(data, 'edge_attr') else None
        } for data in dataset],
        'smiles_list': [getattr(data, 'smiles', '') for data in dataset]  # Explicit SMILES list
    }

    with open(final_embedding_path, 'wb') as f:
        pickle.dump(embedding_data, f)
    
    print(f"10. Final embeddings saved to {final_embedding_path}")
    print(f"11. Manual encoder saved to {os.path.join(save_dir, 'final_manual_encoder.pt')}")
    
    return model, embeddings, dataset

if __name__ == "__main__":
    model, embeddings, dataset = main()

Starting data loading...
1. Loaded dataset with 41 graphs.
2. Failed SMILES count: 0
3. Using device: cpu
4. Creating molecular augmenter...
5. Generating augmented molecules...


Generating augmentations: 100%|████████████████████████████████████████████████████████| 41/41 [00:05<00:00,  7.52it/s]


   - Generated 82 molecules (including originals)
6. Training encoder with manual augmentations...


Generating augmentations: 100%|████████████████████████████████████████████████████████| 41/41 [00:05<00:00,  7.42it/s]


Generated 82 augmented molecules (including originals)
Training encoder with manual augmentations...


Epoch 1/50: 100%|████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 12.08it/s]


Epoch 1, Avg Loss: 2.5647


Epoch 2/50: 100%|████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 15.09it/s]


Epoch 2, Avg Loss: 2.5957


Epoch 3/50: 100%|████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 14.18it/s]


Epoch 3, Avg Loss: 2.5956


Epoch 4/50: 100%|████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 13.06it/s]


Epoch 4, Avg Loss: 2.6001


Epoch 5/50: 100%|████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 12.44it/s]


Epoch 5, Avg Loss: 2.5484


Epoch 6/50: 100%|████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 12.82it/s]


Epoch 6, Avg Loss: 2.5728


Epoch 7/50: 100%|████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 14.37it/s]


Epoch 7, Avg Loss: 2.5765


Epoch 8/50: 100%|████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 13.47it/s]


Epoch 8, Avg Loss: 2.5940


Epoch 9/50: 100%|████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 14.05it/s]


Epoch 9, Avg Loss: 2.5829


Epoch 10/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 12.28it/s]


Epoch 10, Avg Loss: 2.5872


Epoch 11/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 16.00it/s]


Epoch 11, Avg Loss: 2.5806


Epoch 12/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 13.34it/s]


Epoch 12, Avg Loss: 2.5941


Epoch 13/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 14.26it/s]


Epoch 13, Avg Loss: 2.5735


Epoch 14/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 14.53it/s]


Epoch 14, Avg Loss: 2.5929


Epoch 15/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 12.89it/s]


Epoch 15, Avg Loss: 2.5822


Epoch 16/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 12.79it/s]


Epoch 16, Avg Loss: 2.5834


Epoch 17/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 14.01it/s]


Epoch 17, Avg Loss: 2.5823


Epoch 18/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 14.10it/s]


Epoch 18, Avg Loss: 2.5782


Epoch 19/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 14.74it/s]


Epoch 19, Avg Loss: 2.5806


Epoch 20/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 13.87it/s]


Epoch 20, Avg Loss: 2.5824


Epoch 21/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 13.07it/s]


Epoch 21, Avg Loss: 2.5783


Epoch 22/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 13.18it/s]


Epoch 22, Avg Loss: 2.5758


Epoch 23/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 14.35it/s]


Epoch 23, Avg Loss: 2.5869


Epoch 24/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 13.45it/s]


Epoch 24, Avg Loss: 2.5821


Epoch 25/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 14.06it/s]


Epoch 25, Avg Loss: 2.5805


Epoch 26/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 14.18it/s]


Epoch 26, Avg Loss: 2.5871


Epoch 27/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 11.85it/s]


Epoch 27, Avg Loss: 2.5762


Epoch 28/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 14.24it/s]


Epoch 28, Avg Loss: 2.5806


Epoch 29/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 14.01it/s]


Epoch 29, Avg Loss: 2.5766


Epoch 30/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 13.80it/s]


Epoch 30, Avg Loss: 2.5796


Epoch 31/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 13.75it/s]


Epoch 31, Avg Loss: 2.5851


Epoch 32/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 13.44it/s]


Epoch 32, Avg Loss: 2.5735


Epoch 33/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 11.55it/s]


Epoch 33, Avg Loss: 2.5886


Epoch 34/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 13.80it/s]


Epoch 34, Avg Loss: 2.5759


Epoch 35/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 13.73it/s]


Epoch 35, Avg Loss: 2.5778


Epoch 36/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 14.04it/s]


Epoch 36, Avg Loss: 2.5752


Epoch 37/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 14.71it/s]


Epoch 37, Avg Loss: 2.5894


Epoch 38/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 13.28it/s]


Epoch 38, Avg Loss: 2.5773


Epoch 39/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 12.93it/s]


Epoch 39, Avg Loss: 2.5880


Epoch 40/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 13.45it/s]


Epoch 40, Avg Loss: 2.5914


Epoch 41/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 14.26it/s]


Epoch 41, Avg Loss: 2.5825


Epoch 42/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 14.07it/s]


Epoch 42, Avg Loss: 2.5882


Epoch 43/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 13.33it/s]


Epoch 43, Avg Loss: 2.5816


Epoch 44/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 12.23it/s]


Epoch 44, Avg Loss: 2.5738


Epoch 45/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 13.74it/s]


Epoch 45, Avg Loss: 2.5858


Epoch 46/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 14.43it/s]


Epoch 46, Avg Loss: 2.5807


Epoch 47/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 15.32it/s]


Epoch 47, Avg Loss: 2.5862


Epoch 48/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 13.50it/s]


Epoch 48, Avg Loss: 2.5834


Epoch 49/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 12.58it/s]


Epoch 49, Avg Loss: 2.5751


Epoch 50/50: 100%|███████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 13.47it/s]


Epoch 50, Avg Loss: 2.5831


Extracting embeddings: 100%|████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 158.48it/s]


Manual embeddings saved to ./embeddings\manual_embeddings.pkl
7. Training completed!
8. Extracting molecule metadata...


Extracting molecule metadata: 100%|███████████████████████████████████████████████████| 41/41 [00:00<00:00, 211.02it/s]

9. Saving final model...
10. Final embeddings saved to ./embeddings/final_embeddings_molecules_20250310_143038.pkl
11. Manual encoder saved to ./checkpoints_manual\final_manual_encoder.pt



