In [29]:
import torch
import onnx
import onnxruntime as ort
import numpy as np
from tqdm.auto import tqdm
import json
from numpyencoder import NumpyEncoder

In [19]:
# Global Configs

DELTA = 1
NET_PATH = "./mnist-net_256x4.onnx"
DATA = torch.load("./dataset/training.pt")
LAYERS = 4
EXPT_NAME = "256x4_delta1"

In [3]:
# Data preprocessing

imgs = DATA[0]/255
imgs = imgs.reshape(60000, 1, 784, 1).numpy()

In [4]:
# Modifying model to get neuron activations

model = onnx.load(NET_PATH)

def add_intermediate_outputs(model):
    graph = model.graph
    for node in graph.node:
        if node.op_type == "Relu":
            for output in node.output:
                intermediate_value_info = onnx.helper.make_tensor_value_info(output, onnx.TensorProto.FLOAT, None)
                graph.output.append(intermediate_value_info)
    return model

modified_model = add_intermediate_outputs(model)

modified_model_path = f'{NET_PATH.split(".")[1][1:]}_modified.onnx'
onnx.save(modified_model, modified_model_path)


graph torch-jit-export (
  %0[FLOAT, 1x784x1]
) initializers (
  %layers.0.bias[FLOAT, 256]
  %layers.0.weight[FLOAT, 256x784]
  %layers.2.bias[FLOAT, 256]
  %layers.2.weight[FLOAT, 256x256]
  %layers.4.bias[FLOAT, 256]
  %layers.4.weight[FLOAT, 256x256]
  %layers.6.bias[FLOAT, 256]
  %layers.6.weight[FLOAT, 256x256]
  %layers.8.bias[FLOAT, 10]
  %layers.8.weight[FLOAT, 10x256]
) {
  %11 = Flatten[axis = 1](%0)
  %12 = Gemm[alpha = 1, beta = 1, transB = 1](%11, %layers.0.weight, %layers.0.bias)
  %13 = Relu(%12)
  %14 = Gemm[alpha = 1, beta = 1, transB = 1](%13, %layers.2.weight, %layers.2.bias)
  %15 = Relu(%14)
  %16 = Gemm[alpha = 1, beta = 1, transB = 1](%15, %layers.4.weight, %layers.4.bias)
  %17 = Relu(%16)
  %18 = Gemm[alpha = 1, beta = 1, transB = 1](%17, %layers.6.weight, %layers.6.bias)
  %19 = Relu(%18)
  %20 = Gemm[alpha = 1, beta = 1, transB = 1](%19, %layers.8.weight, %layers.8.bias)
  return %20
}


In [5]:
# Setting up inference in onnx

session = ort.InferenceSession(modified_model_path)
input_name = session.get_inputs()[0].name

In [16]:
# Going through each label
P = {
    'config':
    {
        'net': NET_PATH,
        'delta': DELTA,
        'data_len': DATA[1].shape[0]
    }
}

for label in range(10):
    # Filetering relevant data alone
    mask = (DATA[1] == label).numpy()
    S = imgs[mask]

    # Initializing a counter
    count = np.zeros((LAYERS, 256))

    # Counting across relevant data
    print(f"Processing label {label}...")

    for example in tqdm(S):
        outputs = session.run(None, {input_name: example})
        neuron_activations = np.concatenate(outputs[1:])
        whether_activated = neuron_activations > 0
        count += whether_activated

    # Adding neuron indices in A or D based on whether their fr (frequency ratio) is greater than or lesser than delta
    fr = count / S.shape[0]
    greater_than_delta = np.where(fr>=DELTA)
    lesser_than_delta = np.where(fr<(1-DELTA))
    A = list(zip(greater_than_delta[0], greater_than_delta[1]))
    D = list(zip(lesser_than_delta[0], lesser_than_delta[1]))
    
    P[label] = {
        "A":
        {
            "len": len(A),
            "indices": A
        },
        "B": {
            "len": len(D),
            "indices": D
        }
    }

Processing label 0...


100%|██████████| 5923/5923 [00:00<00:00, 24033.25it/s]


Processing label 1...


100%|██████████| 6742/6742 [00:00<00:00, 25504.12it/s]


Processing label 2...


100%|██████████| 5958/5958 [00:00<00:00, 25183.22it/s]


Processing label 3...


100%|██████████| 6131/6131 [00:00<00:00, 24704.23it/s]


Processing label 4...


100%|██████████| 5842/5842 [00:00<00:00, 23116.72it/s]


Processing label 5...


100%|██████████| 5421/5421 [00:00<00:00, 21979.09it/s]


Processing label 6...


100%|██████████| 5918/5918 [00:00<00:00, 24214.33it/s]


Processing label 7...


100%|██████████| 6265/6265 [00:00<00:00, 22706.75it/s]


Processing label 8...


100%|██████████| 5851/5851 [00:00<00:00, 22693.17it/s]


Processing label 9...


100%|██████████| 5949/5949 [00:00<00:00, 21749.78it/s]


In [31]:
# Writing it out into a file

with open(f"./NAPs/{EXPT_NAME}.json", "w") as f:
    json.dump(P, f, cls=NumpyEncoder)