In [1]:
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import AllChem
import random

def visualize_molecules(original_smiles, augmented_smiles_list):
    """
    Visualize original and augmented molecules
    
    Args:
    - original_smiles (str): Original molecule SMILES
    - augmented_smiles_list (list): List of augmented molecule SMILES
    """
    # Convert SMILES to RDKit molecules
    original_mol = Chem.MolFromSmiles(original_smiles)
    
    # Convert augmented SMILES to molecules
    augmented_mols = [Chem.MolFromSmiles(smiles) for smiles in augmented_smiles_list]
    
    # Prepare molecules for drawing (original + augmentations)
    mols_to_draw = [original_mol] + augmented_mols
    
    # Create legends
    legends = ['Original'] + [f'Aug {i+1}' for i in range(len(augmented_mols))]
    
    # Generate grid image
    img = Draw.MolsToGridImage(
        mols_to_draw, 
        molsPerRow=len(mols_to_draw),  # Adjust based on number of molecules
        subImgSize=(300, 300),  # Adjust image size as needed
        legends=legends
    )
    
    # Return the image (you can save or display as needed)
    return img

def generate_augmentations(original_smiles, num_augmentations=5):
    """
    Generate simple augmentations for a molecule
    
    Args:
    - original_smiles (str): Original molecule SMILES
    - num_augmentations (int): Number of augmentations to generate
    
    Returns:
    - List of augmented SMILES
    """
    # Atom replacement strategies
    atom_replacements = {
        'C': ['Si', 'N'],
        'N': ['P', 'C'],
        'O': ['S', 'Se']
    }
    
    # Molecule to modify
    mol = Chem.MolFromSmiles(original_smiles)
    augmented_smiles = []
    
    for _ in range(num_augmentations):
        try:
            # Create an editable molecule
            edit_mol = Chem.EditableMol(mol)
            
            # Randomly select an atom to modify
            atom_idx = random.randint(0, mol.GetNumAtoms() - 1)
            atom = mol.GetAtomWithIdx(atom_idx)
            
            # Get possible replacements for this atom
            if atom.GetSymbol() in atom_replacements:
                # Choose a random replacement
                new_symbol = random.choice(atom_replacements[atom.GetSymbol()])
                new_atomic_num = Chem.GetPeriodicTable().GetAtomicNumber(new_symbol)
                
                # Replace the atom
                edit_mol.ReplaceAtom(atom_idx, Chem.Atom(new_atomic_num))
                
                # Create new molecule
                new_mol = edit_mol.GetMol()
                Chem.SanitizeMol(new_mol)
                
                # Convert to SMILES and add to list
                augmented_smiles.append(Chem.MolToSmiles(new_mol))
        except Exception as e:
            print(f"Augmentation error: {e}")
    
    return augmented_smiles

def main():
    # Example molecule (Acetaminophen)
    original_smiles = 'CC(=O)Nc1ccc(N)cc1'
    
    # Generate augmentations
    augmented_smiles = generate_augmentations(original_smiles)
    
    # Visualize molecules
    img = visualize_molecules(original_smiles, augmented_smiles)
    
    # Print augmented SMILES
    print("Original SMILES:", original_smiles)
    print("\nAugmented SMILES:")
    for i, smiles in enumerate(augmented_smiles, 1):
        print(f"Aug {i}: {smiles}")

if __name__ == '__main__':
    main()

Augmentation error: Explicit valence for atom # 1 N, 4, is greater than permitted
Original SMILES: CC(=O)Nc1ccc(N)cc1

Augmented SMILES:
Aug 1: CC(=O)Nc1ccc(N)[siH]c1
Aug 2: CC(=O)Nc1ccc(N)c[siH]1
Aug 3: NC(=O)Nc1ccc(N)cc1
Aug 4: CC(=O)Nc1ccc(P)cc1


[17:36:57] Explicit valence for atom # 1 N, 4, is greater than permitted
