In [1]:
import sys
import imageio
sys.path.append('../../')
import torch
import tqdm
import glob
import numpy as np
import matplotlib.pyplot as plt

from experiments.visuals.plot_decision_boundary import plot_decision_boundary
from experiments.visuals.plot_metric import plot_riemann_metric
from experiments.visuals.plot_diff_geom import plot_geometry
from models.supervised.mlp.model import MLP
from models.data.sklearn_datasets import MoonDataset, SpiralDataset, BlobsDataset, CirclesDataset

In [2]:
mode = 'moon'
if mode == 'moon':
    dataset = MoonDataset(n_samples=1000, noise=0.01)
elif mode == 'spiral':
    dataset = SpiralDataset(n_samples=1000, noise=0.01)
elif mode == 'blobs':
    dataset = BlobsDataset(n_samples=1000, noise=0.01)
elif mode == 'circles':
    dataset = CirclesDataset(n_samples=1000, noise=0.01)
model = MLP(2, 7, 2, 2)

model.eval()



MLP(
  (layers): ModuleList(
    (0-5): 6 x Layer(
      (act_func): Tanh()
      (linear_map): Linear(in_features=2, out_features=2, bias=True)
    )
    (6): Layer(
      (act_func): Sigmoid()
      (linear_map): Linear(in_features=2, out_features=2, bias=True)
    )
  )
)

In [4]:
%%capture
# Get all corresponding files:
imgs = []
files = glob.glob(f'../../models/supervised/mlp/saved_models/2_wide/mlp_{mode}/*')
for epoch in tqdm.tqdm(range(len(files))):
    model.load_state_dict(torch.load(f'../../models/supervised/mlp/saved_models/mlp_{mode}/model_{epoch}.pth'))
    img = plot_decision_boundary(model, dataset.X, dataset.y, epoch, f"./figures/decision_boundaries/mlp_{mode}")
    imgs.append(img)


In [5]:
imageio.mimsave(f'./figures/decision_boundary_{mode}.gif', imgs, duration=50/3)


In [17]:
%%capture
sigma = 0.05
model = MLP(2, 7, 2, 2)
model.eval()

# Get all corresponding files:
imgs = []
files = glob.glob(f'../../models/supervised/mlp/saved_models/2_wide/mlp_{mode}/*')
for epoch in tqdm.tqdm(range(len(files))):
    model.load_state_dict(torch.load(f'../../models/supervised/mlp/saved_models/2_wide/mlp_{mode}/model_{epoch}.pth'))
    img = plot_riemann_metric(model, dataset.X, dataset.y, epoch, f"./figures/curvature/mlp_{mode}")
    imgs.append(img)


In [17]:
imageio.mimsave(f'./figures/riemannian_metric_{mode}.gif', imgs, duration=50/3)


In [18]:
%%capture
indices_zero = np.where(dataset.y == 0)[0]
indices_one = np.where(dataset.y == 1)[0]

# create arrays based on these indices
x_a = dataset.X[indices_zero]
x_b = dataset.X[indices_one]

model = MLP(2, 7, 2, 2)
grid_dim = 50
files = glob.glob(f'../../models/supervised/mlp/saved_models/2_wide/mlp_{mode}/*')

imgs = []
for epoch in tqdm.tqdm(range(len(files))):
    model.load_state_dict(torch.load(f'../../models/supervised/mlp/saved_models/2_wide/mlp_{mode}/model_{epoch}.pth'))
    img = plot_geometry(model, x_a, x_b, epoch, grid_dim, grid_dim, 
                        plot_rows=1, plot_cols=model.num_layers, plot_grids_=True, 
                        plot_tensors_=False, plot_classifier_=True, 
                        save_folder=f"./figures/geometry/mlp_{mode}")
    imgs.append(img)
    



In [19]:
imageio.mimsave(f'./figures/geometry_{mode}.gif', imgs, duration=50/3)


In [11]:
# OVERFIT A 2D MODEL
indices_zero = np.where(dataset.y == 0)[0]
indices_one = np.where(dataset.y == 1)[0]

# create arrays based on these indices
x_a = dataset.X[indices_zero]
x_b = dataset.X[indices_one]

model = MLP(2, 7, 2, 1)

grid_dim = 50

epochs = [999, 1999, 2999, 3999, 4999, 5999, 6999, 7999, 8999, 9999]
for epoch in epochs:
    model.load_state_dict(torch.load(f'../../models/supervised/mlp/saved_models/overfit/mlp_moon/model_{epoch}.pth'))
    plot_decision_boundary(model, dataset.X, dataset.y, epoch, f"./figures/decision_boundaries/mlp_overfit")
    plot_riemann_metric(model, dataset.X, dataset.y, epoch, f"./figures/curvature/mlp_overfit")
    plot_geometry(model, x_a, x_b, epoch, grid_dim, grid_dim, 
                        plot_rows=1, plot_cols=model.num_layers, plot_grids_=True, 
                        plot_tensors_=False, plot_classifier_=True, 
                        save_folder=f"./figures/geometry/mlp_overfit")

