## Notebook to get some model predictions

In [1]:
import onnx
import onnxruntime as rt
import numpy as np
import matplotlib.pyplot as plt
import pickle
import json

In [2]:
model = onnx.load_model("./trained_model.onnx")
with open("./validation_features.pickle", "rb") as fp:
    test_features, test_labels = pickle.load(fp)
with open("./nsynth_train/class_map.json", "r") as fp:
    class_map = json.load(fp)

inputs = np.array(test_features, dtype=np.float32)
inputs = inputs[:, None, :, :]

In [3]:
inputs.shape

(12678, 1, 28, 28)

## Get Accuracy of onnx Model

In [4]:
providers = ['CPUExecutionProvider']
onnx_file_path = "trained_model.onnx"
output_names = [n.name for n in model.graph.output]
m = rt.InferenceSession(onnx_file_path, providers=providers)
onnx_pred = m.run(output_names, {"input": inputs})[0]

correct_preds_onnx = []
false_preds_onnx = []
for i, preds in enumerate(onnx_pred):
    label_pred = np.argmax(preds)
    if label_pred == class_map[test_labels[i]]:
        correct_preds_onnx.append(i)
    else:
        false_preds_onnx.append(i)

print(f"Correct predictions: {len(correct_preds_onnx)}, False predictions: {len(false_preds_onnx)}")
accuracy = len(correct_preds_onnx) / (len(correct_preds_onnx) + len(false_preds_onnx))
print(f"Accuracy: {accuracy}")

Correct predictions: 6786, False predictions: 5892
Accuracy: 0.5352579271178419


In [5]:
correct_preds_onnx[:10]

[0, 2, 4, 6, 7, 8, 9, 11, 12, 13]

## Get Accuracy of torch model

In [6]:
import torch
from onnx2pytorch import ConvertModel
onnx_model_untrained = onnx.load_model("./nsynth_train/mnist_relu_4_1024.onnx")
torch_model = ConvertModel(onnx_model_untrained)
torch_model.load_state_dict(torch.load("./model_20240307_165159_0"))
torch_model

  layer.weight.data = torch.from_numpy(numpy_helper.to_array(weight))


ConvertModel(
  (Transpose_9): Transpose()
  (Constant_10): Constant(constant=tensor([ -1, 784]))
  (Reshape_11): Reshape(shape=None)
  (Gemm_12): Linear(in_features=784, out_features=1024, bias=True)
  (Relu_13): ReLU(inplace=True)
  (Gemm_14): Linear(in_features=1024, out_features=1024, bias=True)
  (Relu_15): ReLU(inplace=True)
  (Gemm_16): Linear(in_features=1024, out_features=1024, bias=True)
  (Relu_17): ReLU(inplace=True)
  (Gemm_18): Linear(in_features=1024, out_features=10, bias=True)
)

In [7]:
from nsynth_train.train_model import NSynthDataset
test_dataset = NSynthDataset(picklefile="./validation_features.pickle", class_map="./nsynth_train/class_map.json")
test_dataset_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

torch_model.eval()
correct_preds = []
false_preds = []
with torch.no_grad():
    for i, vdata in enumerate(test_dataset_loader):
        vinputs, vlabels = vdata['melfeatures'], vdata['instrument']
        vinputs = vinputs.to(torch.float)
        vinputs = vinputs[None, :, :, :]
        voutputs = torch_model(vinputs)
        class_pred = np.argmax(voutputs)
        if class_pred == vlabels:
            correct_preds.append(i)
        else:
            false_preds.append(i)

print(f"Correct predictions: {len(correct_preds)}, False predictions: {len(false_preds)}")
accuracy = len(correct_preds) / (len(correct_preds) + len(false_preds))
print(f"Accuracy: {accuracy}")

Correct predictions: 6786, False predictions: 5892
Accuracy: 0.5352579271178419


In [8]:
correct_preds[:10]

[0, 2, 4, 6, 7, 8, 9, 11, 12, 13]

## Add modified input layer to trained model

In [8]:
from nsynth_train.modify_model_for_verification import modify_model
torch_model = modify_model(torch_model)

In [9]:
from nsynth_train.train_model import NSynthDataset
test_dataset = NSynthDataset(picklefile="./validation_features.pickle", class_map="./nsynth_train/class_map.json")
test_dataset_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

torch_model.eval()
correct_preds = []
false_preds = []
with torch.no_grad():
    for i, vdata in enumerate(test_dataset_loader):
        vinputs, vlabels = vdata['melfeatures'], vdata['instrument']
        vinputs = vinputs.to(torch.float)
        vinputs = vinputs[None, :, :, :]
        vinputs = torch.cat([vinputs, torch.normal(mean=1, std=0.05, size=(1, 1, 1, 28))], dim=2)
        voutputs = torch_model(vinputs)
        class_pred = np.argmax(voutputs)
        if class_pred == vlabels:
            correct_preds.append(i)
        else:
            false_preds.append(i)

print(f"Correct predictions: {len(correct_preds)}, False predictions: {len(false_preds)}")
accuracy = len(correct_preds) / (len(correct_preds) + len(false_preds))
print(f"Accuracy: {accuracy}")

  filtered_features = (features * filter.T).reshape(new_shape)


Correct predictions: 6770, False predictions: 5908
Accuracy: 0.5339958984066887


In [10]:
correct_preds[:10] # Here, the test sample at index 4 is now misclassified.

[0, 2, 6, 7, 8, 9, 11, 12, 17, 18]

## Modify feature files for filter application

In [4]:
feature_path = "./validation_features.pickle"
modified_features_path = "./validation_features_mod.pickle"
with open(feature_path, "rb") as fp:
    features, labels = pickle.load(fp)
# add vector of ones to every feature
modified_features = [np.concatenate([sample, np.ones((1, sample.shape[1]))]) for sample in features]

with open(modified_features_path, "wb") as fp:
    pickle.dump((modified_features, labels), fp)


In [5]:
from nsynth_train.generate_features import modify_features_for_verification
_, __ = modify_features_for_verification(feature_path, modified_features_path)

In [6]:
with open(modified_features_path, "rb") as fp:
    features_mod, labels = pickle.load(fp)
print(features_mod[:10])

[array([[1.94688782, 0.82821479, 0.62989824, 0.47311177, 0.46633712,
        0.43526602, 0.42756692, 0.41879633, 0.4167358 , 0.39321551,
        0.36755499, 0.31671499, 0.30134821, 0.2966519 , 0.3065595 ,
        0.26862622, 0.28942391, 0.2923381 , 0.26817182, 0.31736253,
        0.34207496, 0.8572281 , 0.66690438, 0.2374782 , 0.12159718,
        0.04733063, 0.03166663, 0.00988467],
       [1.97362895, 0.76966093, 0.52250157, 0.40954744, 0.35449662,
        0.31441901, 0.28612494, 0.2616344 , 0.25445395, 0.24188897,
        0.23524277, 0.21890228, 0.21565717, 0.20215947, 0.22136185,
        0.2047519 , 0.20407975, 0.22675489, 0.21970521, 0.23419508,
        0.31047135, 1.01216838, 0.29902232, 0.07887312, 0.03368206,
        0.02350871, 0.01218765, 0.00624659],
       [1.91487672, 0.75402094, 0.50446603, 0.3998452 , 0.34125425,
        0.30129292, 0.26945313, 0.24831998, 0.22850787, 0.21576501,
        0.20757328, 0.19353355, 0.18631936, 0.17936971, 0.17243153,
        0.19337589, 0.158