In [1]:
import numpy as np
import torch
import torch.nn as nn

# add parent dict to path
import os, sys
sys.path.append(os.path.abspath(".."))
from src import treenode as tn

In [2]:

class NeuralNet(nn.Module):
    def __init__(self, input_size, hidden_sizes, num_classes, classification=False):
        super(NeuralNet, self).__init__()
        self.hidden_sizes = hidden_sizes

        for i in range(len(hidden_sizes)):
            layer_name = f"l{i + 1}"
            relu_name = f"relu{i + 1}"
            if i == 0:
                setattr(self, layer_name, nn.Linear(input_size, hidden_sizes[i]))
                setattr(self, relu_name, nn.ReLU())
            else:
                setattr(
                    self, layer_name, nn.Linear(hidden_sizes[i - 1], hidden_sizes[i])
                )
                setattr(self, relu_name, nn.ReLU())

        output_layer_name = f"l{len(hidden_sizes) + 1}"

        setattr(self, output_layer_name, nn.Linear(hidden_sizes[-1], num_classes))
        # setattr(self, "output_activation", nn.Softmax(dim=num_classes - 1) if classification else nn.Identity())
        
    def forward(self, x):
        out = x
        for i in range(len(self.hidden_sizes)):
            layer_name = f"l{i + 1}"
            relu_name = f"relu{i + 1}"
            out = getattr(self, layer_name)(out)
            out = getattr(self, relu_name)(out)

        output_layer_name = f"l{len(self.hidden_sizes) + 1}"
        out = getattr(self, output_layer_name)(out)
        # out = getattr(self, "output_activation")(out)
        # if len(out.shape) == 1:
        #     out = out.unsqueeze(0)
        # elif len(out.shape) == 2 and out.shape[0] == 1:
        #     out = out.squeeze(0)
        # elif len(out.shape) == 3 and out.shape[0] == 1:
        #     out = out.squeeze(0)
        return out

In [3]:
# Make model
model = NeuralNet(input_size=2, num_classes=1, hidden_sizes=[4,6,5,3])
tree = tn.RegionTree(model.state_dict())
tree.build_tree()
root = tree.get_root()
print("Tree built successfully")


Processing Layer 1/5: 16it [00:00, 17458.08it/s]
Processing Layer 2/5: 64it [00:00, 1770.75it/s]
Processing Layer 3/5: 32it [00:00, 33.59it/s]
Processing Layer 4/5: 8it [00:05,  1.45it/s]
Processing Layer 5/5: 2it [00:08,  4.28s/it]

Tree built successfully





In [4]:
# some two dimensional inputs
inputs = np.random.rand(10000, 2) * 10 - 5  # Random inputs in range [-5, 5]
print(inputs)

[[-3.9985647   4.44384473]
 [ 3.52662872  4.95983818]
 [-4.96513107 -0.84823465]
 ...
 [ 1.8050007   1.98222272]
 [-1.26830481 -1.37224948]
 [-2.48362775  2.11371599]]


In [None]:
from tqdm import tqdm
tree.pass_input_through_tree(inputs[0])
for input in tqdm(inputs):
    tree.pass_input_through_tree(input, reset=True)
    

100%|██████████| 10000/10000 [00:01<00:00, 6419.04it/s]


In [9]:
tree.read_off_counters()

Layer 0 has 1 nodes.
Counters: [0]
Layer 1 has 16 nodes.
Counters: [930, 0, 11229, 0, 0, 0, 1686, 0, 882, 4308, 1146, 0, 0, 5556, 3621, 645]
Layer 2 has 1024 nodes.
Counters: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 930, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2127, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 