## Complete Tiling and Decision Boundary [Part 2b - run big tiling]

In [6]:
from manimlib import *
from functools import partial
import sys, cv2 

sys.path.append('../_2025/backprop_3') #Point to folder where plane_folding_utils.py is
sys.path.append('../')
from geometric_dl_utils import *
from geometric_dl_utils_simplified import *
from polytope_intersection_utils import intersect_polytopes
import matplotlib.pyplot as plt

import matplotlib.patches as mp
tab20_colors_mpl = plt.cm.tab20_r.colors

In [7]:
#2x2
# model_path='../models/2_2_1.pth'
# model = BaarleNet([2,2])
# model.load_state_dict(torch.load(model_path))
# viz_scales=[0.25, 0.25, 0.3, 0.3, 0.15]
# num_neurons=[2, 2, 2, 2, 2]

#3x3
# model_path='../models/3_3_1.pth'
# model = BaarleNet([3,3])
# model.load_state_dict(torch.load(model_path))
# viz_scales=[0.1, 0.1, 0.05, 0.05, 0.15]
# num_neurons=[3, 3, 3, 3, 2]

#8x8
# model_path='../models/8_8_1.pth'
# model = BaarleNet([8,8])
# model.load_state_dict(torch.load(model_path))
# viz_scales=[0.1, 0.1, 0.05, 0.05, 0.15]
# num_neurons=[8, 8, 8, 8, 2]

### 16 16 16
# model_path='../models/16_16_16_1.pth'
# model = BaarleNet([16, 16, 16])
# model.load_state_dict(torch.load(model_path))
# num_neurons=[16, 16, 16, 16, 16, 16, 2]        

### 32 32 32 32
model_path='../models/32_32_32_32_1.pth'
model = BaarleNet([32, 32, 32, 32])
model.load_state_dict(torch.load(model_path))
num_neurons=[32, 32, 32, 32, 32, 32, 32, 32, 2]        

In [8]:
map=cv2.imread('Baarle-Nassau_-_Baarle-Hertog-en no legend.png')[:,:,(2,1,0)]

def viz_descision_boundary(model, res=256, figsize=(6,6)):
    plt.clf()
    fig=plt.figure(0,figsize)
    ax=fig.add_subplot(111)
    
    probe=np.zeros((res,res,2))
    for j, xx in enumerate(np.linspace(-1, 1, res)):
        for k, yy in enumerate(np.linspace(-1, 1, res)):
            probe[j, k]=[yy,xx]
    probe=probe.reshape(res**2, -1)
    with torch.no_grad():
        probe_logits=model(torch.tensor(probe).float())
        probe_logits=probe_logits.detach().numpy().reshape(res,res,2)
        probe_softmax = torch.nn.Softmax(dim=1)(torch.tensor(probe_logits.reshape(-1, 2)))
    
    ax.imshow(map.mean(2), cmap='gray', extent=[-1, 1, -1, 1])
    ax.imshow(np.flipud(np.argmax(probe_logits,2)), 
               extent=[-1, 1, -1, 1],  # This maps to image coordinates
               alpha=0.7,
               cmap='viridis')
    return ax

In [9]:
def viz_polygon_list(ax, polygon_list, alpha=0.5):
    for j, p in enumerate(polygon_list):    
        if len(p)<3: continue
        poly=mp.Polygon(p[:,:2].tolist(), facecolor=tab20_colors_mpl[j%len(tab20_colors_mpl)], 
                        edgecolor=tab20_colors_mpl[j%len(tab20_colors_mpl)], linewidth=2, alpha=alpha)
        ax.add_patch(poly)
    plt.xlim([-1,1]); plt.ylim([-1,1]); ax.axis('off')

def viz_layer_polygons(polygon_list, fig_size=(6,6)):
    '''Assume triple layer list for now'''
    plt.clf()
    fig=plt.figure(0, fig_size)
    sqr=int(np.ceil(np.sqrt(len(polygon_list))))
    for i, pss in enumerate(polygon_list):
        ax=fig.add_subplot(sqr,sqr,i+1)
        for ps in pss: 
            viz_polygon_list(ax, ps)

To Do: 
- Zero region merging
- Top polytope computation and validation
- Put polygon computation in a nice loop that adapts to length
- Ok things look pretty good here, but with computing zero regions before merging (which is think the the right way to do it), my big network seems be taking a long time! Maybe i just set it up on linux and see -> worried I'm going to get combinatorially exploded. There's only so many surfaces I can render in manim, so this filtering might be pretty important.
- Ok let me look at the top polytope stuff and then be done with this! And I can kick off a long run on Linux. 

In [None]:
polygons={} #dict of all polygones as we go. 
polygons['-1.new_tiling']=[np.array([[-1., -1, 0], #First polygon is just input plane
                                    [-1, 1, 0], 
                                    [1, 1, 0], 
                                    [1, -1, 0]])]

for layer_id in range(len(model.model)//2): #Move polygont through layers     
    polygons[str(layer_id)+'.linear_out']=process_with_layers(model.model[:2*layer_id+1], polygons[str(layer_id-1)+'.new_tiling']) 

    #Split polygons w/ Relu and clip negative values to z=0
    polygons[str(layer_id)+'.split_polygons_nested']=split_polygons_with_relu_simple(polygons[str(layer_id)+'.linear_out']) #Triple nested list so we can simplify merging process layer. 
    polygons[str(layer_id)+'.split_polygons_nested_clipped'] = clip_polygons(polygons[str(layer_id)+'.split_polygons_nested'])
    #Merge zero regions
    polygons[str(layer_id)+'.split_polygons_merged'] = merge_zero_regions(polygons[str(layer_id)+'.split_polygons_nested_clipped'])
    #Compute new tiling
    polygons[str(layer_id)+'.new_tiling']=recompute_tiling_general(polygons[str(layer_id)+'.split_polygons_merged'])
    print('Retiled plane into ', str(len(polygons[str(layer_id)+'.new_tiling'])), ' polygons.')

    #Optional filtering step - start by filting a bit aggressively and see if we can make it to the end
    polygons[str(layer_id)+'.new_tiling'] = filter_small_polygons(polygons[str(layer_id)+'.new_tiling'], min_area=1e-4)
    print(str(len(polygons[str(layer_id)+'.new_tiling'])), ' polygons remaining after filtering out small polygons')

#Last linear layer & output
polygons[str(layer_id+1)+'.linear_out']=process_with_layers(model.model, polygons[str(layer_id)+'.new_tiling'])
intersection_lines, new_2d_tiling, upper_polytope, indicator = intersect_polytopes(*polygons[str(layer_id+1)+'.linear_out'])
my_indicator, my_top_polygons = compute_top_polytope(model, new_2d_tiling)

Retiling plane...


100%|██████████████████████████████████████████| 31/31 [00:00<00:00, 147.92it/s]


Retiled plane into  225  polygons.
204  polygons remaining after filtering out small polygons
Retiling plane...


100%|███████████████████████████████████████████| 31/31 [00:28<00:00,  1.07it/s]


Retiled plane into  1315  polygons.
1061  polygons remaining after filtering out small polygons
Retiling plane...


100%|██████████████████████████████████████| 31/31 [13:11:55<00:00, 1532.76s/it]


Retiled plane into  4474473  polygons.
1864114  polygons remaining after filtering out small polygons


In [None]:
plt.clf()
fig=plt.figure(0, (9,9))
for i, polygons_by_neuron in enumerate(polygons['0.split_polygons_merged']):
    ax=fig.add_subplot(4,4,i+1)
    # unravelled=[item for sublist in polygons_by_neuron for item in sublist]
    viz_polygon_list(ax, polygons_by_neuron)

In [None]:
plt.clf()
fig=plt.figure(0, (6,6))
ax=fig.add_subplot(111)
viz_polygon_list(ax, polygons['0.new_tiling'])

In [None]:
plt.clf()
fig=plt.figure(0, (9,9))
for i, polygons_by_neuron in enumerate(polygons['1.split_polygons_merged']):
    ax=fig.add_subplot(4,4,i+1)
    # unravelled=[item for sublist in polygons_by_neuron for item in sublist]
    viz_polygon_list(ax, polygons_by_neuron)

In [None]:
plt.clf()
fig=plt.figure(0, (6,6))
ax=fig.add_subplot(111)
viz_polygon_list(ax, polygons['1.new_tiling'])

In [None]:
ax=viz_descision_boundary(model)
for l in intersection_lines:
    ax.plot(l[:,0], l[:,1], c='m', linewidth=3)

In [None]:
plt.clf()
fig=plt.figure(0, (6,6))
ax=fig.add_subplot(111)
for j, p in enumerate(new_2d_tiling):    
    if len(p)<3: continue
    poly=mp.Polygon(p[:,:2].tolist(), facecolor=tab20_colors_mpl[j%len(tab20_colors_mpl)], 
                    edgecolor=tab20_colors_mpl[j%len(tab20_colors_mpl)], linewidth=2, alpha=0.5)
    ax.add_patch(poly)
for l in intersection_lines:
    ax.plot(l[:,0], l[:,1], 'm--', linewidth=3, )
plt.xlim([-1,1]); plt.ylim([-1,1]);

In [None]:
plt.clf()
fig=plt.figure(0, (8,8))
ax=fig.add_subplot(111)
for j, p in enumerate(my_top_polygons):    
    if len(p)<3: continue
    if my_indicator[j]: color='y'
    else: color='b'
    poly=mp.Polygon(p[:,:2].tolist(), facecolor=color, 
                    edgecolor=color, linewidth=1, alpha=0.5)
    ax.add_patch(poly)
plt.xlim([-1,1]); plt.ylim([-1,1]); ax.axis('off')

NICE!