In [None]:
from src.data.data import *
from src.orcml import *
from src.plotting import *
from src.utils.graph_utils import *
from src.isorc import *
from src.utils.embeddings import *
from sklearn.manifold import TSNE
import umap
import numpy as np
import torchvision

%load_ext autoreload

exp_params = {
    'mode': 'nbrs',
    'n_neighbors': 25,
    'epsilon': None,
    'lda': 1e-5,
    'delta': 0.8
}


def orc_fa(G, a, b):
    A_orc = nx.to_numpy_array(G, weight='ricciCurvature')
    A_edge = nx.to_numpy_array(G, weight='weight')
    
    edge_mask = np.where(A_edge > 0, 1, 0)

    assert b >= a, "b must be greater than a"
    # map each row to [a, b]
    max_orc_row = 1
    min_orc_row = -2

    A_orc_normalized = (b - a) / (max_orc_row - min_orc_row) * (A_orc - min_orc_row) + a
    A_orc_normalized = (A_orc_normalized + A_orc_normalized.T) / 2
    A_orc_normalized *= edge_mask

    G_normalized = G.copy()
    node_indices = list(G.nodes())
    for i, (u, v) in enumerate(G_normalized.edges):
        idx_u = node_indices.index(u)
        idx_v = node_indices.index(v)
        assert A_edge[idx_u, idx_v] != 0 or u == v, "index misalignment"
        G_normalized[u][v]['orc_weight'] = A_orc_normalized[idx_u, idx_v]
    
    def fa(G, weight):
        spectral_init = nx.spectral_layout(G, weight='weight')
        print('Finished spectral initialization')
        # force directed layout
        node_mass = {node:1 for node in G.nodes()}
        pos = nx.forceatlas2_layout(G, pos=spectral_init, seed=42, weight=weight, node_mass=node_mass)
        print('Finished spring layout')
        return pos
    
    pos = fa(G_normalized, weight='orc_weight')
    return pos, G_normalized

In [None]:
%autoreload 2
# quadratics
n_points = 2000
noise = 0.175
noise_thresh = None

return_dict = quadratics(n_points=n_points, noise=noise, noise_thresh=noise_thresh, supersample=True)

In [None]:
orcmanl = ORCManL(
    exp_params=exp_params,
    verbose=True,
    reattach=True,
)
orcmanl.fit(return_dict['data'])

In [6]:
def isorc_fa(orcmanl):
    isorc = ISORC(
        orcmanl=orcmanl,
        exp_params=exp_params,
        verbose=True,
        uniform=False,
        temperature=1
    )
    apsp_energy = isorc.apsp_energy.copy()
    # convert to affinities
    def energy_to_affinity(energy, a):
        affinity = a * (np.exp(-energy**2/energy.max()) - 0.5)
        return affinity
    apsp_affinities = energy_to_affinity(apsp_energy, a=3)

    # now construct the graph
    G = isorc.G.copy()
    node_indices = list(G.nodes())
    affinities = []
    for i, (u, v) in enumerate(G.edges):
        idx_u = node_indices.index(u)
        idx_v = node_indices.index(v)
        assert isorc.A_energy[idx_u, idx_v] != 0 or u == v, "index misalignment"
        G[u][v]['orc_weight'] = apsp_affinities[idx_u, idx_v]
        affinities.append(apsp_affinities[idx_u, idx_v])

    def fa(G, weight):
        spectral_init = nx.spectral_layout(G, weight='weight')
        print('Finished spectral initialization')
        # force directed layout
        node_mass = {node:1 for node in G.nodes()}
        pos = nx.forceatlas2_layout(G, pos=spectral_init, seed=42, weight=weight, node_mass=node_mass)
        print('Finished spring layout')
        return pos

    pos = fa(G, weight='orc_weight')
    return pos, G

In [None]:
pos, G = isorc_fa(orcmanl)
plot_graph_2D(pos, G, node_color=return_dict['cluster'][G.nodes()], edge_color=orcmanl.orcs, title=None, edge_width=0.05, node_size=0.2)

In [8]:

def get_mnist_data(n_samples, label=None):
    """
    Get n_samples MNIST data points with the specified label. If label is None, get n_samples random data points.
    Parameters:

    n_samples: int
        Number of data points to get
    label: int or None
        Label of the data points to get. If None, get random data points.
    Returns:
    ----------
    mnist_data: np.ndarray
        n_samples x 784 array of MNIST data points
    mnist_labels: np.ndarray
        n_samples array of MNIST labels
    """
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Lambda(lambda x: x.view(-1))
    ])
    mnist = torchvision.datasets.MNIST('/home/tristan/Research/Fa24/isorc/data', train=True, download=True, transform=transform)
    mnist_data = torch.stack([x for x, _ in mnist]).numpy().astype(np.float64)
    mnist_labels = torch.tensor([y for _, y in mnist]).numpy().astype(np.float64)
    if label is not None:
        label_indices = np.where(mnist_labels == label)[0]
        np.random.seed(0)
        np.random.shuffle(label_indices)
        label_indices = label_indices[:n_samples]
        mnist_data = mnist_data[label_indices]
        mnist_labels = mnist_labels[label_indices]
    else:
        np.random.seed(0)
        indices = np.random.choice(mnist_data.shape[0], n_samples, replace=False)
        mnist_data = mnist_data[indices]
        mnist_labels = mnist_labels[indices]
    return mnist_data, mnist_labels

In [None]:
# mnist_data, mnist_labels = get_mnist_data(n_samples=2000, label=None)
# orcmanl = ORCManL(
#     exp_params=exp_params,
#     verbose=True,
#     reattach=True,
# )
# orcmanl.fit(mnist_data)
%autoreload 2
# isorc = ISORC(
#     orcmanl=orcmanl,
#     exp_params=exp_params,
#     verbose=True,
#     uniform=False,
#     temperature=1
# )
# plot orc vs energy
def plot_energy():
    from src.utils.graph_utils import energy
    orcs = np.linspace(-1.99, 1, 100)
    energies = [energy(orc) for orc in orcs]
    plt.figure()
    plt.scatter(orcs, energies)
    plt.xlabel('ORC')
    plt.ylabel('Energy')
plot_energy()

apsp_energy = isorc.apsp_energy.copy()
# convert to affinities
def energy_to_affinity(energy, a=4, sigma=7.5e-2):
    norm_energy = energy/(energy.max())
    affinity = a * (np.exp(-norm_energy**2/sigma) - 0.5)
    return affinity, norm_energy

def plot_affinity():
    energies = np.linspace(0.01, 1, 100)
    affinities, _ = energy_to_affinity(energies)
    plt.figure()
    plt.scatter(energies, affinities)
    plt.xlabel('Normalized Energy')
    plt.ylabel('Affinity')
plot_affinity()

apsp_affinities, apsp_energy_norm = energy_to_affinity(apsp_energy)

# now construct the graph
G = isorc.G.copy()
node_indices = list(G.nodes())
affinities = []
norm_energies = []
for i, (u, v) in enumerate(G.edges):
    idx_u = node_indices.index(u)
    idx_v = node_indices.index(v)
    assert isorc.A_energy[idx_u, idx_v] != 0 or u == v, "index misalignment"
    G[u][v]['orc_weight'] = apsp_affinities[idx_u, idx_v]
    affinities.append(apsp_affinities[idx_u, idx_v])
    norm_energies.append(apsp_energy_norm[idx_u, idx_v])

plt.figure()
plt.hist(norm_energies, bins=100)
plt.xlim([0, 1])
plt.xlabel('Normalized Energy')
plt.ylabel('Count')
plt.title('Energy distribution')

plt.figure()
plt.hist(affinities, bins=100)
plt.xlabel('Affinity')
plt.ylabel('Count')
plt.title('Affinity distribution')
pos = fa(G, weight='orc_weight')
plot_graph_2D(pos, G, node_color=mnist_labels[G.nodes()], edge_color=affinities, title=None, edge_width=0.03, node_size=0.2)

In [None]:
def isorc_fa(orcmanl, a, sigma, plot=False):
    isorc = ISORC(
        orcmanl=orcmanl,
        exp_params=exp_params,
        verbose=True,
        uniform=False,
        temperature=1
    )
    apsp_energy = isorc.apsp_energy.copy()
    # convert to affinities
    def energy_to_affinity(energy, a=a, sigma=sigma):
        norm_energy = energy/(energy.max())
        affinity = a * (np.exp(-norm_energy**2/sigma) - 0.5)
        return affinity, norm_energy
    apsp_affinities, apsp_energy_norm = energy_to_affinity(apsp_energy)

    # now construct the graph
    G = isorc.G.copy()
    node_indices = list(G.nodes())
    affinities = []
    norm_energies = []
    for i, (u, v) in enumerate(G.edges):
        idx_u = node_indices.index(u)
        idx_v = node_indices.index(v)
        assert isorc.A_energy[idx_u, idx_v] != 0 or u == v, "index misalignment"
        G[u][v]['orc_weight'] = apsp_affinities[idx_u, idx_v]
        affinities.append(apsp_affinities[idx_u, idx_v])
        norm_energies.append(apsp_energy_norm[idx_u, idx_v])

    def fa(G, weight):
        spectral_init = nx.spectral_layout(G, weight='weight')
        print('Finished spectral initialization')
        # force directed layout
        node_mass = {node:1 for node in G.nodes()}
        pos = nx.forceatlas2_layout(G, pos=spectral_init, seed=42, weight=weight, node_mass=node_mass)
        print('Finished spring layout')
        return pos

    if plot:
        def plot_energy():
            from src.utils.graph_utils import energy
            orcs = np.linspace(-1.99, 1, 100)
            energies = [energy(orc) for orc in orcs]
            plt.figure()
            plt.scatter(orcs, energies)
            plt.xlabel('ORC')
            plt.ylabel('Energy')
        plot_energy()
        def plot_affinity():
            energies = np.linspace(0.01, 1, 100)
            affinities, _ = energy_to_affinity(energies)
            plt.figure()
            plt.scatter(energies, affinities)
            plt.xlabel('Normalized Energy')
            plt.ylabel('Affinity')
        plot_affinity()        
        plt.figure()
        plt.hist(norm_energies, bins=100)
        plt.xlim([0, 1])
        plt.xlabel('Normalized Energy')
        plt.ylabel('Count')
        plt.title('Energy distribution')

        plt.figure()
        plt.hist(affinities, bins=100)
        plt.xlabel('Affinity')
        plt.ylabel('Count')
        plt.title('Affinity distribution')
        
    pos = fa(G, weight='orc_weight')
    return pos, G, norm_energies, affinities

mnist_data, mnist_labels = get_mnist_data(n_samples=2000, label=None)
orcmanl = ORCManL(
    exp_params=exp_params,
    verbose=True,
    reattach=True,
)
orcmanl.fit(mnist_data)
pos, G, norm_energies, affinities = isorc_fa(orcmanl, a=4, sigma=1e-1, plot=True)
plot_graph_2D(pos, G, node_color=mnist_labels[G.nodes()], edge_color=affinities, title=None, edge_width=0.03, node_size=0.2)