# Large Language Models are Fragment Based Drug Designers
# Manas Mahale <manas.mahale@bcp.edu.in>

In [16]:
import os
import random
from rdkit import Chem
from rdkit.Chem import BRICS
from transformers import BertForMaskedLM, BertModel, pipeline, PreTrainedTokenizerFast
from tokenizers.processors import TemplateProcessing
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import WhitespaceSplit

In [2]:
tokenizer = PreTrainedTokenizerFast.from_pretrained('./tokenizer/')

In [3]:
tokenizer.mask_token = "[MASK]"
tokenizer.unk_token = "[UNK]"
tokenizer.pad_token = "[PAD]"
tokenizer.sep_token = "[SEP]"
tokenizer.cls_token = "[CLS]"

In [4]:
model = BertForMaskedLM.from_pretrained(os.path.join('model/', "checkpoint-250"))

In [5]:
fill_mask = pipeline("fill-mask", model=model, tokenizer=tokenizer)

In [6]:
smi = 'CCc1nc(CNc2cccc3cccnc23)cs1'

In [7]:
d = list(BRICS.BRICSDecompose(Chem.MolFromSmiles(smi)))

In [8]:
dummy = Chem.MolFromSmiles('[*]')

def mol_to_smiles(mol):
        return Chem.MolToSmiles(mol, isomericSmiles=True)

def mol_from_smiles(smi):
    return Chem.MolFromSmiles(smi)


def strip_dummy_atoms(mol):
    hydrogen = mol_from_smiles('[H]')
    mols = Chem.ReplaceSubstructs(mol, dummy, hydrogen, replaceAll=True)
    mol = Chem.RemoveHs(mols[0])
    return mol

In [63]:
a = [mol_to_smiles(strip_dummy_atoms(mol_from_smiles(i))) for i in d]
a.insert(random.randint(0, len(a)), '[MASK]')
a = ' '.join(a)
a



'CC c1ccc2ncccc2c1 c1cscn1 [MASK] C N'

In [64]:
b = tokenizer.encode(a)
b

[1, 14, 25, 64, 4, 13, 5, 2]

In [65]:
tokenizer.decode(b)

'[CLS] CC c1ccc2ncccc2c1 c1cscn1 [MASK] C N [SEP]'

In [66]:
for prediction in fill_mask(a):
    print(prediction)

{'score': 0.2568870782852173, 'token': 6, 'token_str': 'CC=O', 'sequence': 'CC c1ccc2ncccc2c1 c1cscn1 CC=O C N'}
{'score': 0.23006771504878998, 'token': 8, 'token_str': 'C=O', 'sequence': 'CC c1ccc2ncccc2c1 c1cscn1 C=O C N'}
{'score': 0.08648885041475296, 'token': 9, 'token_str': 'c1ccccc1', 'sequence': 'CC c1ccc2ncccc2c1 c1cscn1 c1ccccc1 C N'}
{'score': 0.07768130302429199, 'token': 12, 'token_str': 'CO', 'sequence': 'CC c1ccc2ncccc2c1 c1cscn1 CO C N'}
{'score': 0.045079827308654785, 'token': 24, 'token_str': 'c1ccsc1', 'sequence': 'CC c1ccc2ncccc2c1 c1cscn1 c1ccsc1 C N'}
