In [None]:
from src.data.data import *
from src.orcml import *
from src.plotting import *
from src.utils.graph_utils import *
from src.isorc import *
from src.utils.embeddings import *
from sklearn.manifold import TSNE
import umap
import numpy as np
import torchvision

%load_ext autoreload

exp_params = {
    'mode': 'nbrs',
    'n_neighbors': 25,
    'epsilon': None,
    'lda': 1e-5,
    'delta': 0.8
}

def orc_fa(G, a, b, scaling_ratio=2):
    A_orc = nx.to_numpy_array(G, weight='ricciCurvature')
    A_edge = nx.to_numpy_array(G, weight='weight')
    
    edge_mask = np.where(A_edge > 0, 1, 0)

    assert b >= a, "b must be greater than a"
    # map each row to [a, b]
    max_orc_row = 1
    min_orc_row = -2

    A_orc_normalized = (b - a) / (max_orc_row - min_orc_row) * (A_orc - min_orc_row) + a
    A_orc_normalized = (A_orc_normalized + A_orc_normalized.T) / 2
    A_orc_normalized *= edge_mask

    G_normalized = G.copy()
    node_indices = list(G.nodes())
    for i, (u, v) in enumerate(G_normalized.edges):
        idx_u = node_indices.index(u)
        idx_v = node_indices.index(v)
        assert A_edge[idx_u, idx_v] != 0 or u == v, "index misalignment"
        G_normalized[u][v]['orc_weight'] = A_orc_normalized[idx_u, idx_v]
    
    spectral_init = nx.spectral_layout(G_normalized, weight='weight')
    print('Finished spectral initialization')
    # force directed layout
    node_mass = {node:1 for node in G_normalized.nodes()}
    pos = nx.forceatlas2_layout(G_normalized, pos=spectral_init, seed=42, weight='orc_weight', scaling_ratio=scaling_ratio, node_mass=node_mass)
    print('Finished spring layout')
    return pos, G_normalized

def get_mnist_data(n_samples, label=None):
    """
    Get n_samples MNIST data points with the specified label. If label is None, get n_samples random data points.
    Parameters:

    n_samples: int
        Number of data points to get
    label: int or None
        Label of the data points to get. If None, get random data points.
    Returns:
    ----------
    mnist_data: np.ndarray
        n_samples x 784 array of MNIST data points
    mnist_labels: np.ndarray
        n_samples array of MNIST labels
    """
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Lambda(lambda x: x.view(-1))
    ])
    mnist = torchvision.datasets.MNIST('../data', train=True, download=True, transform=transform)
    mnist_data = torch.stack([x for x, _ in mnist]).numpy().astype(np.float64)
    mnist_labels = torch.tensor([y for _, y in mnist]).numpy().astype(np.float64)
    if label is not None:
        label_indices = np.where(mnist_labels == label)[0]
        np.random.seed(0)
        np.random.shuffle(label_indices)
        label_indices = label_indices[:n_samples]
        mnist_data = mnist_data[label_indices]
        mnist_labels = mnist_labels[label_indices]
    else:
        np.random.seed(0)
        indices = np.random.choice(mnist_data.shape[0], n_samples, replace=False)
        mnist_data = mnist_data[indices]
        mnist_labels = mnist_labels[indices]
    return mnist_data, mnist_labels

# kmnist: path data/KMNIST/t10k-images-idx3-ubyte.gz

def get_kmnist_data(n_samples, label=None):
    """
    Get n_samples KMNIST data points with the specified label. If label is None, get n_samples random data points.
    Parameters:

    n_samples: int
        Number of data points to get
    label: int or None
        Label of the data points to get. If None, get random data points.
    Returns:
    ----------
    kmnist_data: np.ndarray
        n_samples x 784 array of KMNIST data points
    kmnist_labels: np.ndarray
        n_samples array of KMNIST labels
    """
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Lambda(lambda x: x.view(-1))
    ])
    kmnist = torchvision.datasets.KMNIST('data', train=True, download=True, transform=transform)
    kmnist_data = torch.stack([x for x, _ in kmnist]).numpy().astype(np.float64)
    # scale so distances are in a reasonable range
    kmnist_labels = torch.tensor([y for _, y in kmnist]).numpy().astype(np.float64)
    if label is not None:
        label_indices = np.where(kmnist_labels == label)[0]
        np.random.seed(0)
        np.random.shuffle(label_indices)
        label_indices = label_indices[:n_samples]
        kmnist_data = kmnist_data[label_indices]
        kmnist_labels = kmnist_labels[label_indices]
    else:
        np.random.seed(0)
        indices = np.random.choice(kmnist_data.shape[0], n_samples, replace=False)
        kmnist_data = kmnist_data[indices]
        kmnist_labels = kmnist_labels[indices]
    return kmnist_data, kmnist_labels

In [78]:
%autoreload 2
# quadratics
n_points = 2000
noise = 0.175
noise_thresh = None

return_dict = quadratics(n_points=n_points, noise=noise, noise_thresh=noise_thresh, supersample=True)

In [None]:
orcmanl = ORCManL(
    exp_params=exp_params,
    verbose=True,
    reattach=True,
)
orcmanl.fit(return_dict['data'])

In [None]:
G = orcmanl.G.copy()
pos, G_normalized = orc_fa(G, -3, 5)
plot_graph_2D(pos, G, node_color=return_dict['cluster'][G.nodes()], edge_color=orcmanl.orcs, title=None, edge_width=0.05, node_size=0.2)

In [None]:
umap_emb = umap.UMAP(n_neighbors=25, min_dist=0.1, metric='euclidean').fit_transform(return_dict['data'])
plot_graph_2D(umap_emb, G, node_color=return_dict['cluster'][G.nodes()], edge_color=orcmanl.orcs, title=None, edge_width=0.05, node_size=0.2)

In [None]:
# # uniform spacing on a circle dataset
n_points = 2000
def circle_dataset(n_points):
    theta = np.linspace(0, 2 * np.pi, n_points)
    r = 1
    x = r * np.cos(theta)
    y = r * np.sin(theta)
    X = np.column_stack((x, y))
    noise = 0.1 * np.random.randn(n_points, 2)
    X += noise
    return X

data = circle_dataset(n_points)
orcmanl = ORCManL(
    exp_params=exp_params,
    verbose=True,
    reattach=True,
)
orcmanl.fit(data)

G = orcmanl.G.copy()
pos, G_normalized = orc_fa(G, 1, 1)
plot_graph_2D(pos, G, node_color=None, edge_color=orcmanl.orcs, title=None, edge_width=0.05, node_size=0.2)

umap_emb = umap.UMAP(n_neighbors=25, min_dist=0.1, metric='euclidean').fit_transform(data)
plot_graph_2D(umap_emb, orcmanl.G, node_color=None, edge_color=orcmanl.orcs, title=None, edge_width=0.05, node_size=0.2)

In [None]:
plot_graph_2D(data, orcmanl.G, node_color=None, edge_color=orcmanl.orcs, title=None, edge_width=0.05, node_size=0.2)

In [None]:
%autoreload 2
from src.isorc import *

isorc = ISORC(
    orcmanl=orcmanl,
    exp_params=exp_params,
    verbose=True,
    uniform=False,
    temperature=2
)

# knn classifier
from sklearn.neighbors import KNeighborsClassifier

n_neighbors = 200

# original data
knn = KNeighborsClassifier(n_neighbors=n_neighbors)
knn.fit(return_dict['data'], return_dict['cluster'])
print(f"Ambient nn accuracy: {knn.score(return_dict['data'], return_dict['cluster'])}")

# isorc pairwise distances
knn = KNeighborsClassifier(n_neighbors=n_neighbors, metric='precomputed')
knn.fit(isorc.apsp_energy, return_dict['cluster'][isorc.G.nodes()])
print(f"Isorc APSP nn accuracy: {knn.score(isorc.apsp_energy, return_dict['cluster'][isorc.G.nodes()])}")

# original graph pairwise distances
knn = KNeighborsClassifier(n_neighbors=n_neighbors, metric='precomputed')
knn.fit(isorc.apsp_euc, return_dict['cluster'][isorc.G.nodes()])
print(f"Euc APSP nn accuracy: {knn.score(isorc.apsp_euc, return_dict['cluster'][isorc.G.nodes()])}")

In [57]:
n_samples = 2000
mnist_data, mnist_labels = get_mnist_data(n_samples, label=None)

In [None]:
orcmanl = ORCManL(
    exp_params=exp_params,
    verbose=True,
    reattach=False
)
orcmanl.fit(mnist_data)

In [None]:
G = orcmanl.G.copy()
pos = orc_fa(G, -3, 5)
plot_graph_2D(pos, G, node_color=mnist_labels[G.nodes()], edge_color=orcmanl.orcs, title=None, edge_width=0.05, node_size=0.2)

In [None]:
# visualize orc versus normalized orc
umap_emb = umap.UMAP(n_neighbors=15, min_dist=0.1, metric='euclidean').fit_transform(mnist_data)

plt.figure()
plot_graph_2D(umap_emb, G, node_color=mnist_labels[G.nodes()], edge_color=orcmanl.orcs, title=None, edge_width=0.05, node_size=0.2)


In [None]:
%autoreload 2
from src.isorc import *
isorc = ISORC(
    orcmanl=orcmanl,
    exp_params=exp_params,
    verbose=True,
    uniform=True,
    temperature=0
)

# knn classifier
from sklearn.neighbors import KNeighborsClassifier

n_neighbors = 5

# original data
knn = KNeighborsClassifier(n_neighbors=n_neighbors)
knn.fit(mnist_data, mnist_labels)
print(f"Ambient nn accuracy: {knn.score(mnist_data, mnist_labels)}")

# isorc pairwise distances
knn = KNeighborsClassifier(n_neighbors=n_neighbors, metric='precomputed')
knn.fit(isorc.apsp_energy, mnist_labels[isorc.G.nodes()])
print(f"Isorc APSP nn accuracy: {knn.score(isorc.apsp_energy, mnist_labels[isorc.G.nodes()])}")

# original graph pairwise distances
knn = KNeighborsClassifier(n_neighbors=n_neighbors, metric='precomputed')
knn.fit(isorc.apsp_euc, mnist_labels[isorc.G.nodes()])
print(f"Euc APSP nn accuracy: {knn.score(isorc.apsp_euc, mnist_labels[isorc.G.nodes()])}")

In [None]:
umap_emb = umap.UMAP(n_components=2, n_neighbors=15, min_dist=0.1).fit_transform(mnist_data)

# plotting with orcs
plot_graph_2D(umap_emb, graph=isorc.G, node_color=mnist_labels[isorc.G.nodes()], title=None, node_size=1, edge_color=orcmanl.orcs, edge_width=0.1)

In [None]:
list(isorc.G.nodes())[1]

In [None]:
labels_reindexed = mnist_labels
edges = []
for edge in isorc.G.edges():
    if labels_reindexed[edge[0]] != labels_reindexed[edge[1]]:
        edges.append(edge)

print(len(edges), len(isorc.G.edges()))
orcs = [isorc.G.edges[edge]['ricciCurvature'] for edge in edges]

plt.hist([orcs, orcmanl.orcs], bins=30, label=['inter-class edges', 'all edges'], density=True)
# plt.hist(orcmanl.orcs, bins=100, label='all edges')

In [None]:
reversed_indices = np.zeros(len(mnist_labels)).astype(int)
for i, node in enumerate(isorc.G.nodes()):
    reversed_indices[node] = i
plot_graph_2D(emb_mds[reversed_indices], isorc.G, node_color=mnist_labels[isorc.G.nodes()], title=None, edge_width=0.25)

In [None]:
from src.utils.graph_utils import *
%autoreload 2
ts = [40]
# a_array = [0.1]
a = 10
orcs = np.arange(start=-2, stop=1, step=0.05)
plt.figure()
plt.ylim((-3, 10))
for t in ts:
    energies = [energy(orc, d=2, a=a) for orc in orcs]
    plt.scatter(orcs, energies, label=f't={t:.10f}')
plt.legend()