In [1]:
import sys
import imageio
sys.path.append('../../')
import torch
import numpy as np
import pickle as pkl

from experiments.visuals.plot_decision_boundary import plot_decision_boundary
from experiments.visuals.plot_metric import plot_riemann_metric
from experiments.visuals.plot_diff_geom import plot_geometry
from models.supervised.mlp.model import MLP
from models.supervised.bimt.model import BioMLP



In [2]:
np.random.seed(2)
torch.manual_seed(2)

<torch._C.Generator at 0x7fb3a8313a50>

In [12]:
model_name = "bimt"
mode = 'moon'
size="vanilla"

models_path = f"../../models/supervised/{model_name}/saved_models"

with open(f"{models_path}/{size}/{mode}/dataset.pkl", 'rb') as f:
    dataset = pkl.load(f)
if model_name == "bimt":
    model = BioMLP(shp=[2,20,20,2])
elif model_name == "mlp":
    if size == "vanilla":
        model = MLP(2, 7, 10, 2)
    elif size == "overfit":
        model = MLP(2, 7, 2, 1)
    elif size == "2_wide":
        model = MLP(2, 7, 2, 2)
model.eval()



BioMLP(
  (layers): ModuleList(
    (0): BioLinear(
      (linear): Linear(in_features=2, out_features=20, bias=True)
    )
    (1): BioLinear(
      (linear): Linear(in_features=20, out_features=20, bias=True)
    )
    (2): BioLinear(
      (linear): Linear(in_features=20, out_features=2, bias=True)
    )
  )
)

In [14]:
%%capture

# Get all corresponding files:
imgs = []
for epoch in tqdm.tqdm(range(0, 10000, 250)):
    model.load_state_dict(torch.load(f'../../models/supervised/{model_name}/saved_models/{size}/{mode}/model_{epoch}.pth'))
    img = plot_decision_boundary(model, dataset.X, dataset.y, epoch, f"./figures/{model_name}/decision_boundaries/{mode}/{size}")
    imgs.append(img)


In [15]:
imageio.mimsave(f'./figures/{model_name}/decision_boundaries/{mode}/decision_boundary_{mode}.gif', imgs, duration=50/3)


In [8]:
%%capture
sigma = 0.05
model = MLP(2, 7, 2, 2)
model.eval()

# Get all corresponding files:
imgs = []
for epoch in tqdm.tqdm(range(200)):
    model.load_state_dict(torch.load(f'../../models/supervised/{model_name}/saved_models/{size}/{mode}/model_{epoch}.pth'))
    img = plot_riemann_metric(model, dataset.X, dataset.y, epoch, f"./figures/{model_name}/curvature/{mode}/{size}")
    imgs.append(img)


In [9]:
imageio.mimsave(f'./figures/{model_name}/curvature/{mode}/riemannian_metric_{mode}.gif', imgs, duration=50/3)



In [6]:
%%capture
indices_zero = np.where(dataset.y == 0)[0]
indices_one = np.where(dataset.y == 1)[0]

# create arrays based on these indices
x_a = dataset.X[indices_zero]
x_b = dataset.X[indices_one]

model = MLP(2, 7, 2, 2)
grid_dim = 50


imgs = []
for epoch in tqdm.tqdm(range(200)):
    model.load_state_dict(torch.load(f'../../models/supervised/{model_name}/saved_models/{size}/{mode}/model_{epoch}.pth'))
    img = plot_geometry(model, x_a, x_b, epoch, grid_dim, grid_dim, 
                        plot_rows=1, plot_cols=model.num_layers+1, plot_grids_=True, 
                        plot_tensors_=False, plot_classifier_=True, 
                        save_folder=f"./figures/{model_name}/geometry/{mode}/{size}")
                        
    imgs.append(img)
    



In [7]:
imageio.mimsave(f'./figures/{model_name}/geometry/{mode}/geometry_{mode}.gif', imgs, duration=50/3)


In [None]:
# OVERFIT A 2D MODEL
indices_zero = np.where(dataset.y == 0)[0]
indices_one = np.where(dataset.y == 1)[0]

# create arrays based on these indices
x_a = dataset.X[indices_zero]
x_b = dataset.X[indices_one]

model = MLP(2, 7, 2, 1)

grid_dim = 50

epochs = [999, 1999, 2999, 3999, 4999, 5999, 6999, 7999, 8999, 9999]
for epoch in epochs:
    model.load_state_dict(torch.load(f'../../models/supervised/{model_name}/saved_models/{size}/{mode}/model_{epoch}.pth'))
    plot_decision_boundary(model, dataset.X, dataset.y, epoch, f"./figures/decision_boundaries/{mode}/overfit")
    plot_riemann_metric(model, dataset.X, dataset.y, epoch, f"./figures/curvature/{mode}/overfit")
    plot_geometry(model, x_a, x_b, epoch, grid_dim, grid_dim, 
                        plot_rows=1, plot_cols=model.num_layers, plot_grids_=True, 
                        plot_tensors_=False, plot_classifier_=True, 
                        save_folder=f"./figures/{model_name}/geometry/{mode}/overfit")

