In [None]:
import os 
import torch 
#os.environment['CUDA_VISIBLE_DEVICES'] = 5
%env CUDA_VISIBLE_DEVICES=6
device = torch.device('cuda:6' if torch.cuda.is_available() else "cpu")

In [None]:
import sys
sys.path.append('./fastai1/')

In [None]:
import pandas as pd 
import numpy as np
import threading
import random

from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import IPythonConsole
from rdkit import RDLogger
from IPython.display import display,Image, SVG
from rdkit.Chem import rdmolops
RDLogger.DisableLog('rdApp.*') # switch off RDKit warning messages


from fastai import *
from fastai.text import *
from fastai.vision import *
from fastai.imports import *
from fastai.callbacks import *

import torch
import torchvision
import torch.nn.functional as F
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
current_path = os.getcwd()
print(current_path)



Set the seed value

In [None]:
def random_seed(seed_value, use_cuda):
    np.random.seed(seed_value) # cpu vars
    torch.manual_seed(seed_value) # cpu  vars
    random.seed(seed_value) # Python
    if use_cuda:
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value) # gpu vars
        torch.backends.cudnn.deterministic = True  #needed
        torch.backends.cudnn.benchmark = False

In [None]:
exp_smiles = pd.read_csv('./data/Experimental_Ligands.csv')
print('Dataset:', exp_smiles.shape)

In [None]:
# Create a path to save the results
GEN = Path('./results/generative_model')
GEN.mkdir(parents=True, exist_ok=True)

In [None]:
def sanitize_smiles(smiles, canonical=True, throw_warning=False):
    new_smiles = []
    for sm in smiles:
        try:
            if canonical:
                new_smiles.append(Chem.MolToSmiles(Chem.MolFromSmiles(sm, sanitize=True)))
            else:
                new_smiles.append(sm)
        except:
            if throw_warning:
                warnings.warn('Unsanitized SMILES string: ' + sm, UserWarning)
            new_smiles.append('')
    return new_smiles


def canonical_smiles(smiles, sanitize=True, throw_warning=False):
    new_smiles = []
    for sm in smiles:
        try:
            mol = Chem.MolFromSmiles(sm, sanitize=sanitize)
            new_smiles.append(Chem.MolToSmiles(mol))
        except:
            if throw_warning:
                warnings.warn(sm + ' can not be canonized: invalid '
                                   'SMILES string!', UserWarning)
            new_smiles.append('')
    return new_smiles

In [None]:
def is_valid(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is not None and mol.GetNumAtoms()>0:
        return smiles

def uniqueness_score(mols): return set(mols)

def novelty_score(mols,ref_mols):
    return set.difference(mols,ref_mols)



class SamplingCB(LearnerCallback):

    _order=-20 # Needs to run before the recorder
    def __init__(self,learn:Learner, objective_mols:Collection=None, num_samples:int=100):
        super().__init__(learn)
        self.num_samples= num_samples
        self.max_size = 100
        self.temperature = 1.0
        self.objective_mols = objective_mols

    def on_train_begin(self,**kwargs):
        #self.ref_model = load_ref_model()
        self.learn.recorder.add_metric_names(['Valid', 'Unique', 'Novel'])

    def on_epoch_being(self,**kwargs):
        self.objective_mols = random.sample(objective_mols,self.num_samples)

    def sampling(self,text:str='', sep:str=''):
        "Vanilla sampling. Return `text` and the `n_words` that come after"
        m = self.learn
        m.model.reset()
        v = self.learn.data.train_ds.vocab
        v_sz = len(v.itos)
        # print(v.itos[v_sz-1])
        xb,yb = self.learn.data.one_item(text)
        new_idx = []
        for _ in range(self.max_size):
            res = m.pred_batch(batch=(xb,yb))[0][-1]
            if self.temperature != 1.:
                res.pow_(1 / self.temperature)
            idx = torch.multinomial(res, 1).item()
            if idx != v_sz-1:
                new_idx.append(idx)
                xb = xb.new_tensor([idx])[None]
            else:
                break
        return text + sep + sep.join(v.textify(new_idx, sep=None))

    def on_epoch_end(self, last_metrics, **kwargs):
        print('Sampling...')
        p = [self.sampling().replace('xxbos','').replace('xxeos','').replace('xxunk','').replace('xxpad','') for i in range(0,self.num_samples)]
        print('Sample of generated SMILES')
        print(p[:5])
        val = list(filter(is_valid,p)) # Validity
        print(val[0:5])
        #sanitized = canonical_smiles(val, sanitize=True, throw_warning=True)
        uniq = uniqueness_score(val) # Uniqueness
        novel = novelty_score(uniq, self.objective_mols) # Novelty

        return add_metrics(last_metrics, [len(val)/self.num_samples, len(uniq)/self.num_samples, len(novel)/self.num_samples])

In [None]:
def sampling(model,dt,text:str, n_words:int, temperature:float=1., sep:str=' '):
    "Vanilla sampling. Return `text` and the `n_words` that come after"
    model.model.reset()
    v = dt.vocab

    xb,yb = dt.one_item(text)
    new_idx = []
    for _ in range(n_words):
        res = model.pred_batch(batch=(xb,yb))[0][-1]

        if temperature != 1.:
            res.pow_(1 / temperature)
        idx = torch.multinomial(res, 1).item()
        if idx != len(v.itos)-1:
            new_idx.append(idx)
            xb = xb.new_tensor([idx])[None]
        else:
            break
    return text + sep + sep.join(v.textify(new_idx, sep=None))



def validation(model, dt, sampling_temperatures, iterations, samples, ref, maxsize=100):

    '''Vanilla sampling and validation function'''
    _validity = np.zeros((iterations, len(sampling_temperatures)))
    _novelty = np.zeros((iterations, len(sampling_temperatures)))
    _uniqueness = np.zeros((iterations, len(sampling_temperatures)))

    for j in range(len(sampling_temperatures)):
        temp = sampling_temperatures[j]
        print('Temperatures = {}'.format(temp))
        for i in range(iterations):
            print('Starting iteration {}'.format(i))
            p = [sampling(model, dt, text='', n_words=maxsize, sep='', temperature=temp).replace(PAD, '').replace(BOS, '').replace(EOS, '').replace(UNK, '') for i in range(0, samples)]
            mols = list(filter(is_valid, p))  # Valid
            #sanitized = canonical_smiles(mols, sanitize=True, throw_warning=True)
            unq_mols = uniqueness_score(mols)  # Uniqueness # Unique
            novel_mols = novelty_score(unq_mols, ref)  # Novel

            _novelty[i, j] = len(novel_mols) / samples * 100
            _uniqueness[i, j] = len(unq_mols) / samples * 100
            _validity[i, j] = len(mols) / samples * 100

        print('Iteration {} ended'.format(i))
    print('----------------------------------')
    return _validity, _novelty, _uniqueness, mols, unq_mols, novel_mols


### Data pre-processing

Define a custom tokenizer

In [None]:
class MolTokenizer(BaseTokenizer):
    ''' Atom-level tokenizer. Splits molecules into individual atoms and special environments.
    A special environment is defined by any elements inside square brackets (e.g., [nH])
    '''
    def __init__(self, lang:str):
        pass
    
    def tokenizer(self, t:str) -> List[str]:
        assert type(t) == str
        pat = '(\[.*?\])'  # Find special environments (e.g., [CH],[NH] etc)
        tokens = []
        t = t.replace('Br', 'L').replace('Cl', 'X')  # Replace halogens
        atom_list = re.split(pat, t)
        for s in atom_list:
            if s.startswith('['):
                tokens.append(s)
            else:
                tokens += [x for x in list(s)]
        tokens = [x.replace('L', 'Br').replace('X', 'Cl') for x in tokens]  # Decode halogens
        return [BOS] + tokens + [EOS]  # + [PAD for i in range(133-len(tokens))]

class Create_Vocab(object):
    '''Tokenize and create vocabulary of atoms in SMILES strings'''
    def __init__(self, smiles):
        self.smiles = smiles

    def tokenize(self):
        k = MolTokenizer
        tok = Tokenizer(k, pre_rules=[], post_rules=[])
        tokens = tok.process_all(self.smiles)

        unique_tokens = [UNK, PAD] + sorted(list({y for x in tokens for y in x}))
        vocab = Vocab(itos=unique_tokens)

        return unique_tokens, vocab


#### SMILES augmentation for language model

In [None]:
def randomize_smiles(smiles):
    m = Chem.MolFromSmiles(smiles)
    ans = list(range(m.GetNumAtoms()))
    np.random.shuffle(ans)
    nm = Chem.RenumberAtoms(m,ans)
    return Chem.MolToSmiles(nm, canonical=False, isomericSmiles=True, kekuleSmiles=False)

def lm_smiles_augmentation(df, N_rounds):

    dist_aug = {col_name: [] for col_name in df}

    for i in range(df.shape[0]):
        for j in range(N_rounds):
            dist_aug['smiles'].append(randomize_smiles(df.iloc[i].smiles))
    df_aug = pd.DataFrame.from_dict(dist_aug)
    df_aug = df_aug.append(df, ignore_index=True)
    return df_aug.drop_duplicates('smiles')

The randomized SMILES are used for data augmentation. The number of augmented SMILES can be passed an arguement to the lm_smiles_augmentation function

In [None]:
random_seed(1234, True)

exp_smiles_aug = lm_smiles_augmentation(exp_smiles, 200)
print(len(exp_smiles_aug))

Create a text databunch for language modeling:

- It takes SMILES as input
- Pass the custom tokenizer defined in the previous step
- Specify the column containing text data
- Define the batch size according to the GPU memory available

In [None]:
random_seed(1234, True)

vocab_list = Create_Vocab(list(exp_smiles_aug.smiles))
unique_tokens,vocab = vocab_list.tokenize()

In [None]:
random_seed(1234, True)

tokenizer = Tokenizer(MolTokenizer,pre_rules=[],post_rules=[],special_cases=[PAD,BOS,EOS,UNK])
processors = [TokenizeProcessor(tokenizer=tokenizer, mark_fields=False,include_bos=False), NumericalizeProcessor(vocab=vocab)]
src = (TextList.from_df(exp_smiles_aug, path=GEN, cols='smiles', processor=processors).split_by_rand_pct(0.10).label_for_lm())

In [None]:
random_seed(1234, True)

data_fn = src.databunch()
data_fn.show_batch()

## Fine-tuning the target task language model

Load the pre-trained weights and vocabulary

In [None]:
pretrained_model_path = Path('./pre_trained_model_checkpoint/')
pretrained_fnames = ['pre_trained_wt', 'pre_trained_vocab']
fnames = [pretrained_model_path/f'{fn}.{ext}' for fn,ext in zip(pretrained_fnames, ['pth', 'pkl'])]

In [None]:
#reference dataset
smiles_ref = canonical_smiles(list(set(exp_smiles.smiles)), sanitize=True, throw_warning=True)
print(len(smiles_ref))

In [None]:
random_seed(1234, True)

learn_fn = language_model_learner(data_fn, AWD_LSTM, pretrained=False, drop_mult=0.8, metrics=[accuracy, error_rate], callback_fns=[partial(CSVLogger,append=True)]).load_pretrained(*fnames)
learn_fn.freeze()

In [None]:
random_seed(1234, True)

learn_fn.fit_one_cycle(5, 1e-1, moms=(0.8,0.7), callbacks=[SamplingCB(learn_fn, num_samples=5, objective_mols=smiles_ref),
                                   SaveModelCallback(learn_fn, every='improvement',monitor='accuracy', name='bestmodel')])

In [None]:
random_seed(1234, True)

learn_fn.freeze_to(-2)

learn_fn.fit_one_cycle(6, 1e-2, moms=(0.8,0.7), callbacks=[SamplingCB(learn_fn, num_samples=5, objective_mols=smiles_ref),
                                  SaveModelCallback(learn_fn, every='improvement', monitor='accuracy', name='bestmodel')])

In [None]:
random_seed(1234, True)

learn_fn.unfreeze()

learn_fn.fit_one_cycle(6, 1e-3, moms=(0.8,0.7), callbacks=[SamplingCB(learn_fn, num_samples=100, objective_mols=smiles_ref),
                                   SaveModelCallback(learn_fn, every='improvement',
                                                     monitor='accuracy', name='bestmodel')])

Save the model

In [None]:
learn_fn.save_encoder('finetuned_encoder')

#### Validate the fine-tuned model in terms of validity, uniqueness, and novelty

In [None]:
learn_fn.validate()

In [None]:
sampling_temperatures = [0.2,  0.6,  0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0]

In [None]:
random_seed(1234, True)

validity, novelty, uniqueness, mols, unq_mols, novel_mols = validation(learn_fn, data_fn, sampling_temperatures, 1, 500, ref=smiles_ref)

In [None]:
val_df = pd.DataFrame(validity, columns=['Temp_{}'.format(i) for i in sampling_temperatures])
nov_df = pd.DataFrame(novelty, columns=['Temp_{}'.format(i) for i in sampling_temperatures])
unq_df = pd.DataFrame(uniqueness, columns=['Temp_{}'.format(i) for i in sampling_temperatures])

In [None]:
len(mols), len(unq_mols), len(novel_mols)

In [None]:
pd.Series(list(novel_mols)).to_csv("./results/generative_model/Generated_Ligands.csv", index=False)