In [1]:
import torch
import pandas as pd
import torch.nn as nn
import copy
import torch.optim as optim
import random
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import os
from datasets import load_dataset, load_metric
from transformers import BertModel, BertTokenizer, AdamW, get_linear_schedule_with_warmup

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
BENCHMARKS_DIR = 'data'
BENCHMARK_NAME = 'virus_us'
train_set_file_path = os.path.join(BENCHMARKS_DIR, '%s.train.csv' % BENCHMARK_NAME)
train_set = pd.read_csv(train_set_file_path).dropna().drop_duplicates()

valid_set_file_path = os.path.join(BENCHMARKS_DIR, '%s.valid.csv' % BENCHMARK_NAME)
valid_set = pd.read_csv(valid_set_file_path).dropna().drop_duplicates()

test_set_file_path = os.path.join(BENCHMARKS_DIR, '%s.test.csv' % BENCHMARK_NAME)
test_set = pd.read_csv(test_set_file_path).dropna().drop_duplicates()

print(f'{len(train_set)} training set records, {len(valid_set)} validation set records, {len(test_set)} test set records.')

4286 training set records, 535 validation set records, 535 test set records.


In [3]:
class CustomDataset(Dataset):

    def __init__(self, data, maxlen, with_labels=True, bert_model='Rostlab/prot_bert'):

        self.data = data  # pandas dataframe
        #Initialize the tokenizer
        self.tokenizer = BertTokenizer.from_pretrained(bert_model)  

        self.maxlen = maxlen
        self.with_labels = with_labels 

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        
        # Selecting sentence1 and sentence2 at the specified index in the data frame
        seq = str(self.data.iloc[index,0])

        # Tokenize the pair of sentences to get token ids, attention masks and token type ids
        encoded_input = self.tokenizer(seq, 
                                      padding='max_length',  # Pad to max_length
                                      truncation=True,  # Truncate to max_length
                                      max_length=self.maxlen,  
                                      return_tensors='pt')  # Return torch.Tensor objects
        
        token_ids = encoded_input['input_ids'].squeeze(0)  # tensor of token ids
        attn_masks = encoded_input['attention_mask'].squeeze(0)  # binary tensor with "0" for padded values and "1" for the other values
        token_type_ids = encoded_input['token_type_ids'].squeeze(0)  # binary tensor with "0" for the 1st sentence tokens & "1" for the 2nd sentence tokens

        if self.with_labels:  # True if the dataset has labels
            label = self.data.iloc[index,1]
            return token_ids, attn_masks, token_type_ids, label  
        else:
            return token_ids, attn_masks, token_type_ids

In [4]:
class ProteinRegressor(nn.Module):

    def __init__(self, bert_model="Rostlab/prot_bert", freeze_bert=False):
        super(ProteinRegressor, self).__init__()
        #  Instantiating BERT-based model object
        self.bert_layer = BertModel.from_pretrained(bert_model)

        #  Fix the hidden-state size of the encoder outputs (If you want to add other pre-trained models here, search for the encoder output size)
        if bert_model == "Rostlab/prot_bert":  # 12M parameters
            hidden_size = 1024

        # Freeze bert layers and only train the classification layer weights
        if freeze_bert:
            for p in self.bert_layer.parameters():
                p.requires_grad = False

        # Classification layer
        self.cls_layer = nn.Linear(hidden_size, 1)
        self.act = nn.ReLU()
        self.dropout = nn.Dropout(p=0.1)

    @autocast()  # run in mixed precision
    def forward(self, input_ids, attn_masks, token_type_ids):
        '''
        Inputs:
            -input_ids : Tensor  containing token ids
            -attn_masks : Tensor containing attention masks to be used to focus on non-padded values
            -token_type_ids : Tensor containing token type ids to be used to identify sentence1 and sentence2
        '''

        # Feeding the inputs to the BERT-based model to obtain contextualized representations
        cont_reps, pooler_output = self.bert_layer(input_ids, attn_masks, token_type_ids, return_dict=False)
        pooler_output = self.act(pooler_output)
        # Feeding to the classifier layer the last layer hidden-state of the [CLS] token further processed by a
        # Linear Layer and a Tanh activation. The Linear layer weights were trained from the sentence order prediction (ALBERT) or next sentence prediction (BERT)
        # objective during pre-training.
        logits = self.cls_layer(self.dropout(pooler_output))

        return logits

In [5]:
def set_seed(seed):
    """ Set all seeds to make results reproducible """
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    

def evaluate_loss(net, device, criterion, dataloader):
    net.eval()

    mean_loss = 0
    count = 0

    with torch.no_grad():
        for it, (seq, attn_masks, token_type_ids, labels) in enumerate(tqdm(dataloader)):
            seq, attn_masks, token_type_ids, labels = \
                seq.to(device), attn_masks.to(device), token_type_ids.to(device), labels.to(device)
            logits = net(seq, attn_masks, token_type_ids)
            mean_loss += criterion(logits, labels).item()
            count += 1

    return mean_loss / count

In [6]:
def train_bert(net, criterion, opti, lr, lr_scheduler, train_loader, val_loader, epochs, iters_to_accumulate):

    best_loss = np.Inf
    best_ep = 1
    nb_iterations = len(train_loader)
    print_every = nb_iterations // 5  # print the training loss 5 times per epoch
    iters = []
    train_losses = []
    val_losses = []

    scaler = GradScaler()

    for ep in range(epochs):

        net.train()
        running_loss = 0.0
        for it, (seq, attn_masks, token_type_ids, labels) in enumerate(tqdm(train_loader)):

            # Converting to cuda tensors
            seq, attn_masks, token_type_ids, labels = \
                seq.to(device), attn_masks.to(device), token_type_ids.to(device), labels.to(device)
    
            # Enables autocasting for the forward pass (model + loss)
            with autocast():
                # Obtaining the logits from the model
                logits = net(seq, attn_masks, token_type_ids)
                
                # Computing loss
                loss = criterion(logits, labels.to(torch.float32))
                
                loss = loss / iters_to_accumulate  # Normalize the loss because it is averaged
                # loss = loss.float()
            # Backpropagating the gradients
            # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
            scaler.scale(loss).backward()

            if (it + 1) % iters_to_accumulate == 0:
                # Optimization step
                # scaler.step() first unscales the gradients of the optimizer's assigned params.
                # If these gradients do not contain infs or NaNs, opti.step() is then called,
                # otherwise, opti.step() is skipped.
                scaler.step(opti)
                # Updates the scale for next iteration.
                scaler.update()
                # Adjust the learning rate based on the number of iterations.
                lr_scheduler.step()
                # Clear gradients
                opti.zero_grad()


            running_loss += loss.item()

            if (it + 1) % print_every == 0:  # Print training loss information
                print()
                print("Iteration {}/{} of epoch {} complete. Loss : {} "
                      .format(it+1, nb_iterations, ep+1, running_loss / print_every))

                running_loss = 0.0


        val_loss = evaluate_loss(net, device, criterion, val_loader)  # Compute validation loss
        print()
        print("Epoch {} complete! Validation Loss : {}".format(ep+1, val_loss))

        if val_loss < best_loss:
            print("Best validation loss improved from {} to {}".format(best_loss, val_loss))
            print()
            net_copy = copy.deepcopy(net)  # save a copy of the model
            best_loss = val_loss
            best_ep = ep + 1

    # Saving the model
    path_to_model='models/{}_lr_{}_val_loss_{}_ep_{}.pt'.format(bert_model, lr, round(best_loss, 5), best_ep)
    torch.save(net_copy.state_dict(), path_to_model)
    print("The model has been saved in {}".format(path_to_model))

    del loss
    torch.cuda.empty_cache()

In [7]:
bert_model = "Rostlab/prot_bert"  # 'albert-base-v2', 'albert-large-v2', 'albert-xlarge-v2', 'albert-xxlarge-v2', 'bert-base-uncased', ...
freeze_bert = False  # if True, freeze the encoder weights and only update the classification layer weights
maxlen = 1024  # maximum length of the tokenized input sentence pair : if greater than "maxlen", the input is truncated and else if smaller, the input is padded
bs = 8  # batch size
iters_to_accumulate = 1  # the gradient accumulation adds gradients over an effective batch of size : bs * iters_to_accumulate. If set to "1", you get the usual batch size
lr = 1e-5  # learning rate
epochs = 4  # number of training epochs

In [8]:
set_seed(1)

# Creating instances of training and validation set
print("Reading training data...")
train_set = CustomDataset(train_set, maxlen, bert_model)
print("Reading validation data...")
val_set = CustomDataset(valid_set, maxlen, bert_model)
# Creating instances of training and validation dataloaders
train_loader = DataLoader(train_set, batch_size=bs, num_workers=5)
val_loader = DataLoader(val_set, batch_size=bs, num_workers=5)


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = ProteinRegressor(bert_model, freeze_bert=freeze_bert)

if torch.cuda.device_count() > 1:  # if multiple GPUs
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    net = nn.DataParallel(net)

net.to(device)

criterion = nn.MSELoss()
opti = AdamW(net.parameters(), lr=lr, weight_decay=1e-2)
num_warmup_steps = 0 # The number of steps for the warmup phase.
num_training_steps = epochs * len(train_loader)  # The total number of training steps
t_total = (len(train_loader) // iters_to_accumulate) * epochs  # Necessary to take into account Gradient accumulation
lr_scheduler = get_linear_schedule_with_warmup(optimizer=opti, num_warmup_steps=num_warmup_steps, num_training_steps=t_total)

train_bert(net, criterion, opti, lr, lr_scheduler, train_loader, val_loader, epochs, iters_to_accumulate)


Reading training data...
Reading validation data...


Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Let's use 3 GPUs!


  return F.mse_loss(input, target, reduction=self.reduction)
 20%|█▉        | 107/536 [02:05<08:04,  1.13s/it]


Iteration 107/536 of epoch 1 complete. Loss : 0.0516243433843913 


 40%|███▉      | 214/536 [04:06<06:04,  1.13s/it]


Iteration 214/536 of epoch 1 complete. Loss : 0.0008869569113999757 


 60%|█████▉    | 321/536 [06:07<04:03,  1.13s/it]


Iteration 321/536 of epoch 1 complete. Loss : 0.0006275077835750814 


 80%|███████▉  | 428/536 [08:08<02:02,  1.13s/it]


Iteration 428/536 of epoch 1 complete. Loss : 0.0005513824891121122 


100%|█████████▉| 535/536 [10:09<00:01,  1.13s/it]


Iteration 535/536 of epoch 1 complete. Loss : 0.00031981009679184696 


  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 536/536 [10:10<00:00,  1.14s/it]
  return F.mse_loss(input, target, reduction=self.reduction)
100%|██████████| 67/67 [00:47<00:00,  1.41it/s]



Epoch 1 complete! Validation Loss : 0.00027634124520057365
Best validation loss improved from inf to 0.00027634124520057365



 20%|█▉        | 107/536 [02:01<08:05,  1.13s/it]


Iteration 107/536 of epoch 2 complete. Loss : 0.0003368130493182885 


 40%|███▉      | 214/536 [04:02<06:04,  1.13s/it]


Iteration 214/536 of epoch 2 complete. Loss : 0.0003339804787123217 


 60%|█████▉    | 321/536 [06:03<04:03,  1.13s/it]


Iteration 321/536 of epoch 2 complete. Loss : 0.0003616311884551553 


 80%|███████▉  | 428/536 [08:04<02:02,  1.13s/it]


Iteration 428/536 of epoch 2 complete. Loss : 0.0004652953766963377 


100%|█████████▉| 535/536 [10:05<00:01,  1.13s/it]


Iteration 535/536 of epoch 2 complete. Loss : 0.00028460957683164794 


100%|██████████| 536/536 [10:06<00:00,  1.13s/it]
100%|██████████| 67/67 [00:47<00:00,  1.41it/s]



Epoch 2 complete! Validation Loss : 0.00028111797923687046


 20%|█▉        | 107/536 [02:01<08:05,  1.13s/it]


Iteration 107/536 of epoch 3 complete. Loss : 0.000273900016602787 


 40%|███▉      | 214/536 [04:02<06:04,  1.13s/it]


Iteration 214/536 of epoch 3 complete. Loss : 0.0003024631797314078 


 60%|█████▉    | 321/536 [06:03<04:03,  1.13s/it]


Iteration 321/536 of epoch 3 complete. Loss : 0.0003294086059935937 


 80%|███████▉  | 428/536 [08:04<02:02,  1.13s/it]


Iteration 428/536 of epoch 3 complete. Loss : 0.0004155317215646332 


100%|█████████▉| 535/536 [10:05<00:01,  1.13s/it]


Iteration 535/536 of epoch 3 complete. Loss : 0.0002901788603789967 


100%|██████████| 536/536 [10:06<00:00,  1.13s/it]
 16%|█▋        | 11/67 [00:08<00:39,  1.41it/s]

: 

: 