In [1]:
from src.data.data import *
from src.orcml import *
from src.plotting import *
from src.utils.graph_utils import *
from src.isorc import *
from src.utils.embeddings import *
from sklearn.manifold import TSNE
import umap
import torch
import torchvision
%load_ext autoreload

In [2]:

def get_mnist_data(n_samples, label=None):
    """
    Get n_samples MNIST data points with the specified label. If label is None, get n_samples random data points.
    Parameters:

    n_samples: int
        Number of data points to get
    label: int or None
        Label of the data points to get. If None, get random data points.
    Returns:
    ----------
    mnist_data: np.ndarray
        n_samples x 784 array of MNIST data points
    mnist_labels: np.ndarray
        n_samples array of MNIST labels
    """
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Lambda(lambda x: x.view(-1))
    ])
    mnist = torchvision.datasets.MNIST('../data', train=True, download=True, transform=transform)
    mnist_data = torch.stack([x for x, _ in mnist]).numpy().astype(np.float64)
    mnist_labels = torch.tensor([y for _, y in mnist]).numpy().astype(np.float64)
    if label is not None:
        label_indices = np.where(mnist_labels == label)[0]
        np.random.seed(0)
        np.random.shuffle(label_indices)
        label_indices = label_indices[:n_samples]
        mnist_data = mnist_data[label_indices]
        mnist_labels = mnist_labels[label_indices]
    else:
        np.random.seed(0)
        indices = np.random.choice(mnist_data.shape[0], n_samples, replace=False)
        mnist_data = mnist_data[indices]
        mnist_labels = mnist_labels[indices]
    return mnist_data, mnist_labels

# kmnist: path data/KMNIST/t10k-images-idx3-ubyte.gz

def get_kmnist_data(n_samples, label=None):
    """
    Get n_samples KMNIST data points with the specified label. If label is None, get n_samples random data points.
    Parameters:

    n_samples: int
        Number of data points to get
    label: int or None
        Label of the data points to get. If None, get random data points.
    Returns:
    ----------
    kmnist_data: np.ndarray
        n_samples x 784 array of KMNIST data points
    kmnist_labels: np.ndarray
        n_samples array of KMNIST labels
    """
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Lambda(lambda x: x.view(-1))
    ])
    kmnist = torchvision.datasets.KMNIST('data', train=True, download=True, transform=transform)
    kmnist_data = torch.stack([x for x, _ in kmnist]).numpy().astype(np.float64)
    # scale so distances are in a reasonable range
    kmnist_labels = torch.tensor([y for _, y in kmnist]).numpy().astype(np.float64)
    if label is not None:
        label_indices = np.where(kmnist_labels == label)[0]
        np.random.seed(0)
        np.random.shuffle(label_indices)
        label_indices = label_indices[:n_samples]
        kmnist_data = kmnist_data[label_indices]
        kmnist_labels = kmnist_labels[label_indices]
    else:
        np.random.seed(0)
        indices = np.random.choice(kmnist_data.shape[0], n_samples, replace=False)
        kmnist_data = kmnist_data[indices]
        kmnist_labels = kmnist_labels[indices]
    return kmnist_data, kmnist_labels

In [4]:
mnist_data, mnist_labels = get_mnist_data(4000)

In [None]:
exp_params = {
    'mode': 'nbrs',
    'n_neighbors': 15,
    'epsilon': None,
    'lda': 0.01,
    'delta': 0.8
}
orcmanl = ORCManL(
    exp_params=exp_params,
    verbose=True,
    reattach=True
)
orcmanl.fit(mnist_data)

In [None]:
%autoreload 2
import gc
torch.cuda.empty_cache()
gc.collect()
isorc = ISORC(orcmanl, init='spectral', dim=2)
X_opt, X_frames = isorc.fit_graph(lr=0.2, n_iter=2000)

In [None]:
# get list of shortcut edges from orcmanl
shortcut_edges_indices = orcmanl.shortcut_edges
# convert to binary array
shortcut_edges = np.zeros(len(orcmanl.G.edges))
shortcut_edges[shortcut_edges_indices] = 1

In [None]:
%autoreload 2
plot_graph_2D(X_opt, orcmanl.G, title=None, node_color=mnist_labels[orcmanl.G.nodes()], edge_width=0.1, edge_color=shortcut_edges, node_size=0.1)

In [None]:
kmnist_data, kmnist_labels = get_kmnist_data(2000)

In [None]:
orcmanl = ORCManL(
    exp_params=exp_params,
    verbose=True,
    reattach=True
)
orcmanl.fit(kmnist_data)

In [None]:
%autoreload 2
import gc
torch.cuda.empty_cache()
gc.collect()
isorc = ISORC(orcmanl, init='ambient', dim=2)
X_opt, losses = isorc.fit_graph(lr=0.2, n_iter=100)

In [None]:
plt.plot(losses)

In [None]:
# get list of shortcut edges from orcmanl
shortcut_edges_indices = orcmanl.shortcut_edges
# convert to binary array
shortcut_edges = np.zeros(len(orcmanl.G.edges))
shortcut_edges[shortcut_edges_indices] = 1

In [None]:
import sklearn
X_opt_spectral = sklearn.manifold.SpectralEmbedding(n_components=2).fit_transform(kmnist_data)

In [None]:
%autoreload 2
plot_graph_2D(X_opt_spectral, orcmanl.G, title=None, node_color=kmnist_labels[orcmanl.G.nodes()], edge_width=0.1, edge_color=shortcut_edges, node_size=0.1)

In [None]:
#umap
umap_data = umap.UMAP(n_neighbors=15, min_dist=0.1).fit_transform(kmnist_data)
plot_graph_2D(umap_data, orcmanl.G, title=None, node_color=kmnist_labels[orcmanl.G.nodes()], edge_width=0.1, edge_color=shortcut_edges, node_size=0.1)

In [None]:
from sklearn.manifold import Isomap
iso = Isomap(n_components=2)
X_iso = iso.fit_transform(kmnist_data)
plot_graph_2D(X_iso, orcmanl.G, title=None, node_color=kmnist_labels[orcmanl.G.nodes()], edge_width=0.1)

In [None]:
# get list of shortcut edges from orcmanl
shortcut_edges_indices = orcmanl.shortcut_edges
# convert to binary array
shortcut_edges = np.zeros(len(orcmanl.G.edges))
shortcut_edges[shortcut_edges_indices] = 1

In [None]:
X_tsne = TSNE(n_components=2).fit_transform(mnist_data)
plot_graph_2D(X_tsne, orcmanl.G, title=None, node_color=mnist_labels[orcmanl.G.nodes()], edge_width=0.1, edge_color=shortcut_edges)

In [None]:
X_umap = umap.UMAP(n_components=2).fit_transform(mnist_data)
plot_graph_2D(X_umap, orcmanl.G, title=None, node_color=mnist_labels[orcmanl.G.nodes()], edge_width=0.1)

In [None]:
plot_graph_2D(X_umap, orcmanl.G_pruned, title=None, node_color=mnist_labels[orcmanl.G_pruned.nodes()], edge_width=0.1)