In [1]:
import numpy as np 
import matplotlib.pyplot as plt 
import src_experiment as scf
import geobin as gb 
import pathlib as pl
import torch
from sklearn.datasets import make_moons


training_seed = 42
testing_seed = 41
inference_seed = 40



# Read config file and/or parse cmb-line arguments
# args = scf.get_args()



# # Check if the experiment path exists, if not, create 
# experiment_path = scf.get_path_to_experiment_storage(args.experiment_name)
# state_dict_path = experiment_path/"state_dicts"
# if not state_dict_path.exists():
#     scf.createfolders(state_dict_path)

# Make the data
train_data = scf.make_moon_dataloader(n_samples=100000, noise=0.15, random_state=training_seed, batch_size=250)
test_data = scf.make_moon_dataloader(n_samples=2000, noise=0.15, random_state=testing_seed, batch_size=250)

# # Set up the model based on the parsed arguments or the config file
model = scf.NeuralNet(input_size=2, hidden_sizes=[3,3,3,3], num_classes=1)

test_tree = gb.RegionTree(model.state_dict())


In [2]:
test_tree.build_tree()

Building tree...


Layer 1 / 5: 100%|██████████| 8/8 [00:00<00:00, 453.63it/s]
Layer 2 / 5: 100%|██████████| 8/8 [00:00<00:00, 499.40it/s]
Layer 3 / 5: 100%|██████████| 8/8 [00:00<00:00, 335.88it/s]
Layer 4 / 5: 100%|██████████| 8/8 [00:00<00:00, 93.82it/s]
Layer 5 / 5: 100%|██████████| 2/2 [00:00<00:00, 13.87it/s]


In [3]:
for i, (inputs,labels) in enumerate(train_data):
    point = inputs[0], labels[0]

In [4]:
print(point)

(tensor([1.2440, 0.0346], dtype=torch.float64), tensor(0))


In [5]:
point = (point[0].numpy(), point[1].numpy())

In [6]:
test_tree.pass_single_point_through_tree(point)

In [7]:
counts = test_tree.get_number_counts()

In [8]:
print(counts.to_string())

       layer_idx  region_idx    0  total
0              0           0  0.0    0.0
1              1           0  0.0    0.0
2              2           0  0.0    0.0
3              3           0  0.0    0.0
4              4           0  0.0    0.0
5              5           0  0.0    0.0
6              5        4096  0.0    0.0
7              4         512  0.0    0.0
8              5         512  0.0    0.0
9              5        4608  0.0    0.0
10             4        1024  0.0    0.0
11             5        1024  0.0    0.0
12             5        5120  0.0    0.0
13             4        1536  0.0    0.0
14             5        1536  0.0    0.0
15             5        5632  0.0    0.0
16             4        2048  0.0    0.0
17             5        2048  0.0    0.0
18             5        6144  0.0    0.0
19             4        2560  0.0    0.0
20             5        2560  0.0    0.0
21             5        6656  0.0    0.0
22             4        3072  0.0    0.0
23             5

In [9]:
print(counts.sort_values("0"))

      layer_idx  region_idx    0  total
8586          5        1877  0.0    0.0
8574          5        7957  0.0    0.0
8575          3         341  0.0    0.0
8576          4         341  0.0    0.0
8577          5         341  0.0    0.0
...         ...         ...  ...    ...
5675          4        2659  1.0    1.0
5677          5        6755  1.0    1.0
5633          2          35  1.0    1.0
5659          3          99  1.0    1.0
4828          1           3  1.0    1.0

[12873 rows x 4 columns]


In [10]:
test_tree.reset_counters()

In [11]:
test_tree.pass_dataloader_through_tree(train_data)

In [12]:
counts = test_tree.get_number_counts()
print(counts.to_string())

       layer_idx  region_idx        1        0    total
0              0           0      0.0      0.0      0.0
1              1           0  27099.0    431.0  27530.0
2              2           0      0.0      0.0      0.0
3              3           0      0.0      0.0      0.0
4              4           0      0.0      0.0      0.0
5              5           0      0.0      0.0      0.0
6              5        4096      0.0      0.0      0.0
7              4         512      0.0      0.0      0.0
8              5         512      0.0      0.0      0.0
9              5        4608      0.0      0.0      0.0
10             4        1024      0.0      0.0      0.0
11             5        1024      0.0      0.0      0.0
12             5        5120      0.0      0.0      0.0
13             4        1536      0.0      0.0      0.0
14             5        1536      0.0      0.0      0.0
15             5        5632      0.0      0.0      0.0
16             4        2048      0.0      0.0  

In [13]:
print(counts.sort_values("layer_idx").to_string())

       layer_idx  region_idx        1        0    total
0              0           0      0.0      0.0      0.0
1              1           0  27099.0    431.0  27530.0
11264          1           7      1.0   3528.0   3529.0
1610           1           1   1354.0   1621.0   2975.0
9655           1           6      0.0      0.0      0.0
6437           1           4    518.0  22580.0  23098.0
3219           1           2      0.0      0.0      0.0
8046           1           5     21.0   4519.0   4540.0
4828           1           3  21007.0  17321.0  38328.0
11265          2           7      0.0      0.0      0.0
1007           2          40      0.0      0.0      0.0
11063          2          62      0.0      0.0      0.0
10661          2          46      0.0      0.0      0.0
11868          2          31      0.0      0.0      0.0
5030           2          11      0.0      0.0      0.0
10862          2          54      0.0      0.0      0.0
10259          2          30      0.0      0.0  

In [14]:
first_layer_counts = counts[counts.layer_idx==1]
second_layer_counts = counts[counts.layer_idx==2]
third_layer_counts = counts[counts.layer_idx==3]


In [15]:
print(np.nansum(first_layer_counts["0"]))
print(np.nansum(first_layer_counts["1"]))

50000.0
50000.0


In [16]:
print(np.nansum(second_layer_counts["0"]))
print(np.nansum(second_layer_counts["1"]))

50000.0
50000.0


In [17]:
print(np.nansum(third_layer_counts["0"]))
print(np.nansum(third_layer_counts["1"]))

50000.0
50000.0


In [18]:
# Replace nan numbers:

# counts["0"] = np.nan_to_num(counts["0"], nan=0.0)
# counts["1"] = np.nan_to_num(counts["1"], nan=0.0)



In [19]:
print(counts.sort_values("layer_idx").to_string())

       layer_idx  region_idx        1        0    total
0              0           0      0.0      0.0      0.0
1              1           0  27099.0    431.0  27530.0
11264          1           7      1.0   3528.0   3529.0
1610           1           1   1354.0   1621.0   2975.0
9655           1           6      0.0      0.0      0.0
6437           1           4    518.0  22580.0  23098.0
3219           1           2      0.0      0.0      0.0
8046           1           5     21.0   4519.0   4540.0
4828           1           3  21007.0  17321.0  38328.0
11265          2           7      0.0      0.0      0.0
1007           2          40      0.0      0.0      0.0
11063          2          62      0.0      0.0      0.0
10661          2          46      0.0      0.0      0.0
11868          2          31      0.0      0.0      0.0
5030           2          11      0.0      0.0      0.0
10862          2          54      0.0      0.0      0.0
10259          2          30      0.0      0.0  

In [20]:
# Make a new column with total points in a region
# column_names = counts.columns.values
# classes = column_names[2:]
# print(classes)
# counts["total"] = counts.apply(lambda row: np.sum([row[classi] for classi in classes]), axis=1)

In [21]:
print(counts.sort_values("layer_idx").to_string())

       layer_idx  region_idx        1        0    total
0              0           0      0.0      0.0      0.0
1              1           0  27099.0    431.0  27530.0
11264          1           7      1.0   3528.0   3529.0
1610           1           1   1354.0   1621.0   2975.0
9655           1           6      0.0      0.0      0.0
6437           1           4    518.0  22580.0  23098.0
3219           1           2      0.0      0.0      0.0
8046           1           5     21.0   4519.0   4540.0
4828           1           3  21007.0  17321.0  38328.0
11265          2           7      0.0      0.0      0.0
1007           2          40      0.0      0.0      0.0
11063          2          62      0.0      0.0      0.0
10661          2          46      0.0      0.0      0.0
11868          2          31      0.0      0.0      0.0
5030           2          11      0.0      0.0      0.0
10862          2          54      0.0      0.0      0.0
10259          2          30      0.0      0.0  