In [1]:
!pip install mrl-pypi
!pip install catboost
!pip uninstall -y rdkit
!pip install rdkit
!pip install pandas==1.4.1

Collecting mrl-pypi
  Downloading mrl_pypi-0.1.5-py3-none-any.whl (109 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m109.9/109.9 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
Collecting selfies>=2.0.0 (from mrl-pypi)
  Downloading selfies-2.1.1-py3-none-any.whl (35 kB)
Collecting rdkit-pypi (from mrl-pypi)
  Downloading rdkit_pypi-2022.9.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (29.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.4/29.4 MB[0m [31m15.1 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.8.0->mrl-pypi)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.8.0->mrl-pypi)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.8.0->mrl-pypi)
  Using cached nvidia_cuda_cupti_cu12-12.1.10

In [2]:
from rdkit.DataStructs import TanimotoSimilarity
from rdkit.Contrib.SA_Score import sascorer
from rdkit.Chem.rdFingerprintGenerator import GetRDKitFPGenerator
from rdkit.Chem import Descriptors

from mrl.imports import *
from mrl.core import *
from mrl.chem import *
from mrl.templates.all import *
from mrl.combichem import *
from mrl.templates.all import *
from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.layers import *
from mrl.dataloaders import *
from mrl.g_models.all import *
from mrl.vocab import *
from mrl.policy_gradient import *
from mrl.train.all import *
from mrl.model_zoo import *
from mrl.combichem import *


from collections import defaultdict
import pandas as pd
import numpy as np
from catboost import CatBoostRegressor

## Base filters

In [4]:
class NumAtomFilter():

    def __init__(self, atoms=['O', 'N', 'P', 'S'], n=12):
        self.atoms = atoms
        self.n = n
        self.name = "NumAtomFilter"

    def __call__(self, mols, with_score=False):
        return maybe_parallel(self.check, to_mols(mols))

    def check(self, mol):
        d = defaultdict(lambda : 0)
        for atom in mol.GetAtoms():
            d[atom.GetAtomicNum()] += 1

        for key in d.keys():
            if key not in [6, 1, 8, 7, 15, 16]:
                return False

        if len(set(d.keys())) < 2:

            return False

        if d[8] + d[7] + d[15] + d[16] > 12:

            return False

        return True


class HeavyAtomFilter():

    def __init__(self, n=500):
        self.n = n
        self.name = 'wt'

    def __call__(self, mols, with_score=False):
         return maybe_parallel(self.check, to_mols(mols))

    def check(self, mol):
        if Descriptors.HeavyAtomMolWt(mol) > self.n:
            return False

        return True


class SorcerFilter():

    def __init__(self, n=5):
        self.n = n
        self.name = 'sas'

    def __call__(self, mols, with_score=False):
        return maybe_parallel(self.check, to_mols(mols))


    def check(self, mol):
        if sascorer.calculateScore(mol) >= self.n:
            return False

        return True

class CatBoostFilter():

    def __init__(self, model_path):
        self.model = CatBoostRegressor()
        self.model.load_model(model_path)
        self.name = 'cb'
        self.descriptions = []
        for desc in Descriptors._descList:
            self.descriptions.append(desc[1])
        self.mols = []

    def __call__(self, mols, with_score=False):
        data = self.prepare(mols)
        res = self.model.predict(data)
        return res.tolist()

    def f(self, x):
        return maybe_parallel(x, self.mols)

    def prepare(self, mols):
        self.mols = to_mols(mols)
        data = maybe_parallel(self.f, self.descriptions)
        data = np.stack(data)
        return data.T

class SumTanimotoSimilarityPenalty():

    def __init__(self):
        self.fpgen = GetRDKitFPGenerator()
        self.mols = []
        self.name = "SumTanimotoSimilarityPenalty"

    def __call__(self, mols, with_score=False):
        mols = to_mols(mols)
        self.mols = maybe_parallel(self.fpgen.GetFingerprint, mols)

        if not isinstance(mols, list):
            return 0

        res = maybe_parallel(self.check, self.mols)
        return res

    def check(self, mol):
        if (mol is None):
            return -10
        a = []
        for mol2 in self.mols:
            if (mol2 is None) or mol2 == 0:
                a.append(0)
            else:
                a.append(TanimotoSimilarity(mol, mol2))
        return -np.mean(a)

## Init and run algorithm

In [8]:
template = Template([ValidityFilter(),
                     SingleCompoundFilter(),
                     NumAtomFilter(),
                     HeavyAtomFilter(),
                     SorcerFilter(),
                     ],
                    [CatBoostFilter('model.cb'),
                     SumTanimotoSimilarityPenalty(),
                     ],
                    fail_score=0)

mutators = [
            ChangeAtom(['6', '7', '8', '9']),
            AppendAtomSingle(['C', 'O', 'N', 'P', 'S']),
            AppendAtomsDouble(['C', 'N', 'O', 'P', 'S']),
            AppendAtomsTriple(),
            DeleteAtom(),
            ChangeBond(),
            InsertAtomSingle(['C', 'O', 'N', 'P', 'S']),
            InsertAtomDouble(['C', 'O', 'N', 'P', 'S']),
            InsertAtomTriple(),
            AddRing(),
            ShuffleNitrogen(10),
            SelfiesInsert(50),
            SelfiesReplace(50),
            SelfiesRemove(50),
]


In [None]:
df = pd.read_csv("/content/dataset.csv")

mc = MutatorCollection(mutators)
crossovers = [FragmentCrossover()]
cbc = CombiChem(mc, crossovers, template=template, rewards=[],
                prune_percentile=70, max_library_size=400, log=True, p_explore=0.2)

cbc.add_data(df[['smiles']])
for i in range(10):
    cbc.step()
    print(np.mean(cbc.library.score))

In [10]:
cbc.library.sort_values('score')[::-1][:100][['smiles']].to_csv("genetic-alg.csv", index=None, header=None)