In [None]:
from __future__ import annotations
import numpy as np
import itertools
from multiprocessing import Pool
from tqdm import tqdm
import pandas as pd
from scipy.optimize import linprog
from typing import Dict, List, Optional, Tuple, Iterable
from scipy.optimize import linprog
import sys
import os
import matplotlib.pyplot as plt
import time

sys.path.append('../')
from src_experiment import get_moons_data, train_model, NeuralNet
from geobin import Region, Tree, RegionTree, TreeNode
import torch

In [None]:
trees = {}
tot_start = time.time()
for epoch in [0,10,20,30,40]:
    state_dict_path = get_test_data() / "state_dicts" / f"epoch{epoch}.pth"
    state = torch.load(state_dict_path)
    start = time.time()
    print(f"\n--- Epoch {epoch} ---")
    tree = Tree(state)
    tree.construct_tree(verbose=True)
    trees[epoch] = tree
    end = time.time()
    print(f"Duration: {end-start:.2f} s")
tot_end = time.time()
print(f"Total duration: {tot_end-tot_start:.2f} s")

In [11]:
import numpy as np
import collections

class DeepPartitionExplorer:
    def __init__(self, input_dim, layer_sizes, seed=42):
        np.random.seed(seed)
        self.input_dim = input_dim
        self.weights = []
        self.biases = []
        
        # Initialize random weights
        prev_dim = input_dim
        for size in layer_sizes:
            # He initialization for stability
            w = np.random.randn(size, prev_dim) * np.sqrt(2/prev_dim)
            b = np.random.randn(size) * 0.1
            self.weights.append(w)
            self.biases.append(b)
            prev_dim = size

    def get_activation_pattern(self, x):
        """Standard forward pass to get the binary code."""
        code = []
        h = x
        for W, b in zip(self.weights, self.biases):
            pre = h @ W.T + b
            mask = (pre > 0).astype(int)
            code.append(tuple(mask))
            h = pre * mask # ReLU
        return tuple(code)

    def get_all_effective_boundaries(self, x):
        """
        THE TRICK: Computes the 'Effective' W and b for every neuron 
        in every layer, relative to the input space at point x.
        
        Returns:
            boundaries: list of dicts {'normal': vector, 'bias': scalar, 'layer': int}
            distances: list of floats (signed distance to each boundary)
        """
        boundaries = []
        distances = []
        
        # Current effective transformation: Starts as Identity (x -> x)
        # We model the input as: W_curr * x + b_curr
        W_curr = np.eye(self.input_dim)
        b_curr = np.zeros(self.input_dim)
        
        # Current actual activation (to calculate masks)
        h_curr = x
        
        for layer_idx, (W, b) in enumerate(zip(self.weights, self.biases)):
            # 1. Calculate the 'Effective' parameters for this layer's pre-activations
            # z = W * (W_curr * x + b_curr) + b
            # z = (W @ W_curr) * x + (W @ b_curr + b)
            
            W_effective = W @ W_curr
            b_effective = W @ b_curr + b
            
            # 2. Collect boundaries for this layer
            # Each neuron j defines a plane: W_eff[j] . x + b_eff[j] = 0
            # We calculate distance to this plane for point x
            
            # Norm of each effective weight row (needed for true distance)
            norms = np.linalg.norm(W_effective, axis=1)
            
            # Avoid division by zero for dead neurons (norm ~ 0)
            valid_mask = norms > 1e-6
            
            # Signed distance: (w.x + b) / ||w||
            # Note: W_effective @ x + b_effective is exactly the pre-activation value
            dists = (W_effective @ x + b_effective) / norms
            
            for i in range(len(dists)):
                if valid_mask[i]:
                    boundaries.append({
                        'normal': W_effective[i] / norms[i], # Unit normal
                        'dist_val': dists[i],
                        'layer': layer_idx,
                        'neuron': i
                    })
                    distances.append(dists[i])
            
            # 3. Update state for the NEXT layer
            # We must apply the ReLU mask of the current point x
            # h_next = ReLU(z)
            
            # Calculate actual pre-activation to find the mask
            z_actual = W @ h_curr + b # Note: using actual local h_curr, not global x
            mask = (z_actual > 0).astype(float)
            
            # Update h_curr for the next loop iteration (standard forward pass step)
            h_curr = z_actual * mask
            
            # Update the Global Linearization (Chain Rule equivalent)
            # The next layer sees: W_next * (Mask * (W_curr * x + b_curr)) + b_next
            # So we multiply rows of W_curr and b_curr by the mask
            
            # Diagonal mask matrix multiplication is equivalent to row-wise scaling
            W_curr = mask[:, None] * (W @ W_curr) # Broadcast mask across columns
            b_curr = mask * (W @ b_curr + b)      # Element-wise multiply
            
        return boundaries, np.array(distances)

    def crawl_deep(self, start_point=None, max_regions=1000, steps_per_point=5):
        if start_point is None:
            start_point = np.zeros(self.input_dim)

        queue = collections.deque([start_point])
        discovered_codes = set()
        discovered_codes.add(self.get_activation_pattern(start_point))
        
        # Keep track of unique centers to avoid looping
        visited_hashes = set()
        
        print(f"--- Deep Crawler (Max {max_regions} regions) ---")
        
        while queue and len(discovered_codes) < max_regions:
            curr_x = queue.popleft()
            
            # 1. Get ALL boundaries (Layer 1, 2, 3...)
            bounds, dists = self.get_all_effective_boundaries(curr_x)
            
            # 2. Sort by absolute distance (find closest walls)
            sorted_indices = np.argsort(np.abs(dists))
            
            # 3. Try to cross the 'k' closest walls
            count = 0
            for idx in sorted_indices:
                if count >= steps_per_point: break
                
                b_info = bounds[idx]
                dist = b_info['dist_val']
                normal = b_info['normal']
                
                # Step: Move slightly past the wall
                # If dist is positive, wall is behind us (relative to normal), go negative direction?
                # Actually dist is signed. If dist is +5, we need to move -5.01 * normal.
                # If dist is -5, we need to move +5.01 * normal.
                # Formula: x_new = x - (dist + sign(dist)*epsilon) * normal
                
                epsilon = 1e-4 # Tiny step
                step_vector = - (dist + np.sign(dist)*epsilon) * normal
                new_x = curr_x + step_vector
                
                # Verify
                new_code = self.get_activation_pattern(new_x)
                
                # Hash the code to store it
                if new_code not in discovered_codes:
                    discovered_codes.add(new_code)
                    queue.append(new_x)
                    count += 1
                    
                    # Optional: Print progress
                    if len(discovered_codes) % 50 == 0:
                        print(f"Found {len(discovered_codes)} regions... (Just crossed Layer {b_info['layer']})")
                        
        return discovered_codes

# ==========================================
# Run the Deep Crawler
# ==========================================

# 1. Setup: 2 Inputs -> 10 Hidden -> 10 Hidden -> 5 Hidden
net = DeepPartitionExplorer(input_dim=2, layer_sizes=[50, 20, 20])

# 2. Crawl
regions = net.crawl_deep(max_regions=10000)

print(f"\nFinal Count: {len(regions)} unique activation patterns found.")

# 3. Verify Complexity
# Let's inspect a few random codes to see if deep layers are actually changing
sample = list(regions)[0]
print(f"Sample Code Structure (Tuple of Tuples): \n L1: {sample[0]} \n L2: {sample[1]} \n L3: {sample[2]}")

--- Deep Crawler (Max 10000 regions) ---
Found 50 regions... (Just crossed Layer 1)
Found 100 regions... (Just crossed Layer 0)
Found 150 regions... (Just crossed Layer 1)
Found 200 regions... (Just crossed Layer 0)
Found 250 regions... (Just crossed Layer 0)
Found 300 regions... (Just crossed Layer 0)
Found 350 regions... (Just crossed Layer 1)
Found 400 regions... (Just crossed Layer 0)
Found 450 regions... (Just crossed Layer 0)
Found 500 regions... (Just crossed Layer 0)
Found 550 regions... (Just crossed Layer 1)
Found 600 regions... (Just crossed Layer 2)
Found 650 regions... (Just crossed Layer 0)
Found 700 regions... (Just crossed Layer 0)
Found 750 regions... (Just crossed Layer 0)
Found 800 regions... (Just crossed Layer 2)
Found 850 regions... (Just crossed Layer 0)
Found 900 regions... (Just crossed Layer 0)
Found 950 regions... (Just crossed Layer 1)
Found 1000 regions... (Just crossed Layer 0)
Found 1050 regions... (Just crossed Layer 0)
Found 1100 regions... (Just crosse