In [1]:
import sys
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time
import torch

sys.path.append('../')

from src_experiment import NeuralNet, moon_path, wbc_path, get_moons_data, get_wbc_data
from geobin import RegionTree, TreeNode

In [2]:
# test params
experiment = "moon"
arch = "small"
dropout = 0.0
noise = 0.0
run_number = 0
path_fn = moon_path if experiment == "moon" else wbc_path
data_fn = get_moons_data if experiment == "moon" else get_wbc_data

In [3]:
# Load the state dicts
epochs = [0,5,10,15,20,25,30,35,40,45,50,55,60,65,70,74]
state_dicts = {epoch: path_fn(arch, dropout, noise, run_number) /"state_dicts"/f"epoch{epoch}.pth" for epoch in epochs}

In [4]:
# Get the testing data
_, data = data_fn(noise, batch_size=32)
len(data)

7

In [5]:
ncounts_per_epoch = {}
import tqdm
for epoch in tqdm.tqdm(epochs):
    state_dict = torch.load(state_dicts[epoch])
    tree = RegionTree(state_dict)
    tree.build_tree(verbose=False, check_feasibility=False)
    tree.pass_dataloader_through_tree(data)
    tree.collect_number_counts()
    ncounts_per_epoch[epoch] = tree.get_number_counts()
    # print(tree.size)

100%|██████████| 16/16 [00:00<00:00, 24.09it/s]


In [6]:
ncounts_per_epoch_checked = {}
for epoch in tqdm.tqdm(epochs):
    state_dict = torch.load(state_dicts[epoch])
    tree = RegionTree(state_dict)
    tree.build_tree(verbose=False, check_feasibility=True)
    tree.pass_dataloader_through_tree(data)
    tree.collect_number_counts()
    ncounts_per_epoch_checked[epoch] = tree.get_number_counts()
    # print(tree.size)

100%|██████████| 16/16 [00:03<00:00,  4.50it/s]


In [7]:
len(ncounts_per_epoch[0])

1608

In [8]:
ncounts_per_epoch[0]

Unnamed: 0,layer_idx,region_idx,1,0,total
1,1,0,19.0,6.0,25.0
2,2,0,0.0,0.0,0.0
3,3,0,0.0,0.0,0.0
4,4,0,0.0,0.0,0.0
5,4,512,0.0,0.0,0.0
...,...,...,...,...,...
1604,4,447,0.0,0.0,0.0
1605,4,959,0.0,0.0,0.0
1606,3,511,0.0,0.0,0.0
1607,4,511,0.0,0.0,0.0


In [9]:
len(ncounts_per_epoch_checked[0])

91

In [10]:
ncounts_per_epoch_checked[0]

Unnamed: 0,layer_idx,region_idx,1,0,total
1,1,0,19.0,6.0,25.0
2,2,14,19.0,6.0,25.0
3,3,79,19.0,6.0,25.0
4,4,3,19.0,6.0,25.0
5,2,21,0.0,0.0,0.0
...,...,...,...,...,...
87,3,113,5.0,15.0,20.0
88,4,21,3.0,5.0,8.0
89,4,52,2.0,10.0,12.0
90,3,132,0.0,0.0,0.0


In [13]:
from src_experiment import DivergenceEngine, QUANTITIES_TO_ESTIMATE
from typing import Dict, List
class EstimateQuantities1Run:
    """
    Experiment-specific wrapper that:
    - loads number counts
    - loops over epochs
    - delegates all math to DivergenceEngine
    """

    def __init__(
        self,
        ncounts
    ):
        self.ncounts = ncounts

        self.estimates: Dict[str, List[pd.DataFrame]] = {
            q: [] for q in QUANTITIES_TO_ESTIMATE
        }

        # Perform calculations
        self.calculate_estimates()

    # ------------------------------------------------------------------

    def calculate_estimates(self) -> None:
        for epoch, frame in self.ncounts.items():
            engine = DivergenceEngine(frame)
            epoch_results = engine.compute()

            for key, df in epoch_results.items():
                df.insert(0, "epoch", epoch)
                self.estimates[key].append(df)

        # Concatenate epochs
        for key, frames in self.estimates.items():
            self.estimates[key] = (
                pd.concat(frames, ignore_index=True)
                .rename_axis(None, axis=1)
            )

    def get_estimates(self) -> Dict[str, pd.DataFrame]:
        return self.estimates


In [20]:
est1 = EstimateQuantities1Run(ncounts_per_epoch)
est2 = EstimateQuantities1Run(ncounts_per_epoch_checked)

  logterm = self.m_kw / (self.m_w @ self.m_k)
  term = self.m_kw / (self.m_w @ self.m_k)
  term = self.m_kw / self.m_w
  logterm = self.m_kw / (self.m_w @ self.m_k)
  term = self.m_kw / (self.m_w @ self.m_k)
  term = self.m_kw / self.m_w
  logterm = self.m_kw / (self.m_w @ self.m_k)
  term = self.m_kw / (self.m_w @ self.m_k)
  term = self.m_kw / self.m_w
  logterm = self.m_kw / (self.m_w @ self.m_k)
  term = self.m_kw / (self.m_w @ self.m_k)
  term = self.m_kw / self.m_w
  logterm = self.m_kw / (self.m_w @ self.m_k)
  term = self.m_kw / (self.m_w @ self.m_k)
  term = self.m_kw / self.m_w
  logterm = self.m_kw / (self.m_w @ self.m_k)
  term = self.m_kw / (self.m_w @ self.m_k)
  term = self.m_kw / self.m_w
  logterm = self.m_kw / (self.m_w @ self.m_k)
  term = self.m_kw / (self.m_w @ self.m_k)
  term = self.m_kw / self.m_w
  logterm = self.m_kw / (self.m_w @ self.m_k)
  term = self.m_kw / (self.m_w @ self.m_k)
  term = self.m_kw / self.m_w
  logterm = self.m_kw / (self.m_w @ self.m_k)
  

In [None]:
from visualization import plot_all_quantities
plot_all_quantities(est1.get_estimates())
plot_all_quantities(est2.get_estimates())

AttributeError: 'dict' object has no attribute 'get_estimates'