In [9]:
import sys
import torch
import pickle as pkl
import numpy as np
sys.path.append('../../../')

#from experiments.supervised.product_manifold.script import main, _plot_clusters, _plot_holonomy_dist, _plot_holonomy_surface
from models.supervised.mlp.model import MLP
from models.unsupervised.vae.model import VAE, Encoder, Decoder
from models.supervised.bimt.model import BioMLP
from experiments.supervised.holonomy.script import main


In [10]:
np.random.seed(2)
torch.manual_seed(2)

<torch._C.Generator at 0x7fa1dac2ec10>

In [None]:
size = "vanilla"
mode = "moon"
model_name = "mlp"
epoch = 199
#model = BioMLP(shp=[2,20,20,2])
model = MLP(2,7,10,2)
models_path = f"../../../models/supervised/{model_name}/saved_models"

with open(f'{models_path}/{size}/{mode}/dataset.pkl', 'rb') as f:
	dataset = pkl.load(f)

full_path = f'{models_path}/{size}/{mode}/model_{epoch}.pth'
model.load_state_dict(torch.load(full_path))
if size == "overfit":
	model.num_layers -= 1
	model.layers = model.layers[:-1]
model.eval()


In [None]:
sigma = 0.1
quantile = 0.9
tol = 1e-5
save_path = f"figures/{model_name}/{mode}/{size}/{epoch}"
N = 20
holonomy_manifolds, loop_point_manifolds, transformation_matrix = main(model, dataset, N, sigma, quantile, tol, save_path, MIN_SIZE=10, wrt="output_wise", plot_hol=True, plot_graph=True, plot_group=True)


In [11]:
import os
import matplotlib.pyplot as plt
from torch import from_numpy
import random
from collections import deque, defaultdict
import networkx as nx
import numpy as np
from riemannian_geometry.differential_geometry.curvature import batch_vectorised_christoffel_symbols
from riemannian_geometry.computations.pullback_metric import pullback_holonomy
from riemannian_geometry.differential_geometry.holonomy import product_manifold
from riemannian_geometry.computations.riemann_metric import LocalDiagPCA
from riemannian_geometry.computations.sample import sample_points_heat_kernel
from riemannian_geometry.differential_geometry.curvature import batch_curvature
import time
from torch.func import vmap, jacfwd, jacrev
from concurrent.futures import ThreadPoolExecutor

def compute_jacobian_multi_layer(layer_func, X, dim_in, dim_out):
    if dim_out >= dim_in:
        jacobian = vmap(jacfwd(layer_func))(X)
    else:
        jacobian = vmap(jacrev(layer_func))(X)
    return jacobian

def batch_form_pullback(start, end, jacobian, form):
    jacobian_batch = jacobian[start:end]
    form_batch = form[start:end]
    return np.einsum('lai,lbj,lck,labc->lijk', jacobian_batch, jacobian_batch, jacobian_batch, form_batch)

def pullback_holonomy(model, activations, N=20, sigma=0.05, method="manifold", wrt="output_wise", normalised=False):
    activations_np = [activation.detach().numpy() for activation in activations]
    manifold = LocalDiagPCA(activations_np[-1], sigma=sigma, rho=1e-5)
    

    xy_grid = sample_points_heat_kernel(activations_np[0], num_samples=N**2, connect_components=2, t=1/sigma)
    surface_tensor = torch.from_numpy(xy_grid).float()
    model.forward(surface_tensor, save_activations=True)
    surfaces = model.get_activations()
    _xy_grids = [surface.detach().numpy() for surface in surfaces]
    xy_grid = _xy_grids[-1]

    output_g, output_dg, output_ddg = manifold.metric_tensor(xy_grid.transpose(), nargout=3)
    _, output_Ricci, _ = batch_curvature(output_g, output_dg, output_ddg)
    
    start = time.time()
    dim_in = model.layers[0].in_features
    dim_out = model.layers[-1].out_features
    xy_grid = _xy_grids[0]

    xy_grid_tensor = torch.from_numpy(xy_grid).float()
    def forward_layers(x):
        return model.forward_layers(x, 0)
    dim_out = model.layers[-1].out_features
    jacobian = compute_jacobian_multi_layer(forward_layers, xy_grid_tensor, dim_in, dim_out)
    end = time.time()
    print("Jacobian computed in {} seconds".format(end-start))

    start = time.time()
    jacobian = jacobian.detach().numpy()
    end = time.time()
    print("Jacobian converted to numpy in {} seconds".format(end-start))  

    start = time.time()
    g_pullback = np.einsum('lai,lbj,lab->lij', jacobian, jacobian, output_g)
    end = time.time()
    print("Pullback metric computed in {} seconds".format(end-start))

    start = time.time()
    n = g_pullback.shape[0]
    D = g_pullback.shape[1]
    dg_batches = []
    batch_size = n // D
    n_batches = n // batch_size
    remainder = n % batch_size
    with ThreadPoolExecutor() as executor:
        futures = []
        for i in range(n_batches):
            start_idx = i * batch_size
            end_idx = (i + 1) * batch_size
            future = executor.submit(batch_form_pullback, start_idx, end_idx, jacobian, output_dg)
            futures.append(future)
        if remainder > 0:
            start_idx = n_batches * batch_size
            end_idx = n  # Go until the end to include the remainder
            future = executor.submit(batch_form_pullback, start_idx, end_idx, jacobian, output_dg)
            futures.append(future)
        for future in futures:
            dg_batches.append(future.result())
    dg_pullback = np.concatenate(dg_batches, axis=0)
    end = time.time()
    print("Pullback differential computed in {} seconds".format(end-start))

    start = time.time()
    Ricci_pullback = np.einsum('Nai,Nbj,Nab->Nij', jacobian, jacobian, output_Ricci)
    end = time.time()
    print("Pullback Ricci computed in {} seconds".format(end-start))

    return g_pullback, dg_pullback, Ricci_pullback, xy_grid


def find_cycles_random_walk(graph, num_cycles, max_walk_length=100):
    cycles = []
    vertices = list(graph.nodes())
    
    while len(cycles) < num_cycles:
        # Step 1: Choose a random starting vertex
        start_vertex = random.choice(vertices)
        
        # Steps 2-5: Perform the random walk to find a cycle
        visited = set()
        walk = deque()
        current_vertex = start_vertex
        last_vertex = None  # Keep track of the last visited vertex
        
        for _ in range(max_walk_length):
            visited.add(current_vertex)
            walk.append(current_vertex)
            
            # Step 3: Move to a randomly chosen neighbor, avoiding the last visited vertex
            neighbors = [n for n in graph.neighbors(current_vertex) if n != last_vertex]
            if not neighbors:
                break  # No valid neighbors to move to; terminate this walk
                
            last_vertex, current_vertex = current_vertex, random.choice(neighbors)
            
            # Step 4: Check if we've found a cycle
            if current_vertex == start_vertex:
                cycle = list(walk)
                cycles.append(cycle)
                break
            
            # Terminate the walk if it's a repeat vertex but not the start (not a simple cycle)
            if current_vertex in visited and current_vertex != start_vertex:
                break

    return cycles

def parallel_transport(P, X, Christoffel):
    """
    Approximate the parallel transport of a vector X along a path P with given Christoffel symbols.
    
    Parameters:
    - P: list of numpy arrays, representing the points on the path
    - X: numpy array, representing the initial vector at P[0]
    - Christoffel: list of 3D numpy arrays, representing the Christoffel symbols at each point P[i]
    
    Returns:
    - X_transported: numpy array, representing the parallel transported vector at the end of the path
    """
    # Initialize the transported vector as X
    X_transported = np.copy(X)
    
    # Iterate over the path
    for i in range(len(P) - 1):
        # Compute the finite difference between adjacent points
        delta_x = P[i + 1] - P[i]
        
        # Get the Christoffel symbols at the current point
        Gamma = Christoffel[i]
        
        # Update the transported vector using the Christoffel symbols and the finite difference
        for k in range(len(X)):
            delta_X_k = -np.sum(Gamma[k, :, :] * X_transported[:, np.newaxis] * delta_x[np.newaxis, :])
            X_transported[k] += delta_X_k
    return X_transported

def holonomy_product_manifold(manifolds, metric, form, surface, num_cycles=None):
    if num_cycles is None:
        num_cycles = 10000
    holonomy_manifolds = []
    loop_point_manifolds = []
    transformation_matrix = []
    ranks = []
    for manifold in manifolds:
        point_cloud_metric = metric[list(manifold.nodes)]
        point_cloud_surface = surface[list(manifold.nodes)]
        point_cloud_form = form[list(manifold.nodes)]
        mappable_dict = {v: indx for indx, v in enumerate(list(manifold.nodes))}
        rank = int(manifold.nodes[list(manifold.nodes)[0]]["rank"])
        ranks.append(rank)
        print(f"Rank {rank}")
        if rank != 0:
            holonomy_manifold, loop_points, transformation_manifold = holonomy(manifold, point_cloud_metric, point_cloud_form, point_cloud_surface, mappable_dict, rank, num_cycles=num_cycles)
            holonomy_manifolds.append(holonomy_manifold)
            loop_point_manifolds.append(loop_points)
            transformation_matrix.append(transformation_manifold)
        else:
            holonomy_manifolds.append([])
            loop_point_manifolds.append([])
            transformation_matrix.append([])

    return holonomy_manifolds, loop_point_manifolds, transformation_matrix, ranks

def filter_eigs(g, dg, K):
    eigenvalues, eigenvectors = np.linalg.eig(g)
    sorted_indices = np.argsort(eigenvalues, axis=-1)

    top_k_indices = sorted_indices[:, -K:]
    V = np.take_along_axis(eigenvectors, np.expand_dims(top_k_indices, axis=2), axis=1)

    # Step 3: Compute the reduced metric \tilde{g}
    g_tilde = np.einsum('nia,njb,nab->nij', V, V, g)

    # Step 4: Compute the differential of the reduced metric \tilde{g}, d\tilde{g}
    d_g_tilde = np.einsum('nia,njb,nkc,nabc->nijk', V, V, V, dg)
    return V, g_tilde, d_g_tilde


def create_dict_from_lists(list1, list2):
    # Initialize a defaultdict to store the sum and count for each unique key
    sum_count_dict = defaultdict(lambda: {'sum': 0, 'count': 0})
    
    # Iterate over the sublists in list1 and list2
    for sublist1, sublist2 in zip(list1, list2):
        for key, value in zip(sublist1, sublist2):
            sum_count_dict[key]['sum'] += value
            sum_count_dict[key]['count'] += 1
    
    # Create the final dictionary where the value is the mean for each key
    mean_dict = {key: sum_count_dict[key]['sum'] / sum_count_dict[key]['count'] for key in sum_count_dict}
    
    return mean_dict

def _plot_hol(holonomy_manifolds, save_path, ranks, wrt="output_wise"):
    iter_ = [indx for indx, h in enumerate(holonomy_manifolds) if len(h) > 0]
    M = len(iter_)
    fig, ax = plt.subplots(1, M, figsize=(M * 8, 8))
    if M == 1:
        ax = [ax]
    for indx, i in enumerate(iter_):
        ax[indx].hist(holonomy_manifolds[i], bins=100)
        ax[indx].set_title(f"Manifold {i} - Rank {ranks[i]}")
    fig.savefig(f"{save_path}/_{wrt}_holonomy_hist.png")
    plt.close(fig)

def _plot_graph(loop_point_manifolds, holonomy_manifolds, subgraphs, pos, save_path, dataset, ranks, wrt="output_wise"):

    result = create_dict_from_lists(loop_point_manifolds, holonomy_manifolds)
    combined_graph = nx.Graph()
    for i in range(len(subgraphs)):
        combined_graph.add_nodes_from(subgraphs[i].nodes)
        for node in subgraphs[i].nodes:
            combined_graph.nodes[node]["cluster"]=i
        combined_graph.add_edges_from(subgraphs[i].edges)

    M = len(subgraphs)+1
    fig, ax = plt.subplots(1, M, figsize=(16*M, 12))
    ax[0].scatter(dataset.X[:,0], dataset.X[:,1], c=dataset.y, cmap=plt.cm.viridis, s=20, edgecolors = 'red')
    color = nx.draw_networkx_nodes(combined_graph, pos=pos, node_color=[combined_graph.nodes[node]["cluster"] for node in combined_graph.nodes], vmin=0, vmax=len(subgraphs), cmap=plt.cm.Accent, node_size=20, ax=ax[0])
    nx.draw_networkx_edges(combined_graph, pos=pos, alpha=0.6, ax=ax[0])
    plt.colorbar(color, ax=ax[0])
    ax[0].set_title("Combined graph")

    for indx, subgraph in enumerate(subgraphs):
        colors = [result.get(i, 0) for i in subgraph.nodes]
        v_min = min(colors)
        v_max = max(colors)
        color = nx.draw_networkx_nodes(subgraph, pos=pos, node_color=colors, cmap=plt.cm.RdBu_r, node_size=20, vmin=v_min, vmax=v_max, ax=ax[indx+1])
        nx.draw_networkx_edges(subgraphs[indx], pos=pos, alpha=0.6, ax=ax[indx+1])
        plt.colorbar(color, ax=ax[indx+1])
        ax[indx+1].set_title(f"Manifold {indx} - Rank {ranks[indx]}")

    fig.savefig(f"{save_path}/_{wrt}_holonomy_graph.png")
    plt.close(fig)

def _plot_holonomy_group(transformation_matrix, holonomy_manifolds, ranks, save_path, wrt="output_wise"):
    iter_ = [indx for indx, h in enumerate(holonomy_manifolds) if len(h) > 0]
    M = len(iter_)
    fig, ax = plt.subplots(2, M, figsize=(M * 8, 16))
    if M == 1:
        ax = [[ax[0]], [ax[1]]]
    for indx, i in enumerate(iter_):
        V = np.linalg.det(transformation_matrix[i])
        ax[0][indx].hist(V[np.absolute(V) < 2].real, bins=100)
        ax[0][indx].set_title(f"Manifold {i} - Rank {ranks[i]}")
        ax[0][indx].set_xlabel('Determinant of Transformation Matrix')
        ax[0][indx].set_ylabel('Frequency')
        ax[0][indx].grid()

        cosine_scores = []
        sine_scores = []

        for loop in holonomy_manifolds[0]:
            loop = loop.real
            sine_angle = np.sqrt(1 - loop ** 2)  # Since sin^2(theta) + cos^2(theta) = 1

            cosine_scores.append(loop)
            cosine_scores.append(-loop)
            sine_scores.append(sine_angle)
            sine_scores.append(-sine_angle)

        ax[1][indx].scatter(cosine_scores, sine_scores, s=10)
        ax[1][indx].set_xlim(-1.5, 1.5)
        ax[1][indx].set_ylim(-1.5, 1.5)
        ax[1][indx].set_xlabel('Cosine Scores')
        ax[1][indx].set_ylabel('Sine Scores')
        ax[1][indx].set_title('Holonomy Group Visualization')
        # Set grids on ax[1][i]
        ax[1][indx].grid()
    plt.tight_layout()
    fig.savefig(f"{save_path}/_{wrt}_holonomy_group.png")
    plt.close(fig)

def main(model, dataset, N, sigma, quantile, tol, save_path, MIN_SIZE=None, wrt="output_wise", plot_hol=False, plot_graph=False, plot_group=False):
    if plot_hol or plot_graph or plot_group:
        if not os.path.exists(save_path):
            os.makedirs(save_path)
    if isinstance(dataset.X, torch.Tensor):
        X = dataset.X.float()[:1000]
    else:
        X = from_numpy(dataset.X).float()
    model.forward(X, save_activations=True)
    activations = model.get_activations()
    # In pullback_metric_christoffel, we force the Christoffel symbols to be normalised as we primarily care about the direction.
    g_pullback, dg_pullback, Ricci_pullback, input_surface = pullback_holonomy(model, activations, N, wrt=wrt, method="manifold", sigma=sigma, normalised=False) 
    print(f"Pulled back")
    pos = {i: input_surface[i] for i in range(len(input_surface))}
    subgraphs = product_manifold(Ricci_pullback, input_surface, g_pullback, 
        quantile=quantile, max_K=5, dataset=dataset, pos=pos, plot_V=False, 
        save_path=None, tol=tol, MIN_SIZE=MIN_SIZE, wrt=wrt)

    holonomy_manifolds, loop_point_manifolds, transformation_matrix, ranks = holonomy_product_manifold(subgraphs, g_pullback, dg_pullback, input_surface)
    #if plot_hol:
    #    _plot_hol(holonomy_manifolds, save_path, ranks, wrt=wrt)
    #if plot_graph:
    #    _plot_graph(loop_point_manifolds, holonomy_manifolds, subgraphs, pos, save_path, dataset, ranks, wrt=wrt)
    #if plot_group:
    #    _plot_holonomy_group(transformation_matrix, holonomy_manifolds, ranks, save_path, wrt=wrt)
    return transformation_matrix, holonomy_manifolds, ranks, save_path, wrt


def holonomy(manifold, metric, form, surface, mappable_dict, rank, num_cycles=10000):
    """
    Compute the holonomy of a given manifold with respect to a given metric and surface.
    
    Parameters:
    - manifold: networkx.Graph, representing the manifold
    - metric: numpy array, representing the metric tensor at each point on the manifold
    - christoffel: numpy array, representing the christoffel tensor at each point on the manifold
    - surface: numpy array, representing the surface at each point on the manifold
    - mappable_dict: dict, mapping the indices of the points on the manifold to the indices of the points on the surface
    - rank: int, representing the rank of the manifold
    - num_cycles: int, representing the number of cycles to compute

    Returns:
    - holonomy: numpy array, representing the holonomy of the manifold
    """
    # Normalise the metric tensor
    eigenvectors, g_tilde, dg_tilde = filter_eigs(metric, form, rank)
    surface_proj = np.einsum('nij, nj -> ni', eigenvectors, surface)
    # Compute the Christoffel symbols of the reduced metric
    g_inv = np.linalg.inv(g_tilde)
    christoffel_tilde = batch_vectorised_christoffel_symbols(g_inv, dg_tilde)
    holonomy_manifold = []
    manifold_loop_points = []
    transformation_manifold = []
    cycles = find_cycles_random_walk(manifold, num_cycles=num_cycles)
    for cycle in tqdm(cycles):
        full_loop = cycle + [cycle[0]]
        points_path = surface_proj[[mappable_dict[p] for p in full_loop]]
        christoffel_path = christoffel_tilde[[mappable_dict[p] for p in full_loop]]
        start_vectors_path = eigenvectors[[mappable_dict[full_loop[0]]]]
        if rank != 1:
            start_vectors_path = start_vectors_path.squeeze()

        transformation_matrix = np.zeros_like(start_vectors_path.T)
        for indx, start_vector in enumerate(start_vectors_path.T):
            transport_vector = parallel_transport(points_path, start_vector, christoffel_path)
            #print("Shape of transport_vector:", transport_vector.shape)
            #print("Shape of transport_vector:", transport_vector.squeeze().shape)
            #print("Shape of transformation_matrix:", transformation_matrix[:, indx].shape)
            if rank == 1:
                transformation_matrix[indx] = np.array([transport_vector])
            else:
                transformation_matrix[indx] = transport_vector
            angle_diff = (np.dot(transport_vector, start_vector) / (np.linalg.norm(transport_vector) * np.linalg.norm(start_vector))).squeeze()
            holonomy_manifold.append(angle_diff)
            manifold_loop_points.append(full_loop[0])
        if rank != 1:
            transformation_matrix = start_vectors_path @ transformation_matrix
        else:
            transformation_matrix = (start_vectors_path.squeeze() @ transformation_matrix.squeeze()).reshape(1,1)
        transformation_manifold.append(transformation_matrix)

    return holonomy_manifold, manifold_loop_points, transformation_manifold


In [12]:
from tqdm import tqdm
mode="disentangled"
model_name = "vae"
size = "disentangled_5"
epoch = 1900


models_path = f"../../../models/unsupervised/{model_name}/saved_models"
res_q_25, res_med, res_q_75 = [], [], []
with open(f'{models_path}/{size}/dataset.pkl', 'rb') as f:
	dataset = pkl.load(f)

features = [64, 32, 16, 8 ,4]
encoder = Encoder(in_features=32, features=features, out_features=2)
decoder = Decoder(in_features=2, features=list(reversed(features)), out_features=32)
model = VAE(encoder, decoder)

model.load_state_dict(torch.load(f"{models_path}/{size}/model_{epoch}.pth"))

<All keys matched successfully>

In [13]:
sigma = 0.1
quantile = 0.9
tol = 1e-5
save_path = f"figures/{model_name}/{mode}/{size}/{epoch}"
N = 20
transformation_matrix, holonomy_manifolds, ranks, save_path, wrt = main(model, dataset, N, sigma, quantile, tol, save_path, MIN_SIZE=10, wrt="output_wise", plot_hol=True, plot_graph=False, plot_group=True)


Using 3 nearest neighbors
Jacobian computed in 0.12490177154541016 seconds
Jacobian converted to numpy in 0.001081228256225586 seconds
Pullback metric computed in 0.007870912551879883 seconds
Pullback differential computed in 0.4119439125061035 seconds
Pullback Ricci computed in 0.008700847625732422 seconds
Pulled back
K=2 -> 127
Nodes remaining: {2, 116}
K=3 -> 36
Nodes remaining: {42, 380, 378}
K=4 -> 5
Nodes remaining: {1, 2, 9, 51, 116, 21, 246, 22, 308, 156, 349}
Rank 2


100%|██████████| 10000/10000 [00:33<00:00, 295.03it/s]


Rank 1


100%|██████████| 10000/10000 [00:29<00:00, 341.84it/s]


Rank 2


100%|██████████| 10000/10000 [00:36<00:00, 274.28it/s]


Rank 1


100%|██████████| 10000/10000 [00:29<00:00, 340.85it/s]


Rank 2


100%|██████████| 10000/10000 [00:35<00:00, 285.20it/s]


Rank 2


100%|██████████| 10000/10000 [00:39<00:00, 256.40it/s]


Rank 2


100%|██████████| 10000/10000 [00:36<00:00, 276.58it/s]


In [8]:
import matplotlib.pyplot as plt
# Set font size for plots
plt.rcParams.update({'font.size': 16})
plt.tight_layout()
_plot_holonomy_group(transformation_matrix, holonomy_manifolds, ranks, save_path, wrt=wrt)
_plot_hol(holonomy_manifolds, save_path, ranks, wrt=wrt)


  sine_angle = np.sqrt(1 - loop ** 2)  # Since sin^2(theta) + cos^2(theta) = 1
  indices = f_indices.astype(np.intp)
  bins = np.array(bins, float)  # causes problems if float16


<Figure size 640x480 with 0 Axes>