## Fine-tuning of language model (LM) and molecule generation
- This notebook contains code for the fine-tuning of target-task LM using pre-trained weights of the pre-trained LM 
- The code is adapted from https://github.com/marcossantanaioc/De_novo_design_SARSCOV2

#### Install RDKit on Google colaboratory

In [None]:
%%bash
add-apt-repository ppa:ubuntu-toolchain-r/test
apt-get update --fix-missing
apt-get dist-upgrade
wget -c https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh
chmod +x Miniconda3-latest-Linux-x86_64.sh
./Miniconda3-latest-Linux-x86_64.sh -b -f -p /usr/local
conda config --set always_yes yes --set changeps1 no
conda install -q -y -c conda-forge python=3.7
conda install -q -y -c conda-forge rdkit
#conda install -q -y -c openbabel openbabel

In [None]:
import sys
sys.path.append('/usr/local/lib/python3.7/site-packages/')

Import the important libraries

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import time
import pandas as pd
import sys
import seaborn as sns
import matplotlib.pyplot as plt

from torch.utils.data import WeightedRandomSampler
import random
import numpy as np
from google.colab import drive

from fastai.callbacks import *
from fastai.text import *
from fastai.metrics import *

from sklearn.model_selection import train_test_split

In [None]:
from rdkit import Chem
from rdkit import rdBase
from rdkit.Chem import Draw, AllChem
from IPython.display import display,Image, SVG
from rdkit.Chem import rdmolops
rdBase.DisableLog('rdApp.error')

Set the seed value

In [None]:
def random_seed(seed_value, use_cuda):
    np.random.seed(seed_value) # cpu vars
    torch.manual_seed(seed_value) # cpu  vars
    random.seed(seed_value) # Python
    if use_cuda: 
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value) # gpu vars
        torch.backends.cudnn.deterministic = True  #needed
        torch.backends.cudnn.benchmark = False

# Data
Mount Google Drive to Google Colab to access the google drive files 

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
alc_smiles = pd.read_csv('/content/gdrive/My Drive/data/alcohol-smiles.csv')
print('Dataset:', alc_smiles.shape)

In [None]:
# Create a path to save the results

GEN = Path('/content/gdrive/My Drive/results/Generative')
GEN.mkdir(parents=True, exist_ok=True)

In [None]:
GENREG = Path('/content/gdrive/My Drive/results/Generative/Regressor')
GENREG.mkdir(parents=True, exist_ok=True)

## Helper functions
### Sampling callback

In [None]:
def is_valid(smiles):
  mol = Chem.MolFromSmiles(smiles)
  if mol is not None and mol.GetNumAtoms()>0:
      return smiles

def uniqueness_score(mols): return set(mols)

def novelty_score(mols,ref_mols): 
    return set.difference(mols,ref_mols)

class SamplingCB(LearnerCallback):

  '''Sampling callback to generate molecules at the end of each training epoch and compute validity,
  novelty and uniqueness.
  learn: Learner

  source_mols: List -> Reference molecules to compute dataset. 

  objective_mols: List -> If finetuning is True, the objective is the dataset we are finetuning to.

  num_samples: Int -> Number of molecules to generate
  '''
  _order=-20 # Needs to run before the recorder
  def __init__(self,learn:Learner, objective_mols:Collection=None, num_samples:int=100):
    super().__init__(learn)
    self.num_samples= num_samples
    self.max_size = 120
    self.temperature = 0.70
    self.objective_mols = objective_mols

  def on_train_begin(self,**kwargs):
    #self.ref_model = load_ref_model()
    self.learn.recorder.add_metric_names(['Valid', 'Unique', 'Novel'])

  def on_epoch_being(self,**kwargs):
    self.objective_mols = random.sample(objective_mols,self.num_samples)

  def sampling(self,text:str='', sep:str=''):
    "Vanilla sampling. Return `text` and the `n_words` that come after"
    m = self.learn
    m.model.reset()
    v = self.learn.data.train_ds.vocab
    v_sz = len(v.itos)
   # print(v.itos[v_sz-1])
    xb,yb = self.learn.data.one_item(text)
    new_idx = []
    for _ in range(self.max_size):
      res = m.pred_batch(batch=(xb,yb))[0][-1]
      if self.temperature != 1.: 
        res.pow_(1 / self.temperature)
      idx = torch.multinomial(res, 1).item()
      if idx != v_sz-1:              
        new_idx.append(idx)
        xb = xb.new_tensor([idx])[None]
      else:
        break
    return text + sep + sep.join(v.textify(new_idx, sep=None))

  def on_epoch_end(self, last_metrics, **kwargs):
    print('Sampling...')
    p = [self.sampling().replace('xxbos','').replace('xxeos','').replace('xxunk','').replace('xxpad','') for i in range(0,self.num_samples)]
    print('Sample of generated SMILES')
    print(p[:5])
    val = list(filter(is_valid,p)) # Validity
    print(val[0:5])
    uniq = uniqueness_score(val) # Uniqueness
    novel = novelty_score(uniq, self.objective_mols) # Novelty

    return add_metrics(last_metrics, [len(val)/self.num_samples, len(uniq)/self.num_samples, len(novel)/self.num_samples])

In [None]:
def sampling(model,dt,text:str, n_words:int, temperature:float=1., sep:str=' '):
  "Vanilla sampling. Return `text` and the `n_words` that come after"
  model.model.reset()
  v = dt.vocab
  
  xb,yb = dt.one_item(text)
  new_idx = []
  for _ in range(n_words):
    res = model.pred_batch(batch=(xb,yb))[0][-1]

    if temperature != 1.: 
      res.pow_(1 / temperature)
    idx = torch.multinomial(res, 1).item()
    if idx != len(v.itos)-1:              
      new_idx.append(idx)
      xb = xb.new_tensor([idx])[None]
    else:
      break
  return text + sep + sep.join(v.textify(new_idx, sep=None))

sampling_temperatures = [0.2, 0.5, 0.6, 0.7, 0.75, 0.8, 1.0, 1.2]

def validation(model,dt,sampling_temperatures,iterations,samples,ref,maxsize=100):
  
  '''Vanilla sampling and validation function'''
  _validity = np.zeros((iterations,len(sampling_temperatures)))
  _novelty = np.zeros((iterations,len(sampling_temperatures)))
  _uniqueness = np.zeros((iterations,len(sampling_temperatures)))

  for j in range(len(sampling_temperatures)):
    temp = sampling_temperatures[j]
    print('Temperatures = {}'.format(temp))
    for i in range(iterations):
      print('Starting iteration {}'.format(i))
      p = [sampling(model,dt,text='',n_words=maxsize,sep='',temperature=temp).replace(PAD,'').replace(BOS,'').replace(EOS,'').replace(UNK,'') for i in range(0,samples)]
      mols = list(filter(is_valid,p)) # Valid
      unq_mols = uniqueness_score(mols) # Uniqueness # Unique
      novel_mols = novelty_score(unq_mols, ref) # Novel

      _novelty[i,j] = len(novel_mols)/samples*100
      _uniqueness[i,j] = len(unq_mols)/samples*100
      _validity[i,j] = len(mols)/samples*100

      print('Iteration {} ended'.format(i))
    print('----------------------------------')
  return _validity, _novelty, _uniqueness, mols, unq_mols, novel_mols

### Data pre-processing

Define a custom tokenizer

In [None]:
#@title
class MolTokenizer(BaseTokenizer):
  ''' Atom-level tokenizer. Splits molecules into individual atoms and special enviroments.
  A special enviroment is defined by any elements inside square brackets (e.g., [nH])
  '''
  def __init__(self, lang:str):
    pass
  def tokenizer(self,t:str) -> List[str]:
    assert type(t) == str
    pat = '(\[.*?\])' # Find special enviroments (e.g., [CH],[NH] etc)
    tokens = []
    t = t.replace('Br','L').replace('Cl','X') # Replace halogens
    atom_list = re.split(pat,t)
    for s in atom_list:
      if s.startswith('['):
        tokens.append(s)
      else:
        tokens += [x for x in list(s)]
    tokens = [x.replace('L','Br').replace('X','Cl') for x in tokens] # Decode halogens
    return [BOS] + tokens + [EOS]# + [PAD for i in range(133-len(tokens))]

class Create_Vocab(object):
  '''Tokenize and create vocabulary of atoms in SMILES strings'''
  def __init__(self,smiles):
    self.smiles = smiles

  def tokenize(self):
    k = MolTokenizer
    tok = Tokenizer(k,pre_rules=[],post_rules=[])
    tokens = tok.process_all(self.smiles)

    unique_tokens = [UNK, PAD] + sorted(list({y for x in tokens for y in x}))
    vocab = Vocab(itos=unique_tokens)
    
    return unique_tokens, vocab

#### SMILES augmentation for language model

In [None]:
def randomize_smiles(smiles):
    m = Chem.MolFromSmiles(smiles)
    ans = list(range(m.GetNumAtoms()))
    np.random.shuffle(ans)
    nm = Chem.RenumberAtoms(m,ans)
    return Chem.MolToSmiles(nm, canonical=False, isomericSmiles=True, kekuleSmiles=False)

def lm_smiles_augmentation(df, N_rounds):
    
    dist_aug = {col_name: [] for col_name in df}

    for i in range(df.shape[0]):
        for j in range(N_rounds):
            dist_aug['smiles'].append(randomize_smiles(df.iloc[i].smiles))
    df_aug = pd.DataFrame.from_dict(dist_aug)
    df_aug = df_aug.append(df, ignore_index=True)
    return df_aug.drop_duplicates('smiles')

The randomized SMILES are used for data augmentation. The number of augmented SMILES can be passed an arguement to the lm_smiles_augmentation function

In [None]:
random_seed(1234, True)

alc_smiles_aug = ee_smiles_augmentation(alc_smiles, 110)
print(len(alc_smiles_aug))

Create a text databunch for language modeling:

- It takes SMILES as input
- Pass the custom tokenizer defined in the previous step
- Specify the column containing text data
- Define the batch size according to the GPU memory available

In [None]:
random_seed(1234, True)

vocab_list = Create_Vocab(list(alc_smiles_aug.smiles))
unique_tokens,vocab = vocab_list.tokenize()

In [None]:
random_seed(1234, True)

tokenizer = Tokenizer(MolTokenizer,pre_rules=[],post_rules=[],special_cases=[PAD,BOS,EOS,UNK])
processors = [TokenizeProcessor(tokenizer=tokenizer, mark_fields=False,include_bos=False), NumericalizeProcessor(vocab=vocab)]
src = (TextList.from_df(alc_smiles_aug, path=GEN, cols='smiles', processor=processors).split_by_rand_pct(0.10).label_for_lm())

In [None]:
random_seed(1234, True)

data_fn = src.databunch()
data_fn.show_batch()

## Fine-tuning the target task language model 

Load the pre-trained weights and vocabulary

In [None]:
pretrained_model_path = Path('/content/gdrive/My Drive/results/MSPM/models')
pretrained_fnames = ['pre-trained_wt', 'pre-trained_vocab']
fnames = [pretrained_model_path/f'{fn}.{ext}' for fn,ext in zip(pretrained_fnames, ['pth', 'pkl'])]

In [None]:
#reference dataset
obj_ref = list(set(alc_smiles.smiles))
print(len(obj_ref))

Create a learner for language modeling:

- Initialize the learner with the pre-trained weights
- Pass the text databunch loaded in the previous step
- Drop_mult is a hyperparameter that can be tuned
- Accuracy is the metric used for model evaluation

In [None]:
random_seed(1234, True)

learn_fn = language_model_learner(data_fn, AWD_LSTM, pretrained=False, drop_mult=0.4, metrics=[accuracy, error_rate], callback_fns=[partial(CSVLogger,append=True)]).load_pretrained(*fnames)
learn_fn.freeze()

Train the model using fit_one_cycle in three steps using gradual unfreezing:

- For the first step, the weights of the LSTM layers are kept frozen and the rest of the model is trained.
- In the second step, the weight of last LSTM is unfrozen
- In the third step, all layers are unfrozen so that the LSTM layers can be fine-tuned
- Number of epochs and learning rate are the two hyperparameters that can be tuned here

In [None]:
random_seed(1234, True)

learn_fn.fit_one_cycle(5, 1e-1, moms=(0.8,0.7), callbacks=[SamplingCB(learn_fn, num_samples=100, objective_mols=obj_ref),
                                   SaveModelCallback(learn_fn, every='improvement',monitor='accuracy', name='bestmodel')])

epoch,train_loss,valid_loss,accuracy,error_rate,Valid,Unique,Novel,time
0,1.189117,1.041965,0.710491,0.289509,0.0,0.0,0.0,00:19
1,0.885585,0.689171,0.782738,0.217262,0.77,0.77,0.77,00:08
2,0.724809,0.619877,0.804613,0.195387,0.89,0.89,0.87,00:07
3,0.640362,0.607313,0.807069,0.192932,0.93,0.93,0.93,00:07
4,0.591198,0.594575,0.810565,0.189435,0.94,0.94,0.94,00:07


Sampling...
Sample of generated SMILES
[')CC2(CCCC2)CC1=O)C(=O)OC(C)(C)CCCCOC1OCCO1)[C@@H](O)[C@H](O)CO[C@H]1O[C@H](CO)[C@@H](O)[C@H](O)[C@@H]1OC(=O)CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC', ')=C(C(OC)=O)C(C)=N1CN1CCCCC1C2=O)C(=O)OCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC', ')ccccc1)=O)OCCN(CC)CCOC(=O)C(C)Oc1ccccc1[N+]([O-])=O)c1c([N+](=O)[O-])cc(OC)cc1[N+]([O-])=Oc1c(OC)ccc(C(=O)OC)c1OC(=O)CCCCCCCCCCCCCCCCCCCC', ')cc(C)ccc1OC(C)OC(OC(C)(C)C)=O)C[C@H]1[C@H](O)C(C)(C)Oc2ccc(C(=O)O)cc212CCC2=C(C)CCC/C(C)=C/COC(=O)CC(C)CCCC(C)CCCC(C)CCCC(C)CCC', ')nc(C)n(CCO)c1ncn2C(C)=O)(C)OC(C1CCCCC1)=O)(C)COC(=O)CCC(O)=O[C@@H]1[C@@H]2[C@@H](C(COC(C)=O)=CO[C@H]2O)[C@H]2[C@@H]1[C@]1(C)CCC[C@@](COC(C)=O)(C)[C@H]1CC2=ONC']
[]
Better model found at epoch 0 with accuracy value: 0.7104911208152771.
Sampling...
Sample of generated SMILES
['=[N+](CCCCC)[O-])[C@H]1[C@H](O)CO[C@H]1CO', '=C1C(OC)=CC(O)(C=C)C=C1', 'C(CCCCC)(O)CC(OC)=O', 'c1c(OC)ccc(/C=C2\\C(OC)=C(OC)C(=O)

In [None]:
random_seed(1234, True)

learn_fn.freeze_to(-2)

learn_fn.fit_one_cycle(6, 1e-2, moms=(0.8,0.7), callbacks=[SamplingCB(learn_fn, num_samples=100, objective_mols=obj_ref),
                                  SaveModelCallback(learn_fn, every='improvement', monitor='accuracy', name='bestmodel')])

epoch,train_loss,valid_loss,accuracy,error_rate,Valid,Unique,Novel,time
0,0.457638,0.516757,0.838095,0.161905,0.97,0.97,0.86,00:07
1,0.396397,0.453073,0.855208,0.144792,1.0,0.98,0.67,00:07
2,0.355367,0.435677,0.855208,0.144792,0.99,0.96,0.6,00:08
3,0.326964,0.425345,0.859747,0.140253,1.0,0.98,0.47,00:07
4,0.308238,0.421426,0.858929,0.141071,1.0,0.96,0.51,00:07
5,0.289998,0.421951,0.859524,0.140476,1.0,0.98,0.46,00:07


Sampling...
Sample of generated SMILES
['OC(C1CCC(O)CC1)(c1ccccc1)c1ccccc1', 'CCN(CC(=O)OC)C(=O)OC1(CO)Cc1ccccc1', 'C1[C@@H](O)[C@H]2[C@H](O)C[C@H]3[C@@](C)([C@H]2CC1)CC[C@H]1[C@]3(C)CC[C@H](O)C1', 'OC(C)CCC(O)CCCCCCC=C', 'CCOC(=O)c1c(CO)n(C)c2ccc(OCCCC(O)(Cn3cncn3)c3ccc(F)cc3F)cc12']
['OC(C1CCC(O)CC1)(c1ccccc1)c1ccccc1', 'C1[C@@H](O)[C@H]2[C@H](O)C[C@H]3[C@@](C)([C@H]2CC1)CC[C@H]1[C@]3(C)CC[C@H](O)C1', 'OC(C)CCC(O)CCCCCCC=C', 'CCOC(=O)c1c(CO)n(C)c2ccc(OCCCC(O)(Cn3cncn3)c3ccc(F)cc3F)cc12', 'C1C2CCC(CC1O)N2C(=O)OC(C)(C)C']
Better model found at epoch 0 with accuracy value: 0.8380951881408691.
Sampling...
Sample of generated SMILES
['C1CC(O)CC2=CC[C@H]3[C@H]4CC=C(c5cccnc5)[C@@]4(C)CC[C@@H]3[C@@]12C', 'C(CCC)N(CC(c1cc(Cl)cc2c1-c1c(cc(Cl)cc1)/C2=C/c1ccc(Cl)cc1)O)CCCC', 'C(CCCC)N(CCCCC)CCCO', 'c1ccccc1C(n1c2c(nc1)c(N(Cc1ccccc1)Cc1ccccc1)ncn2)c1ccccc1', 'C(Cc1ccccc1)C(C)O']
['C1CC(O)CC2=CC[C@H]3[C@H]4CC=C(c5cccnc5)[C@@]4(C)CC[C@@H]3[C@@]12C', 'C(CCC)N(CC(c1cc(Cl)cc2c1-c1c(cc(Cl)cc1)/C2=C/c1c

In [None]:
random_seed(1234, True)

learn_fn.unfreeze()

learn_fn.fit_one_cycle(6, 1e-3, moms=(0.8,0.7), callbacks=[SamplingCB(learn_fn, num_samples=100, objective_mols=obj_ref),
                                   SaveModelCallback(learn_fn, every='improvement', 
                                                     monitor='accuracy', name='bestmodel')])

epoch,train_loss,valid_loss,accuracy,error_rate,Valid,Unique,Novel,time
0,0.263676,0.42961,0.860193,0.139807,1.0,0.98,0.45,00:07
1,0.263043,0.413682,0.862351,0.137649,0.99,0.98,0.46,00:07
2,0.259629,0.430949,0.859152,0.140848,1.0,0.98,0.39,00:08
3,0.257053,0.434126,0.857292,0.142708,1.0,0.97,0.4,00:08
4,0.257199,0.42792,0.857366,0.142634,0.99,0.96,0.36,00:07
5,0.252817,0.433349,0.859077,0.140923,0.99,0.94,0.33,00:08


Sampling...
Sample of generated SMILES
['C(=C1\\c2cc(Cl)cc(C(O)CN(CCCC)CCCC)c2-c2ccc(Cl)cc21)\\c1ccc(Cl)cc1', 'c1ccc(OCC(O)C)cc1', 'C1C[C@@]2(C)OO[C@]34[C@H](O[C@H](O)[C@@H](C)C3CC[C@@H](C)[C@@H]4C1)O2', 'c1(-c2oc(CCCO)nc2-c2ccccc2)ccccc1', 'C(OC)(=O)[C@@H]1N(C(=O)OC(C)(C)C)C[C@@H](O)C1']
['C(=C1\\c2cc(Cl)cc(C(O)CN(CCCC)CCCC)c2-c2ccc(Cl)cc21)\\c1ccc(Cl)cc1', 'c1ccc(OCC(O)C)cc1', 'C1C[C@@]2(C)OO[C@]34[C@H](O[C@H](O)[C@@H](C)C3CC[C@@H](C)[C@@H]4C1)O2', 'c1(-c2oc(CCCO)nc2-c2ccccc2)ccccc1', 'C(OC)(=O)[C@@H]1N(C(=O)OC(C)(C)C)C[C@@H](O)C1']
Better model found at epoch 0 with accuracy value: 0.8601934313774109.
Sampling...
Sample of generated SMILES
['C(C)CN(S(c1ccc(CO)cc1)(=O)=O)CCC', 'c1ccc(CCCCO)cc1', 'c1ccccc1C(c1ccccc1)(N[C@@H](CO)C(=O)OC)c1ccccc1', 'CCCCN(CCCC)CC(O)c1c2c(cc(Cl)c1)/C(=C\\c1ccc(Cl)cc1)c1c-2ccc(Cl)c1', 'C1(C)(C)O[C@@H]([C@@H]2O[C@H]3OC(C)(C)O[C@@H]3[C@@H]2O)CO1']
['C(C)CN(S(c1ccc(CO)cc1)(=O)=O)CCC', 'c1ccc(CCCCO)cc1', 'c1ccccc1C(c1ccccc1)(N[C@@H](CO)C(=O)OC)c1ccccc1', 'CCC

Save the model

In [None]:
learn_fn.save_encoder('finetuned_encoder_alc')

#### Validate the fine-tuned model in terms of validity, uniqueness, and novelty

In [None]:
learn_fn.validate()

[0.41006887, tensor(0.8413), tensor(0.1587)]

In [None]:
random_seed(1234, True)

validity, novelty, uniqueness, mols, unq_mols, novel_mols = validation(learn_fn, data_fn, sampling_temperatures, 1, 500, ref=obj_ref)

In [None]:
val_df = pd.DataFrame(validity, columns=['Temp_{}'.format(i) for i in sampling_temperatures])
nov_df = pd.DataFrame(novelty, columns=['Temp_{}'.format(i) for i in sampling_temperatures])
unq_df = pd.DataFrame(uniqueness, columns=['Temp_{}'.format(i) for i in sampling_temperatures])

In [None]:
pd.Series(list(novel_mols)).to_csv("nov_alc.csv", index=False)