## Notebook to get some model predictions

In [1]:
import onnx
import onnxruntime as rt
import numpy as np
import matplotlib.pyplot as plt
import pickle
import json

In [2]:
model = onnx.load_model("./trained_model.onnx")
with open("./validation_features.pickle", "rb") as fp:
    test_features, test_labels = pickle.load(fp)
with open("./nsynth_train/class_map.json", "r") as fp:
    class_map = json.load(fp)

inputs = np.array(test_features, dtype=np.float32)
inputs = inputs[:, None, :, :]

In [3]:
inputs.shape

(12678, 1, 28, 28)

## Get Accuracy of onnx Model

In [14]:
providers = ['CPUExecutionProvider']
onnx_file_path = "trained_model.onnx"
output_names = [n.name for n in model.graph.output]
m = rt.InferenceSession(onnx_file_path, providers=providers)
onnx_pred = m.run(output_names, {"input": inputs})[0]

correct_preds_onnx = []
false_preds_onnx = []
for i, preds in enumerate(onnx_pred):
    label_pred = np.argmax(preds)
    if label_pred == class_map[test_labels[i]]:
        correct_preds_onnx.append(i)
    else:
        false_preds_onnx.append(i)

print(f"Correct predictions: {len(correct_preds_onnx)}, False predictions: {len(false_preds_onnx)}")
accuracy = len(correct_preds_onnx) / (len(correct_preds_onnx) + len(false_preds_onnx))
print(f"Accuracy: {accuracy}")

Correct predictions: 6786, False predictions: 5892
Accuracy: 0.5352579271178419


In [16]:
correct_preds_onnx[:10]

[0, 2, 4, 6, 7, 8, 9, 11, 12, 13]

## Get Accuracy of torch model

In [10]:
import torch
from onnx2pytorch import ConvertModel
onnx_model_untrained = onnx.load_model("./nsynth_train/mnist_relu_4_1024.onnx")
torch_model = ConvertModel(onnx_model_untrained)
torch_model.load_state_dict(torch.load("./model_20240307_165159_0"))
torch_model

  layer.weight.data = torch.from_numpy(numpy_helper.to_array(weight))


ConvertModel(
  (Transpose_9): Transpose()
  (Constant_10): Constant(constant=tensor([ -1, 784]))
  (Reshape_11): Reshape(shape=None)
  (Gemm_12): Linear(in_features=784, out_features=1024, bias=True)
  (Relu_13): ReLU(inplace=True)
  (Gemm_14): Linear(in_features=1024, out_features=1024, bias=True)
  (Relu_15): ReLU(inplace=True)
  (Gemm_16): Linear(in_features=1024, out_features=1024, bias=True)
  (Relu_17): ReLU(inplace=True)
  (Gemm_18): Linear(in_features=1024, out_features=10, bias=True)
)

In [15]:
from nsynth_train.train_model import NSynthDataset
test_dataset = NSynthDataset(picklefile="./validation_features.pickle", class_map="./nsynth_train/class_map.json")
test_dataset_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

torch_model.eval()
correct_preds = []
false_preds = []
with torch.no_grad():
    for i, vdata in enumerate(test_dataset_loader):
        vinputs, vlabels = vdata['melfeatures'], vdata['instrument']
        vinputs = vinputs.to(torch.float)
        vinputs = vinputs[None, :, :, :]
        voutputs = torch_model(vinputs)
        class_pred = np.argmax(voutputs)
        if class_pred == vlabels:
            correct_preds.append(i)
        else:
            false_preds.append(i)

print(f"Correct predictions: {len(correct_preds)}, False predictions: {len(false_preds)}")
accuracy = len(correct_preds) / (len(correct_preds) + len(false_preds))
print(f"Accuracy: {accuracy}")

Correct predictions: 6786, False predictions: 5892
Accuracy: 0.5352579271178419


In [13]:
correct_preds[:10]

[0, 2, 4, 6, 7, 8, 9, 11, 12, 13]