In [76]:
import random
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors
from rdkit.Chem import rdmolops
import joblib
from tqdm.auto import tqdm
import pandas as pd


from src.utils import all_rools_valid, sa_score

In [57]:
estimation_model, feature_names = list(joblib.load('models/model_iter_2').values())

desc_dict = dict(Descriptors.descList)

descriptor_funcs = []
for f_name in feature_names:
    descriptor_funcs.append((f_name, desc_dict[f_name]))

In [58]:
initial_molecules = \
pd.read_csv('data/quantitive_neftekod25_data.csv')['SMILES'].drop_duplicates().tolist() +\
pd.read_csv('data/export.csv')['SMILES'].tolist()

In [59]:
initial_molecules[:10]

['C1=CC=C(C=C1)NC2=CC=CC=C2',
 'CC(C)(C)CC(C)(C)C1=CC=CC=C1NC2=CC=CC3=CC=CC=C32',
 'C1(=CC=CC=C1N(C2=CC=CC=C2CCCCCCCCC)[H])CCCCCCCCC',
 'C1=CC=C(C=C1)NC2=CC=CC3=CC=CC=C32',
 'CC1=C(C(=CC=C1)O)C',
 'CC1=CC(=C(C=C1)C)O',
 'CC1=C(C=C(C=C1)O)C',
 'CC1=CC(=CC(=C1)O)C',
 'CCC1=CC=C(C=C1)O',
 'CC1=CC(=C(C(=C1)C(C)(C)C)O)C(C)(C)C']

In [60]:
class GeneticAlgorithm:
    def __init__(self, mutation_func, crossover_func, fitness_func, population_size, mutation_prob, crossover_prob):

        self.mutation_func = mutation_func
        self.crossover_func = crossover_func
        self.fitness_func = fitness_func
        self.population_size = population_size
        self.mutation_prob = mutation_prob
        self.crossover_prob = crossover_prob
        self.log = []

    def run(self, initial_population, iterations):

        population = initial_population.copy()

        for _ in range(iterations):

            fitness_values = [self.fitness_func(ind) for ind in population]
            
            min_fitness = min(fitness_values)
            max_fitness = max(fitness_values)
            avg_fitness = sum(fitness_values) / len(fitness_values)
            sorted_fitness = sorted(fitness_values)
            median_fitness = sorted_fitness[len(sorted_fitness) // 2] if len(sorted_fitness) % 2 == 1 else (
                sorted_fitness[len(sorted_fitness) // 2 - 1] + sorted_fitness[len(sorted_fitness) // 2])/ 2
            
            self.log.append({
                'min': min_fitness,
                'max': max_fitness,
                'avg': avg_fitness,
                'median': median_fitness
            })

            print(self.log[-1])
            pd.DataFrame({'SMILES': population, 'res_fitness': fitness_values}).to_csv('gen_algo_cache/population_mol_based.csv', index=False)
            
            offspring = []
            for i in tqdm(range(min(len(population), self.population_size))):

                if random.random() < self.crossover_prob:
                    parent1 = self._tournament_selection(population, fitness_values)
                    parent2 = self._tournament_selection(population, fitness_values)
                    child1, child2 = self.crossover_func(parent1, parent2)
    
                    offspring.extend([child1, child2])
                

                if random.random() < self.mutation_prob:
                    mut_child = self.mutation_func(population[i])
                    offspring.append(mut_child)

            combined = population + offspring
            combined = list(set(combined))
            combined.sort(key=lambda x: -self.fitness_func(x))
            population = combined[:self.population_size]
        
        return population, self.log

    def _tournament_selection(self, population, fitness_values, tournament_size=3):
        participants = random.sample(list(zip(population, fitness_values)), tournament_size)
        participants.sort(key=lambda x: -x[1])
        return participants[0][0]

In [61]:
firness_cache = {}

groups = [
    'O',                 # -OH (гидроксил)
    'N',                 # -NH2 (аминогруппа)
    'OC(=O)C',           # Ацетилокси-группа
    'c1ccccc1',          # Бензольное кольцо
    'c1cc(-c2ccccc2)cc(-c2ccccc2)c1',  # Бифенильные структуры
    'c1ccc2c(c1)ccc1c2ccccc1',           # Нафталиновые ядра
    'C(C)(C)C',          # Трет-бутильная группа
    'C#C',               # Алкиновые группы
    'C=O',               # Карбонильная группа
    'NC(=O)',            # Амидная группа
    'S(=O)(=O)',         # Сульфоновая группа
    'CSC',               # Тиоэфир
    'COC(=O)',           # Сложноэфирная группа
    'C1CCCCC1',          # Циклогексильное кольцо
    'C(C)=C',            # Алкеновая группа
    'CN(C)C',            # Диметиламиногруппа
    'c1ccncc1',          # Пиридиновое кольцо
    'Oc1ccccc1',         # Феноксильная группа
    'C(C)(O)N',          # Аминоспиртовая группа
    'C1=CC=C(C=C1)O'     # Замещённое фенольное кольцо
]

def mol_based_mutation(smiles: str, max_attempts=100):

    mol = Chem.MolFromSmiles(smiles)
    
    def _add_group_to_bond(mol):
        if len(mol.GetBonds()) == 0:
            return mol
        bond = random.choice(mol.GetBonds())
        mol = Chem.RWMol(mol)
        mol.RemoveBond(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx())
        
        group = random.choice(groups)
        group_mol = Chem.MolFromSmiles(group)
        mol.InsertMol(group_mol)

        return Chem.MolFromSmiles(Chem.MolToSmiles(mol))
    
    
    def _add_functional_group(mol):
        
        atom_idx = random.choice([a.GetIdx() for a in mol.GetAtoms()])
        
        rw_mol = Chem.RWMol(mol)
        
        new_group = Chem.MolFromSmiles(random.choice(groups))
        new_group_atoms = list(new_group.GetAtoms())
        
        combo = Chem.CombineMols(rw_mol, new_group)
        
        rw_mol.AddBond(
            atom_idx,
            rw_mol.GetNumAtoms() - 1,
            Chem.BondType.SINGLE
        )
        
        new_mol = rw_mol.GetMol()

        return Chem.MolFromSmiles(Chem.MolToSmiles(new_mol))
        
    transform = random.choice((_add_group_to_bond, _add_functional_group))

    for attempt in range(max_attempts):
        try:
            new_mol = transform(mol)
            new_smiles = Chem.MolToSmiles(new_mol)

            new_smiles = new_smiles.replace('.', '')
            
            assert all_rools_valid(new_smiles)
            return new_smiles
            
        except Exception as e:
            continue

    return smiles


def mol_based_crossover(smiles1: str, smiles2: str, max_attempts=100):

    def _get_breakable_bonds(mol):
        breakable = []
        for bond in mol.GetBonds():
            if bond.GetBondType() == Chem.BondType.SINGLE:
                if bond.GetBeginAtom().GetSymbol() in ['C','O','N'] and \
                   bond.GetEndAtom().GetSymbol() in ['C','O','N']:
                    breakable.append(bond.GetIdx())
        return breakable
    
    def _get_fragments(mol, bond_idx):

        broken = rdmolops.FragmentOnBonds(mol, [bond_idx], addDummies=True)
        return Chem.GetMolFrags(broken, asMols=True, sanitizeFrags=False)

    def crossover(mol1, mol2):        
        bond1 = random.choice(_get_breakable_bonds(mol1))
        bond2 = random.choice(_get_breakable_bonds(mol2))
        
        frags1 = rdmolops.FragmentOnBonds(mol1, [bond1])
        frags2 = rdmolops.FragmentOnBonds(mol2, [bond2])

        frags1 = _get_fragments(mol1, bond1)
        frags2 = _get_fragments(mol2, bond2)
        
        frag1 = random.choice(frags1)
        frag2 = random.choice(frags2)
        
        combined = Chem.CombineMols(frag1, frag2)
        editable = Chem.EditableMol(combined)
        
        dummies = [atom.GetIdx() for atom in combined.GetAtoms() if atom.GetAtomicNum() == 0]

        if len(dummies) >= 2:
            editable.AddBond(dummies[-2], dummies[-1], Chem.BondType.SINGLE)
            editable.RemoveAtom(dummies[-1])
            editable.RemoveAtom(dummies[-2])


        tmp_mol = editable.GetMol()

        for atom in reversed(list(tmp_mol.GetAtoms())):
            if atom.GetAtomicNum() == 0:
                editable.RemoveAtom(atom.GetIdx())

        new_mol = editable.GetMol()
        
        return editable.GetMol()

    mol1 = Chem.MolFromSmiles(smiles1)
    mol2 = Chem.MolFromSmiles(smiles2)

    for attempt in range(max_attempts):
        
        try:
            new_mol = crossover(mol1, mol2)
            new_smiles = Chem.MolToSmiles(new_mol)
            
            new_smiles = new_smiles.replace('.', '')
            
            assert all_rools_valid(new_smiles)
            return new_smiles, new_smiles
            
        except Exception as e:
            continue

    return smiles1, smiles2


def calculate_fitness(smiles: str) -> float:

    if smiles in firness_cache:
        return firness_cache[smiles]
    
    mol = Chem.MolFromSmiles(smiles)
    descriptors = {desc_name: desc_f(mol) for desc_name, desc_f in descriptor_funcs}
    X = pd.DataFrame([descriptors])
    prediction = estimation_model.predict(X)[0]

    raw_sa = sa_score(smiles)

    penalty = (raw_sa - 3) ** 4

    res = prediction - penalty 
    
    firness_cache[smiles] = res
    
    return res

In [62]:
from rdkit import rdBase

rdBase.DisableLog('rdApp.error')  
rdBase.DisableLog('rdApp.warning')

In [63]:
mol_based_genetic_algo = GeneticAlgorithm(
    mol_based_mutation,
    mol_based_crossover,
    calculate_fitness,
    population_size=200,
    mutation_prob=0.8,
    crossover_prob=0.8
)

_ = mol_based_genetic_algo.run(initial_molecules, 10)

{'min': -18.627716952186717, 'max': 341.82751146461203, 'avg': 155.0205943069376, 'median': 192.30471646230126}


100%|██████████| 85/85 [00:02<00:00, 33.75it/s]


{'min': 0.32924627539441076, 'max': 372.00136281639135, 'avg': 193.75069924046727, 'median': 233.03262860226732}


100%|██████████| 200/200 [00:05<00:00, 36.35it/s]


{'min': 260.70775889861363, 'max': 417.29783382444884, 'avg': 306.33300754621405, 'median': 303.63735628202573}


100%|██████████| 200/200 [00:07<00:00, 27.96it/s]


{'min': 314.0653048449017, 'max': 417.29783382444884, 'avg': 340.5848931659966, 'median': 335.2604619448733}


100%|██████████| 200/200 [00:10<00:00, 18.42it/s]


{'min': 340.5805255776993, 'max': 437.25360514841145, 'avg': 364.2421495210327, 'median': 358.9107081943385}


100%|██████████| 200/200 [00:12<00:00, 15.59it/s]


{'min': 359.9193382204805, 'max': 519.4383590947284, 'avg': 383.1081980853802, 'median': 376.3898325698281}


100%|██████████| 200/200 [00:17<00:00, 11.75it/s]


{'min': 375.8640372739951, 'max': 519.4383590947284, 'avg': 401.5598042420841, 'median': 394.97232439140436}


100%|██████████| 200/200 [00:17<00:00, 11.16it/s]


{'min': 395.2601525260487, 'max': 527.5001192084488, 'avg': 424.11790319959283, 'median': 414.89780696543926}


100%|██████████| 200/200 [00:21<00:00,  9.52it/s]


{'min': 414.7240262377604, 'max': 531.8626506551088, 'avg': 445.74116086683824, 'median': 435.88415925491455}


100%|██████████| 200/200 [00:22<00:00,  8.88it/s]


{'min': 434.15738007987386, 'max': 535.2543762412361, 'avg': 466.80422056333714, 'median': 457.9500916474683}


100%|██████████| 200/200 [00:21<00:00,  9.40it/s]


(['C=CCC1Cc2cc(N)ccc2ONc2cccc(N)c2-c2ccccc2OCNc2ccc3cc2-c2cc(ccc2N)-c2c(N)cccc2-c2ccc(N)c(c2)Oc2cc(ccc2N)-c2cc(N)cc-3c2N1',
  'CNc1ccc2cc1-c1cc(ccc1N)-c1c(N)cccc1-c1ccc(N)c(c1)Oc1cc(ccc1N)-c1cc(N)cc-2c1NCc1cc(N)cc(C)c1ONc1cccc(N)c1-c1ccccc1N',
  'Cc1c(N)ccc2c1CNOc1cc(c3c(c1)-c1cc(N)cc(c1Nc1ccc4c(c1N)-c1cccc-4c1)-c1c(N)cc(-c4cc(N)c(N)c(-c5ccccc5)c4)cc1-c1ccc(O)c(O)c1-c1ccccc1-3)O2',
  'CNCc1c(N)ccc2c1CNOc1cc(c3c(c1)-c1cc(N)cc(c1Nc1ccc4c(c1)-c1cccc-4c1)-c1c(N)cc(-c4cc(N)c(N)c(-c5ccccc5)c4)cc1-c1ccc(O)c(O)c1-c1ccccc1-3)O2',
  'C#CCNc1cc(-c2ccccc2-c2c(N)cccc2NOc2c(C)cc(N)cc2CN)c2cc1-c1cc(ccc1N)-c1c(N)cccc1-c1ccc(N)c(c1)Oc1cc(ccc1N)-c1cc(N)cc-2c1',
  'CNc1cc2c3cc1-c1cc(ccc1N)-c1c(N)cccc1-c1ccc(N)c(c1)Oc1cc(ccc1N)-c1cc(N)cc-3c1NCc1cc(N)cc(C)c1ONONc1cccc(N)c1-c1ccccc1-2',
  'CCNc1ccc2cc1-c1c(N)ccc(c1-c1ccccc1-c1c(N)cccc1NOc1c(C)cc(N)cc1CN)-c1ccccc1-c1ccc(N)c(c1)Oc1cc(ccc1N)-c1cc(N)cc-2c1Nc1ccccc1',
  'CNc1cc2c3cc1-c1cc(ccc1N)-c1c(N)cccc1-c1ccc(N)c(c1)Oc1cc(ccc1N)-c1cc(N)cc-3c1NCc1cc(N)cc(C)c1

In [64]:
new_mols_df = pd.read_csv('gen_algo_cache/population_mol_based.csv')
new_mols_df = new_mols_df[~new_mols_df.SMILES.isin(initial_molecules)]

In [74]:
new_mols_df.head(10)

Unnamed: 0,SMILES,pred_pdsc
0,CNc1ccc2cc1-c1cc(ccc1N)-c1c(N)cccc1-c1ccc(N)c(...,535.254376
1,Cc1c(N)ccc2c1CNOc1cc(c3c(c1)-c1cc(N)cc(c1Nc1cc...,531.862651
2,CNCc1c(N)ccc2c1CNOc1cc(c3c(c1)-c1cc(N)cc(c1Nc1...,531.35758
3,CNc1cc2c3cc1-c1cc(ccc1N)-c1c(N)cccc1-c1ccc(N)c...,530.285149
4,CCNc1ccc2cc1-c1c(N)ccc(c1-c1ccccc1-c1c(N)cccc1...,529.216314
5,CNc1cc2c3cc1-c1cc(ccc1N)-c1c(N)cccc1-c1ccc(N)c...,527.500119
6,CNc1ccc2cc1-c1c(N)ccc(c1-c1ccccc1-c1ccccc1N)-c...,524.66564
7,CNc1cc(-c2ccccc2-c2c(N)cccc2NOc2c(C)cc(N)cc2CN...,521.775853
8,CNc1ccc2cc1-c1c(N)ccc(c1-c1ccccc1-c1c(N)cccc1N...,520.850048
9,CNc1ccc2cc1-c1cc(ccc1N)-c1c(N)cccc1-c1ccc(N)c(...,519.438359


In [67]:
new_mols_df = new_mols_df[(~new_mols_df['SMILES'].apply(lambda el: '.' in el)).tolist()]
new_mols_df = new_mols_df[~new_mols_df['SMILES'].isin(initial_molecules)]

In [78]:
new_mols_df['SMILES'].head(15).to_csv('submit.csv', encoding='UTF-8', index=False)

In [79]:
import os
import zipfile
with zipfile.ZipFile('submission.zip', 'w', zipfile.ZIP_DEFLATED) as zipf:
    zipf.write('submit.csv', os.path.basename('submit.csv'))
    # zipf.write('4_ml_PDSC_prediction.ipynb', os.path.basename('4_ml_PDSC_prediction.ipynb'))
    zipf.write('6_genetic_algo_graph_based.ipynb', os.path.basename('6_genetic_algo_graph_based.ipynb'))
    zipf.write('src/utils.py', os.path.basename('src/utils.py'))

In [80]:
sa_score('CNc1ccc2cc1-c1cc(ccc1N)-c1c(N)cccc1-c1ccc(N)c(c1)Oc1cc(ccc1N)-c1cc(N)cc-2c1NCc1cc(N)cc(C)c1ONc1cccc(N)c1-c1ccccc1N')

6.1090883795804976