In [1]:
import re
import torch
import pickle
import argparse
import numpy as np
import pandas as pd

from transformers import BertModel, BertTokenizer

### 1. Load data

In [2]:
BS_data_df = pd.read_csv("../../input_data/PDB/BS/Training_BS_data.tsv", sep = "\t")
BS_uniprot_IDs, BS_uniprot_seqs = BS_data_df.iloc[:, 1].values, BS_data_df.iloc[:, 2].values

### 2. Get protein features

In [3]:
protein_seqs_dict, protein_features_dict = dict(), dict()

for i, s in zip(BS_uniprot_IDs, BS_uniprot_seqs):
    protein_seqs_dict[i] = s
print(f"Uniprot_IDs: {len(protein_seqs_dict)}")

Uniprot_IDs: 5598


In [4]:
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case = False)
prots_model = BertModel.from_pretrained("Rostlab/prot_bert") 
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

prots_model = prots_model.to(device)
prots_model = prots_model.eval()

Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
for PID in list(protein_seqs_dict.keys()):
    seqs_example = " ".join(list(re.sub(r"[UZOB]", "X", protein_seqs_dict[PID])))

    ids = tokenizer.batch_encode_plus([seqs_example], add_special_tokens = True, pad_to_max_length = True)
    input_ids = torch.tensor(ids['input_ids']).to(device)
    attention_mask = torch.tensor(ids['attention_mask']).to(device) 

    with torch.no_grad(): 
        embedding = prots_model(input_ids = input_ids, attention_mask = attention_mask)[0]
        embedding = embedding.cpu().numpy()
        seq_len = (attention_mask[0] == 1).sum()

        if seq_len < 1503:
            seq_emd = embedding[0][1:seq_len-1]            

        else:
            seq_len = 1502
            seq_emd = embedding[0][1:seq_len-1]

    protein_features_dict[PID] = seq_emd

print(len(protein_features_dict))



5598


In [6]:
with open("../../input_data/PDB/BS/Training_BS_protein_features.pkl", "wb") as f:        
    pickle.dump(protein_features_dict, f) 

### 3. Get binding site labels

In [26]:
binding_sites_8A_dict, binding_sites_4A_dict = dict(), dict()

BS_8A_labels, BS_4A_labels = BS_data_df.iloc[:, 3].values, BS_data_df.iloc[:, 4].values

In [27]:
for PID, BS_8A, BS_4A in zip(BS_uniprot_IDs, BS_8A_labels, BS_4A_labels):
    BS_8A_list, BS_4A_list = BS_8A.split(","), BS_4A.split(",")
    
    if PID in binding_sites_4A_dict:
        binding_sites_4A_dict[PID].extend(list(map(int, BS_4A_list)))
        binding_sites_8A_dict[PID].extend(list(map(int, BS_8A_list)))

    else:
        binding_sites_4A_dict[PID] = list(map(int, BS_4A_list))
        binding_sites_8A_dict[PID] = list(map(int, BS_8A_list))

In [28]:
uniprot_binding_sites_dict = {"Uniprot_IDs":[], "Uniprot_Seqs":[], "BS_4A":[], "BS_8A":[]}

for PID in list(binding_sites_4A_dict.keys()):
    uniprot_binding_sites_dict["Uniprot_IDs"].append(PID)
    uniprot_binding_sites_dict["Uniprot_Seqs"].append(protein_seqs_dict[PID])

    BS_4A = sorted(list(set(binding_sites_4A_dict[PID])))
    BS_4A = list(map(str, BS_4A))
    uniprot_binding_sites_dict["BS_4A"].append(",".join(BS_4A))

    BS_8A = sorted(list(set(binding_sites_8A_dict[PID])))

    BS_8A = list(map(str, BS_8A))
    uniprot_binding_sites_dict["BS_8A"].append(",".join(BS_8A))

uniprot_binding_sites_df = pd.DataFrame(uniprot_binding_sites_dict)
uniprot_binding_sites_df.to_csv(f"../../input_data/PDB/BS/Training_BS_labels.tsv", sep = "\t", index = False)