In [1]:

from torch_geometric.data import DataLoader
from tqdm import tqdm
import torch, os
from model import get_inter_feature, main_model
from src import data_loader, model_utils, metric
device = torch.device("cuda:7" if torch.cuda.is_available() else "cpu")
current_pid = os.getpid()
print(f"THE PROCESS IS:{current_pid}")

THE PROCESS IS:987725


In [2]:
batch_size = 1
epoches = 50
learning_rate = 0.0001

In [3]:
demo_dataset_path = "./src/demo_dataset/"
X_demo = data_loader.protein_dataset(demo_dataset_path)
demo_loader = DataLoader(X_demo, batch_size=batch_size, shuffle=False, drop_last=True)



In [4]:
model = main_model.CL_interpro_model(inter_size = 18847, inter_hid=1280, graph_size=20+184+9, graph_hid=1280, seq_size=1280, seq_hid=1280, label_num=10, head=4).to(device)
optim = torch.optim.Adam(params = model.parameters(),lr = learning_rate, weight_decay=0.0001)

In [5]:
ckp = "./src/ckp/dataset0_ckp.pt"
model, optim, current_epoch, min_val_loss = model_utils.load_ckp(ckp, model, optim, device = device)

In [6]:
model.eval()
with torch.no_grad():
    for data in demo_loader:    
        esm_tokens, esm_representations, edge_index, one_hot_seq, interpro = data.esm_tokens.to(torch.float32).to(device), data.esm_representations.to(device), data.edge_index.to(device), data.one_hot_seq.to(device), data.interpro.to(device)
        inter_features = get_inter_feature.get_interpro_data(interpro, device)
        label = data.SL_label.float()
        batch = data.batch.to(device)
        node_feat = torch.cat([data.one_hot_seq, data.h_V_geo, data.DSSP], dim=-1).to(device)
        ID = data["ID"]
        structure_embedding, sequence_embedding, y_pred = model(esm_tokens, edge_index, batch, inter_features, esm_representations, node_feat)
        y_pred = torch.sigmoid(y_pred).to(torch.float32).detach()[0]
        print("ID:%s, Cytoplasm:%.3f, Nucleus:%.3f, Extracellular:%.3f, Cell membrane:%.3f, Mitochondrion:%.3f, Plastid:%.3f, Endoplasmic reticulum:%.3f, Lysosome/Vacuole:%.3f, Golgi apparatus:%.3f, Peroxisome:%.3f" %(ID[0], y_pred[0],y_pred[1],y_pred[2], y_pred[3], y_pred[4], y_pred[5], y_pred[6], y_pred[7], y_pred[8], y_pred[9]))
        print(label)



ID:P39707, Cytoplasm:0.119, Nucleus:0.750, Extracellular:0.000, Cell membrane:0.000, Mitochondrion:0.899, Plastid:0.000, Endoplasmic reticulum:0.001, Lysosome/Vacuole:0.000, Golgi apparatus:0.000, Peroxisome:0.000
tensor([[0., 1., 0., 0., 1., 0., 0., 0., 0., 0.]])
ID:P39655, Cytoplasm:0.988, Nucleus:0.239, Extracellular:0.003, Cell membrane:0.306, Mitochondrion:0.003, Plastid:0.000, Endoplasmic reticulum:0.009, Lysosome/Vacuole:0.003, Golgi apparatus:0.031, Peroxisome:0.000
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
ID:P39677, Cytoplasm:0.002, Nucleus:0.000, Extracellular:0.000, Cell membrane:0.000, Mitochondrion:0.999, Plastid:0.000, Endoplasmic reticulum:0.000, Lysosome/Vacuole:0.000, Golgi apparatus:0.000, Peroxisome:0.000
tensor([[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])
ID:P39683, Cytoplasm:0.993, Nucleus:0.787, Extracellular:0.000, Cell membrane:0.004, Mitochondrion:0.003, Plastid:0.000, Endoplasmic reticulum:0.000, Lysosome/Vacuole:0.001, Golgi apparatus:0.000, Peroxiso