In [19]:
import os
# set cuda params
# 'TORCH_HOME'directory will be used to save origenal esm-1v weights
os.environ['TORCH_HOME'] = "../torch_hub"
os.environ['CUDA_VISIBLE_DEVICES'] = "1"

In [20]:
import scipy
import sklearn
import esm
import pickle

import pandas as pd
import numpy as np

import torch
from torch.utils.data import Dataset
from torch import nn
import math

import transformers
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import Trainer, TrainingArguments, EvalPrediction
from transformers import EsmTokenizer, EsmForMaskedLM

from esm.pretrained import load_model_and_alphabet_hub
from esm.inverse_folding.util import CoordBatchConverter

from sklearn.metrics import r2_score, mean_squared_error

from biotite.structure.residues import get_residues

from saprot_utils.foldseek_util import get_struc_seq
from pathlib import Path

In [21]:
onetothree = {
    'A': 'ALA',  # Alanine
    'R': 'ARG',  # Arginine
    'N': 'ASN',  # Asparagine
    'D': 'ASP',  # Aspartic Acid
    'C': 'CYS',  # Cysteine
    'Q': 'GLN',  # Glutamine
    'E': 'GLU',  # Glutamic Acid
    'G': 'GLY',  # Glycine
    'H': 'HIS',  # Histidine
    'I': 'ILE',  # Isoleucine
    'L': 'LEU',  # Leucine
    'K': 'LYS',  # Lysine
    'M': 'MET',  # Methionine
    'F': 'PHE',  # Phenylalanine
    'P': 'PRO',  # Proline
    'S': 'SER',  # Serine
    'T': 'THR',  # Threonine
    'W': 'TRP',  # Tryptophan
    'Y': 'TYR',  # Tyrosine
    'V': 'VAL'   # Valine
}

def load_data(path, filtr_length = None):
    '''
    path- path to table with columns: pdb_id_chain, resi_pos, resi_aa, contact_number
    '''
    df= pd.read_csv(path)

    df = df.groupby('pdb_id_chain').agg({
        'resi_pos': list,
        'resi_aa': list,
        'contact_number': list,
        'contact_number_binary': list,
        }).reset_index()

    if filtr_length:
            df = df[df['resi_aa'].str.len() >= filtr_length]

    to_residue_name = lambda x: [onetothree[i] if i in onetothree else None for i in x]
    df.loc[:, 'resi_name'] = df['resi_aa'].apply(to_residue_name)

    return df


def get_features_from_pdb(pdb_path, chain):
    '''
    load needed features from PDB file
    
        Parameters:
            pdb_path (Path): path to pdb-file
        Returns:
            dict (dict): dictionary, where keys are properties of the protein's tertiary structure
    '''
    
    try:
        parsed_seqs = get_struc_seq("saprot_utils/bin/foldseek", pdb_path)[chain.upper()]
        structure = esm.inverse_folding.util.load_structure(str(pdb_path), chain.upper())
    except:
        parsed_seqs = get_struc_seq("saprot_utils/bin/foldseek", pdb_path)[chain.lower()]
        structure = esm.inverse_folding.util.load_structure(str(pdb_path), chain.lower())
    
    seq, foldseek_seq, combined_seq = parsed_seqs    
        
    resi_index = get_residues(structure)[0]
    resi_aa    = get_residues(structure)[1]
    resi_keys  = []
    for resi_index_,resi_aa_ in zip(get_residues(structure)[0],get_residues(structure)[1]):
        key = (str(resi_aa_),int(resi_index_))
        resi_keys.append(key) 
    
    assert len(resi_keys) == len(seq), 'foldseek and esm.inverse_folding.util.load_structure returns different seqs'

    return {"seq":seq,
            "foldseek_seq":foldseek_seq,
            "combined_seq":combined_seq,
            "residues": resi_keys
            }


def get_contact_number_from_struct_data(df_row, struct_data):

    # assert type(df_row['resi_pos'][0]) == type(struct_data["residues"][0][1]), \
    #     f'different residue position types in pdb and loaded df:\npdb- {type(struct_data["residues"][0][1])}\ndf- {type(df_row["resi_pos"][0])}'

    cn = [-100]* len(struct_data['seq'])

    key_map = {} 
    for i, key in enumerate(zip(df_row['resi_name'], df_row['resi_pos'])):
        # key_map[(key[0],int(key[1]))] = i
        key_map[(key[0],key[1])] = i

    check_keys = {}
    for i,key in enumerate(struct_data["residues"]):
        try:
            key = (key[0],int(key[1]))
            # key = (key[0],key[1])
        except:
            key = (key[0],str(key[1]))

        check_keys[(key[0],key[1])] = i

        if key in key_map:
            cn[i]= df_row['cn'][key_map[key]]

    assert len([i for i in cn if i != -100]) > 0, \
            f'all contact numbers are equal to -100:'+ '\n' + \
            f'{key_map}\n{check_keys}\n{struct_data["residues"]}\n{cn}'

    # if len([i for i in cn if i != -100]) == 0:
    #     for i,key in enumerate(struct_data["residues"]):
    #         key = (key[0],key[1])
    #         print(key, type(key[1]), key in key_map)
    #         if key in key_map:
    #             cn[i]= df_row['cn'][key_map[key]]
    #         print(len([i for i in cn if i != -100]))
        
    #     print(df_row['cn'])
    #     assert 1 ==0


    return cn


def process_data(df, pdb_folder_path, for_classification = False):
    
    if for_classification:
        colname = 'contact_number_binary'
    else:
        colname = 'contact_number'

    df.loc[:, 'cn'] = df.loc[:, colname]

    processed_data = []
    bad_prots = []

    for row_id in range(df.shape[0]):
    # for row_id in df.index:
        # print(row_id)
        df_row = df.iloc[row_id, :]
        pdb_id = df_row["pdb_id_chain"]
        chain = pdb_id.split('_')[1]
        pdb_path = os.path.join(pdb_folder_path, pdb_id+ '.pdb')

        if np.unique(df_row['cn']).shape[0] == 1:
            bad_prots.append(pdb_id)
            continue

        if not os.path.exists(pdb_path):
            bad_prots.append(pdb_id)
            continue
        
        struct_data = get_features_from_pdb(pdb_path, chain)

        struct_data["cn"] = get_contact_number_from_struct_data(df_row, struct_data)

        processed_data.append(struct_data)

    print('Number of proteins with invalid contact_number: ', len(bad_prots))
    
    processed_data = pd.DataFrame(processed_data)[['combined_seq', 'cn']]
    return processed_data

In [22]:
class SaProtDataset(Dataset):
    """
    A class to represent a sutable data set for model. 
    
    convert original pandas data frame to model set,
    where 'token ids' is ESM-1v embedings corresponed to protein sequence (max length 1024 AA)
    and 'lables' is a contact number values
    Attributes:
        df (pandas.DataFrame): dataframe with two columns: 
                0 -- preotein sequence in string ('GLVM') or list (['G', 'L', 'V', 'M']) format
                1 -- contcat number values in list [0, 0.123, 0.23, -100, 1.34] format
        esm1v_batch_converter (function):
                    ESM function callable to convert an unprocessed (labels + strings) batch to a
                    processed (labels + tensor) batch.
        label_type (str):
                type of model: regression or binary

    """
    def __init__(self, df, label_type ='regression'):
        """
        Construct all the necessary attributes to the PDB_Database object.
        
        Parameters:
            df (pandas.DataFrame): dataframe with two columns: 
                0 -- protein sequence in string ('GLVM') or list (['G', 'L', 'V', 'M']) format
                1 -- contcat number values in list [0, 0.123, 0.23, -100, 1.34] format
            label_type (str):
                type of model: regression or binary
        """
        model_path = "westlake-repl/SaProt_650M_PDB"
        self.tokenizer = EsmTokenizer.from_pretrained(model_path)

        self.df = df
        self.label_type = label_type

    def __getitem__(self, idx):
        item = {}
        
        # item['token_ids'] = self.tokenizer(self.df.iloc[idx, 0][:2044], padding = False, return_tensors="pt")['input_ids']
        # item['labels'] = torch.unsqueeze(torch.tensor(self.df.iloc[idx, 1][:1022]),0)
        item['token_ids'] = self.tokenizer(self.df.iloc[idx, 0], padding = False, return_tensors="pt")['input_ids']
        item['labels'] = torch.unsqueeze(torch.tensor(self.df.iloc[idx, 1]),0)

        return item

    def __len__(self):
        return len(self.df)


In [14]:
pdb_path = '/home/shevtsov/SEMAi/data/new_train'
train_path = '/home/shevtsov/SEMAi/data/train_set_03.csv'
test_path = '/home/shevtsov/SEMAi/data/test_set_prev.csv'

train_data = load_data(train_path)
test_data = load_data(test_path)

processed_train_data = process_data(train_data, pdb_path, for_classification = False)
processed_test_data = process_data(test_data, pdb_path, for_classification = False)

train_ds = SaProtDataset(processed_train_data)
test_ds = SaProtDataset(processed_test_data)

print(f'Data preparation was completed\nNumber of enteries:\ntrain= {train_ds.__len__()}\ntest= {test_ds.__len__()}')


Number of proteins with invalid contact_number:  27
Number of proteins with invalid contact_number:  0
Data preparation was completed
Number of enteries:
train= 1517
test= 101


In [23]:
def compute_metrics_regr(p: EvalPrediction):
    
    preds = p.predictions[:,:,1]

    batch_size, seq_len = preds.shape    
    out_labels, out_preds = [], []

    for i in range(batch_size):
        for j in range(seq_len):
            if p.label_ids[i, j] > -1:
                out_labels.append(p.label_ids[i][j])
                out_preds.append(preds[i][j])
                
    # out_labels_regr = [math.log(t+1) for t in out_labels]
    out_labels_regr = out_labels

    
    return {
        "pearson_r": scipy.stats.pearsonr(out_labels_regr, out_preds)[0],
        "mse": mean_squared_error(out_labels_regr, out_preds),
        "r2_score": r2_score(out_labels_regr, out_preds)
    }

    
class SaProtForTokenClassification(nn.Module):

    def __init__(self, num_labels = 2):
        super().__init__()
        self.num_labels = num_labels
        model_path = "westlake-repl/SaProt_650M_PDB"
        self.encoder = EsmForMaskedLM.from_pretrained(model_path)

        self.classifier = nn.Linear(446, self.num_labels)

    def forward(self, token_ids, labels):
        # print('SaProtForTokenClassification token_ids', token_ids.shape)
        outputs = self.encoder(input_ids = token_ids)['logits']
        # print('SaProtForTokenClassification encoder outputs', outputs.shape)
        outputs = outputs[:,1:-1,:]
        logits = self.classifier(outputs)
        # print('logits', logits.shape)
        # print('logits', logits)
        return SequenceClassifierOutput(logits=logits)

In [24]:
def model_init():
    return SaProtForTokenClassification().cuda()

In [25]:
class MaskedMSELoss(torch.nn.Module):
    def __init__(self):
        super(MaskedMSELoss, self).__init__()

    def forward(self, inputs, target, mask):   
        # print('loss inputs',inputs.shape, inputs)
        # print('target',target.shape, target)
        # print('mask',mask.shape)
        diff2 = (torch.flatten(inputs[:,:,1]) - torch.flatten(target)) ** 2.0 * torch.flatten(mask)
        result = torch.sum(diff2) / torch.sum(mask)
        if torch.sum(mask)==0:
            return torch.sum(diff2)
        else:
            #print('loss:', result)
            return result

    
class MaskedRegressTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):

        # print('trainer inputs',inputs['token_ids'].shape, inputs['token_ids'])
        # print('trainer labels',inputs['labels'].shape, inputs['labels'])
        # # inputs = inputs[0]
        # # labels = inputs.pop("labels")
        labels = inputs["labels"]
        labels = labels.squeeze().detach().cpu().numpy().tolist()
        # labels = [math.log(t+1) if t!=-100 else -100 for t in labels]
        labels = torch.unsqueeze(torch.FloatTensor(labels), 0).cuda()
        masks = ~torch.eq(labels, -100).cuda()
        
        #masks = inputs.pop("masks")
        outputs = model(**inputs)
        logits = outputs.logits

        loss_fn = MaskedMSELoss()
        loss = loss_fn(logits, labels, masks)
        
        return (loss, outputs) if return_outputs else loss        
        
def collator_fn(x):
    if len(x)==1:
        return x[0]
    # print('x:', x)
    return x


## Model training and test 

In [26]:
training_args = TrainingArguments(
    output_dir='./results_fold' ,          # output directory
    num_train_epochs=2,          # total number of training epochs
    per_device_train_batch_size=1,   # batch size per device during training
    per_device_eval_batch_size=1,   # batch size for evaluation
    warmup_steps=20,                # number of warmup steps for learning rate scheduler
    learning_rate=1e-05,             # learning rate
    weight_decay=0.0,                # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=200,               # How often to print logs
    save_strategy = "no",
    do_train=True,                   # Perform training
    do_eval=False,                    # Perform evaluation
    evaluation_strategy="epoch",     # evalute after each epoch
    gradient_accumulation_steps=1,  # total number of steps before back propagation
    fp16=False,                       # Use mixed precision
    run_name="PDB_regr",      # experiment name
    seed=42,                         # Seed for experiment reproducibility
    load_best_model_at_end=False,
    metric_for_best_model="eval_r2",
    greater_is_better=True,

)

In [27]:
#create direactory to weights storage
if not os.path.exists("../models/"):
    os.makedirs("../models/")

### Train one model

In [None]:
# torch.cuda.empty_cache()
trainer = MaskedRegressTrainer(
    model=model_init(),                 # the instantiated 🤗 Transformers model to be trained
    args=training_args,                   # training arguments, defined above
    train_dataset = train_ds,    # training dataset
    eval_dataset  = test_ds,    # evaluation dataset
    data_collator = collator_fn,
    compute_metrics = compute_metrics_regr
)
trainer.train()

#save weights
# torch.save(trainer.model.state_dict(), "../models/newdata_sema_saprot_continous_noncut_0.pth")

### Train ensemble

In [None]:
for seed in [0, 1, 2, 3, 4]:
    training_args = TrainingArguments(
        output_dir='./results_fold' ,          # output directory
        num_train_epochs=2,          # total number of training epochs
        per_device_train_batch_size=1,   # batch size per device during training
        per_device_eval_batch_size=1,   # batch size for evaluation
        warmup_steps=20,                # number of warmup steps for learning rate scheduler
        learning_rate=1e-05,             # learning rate
        weight_decay=0.0,                # strength of weight decay
        logging_dir='./logs',            # directory for storing logs
        logging_steps=200,               # How often to print logs
        save_strategy = "no",
        do_train=True,                   # Perform training
        do_eval=False,                    # Perform evaluation
        evaluation_strategy="epoch",     # evalute after each epoch
        gradient_accumulation_steps=1,  # total number of steps before back propagation
        fp16=False,                       # Use mixed precision
        run_name="PDB_regr",      # experiment name
        seed=seed,                         # Seed for experiment reproducibility
        load_best_model_at_end=False,
        metric_for_best_model="eval_r2",
        greater_is_better=True,
    )

    torch.cuda.empty_cache()
    trainer = MaskedRegressTrainer(
        model=model_init(),                 # the instantiated 🤗 Transformers model to be trained
        args=training_args,                   # training arguments, defined above
        train_dataset = train_ds,    # training dataset
        eval_dataset  = test_ds,    # evaluation dataset
        data_collator = collator_fn,
        compute_metrics = compute_metrics_regr
    )
    trainer.train()

    #save weights
    # torch.save(trainer.model.state_dict(), f"../models/newdata_sema_saprot_continous_noncut_{seed}.pth")