In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from tape import ProteinBertModel, TAPETokenizer
import pandas as pd
import numpy as np
import os,re,math

device = "cuda:0"

### load TAPE bert-based pre-trained model

In [2]:
model = ProteinBertModel.from_pretrained('model/tape_bert/').to(device)
tokenizer = TAPETokenizer(vocab='iupac')

def seq_embed(seq):
    seq_tensor = torch.tensor([tokenizer.encode(seq)]).to(device)
    seq_bert = model(seq_tensor)[0][0]
    return seq_bert

### Degpred architecture

In [3]:
class DEG_LSTM(nn.Module):
    def __init__(self, input_size, deg_lstm_hidden_size, fc1_output_size, output_size):
        super().__init__()
        self.deg_lstm = nn.LSTM(input_size, deg_lstm_hidden_size, 2, bidirectional=True, batch_first = True)
        self.deg_fc1 = nn.Linear(deg_lstm_hidden_size*2, fc1_output_size)
        self.deg_fc2 = nn.Linear(fc1_output_size, output_size)

    def forward(self, input):
        output, (h, c) = self.deg_lstm(input)
        output = self.deg_fc1(output)
        output = torch.sigmoid(self.deg_fc2(output))
        return output

lstm1 = torch.load('model/five_model/degpred_model1.pkl').to(device)
lstm2 = torch.load('model/five_model/degpred_model2.pkl').to(device)
lstm3 = torch.load('model/five_model/degpred_model3.pkl').to(device)
lstm4 = torch.load('model/five_model/degpred_model4.pkl').to(device)
lstm5 = torch.load('model/five_model/degpred_model5.pkl').to(device)

In [4]:
def degpred(seq_bert):
    pred1 = lstm1(seq_bert)
    pred2 = lstm2(seq_bert)
    pred3 = lstm3(seq_bert)
    pred4 = lstm4(seq_bert)
    pred5 = lstm5(seq_bert)
    pred = (pred1 + pred2 + pred3 + pred4 + pred5)/5
    return pred

### predict a sequence (example: P53)

In [14]:
seq = 'MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPPVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD'
seq_bert = seq_embed(seq)
pred = degpred(seq_bert.unsqueeze(0)).squeeze().cpu().detach().numpy()[1:-1]

### find degrons on the sequence

In [10]:
def continusFind(num_list):
    s=1
    find_list=[]
    have_list=[]
    while s <= len(num_list)-1:
        if num_list[s] - num_list[s-1] < 4:
            flag=s-1
            while (s<=len(num_list)-1) and (num_list[s]-num_list[s-1] < 4):
                s+=1
            find_list.append([num_list[flag], num_list[s-1]])
            have_list+=num_list[flag:s]
        else:
            if abs(num_list[s-1]-num_list[s-2]) > 3:
                find_list.append(num_list[s-1:s])
            s+=1
    return find_list

In [15]:
thre = 0.3
deg_position = list(np.where(pred > thre)[0])
deg_interval = [i for i in continusFind(deg_position) if i[-1] - i[0] > 2]

for j in deg_interval:
    print('start:', j[0], 'end: ', j[1], 'degron_seq:', seq[j[0]: j[1]+1])

start: 16 end:  27 degron_seq: ETFSDLWKLLPE
start: 256 end:  262 degron_seq: LEDSSGN
start: 282 end:  290 degron_seq: RTEEENLRK


### predict binding E3s of the degrons

In [None]:
e3 = pd.read_csv('motifs/pssm_cutoffs.csv')
pssms = []
for i in e3.index:
    a = pd.read_table('motifs/' + e3.loc[i, 'E3_entry'] + '_' + str(e3.loc[i, 'length']) + '_pssm.txt', index_col=0)
    a.columns = a.columns.astype(int)
    pssms.append(a)
e3['pssm'] = pssms

In [None]:
def findE3(dseq):
    e3s = []
    for i in e3.index:
        scores = [0,]
        ps = e3.loc[i, 'pssm']
        length = e3.loc[i, 'length']
        try:
            for k in range(len(dseq) - length + 1):
                p = dseq[k : k+length]
                s = 0
                for j in range(length):
                    s += ps.loc[p[j], j+1]
                scores.append(s)
            if max(scores) > e3.loc[i, 'thre1000']:
                e3s.append(e3.loc[i, 'E3'])
        except:
            print('error in degron sequence', dseq)
    return e3s

In [None]:
for i in deg_interval:
    print(i, seq[i[0]: i[1]+1], findE3(seq[max(0, i[0]-3): min(i[1]+4, len(seq))]))