In [1]:
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
import random
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from rdkit.Chem import DataStructs
from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect
import warnings

# Suppress RDKit warnings
warnings.filterwarnings('ignore', category=DeprecationWarning)

class MolecularAugmenter:
    @staticmethod
    def manual_augmentations(mol, num_augmentations=5):
        """
        Generate manual augmentations for a molecule
        
        Augmentation types:
        1. Add/Remove Hydrogens
        2. Change Bond Order
        3. Enumerate Stereoisomers
        4. Mutate Functional Groups
        5. Change Atom
        6. Add Ring Substituent
        """
        augmented_mols = []
        
        # 1. Add/Remove Hydrogens
        mol_with_h = Chem.AddHs(mol)
        mol_without_h = Chem.RemoveHs(mol)
        augmented_mols.extend([mol_with_h, mol_without_h])
        
        # 2. Change Bond Order
        try:
            for bond in mol.GetBonds():
                new_mol = Chem.Mol(mol)
                new_mol_edit = Chem.EditableMol(new_mol)
                
                # Try changing single to double, double to triple, etc.
                if bond.GetBondType() == Chem.BondType.SINGLE:
                    new_mol_edit.ReplaceBond(bond.GetIdx(), Chem.BondType.DOUBLE)
                elif bond.GetBondType() == Chem.BondType.DOUBLE:
                    new_mol_edit.ReplaceBond(bond.GetIdx(), Chem.BondType.TRIPLE)
                
                try:
                    new_mol = new_mol_edit.GetMol()
                    Chem.SanitizeMol(new_mol)
                    augmented_mols.append(new_mol)
                except:
                    pass
        except:
            pass
        
        # 3. Enumerate Stereoisomers
        try:
            from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers
            stereoisomers = list(EnumerateStereoisomers(mol))
            augmented_mols.extend(stereoisomers)
        except:
            pass
        
        # 4. Mutate Functional Groups
        functional_group_transforms = [
            ('[OH]', '[OC]'),        # Alcohol to ether
            ('C(=O)[OH]', 'C(=O)OC'), # Carboxylic acid to ester
            ('[NH2]', '[NH]C(=O)C')   # Amine to amide
        ]
        
        for old_pattern, new_pattern in functional_group_transforms:
            try:
                old_pattern_mol = Chem.MolFromSmarts(old_pattern)
                if mol.HasSubstructMatch(old_pattern_mol):
                    smiles = Chem.MolToSmiles(mol)
                    new_smiles = smiles.replace(old_pattern, new_pattern, 1)
                    new_mol = Chem.MolFromSmiles(new_smiles)
                    if new_mol:
                        augmented_mols.append(new_mol)
            except:
                pass
        
        # 5. Change Atom 
        atom_replacements = {
            'C': ['Si'],  # Carbon to Silicon
            'N': ['P'],   # Nitrogen to Phosphorus
            'O': ['S']    # Oxygen to Sulfur
        }
        
        for atom in mol.GetAtoms():
            symbol = atom.GetSymbol()
            if symbol in atom_replacements:
                try:
                    new_mol = Chem.Mol(mol)
                    new_mol_edit = Chem.EditableMol(new_mol)
                    
                    # Replace atom with another from the same group
                    new_symbol = random.choice(atom_replacements[symbol])
                    new_atomic_num = Chem.GetPeriodicTable().GetAtomicNumber(new_symbol)
                    
                    new_mol_edit.ReplaceAtom(atom.GetIdx(), Chem.Atom(new_atomic_num))
                    
                    new_mol = new_mol_edit.GetMol()
                    Chem.SanitizeMol(new_mol)
                    augmented_mols.append(new_mol)
                except:
                    pass
        
        # 6. Add Ring Substituent
        substituents = [
            '[C]',    # Methyl
            '[O]',    # Hydroxyl
            '[F]',    # Fluorine
            '[N]'     # Amino group
        ]
        
        try:
            for substituent_smarts in substituents:
                substituent = Chem.MolFromSmarts(substituent_smarts)
                if mol.HasSubstructMatch(Chem.MolFromSmarts('c1ccccc1')):
                    # Add substituent to benzene ring
                    new_smiles = Chem.MolToSmiles(mol).replace('c1ccccc1', f'c1ccccc1{substituent_smarts}')
                    new_mol = Chem.MolFromSmiles(new_smiles)
                    if new_mol:
                        augmented_mols.append(new_mol)
        except:
            pass
        
        # Remove duplicates and sanitize
        unique_augmented_mols = []
        for aug_mol in augmented_mols:
            try:
                Chem.SanitizeMol(aug_mol)
                smiles = Chem.MolToSmiles(aug_mol)
                if not any(Chem.MolToSmiles(umol) == smiles for umol in unique_augmented_mols):
                    unique_augmented_mols.append(aug_mol)
            except:
                pass
        
        # Limit to specified number of augmentations
        return unique_augmented_mols[:num_augmentations]

    @staticmethod
    def gan_augmentations(mol, num_augmentations=5):
        """
        Simulate GAN-based augmentations
        
        Note: This is a placeholder and should be replaced with actual 
        GAN augmentation logic from your trained model
        """
        augmented_mols = []
        
        # Create base molecule with some modifications
        for _ in range(num_augmentations):
            # Simple perturbation - this is just a placeholder
            try:
                # Create a copy of the molecule
                new_mol = Chem.Mol(mol)
                
                # Randomly modify a few atoms or bonds
                edit_mol = Chem.EditableMol(new_mol)
                
                # Randomly change an atom type (for demonstration)
                atom_idx = random.randint(0, mol.GetNumAtoms() - 1)
                atom = new_mol.GetAtomWithIdx(atom_idx)
                
                # Replace with a different atom from the same group
                replacements = {
                    'C': 'Si',
                    'N': 'P',
                    'O': 'S'
                }
                
                if atom.GetSymbol() in replacements:
                    new_symbol = replacements[atom.GetSymbol()]
                    new_atomic_num = Chem.GetPeriodicTable().GetAtomicNumber(new_symbol)
                    edit_mol.ReplaceAtom(atom_idx, Chem.Atom(new_atomic_num))
                
                # Create new molecule
                new_mol = edit_mol.GetMol()
                Chem.SanitizeMol(new_mol)
                
                augmented_mols.append(new_mol)
            except:
                pass
        
        return augmented_mols

    @staticmethod
    def compute_fingerprint_similarities(query_mol, mol_list):
        """Compute Morgan Fingerprint similarities"""
        query_fp = GetMorganFingerprintAsBitVect(query_mol, 2, nBits=2048)
        mol_fps = [GetMorganFingerprintAsBitVect(mol, 2, nBits=2048) for mol in mol_list]
        similarities = [DataStructs.TanimotoSimilarity(query_fp, mol_fp) for mol_fp in mol_fps]
        return similarities

def visualize_augmentations(query_smiles, augmentation_type='manual', num_augmentations=5):
    """
    Visualize augmentations for a given SMILES string
    
    Args:
    - query_smiles (str): SMILES string of the molecule to augment
    - augmentation_type (str): 'manual' or 'gan'
    - num_augmentations (int): Number of augmentations to generate
    """
    # Create query molecule
    query_mol = Chem.MolFromSmiles(query_smiles)
    
    # Generate augmentations
    if augmentation_type == 'manual':
        augmented_mols = MolecularAugmenter.manual_augmentations(query_mol, num_augmentations)
    elif augmentation_type == 'gan':
        augmented_mols = MolecularAugmenter.gan_augmentations(query_mol, num_augmentations)
    else:
        print(f"Augmentation type {augmentation_type} not implemented")
        return
    
    # Compute similarities
    similarities = MolecularAugmenter.compute_fingerprint_similarities(query_mol, augmented_mols)
    
    # Prepare molecules for grid visualization
    mols_to_draw = [query_mol] + augmented_mols
    
    # Generate legends with similarities
    legends = ['Original'] + [f'Aug {i+1}\nSim: {sim:.3f}' for i, sim in enumerate(similarities)]
    
    # Create grid image
    grid_img = Draw.MolsToGridImage(
        mols_to_draw, 
        molsPerRow=3, 
        subImgSize=(300, 300),
        legends=legends
    )
    
    # Save the image using RDKit's drawing method
#     Draw.MolToFile(grid_img, f'{augmentation_type}_molecule_augmentations.png', size=(900, 300))
    
    # Optionally, plot similarity distribution
    plt.figure(figsize=(8, 4))
    plt.subplot(1, 2, 1)
    plt.title(f'{augmentation_type.capitalize()} Augmentation\nFP Similarities')
    sns.histplot(similarities, kde=True, color='skyblue')
    plt.xlabel('Fingerprint Similarity')
    plt.ylabel('Frequency')
    plt.tight_layout()
    plt.savefig(f'{augmentation_type}_similarity_distribution.png')
    plt.close()
    
    # Print out augmented SMILES
    print(f"{augmentation_type.capitalize()} Augmentations:")
    for i, (mol, sim) in enumerate(zip(augmented_mols, similarities), 1):
        print(f"Aug {i} (Similarity: {sim:.3f}): {Chem.MolToSmiles(mol)}")
    
    return mols_to_draw, similarities

def main():
    # Example molecule (Acetaminophen)
    query_smiles = 'CC(=O)Nc1ccc(N)cc1'
    
    # Visualize manual augmentations
    visualize_augmentations(query_smiles, 'manual')
    
    # Visualize GAN-like augmentations
    visualize_augmentations(query_smiles, 'gan')

if __name__ == '__main__':
    main()



Manual Augmentations:
Aug 1 (Similarity: 0.289): [H]c1c([H])c(N([H])C(=O)C([H])([H])[H])c([H])c([H])c1N([H])[H]
Aug 2 (Similarity: 1.000): CC(=O)Nc1ccc(N)cc1
Aug 3 (Similarity: 0.615): Nc1ccc(NC(=O)[SiH3])cc1
Aug 4 (Similarity: 0.448): C[Si](=O)Nc1ccc(N)cc1
Aug 5 (Similarity: 0.615): CC(=S)Nc1ccc(N)cc1
Gan Augmentations:
Aug 1 (Similarity: 0.448): CC(=O)Pc1ccc(N)cc1
Aug 2 (Similarity: 0.516): CC(=O)Nc1ccc(N)c[siH]1
Aug 3 (Similarity: 0.516): CC(=O)Nc1ccc(N)c[siH]1
Aug 4 (Similarity: 0.615): CC(=O)Nc1ccc(P)cc1
Aug 5 (Similarity: 0.615): CC(=O)Nc1ccc(P)cc1
