In [1]:
import os
import argparse
import subprocess
import sys
import torch
import torch.nn as nn
import utility as u
import models as md
import numpy as np
import time
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
input = '/home/zhenglei/DL_code/enzyme_second_prediction_project/test.fasta'
output = '/home/zhenglei/DL_code/enzyme_second_prediction_project/result'
batch_size = 2
device = 'cpu'

In [3]:
print('Process 1: Embedding')
start =  time.time()
id, data = u.data_embedding_vector(input, batch_size, device)
torch.save(data, f'{output}/data.pt')
end =  time.time()
time_str = u.consumption_time(start, end)
print(f'Process 1: Embedding finished. Consumption time: {time_str}.')

Process 1: Embedding
Process 1: Embedding finished. Consumption time: 0h1m35s.


In [24]:
print('Process 2: IsEnzyme prediction')
start =  time.time()
net = md.MLP_esm3B_1()
net.load_state_dict(torch.load("./models_param/esm2_3B_MLP_1.pt"))
probs, labels = u.prediction(net, data, batch_size, device)
probs = [np.array(n).max() for i in probs for n in i]
labels = np.concatenate(labels).tolist()
IsEnzyme = u.transfer_name(labels, u.enzyme_dic)
result_IsEnzyme = pd.DataFrame({"id": id, "IsEnzyme":IsEnzyme, "prob":probs})
result_IsEnzyme.to_csv(f'{output}/result_IsEnzyme.csv',index=False)
end =  time.time()
time_str = u.consumption_time(start, end)
print(f'Process 2: IsEnzyme prediction finished. Consumption time: {time_str}.')

Process 2: IsEnzyme prediction
Process 2: IsEnzyme prediction finished. Consumption time: 0h0m0s.


In [25]:
print('Process 3: First EC prediction')
start =  time.time()
bool_list = result_IsEnzyme['IsEnzyme'] == 'enzyme'
bool_tensor = torch.tensor(bool_list)
enzyme_data = data[bool_tensor]
net = md.MLP_esm3B_2()
net.load_state_dict(torch.load("./models_param/esm2_3B_MLP_2.pt"))
probs, labels = u.prediction(net, enzyme_data, batch_size, device)
probs = [np.array(n).max() for i in probs for n in i]
print(probs)
labels = np.concatenate(labels).tolist()
First_EC = u.transfer_name(labels, u.first_enzyme_dic)
result_IsEnzyme.loc[bool_list,'IsEnzyme'] = First_EC
result_IsEnzyme.loc[bool_list,'prob'] = probs
result_First_EC = result_IsEnzyme
result_First_EC.rename(columns={'IsEnzyme':'First_EC'}, inplace=True)
result_First_EC.to_csv(f'{output}/result_First_EC.csv',index=False, columns=['id','First_EC','prob'])
end =  time.time()
time_str = u.consumption_time(start, end)
print(f'Process 3: First EC prediction finished. Consumption time: {time_str}.')
result_First_EC

Process 3: First EC prediction
[0.99998677, 0.9975073, 0.996336, 0.9997508, 0.99999785]
Process 3: First EC prediction finished. Consumption time: 0h0m0s.


Unnamed: 0,id,First_EC,prob
0,PVP0,non_enzyme,0.999745
1,PVP1,non_enzyme,0.999649
2,nonPVP0,2,0.999987
3,nonPVP1,non_enzyme,0.964391
4,en1,non_enzyme,0.871626
5,en4,4,0.997507
6,en5.1.1,5,0.996336
7,en6.2.1.12,6,0.999751
8,en7.6.2,7,0.999998


In [26]:
print('Process 4: Second EC prediction')
start =  time.time()
First_EC_value = result_First_EC['First_EC'].tolist()
Second_EC_value = []
Second_EC_prob = []
for i, n in enumerate(First_EC_value):
    if n == 'non_enzyme':
        Second_EC_value.append(n)
        Second_EC_prob.append(0)
    elif n == '1':
        net = md.MLP_esm3B_3()
        net.load_state_dict(torch.load("./models_param/esm2_3B_MLP_3.pt"))
        Second_EC_data = data[[i]]
        probs, labels = u.prediction(net, Second_EC_data, batch_size, device)
        probs = [np.array(na).max() for ia in probs for na in ia]
        labels = labels[0]
        Second_EC_value.append(u.transfer_name(labels, u.second_enzyme_first_dic)[0])
        Second_EC_prob.append(probs[0])
    elif n == '2':
        net = md.MLP_esm3B_4()
        net.load_state_dict(torch.load("./models_param/esm2_3B_MLP_4.pt"))
        Second_EC_data = data[[i]]
        probs, labels = u.prediction(net, Second_EC_data, batch_size, device)
        probs = [np.array(na).max() for ia in probs for na in ia]
        labels = labels[0]
        Second_EC_value.append(u.transfer_name(labels, u.second_enzyme_second_dic)[0])
        Second_EC_prob.append(probs[0])
    elif n == '3':
        net = md.MLP_esm3B_5()
        net.load_state_dict(torch.load("./models_param/esm2_3B_MLP_5.pt"))
        Second_EC_data = data[[i]]
        probs, labels = u.prediction(net, Second_EC_data, batch_size, device)
        probs = [np.array(na).max() for ia in probs for na in ia]
        labels = labels[0]
        Second_EC_value.append(u.transfer_name(labels, u.second_enzyme_third_dic)[0])
        Second_EC_prob.append(probs[0])
    elif n == '4':
        net = md.MLP_esm3B_6()
        net.load_state_dict(torch.load("./models_param/esm2_3B_MLP_6.pt"))
        Second_EC_data = data[[i]]
        probs, labels = u.prediction(net, Second_EC_data, batch_size, device)
        probs = [np.array(na).max() for ia in probs for na in ia]
        labels = labels[0]
        Second_EC_value.append(u.transfer_name(labels, u.second_enzyme_fourth_dic)[0])
        Second_EC_prob.append(probs[0])
    elif n == '5':
        net = md.MLP_esm3B_7()
        net.load_state_dict(torch.load("./models_param/esm2_3B_MLP_7.pt"))
        Second_EC_data = data[[i]]
        probs, labels = u.prediction(net, Second_EC_data, batch_size, device)
        probs = [np.array(na).max() for ia in probs for na in ia]
        labels = labels[0]
        Second_EC_value.append(u.transfer_name(labels, u.second_enzyme_fiveth_dic)[0])
        Second_EC_prob.append(probs[0])
    elif n == '6':
        net = md.MLP_esm3B_8()
        net.load_state_dict(torch.load("./models_param/esm2_3B_MLP_8.pt"))
        Second_EC_data = data[[i]]
        probs, labels = u.prediction(net, Second_EC_data, batch_size, device)
        probs = [np.array(na).max() for ia in probs for na in ia]
        labels = labels[0]
        Second_EC_value.append(u.transfer_name(labels, u.second_enzyme_sixth_dic)[0])
        Second_EC_prob.append(probs[0])
    elif n == '7':
        net = md.MLP_esm3B_9()
        net.load_state_dict(torch.load("./models_param/esm2_3B_MLP_9.pt"))
        Second_EC_data = data[[i]]
        probs, labels = u.prediction(net, Second_EC_data, batch_size, device)
        probs = [np.array(na).max() for ia in probs for na in ia]
        labels = labels[0]
        Second_EC_value.append(u.transfer_name(labels, u.second_enzyme_seventh_dic)[0])
        Second_EC_prob.append(probs[0])
result_Second_EC = pd.DataFrame({"id": id, "Second_EC":Second_EC_value, "prob":Second_EC_prob})
result_Second_EC.to_csv(f'{output}/result_Second_EC.csv',index=False)
end =  time.time()
time_str = u.consumption_time(start, end)
print(f'Process 3: Second EC prediction finished. Consumption time: {time_str}.')

Process 4: Second EC prediction
Process 3: Second EC prediction finished. Consumption time: 0h0m0s.
