In [1]:
import re
import torch
import pickle
import argparse
import numpy as np
import pandas as pd

from transformers import BertModel, BertTokenizer

### 1. Load data

In [2]:
PDBbind_df = pd.read_csv("../../input_data/PDB/BA/Training_BA_data.tsv", sep = "\t")
PDBbind_uniprot_IDs, PDBbind_uniprot_seqs = PDBbind_df.iloc[:, 1].values, PDBbind_df.iloc[:, 4].values

CASF2016_df = pd.read_csv("../../input_data/PDB/BA/CASF2016_BA_data.tsv", sep = "\t")
CASF2016_uniprot_IDs, CASF2016_uniprot_seqs = CASF2016_df.iloc[:, 1].values, CASF2016_df.iloc[:, 4].values

CASF2013_df = pd.read_csv("../../input_data/PDB/BA/CASF2013_BA_data.tsv", sep = "\t")
CASF2013_uniprot_IDs, CASF2013_uniprot_seqs = CASF2013_df.iloc[:, 1].values, CASF2013_df.iloc[:, 4].values

CSAR2014_df = pd.read_csv("../../input_data/PDB/BA/CSAR2014_BA_data.tsv", sep = "\t")
CSAR2014_uniprot_IDs, CSAR2014_uniprot_seqs = CSAR2014_df.iloc[:, 1].values, CSAR2014_df.iloc[:, 4].values

CSAR2012_df = pd.read_csv("../../input_data/PDB/BA/CSAR2012_BA_data.tsv", sep = "\t")
CSAR2012_uniprot_IDs, CSAR2012_uniprot_seqs = CSAR2012_df.iloc[:, 1].values, CSAR2012_df.iloc[:, 4].values

CSARset1_df = pd.read_csv("../../input_data/PDB/BA/CSARset1_BA_data.tsv", sep = "\t")
CSARset1_uniprot_IDs, CSARset1_uniprot_seqs = CSARset1_df.iloc[:, 1].values, CSARset1_df.iloc[:, 4].values

CSARset2_df = pd.read_csv("../../input_data/PDB/BA/CSARset2_BA_data.tsv", sep = "\t")
CSARset2_uniprot_IDs, CSARset2_uniprot_seqs = CSARset2_df.iloc[:, 1].values, CSARset2_df.iloc[:, 4].values

Astex_df = pd.read_csv("../../input_data/PDB/BA/Astex_BA_data.tsv", sep = "\t")
Astex_uniprot_IDs, Astex_uniprot_seqs = Astex_df.iloc[:, 1].values, Astex_df.iloc[:, 4].values

COACH420_df = pd.read_csv("../../input_data/PDB/BA/COACH420_IS_data.tsv", sep = "\t")
COACH420_uniprot_IDs, COACH420_uniprot_seqs = COACH420_df.iloc[:, 1].values, COACH420_df.iloc[:, 3].values

HOLO4K_df = pd.read_csv("../../input_data/PDB/BA/HOLO4K_IS_data.tsv", sep = "\t")
HOLO4K_uniprot_IDs, HOLO4K_uniprot_seqs = HOLO4K_df.iloc[:, 1].values, HOLO4K_df.iloc[:, 3].values

### 2. Get protein features

In [3]:
def get_info(dataset, uniprot_ids, uniprot_seqs):
    protein_seqs_dict = dict()
    
    for i, s in zip(uniprot_ids, uniprot_seqs):
        protein_seqs_dict[i] = s
    print(f"[{dataset}] Uniprot_IDs: {len(protein_seqs_dict)}")
    
    return protein_seqs_dict

In [4]:
PDBbind_protein_seqs_dict = get_info("PDBbind", PDBbind_uniprot_IDs, PDBbind_uniprot_seqs)
CASF2016_protein_seqs_dict = get_info("CASF2016", CASF2016_uniprot_IDs, CASF2016_uniprot_seqs)
CASF2013_protein_seqs_dict = get_info("CASF2013", CASF2013_uniprot_IDs, CASF2013_uniprot_seqs)
CSAR2014_protein_seqs_dict = get_info("CSAR2014", CSAR2014_uniprot_IDs, CSAR2014_uniprot_seqs)
CSAR2012_protein_seqs_dict = get_info("CSAR2012", CSAR2012_uniprot_IDs, CSAR2012_uniprot_seqs)
CSARset1_protein_seqs_dict = get_info("CSARset1", CSARset1_uniprot_IDs, CSARset1_uniprot_seqs)
CSARset2_protein_seqs_dict = get_info("CSARset2", CSARset2_uniprot_IDs, CSARset2_uniprot_seqs)
Astex_protein_seqs_dict = get_info("Astex", Astex_uniprot_IDs, Astex_uniprot_seqs)
COACH420_protein_seqs_dict = get_info("COACH420", COACH420_uniprot_IDs, COACH420_uniprot_seqs)
HOLO4K_protein_seqs_dict = get_info("HOLO4K", HOLO4K_uniprot_IDs, HOLO4K_uniprot_seqs)

[PDBbind] Uniprot_IDs: 2422
[CASF2016] Uniprot_IDs: 63
[CASF2013] Uniprot_IDs: 63
[CSAR2014] Uniprot_IDs: 3
[CSAR2012] Uniprot_IDs: 7
[CSARset1] Uniprot_IDs: 107
[CSARset2] Uniprot_IDs: 86
[Astex] Uniprot_IDs: 72
[COACH420] Uniprot_IDs: 239
[HOLO4K] Uniprot_IDs: 1086


In [5]:
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case = False)
prots_model = BertModel.from_pretrained("Rostlab/prot_bert") 
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

prots_model = prots_model.to(device)
prots_model = prots_model.eval()

Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
def get_features(protein_seqs_dict):
    protein_features_dict = dict()
    
    for PID in list(protein_seqs_dict.keys()):
        seqs_example = " ".join(list(re.sub(r"[UZOB]", "X", protein_seqs_dict[PID])))

        ids = tokenizer.batch_encode_plus([seqs_example], add_special_tokens = True, pad_to_max_length = True)
        input_ids = torch.tensor(ids['input_ids']).to(device)
        attention_mask = torch.tensor(ids['attention_mask']).to(device) 

        with torch.no_grad(): 
            embedding = prots_model(input_ids = input_ids, attention_mask = attention_mask)[0]
            embedding = embedding.cpu().numpy()
            seq_len = (attention_mask[0] == 1).sum()

            if seq_len < 1503:
                seq_emd = embedding[0][1:seq_len-1]            

            else:
                seq_len = 1502
                seq_emd = embedding[0][1:seq_len-1]

        protein_features_dict[PID] = seq_emd
        
    return protein_features_dict

In [7]:
PDBbind_features_dict = get_features(PDBbind_protein_seqs_dict)
print(f"PDBbind features: {len(PDBbind_features_dict)}")
with open("../../input_data/PDB/BA/Training_protein_features.pkl", "wb") as f:        
    pickle.dump(PDBbind_features_dict, f) 



PDBbind features: 2422


In [8]:
CASF2016_features_dict = get_features(CASF2016_protein_seqs_dict)
print(f"PDBbind features: {len(CASF2016_features_dict)}")
with open("../../input_data/PDB/BA/CASF2016_protein_features.pkl", "wb") as f:        
    pickle.dump(CASF2016_features_dict, f) 

PDBbind features: 63


In [10]:
CASF2013_features_dict = get_features(CASF2013_protein_seqs_dict)
print(f"CASF2013 features: {len(CASF2013_features_dict)}")
with open("../../input_data/PDB/BA/CASF2013_protein_features.pkl", "wb") as f:        
    pickle.dump(CASF2013_features_dict, f) 

CASF2013 features: 63


In [11]:
CSAR2014_features_dict = get_features(CSAR2014_protein_seqs_dict)
print(f"CSAR2014 features: {len(CSAR2014_features_dict)}")
with open("../../input_data/PDB/BA/CSAR2014_protein_features.pkl", "wb") as f:        
    pickle.dump(CSAR2014_features_dict, f) 

CSAR2014 features: 3


In [12]:
CSAR2012_features_dict = get_features(CSAR2012_protein_seqs_dict)
print(f"CSAR2012 features: {len(CSAR2012_features_dict)}")
with open("../../input_data/PDB/BA/CSAR2012_protein_features.pkl", "wb") as f:        
    pickle.dump(CSAR2012_features_dict, f) 

CSAR2012 features: 7


In [13]:
CSARset1_features_dict = get_features(CSARset1_protein_seqs_dict)
print(f"CSARset1 features: {len(CSARset1_features_dict)}")
with open("../../input_data/PDB/BA/CSARset1_protein_features.pkl", "wb") as f:        
    pickle.dump(CSARset1_features_dict, f) 

CSARset1 features: 107


In [14]:
CSARset2_features_dict = get_features(CSARset2_protein_seqs_dict)
print(f"CSARset2 features: {len(CSARset2_features_dict)}")
with open("../../input_data/PDB/BA/CSARset2_protein_features.pkl", "wb") as f:        
    pickle.dump(CSARset2_features_dict, f) 

CSARset2 features: 86


In [15]:
Astex_features_dict = get_features(Astex_protein_seqs_dict)
print(f"Astex features: {len(Astex_features_dict)}")
with open("../../input_data/PDB/BA/Astex_protein_features.pkl", "wb") as f:        
    pickle.dump(Astex_features_dict, f) 

Astex features: 72


In [16]:
COACH420_features_dict = get_features(COACH420_protein_seqs_dict)
print(f"COACH420 features: {len(COACH420_features_dict)}")
with open("../../input_data/PDB/BA/COACH420_protein_features.pkl", "wb") as f:        
    pickle.dump(COACH420_features_dict, f) 

COACH420 features: 239


In [17]:
HOLO4K_features_dict = get_features(HOLO4K_protein_seqs_dict)
print(f"HOLO4K features: {len(HOLO4K_features_dict)}")
with open("../../input_data/PDB/BA/HOLO4K_protein_features.pkl", "wb") as f:        
    pickle.dump(HOLO4K_features_dict, f) 

HOLO4K features: 1086
