In [1]:
import sys
import os
sys.path.append("../../../")

import torch
import numpy as np

from experiments.assumptions.metric_comparison import script
from models.supervised.mlp.model import MLP
from models.data.sklearn_datasets import MoonDataset, BlobsDataset, SpiralDataset, CirclesDataset
from utils.metrics.metrics import compute_cosine_score, compute_magnitude_score


models_path = "../../../models/supervised/mlp/saved_models"

# We want to measure the error from using the pullback metric instead of computing the Riemannian
# metric for each layer.

# We will use the Moon dataset for this experiment and the cosine function to measure the 
# between the two given metrics are scale-invariant.

In [2]:
np.random.seed(2)
torch.manual_seed(2)

<torch._C.Generator at 0x7fb81ca9c490>

In [3]:
dataset_name = 'moon'
size = "skinny"
if dataset_name == 'moon':
    dataset = MoonDataset(n_samples=1000, noise=0.01)
elif dataset_name == 'blobs':
    dataset = BlobsDataset(n_samples=1000, noise=0.01)
elif dataset_name == 'spiral':
    dataset = SpiralDataset(n_samples=1000, noise=0.01)
elif dataset_name == 'circles':
    dataset = CirclesDataset(n_samples=1000, noise=0.01)

epoch = 199


if size == "skinny":
    model = MLP(2,7,2,2)
    model.load_state_dict(torch.load(f'{models_path}/2_wide/mlp_{dataset_name}/model_{epoch}.pth'))
elif size == "overfit":
    model = MLP(2,7,2,1)
    model.load_state_dict(torch.load(f'{models_path}/overfit/mlp_{dataset_name}/model_{epoch}.pth'))
    final_layer = model.layers[-1]
    model.layers = model.layers[:-1]
elif size == "vanilla":
    model = MLP(2,7,10,2)
    model.load_state_dict(torch.load(f'{models_path}/vanilla/mlp_{dataset_name}/model_{epoch}.pth'))

os.makedirs(f'figures/{epoch}', exist_ok=True)
os.makedirs(f'figures/{epoch}/{dataset_name}', exist_ok=True)
os.makedirs(f'figures/{epoch}/{dataset_name}/{size}', exist_ok=True)

In [4]:
local_g_diag_lattice = script.local_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=50, plot_method='lattice_diagonal')
local_g_lattice = script.local_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=15, plot_method='lattice')
local_g_surface = script.local_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=50, plot_method='surface')

In [4]:
pullback_g_diag_lattice, grid_diag_lattice = script.pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=50, plot_method='lattice_diagonal')
pullback_g_lattice, grid_lattice = script.pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=15, plot_method='lattice')
pullback_g_surface, grid_surface = script.pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=50, plot_method='surface')


  s = (x.conj() * x).real


In [4]:
full_pullback_g_diag_lattice, full_grid_diag_lattice = script.full_pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=50, plot_method='lattice_diagonal')
full_pullback_g_lattice, full_grid_lattice = script.full_pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=15, plot_method='lattice')
full_pullback_g_surface, full_grid_surface = script.full_pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=50, plot_method='surface')

In [17]:
cosine_score = compute_cosine_score(pullback_g_lattice, full_pullback_g_lattice)
magnitude_score = compute_magnitude_score(pullback_g_lattice, full_pullback_g_lattice)
script.plot_err_heatmap(cosine_score, magnitude_score, full_grid_lattice, model, dataset.X, dataset.y, N=15, save_name=f"{dataset_name}/{size}", epoch=epoch)
script.violin_plot(cosine_score, magnitude_score, save_name=f"{dataset_name}/{size}", epoch=epoch)

In [3]:
def _iter_plot(dataset_name, size, epoch):
    if dataset_name == 'moon':
        dataset = MoonDataset(n_samples=1000, noise=0.01)
    elif dataset_name == 'blobs':
        dataset = BlobsDataset(n_samples=1000, noise=0.01)
    elif dataset_name == 'spiral':
        dataset = SpiralDataset(n_samples=1000, noise=0.01)
    elif dataset_name == 'circles':
        dataset = CirclesDataset(n_samples=1000, noise=0.01)



    if size == "skinny":
        model = MLP(2,7,2,2)
        model.load_state_dict(torch.load(f'{models_path}/2_wide/mlp_{dataset_name}/model_{epoch}.pth'))
    elif size == "overfit":
        model = MLP(2,7,2,1)
        model.load_state_dict(torch.load(f'{models_path}/overfit/mlp_{dataset_name}/model_{epoch}.pth'))
        final_layer = model.layers[-1]
        model.layers = model.layers[:-1]
    elif size == "vanilla":
        model = MLP(2,7,10,2)
        model.load_state_dict(torch.load(f'{models_path}/vanilla/mlp_{dataset_name}/model_{epoch}.pth'))

    os.makedirs(f'figures/{epoch}', exist_ok=True)
    os.makedirs(f'figures/{epoch}/{dataset_name}', exist_ok=True)
    os.makedirs(f'figures/{epoch}/{dataset_name}/{size}', exist_ok=True)

    local_g_diag_lattice = script.local_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=50, plot_method='lattice_diagonal')
    local_g_lattice = script.local_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=15, plot_method='lattice')
    local_g_surface = script.local_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=50, plot_method='surface')
    
    pullback_g_diag_lattice, grid_diag_lattice = script.pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=50, plot_method='lattice_diagonal')
    pullback_g_lattice, grid_lattice = script.pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=15, plot_method='lattice')
    pullback_g_surface, grid_surface = script.pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=50, plot_method='surface')

    full_pullback_g_diag_lattice, full_grid_diag_lattice = script.full_pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=50, plot_method='lattice_diagonal')
    full_pullback_g_lattice, full_grid_lattice = script.full_pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=15, plot_method='lattice')
    full_pullback_g_surface, full_grid_surface = script.full_pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=50, plot_method='surface')

    cosine_score = compute_cosine_score(pullback_g_lattice, full_pullback_g_lattice)
    magnitude_score = compute_magnitude_score(pullback_g_lattice, full_pullback_g_lattice)
    script.plot_err_heatmap(cosine_score, magnitude_score, full_grid_lattice, model, dataset.X, dataset.y, N=15, save_name=f"{dataset_name}/{size}", epoch=epoch)
    script.violin_plot(cosine_score, magnitude_score, save_name=f"{dataset_name}/{size}", epoch=epoch)


In [4]:
dataset_name = 'moon'
size = "skinny"
epoch = 199
iter_ = [('moon', 'skinny', 199), ('moon', 'skinny', 5), ('moon', 'overfit', 5), ('moon', 'overfit', 199), ('moon', 'overfit', 499), ('moon', 'overfit', 999), ('moon', 'overfit', 4999), ('moon', 'overfit', 9999)]
for (dataset_name, size, epoch) in iter_:
    _iter_plot(dataset_name, size, epoch)


  direction_metric = direction_metric/np.linalg.norm(direction_metric, axis=1).reshape(-1, 1)
