In [1]:
import sys
import torch
import numpy as np
import pickle as pkl
from tqdm import tqdm

sys.path.append('../../../')

from experiments.assumptions.degeneracy.script import eigenvalue_result, eigenvalue_results_large, plot_rank_train, rank_over_training
from models.supervised.mlp.model import MLP
from models.supervised.bimt.model import BioMLP



In [2]:
import torch
import os
import numpy as np
import matplotlib.pyplot as plt

def degeneracy_plot(g, surface, activations_np, labels, save_path, wrt="layer_wise", precision=7):
    layers = len(g)
    fig, ax = plt.subplots(1, layers-1, figsize=(10*(layers-1), 10))

    for layer in range(layers-1):
        surface_ = surface[layer]
        activations_np_ = activations_np[layer]
        g_ = g[layer]
        eigenval = np.linalg.eigvals(g_)
        eigenval = np.round(eigenval, precision)
        cnt_zero_eigenval = np.sum(np.abs(eigenval) > 1e-5, axis=1)

        color = ax[layer].scatter(surface_[:, 0], surface_[:, 1], c=cnt_zero_eigenval, vmin=cnt_zero_eigenval.min(), vmax=cnt_zero_eigenval.max(), s=10, cmap="viridis")
        ax[layer].scatter(activations_np_[:, 0], activations_np_[:, 1], c=labels, s=10, alpha=0.5, cmap="RdBu_r")
        ax[layer].set_title(f"Rank of Metric Tensor over the Manifold - Layer {layer+1}")
        plt.colorbar(color, ax=ax[layer])
        ax[layer].set_xlabel("x")
        ax[layer].set_ylabel("y")
    path = f"{save_path}"
    if not os.path.exists(path):
        os.makedirs(path)
    fig.savefig(f"{path}/_{wrt}_degeneracy.png")
    plt.close(fig)

def plot_eigenvalue_spectra(g, surface, activations_np, labels, save_path, wrt="layer_wise", precision=7):

    layers = len(g)
    fig, ax = plt.subplots(2, layers-1, figsize=(10*(layers-1), 20))

    for layer in range(layers-1):
        surface_ = surface[layer]
        activations_np_ = activations_np[layer]
        g_ = g[layer]
        eigenval = np.linalg.eigvals(g_)
        eigenval = np.round(eigenval, precision).real
        min_eigenvalues = np.min(eigenval, axis=1)
        max_eigenvalues = np.max(eigenval, axis=1)


        color = ax[0][layer].scatter(surface_[:, 0], surface_[:, 1], c=min_eigenvalues, vmin=min_eigenvalues.min(), vmax=min_eigenvalues.max(), s=10, cmap="viridis")
        ax[0][layer].scatter(activations_np_[:, 0], activations_np_[:, 1], c=labels, s=10, alpha=0.5, cmap="RdBu_r")
        ax[0][layer].set_title(f"Minimum Eigenvalue - Layer {layer+1}")
        plt.colorbar(color, ax=ax[0][layer])
        ax[0][layer].set_xlabel("x")
        ax[0][layer].set_ylabel("y")

        color = ax[1][layer].scatter(surface_[:, 0], surface_[:, 1], c=min_eigenvalues, vmin=max_eigenvalues.min(), vmax=max_eigenvalues.max(), s=10, cmap="viridis")
        ax[1][layer].scatter(activations_np_[:, 0], activations_np_[:, 1], c=labels, s=10, alpha=0.5, cmap="RdBu_r")
        ax[1][layer].set_title(f"Maximum Eigenvalue - Layer {layer+1}")
        plt.colorbar(color, ax=ax[1][layer])
        ax[1][layer].set_xlabel("x")
        ax[1][layer].set_ylabel("y")

    path = f"{save_path}"
    if not os.path.exists(path):
        os.makedirs(path)
    fig.savefig(f"{path}/_{wrt}_eigenvalue_spectra.png")
    plt.close(fig)

def pseudo_riemann(g, surface, activations_np, labels, save_path, wrt="layer_wise", precision=7):
    layers = len(g)
    fig, ax = plt.subplots(1, layers-1, figsize=(10*(layers-1), 10))

    for layer in range(layers-1):
        surface_ = surface[layer]
        activations_np_ = activations_np[layer]
        g_ = g[layer]
        eigenval = np.linalg.eigvals(g_)
        eigenval = np.round(eigenval, precision)
        neg_eigenvalues = np.sum(eigenval < 0, axis=1)
        color = ax[layer].scatter(surface_[:, 0], surface_[:, 1], c=neg_eigenvalues, vmin=neg_eigenvalues.min(), vmax=neg_eigenvalues.max(), s=10, cmap="viridis")
        ax[layer].scatter(activations_np_[:, 0], activations_np_[:, 1], c=labels, s=10, alpha=0.5, cmap="RdBu_r")
        ax[layer].set_title(f"Number of Negative Eigenvalues over the Manifold - Layer {layer+1}")
        plt.colorbar(color, ax=ax[layer])
        ax[layer].set_xlabel("x")
        ax[layer].set_ylabel("y")
    path = f"{save_path}"
    if not os.path.exists(path):
        os.makedirs(path)
    fig.savefig(f"{path}/_{wrt}_pseudo_riemann.png")
    plt.close(fig)


def pseudo_riemann_det(g, surface, activations_np, labels, save_path, wrt="layer_wise", precision=7):
    layers = len(g)
    fig, ax = plt.subplots(1, layers-1, figsize=(10*(layers-1), 10))

    for layer in range(layers-1):
        surface_ = surface[layer]
        activations_np_ = activations_np[layer]
        g_ = g[layer]
        det_g = np.linalg.det(g_)
        det_g = np.round(det_g, precision)

        color = ax[layer].scatter(surface_[:, 0], surface_[:, 1], c=det_g, vmin=det_g.min(), vmax=det_g.max(), s=10, cmap="viridis")
        ax[layer].scatter(activations_np_[:, 0], activations_np_[:, 1], c=labels, s=10, alpha=0.5, cmap="RdBu_r")
        ax[layer].set_title(f"Determinant of $g$ over the Manifold - Layer {layer+1}")
        plt.colorbar(color, ax=ax[layer])
        ax[layer].set_xlabel("x")
        ax[layer].set_ylabel("y")
    path = f"{save_path}"
    if not os.path.exists(path):
        os.makedirs(path)
    fig.savefig(f"{path}/_{wrt}_det_pseudo_riemann.png")
    plt.close(fig)

def eigenvalue_distribution(g, save_path, wrt="layer_wise", precision=7):
    get_max_rank = 0
    for layer_g in g[:-1]:
        get_max_rank = max(get_max_rank, max(np.linalg.matrix_rank(layer_g)))
    max_dim = get_max_rank
    fig, ax = plt.subplots(max_dim, len(g)-1, figsize=(10*(len(g)-1), 10*max_dim))
    for layer, layer_g in enumerate(g[:-1]):
        eigenvals = np.linalg.eigvals(layer_g).real
        eigenvals = np.round(eigenvals, precision)
        eigenvals = np.abs(eigenvals)*np.log(np.abs(eigenvals)+1)
        sorted_eigenvals = np.sort(eigenvals, axis=1)
        if sorted_eigenvals.shape[-1] > max_dim: 
            sorted_eigenvals = sorted_eigenvals[:, -max_dim:]
        
        for indx, eig in enumerate(sorted_eigenvals.T):
            ax[indx][layer].hist(eig)
            ax[indx][layer].set_title(f"Log Eigenvalue {indx+1} - Layer {layer}")
    path = f"{save_path}"
    if not os.path.exists(path):
        os.makedirs(path)
    fig.savefig(f"{path}/_{wrt}_eigenvalue_distribution.png")
    plt.close(fig)

def plot_rank(g, save_path, wrt="layer_wise", precision=7):
    res = []        
    fig, ax = plt.subplots(1, 1, figsize=(max(2*(len(g)-1), 10), 8))
    q_25, med, q_75 = [], [], []
    for layer_g in g[:-1]:
        eigenvalues = np.linalg.eigvals(layer_g).real
        eigenvalues = np.round(eigenvalues, precision)
        ranks = np.sum(np.abs(eigenvalues) > 1e-5, axis=1)/np.shape(eigenvalues)[-1]
        res.append(ranks)
        q_25.append(np.quantile(ranks, 0.25))
        med.append(np.quantile(ranks, 0.5))
        q_75.append(np.quantile(ranks, 0.75))

    ax.violinplot(res, showmedians=True)
    ax.set_xlabel('Layer Index')
    ax.set_ylabel('Metric Rank')
    ax.set_title('Distribution of Metric Rank over Layers with Medians')

    plt.tight_layout()
    plt.grid(axis='y')    
    path = f"{save_path}"
    if not os.path.exists(path):
        os.makedirs(path)
    fig.savefig(f"{path}/_{wrt}_rank_distribution.png")
    plt.close(fig)
    return q_25, med, q_75

def eigenvalue_result(input_, model, N, labels, save_path, wrt="layer_wise", sigma=0.05, precision=7):
    X = torch.from_numpy(input_).float()
    model.forward(X, save_activations=True)

    activations = model.get_activations()
    activations_np = [a.detach().numpy() for a in activations]

    g, surface = pullback_metric(model, activations, N, wrt=wrt, method="manifold", sigma=sigma, normalised=True)
    eigenvalue_distribution(g, wrt=wrt, precision=precision, save_path=save_path)
    plot_rank(g, wrt=wrt, precision=precision, save_path=save_path)
    degeneracy_plot(g, surface, activations_np, labels, wrt=wrt, save_path=save_path, precision=precision)    
    plot_eigenvalue_spectra(g, surface, activations_np, labels, wrt=wrt, save_path=save_path, precision=precision)
    pseudo_riemann(g, surface, activations_np, labels, wrt=wrt, save_path=save_path, precision=precision)
    pseudo_riemann_det(g, surface, activations_np, labels, wrt=wrt, save_path=save_path, precision=precision)

def plot_rank_train(q_25, med, q_75, savepath, wrt="layer_wise"):
    fig, ax = plt.subplots(1, len(q_25[0]), figsize=(10*len(q_25[0]), 5))
    q_25, med, q_75 = np.array(q_25), np.array(med), np.array(q_75)
    for indx in range(len(q_25[0])):
        ax[indx].plot(med[:, indx], label="Median")
        ax[indx].fill_between(list(range(len(q_25))), q_25[:, indx], q_75[:, indx], alpha=0.2)
        ax[indx].set_xlabel('Epoch')
        ax[indx].set_ylabel('Metric Rank')
        ax[indx].set_title('Normalised Metric Rank throughout Training')
        ax[indx].legend()
        
    plt.tight_layout()
    path = f"{savepath}"
    if not os.path.exists(path):
        os.makedirs(path)
    fig.savefig(f"{path}/_{wrt}_rank_evolution.png")
    plt.close(fig)


def eigenvalue_results_large(input_, model, N, save_path, wrt="layer_wise", sigma=0.05, precision=7, sampling="manifold"):
    print('checkpoint 1')
    X = torch.from_numpy(input_).float()
    model.forward(X, save_activations=True)
    print("Gone forward")
    activations = model.get_activations()
    g, _ = pullback_metric(model, activations, N, wrt=wrt, method=sampling, sigma=sigma, normalised=False)
    print("Gone back")
    eigenvalue_distribution(g, wrt=wrt, precision=precision, save_path=save_path)
    q_25, med, q_75 = plot_rank(g, wrt=wrt, precision=precision, save_path=save_path)
    return q_25, med, q_75


def rank_over_training(input_, model, N, wrt="layer_wise", sigma=0.05, precision=7):
    X = torch.from_numpy(input_).float()
    model.forward(X, save_activations=True)

    activations = model.get_activations()
    g, _ = pullback_metric(model, activations, N, wrt=wrt, method="manifold", sigma=sigma, normalised=False)
  
    q_25, med, q_75 = [], [], []
    for layer_g in g[:-1]:
        eigenvalues = np.linalg.eigvals(layer_g).real
        eigenvalues = np.round(eigenvalues, precision)
        ranks = np.sum(np.abs(eigenvalues) > 1e-5, axis=1)/np.shape(eigenvalues)[-1]
        q_25.append(np.quantile(ranks, 0.25))
        med.append(np.quantile(ranks, 0.5))
        q_75.append(np.quantile(ranks, 0.75))

    return q_25, med, q_75

In [3]:
import numpy as np
from riemannian_geometry.computations.riemann_metric import LocalDiagPCA
from utils.plotting.mesh import generate_lattice
from utils.metrics.metrics import z_normalise
from riemannian_geometry.computations.sample import generate_manifold_sample, sample_points_heat_kernel
from riemannian_geometry.differential_geometry.curvature import batch_curvature, batch_vectorised_christoffel_symbols
import torch
from torch.func import vmap, jacfwd, jacrev
import time
from concurrent.futures import ThreadPoolExecutor

from scipy.sparse.linalg import eigsh
import numpy as np
import networkx as nx
import ghalton
from scipy.spatial import KDTree
from annoy import AnnoyIndex

def generate_halton_points(point_dataset, N):
    # Calculate the dimensionality of the dataset
    dim = point_dataset.shape[1]
    
    # Initialize the Halton sequence generator
    sequencer = ghalton.Halton(dim)
    
    # Generate N points
    halton_points = np.array(sequencer.get(N))
    
    # Scale the Halton points to match the range of the original dataset
    max_ = np.max(point_dataset, axis=0)+1e-2
    min_ = np.min(point_dataset, axis=0)-1e-2

    scaled_halton_points = halton_points * (max_ - min_) + min_
    
    return scaled_halton_points

def rejection_sampling(manifold, sampled_points, tol=1e-4):
    mask = manifold.metric_tensor(sampled_points.transpose(), nargout=1)
    mask = np.prod(1/np.diagonal(mask, axis1=1, axis2=2), axis=1)  > tol
    return sampled_points[mask]

def generate_manifold_sample(manifold, activations, N, tol=None):
    if tol is None:
        tol = manifold.rho**2
    halton_points = generate_halton_points(activations, N)
    return rejection_sampling(manifold, halton_points, tol=tol)

def find_k_approximate_neighbors(annoy, query_vector, k=5):
    return annoy.get_nns_by_vector(query_vector, k)

def construct_graph(points, k_neighbors=5):
    """
    Construct a k-nearest neighbor graph from the given points.
    Returns the weighted adjacency matrix W and degree matrix D.
    """
    if len(points.shape) > 2:
        prod = np.prod(points.shape[1:])
        points = points.reshape(points.shape[0], prod)
        G = nx.Graph()
        dim=points.shape[-1]

        annoy = AnnoyIndex(dim, metric='euclidean')
        for i, vector in enumerate(points):
            annoy.add_item(i, vector)
        annoy.build(int(np.sqrt(dim)))
        for indx, point in enumerate(points):
            result = find_k_approximate_neighbors(annoy, point, k=k_neighbors+1)[1:]
            for neighbor in result:
                G.add_edge(indx, neighbor)
        W = nx.adjacency_matrix(G).todense()
        D = np.diag(np.sum(W, axis=1))
        return W, D
    else:
        tree = KDTree(points)
        _, indices = tree.query(points, k=k_neighbors+1)  # +1 to include the point itself

        n = len(points)
        W = np.zeros((n, n))

        for i in range(n):
            for j in indices[i]:
                if i != j:
                    W[i, j] = np.exp(-np.linalg.norm(points[i] - points[j])**2)
                    W[j, i] = W[i, j]

        D = np.diag(np.sum(W, axis=1))
        return W, D

def find_minimal_connected_graph(points, start_k=None, connected_components=2):
    if start_k == None:
        start_k = int(np.sqrt(len(points)))
    W_start, D_start = construct_graph(points, k_neighbors=start_k)
    G_start = nx.from_numpy_array(W_start)
    if nx.number_connected_components(G_start) <= connected_components:
        W, D = W_start, D_start
        for k in reversed(range(1, start_k)):
            W_tmp, D_tmp = construct_graph(points, k_neighbors=k)
            G = nx.from_numpy_array(W)
            if nx.number_connected_components(G) > connected_components:
                return W, D, k+1
            else:
                W, D = W_tmp, D_tmp
    if nx.number_connected_components(G_start) > connected_components:
        k_max = len(points)//2
        W, D = W_start, D_start
        for k in range(start_k+1, k_max):
            W_tmp, D_tmp = construct_graph(points, k_neighbors=k)
            G = nx.from_numpy_array(W)
            if nx.number_connected_components(G) <= connected_components:
                return W, D, k
            else:
                W, D = W_tmp, D_tmp
    return W_start, D_start, start_k    

def compute_heat_kernel(W, D, t, dim):
    """Compute the heat kernel for a given time t using the adjacency matrix W and degree matrix D."""
    L = np.linalg.inv(np.sqrt(D)).dot(D - W).dot(np.linalg.inv(np.sqrt(D)))
    lambdas, phis = eigsh(L, k=dim, which='SM', tol=1e-5)  # Compute the 10 smallest eigenvalues and eigenvectors

    K_t = np.zeros_like(W, dtype=np.float64)
    for i in range(len(lambdas)):
        K_t += np.exp(-lambdas[i].real * t) * np.outer(phis[:, i].real, phis[:, i].real)

    return K_t

def sample_using_heat_kernel(points, K_t, num_samples=10):
    """Sample new points using the heat kernel."""
    n = len(points)
    new_points = []

    for _ in range(num_samples):
        # Pick a random starting point
        i = np.random.randint(n)
        
        # Sample a new point based on the heat distribution from point i
        j = np.random.choice(n, p=K_t[i]**2 / np.sum(K_t[i]**2))
        
        # Take the midpoint between the two points as the new point
        new_points.append((points[i] + points[j]) / 2)

    return np.array(new_points)

def sample_points_heat_kernel(points, num_samples=10, t=0.1, connect_components=1):
    """Sample new points using the heat kernel."""
    flag = len(points.shape) > 2
    W, D, k = find_minimal_connected_graph(points, connected_components=connect_components, start_k=None)
    print(f"Using {k} nearest neighbors")
    if flag:
        re_points = points.reshape(points.shape[0], np.prod(points.shape[1:]))
    K_t = compute_heat_kernel(W, D, t, dim=re_points.shape[1])
    new_points = sample_using_heat_kernel(re_points, K_t, num_samples=num_samples)
    if flag:
        new_points = new_points.reshape(num_samples, points.shape[-2], points.shape[-1])
    return new_points

def uniform_sample(n_samples, dataset):
    max_ = np.max(dataset, axis=0) + np.random.uniform(0, 0.1, size=dataset.shape[1])
    min_ = np.min(dataset, axis=0) + np.random.uniform(0, 0.1, size=dataset.shape[1])
    return np.random.uniform(min_, max_, size=(n_samples, dataset.shape[1]))




In [4]:
import numpy as np
from riemannian_geometry.computations.riemann_metric import LocalDiagPCA
from utils.plotting.mesh import generate_lattice
from utils.metrics.metrics import z_normalise
from riemannian_geometry.computations.sample import generate_manifold_sample, sample_points_heat_kernel
from riemannian_geometry.differential_geometry.curvature import batch_curvature, batch_vectorised_christoffel_symbols
import torch
from torch.func import vmap, jacfwd, jacrev
import time
from concurrent.futures import ThreadPoolExecutor

from scipy.sparse.linalg import eigsh
import numpy as np
import networkx as nx
import ghalton
from scipy.spatial import KDTree
from annoy import AnnoyIndex

def generate_halton_points(point_dataset, N):
    # Calculate the dimensionality of the dataset
    dim = point_dataset.shape[1]
    
    # Initialize the Halton sequence generator
    sequencer = ghalton.Halton(dim)
    
    # Generate N points
    halton_points = np.array(sequencer.get(N))
    
    # Scale the Halton points to match the range of the original dataset
    max_ = np.max(point_dataset, axis=0)+1e-2
    min_ = np.min(point_dataset, axis=0)-1e-2

    scaled_halton_points = halton_points * (max_ - min_) + min_
    
    return scaled_halton_points

def rejection_sampling(manifold, sampled_points, tol=1e-4):
    mask = manifold.metric_tensor(sampled_points.transpose(), nargout=1)
    mask = np.prod(1/np.diagonal(mask, axis1=1, axis2=2), axis=1)  > tol
    return sampled_points[mask]

def generate_manifold_sample(manifold, activations, N, tol=None):
    if tol is None:
        tol = manifold.rho**2
    halton_points = generate_halton_points(activations, N)
    return rejection_sampling(manifold, halton_points, tol=tol)

def find_k_approximate_neighbors(annoy, query_vector, k=5):
    return annoy.get_nns_by_vector(query_vector, k)

def construct_graph(points, k_neighbors=5):
    """
    Construct a k-nearest neighbor graph from the given points.
    Returns the weighted adjacency matrix W and degree matrix D.
    """
    if len(points.shape) > 2:
        prod = np.prod(points.shape[1:])
        points = points.reshape(points.shape[0], prod)
        G = nx.Graph()
        dim=points.shape[-1]

        annoy = AnnoyIndex(dim, metric='euclidean')
        for i, vector in enumerate(points):
            annoy.add_item(i, vector)
        annoy.build(int(np.sqrt(dim)))
        for indx, point in enumerate(points):
            result = find_k_approximate_neighbors(annoy, point, k=k_neighbors+1)[1:]
            for neighbor in result:
                G.add_edge(indx, neighbor)
        W = nx.adjacency_matrix(G).todense()
        D = np.diag(np.sum(W, axis=1))
        return W, D
    else:
        tree = KDTree(points)
        _, indices = tree.query(points, k=k_neighbors+1)  # +1 to include the point itself

        n = len(points)
        W = np.zeros((n, n))

        for i in range(n):
            for j in indices[i]:
                if i != j:
                    W[i, j] = np.exp(-np.linalg.norm(points[i] - points[j])**2)
                    W[j, i] = W[i, j]

        D = np.diag(np.sum(W, axis=1))
        return W, D

def find_minimal_connected_graph(points, start_k=None, connected_components=2):
    if start_k == None:
        start_k = int(np.sqrt(len(points)))
    W_start, D_start = construct_graph(points, k_neighbors=start_k)
    G_start = nx.from_numpy_array(W_start)
    if nx.number_connected_components(G_start) <= connected_components:
        W, D = W_start, D_start
        for k in reversed(range(1, start_k)):
            W_tmp, D_tmp = construct_graph(points, k_neighbors=k)
            G = nx.from_numpy_array(W)
            if nx.number_connected_components(G) > connected_components:
                return W, D, k+1
            else:
                W, D = W_tmp, D_tmp
    if nx.number_connected_components(G_start) > connected_components:
        k_max = len(points)//2
        W, D = W_start, D_start
        for k in range(start_k+1, k_max):
            W_tmp, D_tmp = construct_graph(points, k_neighbors=k)
            G = nx.from_numpy_array(W)
            if nx.number_connected_components(G) <= connected_components:
                return W, D, k
            else:
                W, D = W_tmp, D_tmp
    return W_start, D_start, start_k    

def compute_heat_kernel(W, D, t, dim):
    """Compute the heat kernel for a given time t using the adjacency matrix W and degree matrix D."""
    L = np.linalg.inv(np.sqrt(D)).dot(D - W).dot(np.linalg.inv(np.sqrt(D)))
    lambdas, phis = eigsh(L, k=dim, which='SM', tol=1e-5)  # Compute the 10 smallest eigenvalues and eigenvectors

    K_t = np.zeros_like(W, dtype=np.float64)
    for i in range(len(lambdas)):
        K_t += np.exp(-lambdas[i].real * t) * np.outer(phis[:, i].real, phis[:, i].real)

    return K_t

def sample_using_heat_kernel(points, K_t, num_samples=10):
    """Sample new points using the heat kernel."""
    n = len(points)
    new_points = []

    for _ in range(num_samples):
        # Pick a random starting point
        i = np.random.randint(n)
        
        # Sample a new point based on the heat distribution from point i
        j = np.random.choice(n, p=K_t[i]**2 / np.sum(K_t[i]**2))
        
        # Take the midpoint between the two points as the new point
        new_points.append((points[i] + points[j]) / 2)

    return np.array(new_points)

def sample_points_heat_kernel(points, num_samples=10, t=0.1, connect_components=1):
    """Sample new points using the heat kernel."""
    flag = len(points.shape) > 2
    W, D, k = find_minimal_connected_graph(points, connected_components=connect_components, start_k=None)
    print(f"Using {k} nearest neighbors")
    if flag:
        re_points = points.reshape(points.shape[0], np.prod(points.shape[1:]))
    K_t = compute_heat_kernel(W, D, t, dim=re_points.shape[1])
    new_points = sample_using_heat_kernel(re_points, K_t, num_samples=num_samples)
    if flag:
        new_points = new_points.reshape(num_samples, 1, points.shape[-2], points.shape[-1])
    return new_points

def uniform_sample(n_samples, dataset):
    max_ = np.max(dataset, axis=0) + np.random.uniform(0, 0.1, size=dataset.shape[1])
    min_ = np.min(dataset, axis=0) + np.random.uniform(0, 0.1, size=dataset.shape[1])
    return np.random.uniform(min_, max_, size=(n_samples, dataset.shape[1]))


def batch_form_pullback(start, end, jacobian, form):
    jacobian_batch = jacobian[start:end]
    form_batch = form[start:end]
    return np.einsum('lai,lbj,lck,labc->lijk', jacobian_batch, jacobian_batch, jacobian_batch, form_batch)

def compute_jacobian_layer(model, X, layer_indx):
    dim_in = model.layers[layer_indx].in_features
    dim_out = model.layers[layer_indx].out_features
    if dim_out >= dim_in:
        jacobian = vmap(jacfwd(model.layers[layer_indx].forward))(X)
    else:
        jacobian = vmap(jacrev(model.layers[layer_indx].forward))(X)
    return jacobian

def compute_jacobian_multi_layer(layer_func, X, dim_in, dim_out):
    if dim_out >= dim_in:
        jacobian = vmap(jacfwd(layer_func))(X)
    else:
        jacobian = vmap(jacrev(layer_func))(X)
    return jacobian


def pullback_metric(model, activations, N=50, wrt="output_wise", method="lattice", normalised=False, sigma=0.05):
    print('pull 1')
    activations_np = [activation.detach().numpy() for activation in activations]
    N_layers = len(activations_np)
    manifold = LocalDiagPCA(activations_np[-1], sigma=sigma, rho=1e-5)
    
    if method == "lattice":
        xy_grid = generate_lattice(activations_np[-1], N)

    elif method == "manifold":
        manifold_2 = LocalDiagPCA(activations_np[0], sigma=sigma, rho=1e-5)
        xy_grid = generate_manifold_sample(manifold_2, activations_np[0], N=N**2)
        del manifold_2
        surface_tensor = torch.from_numpy(xy_grid.reshape(activations_np[0].shape[-2])).float()
        model.forward(surface_tensor, save_activations=True)
        surfaces = model.get_activations() 
        _xy_grids = [surface.detach().numpy() for surface in surfaces]
        xy_grid = _xy_grids[-1]
    elif method == "heat":
        xy_grid = sample_points_heat_kernel(activations_np[0], num_samples=N**2, connect_components=2, t=1/sigma)
        print('Sampled points')
        print(xy_grid.shape)
        surface_tensor = torch.from_numpy(xy_grid).float()
        model.forward(surface_tensor, save_activations=True)
        surfaces = model.get_activations()
        print('Forwarded surface')
        print([s.shape for s in surfaces])
        _xy_grids = [surface.detach().numpy() for surface in surfaces]
        xy_grid = _xy_grids[-1]
    else:
        raise ValueError("method must be either 'lattice' or 'manifold'")
    g_ = manifold.metric_tensor(xy_grid.transpose(), nargout=1)


    g = [0 for _ in activations_np]
    
    g[-1] = g_

    save_grids = [0 for _ in activations_np]
    save_grids[-1] = xy_grid
    print('checkpoint pre-pullback')

    for indx in tqdm(reversed(range(0, N_layers-1))):
        dim_in = model.layers[indx].in_features
        dim_out = model.layers[indx].out_features
        if method == "manifold" or method == "heat":
            xy_grid = _xy_grids[indx]
            xy_grid_tensor = torch.from_numpy(xy_grid).float()

        elif method == "lattice":
            xy_grid = generate_lattice(activations_np[indx], N)
            xy_grid_tensor = torch.from_numpy(xy_grid).float()

        if wrt == "layer_wise":
            jacobian = compute_jacobian_multi_layer(model.layers[indx], xy_grid_tensor, dim_in, dim_out)
            ref = indx + 1

        elif wrt == "output_wise":
            def forward_layers(x):
                return model.forward_layers(x, indx)
            dim_out = model.layers[-1].out_features

            jacobian = compute_jacobian_multi_layer(forward_layers, xy_grid_tensor, dim_in, dim_out)
            ref = -1
        print('comp jac')
        jacobian = jacobian.detach().numpy()
        g_pullback = np.einsum('lai,lbj,lab->lij', jacobian, jacobian, g[ref])
        print('pulled back')
        if normalised:
            g_pullback = z_normalise(g_pullback)
        g[indx] = g_pullback
        save_grids[indx] = xy_grid

    return g, save_grids



In [5]:
from models.data.mnist import MNISTDataset
from models.supervised.cnn.model import CNN
import torch.nn as nn
import os
from torch.utils.data import DataLoader
from torch.utils.data import Subset

dataset = MNISTDataset(train=True, root="../../../data")
batch_size = 128
N_batches = batch_size//8
subset_mnist = np.random.randint(0, len(dataset), batch_size**2//8)
random_subset = Subset(dataset, subset_mnist)

val_data = DataLoader(random_subset, batch_size=batch_size, shuffle=False)
mode="moon"
model_name = "cnn"
size = "vanilla"

cnn_layers = [(1, 16, 3, 1), (16, 32, 3, 1)]
fc_layers = [(32 * 24 * 24, 128, nn.ReLU()), (128, 64, nn.ReLU())]
output_dim = 10

model = CNN(cnn_layers, fc_layers, output_dim)
models_path = f"../../../models/supervised/{model_name}/saved_models"

res_q_25, res_med, res_q_75 = [], [], []


tmp = os.listdir(f"{models_path}/{size}")

epochs = []
for i in tmp:
    if i[-3:] == "pth":
        epochs.append(int(i.split('_')[1].split('.')[0]))
epochs = sorted(epochs)

In [6]:
for epoch in tqdm([1]):
	model.load_state_dict(torch.load(f'{models_path}/{size}/model_{epoch}.pth'))
	model.eval()
	save_path = f"figures/{model_name}/{size}/{epoch}/"
	q_25, med, q_75 = [], [], []
	for X, y in val_data:
		q_25_tmp, med_tmp, q_75_tmp = eigenvalue_results_large(X.detach().numpy(), model, N=int(np.sqrt(len(X)*2)), wrt="output_wise", sigma=0.05, precision=7, save_path=save_path, sampling="heat")
		q_25 += q_25_tmp
		med += med_tmp
		q_75 += q_75_tmp

	res_q_25.append(q_25)
	res_med.append(med)
	res_q_75.append(q_75)


  0%|          | 0/1 [00:00<?, ?it/s]

checkpoint 1
Gone forward
pull 1




Using 11 nearest neighbors
Sampled points
(256, 1, 28, 28)
Forwarded surface
[torch.Size([256, 1, 28, 28]), torch.Size([256, 16, 26, 26]), torch.Size([256, 32, 24, 24]), torch.Size([256, 18432]), torch.Size([256, 128]), torch.Size([256, 64])]
checkpoint pre-pullback




comp jac




pulled back
comp jac


: 

: 