In [17]:
import sys
import os
sys.path.append("../../../")

import torch
import numpy as np

from experiments.assumptions.metric_comparison import script
from models.supervised.mlp.model import MLP
from models.data.sklearn_datasets import MoonDataset, BlobsDataset, SpiralDataset, CirclesDataset
from utils.metrics.metrics import compute_cosine_score, compute_magnitude_score



# We want to measure the error from using the pullback metric instead of computing the Riemannian
# metric for each layer.

# We will use the Moon dataset for this experiment and the cosine function to measure the 
# between the two given metrics are scale-invariant.

In [18]:
np.random.seed(2)
torch.manual_seed(2)

<torch._C.Generator at 0x7f8aa6b40510>

In [19]:
import numpy as np
from riemannian_geometry.computations.riemann_metric import LocalDiagPCA
from riemannian_geometry.computations.pullback_metric import compute_jacobian_multi_layer
import matplotlib.pyplot as plt
import torch
from matplotlib.patches import Ellipse
from utils.plotting.mesh import generate_lattice




def plot_surface(ax, activation, labels, xy_grid, metric_layer, layer, N=50):
    n_points, K, _ = metric_layer.shape
    if K != 2:
        metric_layer = metric_layer[:, :2, :2]
        xy_grid = xy_grid[:, :2]
    # Plots pull_forward or pull_back metric on the surface of the neural networks learnt functions
    ax[0][layer].scatter(activation[:, 0], activation[:, 1], c=labels, edgecolors='k')
    ax[0][layer].set_title(f'Magnitude Riemann Metric - Surface - Layer {layer}')

    ax[1][layer].scatter(activation[:, 0], activation[:, 1], c=labels, edgecolors='k')
    ax[1][layer].set_title(f'Direction Riemann Metric - Surface - Layer {layer}')

    x_max, y_max = np.max(xy_grid[:, 0]), np.max(xy_grid[:, 1])
    x_min, y_min = np.min(xy_grid[:, 0]), np.min(xy_grid[:, 1])
    x, y = zip(*xy_grid)
    col_1 = metric_layer[:, 0, :]
    col_2 = metric_layer[:, 1, :]


    col_1_scale = np.log(1 + np.abs(col_1.copy()))
    col_2_scale = np.log(1 + np.abs(col_2.copy()))
    

    col_1_scale[:, 0] *= (x_max - x_min)/np.sqrt(n_points)
    col_1_scale[:, 1] *= (y_max - y_min)/np.sqrt(n_points)

    col_2_scale[:, 0] *= (x_max - x_min)/np.sqrt(n_points)
    col_2_scale[:, 1] *= (y_max - y_min)/np.sqrt(n_points)

    ax[0][layer].quiver(x, y, col_1_scale[:, 0], col_1_scale[:, 1], angles='xy', scale_units='xy', scale=1, color='r')
    ax[0][layer].quiver(x, y, col_2_scale[:, 0], col_2_scale[:, 1], angles='xy', scale_units='xy', scale=1, color='r')


    col_1_norm = col_1/(np.linalg.norm(col_1, axis=1).reshape(-1, 1) + 1e-5)
    col_2_norm = col_2/(np.linalg.norm(col_2, axis=1).reshape(-1, 1) + 1e-5)
    col_1_norm[np.linalg.norm(col_1, axis=1) < 1e-5] = np.array([1, 0])
    col_2_norm[np.linalg.norm(col_2, axis=1) < 1e-5] = np.array([0, 1])
    
    col_1_norm[:, 0] *= (x_max - x_min)/np.sqrt(n_points)
    col_1_norm[:, 1] *= (y_max - y_min)/np.sqrt(n_points)

    col_2_norm[:, 0] *= (x_max - x_min)/np.sqrt(n_points)
    col_2_norm[:, 1] *= (y_max - y_min)/np.sqrt(n_points)
    ax[1][layer].quiver(x, y, col_1_norm[:, 0], col_1_norm[:, 1], angles='xy', scale_units='xy', scale=1, color='r')
    ax[1][layer].quiver(x, y, col_2_norm[:, 0], col_2_norm[:, 1], angles='xy', scale_units='xy', scale=1, color='r')
    

def plot_lattice_diagonal(ax, activation, labels, xy_grid, metric_layer, layer, N=50):
    zeros = np.zeros(N**2)
    n_points, K, _ = metric_layer.shape

    identity = np.zeros((n_points, K, K))
    identity[:, torch.arange(K), torch.arange(K)] = 1
    diag_score = np.sum(np.sum(np.abs(metric_layer*identity - metric_layer), axis=2), axis=1)
    # If the metric is mostly diagonal, we can just use a 2d plot to visualise the metric
    xx = xy_grid[:, 0].reshape(N, N)
    yy = xy_grid[:, 1].reshape(N, N)
    Z = diag_score.reshape(N, N)
    ax[2][layer].scatter(activation[:, 0], activation[:, 1], c=labels, edgecolors='k')
    contour = ax[2][layer].contourf(xx, yy, Z, alpha=0.4, cmap="RdBu_r")
    plt.colorbar(contour)

    if K != 2:
        metric_layer = metric_layer[:, :2, :2]

    diag_metric = np.diagonal(metric_layer, axis1=1, axis2=2).copy()
    direction_metric = diag_metric
    max_x, max_y = np.max(xy_grid[:, 0]), np.max(xy_grid[:, 1])
    min_x, min_y = np.min(xy_grid[:, 0]), np.min(xy_grid[:, 1])
    direction_metric = direction_metric/np.linalg.norm(direction_metric, axis=1).reshape(-1, 1)
    direction_metric = (direction_metric * np.array([max_x - min_x, max_y - min_y]) )/ N
    x, y = zip(*xy_grid)
    a, b = zip(*direction_metric)

    ax[1][layer].scatter(activation[:, 0], activation[:, 1], c=labels, edgecolors='k')
    ax[1][layer].quiver(x, y, a, zeros, angles='xy', scale_units='xy', scale=1, color='r')
    ax[1][layer].quiver(x, y, zeros, b, angles='xy', scale_units='xy', scale=1, color='r')
    ax[1][layer].set_title(f'Vector Direction - Layer {layer+1}')

    max_g_0, max_g_1 = np.max(metric_layer[:, 0]), np.max(metric_layer[:, 1])
    diag_metric[:, 0] = diag_metric[:, 0]/max_g_0
    diag_metric[:, 1] = diag_metric[:, 1]/max_g_1
    diag_metric = (diag_metric * np.array([max_x - min_x, max_y - min_y]) )/ N
    a, b = zip(*diag_metric)
        
        # Creating the vector field visualization
    ax[0][layer].scatter(activation[:, 0], activation[:, 1], c=labels, edgecolors='k')
    ax[0][layer].quiver(x, y, a, zeros, angles='xy', scale_units='xy', scale=1, color='r')
    ax[0][layer].quiver(x, y, zeros, b, angles='xy', scale_units='xy', scale=1, color='r')
    ax[0][layer].set_title(f'Vector Magnitude - Layer {layer+1}')

def plot_lattice(ax, activation, labels, xy_grid, metric_layer, layer, N=15):
    _, K, _ = metric_layer.shape
    if K != 2:
        metric_layer = metric_layer[:, :2, :2]

    metric_layer_tensor = torch.from_numpy(metric_layer).float()
    eigenvalues, eigenvectors = torch.linalg.eigh(metric_layer_tensor)

    eigenvalues = eigenvalues.detach().numpy()

    errors = np.log(1-eigenvalues[:,0] * (eigenvalues[:, 0] < 0))
    eigenvalues = eigenvalues * (eigenvalues > 0)
    eigenvalues = np.sqrt(eigenvalues)
    
    eigenvectors = eigenvectors.detach().numpy()
    ax[0][layer].scatter(activation[:, 0], activation[:, 1], c=labels, edgecolors='k')
    ax[0][layer].set_title(f'Ellipse Representing Metric - Layer {layer}')
    # Scale ellipse shapes
    max_x, max_y = np.max(xy_grid[:, 0]), np.max(xy_grid[:, 1])
    min_x, min_y = np.min(xy_grid[:, 0]), np.min(xy_grid[:, 1])
    max_g_0, max_g_1 = np.max(eigenvalues[:, 0]), np.max(eigenvalues[:, 1])
    eigenvalues[:, 0] = eigenvalues[:, 0]/max_g_0
    eigenvalues[:, 1] = eigenvalues[:, 1]/max_g_1
    eigenvalues = (eigenvalues * np.array([max_x - min_x, max_y - min_y]) )/ N

    # Plot ellipses
    for indx, (point, eigenvals, eigenvecs) in enumerate(zip(xy_grid, eigenvalues, eigenvectors)):
        width, height = eigenvals
        angle = np.degrees(np.arctan2(eigenvecs[0, 1], eigenvecs[0,0]))
        ellipse = Ellipse(xy=point, width=width, height=height, 
                          angle=angle, edgecolor='r', facecolor='none')


        ax[0][layer].add_patch(ellipse)


    xx = xy_grid[:, 0].reshape(N, N)
    yy = xy_grid[:, 1].reshape(N, N)
    Z = errors.reshape(N, N)
    ax[1][layer].scatter(activation[:, 0], activation[:, 1], c=labels, edgecolors='k')
    ax[1][layer].set_title(f'Negative Eigenvalue Log Error - Layer {layer}')
    contour = ax[1][layer].contourf(xx, yy, Z, alpha=0.4, cmap="RdBu_r")
    plt.colorbar(contour)

    
    return eigenvalues




def pullback_plot(model, X, labels, save_path, epoch=0, N=15, plot_method='lattice'):
    X_tensor = torch.from_numpy(X).float()
    model.forward(X_tensor, save_activations=True)
    activations = model.get_activations()

    activations_np = [activation.detach().numpy() for activation in activations]

    manifold = LocalDiagPCA(activations_np[-1], sigma=0.05, rho=1e-3)

    N_layers = len(activations_np)
    g = [0 for _ in activations_np]



    xy_grid = generate_lattice(activations_np[-1], N)
    g[-1] = manifold.metric_tensor(xy_grid.transpose())

    if plot_method == 'lattice':
        fig, ax = plt.subplots(2, N_layers, figsize=(N_layers * 16, 8*2))
        plot_lattice(ax, activations_np[-1], labels, xy_grid, g[-1], -1, N=N)    
    
    elif plot_method == "lattice_diagonal":
        fig, ax = plt.subplots(3, N_layers, figsize=(N_layers * 16, 24))
        plot_lattice_diagonal(ax, activations_np[-1], labels, xy_grid, g[-1], -1, N=N)

    elif plot_method == 'surface':
        xy_grid = generate_lattice(activations_np[0], N)
        xy_grid_tensor = torch.from_numpy(xy_grid).float()
        model.forward(xy_grid_tensor, save_activations=True)
        fig, ax = plt.subplots(2, N_layers, figsize=(N_layers * 16, 16))
        surface = model.get_activations()
        surface_np = [activation.detach().numpy() for activation in surface]
        plot_surface(ax, activations_np[-1], labels, surface_np[-1], g[-1], -1, N=N)
    else:
        raise ValueError(f'Plot method {plot_method} not recognised. Please use either lattice or surface.')
    
    final_layer_metric_tensor = torch.from_numpy(g[-1]).float()
    dim_out = model.layers[-1].out_features
    store_plot_grids = [_ for _ in activations_np]
    store_plot_grids[-1] = xy_grid
    for indx in reversed(range(0, N_layers-1)):
        def forward_layers(x):
            return model.forward_layers(x, indx)
        
        if plot_method == 'lattice' or plot_method == "lattice_diagonal":
            xy_grid = generate_lattice(activations_np[indx], N)
            xy_grid_tensor = torch.from_numpy(xy_grid).float()
            
        elif plot_method == 'surface':
            xy_grid_tensor = surface[indx]
            xy_grid = surface_np[indx]
        dim_in = model.layers[indx].in_features
        jacobian = compute_jacobian_multi_layer(model.layers[indx], xy_grid_tensor, dim_in, dim_out)
        jacobian = jacobian.detach().numpy()
        g[indx]  = np.einsum('lai,lbj,lab->lij', jacobian, jacobian, g[indx+1])
        
        if plot_method == 'lattice':
            plot_lattice(ax, activations_np[indx], labels, xy_grid, g[indx], indx, N=N)    
        elif plot_method == "lattice_diagonal":
            plot_lattice_diagonal(ax, activations_np[indx], labels, xy_grid, g[indx], indx, N=N)
        elif plot_method == 'surface':
            plot_surface(ax, activations_np[indx], labels, xy_grid, g[indx], indx, N=N)
        store_plot_grids[indx] = xy_grid
    if save_path is None:
        fig.savefig(f"figures/{epoch}/pullback_{plot_method}.png")
    else:
        fig.savefig(f"figures/{epoch}/{save_path}/pullback_{plot_method}.png")

    plt.close()
    return g, store_plot_grids

def full_pullback_plot(model, X, labels, save_path, epoch=0, N=15, plot_method='lattice'):
    X_tensor = torch.from_numpy(X).float()
    model.forward(X_tensor, save_activations=True)
    activations = model.get_activations()

    activations_np = [activation.detach().numpy() for activation in activations]

    manifold = LocalDiagPCA(activations_np[-1], sigma=0.05, rho=1e-3)

    N_layers = len(activations_np)
    g = [0 for _ in activations_np]


    xy_grid = generate_lattice(activations_np[-1], N)
    g[-1] = manifold.metric_tensor(xy_grid.transpose())
#
    if plot_method == 'lattice':
        fig, ax = plt.subplots(2, N_layers, figsize=(N_layers * 16, 8*2))
        plot_lattice(ax, activations_np[-1], labels, xy_grid, g[-1], -1, N=N)    
    
    elif plot_method == "lattice_diagonal":
        fig, ax = plt.subplots(3, N_layers, figsize=(N_layers * 16, 24))
        plot_lattice_diagonal(ax, activations_np[-1], labels, xy_grid, g[-1], -1, N=N)

    elif plot_method == 'surface':
        xy_grid = generate_lattice(activations_np[0], N)
        xy_grid_tensor = torch.from_numpy(xy_grid).float()
        model.forward(xy_grid_tensor, save_activations=True)
        surface = model.get_activations()
        surface_np = [activation.detach().numpy() for activation in surface]
        fig, ax = plt.subplots(2, N_layers, figsize=(N_layers * 16, 16))
        plot_surface(ax, activations_np[-1], labels, surface_np[-1], g[-1], -1, N=N)
    
    else:
        raise ValueError(f'Plot method {plot_method} not recognised. Please use either lattice or surface.')
    store_plot_grids = [_ for _ in activations_np]
    store_plot_grids[-1] = xy_grid
    final_layer_metric_tensor = g[-1]
    dim_out = model.layers[-1].out_features
    for indx in reversed(range(0, N_layers-1)):
        def forward_layers(x):
            return model.forward_layers(x, indx)
        
        if plot_method == 'lattice' or plot_method == "lattice_diagonal":
            xy_grid = generate_lattice(activations_np[indx], N)
            xy_grid_tensor = torch.from_numpy(xy_grid).float()
            
        elif plot_method == 'surface':
            xy_grid_tensor = surface[indx]
            xy_grid = surface_np[indx]
        dim_in = model.layers[indx].in_features
        jacobian = compute_jacobian_multi_layer(forward_layers, xy_grid_tensor, dim_in, dim_out)
        jacobian = jacobian.detach().numpy()
        g_pullback = np.einsum('lai,lbj,lab->lij', jacobian, jacobian, final_layer_metric_tensor)
        g[indx] = g_pullback
        
        if plot_method == 'lattice':
            plot_lattice(ax, activations_np[indx], labels, xy_grid, g[indx], indx, N=N)    
        elif plot_method == "lattice_diagonal":
            plot_lattice_diagonal(ax, activations_np[indx], labels, xy_grid, g[indx], indx, N=N)
        elif plot_method == 'surface':
            plot_surface(ax, activations_np[indx], labels, xy_grid, g[indx], indx, N=N)
        store_plot_grids[indx] = xy_grid

    if save_path is None:
        fig.savefig(f"figures/{epoch}/full_pullback_{plot_method}.png")
    else:
        fig.savefig(f"figures/{epoch}/{save_path}/full_pullback_{plot_method}.png")
    print(len(store_plot_grids))

    plt.close()
    return g, store_plot_grids

def local_plot(model, X, labels, save_path, epoch=0, N=15, plot_method='lattice'):
    X_tensor = torch.from_numpy(X).float()
    model.forward(X_tensor, save_activations=True)
    activations = model.get_activations()

    activations_np = [activation.detach().numpy() for activation in activations]

    manifold = LocalDiagPCA(activations_np[0], sigma=0.05, rho=1e-3)

    N_layers = len(activations_np)
    g = [0 for _ in activations_np]


    
    xy_grid = generate_lattice(activations_np[0], N)
    g[0] = manifold.metric_tensor(xy_grid.transpose())

    
    if plot_method == 'lattice':
        fig, ax = plt.subplots(2, N_layers, figsize=(N_layers * 16, 8*2))
        plot_lattice(ax, activations_np[0], labels, xy_grid, g[0], 0, N=N)    
    
    elif plot_method == "lattice_diagonal":
        fig, ax = plt.subplots(3, N_layers, figsize=(N_layers * 16, 24))
        plot_lattice_diagonal(ax, activations_np[0], labels, xy_grid, g[0], 0, N=N)

    elif plot_method == 'surface':
        xy_grid = generate_lattice(activations_np[0], N)
        xy_grid_tensor = torch.from_numpy(xy_grid).float()
        model.forward(xy_grid_tensor, save_activations=True)
        surface = model.get_activations()
        surface_np = [activation.detach().numpy() for activation in surface]
        fig, ax = plt.subplots(2, N_layers, figsize=(N_layers * 16, 16))
        plot_surface(ax, activations_np[0], labels, xy_grid, g[0], 0, N=N)
    
    else:
        raise ValueError(f'Plot method {plot_method} not recognised. Please use either lattice or surface.')
    store_plot_grids = [_ for _ in activations_np]
    store_plot_grids[0] = xy_grid
    for indx in reversed(range(1, N_layers)):
        manifold = LocalDiagPCA(activations_np[indx], sigma=0.05, rho=1e-3)
    
        if plot_method == 'lattice' or plot_method == "lattice_diagonal":
            xy_grid = generate_lattice(activations_np[indx], N)
            
        elif plot_method == 'surface':
            xy_grid = surface_np[indx]

        g[indx] = manifold.metric_tensor(xy_grid.transpose())

        
        
        if plot_method == 'lattice':
            plot_lattice(ax, activations_np[indx], labels, xy_grid, g[indx], indx, N=N)    
        elif plot_method == "lattice_diagonal":
            plot_lattice_diagonal(ax, activations_np[indx], labels, xy_grid, g[indx], indx, N=N)
        elif plot_method == 'surface':
            plot_surface(ax, activations_np[indx], labels, xy_grid, g[indx], indx, N=N)
        store_plot_grids[indx] = xy_grid

    if save_path is None:
        fig.savefig(f"figures/{epoch}/local_{plot_method}.png")
    else:
        fig.savefig(f"figures/{epoch}/{save_path}/local_{plot_method}.png")

    plt.close()
    return g, store_plot_grids
    





    


def violin_plot(cosine_scores, magnitude_scores, save_name=None, epoch=0):

    fig, ax = plt.subplots(1, 2, figsize=(16*2, 16))

    ax[0].violinplot(cosine_scores, showmeans=True, showmedians=True)
    ax[0].set_xlabel('Layer Index')
    ax[0].set_ylabel('Cosine Similarity')
    ax[0].set_title('Distribution of Cosine Similarity of Riemannian Metrics by Layer')

    ax[1].violinplot(magnitude_scores, showmeans=True, showmedians=True)
    ax[1].set_xlabel('Layer Index')
    ax[1].set_ylabel('Magnitude Difference')
    ax[1].set_title('Distribution of Magnitude Difference of Riemannian Metrics by Layer')

    plt.tight_layout()
    plt.grid(axis='y')
    if save_name is not None:
        plt.savefig(f"figures/{epoch}/{save_name}/violin_plot.png")
    else:
        plt.savefig(f"figures/{epoch}/violin_plot.png")
    plt.close()
    

def plot_err_heatmap(cosine_scores, magnitude_scores, xy_grids, model, X, y, N=50, save_name=None, epoch=0):
    fig, ax = plt.subplots(2, len(cosine_scores), figsize=(16*len(cosine_scores), 16))
    X = torch.from_numpy(X).float()
    model.forward(X, save_activations=True)
    for indx, scores in enumerate(cosine_scores):
        x_min, x_max = xy_grids[indx][:, 0].min(), xy_grids[indx][:, 0].max()
        y_min, y_max = xy_grids[indx][:, 1].min(), xy_grids[indx][:, 1].max()
        xx, yy = np.meshgrid(np.linspace(x_min, x_max, N), np.linspace(y_min, y_max, N))
        output = np.array(scores).reshape(xx.shape)
        contour = ax[0][indx].contourf(xx, yy, output, alpha=0.4, cmap="RdBu_r")
        ax[0][indx].scatter(model.activations[indx].detach().numpy()[:, 0], model.activations[indx].detach().numpy()[:, 1], c=y, s=20, edgecolor='k')
        ax[0][indx].set_xlim(xx.min(), xx.max())
        ax[0][indx].set_ylim(yy.min(), yy.max())
        ax[0][indx].set_title(f'Cosine Difference between Two Pullback Metrics - Layer {indx+1}')
        plt.colorbar(contour)

        output = np.array(magnitude_scores[indx]).reshape(xx.shape)
        contour = ax[1][indx].contourf(xx, yy, output, alpha=0.4, cmap="RdBu_r")
        ax[1][indx].scatter(model.activations[indx].detach().numpy()[:, 0], model.activations[indx].detach().numpy()[:, 1], c=y, s=20, edgecolor='k')
        ax[1][indx].set_xlim(xx.min(), xx.max())
        ax[1][indx].set_ylim(yy.min(), yy.max())
        ax[1][indx].set_title(f'Magnitude Difference between Two Pullback Metrics - Layer {indx+1}')
        plt.colorbar(contour)
        


        fig.canvas.draw()
    if save_name is not None:
        fig.savefig(f"figures/{epoch}/{save_name}/err_heatmap.png")
    else:
        fig.savefig(f"figures/{epoch}/err_heatmap.png")
    plt.close()






In [20]:
dataset_name = 'moon'
size = "skinny"
model_name = "mlp"
models_path = f"../../../models/supervised/{model_name}/saved_models"

if dataset_name == 'moon':
    dataset = MoonDataset(n_samples=1000, noise=0.01)
elif dataset_name == 'blobs':
    dataset = BlobsDataset(n_samples=1000, noise=0.01)
elif dataset_name == 'spiral':
    dataset = SpiralDataset(n_samples=1000, noise=0.01)
elif dataset_name == 'circles':
    dataset = CirclesDataset(n_samples=1000, noise=0.01)

epoch = 199


if size == "skinny":
    model = MLP(2,7,2,2)
    model.load_state_dict(torch.load(f'{models_path}/2_wide/{dataset_name}/model_{epoch}.pth'))
elif size == "overfit":
    model = MLP(2,7,2,1)
    model.load_state_dict(torch.load(f'{models_path}/overfit/{dataset_name}/model_{epoch}.pth'))
    final_layer = model.layers[-1]
    model.layers = model.layers[:-1]
elif size == "vanilla":
    model = MLP(2,7,10,2)
    model.load_state_dict(torch.load(f'{models_path}/vanilla/{dataset_name}/model_{epoch}.pth'))

os.makedirs(f'figures/{epoch}', exist_ok=True)
os.makedirs(f'figures/{epoch}/{dataset_name}', exist_ok=True)
os.makedirs(f'figures/{epoch}/{dataset_name}/{size}', exist_ok=True)

In [8]:
local_g_diag_lattice = script.local_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=50, plot_method='lattice_diagonal')
local_g_lattice = script.local_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=15, plot_method='lattice')
local_g_surface = script.local_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=50, plot_method='surface')

In [22]:
pullback_g_diag_lattice, grid_diag_lattice = pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=50, plot_method='lattice_diagonal')
pullback_g_lattice, grid_lattice = pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=15, plot_method='lattice')
pullback_g_surface, grid_surface = pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=50, plot_method='surface')


8
8
8


In [None]:
full_pullback_g_diag_lattice, full_grid_diag_lattice = script.full_pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=50, plot_method='lattice_diagonal')
full_pullback_g_lattice, full_grid_lattice = script.full_pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=15, plot_method='lattice')
full_pullback_g_surface, full_grid_surface = script.full_pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=50, plot_method='surface')

In [None]:
cosine_score = compute_cosine_score(pullback_g_lattice, full_pullback_g_lattice)
magnitude_score = compute_magnitude_score(pullback_g_lattice, full_pullback_g_lattice)
script.plot_err_heatmap(cosine_score, magnitude_score, full_grid_lattice, model, dataset.X, dataset.y, N=15, save_name=f"{dataset_name}/{size}", epoch=epoch)
script.violin_plot(cosine_score, magnitude_score, save_name=f"{dataset_name}/{size}", epoch=epoch)

In [15]:
def _iter_plot(dataset_name, size, epoch):
    if dataset_name == 'moon':
        dataset = MoonDataset(n_samples=1000, noise=0.01)
    elif dataset_name == 'blobs':
        dataset = BlobsDataset(n_samples=1000, noise=0.01)
    elif dataset_name == 'spiral':
        dataset = SpiralDataset(n_samples=1000, noise=0.01)
    elif dataset_name == 'circles':
        dataset = CirclesDataset(n_samples=1000, noise=0.01)



    if size == "skinny":
        model = MLP(2,7,2,2)
        model.load_state_dict(torch.load(f'{models_path}/2_wide/{dataset_name}/model_{epoch}.pth'))
    elif size == "overfit":
        model = MLP(2,7,2,1)
        model.load_state_dict(torch.load(f'{models_path}/overfit/{dataset_name}/model_{epoch}.pth'))
        final_layer = model.layers[-1]
        model.layers = model.layers[:-1]
    elif size == "vanilla":
        model = MLP(2,7,10,2)
        model.load_state_dict(torch.load(f'{models_path}/vanilla/{dataset_name}/model_{epoch}.pth'))

    os.makedirs(f'figures/mlp/{epoch}', exist_ok=True)
    os.makedirs(f'figures/mlp/{epoch}/{dataset_name}', exist_ok=True)
    os.makedirs(f'figures/mlp/{epoch}/{dataset_name}/{size}', exist_ok=True)

    local_g_diag_lattice = script.local_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=50, plot_method='lattice_diagonal')
    local_g_lattice = script.local_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=15, plot_method='lattice')
    local_g_surface = script.local_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=50, plot_method='surface')
    
    pullback_g_diag_lattice, grid_diag_lattice = script.pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=50, plot_method='lattice_diagonal')
    pullback_g_lattice, grid_lattice = script.pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=15, plot_method='lattice')
    pullback_g_surface, grid_surface = script.pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=50, plot_method='surface')

    full_pullback_g_diag_lattice, full_grid_diag_lattice = script.full_pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=50, plot_method='lattice_diagonal')
    full_pullback_g_lattice, full_grid_lattice = script.full_pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=15, plot_method='lattice')
    full_pullback_g_surface, full_grid_surface = script.full_pullback_plot(model, dataset.X, dataset.y, epoch=epoch, save_path = f"{dataset_name}/{size}", N=50, plot_method='surface')

    cosine_score = compute_cosine_score(pullback_g_lattice, full_pullback_g_lattice)
    magnitude_score = compute_magnitude_score(pullback_g_lattice, full_pullback_g_lattice)
    script.plot_err_heatmap(cosine_score, magnitude_score, full_grid_lattice, model, dataset.X, dataset.y, N=15, save_name=f"{dataset_name}/{size}", epoch=epoch)
    script.violin_plot(cosine_score, magnitude_score, save_name=f"{dataset_name}/{size}", epoch=epoch)


In [16]:
import tqdm 
dataset_name = 'moon'
size = "overfit"
epoch = 199
iter_ = [('moon', 'skinny', 5), ('moon', 'overfit', 5), ('moon', 'overfit', 199), ('moon', 'overfit', 499), ('moon', 'overfit', 999), ('moon', 'overfit', 4999), ('moon', 'overfit', 9999), ('moon', 'skinny', 199)]

for (dataset_name, size, epoch) in tqdm.tqdm(iter_):
    _iter_plot(dataset_name, size, epoch)


  direction_metric = direction_metric/np.linalg.norm(direction_metric, axis=1).reshape(-1, 1)
100%|██████████| 8/8 [07:16<00:00, 54.59s/it]
