In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
sys.path.append("/Users/chenweijia/Documents/code/nettcr_pytorch/src")

In [75]:
import onnxruntime as ort

import onnx
import torch

import numpy as np
import pandas as pd 
import random
import time

from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score

# import dataset
# import constants
from dataset import CustomDataset
from constants import blosum50_20aa


In [10]:
def print_model_info(model_path):
    # Load the ONNX model
    model = onnx.load(model_path)

    # Print the model's input information
    print("Model Inputs:")
    for input_tensor in model.graph.input:
        print(f"Name: {input_tensor.name}")
        # Tensor shape may be None or empty in some ONNX models if not explicitly set
        shape = [dim.dim_value for dim in input_tensor.type.tensor_type.shape.dim]
        print(f"Shape: {shape}")
        print(f"Type: {input_tensor.type.tensor_type.elem_type}")  # This is an integer representing a data type

    # Print the model architecture by iterating over all nodes in the graph
    print("\nModel Architecture:")
    for node in model.graph.node:
        print(f"Node name: {node.name}")
        print(f"Node type: {node.op_type}")
        print(f"Inputs: {node.input}")
        print(f"Outputs: {node.output}")
        print("---")

# Example usage
model_path = 'models/onnx/t.0.v.1.onnx'
print_model_info(model_path)

Model Inputs:
Name: serving_default_a2:0
Shape: [0, 8, 20]
Type: 1
Name: serving_default_a1:0
Shape: [0, 7, 20]
Type: 1
Name: serving_default_b2:0
Shape: [0, 7, 20]
Type: 1
Name: serving_default_b1:0
Shape: [0, 6, 20]
Type: 1
Name: serving_default_a3:0
Shape: [0, 22, 20]
Type: 1
Name: serving_default_b3:0
Shape: [0, 23, 20]
Type: 1
Name: serving_default_pep:0
Shape: [0, 12, 20]
Type: 1

Model Architecture:
Node name: model/conv1d_5/Conv1D/ExpandDims
Node type: Unsqueeze
Inputs: ['serving_default_a1:0', 'const_fold_opt__428']
Outputs: ['model/conv1d_5/Conv1D/ExpandDims']
---
Node name: model/conv1d_9/Relu;model/conv1d_9/BiasAdd;model/conv1d_9/Conv1D/Squeeze;model/conv1d_9/BiasAdd/ReadVariableOp;model/conv1d_4/Conv1D;model/conv1d_9/Conv1D__119
Node type: Transpose
Inputs: ['model/conv1d_5/Conv1D/ExpandDims']
Outputs: ['model/conv1d_9/Relu;model/conv1d_9/BiasAdd;model/conv1d_9/Conv1D/Squeeze;model/conv1d_9/BiasAdd/ReadVariableOp;model/conv1d_4/Conv1D;model/conv1d_9/Conv1D__119:0']
---
Nod

In [11]:
def load_model(path):
    # Create an inference session with onnxruntime
    return ort.InferenceSession(path)

def compute_metrics(y_true, y_pred):
    accuracy = (np.array(y_pred) == np.array(y_true)).mean()
    auc = roc_auc_score(y_true, y_pred)
    return accuracy, auc

def eval(test_data, outdir, model_name, model_type, seed=15, batch_size=64):
    # Set random seed
    np.random.seed(seed)  # Numpy module.
    random.seed(seed)  # Python random module.
    # Read in data
    test_df = pd.read_csv(test_data)
    test_data = CustomDataset(test_data, ENCODING)
    data_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
    pep_list = list(test_df.peptide.value_counts(ascending=False).index)
    

In [12]:
session = load_model("models/onnx/t.0.v.1.onnx")
input_name = session.get_inputs()[0].name  # Get the input name for the model

In [14]:
ENCODING = blosum50_20aa 

In [42]:
test_data_path = "/Users/chenweijia/Documents/code/nettcr_pytorch/data/train_example.csv"
test_data = CustomDataset(test_data_path, ENCODING, isEval=True)
data_loader = DataLoader(test_data, batch_size=64, shuffle=True)

In [71]:
y_pred = []
y_true = []

In [77]:
input_data = {k:v.cpu().numpy() for k,v in data.items()}

In [80]:
for batch_idx, (data, labels, weight) in enumerate(data_loader):
    # Convert data to a format suitable for ONNX runtime (numpy)
    # data = [i.to(device) for i in data]
    # labels = labels.to(device)
    # Run inference
    outputs = session.run(None, input_data)

    # Convert outputs to tensor and get predictions
    outputs_tensor = torch.tensor(outputs[0])
    preds = (outputs_tensor > 0.5).to(torch.float32)

    preds.squeeze()

    # Store predictions and true labels
    y_pred.extend(preds.numpy())
    y_true.extend(labels.numpy())

    print(f"Batch id: {batch_idx}/{len(data_loader)}")
    accuracy, auc = compute_metrics(y_true, y_pred)
    print(f"Accuracy: {accuracy}, AUC: {auc}")

Batch id: 0/324
Accuracy: 0.7377431441326531, AUC: 0.5347397260273972
Batch id: 1/324
Accuracy: 0.7295074462890625, AUC: 0.5120892494929007
Batch id: 2/324
Accuracy: 0.7279730902777778, AUC: 0.511428448442847
Batch id: 3/324
Accuracy: 0.724599609375, AUC: 0.5311166398866601
Batch id: 4/324
Accuracy: 0.7237780862603306, AUC: 0.5289736407383465
Batch id: 5/324
Accuracy: 0.7195841471354166, AUC: 0.5280978035640282
Batch id: 6/324
Accuracy: 0.7192758413461539, AUC: 0.5222984562607204
Batch id: 7/324
Accuracy: 0.7197464923469388, AUC: 0.5220005659309565
Batch id: 8/324
Accuracy: 0.7153125, AUC: 0.5348124098124099


KeyboardInterrupt: 