# Directory settings

In [1]:
VER = 4

In [2]:
# ====================================================
# Directory settings
# ====================================================
import os

OUTPUT_DIR =  f'/home/wangjingqi/input/ck/nesp'
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

# CFG

In [3]:
# ====================================================
# CFG
# ====================================================
class CFG:
    wandb=False
    competition='FB3'
    _wandb_kernel='nakama'
    debug=False
    apex=True
    print_freq=20
    num_workers=0
    #model="Rostlab/prot_bert"
    model="facebook/esm2_t33_650M_UR50D" 
    gradient_checkpointing=False
    scheduler='constant' # ['linear', 'cosine', 'constant']
    batch_scheduler=True
    num_warmup_steps=0
    
    # LEARNING RATE. 
    # Suggested: Prot_bert = 5e-6, ESM2 = 5e-5
    epochs=1
    num_cycles=1.0
    encoder_lr=5e-5
    decoder_lr=5e-5
    batch_size=8
    
    # MODEL INFO - PROT_BERT
    total_layers = 30 
    initial_layers = 5 
    layers_per_block = 16 
    # MODEL INFO - FACEBOOK ESM2
    if 'esm2' in model:
        total_layers = int(model.split('_')[1][1:])
        initial_layers = 2 
        layers_per_block = 15 
        
    # FREEZE
    # Suggested: Prot_bert -8, ESM2 -1
    num_freeze_layers = total_layers - 1
    # NO FREEZE
    #num_freeze_layers = 0
    
    min_lr=1e-6
    eps=1e-6
    betas=(0.9, 0.999)
    max_len=512
    weight_decay=0.01
    gradient_accumulation_steps=1
    max_grad_norm=1000
    target_cols=['target']
    seed=42
    n_fold=5
    trn_fold=[0,1,2,3,4]
    train=True
    pca_dim = 64
    device = 2
if CFG.debug:
    CFG.epochs = 2
    CFG.trn_fold = [0]

# Library

In [4]:
# ====================================================
# Library
# ====================================================
import os
import gc
import re
import ast
import sys
import copy
import json
import time
import math
import string
import pickle
import random
import joblib
import itertools
import warnings
warnings.filterwarnings("ignore")

import scipy as sp
import numpy as np
import pandas as pd
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
from tqdm.auto import tqdm
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold

#os.system('pip install iterative-stratification==0.1.7')
#from iterstrat.ml_stratifiers import MultilabelStratifiedKFold

import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.optim import Adam, SGD, AdamW
from torch.utils.data import DataLoader, Dataset

#os.system('pip uninstall -y transformers')
#os.system('pip uninstall -y tokenizers')
#os.system('python -m pip install --no-index --find-links=../input/fb3-pip-wheels transformers')
#os.system('python -m pip install --no-index --find-links=../input/fb3-pip-wheels tokenizers')
# os.system('pip install transformers --upgrade')
#os.system('pip install tokenizers --upgrade')

import tokenizers
import transformers
print(f"tokenizers.__version__: {tokenizers.__version__}")
print(f"transformers.__version__: {transformers.__version__}")
from transformers import AutoTokenizer, AutoModel, AutoConfig
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
from transformers import get_constant_schedule_with_warmup
%env TOKENIZERS_PARALLELISM=true

device =CFG.device

tokenizers.__version__: 0.12.1
transformers.__version__: 4.23.1
env: TOKENIZERS_PARALLELISM=true


# Utils

In [5]:
# ====================================================
# Utils
# ====================================================
def MCRMSE(y_trues, y_preds):
    scores = []
    idxes = y_trues.shape[1]
    for i in range(idxes):
        y_true = y_trues[:,i]
        y_pred = y_preds[:,i]
        score = mean_squared_error(y_true, y_pred, squared=False) # RMSE
        scores.append(score)
    mcrmse_score = np.mean(scores)
    return mcrmse_score, scores


def get_score(y_trues, y_preds):
    mcrmse_score, scores = MCRMSE(y_trues, y_preds)
    return mcrmse_score, scores


def get_logger(filename=OUTPUT_DIR+'train'):
    from logging import getLogger, INFO, StreamHandler, FileHandler, Formatter
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=f"{filename}.log")
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = get_logger()


def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
seed_everything(seed=42)

# Data Loading

In [6]:
train = pd.read_csv('/home/wangjingqi/nesp/all_train_data_v17.csv')
train = train.loc[train.source.str.contains('jin')]
print('Train has shape:', train.shape )
train.head()

Train has shape: (5072, 10)


Unnamed: 0,PDB,wildtype,position,mutation,ddG,sequence,mutant_seq,source,dTm,CIF
0,1A5E,L,121,R,0.66,MEPAAGSSMEPSADWLATAAARGRVEEVRALLEAGALPNAPNSYGR...,MEPAAGSSMEPSADWLATAAARGRVEEVRALLEAGALPNAPNSYGR...,jin_train.csv,,
1,1A5E,L,37,S,0.71,MEPAAGSSMEPSADWLATAAARGRVEEVRALLEAGALPNAPNSYGR...,MEPAAGSSMEPSADWLATAAARGRVEEVRALLEAGASPNAPNSYGR...,jin_train.csv,,
2,1A5E,W,15,D,0.17,MEPAAGSSMEPSADWLATAAARGRVEEVRALLEAGALPNAPNSYGR...,MEPAAGSSMEPSADDLATAAARGRVEEVRALLEAGALPNAPNSYGR...,jin_train.csv,,
3,1A5E,D,74,N,-2.0,MEPAAGSSMEPSADWLATAAARGRVEEVRALLEAGALPNAPNSYGR...,MEPAAGSSMEPSADWLATAAARGRVEEVRALLEAGALPNAPNSYGR...,jin_train.csv,,
4,1A5E,P,81,L,0.0,MEPAAGSSMEPSADWLATAAARGRVEEVRALLEAGALPNAPNSYGR...,MEPAAGSSMEPSADWLATAAARGRVEEVRALLEAGALPNAPNSYGR...,jin_train.csv,,


In [7]:
def add_spaces(x):
    return " ".join(list(x))
train.sequence = train.sequence.map(add_spaces)
train.mutant_seq = train.mutant_seq.map(add_spaces)
print('Train has shape',train.shape)
train.head()

Train has shape (5072, 10)


Unnamed: 0,PDB,wildtype,position,mutation,ddG,sequence,mutant_seq,source,dTm,CIF
0,1A5E,L,121,R,0.66,M E P A A G S S M E P S A D W L A T A A A R G ...,M E P A A G S S M E P S A D W L A T A A A R G ...,jin_train.csv,,
1,1A5E,L,37,S,0.71,M E P A A G S S M E P S A D W L A T A A A R G ...,M E P A A G S S M E P S A D W L A T A A A R G ...,jin_train.csv,,
2,1A5E,W,15,D,0.17,M E P A A G S S M E P S A D W L A T A A A R G ...,M E P A A G S S M E P S A D D L A T A A A R G ...,jin_train.csv,,
3,1A5E,D,74,N,-2.0,M E P A A G S S M E P S A D W L A T A A A R G ...,M E P A A G S S M E P S A D W L A T A A A R G ...,jin_train.csv,,
4,1A5E,P,81,L,0.0,M E P A A G S S M E P S A D W L A T A A A R G ...,M E P A A G S S M E P S A D W L A T A A A R G ...,jin_train.csv,,


In [8]:
REMOVE_CT = 20
train['n'] = train.groupby('PDB').PDB.transform('count')
train = train.loc[train.n>REMOVE_CT]
print(f'After removing mutation groups with less than {REMOVE_CT} mutations, our data has shape:')
print( train.shape )

After removing mutation groups with less than 20 mutations, our data has shape:
(4085, 11)


In [9]:
# RANK NORMALIZE EACH GROUP OF MUTATION PROTEINS
from scipy.stats import rankdata
train['target'] = 0.5
for p in train.PDB.unique():
    target = 'dTm'
    tmp = train.loc[train.PDB==p,'dTm']
    if tmp.isna().sum()>len(tmp)/2: target = 'ddG'
    train.loc[train.PDB==p,'target'] =\
        rankdata( train.loc[train.PDB==p,target] )/len( train.loc[train.PDB==p,target] )
train = train.reset_index(drop=True)

In [10]:
print('Our train data has',train.PDB.nunique(),'unique mutation groups')

Our train data has 61 unique mutation groups


# CV split

In [11]:
from sklearn.model_selection import GroupKFold

In [12]:
# ====================================================
# CV split
# ====================================================
Fold = GroupKFold(n_splits=CFG.n_fold) 
for n, (train_index, val_index) in enumerate(Fold.split(train, train[CFG.target_cols], train.PDB)):
    train.loc[val_index, 'fold'] = int(n)
train['fold'] = train['fold'].astype(int)
display(train.groupby('fold').size())

fold
0    813
1    832
2    815
3    813
4    812
dtype: int64

In [13]:
if CFG.debug:
    display(train.groupby('fold').size())
    train = train.sample(n=1000, random_state=0).reset_index(drop=True)
    display(train.groupby('fold').size())

# Tokenizer

In [14]:
# ====================================================
# tokenizer
# ====================================================
tokenizer = AutoTokenizer.from_pretrained(CFG.model)
tokenizer.save_pretrained(OUTPUT_DIR+'tokenizer/')
CFG.tokenizer = tokenizer

In [15]:
tokenizer

PreTrainedTokenizer(name_or_path='facebook/esm2_t33_650M_UR50D', vocab_size=33, model_max_len=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'eos_token': '<eos>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'cls_token': '<cls>', 'mask_token': '<mask>'})

# Dataset

In [16]:
# ====================================================
# Dataset
# ====================================================
def prepare_input(cfg, text):
    inputs = cfg.tokenizer.encode_plus(
        text, 
        return_tensors=None, 
        add_special_tokens=True, 
        max_length=cfg.max_len,
        pad_to_max_length=True,
        truncation=True
    )
    for k, v in inputs.items():
        inputs[k] = torch.tensor(v, dtype=torch.long)
    return inputs

def get_context(cfg, inputs, position):
    context = []
    for i in range(position-cfg.context_size, position+cfg.context_size+1):
        if i < 0 or i >= len(inputs["input_ids"][0]):
            context.append(cfg.tokenizer.pad_token_id)
        else:
            context.append(inputs["input_ids"][0][i])
    return torch.tensor(context, dtype=torch.long)
class TrainDataset(Dataset):
    def __init__(self, cfg, df):
        self.cfg = cfg
        self.texts1 = df['sequence'].values
        self.texts2 = df['mutant_seq'].values
        self.labels = df[cfg.target_cols].values
        self.position = df['position'].values

    def __len__(self):
        return len(self.texts1)

    def __getitem__(self, item):
        inputs1 = prepare_input(self.cfg, self.texts1[item])
        inputs2 = prepare_input(self.cfg, self.texts2[item])
        position = np.zeros(self.cfg.max_len)
        position[self.position[item]] = 1
        position = torch.tensor(position, dtype=torch.int)
        label = torch.tensor(self.labels[item], dtype=torch.float)
        return inputs1, inputs2, position, label
    

# EDA Dataloader
Below displays the output of our data loader using a fake data example. Our data loader provides us with "two sentences" (i.e. both the wild type sequence and mutation sequence) and it provides us with the location of the mutation. Our model will use all this information. Note in the output that the first token id is `<cls>` and the last token id is `<eos>` and the tokens inbetween are the amino acids.

In [17]:
fake = pd.DataFrame(columns=['sequence','mutant_seq','position','target'])
fake['sequence'] = ['V P V N P E','V P V N P E']
fake['mutant_seq'] = ['V P A N P E','V P V N P C']
fake['position'] = [3,6]
fake['target'] = [0.2,0.8]
fake.head()

Unnamed: 0,sequence,mutant_seq,position,target
0,V P V N P E,V P A N P E,3,0.2
1,V P V N P E,V P V N P C,6,0.8


In [18]:
class CFG_fake:
    max_len = 8
    target_cols = ['target']
    tokenizer = CFG.tokenizer
    
train_dataset = TrainDataset(CFG_fake, fake)
train_loader = DataLoader(train_dataset,
                          batch_size=2,
                          shuffle=False)
for inputs1,inputs2,position,label in train_loader:
    break
    
print('Inputs1 =')
print(inputs1,'\n')
print('Inputs2 =')
print(inputs2,'\n')
print('Position =')
print(position,'\n')
print('Label =')
print(label)

Inputs1 =
{'input_ids': tensor([[ 0,  7, 14,  7, 17, 14,  9,  2],
        [ 0,  7, 14,  7, 17, 14,  9,  2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1]])} 

Inputs2 =
{'input_ids': tensor([[ 0,  7, 14,  5, 17, 14,  9,  2],
        [ 0,  7, 14,  7, 17, 14, 23,  2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1]])} 

Position =
tensor([[0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0]], dtype=torch.int32) 

Label =
tensor([[0.2000],
        [0.8000]])


# Model
The model architecture takes both wild type sequence and mutation sequence and location of mutation as input. Then it subtracts their embeddings and concatenates that with their individual embeddings. It does this with both single mutation token position and mean pooling of entire sequence. Finally a dense layer makes the regression prediction. See the model architecture below.

In [19]:
# ====================================================
# Model
# ====================================================
class MeanPooling(nn.Module):
    def __init__(self):
        super(MeanPooling, self).__init__()
        
    def forward(self, last_hidden_state, attention_mask):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
        sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
        sum_mask = input_mask_expanded.sum(1)
        sum_mask = torch.clamp(sum_mask, min=1e-9)
        mean_embeddings = sum_embeddings / sum_mask
        return mean_embeddings
    

class CustomModel(nn.Module):
    def __init__(self, cfg, config_path=None, pretrained=False):
        super().__init__()
        self.cfg = cfg
        if config_path is None:
            self.config = AutoConfig.from_pretrained(cfg.model, output_hidden_states=True)
            self.config.hidden_dropout = 0.
            self.config.hidden_dropout_prob = 0.
            self.config.attention_dropout = 0.
            self.config.attention_probs_dropout_prob = 0.
            LOGGER.info(self.config)
        else:
            self.config = torch.load(config_path)
            #self.config = AutoConfig.from_pretrained(config_path)
        if pretrained:
            self.model = AutoModel.from_pretrained(cfg.model, config=self.config)
        else:
            self.model = AutoModel.from_config(self.config)
            
        if self.cfg.gradient_checkpointing:
            self.model.gradient_checkpointing_enable()
        self.pool = MeanPooling()
        self.fc1 = nn.Linear(self.config.hidden_size, self.cfg.pca_dim)
        self.fc2 = nn.Linear(self.cfg.pca_dim*6, 1)
        self._init_weights(self.fc1)
        self._init_weights(self.fc2)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        
    def feature(self, inputs, position):
        outputs = self.model(**inputs)
        last_hidden_states = outputs[0]
        feature = self.pool(last_hidden_states, position)
        return feature

    def forward(self, inputs1, inputs2, position):
        feature1 = self.fc1(self.feature(inputs1, position))
        feature2 = self.fc1(self.feature(inputs2, position))
        
        feature3 = self.fc1(self.feature(inputs1, inputs1['attention_mask']))
        feature4 = self.fc1(self.feature(inputs2, inputs2['attention_mask']))        
        
        feature = torch.cat((feature1, feature2, feature2 - feature1,
                             feature3, feature4, feature4 - feature3),axis=-1)

        output = self.fc2(feature)
        return output

# Loss

In [20]:
# ====================================================
# Loss
# ====================================================
class RMSELoss(nn.Module):
    def __init__(self, reduction='mean', eps=1e-9):
        super().__init__()
        self.mse = nn.MSELoss(reduction='none')
        self.reduction = reduction
        self.eps = eps

    def forward(self, y_pred, y_true):
        loss = torch.sqrt(self.mse(y_pred, y_true) + self.eps)
        if self.reduction == 'none':
            loss = loss
        elif self.reduction == 'sum':
            loss = loss.sum()
        elif self.reduction == 'mean':
            loss = loss.mean()
        return loss

# Helpler functions

In [21]:
# ====================================================
# Helper functions
# ====================================================
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))


def train_fn(fold, train_loader, model, criterion, optimizer, epoch, scheduler, device):
    model.train()
    scaler = torch.cuda.amp.GradScaler(enabled=CFG.apex)
    losses = AverageMeter()
    start = end = time.time()
    global_step = 0
    for step, (inputs1, inputs2, position, labels) in enumerate(train_loader):
     
        for k, v in inputs1.items():
            inputs1[k] = v.to(device)

        for k, v in inputs2.items():
            inputs2[k] = v.to(device)
        position = position.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        with torch.cuda.amp.autocast(enabled=CFG.apex):
            y_preds = model(inputs1,inputs2,position)
            loss = criterion(y_preds, labels)
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps
        losses.update(loss.item(), batch_size)
        scaler.scale(loss).backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)
        if (step + 1) % CFG.gradient_accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            global_step += 1
            if CFG.batch_scheduler:
                scheduler.step()
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(train_loader)-1):
            print('Epoch: [{0}][{1}/{2}] '
                  'Elapsed {remain:s} '
                  'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                  'Grad: {grad_norm:.4f}  '
                  'LR: {lr:.8f}  '
                  .format(epoch+1, step, len(train_loader), 
                          remain=timeSince(start, float(step+1)/len(train_loader)),
                          loss=losses,
                          grad_norm=grad_norm,
                          lr=scheduler.get_lr()[0]))
        if CFG.wandb:
            wandb.log({f"[fold{fold}] loss": losses.val,
                       f"[fold{fold}] lr": scheduler.get_lr()[0]})
    return losses.avg


def valid_fn(valid_loader, model, criterion, device):
    losses = AverageMeter()
    model.eval()
    preds = []
    start = end = time.time()
    for step, (inputs1, inputs2, position, labels) in enumerate(valid_loader):
   
        for k, v in inputs1.items():
            inputs1[k] = v.to(device)
        for k, v in inputs2.items():
            inputs2[k] = v.to(device)
        position = position.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        with torch.no_grad():
            y_preds = model(inputs1,inputs2,position)
            loss = criterion(y_preds, labels)
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps
        losses.update(loss.item(), batch_size)
        preds.append(y_preds.to('cpu').numpy())
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(valid_loader)-1):
            print('EVAL: [{0}/{1}] '
                  'Elapsed {remain:s} '
                  'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                  .format(step, len(valid_loader),
                          loss=losses,
                          remain=timeSince(start, float(step+1)/len(valid_loader))))
    predictions = np.concatenate(preds)
    return losses.avg, predictions

# Train loop

In [22]:
# ====================================================
# train loop
# ====================================================
def train_loop(folds, fold):
    
    LOGGER.info(f"========== fold: {fold} training ==========")

    # ====================================================
    # loader
    # ====================================================
    train_folds = folds[folds['fold'] != fold].reset_index(drop=True)
    valid_folds = folds[folds['fold'] == fold].reset_index(drop=True)
    valid_labels = valid_folds[CFG.target_cols].values
    print('### train shape:',train_folds.shape)
    
    train_dataset = TrainDataset(CFG, train_folds)
    valid_dataset = TrainDataset(CFG, valid_folds)

    train_loader = DataLoader(train_dataset,
                              batch_size=CFG.batch_size,
                              shuffle=True,
                              num_workers=CFG.num_workers, pin_memory=True, drop_last=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=CFG.batch_size * 2,
                              shuffle=False,
                              num_workers=CFG.num_workers, pin_memory=True, drop_last=False)

    # ====================================================
    # model & optimizer
    # ====================================================
    model = CustomModel(CFG, config_path=None, pretrained=True)
    torch.save(model.config, OUTPUT_DIR+'config.pth')
    
    # FREEZE LAYERS
    if CFG.num_freeze_layers>0:
        print(f'### Freezing first {CFG.num_freeze_layers} layers.',
              f'Leaving {CFG.total_layers-CFG.num_freeze_layers} layers unfrozen')
        for name, param in list(model.named_parameters())\
            [:CFG.initial_layers+CFG.layers_per_block*CFG.num_freeze_layers]:     
                param.requires_grad = False
    model.to(device)
    
    def get_optimizer_params(model, encoder_lr, decoder_lr, weight_decay=0.0):
        param_optimizer = list(model.named_parameters())
        no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
        optimizer_parameters = [
            {'params': [p for n, p in model.model.named_parameters() if not any(nd in n for nd in no_decay)],
             'lr': encoder_lr, 'weight_decay': weight_decay},
            {'params': [p for n, p in model.model.named_parameters() if any(nd in n for nd in no_decay)],
             'lr': encoder_lr, 'weight_decay': 0.0},
            {'params': [p for n, p in model.named_parameters() if "model" not in n],
             'lr': decoder_lr, 'weight_decay': 0.0}
        ]
        return optimizer_parameters

    optimizer_parameters = get_optimizer_params(model,
                                                encoder_lr=CFG.encoder_lr, 
                                                decoder_lr=CFG.decoder_lr,
                                                weight_decay=CFG.weight_decay)
    optimizer = AdamW(optimizer_parameters, lr=CFG.encoder_lr, eps=CFG.eps, betas=CFG.betas)
    
    # ====================================================
    # scheduler
    # ====================================================
    def get_scheduler(cfg, optimizer, num_train_steps):
        if cfg.scheduler == 'linear':
            scheduler = get_linear_schedule_with_warmup(
                optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps
            )
        elif cfg.scheduler == 'constant':
            scheduler = get_constant_schedule_with_warmup(
                optimizer, num_warmup_steps=cfg.num_warmup_steps
            )
        elif cfg.scheduler == 'cosine':
            scheduler = get_cosine_schedule_with_warmup(
                optimizer, num_warmup_steps=cfg.num_warmup_steps, num_training_steps=num_train_steps, num_cycles=cfg.num_cycles
            )
        return scheduler
    
    num_train_steps = int(len(train_folds) / CFG.batch_size * CFG.epochs)
    scheduler = get_scheduler(CFG, optimizer, num_train_steps)

    # ====================================================
    # loop
    # ====================================================
    criterion = RMSELoss(reduction="mean") #nn.SmoothL1Loss(reduction='mean')
    
    best_score = np.inf

    for epoch in range(CFG.epochs):

        start_time = time.time()

        # train
        avg_loss = train_fn(fold, train_loader, model, criterion, optimizer, epoch, scheduler, device)

        # eval
        avg_val_loss, predictions = valid_fn(valid_loader, model, criterion, device)
        
        # scoring
        score, scores = get_score(valid_labels, predictions)

        elapsed = time.time() - start_time

        LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')
        LOGGER.info(f'Epoch {epoch+1} - Score: {score:.4f}  Scores: {scores}')
        if CFG.wandb:
            wandb.log({f"[fold{fold}] epoch": epoch+1, 
                       f"[fold{fold}] avg_train_loss": avg_loss, 
                       f"[fold{fold}] avg_val_loss": avg_val_loss,
                       f"[fold{fold}] score": score})
        
        if best_score > score:
            best_score = score
            LOGGER.info(f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model')
            torch.save({'model': model.state_dict(),
                        'predictions': predictions},
                        OUTPUT_DIR+f"{CFG.model.replace('/', '-')}_fold{fold}_best.pth")

    predictions = torch.load(OUTPUT_DIR+f"{CFG.model.replace('/', '-')}_fold{fold}_best.pth", 
                             map_location=torch.device('cpu'))['predictions']
    valid_folds[[f"pred_{c}" for c in CFG.target_cols]] = predictions

    torch.cuda.empty_cache()
    gc.collect()
    
    return valid_folds

In [23]:
if __name__ == '__main__':
    
    def get_result(oof_df):
        labels = oof_df[CFG.target_cols].values
        preds = oof_df[[f"pred_{c}" for c in CFG.target_cols]].values
        score, scores = get_score(labels, preds)
        LOGGER.info(f'Score: {score:<.4f}  Scores: {scores}')
    
    if CFG.train:
        oof_df = pd.DataFrame()
        for fold in range(CFG.n_fold):
            if fold in CFG.trn_fold:
                _oof_df = train_loop(train, fold)
                oof_df = pd.concat([oof_df, _oof_df])
                LOGGER.info(f"========== fold: {fold} result ==========")
                get_result(_oof_df)
        oof_df = oof_df.reset_index(drop=True)
        LOGGER.info(f"========== CV ==========")
        get_result(oof_df)
        oof_df.to_pickle(OUTPUT_DIR+'oof_df.pkl')
        
    if CFG.wandb:
        wandb.finish()



### train shape: (3272, 13)


EsmConfig {
  "_name_or_path": "facebook/esm2_t33_650M_UR50D",
  "architectures": [
    "EsmForMaskedLM"
  ],
  "attention_dropout": 0.0,
  "attention_probs_dropout_prob": 0.0,
  "classifier_dropout": null,
  "emb_layer_norm_before": false,
  "hidden_act": "gelu",
  "hidden_dropout": 0.0,
  "hidden_dropout_prob": 0.0,
  "hidden_size": 1280,
  "initializer_range": 0.02,
  "intermediate_size": 5120,
  "layer_norm_eps": 1e-05,
  "mask_token_id": 32,
  "max_position_embeddings": 1026,
  "model_type": "esm",
  "num_attention_heads": 20,
  "num_hidden_layers": 33,
  "output_hidden_states": true,
  "pad_token_id": 1,
  "position_embedding_type": "rotary",
  "token_dropout": true,
  "torch_dtype": "float32",
  "transformers_version": "4.23.1",
  "use_cache": true,
  "vocab_size": 33
}

Some weights of the model checkpoint at facebook/esm2_t33_650M_UR50D were not used when initializing EsmModel: ['lm_head.dense.weight', 'lm_head.decoder.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 

### Freezing first 32 layers. Leaving 1 layers unfrozen
Epoch: [1][0/409] Elapsed 0m 1s (remain 8m 55s) Loss: 0.4600(0.4600) Grad: inf  LR: 0.00005000  
Epoch: [1][20/409] Elapsed 0m 18s (remain 5m 38s) Loss: 0.4344(0.4696) Grad: 78413.0859  LR: 0.00005000  
Epoch: [1][40/409] Elapsed 0m 35s (remain 5m 17s) Loss: 0.3782(0.3983) Grad: 41116.3281  LR: 0.00005000  
Epoch: [1][60/409] Elapsed 0m 52s (remain 4m 59s) Loss: 0.2951(0.3591) Grad: 44885.5898  LR: 0.00005000  
Epoch: [1][80/409] Elapsed 1m 9s (remain 4m 41s) Loss: 0.1501(0.3304) Grad: 36153.2695  LR: 0.00005000  
Epoch: [1][100/409] Elapsed 1m 26s (remain 4m 24s) Loss: 0.2551(0.3137) Grad: 108206.0625  LR: 0.00005000  
Epoch: [1][120/409] Elapsed 1m 43s (remain 4m 7s) Loss: 0.2095(0.3002) Grad: 49134.2617  LR: 0.00005000  
Epoch: [1][140/409] Elapsed 2m 1s (remain 3m 50s) Loss: 0.2563(0.2931) Grad: 31308.6895  LR: 0.00005000  
Epoch: [1][160/409] Elapsed 2m 18s (remain 3m 33s) Loss: 0.2159(0.2869) Grad: 80354.9219  LR: 0.00005000

# Compute OOF Score

In [None]:
oof_df.head()

In [None]:
print('Here are CV scores for each mutation group in OOF:')
from scipy.stats import spearmanr
rs = []
for p in oof_df.PDB.unique():
    tmp = oof_df.loc[oof_df.PDB==p]
    
    target = 'dTm'
    if tmp.dTm.isna().sum()>len(tmp)/2: target = 'ddG'
    
    r = np.abs( spearmanr( tmp.target, tmp.pred_target ).correlation )
    print(p,'ct=',len(tmp),'r=',r,'type=',target)
    rs.append(r)
print('==> CV',np.mean(rs))

# Infer Test Data

In [None]:
CFG.path = OUTPUT_DIR
CFG.config_path = CFG.path+'config.pth'

In [None]:
class TestDataset(Dataset):
    def __init__(self, cfg, df):
        self.cfg = cfg
        self.texts1 = df['sequence'].values
        self.texts2 = df['mutant_seq'].values
        self.position = df['position'].values

    def __len__(self):
        return len(self.texts1)

    def __getitem__(self, item):
        inputs1 = prepare_input(self.cfg, self.texts1[item])
        inputs2 = prepare_input(self.cfg, self.texts2[item])
        position = np.zeros(self.cfg.max_len)
        position[self.position[item]] = 1
        position = torch.tensor(position, dtype=torch.int)
        return inputs1, inputs2, position

In [None]:
# ====================================================
# inference
# ====================================================
def inference_fn(test_loader, model, device):
    preds = []
    model.eval()
    model.to(device)
    tk0 = tqdm(test_loader, total=len(test_loader))
    for (inputs1,inputs2, position) in tk0:
        for k, v in inputs1.items():
            inputs1[k] = v.to(device)
        for k, v in inputs2.items():
            inputs2[k] = v.to(device)
        position = position.to(device)
        with torch.no_grad():
            y_preds = model(inputs1,inputs2,position)
        preds.append(y_preds.to('cpu').numpy())
    predictions = np.concatenate(preds)
    return predictions

In [None]:
# LOAD TEST WILDTYPE
base = 'VPVNPEPDATSVENVALKTGSGDSQSDPIKADLEVKGQSALPFDVDCWAILCKGAPNVLQRVNEKTKNSNRDRSGANKGPFKDPQKWGIKALPPKNPSWSAQDFKSPEEYAFASSLQGGTNAILAPVNLASQNSQGGVLNGFYSANKVAQFDPSKPQQTKGTWFQITKFTGAAGPYCKALGSNDKSVCDKNKNIAGDWGFDPAKWAYQYDEKNNKFNYVGK'
len(base)

In [None]:
def get_test_mutation(row):
    for i,(a,b) in enumerate(zip(row.protein_sequence,base)):
        if a!=b: break
    row['wildtype'] = base[i]
    row['mutation'] = row.protein_sequence[i]
    row['position'] = i+1
    return row

In [None]:
submission = pd.read_csv('../input/novozymes-enzyme-stability-prediction/sample_submission.csv')
test = pd.read_csv('../input/novozymes-enzyme-stability-prediction/test.csv')
test = test.apply(get_test_mutation,axis=1)
deletions = test.loc[test.protein_sequence.str.len()==220,'seq_id'].values
test['sequence'] = base
test.sequence = test.sequence.map(add_spaces)
test = test.rename({'protein_sequence':'mutant_seq'},axis=1)
test.mutant_seq = test.mutant_seq.map(add_spaces)
print('Test data shape:', test.shape )
test.head()

In [None]:
test_dataset = TestDataset(CFG, test)
test_loader = DataLoader(test_dataset,
                         batch_size=CFG.batch_size,
                         shuffle=False,
                         #collate_fn=DataCollatorWithPadding(tokenizer=CFG.tokenizer, padding='longest'),
                         num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
predictions_ = []
for fold in CFG.trn_fold:
    model = CustomModel(CFG, config_path=CFG.config_path, pretrained=False)
    state = torch.load(CFG.path+f"{CFG.model.replace('/', '-')}_fold{fold}_best.pth",
                       map_location=torch.device('cpu'))
    model.load_state_dict(state['model'])
    prediction = inference_fn(test_loader, model, device)
    predictions_.append(prediction)
    del model, state, prediction; gc.collect()
    torch.cuda.empty_cache()
predictions = np.mean(predictions_, axis=0)

In [None]:
test['tm'] = predictions
submission = submission.drop(columns=['tm']).merge(test[['seq_id','tm']], on='seq_id', how='left')
display(submission.head())
submission[['seq_id','tm']].to_csv(f'submission_bert_{VER}.csv', index=False)

In [None]:
import matplotlib.pyplot as plt
plt.hist(submission.tm, bins=100)
plt.show()