In [None]:
from src.data.data import *
from src.orcml import *
from src.plotting import *
from src.utils.graph_utils import *
from src.isorc import *
from src.utils.embeddings import *
from sklearn.manifold import TSNE
import umap
%load_ext autoreload

exp_params = {
    'mode': 'nbrs',
    'n_neighbors': 15,
    'epsilon': None,
    'lda': 0.01,
    'delta': 0.8
}

In [None]:
n_points = 2500
noise = 6.2
noise_thresh = 2.2

dataset_info = {
    'name': '3D_swiss_roll',
    'n_points': n_points,
    'noise': noise,
    'noise_thresh': noise_thresh
}

return_dict = swiss_roll(n_points=n_points, noise=noise, noise_thresh=noise_thresh, supersample=True, dim=3, hole=False)
swiss_roll_data, color, cluster, swiss_roll_supersample, subsample_indices = return_dict['data'], return_dict['color'], return_dict['cluster'], return_dict['data_supersample'], return_dict['subsample_indices']

In [None]:
%autoreload 2
orcmanl = ORCManL(
    exp_params=exp_params,
    verbose=True,
    reattach=False,
    nbrhood_size=1
)
orcmanl.fit(return_dict['data'])

In [None]:
def kernel_distances(k_ij, d_ij, tau):
    """
    Computes the distance of an edge under the logarithmic barrier kernel 
    """
    return -1/tau * np.log(k_ij + 2) + (d_ij * tau + np.log(3))/tau

def compute_kernel_distances(G, tau, rep_factor=10):
    """
    Computes the distances of all edges in a graph under the logarithmic barrier kernel
    """
    kdists = []
    for u, v in G.edges():
        k_ij = G[u][v]['ricciCurvature']
        if G[u][v]['shortcut'] == 1:
            G[u][v]['weight'] = G[u][v]['weight'] * rep_factor
        d_ij = G[u][v]['weight']
        G[u][v]['kernel_distance'] = kernel_distances(k_ij, d_ij, tau)
        kdists.append(G[u][v]['kernel_distance'])
        orcs.append(k_ij)
    return G, kdists, orcs

In [None]:
G_orc = orcmanl.G_ann.copy()
G_orc, kdists, orcs = compute_kernel_distances(G_orc, 0.1)

In [None]:
import numpy as np
import networkx as nx
from scipy.sparse.csgraph import shortest_path
from multiprocessing import Pool

def process_row(args):
    """Compute a single row of the APSP matrix."""
    i, predecessors, A_euc_distance, n = args
    row_distances = np.zeros(n)
    for j in range(n):
        if i == j:
            row_distances[j] = 0
            continue
        total_weight_A = 0
        current = j
        while current != i:
            prev = predecessors[i, current]
            if prev == -9999:  # Path does not exist
                total_weight_A = np.inf
                break
            total_weight_A += A_euc_distance[prev, current]
            current = prev
        row_distances[j] = total_weight_A
    return i, row_distances

def compute_apsp_with_dual_weights_multiprocessing(G, weight_B, weight_A):
    """
    Compute the all-pairs shortest path matrix where:
    - Paths are shortest with respect to 'weight_B'.
    - Distances are accumulated with respect to 'weight_A'.
    
    This version parallelizes path reconstruction and weight accumulation using multiprocessing.
    """
    # Step 1: Get adjacency matrices for both weights
    A_kernel_distance = nx.to_numpy_array(G, weight=weight_B)
    A_euc_distance = nx.to_numpy_array(G, weight=weight_A)

    # Step 2: Compute shortest paths and predecessors with respect to 'weight_B'
    apsp_weight_B, predecessors = shortest_path(
        A_kernel_distance, directed=False, unweighted=False, return_predecessors=True
    )

    # Step 3: Use multiprocessing to process rows in parallel
    n = A_kernel_distance.shape[0]
    apsp_weight_A = np.zeros((n, n))

    args = [(i, predecessors, A_euc_distance, n) for i in range(n)]

    with Pool() as pool:
        results = pool.map(process_row, args)

    # Step 4: Populate the APSP matrix
    for i, row_distances in results:
        apsp_weight_A[i] = row_distances

    return apsp_weight_A, apsp_weight_B

In [None]:
apsp_weight_A, apsp_weight_B = compute_apsp_with_dual_weights_multiprocessing(G_orc, 'kernel_distance', 'weight')

In [None]:
%autoreload 2
from src.isorc import *

isorc = ISORC(
    orcmanl=orcmanl,
    exp_params=exp_params,
    verbose=True,
)
emb = isorc.fit_kPCA()
plot_data_2D(emb, color=color[isorc.G.nodes()], title=None)

In [None]:
plot_graph_3D(return_dict['data'], orcmanl.G_pruned, title=None)

In [None]:
from sklearn.manifold import Isomap

isomap = Isomap(n_components=2, metric='precomputed')
emb = isomap.fit_transform(apsp_weight_A)

plot_data_2D(emb, color=color[G_orc.nodes()], title=None)

In [None]:
umap_emb = umap.UMAP(n_components=2, n_neighbors=15, min_dist=0.1).fit_transform(return_dict['data'])
plot_data_2D(umap_emb, color=color, title=None)


In [None]:
%autoreload 2
# quadratics
n_points = 1200
noise = 0.15
noise_thresh = 0.4

return_dict = quadratics(n_points=n_points, noise=noise, noise_thresh=noise_thresh, supersample=True)

In [None]:
orcmanl = ORCManL(
    exp_params=exp_params,
    verbose=True,
    reattach=True
)
orcmanl.fit(return_dict['data'])

In [None]:
G_orc = orcmanl.G_ann.copy()
G_orc, kdists, orcs = compute_kernel_distances(G_orc, 1e-2, rep_factor=10)

In [None]:
apsp_weight_A, apsp_weight_B = compute_apsp_with_dual_weights_multiprocessing(G_orc, 'kernel_distance', 'weight')

In [None]:
isomap = Isomap(n_components=2, metric='precomputed')
emb_A = isomap.fit_transform(apsp_weight_A)
emb_B = isomap.fit_transform(apsp_weight_B)

plot_data_2D(emb_A, color=return_dict['cluster'][G_orc.nodes()], title=None)
plot_data_2D(emb_B, color=return_dict['cluster'][G_orc.nodes()], title=None)

In [None]:
umap_emb = umap.UMAP(n_components=2, n_neighbors=15, min_dist=0.1).fit_transform(return_dict['data'])
plot_data_2D(umap_emb, color=return_dict['cluster'], title=None)

In [None]:
# quadratics # more noise
n_points = 1200
noise = 0.20
noise_thresh = 0.4

return_dict = quadratics(n_points=n_points, noise=noise, noise_thresh=noise_thresh, supersample=True)

In [None]:
orcmanl = ORCManL(
    exp_params=exp_params,
    verbose=True,
    reattach=True
)
orcmanl.fit(return_dict['data'])

In [None]:
G_orc = orcmanl.G_ann.copy()
G_orc, kdists, orcs = compute_kernel_distances(G_orc, 1e-10)

In [None]:
apsp_weight_A, apsp_weight_B = compute_apsp_with_dual_weights_multiprocessing(G_orc, 'kernel_distance', 'weight')

In [None]:
%autoreload 2
plot_graph_2D(return_dict['data'], G_orc, node_color=return_dict['cluster'][G_orc.nodes()], edge_color=kdists, title=None)

In [None]:

isomap = Isomap(n_components=2, metric='precomputed')
emb_A = isomap.fit_transform(apsp_weight_A)
emb_B = isomap.fit_transform(apsp_weight_B)

plot_data_2D(emb_A, color=return_dict['cluster'][G_orc.nodes()], title=None)
plot_data_2D(emb_B, color=return_dict['cluster'][G_orc.nodes()], title=None)

In [None]:
umap_emb = umap.UMAP(n_components=2, n_neighbors=15, min_dist=0.1).fit_transform(return_dict['data'])
plot_data_2D(umap_emb, color=return_dict['cluster'], title=None)

In [None]:
import numpy as np
import torchvision

def get_mnist_data(n_samples, label=None):
    """
    Get n_samples MNIST data points with the specified label. If label is None, get n_samples random data points.
    Parameters:

    n_samples: int
        Number of data points to get
    label: int or None
        Label of the data points to get. If None, get random data points.
    Returns:
    ----------
    mnist_data: np.ndarray
        n_samples x 784 array of MNIST data points
    mnist_labels: np.ndarray
        n_samples array of MNIST labels
    """
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Lambda(lambda x: x.view(-1))
    ])
    mnist = torchvision.datasets.MNIST('../data', train=True, download=True, transform=transform)
    mnist_data = torch.stack([x for x, _ in mnist]).numpy().astype(np.float64)
    mnist_labels = torch.tensor([y for _, y in mnist]).numpy().astype(np.float64)
    if label is not None:
        label_indices = np.where(mnist_labels == label)[0]
        np.random.seed(0)
        np.random.shuffle(label_indices)
        label_indices = label_indices[:n_samples]
        mnist_data = mnist_data[label_indices]
        mnist_labels = mnist_labels[label_indices]
    else:
        np.random.seed(0)
        indices = np.random.choice(mnist_data.shape[0], n_samples, replace=False)
        mnist_data = mnist_data[indices]
        mnist_labels = mnist_labels[indices]
    return mnist_data, mnist_labels

# kmnist: path data/KMNIST/t10k-images-idx3-ubyte.gz

def get_kmnist_data(n_samples, label=None):
    """
    Get n_samples KMNIST data points with the specified label. If label is None, get n_samples random data points.
    Parameters:

    n_samples: int
        Number of data points to get
    label: int or None
        Label of the data points to get. If None, get random data points.
    Returns:
    ----------
    kmnist_data: np.ndarray
        n_samples x 784 array of KMNIST data points
    kmnist_labels: np.ndarray
        n_samples array of KMNIST labels
    """
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Lambda(lambda x: x.view(-1))
    ])
    kmnist = torchvision.datasets.KMNIST('data', train=True, download=True, transform=transform)
    kmnist_data = torch.stack([x for x, _ in kmnist]).numpy().astype(np.float64)
    # scale so distances are in a reasonable range
    kmnist_labels = torch.tensor([y for _, y in kmnist]).numpy().astype(np.float64)
    if label is not None:
        label_indices = np.where(kmnist_labels == label)[0]
        np.random.seed(0)
        np.random.shuffle(label_indices)
        label_indices = label_indices[:n_samples]
        kmnist_data = kmnist_data[label_indices]
        kmnist_labels = kmnist_labels[label_indices]
    else:
        np.random.seed(0)
        indices = np.random.choice(kmnist_data.shape[0], n_samples, replace=False)
        kmnist_data = kmnist_data[indices]
        kmnist_labels = kmnist_labels[indices]
    return kmnist_data, kmnist_labels

In [None]:
n_samples = 1500
mnist_data, mnist_labels = get_mnist_data(n_samples, label=None)

In [None]:
orcmanl = ORCManL(
    exp_params=exp_params,
    verbose=True,
    reattach=False
)
orcmanl.fit(mnist_data)

In [None]:
G_orc = orcmanl.G_ann.copy()
G_orc, kdists, orcs = compute_kernel_distances(G_orc, 1e-3)

In [None]:

apsp_weight_A, apsp_weight_B = compute_apsp_with_dual_weights_multiprocessing(G_orc, 'kernel_distance', 'weight')


In [None]:
isomap = Isomap(n_components=2, metric='precomputed')
emb_A = isomap.fit_transform(apsp_weight_A)
emb_B = isomap.fit_transform(apsp_weight_B)

plot_data_2D(emb_A, color=mnist_labels[G_orc.nodes()], title=None)
plot_data_2D(emb_B, color=mnist_labels[G_orc.nodes()], title=None)

In [None]:
umap_emb = umap.UMAP(n_components=2, n_neighbors=15, min_dist=0.1).fit_transform(mnist_data)
plot_data_2D(umap_emb, color=mnist_labels, title=None)