In [1]:
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
import random
import numpy as np
import matplotlib.pyplot as plt
from rdkit.Chem import DataStructs
from rdkit.Chem.AllChem import GetMorganFingerprintAsBitVect
import warnings
from PIL import Image
import io

# Suppress warnings
warnings.filterwarnings('ignore', category=DeprecationWarning)

class MolecularAugmenter:
    @staticmethod
    def manual_augmentations(mol, num_augmentations=5):
        """
        Generate manual augmentations for a molecule
        
        Augmentation techniques:
        1. Atom Substitution
        2. Bond Modification
        3. Functional Group Transformation
        4. Stereochemistry Changes
        5. Structural Modifications
        """
        augmented_mols = []
        
        # Atom Substitution
        atom_replacements = {
            'C': ['Si', 'N'],      # Carbon to Silicon or Nitrogen
            'N': ['P', 'C'],        # Nitrogen to Phosphorus or Carbon
            'O': ['S', 'Se']        # Oxygen to Sulfur or Selenium
        }
        
        for atom in mol.GetAtoms():
            symbol = atom.GetSymbol()
            if symbol in atom_replacements:
                try:
                    new_mol = Chem.Mol(mol)
                    new_mol_edit = Chem.EditableMol(new_mol)
                    
                    # Replace atom with another from the group
                    new_symbol = random.choice(atom_replacements[symbol])
                    new_atomic_num = Chem.GetPeriodicTable().GetAtomicNumber(new_symbol)
                    
                    new_mol_edit.ReplaceAtom(atom.GetIdx(), Chem.Atom(new_atomic_num))
                    
                    new_mol = new_mol_edit.GetMol()
                    Chem.SanitizeMol(new_mol)
                    augmented_mols.append(new_mol)
                except:
                    pass
        
        # Bond Modification
        try:
            for bond in mol.GetBonds():
                new_mol = Chem.Mol(mol)
                new_mol_edit = Chem.EditableMol(new_mol)
                
                # Change bond types
                if bond.GetBondType() == Chem.BondType.SINGLE:
                    new_mol_edit.ReplaceBond(bond.GetIdx(), Chem.BondType.DOUBLE)
                elif bond.GetBondType() == Chem.BondType.DOUBLE:
                    new_mol_edit.ReplaceBond(bond.GetIdx(), Chem.BondType.TRIPLE)
                
                try:
                    new_mol = new_mol_edit.GetMol()
                    Chem.SanitizeMol(new_mol)
                    augmented_mols.append(new_mol)
                except:
                    pass
        except:
            pass
        
        # Functional Group Transformation
        functional_group_transforms = [
            ('[OH]', '[OC]'),        # Alcohol to ether
            ('C(=O)[OH]', 'C(=O)OC'), # Carboxylic acid to ester
            ('[NH2]', '[NH]C(=O)C')   # Amine to amide
        ]
        
        for old_pattern, new_pattern in functional_group_transforms:
            try:
                old_pattern_mol = Chem.MolFromSmarts(old_pattern)
                if mol.HasSubstructMatch(old_pattern_mol):
                    smiles = Chem.MolToSmiles(mol)
                    new_smiles = smiles.replace(old_pattern, new_pattern, 1)
                    new_mol = Chem.MolFromSmiles(new_smiles)
                    if new_mol:
                        augmented_mols.append(new_mol)
            except:
                pass
        
        # Unique augmentations
        unique_augmented_mols = []
        for aug_mol in augmented_mols:
            try:
                Chem.SanitizeMol(aug_mol)
                smiles = Chem.MolToSmiles(aug_mol)
                if not any(Chem.MolToSmiles(umol) == smiles for umol in unique_augmented_mols):
                    unique_augmented_mols.append(aug_mol)
            except:
                pass
        
        return unique_augmented_mols[:5]  # Limit to 5 augmentations

    @staticmethod
    def gan_augmentations(mol, num_augmentations=5):
        """
        Simulate GAN-based augmentations
        
        Note: This is a placeholder mimicking potential GAN augmentations
        """
        augmented_mols = []
        
        # GAN-like augmentation strategies
        atom_modifications = {
            'C': ['Si', 'N'],
            'N': ['P', 'C'],
            'O': ['S', 'Se']
        }
        
        for _ in range(num_augmentations):
            try:
                # Create a copy of the molecule
                new_mol = Chem.Mol(mol)
                edit_mol = Chem.EditableMol(new_mol)
                
                # Randomly modify an atom
                atom_idx = random.randint(0, mol.GetNumAtoms() - 1)
                atom = new_mol.GetAtomWithIdx(atom_idx)
                
                # Replace with a different atom
                if atom.GetSymbol() in atom_modifications:
                    new_symbol = random.choice(atom_modifications[atom.GetSymbol()])
                    new_atomic_num = Chem.GetPeriodicTable().GetAtomicNumber(new_symbol)
                    edit_mol.ReplaceAtom(atom_idx, Chem.Atom(new_atomic_num))
                
                # Modify a bond
                if mol.GetNumBonds() > 0:
                    bond_idx = random.randint(0, mol.GetNumBonds() - 1)
                    bond = mol.GetBondWithIdx(bond_idx)
                    
                    if bond.GetBondType() == Chem.BondType.SINGLE:
                        edit_mol.ReplaceBond(bond_idx, Chem.BondType.DOUBLE)
                
                # Create new molecule
                new_mol = edit_mol.GetMol()
                Chem.SanitizeMol(new_mol)
                
                augmented_mols.append(new_mol)
            except:
                pass
        
        return augmented_mols

    @staticmethod
    def compute_similarities(query_mol, augmented_mols):
        """Compute Morgan Fingerprint similarities"""
        query_fp = GetMorganFingerprintAsBitVect(query_mol, 2, nBits=2048)
        mol_fps = [GetMorganFingerprintAsBitVect(mol, 2, nBits=2048) for mol in augmented_mols]
        similarities = [DataStructs.TanimotoSimilarity(query_fp, mol_fp) for mol_fp in mol_fps]
        return similarities

def visualize_augmentations(query_smiles):
    """
    Visualize augmentations for manual and GAN approaches
    
    Args:
    - query_smiles (str): SMILES string of the molecule to augment
    
    Returns:
    - Visualization of augmentations
    """
    # Create query molecule
    query_mol = Chem.MolFromSmiles(query_smiles)
    
    # Generate manual augmentations
    manual_augmented_mols = MolecularAugmenter.manual_augmentations(query_mol)
    manual_similarities = MolecularAugmenter.compute_similarities(query_mol, manual_augmented_mols)
    
    # Generate GAN-like augmentations
    gan_augmented_mols = MolecularAugmenter.gan_augmentations(query_mol)
    gan_similarities = MolecularAugmenter.compute_similarities(query_mol, gan_augmented_mols)
    
    # Create manual augmentations grid image
    manual_grid = Draw.MolsToGridImage(
        manual_augmented_mols, 
        molsPerRow=len(manual_augmented_mols), 
        subImgSize=(300, 300),
        legends=[f'Aug {i+1}\nSim: {sim:.3f}' for i, sim in enumerate(manual_similarities)]
    )
    
    # Create GAN augmentations grid image
    gan_grid = Draw.MolsToGridImage(
        gan_augmented_mols, 
        molsPerRow=len(gan_augmented_mols), 
        subImgSize=(300, 300),
        legends=[f'Aug {i+1}\nSim: {sim:.3f}' for i, sim in enumerate(gan_similarities)]
    )
    
    # Save grid images
#     manual_grid.save('manual_augmentations.png')
#     gan_grid.save('gan_augmentations.png')
    
    # Print augmentation details
    print("Manual Augmentations:")
    for i, (mol, sim) in enumerate(zip(manual_augmented_mols, manual_similarities), 1):
        print(f"Aug {i} (Similarity: {sim:.3f}): {Chem.MolToSmiles(mol)}")
    
    print("\nGAN-like Augmentations:")
    for i, (mol, sim) in enumerate(zip(gan_augmented_mols, gan_similarities), 1):
        print(f"Aug {i} (Similarity: {sim:.3f}): {Chem.MolToSmiles(mol)}")

def main():
    # Example molecule (Acetaminophen)
    query_smiles = 'CC(=O)Nc1ccc(N)cc1'
    
    # Visualize augmentations
    visualize_augmentations(query_smiles)

if __name__ == '__main__':
    main()

Manual Augmentations:
Aug 1 (Similarity: 0.640): NC(=O)Nc1ccc(N)cc1
Aug 2 (Similarity: 0.615): CC(=[Se])Nc1ccc(N)cc1
Aug 3 (Similarity: 0.448): CC(=O)Pc1ccc(N)cc1
Aug 4 (Similarity: 0.467): CC(=O)N[si]1ccc(N)cc1
Aug 5 (Similarity: 0.516): CC(=O)Nc1ccc(N)c[siH]1

GAN-like Augmentations:
Aug 1 (Similarity: 0.467): CC(=O)N[si]1ccc(N)cc1
Aug 2 (Similarity: 0.615): CC(=[Se])Nc1ccc(N)cc1


[17:32:48] Explicit valence for atom # 1 N, 4, is greater than permitted
[17:32:48] Can't kekulize mol.  Unkekulized atoms: 4 5 6 9 10
