In [28]:
import onnx
import onnxruntime as ort
import numpy as np
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from numpyencoder import NumpyEncoder

In [13]:
NET_PATH = './models/mnist_fc_64x4_adv_1.onnx'
FILE_NAME = 'mnist_fc_64x4_adv_1'
DELTA = 1

In [14]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.,), (1.,))
])

trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=50000, shuffle=False)

imgs, labels = next(iter(trainloader))

In [15]:

model = onnx.load(NET_PATH)

def add_intermediate_outputs(model):
    graph = model.graph
    for node in graph.node:
        if node.op_type == "Relu":
            for output in node.output:
                intermediate_value_info = onnx.helper.make_tensor_value_info(output, onnx.TensorProto.FLOAT, None)
                graph.output.append(intermediate_value_info)
    return model

modified_model = add_intermediate_outputs(model)

modelname = NET_PATH.split("/")[-1].split('.')[0]
modified_model_path = f'./.temp/{modelname}_modified.onnx'
onnx.save(modified_model, modified_model_path)

In [16]:
session = ort.InferenceSession(modified_model_path)
input_name = session.get_inputs()[0].name

In [17]:
def get_relationships(l1, l2):
    a = l1.reshape((-1, 1))
    b = l2.reshape((-1))
    na = np.logical_not(a)
    nb = np.logical_not(b)
    
    o1 = np.logical_and(a, b)
    o2 = np.logical_and(na, b)
    o3 = np.logical_and(a, nb)
    o4 = np.logical_and(na, nb)

    return np.stack([o1, o2, o3, o4]).astype(int)

In [18]:
res = {}

for label in range(10):
    # Filetering relevant data alone
    mask = (labels == label).numpy()
    S = imgs[mask]

    # Initializing a counter
    counter = None

    print(f"Processing label {label}...")

    for example in tqdm(S):
        outputs = session.run(None, {input_name: example.numpy().reshape(1, 1, 28, 28)})
        if counter is None:
            counter = {}
            for op in range(1, len(outputs)-1):
                counter[f'{op-1}:{op}'] = np.zeros((4, outputs[op].size, outputs[op+1].size), dtype=np.int64)
        
        for op in range(1, len(outputs)-1):
            a = outputs[op] > 0.
            b = outputs[op+1] > 0.
            counter[f'{op-1}:{op}'] += get_relationships(a, b)

    res[label] = (counter, S.shape[0])


Processing label 0...


  0%|          | 0/4932 [00:00<?, ?it/s]

Processing label 1...


  0%|          | 0/5678 [00:00<?, ?it/s]

Processing label 2...


  0%|          | 0/4968 [00:00<?, ?it/s]

Processing label 3...


  0%|          | 0/5101 [00:00<?, ?it/s]

Processing label 4...


  0%|          | 0/4859 [00:00<?, ?it/s]

Processing label 5...


  0%|          | 0/4506 [00:00<?, ?it/s]

Processing label 6...


  0%|          | 0/4951 [00:00<?, ?it/s]

Processing label 7...


  0%|          | 0/5175 [00:00<?, ?it/s]

Processing label 8...


  0%|          | 0/4842 [00:00<?, ?it/s]

Processing label 9...


  0%|          | 0/4988 [00:00<?, ?it/s]

In [19]:
import pickle
with open(f"counters/{FILE_NAME}.pkl", "wb") as f:
    pickle.dump(res, f)

In [25]:
nap2s = {}

In [26]:
if DELTA == 1:
    for label in res:
        nap2s[label] = {}
        for layerpair in res[label][0]:
            nap2s[label][layerpair] = list(zip(*(res[label][0][layerpair] == res[label][1]).nonzero()))
else:
    raise NotImplementedError()

In [29]:
import json
with open(f"NAP2s/{FILE_NAME}.json", "w") as f:
    json.dump(nap2s, f, cls=NumpyEncoder)