In [None]:
from src.data.data import *
from src.orcml import *
from src.plotting import *
from src.utils.graph_utils import *
from src.isorc import *
from src.utils.embeddings import *
from sklearn.manifold import TSNE
import umap
%load_ext autoreload

exp_params = {
    'mode': 'nbrs',
    'n_neighbors': 15,
    'epsilon': None,
    'lda': 0.01,
    'delta': 0.8
}

In [292]:
n_points = 4000
noise = 6.2
noise_thresh = 2.2

dataset_info = {
    'name': '3D_swiss_roll',
    'n_points': n_points,
    'noise': noise,
    'noise_thresh': noise_thresh
}

return_dict = swiss_roll(n_points=n_points, noise=noise, noise_thresh=noise_thresh, supersample=True, dim=3, hole=False)
swiss_roll_data, color, cluster, swiss_roll_supersample, subsample_indices = return_dict['data'], return_dict['color'], return_dict['cluster'], return_dict['data_supersample'], return_dict['subsample_indices']

In [None]:
%autoreload 2
orcmanl = ORCManL(
    exp_params=exp_params,
    verbose=True,
    reattach=False,
)
orcmanl.fit(return_dict['data'])

In [None]:
%autoreload 2
from src.isorc import *

isorc = ISORC(
    orcmanl=orcmanl,
    exp_params=exp_params,
    verbose=True,
    temperature=5
)
emb_mds = isorc.fit_MDS()
emb_iso = isorc.fit_isomap(n_components=2)
plot_data_2D(emb_mds, color=color[isorc.G.nodes()], title=None)
plot_data_2D(emb_iso, color=color[isorc.G.nodes()], title=None)

In [None]:
# get list of shortcut edges from orcmanl
shortcut_edges_indices = orcmanl.shortcut_edges
# convert to binary array
shortcut_edges = np.zeros(len(orcmanl.G.edges))
shortcut_edges[shortcut_edges_indices] = 1

plt.scatter(orcmanl.orcs, np.log(isorc.lb_dists), c=shortcut_edges)


In [None]:
umap_emb = umap.UMAP(n_components=2, n_neighbors=15, min_dist=0.1).fit_transform(return_dict['data'])
plot_data_2D(umap_emb, color=color, title=None)


In [361]:
%autoreload 2
# quadratics
n_points = 2500
noise = 0.15
noise_thresh = 0.4

return_dict = quadratics(n_points=n_points, noise=noise, noise_thresh=noise_thresh, supersample=True)

In [None]:
orcmanl = ORCManL(
    exp_params=exp_params,
    verbose=True,
    reattach=False
)
orcmanl.fit(return_dict['data'])

In [None]:
%autoreload 2
from src.isorc import *

isorc = ISORC(
    orcmanl=orcmanl,
    exp_params=exp_params,
    verbose=True,
    temperature=5
)
emb_mds = isorc.fit_MDS()
emb_iso = isorc.fit_isomap(n_components=2)
plot_data_2D(emb_mds, color=return_dict['cluster'][isorc.G.nodes()], title=None)
plot_data_2D(emb_iso, color=return_dict['cluster'][isorc.G.nodes()], title=None)

In [None]:
plot_graph_2D(return_dict['data'], orcmanl.G_ann, title=None, edge_color=orcmanl.orcs)

In [None]:
umap_emb = umap.UMAP(n_components=2, n_neighbors=15).fit_transform(return_dict['data'])
plot_data_2D(umap_emb, color=return_dict['cluster'], title=None)

In [366]:
import numpy as np
import torchvision

def get_mnist_data(n_samples, label=None):
    """
    Get n_samples MNIST data points with the specified label. If label is None, get n_samples random data points.
    Parameters:

    n_samples: int
        Number of data points to get
    label: int or None
        Label of the data points to get. If None, get random data points.
    Returns:
    ----------
    mnist_data: np.ndarray
        n_samples x 784 array of MNIST data points
    mnist_labels: np.ndarray
        n_samples array of MNIST labels
    """
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Lambda(lambda x: x.view(-1))
    ])
    mnist = torchvision.datasets.MNIST('../data', train=True, download=True, transform=transform)
    mnist_data = torch.stack([x for x, _ in mnist]).numpy().astype(np.float64)
    mnist_labels = torch.tensor([y for _, y in mnist]).numpy().astype(np.float64)
    if label is not None:
        label_indices = np.where(mnist_labels == label)[0]
        np.random.seed(0)
        np.random.shuffle(label_indices)
        label_indices = label_indices[:n_samples]
        mnist_data = mnist_data[label_indices]
        mnist_labels = mnist_labels[label_indices]
    else:
        np.random.seed(0)
        indices = np.random.choice(mnist_data.shape[0], n_samples, replace=False)
        mnist_data = mnist_data[indices]
        mnist_labels = mnist_labels[indices]
    return mnist_data, mnist_labels

# kmnist: path data/KMNIST/t10k-images-idx3-ubyte.gz

def get_kmnist_data(n_samples, label=None):
    """
    Get n_samples KMNIST data points with the specified label. If label is None, get n_samples random data points.
    Parameters:

    n_samples: int
        Number of data points to get
    label: int or None
        Label of the data points to get. If None, get random data points.
    Returns:
    ----------
    kmnist_data: np.ndarray
        n_samples x 784 array of KMNIST data points
    kmnist_labels: np.ndarray
        n_samples array of KMNIST labels
    """
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Lambda(lambda x: x.view(-1))
    ])
    kmnist = torchvision.datasets.KMNIST('data', train=True, download=True, transform=transform)
    kmnist_data = torch.stack([x for x, _ in kmnist]).numpy().astype(np.float64)
    # scale so distances are in a reasonable range
    kmnist_labels = torch.tensor([y for _, y in kmnist]).numpy().astype(np.float64)
    if label is not None:
        label_indices = np.where(kmnist_labels == label)[0]
        np.random.seed(0)
        np.random.shuffle(label_indices)
        label_indices = label_indices[:n_samples]
        kmnist_data = kmnist_data[label_indices]
        kmnist_labels = kmnist_labels[label_indices]
    else:
        np.random.seed(0)
        indices = np.random.choice(kmnist_data.shape[0], n_samples, replace=False)
        kmnist_data = kmnist_data[indices]
        kmnist_labels = kmnist_labels[indices]
    return kmnist_data, kmnist_labels

In [375]:
n_samples = 10000
mnist_data, mnist_labels = get_mnist_data(n_samples, label=None)

In [None]:
orcmanl = ORCManL(
    exp_params=exp_params,
    verbose=True,
    reattach=False
)
orcmanl.fit(mnist_data)

In [None]:
isorc = ISORC(
    orcmanl=orcmanl,
    exp_params=exp_params,
    verbose=True,
    temperature=5
)
emb_mds = isorc.fit_MDS()
emb_iso = isorc.fit_isomap()
plot_data_2D(emb_mds, color=mnist_labels[isorc.G.nodes()], title=None)
plot_data_2D(emb_iso, color=mnist_labels[isorc.G.nodes()], title=None)