In [1]:
import torch
import onnx
import onnxruntime as ort
import numpy as np
from tqdm.auto import tqdm
import json
from numpyencoder import NumpyEncoder

In [6]:
# Global Configs

DELTA = 0.99
NET_PATH = "./mnist-net_256x4.onnx"
DATA = torch.load("./dataset/training.pt")
LAYERS = 4
EXPT_NAME = "256x4_delta99_nt"

In [11]:
# Data preprocessing

imgs = torch.load("./dataset/normalized_train.pt")
imgs = imgs.reshape(60000, 1, 784, 1).numpy()

DATA = (None, torch.load("./dataset/nt_targets.pt"))

In [12]:
# Modifying model to get neuron activations

model = onnx.load(NET_PATH)

def add_intermediate_outputs(model):
    graph = model.graph
    for node in graph.node:
        if node.op_type == "Relu":
            for output in node.output:
                intermediate_value_info = onnx.helper.make_tensor_value_info(output, onnx.TensorProto.FLOAT, None)
                graph.output.append(intermediate_value_info)
    return model

modified_model = add_intermediate_outputs(model)

modified_model_path = f'{NET_PATH.split(".")[1][1:]}_modified.onnx'
onnx.save(modified_model, modified_model_path)


In [13]:
# Setting up inference in onnx

session = ort.InferenceSession(modified_model_path)
input_name = session.get_inputs()[0].name

In [14]:
# Going through each label
P = {
    'config':
    {
        'net': NET_PATH,
        'delta': DELTA,
        'data_len': DATA[1].shape[0]
    }
}

for label in range(10):
    # Filetering relevant data alone
    mask = (DATA[1] == label).numpy()
    S = imgs[mask]

    # Initializing a counter
    count = np.zeros((LAYERS, 256))

    # Counting across relevant data
    print(f"Processing label {label}...")

    for example in tqdm(S):
        outputs = session.run(None, {input_name: example})
        neuron_activations = np.concatenate(outputs[1:])
        whether_activated = neuron_activations > 0
        count += whether_activated

    # Adding neuron indices in A or D based on whether their fr (frequency ratio) is greater than or lesser than delta
    fr = count / S.shape[0]
    greater_than_delta = np.where(fr>=DELTA)
    lesser_than_delta = np.where(fr<(1-DELTA))
    A = list(zip(greater_than_delta[0], greater_than_delta[1]))
    D = list(zip(lesser_than_delta[0], lesser_than_delta[1]))
    
    P[label] = {
        "A":
        {
            "len": len(A),
            "indices": A
        },
        "B": {
            "len": len(D),
            "indices": D
        }
    }

Processing label 0...


  0%|          | 0/5923 [00:00<?, ?it/s]

Processing label 1...


  0%|          | 0/6742 [00:00<?, ?it/s]

Processing label 2...


  0%|          | 0/5958 [00:00<?, ?it/s]

Processing label 3...


  0%|          | 0/6131 [00:00<?, ?it/s]

Processing label 4...


  0%|          | 0/5842 [00:00<?, ?it/s]

Processing label 5...


  0%|          | 0/5421 [00:00<?, ?it/s]

Processing label 6...


  0%|          | 0/5918 [00:00<?, ?it/s]

Processing label 7...


  0%|          | 0/6265 [00:00<?, ?it/s]

Processing label 8...


  0%|          | 0/5851 [00:00<?, ?it/s]

Processing label 9...


  0%|          | 0/5949 [00:00<?, ?it/s]

In [15]:
# Writing it out into a file

with open(f"./NAPs/{EXPT_NAME}.json", "w") as f:
    json.dump(P, f, cls=NumpyEncoder)