In [1]:
import pandas as pd
from rdkit import Chem

In [2]:
def generate_variants_smiles_with_labels(row, num_variants=1):
    variant_rows = []
    smiles = row['SMILES']
    mol = Chem.MolFromSmiles(smiles)
    if mol is not None:
        for _ in range(num_variants):
            Chem.Kekulize(mol, clearAromaticFlags=True)
            variant_smiles = Chem.MolToSmiles(mol, canonical=False, doRandom=True, isomericSmiles=False)
            new_row = row.copy()
            new_row['SMILES'] = variant_smiles
            variant_rows.append(new_row)
    return variant_rows

In [3]:
def smiles_augmentation(df, original_multiplier=1, num_variants=0):
    augmented_rows = []
    for _, row in df.iterrows():
        for _ in range(original_multiplier): 
            augmented_rows.append(row.copy())
        if num_variants > 0:
            augmented_rows.extend(generate_variants_smiles_with_labels(row, num_variants))
    augmented_df = pd.DataFrame(augmented_rows).reset_index(drop=True)
    return augmented_df