In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import roc_auc_score, balanced_accuracy_score
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import numpy as np
import os
import sys
from tqdm import tqdm
import pandas as pd
import plotly.express as px
import matplotlib.pyplot as plt
from transformers import T5Tokenizer, T5EncoderModel
from Bio import SeqIO
import re

# Prepare embeddings for fasta file

In [27]:
from transformers import T5Tokenizer, T5EncoderModel
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False) #.to(device)
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc").to(device);

In [28]:
def get_embeddings_aa(seq):
    sequence_examples = [" ".join(list(re.sub(r"[UZOB]", "X", seq)))]

    ids = tokenizer.batch_encode_plus(sequence_examples, add_special_tokens=True, padding="longest")

    input_ids = torch.tensor(ids['input_ids']).to(device)
    attention_mask = torch.tensor(ids['attention_mask']).to(device)

    # generate embeddings
    with torch.no_grad():
        embedding_repr = model(input_ids=input_ids,
                               attention_mask=attention_mask)

    # extract residue embeddings for the first ([0,:]) sequence in the batch and remove padded & special tokens ([0,:7])
    emb_0 = embedding_repr.last_hidden_state[0]
    # emb_0_per_protein = emb_0.mean(dim=0)

    del input_ids

    del attention_mask

    del embedding_repr

    return emb_0[:len(seq)]

In [29]:
fn = 'UP000005640_9606.fasta'
test_seq = list(SeqIO.parse(fn, "fasta"))
len(test_seq)

20586

In [30]:
ids = []
test_prots = []
num_sequences = len(test_seq)
i = 0

for seq in tqdm(test_seq):
    input = str(seq.seq)
    if len(input) < 1000:
        ids.append(seq.id)
        concatenated_tensor = get_embeddings_aa(input)
        mean_tensor_p = concatenated_tensor.detach().cpu().numpy()
        test_prots.append(mean_tensor_p)
        i += 1
    else:
        n_split = len(input) // 1000
        ids.append(seq.id)
        embeds_temp = []
        
        for split in range(n_split + 1):
            input_splitted = input[split * 1000: (split + 1) * 1000]
            emb_aa = get_embeddings_aa(input_splitted)
            embeds_temp.append(emb_aa)
        
        concatenated_tensor = torch.cat(embeds_temp, dim=0)
        test_prots.append(concatenated_tensor.detach().cpu().numpy())
        i += 1

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20586/20586 [1:31:26<00:00,  3.75it/s]


# Load the model and predict the score with the sliding window

In [31]:
class DubFC(nn.Module):
    def __init__(self, input_size, hidden_size, dropout):
        super().__init__()
        self.bn = nn.BatchNorm1d(input_size)
        self.dropout = nn.Dropout(dropout)
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fcint = nn.Linear(hidden_size, hidden_size)
        self.fcout = nn.Linear(hidden_size, 1)
        self.relu = nn.ReLU()
        self.activation = torch.nn.Sigmoid()

    def forward(self, emb):
        out = self.bn(emb)
        out = self.relu(self.fc1(out))
        out = self.dropout(out)
        out = self.relu(self.fcint(out))
        out = self.dropout(out)
        out = self.relu(self.fcint(out))
        out = self.fcout(out)

        return self.activation(out)

In [32]:
HIDDEN_SIZE = 512
DROPOUT = 0.5
EMB_SIZE = 1024

params = {
    "input_size": EMB_SIZE,
    "hidden_size": HIDDEN_SIZE,
    "dropout": DROPOUT,
}

In [33]:
win_pred = {}

for seq_num, test_prot in tqdm(enumerate(test_prots)):
    win_scores = {}
    for window in [30, 50, 100, 130, 180, 250, 340, 400, 500, 600]:
        if window + 21 < test_prot.shape[0]:
            step = (test_prot.shape[0] - window) // 19
            
            def count_substrings(string, window, step):
                count = 0
                for i in range(0, string - window + 1, step):
                    count += 1
                return count
            
            num_substrings = count_substrings(test_prot.shape[0], window, step)
            
            
            test_embeddings = np.zeros((num_substrings, 1024))
            pos = []
            
            for j, i in enumerate(range(0, test_prot.shape[0] - window + 1, step)):
                test_embeddings[j] = np.mean(test_prot[i: i + window], axis=0)
                pos.append(str(i) + '-' + str(i + window))
            predictions_av = torch.zeros([len(test_embeddings), 1], dtype=torch.float64, device=device)
            
    
            for fold in range(1,6):
                # fold = 3
                path = f"/home/alexey_bondarev/DUB/torch_metrics/domain_human_fold{fold}.pt"
                checkpoint = torch.load(path, map_location='cpu')
                model = DubFC(**params).to(device)
                model.load_state_dict(checkpoint)
                model = model.to(device)
                model.eval()
                
                test_embeddings_torch = torch.from_numpy(test_embeddings)
                test_embeddings_torch = test_embeddings_torch.to(device)
                predictions =  model(test_embeddings_torch.float())
                predictions_av = predictions_av.add(predictions)
            predictions_av = predictions_av / 5
            
            df_plot = pd.DataFrame(predictions_av.detach().tolist())
            df_plot['pos'] = pos
            win_scores.update(df_plot.set_index('pos')[0].to_dict())
        
            
    test_embeddings = np.zeros((1, 1024))
    test_embeddings[0] = np.mean(test_prot, axis=0)
    predictions_av = torch.zeros([len(test_embeddings), 1], dtype=torch.float64, device=device)
    
    
    for fold in range(1,6):
        # fold = 3
        path = f"/home/alexey_bondarev/DUB/torch_metrics/domain_human_fold{fold}.pt"
        checkpoint = torch.load(path, map_location='cpu')
        model = DubFC(**params).to(device)
        model.load_state_dict(checkpoint)
        model = model.to(device)
        model.eval()
        
        test_embeddings_torch = torch.from_numpy(test_embeddings)
        test_embeddings_torch = test_embeddings_torch.to(device)
        predictions =  model(test_embeddings_torch.float())
        predictions_av = predictions_av.add(predictions)
    predictions_av = predictions_av / 5
    
    
    df_plot = pd.DataFrame(predictions_av.detach().tolist())
    win_scores['0-' + str(test_prot.shape[0])] = list(df_plot[0])[0]

    top_values = sorted(win_scores.values(), reverse=True)[:3]
    result = {key: value for key, value in win_scores.items() if value in top_values}
    sorted_results = {k: v for k, v in sorted(result.items(), key=lambda item: item[1])}
    win_pred[ids[seq_num]] = sorted_results

20586it [1:26:24,  3.97it/s]


In [34]:
max_values = {key: max(value.values()) for key, value in win_pred.items()}

fig = px.histogram(x=max_values, nbins=50, title='Predictions distibution')
fig.show()

# Define the thershold, add annotation and save the prediction list

In [35]:
win_threshold = {}
threshold = 0.05

for id, result in win_pred.items():
    max_key = sorted(result, key=result.get, reverse=True)[0]
    max_value = result[max_key]
    if max_value >= threshold:
        win_threshold[id] = max_value

In [36]:
df_list = pd.read_csv('human_protein_list.csv', header=None)
proteins_dict = dict(zip(df_list[0], df_list[1]))

files_dir = "human_fasta/"
files = [files_dir + x for x in os.listdir(files_dir)]
files.remove('human_fasta/combined.fa')
dub_human = []
for file in files:
    type = file[12:].split('_')[0]
    for record in SeqIO.parse(file, "fasta"):
        dub_human.append(record.name)

In [37]:
df_out = pd.DataFrame.from_dict(win_threshold, orient='index').reset_index().rename(columns = {'index': 'name', 0: 'score'})
df_out['is_dub'] = df_out['name'].apply(lambda x: 'dub' if x in dub_human else 'no')
df_out['scores'] = df_out['name'].apply(lambda x: win_pred[x])
df_out['annotation'] = df_out['name'].apply(lambda x: proteins_dict[x.split('|')[1]] if x.split('|')[1] in proteins_dict else '')
df_out = df_out.sort_values(by='score', ascending=False).reset_index(drop=True)
df_out

Unnamed: 0,name,score,is_dub,scores,annotation
0,sp|Q9NQC7|CYLD_HUMAN,0.998386,dub,"{'544-884': 0.9983218908309937, '629-879': 0.9...",Ubiquitin carboxyl-terminal hydrolase CYLD
1,sp|Q70CQ2|UBP34_HUMAN,0.998289,dub,"{'1815-2215': 0.9975494265556336, '1903-2153':...",Ubiquitin carboxyl-terminal hydrolase 34
2,sp|Q70EK8|UBP53_HUMAN,0.998226,dub,"{'86-336': 0.9981784939765931, '270-300': 0.99...",Inactive ubiquitin carboxyl-terminal hydrolase 53
3,sp|P09936|UCHL1_HUMAN,0.998221,dub,"{'60-160': 0.998074734210968, '78-178': 0.9981...",Ubiquitin carboxyl-terminal hydrolase isozyme L1
4,sp|Q504Q3|PAN2_HUMAN,0.998155,dub,"{'462-862': 0.9974869370460511, '504-904': 0.9...",PAN2-PAN3 deadenylation complex catalytic subu...
...,...,...,...,...,...
1238,sp|P54253|ATX1_HUMAN,0.050786,no,"{'399-799': 0.02540142107754946, '425-765': 0....",Ataxin-1
1239,sp|Q9BZA8|PC11Y_HUMAN,0.050359,no,"{'833-1233': 0.04338116163053201, '884-1224': ...",Protocadherin-11 Y-linked
1240,sp|P52735|VAV2_HUMAN,0.050307,no,"{'112-712': 0.04642800716683269, '126-726': 0....",Guanine nucleotide exchange factor VAV2
1241,sp|Q6UXB0|F131A_HUMAN,0.050224,no,"{'221-251': 0.015844495084820664, '204-334': 0...",Protein FAM131A
