## Imports

In [None]:
import copy
import logging
from pathlib import Path
from typing import Dict
import math
import itertools

import hydra
import matplotlib
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import omegaconf
import seaborn as sns
import torch  # noqa
import wandb
from hydra.utils import instantiate
from matplotlib import tri
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from omegaconf import DictConfig
from pytorch_lightning import LightningModule
from scipy.stats import qmc
from torch.utils.data import DataLoader
from tqdm import tqdm
from ccmm.matching.utils import perm_indices_to_perm_matrix
from ccmm.utils.utils import normalize_unit_norm, project_onto
from functools import partial

from nn_core.callbacks import NNTemplateCore
from nn_core.common import PROJECT_ROOT
from nn_core.common.utils import seed_index_everything
from nn_core.model_logging import NNLogger
from ccmm.utils.utils import fuse_batch_norm_into_conv
from torch.utils.data import DataLoader, Subset, SubsetRandomSampler
import autograd.numpy as anp

import torch
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors
import numpy as np
from scipy.linalg import eig
from numpy.linalg import svd
from scipy.optimize import linear_sum_assignment
import scipy
import json

import pymanopt
import pymanopt.manifolds
import pymanopt.optimizers


import ccmm  # noqa
from ccmm.matching.utils import (
    apply_permutation_to_statedict,
    get_all_symbols_combinations,
    plot_permutation_history_animation,
    restore_original_weights,
)
from ccmm.utils.utils import (
    linear_interpolate,
    load_model_from_info,
    map_model_seed_to_symbol,
    save_factored_permutations,
)
from ccmm.pl_modules.pl_module import MyLightningModule

from ccmm.matching.utils import load_permutations

from ccmm.utils.utils import vector_to_state_dict, get_interpolated_loss_acc_curves, cumulative_sum


from ccmm.fm_utils import refine as zoomOut_refine


import pytorch_lightning

In [None]:
matplotlib.rcParams["font.family"] = "serif"
sns.set_context("talk")
cmap_name = "coolwarm_r"

logging.getLogger("lightning.pytorch").setLevel(logging.WARNING)
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("pytorch_lightning.accelerators.cuda").setLevel(logging.WARNING)
pylogger = logging.getLogger(__name__)

## Configuration

In [None]:
%load_ext autoreload
%autoreload 2

import hydra
from hydra import initialize, compose
from typing import Dict, List

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(version_base=None, config_path=str("../conf"), job_name="matching_n_models")

In [None]:
cfg = compose(config_name="func_maps", overrides=[])

In [None]:
core_cfg = cfg  # NOQA
cfg = cfg.matching

seed_index_everything(cfg)

## Hyperparameters

Change these values to any positive number $x \in [1, \dots, N]$ to select a subsample of the corresponding dataset

In [None]:
num_test_samples = -1
num_train_samples = -1

## Load dataset

In [None]:
transform = instantiate(core_cfg.dataset.test.transform)

train_dataset = instantiate(core_cfg.dataset.train, transform=transform)
test_dataset = instantiate(core_cfg.dataset.test, transform=transform)

num_train_samples = len(train_dataset) if num_train_samples < 0 else num_train_samples

train_subset = Subset(train_dataset, list(range(num_train_samples)))
train_loader = DataLoader(train_subset, batch_size=1000, num_workers=cfg.num_workers)

num_test_samples = len(test_dataset) if num_test_samples < 0 else num_test_samples
test_subset = Subset(test_dataset, list(range(num_test_samples)))

test_loader = DataLoader(test_subset, batch_size=1000, num_workers=cfg.num_workers)

In [None]:
trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False, max_epochs=10)

## Train models

In [None]:
import SpectralUtils


EPOCHS = 10

### Model definition

Here we define a standard MLP with a input layer, 3 hidden layers and an output layer. We use ReLU as the activation function and log_softmax as the output function. We return the activactions for each layer as we will use them in matching the networks.

In [None]:
import torch.nn as nn


class MLP(nn.Module):
    def __init__(self, input=28 * 28, num_classes=10):
        super().__init__()
        self.input = input
        self.layer0 = nn.Linear(input, 512)
        self.layer1 = nn.Linear(512, 512)
        self.layer2 = nn.Linear(512, 512)
        self.layer3 = nn.Linear(512, 256)
        self.layer4 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = x.view(-1, self.input)

        h0 = nn.functional.relu(self.layer0(x))

        h1 = nn.functional.relu(self.layer1(h0))

        h2 = nn.functional.relu(self.layer2(h1))

        h3 = nn.functional.relu(self.layer3(h2))

        h4 = self.layer4(h3)

        embeddings = [h0, h1, h2, h3, h4]

        return nn.functional.log_softmax(h4, dim=-1), embeddings

Permutation specifics tell us what permutations to apply to what layers. A PermutationSpec has two objects: 
- `perm_to_layers_and_axes`: a dictionary that maps each permutation matrix to the layers it permutes, specifying on what axis. e.g. `{'P0': [('conv1.weight', 0), ('conv2.weight', 1)], 'P1': ...}`
- `layer_and_axes_to_perm`: a dictionary that maps each layer to a tuple long as the number of dimensions of the layer, each dimension specifying the permutation that acts on it, e.g. `{ 'conv2.weight': ('P1', 'P0', None, None), 'conv3.weight': ...}`

In [None]:
from ccmm.matching.permutation_spec import MLPPermutationSpecBuilder, SpectralPermutationSpecBuilder


spectral_permutation_spec_builder = SpectralPermutationSpecBuilder(4)
permutation_spec_builder = MLPPermutationSpecBuilder(4)

spectral_permutation_spec = spectral_permutation_spec_builder.create_permutation_spec()
permutation_spec = permutation_spec_builder.create_permutation_spec()

### Train and test first model

In [None]:
cfg.seed_index = 0
seed_index_everything(cfg)
model_a = MyLightningModule(MLP(), num_classes=10)

spectral_model_a = MyLightningModule(MLP(), num_classes=10)
SpectralUtils.spectral_all(model=spectral_model_a.model, verbose=True)

In [None]:
print("Model A state dict keys: ", model_a.state_dict().keys())
print("Spectral Model A state dict keys: ", spectral_model_a.state_dict().keys())

In [None]:
trainer = instantiate(cfg.trainer, enable_progress_bar=True, enable_model_summary=False, max_epochs=EPOCHS)

trainer.fit(model_a, train_loader)
trainer.test(model_a, test_loader)

In [None]:
trainer = instantiate(cfg.trainer, enable_progress_bar=True, enable_model_summary=False, max_epochs=EPOCHS)

trainer.fit(spectral_model_a, train_loader)
trainer.test(spectral_model_a, test_loader)

### Train and test second model

In [None]:
cfg.seed_index = 1
seed_index_everything(cfg)

model_b = MyLightningModule(MLP(), num_classes=10)

spectral_model_b = MyLightningModule(MLP(), num_classes=10)
SpectralUtils.spectral_all(model=spectral_model_b.model, verbose=True)

trainer = instantiate(cfg.trainer, enable_progress_bar=True, enable_model_summary=False, max_epochs=EPOCHS)
trainer.fit(model_b, train_loader)
trainer.test(model_b, test_loader)

trainer = instantiate(cfg.trainer, enable_progress_bar=True, enable_model_summary=False, max_epochs=EPOCHS)
trainer.fit(spectral_model_b, train_loader)
trainer.test(spectral_model_b, test_loader)

## Matching

We use the permutations obtained from `git_rebasin` as ground truth to visualize the functional maps. We obtain these using function `weight_matching`

In [None]:
print(model_a)

In [None]:
print(spectral_model_a)

In [None]:
from ccmm.matching.weight_matching import weight_matching

permutations = weight_matching(permutation_spec, model_a.model.state_dict(), model_b.model.state_dict())

spectral_permutations = weight_matching(
    spectral_permutation_spec, spectral_model_a.model.state_dict(), spectral_model_b.model.state_dict()
)

## Focus on a single layer

In [None]:
layer_idx = 1


perm_gt = permutations[f"P_{layer_idx}"]

In [None]:
spectral_permutations_gt = spectral_permutations[f"P_{layer_idx}"]

In [None]:
print(spectral_permutations_gt)

### Descriptor 1: weights

In [None]:
layer_a_weights = model_a.model.state_dict()[f"layer{layer_idx}.weight"]
layer_b_weights = model_b.model.state_dict()[f"layer{layer_idx}.weight"]

In [None]:
W_a = layer_a_weights.detach().numpy()
W_b = layer_b_weights.detach().numpy()

W_a.shape

### Descriptor 2: Activations

We run a forward pass over a single batch of size `num_activactions` to obtain the activactions from both models. 

In [None]:
num_activations = 10000
train_loader = DataLoader(train_subset, batch_size=num_activations, num_workers=cfg.num_workers)

In [None]:
for batch in train_loader:

    x, y = batch
    # model returns logits and a list of embeddings, so we take the embeddings
    features_a = model_a.model(x)[-1]
    features_b = model_b.model(x)[-1]
    break

In [None]:
# (descriptor_dim, num_neurons), where descriptor_dim is the number of samples for which we are considering the neuron activation
layer_a = features_a[layer_idx]
layer_b = features_b[layer_idx]

In [None]:
num_neurons = layer_a.shape[1]

In [None]:
# normalize to have unit norm

layer_a = layer_a / (torch.norm(layer_a, dim=0) + 1e-6)
layer_b = layer_b / (torch.norm(layer_b, dim=0) + 1e-6)

In [None]:
print(layer_a.shape, layer_b.shape)

In [None]:
from ccmm.utils.utils import to_np

layer_a = to_np(layer_a)
layer_b = to_np(layer_b)

### Descriptor 3: denoised activactions 

In [None]:
# (num_samples, num_neurons)
layer_a.shape

In [None]:
import numpy as np


def svd_threshold(matrix, variance_threshold=0.99):
    # Compute SVD
    U, S, Vt = np.linalg.svd(matrix, full_matrices=False)

    # Calculate the cumulative variance explained by the singular values
    total_variance = np.sum(S**2)
    explained_variance = np.cumsum(S**2) / total_variance

    # Determine the number of singular values needed to explain the desired threshold of variance
    num_components = np.argmax(explained_variance >= variance_threshold) + 1

    # Select the subset of singular values and vectors explaining the desired variance
    U_reduced = U[:, :num_components]
    S_reduced = S[:num_components]
    Vt_reduced = Vt[:num_components, :]

    return U_reduced, S_reduced, Vt_reduced, explained_variance


def svd_num_components(matrix, num_components=10):
    # matrix is ~ (num_samples, num_neurons)
    num_samples, num_neurons = matrix.shape
    K = num_components

    U, S, Vt = np.linalg.svd(matrix, full_matrices=False)

    assert U.shape == (num_samples, num_neurons)
    assert S.shape == (num_neurons,)
    assert Vt.shape == (num_neurons, num_neurons)

    U_reduced = U[:, :K]
    S_reduced = S[:K]
    Vt_reduced = Vt[:K, :]

    return U_reduced, S_reduced, Vt_reduced

In [None]:
# (num_samples, num_comps), (num_comps), (num_comps, num_neurons)
num_components = 256
U_a, S_a, Vt_a = svd_num_components(layer_a, num_components=num_components)
U_b, S_b, Vt_b = svd_num_components(layer_b, num_components=num_components)

In [None]:
print(U_a.shape, S_a.shape, Vt_a.shape)
print(U_b.shape, S_b.shape, Vt_b.shape)

#### Reconstruction error 

In [None]:
# express each layer as a linear combination of the singular vectors
layer_a_reconstructed = U_a @ np.diag(S_a) @ Vt_a
layer_b_reconstructed = U_b @ np.diag(S_b) @ Vt_b

layer_a_reconstructed.shape

In [None]:
# check if the reconstruction is close to the original layer by computing the norm
np.linalg.norm(layer_a_reconstructed - layer_a)

In [None]:
# check if the reconstruction is close to the original layer by computing the norm
np.linalg.norm(layer_b_reconstructed - layer_b)

In [None]:
# compute the norm of the two models for comparison
np.linalg.norm(layer_a - layer_b)

### Descriptor 4: eigenneurons

In [None]:
eigenneurons_a = 1 / ((np.diag(S_a) ** 0.5) + 1 - 6) @ Vt_a
eigenneurons_b = 1 / ((np.diag(S_b) ** 0.5) + 1 - 6) @ Vt_b

eigenneurons_a = eigenneurons_a.T
eigenneurons_b = eigenneurons_b.T

### Descriptor 5: Spectral Models

In [None]:
print(spectral_model_a)
print(spectral_model_a.state_dict().keys())

In [None]:
spectral_model_a.model.state_dict().keys()

In [None]:
spectral_layer_a_weights = spectral_model_a.model.state_dict()[f"layer{layer_idx}.0.weight"]
spectral_layer_a_eigvals = spectral_model_a.model.state_dict()[f"layer{layer_idx}.1.eigvals"]

spectral_layer_b_weights = spectral_model_b.model.state_dict()[f"layer{layer_idx}.0.weight"]
spectral_layer_b_eigvals = spectral_model_b.model.state_dict()[f"layer{layer_idx}.1.eigvals"]

spectral_a = spectral_layer_a_weights * spectral_layer_a_eigvals
spectral_b = spectral_layer_b_weights * spectral_layer_b_eigvals

## Functional maps

In [None]:
def get_descriptors(descriptor_type):

    if descriptor_type == "weights":
        X, Y = W_a, W_b
    elif descriptor_type == "features":
        X, Y = layer_a.T, layer_b.T
    elif descriptor_type == "features_denoised":
        X, Y = layer_a_reconstructed.T, layer_b_reconstructed.T
    elif descriptor_type == "eigenneurons":
        X, Y = eigenneurons_a, eigenneurons_b
    elif descriptor_type == "spectral":
        X, Y = spectral_a, spectral_b
        print("Since we are using spectral descriptors, we will use the spectral permutations")
        global perm_gt
        perm_gt = spectral_permutations[f"P_{layer_idx}"]
    else:
        raise ValueError("Invalid value for use_weights_or_features")
    return X, Y


descriptor_type = "spectral"  # weights, features, features_denoised, eigenneurons
X, Y = get_descriptors(descriptor_type)

### Build the KNN graph

In [None]:
def build_laplacian(A, normalized=True):

    D = np.diag(np.sum(A, axis=1)) + 1e-6

    assert not np.any(D < 0)

    L = D - A

    if normalized:
        D_inv_sqrt = np.diag(1 / np.sqrt(np.diag(D)))
        L = D_inv_sqrt @ L @ D_inv_sqrt
        L = (L + L.T) / 2

    assert not np.any(np.isnan(L))

    evals, evecs = np.linalg.eigh(L)

    idx = evals.argsort()
    evals = evals[idx]
    evecs = evecs[:, idx]

    return A, L, evals, evecs

In [None]:
def build_knn_graph(X, radius=None, num_neighbors=None, mode="distance"):
    assert radius is not None or num_neighbors is not None

    if radius is not None:
        Xneigh = NearestNeighbors(radius=radius)

    elif num_neighbors is not None:
        Xneigh = NearestNeighbors(n_neighbors=num_neighbors)

    else:
        raise ValueError("Either radius or num_neighbors must be provided")

    Xneigh.fit(X)

    # (num_neurons, num_neurons)
    X_knn_graph = Xneigh.kneighbors_graph(X, mode=mode)

    X_adj = X_knn_graph.toarray()

    np.fill_diagonal(X_adj, 0)

    X_adj_sym = (X_adj + X_adj.T) / 2

    assert np.allclose(X_adj_sym, X_adj_sym.T), "Adjacences are not symmetric"

    return X_adj_sym

### Functional maps

In [None]:
def zoomOut(elements, steps=2):  # elemets should be (func_map, Xevec, Yevec)
    func_map, Xevec, Yevec = elements

    return zoomOut_refine.zoomout_refine(FM_12=func_map, evects1=Xevec, evects2=Yevec, n_jobs=1, step=steps)

In [None]:
def compute_func_map(
    X,
    Y,
    P,
    radius=None,
    num_neighbors=None,
    mode="distance",
    normalize_lap=True,
    num_eigenvectors=50,
    returnEigenvectors=False,
):

    X_adj_sym = build_knn_graph(X, radius, num_neighbors, mode)
    Y_adj_sym = build_knn_graph(Y, radius, num_neighbors, mode)

    if X_adj_sym.sum() == 0 or Y_adj_sym.sum() == 0:
        return np.zeros((X_adj_sym.shape[0], Y_adj_sym.shape[0]))

    XA, XL, Xevals, Xevecs = build_laplacian(X_adj_sym, normalize_lap)
    YA, YL, Yevals, Yevecs = build_laplacian(Y_adj_sym, normalize_lap)

    Xevecs = Xevecs
    Yevecs = Yevecs

    num_eigenvectors = num_eigenvectors
    C = Xevecs[:, :num_eigenvectors].T @ P @ Yevecs[:, :num_eigenvectors]
    if returnEigenvectors:

        return C, Xevecs, Yevecs

    return C

In [None]:
def plot_func_maps(func_maps, fig_name, vmin, vmax):
    fig, axs = plt.subplots(7, 7, figsize=(20, 20))

    k = range(1, 100, 2)
    # Add title to subplot
    fig.suptitle(" ".join(fig_name.split("_")), fontsize=30)
    for i in range(7):
        for j in range(7):

            ax = axs[i, j]
            ax.imshow(func_maps[i * 7 + j], cmap=cmap_name, vmin=vmin, vmax=vmax)
            ax.axis("off")
            ax.set_title(f"k={k[i * 7 + j]}")
    # remove \n from fig_name
    fig_name = fig_name.replace("\n", "")
    plt.savefig(f"figures/{fig_name}.png")

In [None]:
P = perm_indices_to_perm_matrix(perm_gt).numpy()
normalize_lap = True
mode = "connectivity"  # connectivity or distance
num_eigenvectors = 20

In [None]:
func_maps_neighbors = [
    compute_func_map(
        X, Y, P, num_neighbors=k, mode=mode, normalize_lap=normalize_lap, num_eigenvectors=num_eigenvectors
    )
    for k in range(1, 100, 2)
]

In [None]:
plot_name = f"func_maps_{descriptor_type}_{mode}_normalizeLap_{normalize_lap}_numEigenvectors_{num_eigenvectors}"
plot_func_maps(func_maps_neighbors, plot_name, vmin=-0.6, vmax=0.6)

In [None]:
zoomOut_func_maps_neighbors = [
    (
        compute_func_map(
            X,
            Y,
            P,
            num_neighbors=k,
            mode=mode,
            normalize_lap=normalize_lap,
            num_eigenvectors=num_eigenvectors,
            returnEigenvectors=True,
        )
    )
    for k in range(3, 100, 2)
]

In [None]:
zoomOutSteps = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 15, 20, 30]

for steps in zoomOutSteps:
    zoomOut_func_maps = [zoomOut(elements, steps=steps) for elements in zoomOut_func_maps_neighbors]
    fmpSize = zoomOut_func_maps[0].shape[0]
    plot_func_maps(
        zoomOut_func_maps,
        f"zoomOut/zoomOut_func_maps_{descriptor_type}_{mode}_normalizeLap_{normalize_lap}_numEigenvectors_{num_eigenvectors}_\nzoomOutStep_{steps}_fmapSize_{fmpSize}x{fmpSize}",
        vmin=-0.5,
        vmax=0.5,
    )

In [None]:
func_maps_radius = [compute_func_map(X, Y, P, radius=r) for r in np.linspace(0.01, 1, 50)]

In [None]:
plot_func_maps(func_maps_radius, f"func_maps_{descriptor_type}_radius", vmin=-0.5, vmax=0.5)

### Transfer indicator function

In [None]:
mode = "distance"
normalize_lap = True
num_neighbors = 80
num_eigenvectors = 20

In [None]:
descriptor_type = "spectral"  # weights, features, features_denoised, eigenneurons, spectral
X, Y = get_descriptors(descriptor_type)

In [None]:
C = compute_func_map(
    X, Y, P, num_neighbors=num_neighbors, mode=mode, normalize_lap=normalize_lap, num_eigenvectors=num_eigenvectors
)

plt.imshow(C, cmap=cmap_name, vmin=-0.6, vmax=0.6)
plt.axis("off")

consider an indicator function $f$ 
 
$f(x_i) = 1$ for some $i$, 0 otherwise

In [None]:
indicator_func = np.zeros((num_neurons,))
selected_neuron_idx = 32
indicator_func[selected_neuron_idx] = 1

Project them onto the eigenvectors; basically, the identity matrix can be considered a stacking of all the indicator functions so we don't really need to do this

In [None]:
X_adj_sym = build_knn_graph(X, num_neighbors=num_neighbors, mode=mode)
XA, XL, Xevals, Xevecs = build_laplacian(X_adj_sym, normalize_lap)

Y_adj_sym = build_knn_graph(Y, num_neighbors=num_neighbors, mode=mode)
YA, YL, Yevals, Yevecs = build_laplacian(Y_adj_sym, normalize_lap)

In [None]:
# get degree from weighted adj matrix
X_adj_sym.mean()

In [None]:
Phi = Xevecs[:, :num_eigenvectors].real
Psi = Yevecs[:, :num_eigenvectors].real
P_tilde = Psi @ C @ Phi.T

take the argmax of P_tilde (not guaranteed to be a permutation)


In [None]:
mapped_points_argmax = P_tilde.argmax(axis=1)

solve an LAP to get a permutation 


In [None]:
from ccmm.matching.weight_matching import solve_linear_assignment_problem

P_tilde_lap = solve_linear_assignment_problem(P_tilde.T, return_matrix=True)

### Comparing permutation matrices
For each point, we map it to the other graph

In [None]:
from ccmm.matching.utils import perm_matrix_to_perm_indices

# mapped_points[i] = j means that the i-th point in the first set is mapped to the j-th point in the second set
mapped_points = perm_matrix_to_perm_indices(P_tilde_lap)
mapped_points[:10]

In [None]:
num_exact_matchings = (mapped_points == perm_gt).sum().item()
num_exact_matchings

we compute the minimum path from the mapped point to the ground truth point
* x axis has a radius (0, diameter of the graph) 
* y axis has the number of matchings that are within the radius from the ground truth point
* for radius=0, you are counting the number of exact matchings; for radius=diameter, every matching is considered a match
* the curve goes from 0 to 100%, the faster curve gets to 100%, the better the matching

In [None]:
from collections import deque


def bfs_shortest_distance(adj, start):
    # Initialize distances with infinity
    n = adj.shape[0]
    distance = [np.inf] * n
    distance[start] = 0
    queue = deque([start])

    while queue:
        current = queue.popleft()
        for i in range(n):
            if adj[current, i] > 0 and distance[i] == np.inf:
                queue.append(i)
                distance[i] = distance[current] + 1
    return distance


def bfs_shortest_path(adj, u, v):
    # Number of nodes
    n = adj.shape[0]
    # To keep track of visited nodes to prevent revisiting
    visited = [False] * n
    # To keep track of the path
    parent = [-1] * n

    # Queue for BFS
    queue = deque([u])
    visited[u] = True

    # Perform BFS
    while queue:
        current = queue.popleft()

        # If we've reached the target node, break
        if current == v:
            break

        # Check all adjacent nodes
        for i in range(n):
            if adj[current, i] > 0 and not visited[i]:
                queue.append(i)
                visited[i] = True
                parent[i] = current

    # Reconstruct the path from u to v
    path = []
    if visited[v]:
        while v != -1:
            path.append(v)
            v = parent[v]
        path.reverse()

    return path if path else None

In [None]:
print(num_neurons)

In [None]:
path_lengths = []

for i in range(num_neurons):
    pred_mapping = mapped_points[i]
    gt_mapping = perm_gt[i]

    shortest_path = bfs_shortest_path(Y_adj_sym, pred_mapping.item(), gt_mapping.item())
    shortest_path_length = len(shortest_path) - 1 if shortest_path is not None else np.inf

    path_lengths.append(shortest_path_length)

In [None]:
path_lengths[:10]

# look for argmin_{P_tilde} P_tilde - Psi C Phi^T
# multiply Psi^T to the left
# Psi^T P_tilde - C Phi^T      --- Phi^T ~ (N, K)
# look for the binary P_tilde that minimizes this measure
# P_tilde_i = nearest_neighbor(Psi^T _i , C Phi^T _i)

In [None]:
from collections import Counter

path_length_count = Counter(path_lengths)

In [None]:
path_length_frequencies = {k: v / num_neurons for k, v in path_length_count.items()}
path_length_frequencies

In [None]:
def compute_graph_diameter(adj):
    n = adj.shape[0]
    diameter = 0

    for u in range(n):
        distance = bfs_shortest_distance(adj, u)
        # Update the diameter with the maximum distance found from this node
        max_distance = max(distance)

        if max_distance > diameter and max_distance != np.inf:
            diameter = max_distance

    return diameter


diameter = compute_graph_diameter(X_adj_sym)
print(diameter)

In [None]:
radiuses = range(0, diameter + 1)
print(radiuses)

for r in radiuses:
    if r not in path_length_frequencies:
        path_length_frequencies[r] = 0

In [None]:
ys = cumulative_sum(path_length_frequencies)

xs = radiuses

plt.plot(xs, [ys[x] for x in xs], marker="o")

## Visualize graphs

In [None]:
k = 10

Xneigh = NearestNeighbors(n_neighbors=k)
Xneigh.fit(X)

# (num_neurons, num_neurons)
X_knn_graph = Xneigh.kneighbors_graph(X, mode="connectivity")

Yneigh = NearestNeighbors(n_neighbors=k)
Yneigh.fit(Y)
Y_knn_graph = Yneigh.kneighbors_graph(Y, mode="connectivity")

In [None]:
pca = PCA(n_components=3)
pca.fit(X.T)

Xx = pca.components_[0, :]
Xy = pca.components_[1, :]
Xz = pca.components_[2, :]

pca = PCA(n_components=3)
pca.fit(Y.T)

Yx = pca.components_[0, :]
Yy = pca.components_[1, :]
Yz = pca.components_[2, :]

fig, ax = plt.subplots(1, 2, figsize=(10, 5))

ax[0] = fig.add_subplot(121, projection="3d")
ax[0].scatter(Xx, Xy, Xz, c="tab:blue")

ax[1] = fig.add_subplot(122, projection="3d")
ax[1].scatter(Yx, Yy, Yz, c="tab:red")

plt.show()

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 5))

ax[0] = fig.add_subplot(121, projection="3d")

num_neurons = W_a.shape[0]
for i in range(num_neurons):
    for j in range(num_neurons):
        if X_knn_graph[i, j] > 0:
            ax[0].plot([Xx[i], Xx[j]], [Xy[i], Xy[j]], [Xz[i], Xz[j]], "b-", alpha=0.5)

ax[0].scatter(Xx, Xy, Xz, c="tab:blue")

ax[1] = fig.add_subplot(122, projection="3d")
for i in range(num_neurons):
    for j in range(num_neurons):
        if Y_knn_graph[i, j] > 0:
            ax[1].plot([Yx[i], Yx[j]], [Yy[i], Yy[j]], [Yz[i], Yz[j]], "b-", alpha=0.5)

ax[1].scatter(Yx, Yy, Yz, c="tab:red")

plt.show()

# Hic sunt leones: you can ignore this part

In [None]:
XA, XL, Xevals, Xevecs = build_laplacian(X_knn_graph, True)
YA, YL, Yevals, Yevecs = build_laplacian(Y_knn_graph, True)

### Solve a LAP in the reduced space

In [None]:
from scipy.optimize import linear_sum_assignment

# _, ci = linear_sum_assignment(U_a.T @ U_b + Vt_a.T @ Vt_b.T, maximize=True)
_, ci = linear_sum_assignment(layer_a_reconstructed.T @ layer_b_reconstructed, maximize=True)

In [None]:
perm_matrix = perm_indices_to_perm_matrix(torch.tensor(ci)).numpy()

In [None]:
perm_matrix.shape

In [None]:
layer_b_reconstructed_perm = perm_matrix @ layer_b_reconstructed.T

layer_b_reconstructed_perm = layer_b_reconstructed_perm.T

In [None]:
layer_b_recon_perm_norm = layer_b_reconstructed_perm / (np.linalg.norm(layer_b_reconstructed_perm, axis=0) + 1e-6)
layer_a_norm = layer_a / (np.linalg.norm(layer_a, axis=0) + 1e-6)
layer_b_norm = layer_b / (np.linalg.norm(layer_b, axis=0) + 1e-6)

In [None]:
np.trace(layer_b_recon_perm_norm.T @ layer_a_norm)

In [None]:
np.trace(layer_b_norm.T @ layer_a_norm)

### LAP in the original space

In [None]:
sim_matrix_orig_space = layer_a @ layer_b.T

_, ci = linear_sum_assignment(-sim_matrix_orig_space, maximize=True)
perm_matrix = perm_indices_to_perm_matrix(torch.tensor(ci)).numpy()

In [None]:
layer_b_perm = perm_matrix @ layer_b

layer_b_perm_norm = layer_b_perm / (np.linalg.norm(layer_b_perm, axis=0) + 1e-6)
np.trace(layer_a_norm @ layer_b_perm_norm.T)