In [1]:
import sys
import os
sys.path.append("../../../")

import torch
import numpy as np

from experiments.assumptions.metric_comparison import script
from models.supervised.mlp.model import MLP
from models.data.sklearn_datasets import MoonDataset, BlobsDataset, SpiralDataset, CirclesDataset
models_path = "../../../models/supervised/mlp/saved_models"

# We want to measure the error from using the pullback metric instead of computing the Riemannian
# metric for each layer.

# We will use the Moon dataset for this experiment and the cosine function to measure the 
# between the two given metrics are scale-invariant.

In [2]:
np.random.seed(2)
torch.manual_seed(2)

<torch._C.Generator at 0x7fc3e4c2c4b0>

In [3]:
mode = 'moon'
size = "skinny"
if mode == 'moon':
    dataset = MoonDataset(n_samples=1000, noise=0.01)
elif mode == 'blobs':
    dataset = BlobsDataset(n_samples=1000, noise=0.01)
elif mode == 'spiral':
    dataset = SpiralDataset(n_samples=1000, noise=0.01)
elif mode == 'circles':
    dataset = CirclesDataset(n_samples=1000, noise=0.01)

epoch = 199
os.makedirs(f'figures/{epoch}', exist_ok=True)

if size == "skinny":
    model = MLP(2,7,2,2)
    model.load_state_dict(torch.load(f'{models_path}/2_wide/mlp_{mode}/model_{epoch}.pth'))
else:
    model = MLP(2,7,10,2)
    model.load_state_dict(torch.load(f'{models_path}/vanilla/mlp_{mode}/model_{epoch}.pth'))

In [None]:
local_g_diag_lattice = script.local_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{mode}_{size}", N=50, plot_method='lattice_diagonal')
local_g_lattice = script.local_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{mode}_{size}", N=15, plot_method='lattice')
local_g_surface = script.local_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{mode}_{size}", N=50, plot_method='surface')

In [4]:
pullback_g_diag_lattice, grid_diag_lattice = script.pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{mode}_{size}", N=50, plot_method='lattice_diagonal')
pullback_g_lattice, grid_lattice = script.pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{mode}_{size}", N=15, plot_method='lattice')
pullback_g_surface, grid_surface = script.pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{mode}_{size}", N=50, plot_method='surface')


In [5]:
full_pullback_g_diag_lattice, full_grid_diag_lattice = script.full_pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{mode}_{size}", N=50, plot_method='lattice_diagonal')
full_pullback_g_lattice, full_grid_lattice = script.full_pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{mode}_{size}", N=15, plot_method='lattice')
full_pullback_g_surface, full_grid_surface = script.full_pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{mode}_{size}", N=50, plot_method='surface')

In [12]:
a = 


[[1.44488062e-03 3.73087241e-04]
 [1.69156027e+00 4.42578375e-01]
 [1.38948758e-06 3.77147529e-07]
 [1.18238734e-08 3.57146224e-09]
 [6.34412689e-10 2.65928501e-10]
 [4.99731541e-11 6.48606793e-11]
 [4.25113330e-12 5.69265468e-11]
 [4.46138626e-10 2.45555631e-10]
 [1.17556178e-08 4.33387548e-09]
 [9.97794487e-08 8.10315441e-08]
 [2.59593413e-09 1.64216715e-07]
 [2.60280473e-08 4.21341966e-08]
 [1.30450939e-08 6.87903423e-09]
 [3.96553368e-09 1.38538703e-09]
 [1.25171995e-09 3.67921749e-10]
 [1.19579738e-04 3.09273019e-05]
 [1.74694763e+03 4.58855896e+02]
 [1.71800020e-05 4.70843042e-06]
 [4.15588204e-08 1.28682451e-08]
 [1.48173784e-09 6.63939015e-10]
 [9.26607957e-11 1.48939208e-10]
 [1.77866229e-11 1.43553891e-10]
 [1.39386591e-09 7.99883382e-10]
 [4.75624375e-08 2.17008385e-08]
 [3.82141309e-07 5.37560481e-07]
 [1.51832680e-08 7.44975637e-07]
 [1.06498476e-07 1.11731687e-07]
 [2.92880120e-08 1.32858347e-08]
 [6.88643098e-09 2.25887042e-09]
 [1.90231764e-09 5.43806833e-10]
 [1.908942

In [23]:



cosine_score = compute_cosine_score(pullback_g_lattice, full_pullback_g_diag_lattice)
print(f"cosine score: {cosine_score}")

cosine score: [[1.0, 0.99999654, 0.9999394, 0.9992633, 0.9900533, 0.79018396, 0.33177155, 0.9722117, 0.9966247, 0.9200998, 0.29901457, 0.7515501, 0.9863419, 0.9998923, 0.9938608, 0.9768098, 0.9462595, 0.87168366, 0.717301, 0.56801534, 0.8549899, 0.9986888, 0.7799763, 0.9189171, 0.9332422, 0.44871804, 0.90680104, 0.9970415, 0.9993561, 0.9941957, 0.97920483, 0.94639343, 0.86221987, 0.7007858, 0.5868296, 0.909925, 0.9873849, 0.66442966, 0.8169502, 0.9414682, 0.75553304, 0.99662, 0.9949184, 0.9927614, 0.99429446, 0.9955618, 0.99781865, 0.9994116, 0.99971807, 0.9791169, 0.55065215, 0.4837358, 0.94797796, 0.9151895, 0.35604185, 0.5747716, 0.9588165, 0.9961387, 0.9997846, 0.9999496, 0.999433, 0.9991207, 0.9989591, 0.99972844, 0.990824, 0.64420027, 0.79841876, 0.98353195, 0.9693547, 0.98832434, 0.84181476, 0.532723, 0.6625104, 0.8574014, 0.94951844, 0.98068184, 0.9922737, 0.9968039, 0.9997172, 0.982366, 0.5828419, 0.8630281, 0.9783582, 0.9917369, 0.99562794, 0.7252189, 0.45832032, 0.5277117, 0

In [30]:
import matplotlib.pyplot as plt
def plot_err_heatmap(cosine_scores, xy_grids, model, X, y, N=50, save_name=None, epoch=0):
    fig, ax = plt.subplots(1, len(cosine_scores), figsize=(15*len(cosine_scores), 15))
    X = torch.from_numpy(X).float()
    model.forward(X, save_activations=True)
    for indx, scores in enumerate(cosine_scores):
        x_min, x_max = xy_grids[indx][:, 0].min(), xy_grids[indx][:, 0].max()
        y_min, y_max = xy_grids[indx][:, 1].min(), xy_grids[indx][:, 1].max()
        xx, yy = np.meshgrid(np.linspace(x_min, x_max, N), np.linspace(y_min, y_max, N))
        output = np.array(scores).reshape(xx.shape)
        contour = ax[indx].contourf(xx, yy, output, alpha=0.4, cmap="RdBu_r")
        ax[indx].scatter(model.activations[indx].detach().numpy()[:, 0], model.activations[indx].detach().numpy()[:, 1], c=y, s=20, edgecolor='k')
        ax[indx].set_xlim(xx.min(), xx.max())
        ax[indx].set_ylim(yy.min(), yy.max())

        fig.canvas.draw()
        plt.colorbar(contour)
    if save_name is not None:
        fig.savefig(f"figures/{epoch}/err_heatmap_{save_name}.png")
    else:
        fig.savefig(f"figures/{epoch}/err_heatmap.png")
    plt.close()

plot_err_heatmap(cosine_score, full_grid_lattice, model, dataset.X, dataset.y, N=15, save_name=f"{mode}_{size}", epoch=epoch)

