**Check GPU if exists**

In [1]:
!nvidia-smi

Thu May 27 17:23:09 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.102.04   Driver Version: 450.102.04   CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  A100-SXM4-40GB      On   | 00000000:07:00.0 Off |                    0 |
| N/A   35C    P0    52W / 400W |      0MiB / 40537MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  A100-SXM4-40GB      On   | 00000000:0F:00.0 Off |                    0 |
| N/A   35C    P0    54W / 400W |      0MiB / 40537MiB |      0%      Default |
|       

**Load necessry libraries including huggingface transformers**

In [2]:
import torch

from transformers import AutoModelForSequenceClassification, BertModel, RobertaModel, BertTokenizer, RobertaTokenizer
from transformers import PreTrainedModel, BertConfig, RobertaConfig
from transformers import Trainer, TrainingArguments
from transformers.data.data_collator import default_data_collator
from transformers.tokenization_utils_base import BatchEncoding
from transformers import EvalPrediction

from transformers import AutoModelForMaskedLM
from transformers import AdamW
from scipy.special import softmax
from datasets import load_dataset
from torch.utils.data import Dataset

from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from sklearn.metrics import mean_squared_error, mean_absolute_error

import re
import gc
import os
import pandas as pd
import numpy as np
import requests
from tqdm.auto import tqdm

import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DD

In [3]:
torch.__version__

'1.8.1'

**Select Model**

In [4]:
#model_name = "Rostlab/prot_t5_xl_uniref50" # for embedding only
seq_model_name = "Rostlab/prot_bert_bfd" # for fine-tuning

**Load the vocabulary**

In [5]:
if "bert" in seq_model_name:
    seq_tokenizer = BertTokenizer.from_pretrained(seq_model_name, do_lower_case=False)
else:
    print("Unkown model name")

**Ligand SMILES model (Roberta)**

In [6]:
import json

In [7]:
#model_directory = '/home/xvg/maskedevolution/examples/molecules/pretrained_model'
model_directory = '/home/xvg/maskedevolution/models/bert_large_1B/model'

In [8]:
#tokenizer_directory =  '/home/xvg/maskedevolution/examples/molecules/pretrained_tokenizer'
tokenizer_directory =  '/home/xvg/maskedevolution/models/bert_large_1B/tokenizer'

In [9]:
tokenizer_config = json.load(open(tokenizer_directory+'/config.json','r'))

In [10]:
#smiles_tokenizer =  RobertaTokenizer.from_pretrained(tokenizer_directory, **tokenizer_config, do_lower_case=False)
smiles_tokenizer =  BertTokenizer.from_pretrained(tokenizer_directory, **tokenizer_config)
smiles_tokenizer.do_lower_case

False

In [11]:
# maximum sequence length
#max_smiles_length = min(120,RobertaConfig.from_pretrained(model_directory).max_position_embeddings)
max_smiles_length = min(200,BertConfig.from_pretrained(model_directory).max_position_embeddings)
max_seq_length = min(2048,BertConfig.from_pretrained(seq_model_name).max_position_embeddings)

Define the ensemble model

In [12]:
class MLP(torch.nn.Module):
    '''
    Multilayer Perceptron.
    '''
    def __init__(self,ninput):
        super().__init__()
        self.layers = torch.nn.Sequential(
               torch.nn.Linear(ninput, 32),
               torch.nn.ReLU(),
               torch.nn.Linear(32, 32),
               torch.nn.ReLU(),
               torch.nn.Linear(32, 1)
#        torch.nn.Linear(ninput, 1)
        )

    def forward(self, x):
        '''Forward pass'''
        return self.layers(x)

In [13]:
class EnsembleSequenceRegressor(torch.nn.Module):
    def __init__(self, seq_model_name, smiles_model_name, *args, **kwargs):
        super().__init__()
        
        self.seq_model = BertModel.from_pretrained(seq_model_name)
#        self.smiles_model = RobertaModel.from_pretrained(smiles_model_name)
        self.smiles_model = BertModel.from_pretrained(smiles_model_name)

        seq_config = self.seq_model.config
        smiles_config = self.smiles_model.config  

        self.cls = MLP(seq_config.hidden_size+smiles_config.hidden_size)
        #self.init_weights()


    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            token_type_ids=None,
            position_ids=None,
            head_mask=None,
            inputs_embeds=None,
            labels=None,
    ):
        outputs = []
        input_ids_1 = input_ids[:,:max_seq_length]
        attention_mask_1 = attention_mask[:,:max_seq_length]
        token_type_ids_1 = token_type_ids[:,:max_seq_length]
        outputs.append(self.seq_model(input_ids=input_ids_1,
                                      attention_mask=attention_mask_1,
                                      token_type_ids=token_type_ids_1,
                                      return_dict=False))

        input_ids_2 = input_ids[:,max_seq_length:]
        attention_mask_2 = attention_mask[:,max_seq_length:]
        token_type_ids_2 = token_type_ids[:,max_seq_length:]
        #outputs.append(self.smiles_model(input_ids=input_ids_2,
        #                                 attention_mask=attention_mask_2,
        #                                 return_dict=False))
        outputs.append(self.smiles_model(input_ids=input_ids_2,
                                         attention_mask=attention_mask_2,
                                         token_type_ids=token_type_ids_2,
                                         return_dict=False))

        # just get the [CLS] embeddings (first token in sequence)
        last_hidden_states = torch.cat([output[1] for output in outputs], dim=1)  # output is a tuple

        logits = self.cls(last_hidden_states).squeeze(-1)
        
        if labels is not None:
            # crossentropyloss: https://pytorch.org/docs/stable/nn.html#crossentropyloss
            loss_fct = torch.nn.MSELoss()
            loss = loss_fct(logits.view(-1, 1), labels.view(-1,1))
            return (loss, logits) 
        else:
            return output

Create the dataset

In [14]:
def expand_seqs(seqs):
    input_fixed = ["".join(seq.split()) for seq in seqs]
    input_fixed = [re.sub(r"[UZOB]", "X", seq) for seq in input_fixed]
    return [list(seq) for seq in input_fixed]

def expand_smiles(seqs):
    input_fixed = ["".join(seq.split()) for seq in seqs]
    input_fixed = [re.sub(r"[UZOB]", "X", seq) for seq in input_fixed]
    return [list(seq) for seq in input_fixed]

# on-the-fly tokenization
def encode(item):
        seq_encodings = seq_tokenizer(expand_seqs(item['seq'])[0],
                                     is_split_into_words=True,
                                     return_offsets_mapping=False,
                                     truncation=True,
                                     padding='max_length',
                                     add_special_tokens=True,
                                     max_length=max_seq_length)

        smiles_encodings = smiles_tokenizer(item['smiles'][0],
                                            padding='max_length',
                                            max_length=max_smiles_length,
                                            add_special_tokens=True,
                                            truncation=True)

        item['input_ids'] = [torch.cat([torch.tensor(seq_encodings['input_ids']),
#                                      torch.tensor(smiles_encodings['input_ids'][0])])]
                                        torch.tensor(smiles_encodings['input_ids'])])]
#        item['token_type_ids'] = [torch.tensor(seq_encodings['token_type_ids'])]
        item['token_type_ids'] = [torch.cat([torch.tensor(seq_encodings['token_type_ids']),
                                        torch.tensor(smiles_encodings['token_type_ids'])])]
        item['attention_mask'] = [torch.cat([torch.tensor(seq_encodings['attention_mask']),
                                            torch.tensor(smiles_encodings['attention_mask'])])]
        return item
        
class AffinityDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __getitem__(self, idx):
        item = self.dataset[idx]
        
        #affinity = item['neg_log10_affinity_M']
        affinity = item['affinity']
        #affinity = (affinity**lmbda-1)/lmbda # transform
        #item['labels'] = torch.tensor(float(affinity))
        item['labels'] = float(affinity)
        
        # drop the non-encoded input
        item.pop('smiles')
        item.pop('seq')
        item.pop('neg_log10_affinity_M')
        item.pop('affinity')
        return item

    def __len__(self):
        return len(self.dataset)

In [15]:
data_all = load_dataset("jglaser/binding_affinity",split='train')
f = 0.9
split = data_all.train_test_split(train_size=f)
train = split['train']
validation = split['test']
train.set_transform(encode)
validation.set_transform(encode)

Using custom data configuration default
Reusing dataset binding_affinity (/home/xvg/.cache/huggingface/datasets/binding_affinity/default/1.0.0/c5436e2040ce420c1c5fbc8df1a6b522e2ebf93d678786f4e761663a5ccaf89c)
Loading cached split indices for dataset at /home/xvg/.cache/huggingface/datasets/binding_affinity/default/1.0.0/c5436e2040ce420c1c5fbc8df1a6b522e2ebf93d678786f4e761663a5ccaf89c/cache-7667d8a5a3fe570d.arrow and /home/xvg/.cache/huggingface/datasets/binding_affinity/default/1.0.0/c5436e2040ce420c1c5fbc8df1a6b522e2ebf93d678786f4e761663a5ccaf89c/cache-e6151d8ec58389ee.arrow


In [16]:
data_all = load_dataset("jglaser/binding_affinity",split='train')
split = data_all.train_test_split(train_size=250,test_size=10)
train = split['train']
validation = split['test']
train.set_transform(encode)
validation.set_transform(encode)

Using custom data configuration default
Reusing dataset binding_affinity (/home/xvg/.cache/huggingface/datasets/binding_affinity/default/1.0.0/c5436e2040ce420c1c5fbc8df1a6b522e2ebf93d678786f4e761663a5ccaf89c)


In [17]:
# from scipy.stats import boxcox
# import matplotlib.pyplot as plt
# fig, ax = plt.subplots()
# data = np.array(split['train'][:]['neg_log10_affinity_M'])
# data_transform, lmbda = boxcox(data)
# ax.hist([data,data_transform])

In [18]:
train_dataset = AffinityDataset(train)
val_dataset = AffinityDataset(validation)

Define the evaluation metrics

In [19]:
def compute_metrics(p: EvalPrediction):
    preds_list, out_label_list = p.predictions, p.label_ids

    return {
        "mse": mean_squared_error(out_label_list, preds_list),
        "mae": mean_absolute_error(out_label_list, preds_list),
    }


Define the training args and start the trainer

In [20]:
def model_init():
    return EnsembleSequenceRegressor(seq_model_name, model_directory)

In [21]:
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=50,              # total number of training epochs
    per_device_train_batch_size=1,   # batch size per device during training
    per_device_eval_batch_size=1,    # batch size for evaluation
#    warmup_steps=200,                # number of warmup steps for learning rate scheduler
    warmup_steps=0,                # number of warmup steps for learning rate scheduler
    learning_rate=3e-05,#3e-03,             # learning rate
    weight_decay=0.0,                # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=1,                # How often to print logs
    do_train=True,                   # Perform training
    do_eval=True,                    # Perform evaluation
    evaluation_strategy="epoch",     # evalute after each epoch
    gradient_accumulation_steps=32,#128, # total number of steps before back propagation
    fp16=True,                       # Use mixed precision
    fp16_opt_level="02",             # mixed precision mode
    run_name="seq_smiles_affinity",     # experiment name
    seed=3,                          # Seed for experiment reproducibility
    load_best_model_at_end=True,
    metric_for_best_model="eval_mse",
    greater_is_better=False,
)

trainer = Trainer(
    model_init=model_init,                # the instantiated ðŸ¤— Transformers model to be trained
    args=training_args,                   # training arguments, defined above
    train_dataset=train_dataset,          # training dataset
    eval_dataset=val_dataset,             # evaluation dataset
    compute_metrics = compute_metrics,    # evaluation metric
)

trainer.train()

Some weights of the model checkpoint at Rostlab/prot_bert_bfd were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at /home/xvg/maskedevolution/models/bert_large_1B/model were not used when initial

Epoch,Training Loss,Validation Loss,Mse,Mae
1,0.9691,1.524236,1.524236,0.944886
2,0.9319,1.447094,1.447095,0.906003
3,0.9033,1.402751,1.402751,0.878541
4,0.9042,1.399923,1.399923,0.87816
5,0.8328,1.382009,1.382009,0.860786
6,0.8176,1.349485,1.349485,0.835797
7,0.7561,1.352248,1.352248,0.825396
8,0.7147,1.391916,1.391916,0.825916
9,0.6854,1.429179,1.429179,0.823149
10,0.6318,1.436488,1.436488,0.817318




RuntimeError: [enforce fail at inline_container.cc:274] . unexpected pos 25321984 vs 25321872

In [22]:
trainer.predict([train_dataset[i] for i in range(100)])



PredictionOutput(predictions=array([-0.60650367,  1.2087307 ,  1.2201444 , -0.6108015 , -0.5705229 ,
        1.2634574 ,  1.0566033 ,  0.72792447, -0.42828178, -0.56595844,
       -0.6310915 , -0.43187314, -0.5463716 ,  1.2415129 ,  1.116701  ,
       -0.55604273,  0.7806607 , -0.48442325, -0.18449609, -0.20868082,
       -0.62431455, -0.56523424, -0.5763343 ,  1.2617589 ,  0.59965086,
        1.2774117 , -0.5354856 , -0.5674191 ,  0.6502618 , -0.61222655,
        1.2677677 ,  1.2661772 , -0.58664244, -0.46732855,  1.2667332 ,
       -0.4920587 , -0.62003183,  0.39638826, -0.461332  ,  1.2721995 ,
        1.2018718 , -0.20630965,  0.38191348,  1.2586787 , -0.58568317,
        0.6763148 , -0.1173532 ,  0.5488733 , -0.6299681 ,  0.5065101 ,
       -0.58408207, -0.59855276,  1.2667154 ,  1.1689699 ,  0.4735074 ,
       -0.6322884 , -0.6086712 ,  1.172646  , -0.5534608 , -0.61082566,
       -0.55447704,  1.2280465 , -0.6101769 ,  1.1278055 , -0.6279281 ,
       -0.3384112 , -0.53190655,  1