### Cluster the moons dataset

In this experiment we cluster simple fake datasets. The default dataset is `moons`, although other datasets can be easily used by.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import numpy as np
import torch
import torch.distributions as dist
import matplotlib.pyplot as plt
import matplotlib
from matplotlib_inline.backend_inline import set_matplotlib_formats
import shapeflow as sf
import sklearn.datasets as datasets
import seaborn as sns
import extratorch as etorch

In [None]:
# make reproducible
seed = torch.manual_seed(0)

# better plotting
set_matplotlib_formats("pdf", "svg")
matplotlib.rcParams.update({"font.size": 12})
set_matplotlib_formats("pdf", "svg")
plt.style.use("tableau-colorblind10")
sns.set_style("white")

Load data and standardize

In [None]:
# make moons
num_total = 2048
num_supervised = 0

x, y = datasets.make_moons(num_total, noise=0.05)

In [None]:
# standardize data
x = (x - x.mean()) / x.std()
q = np.stack((y, np.abs(y - 1)), axis=-1)

plt.scatter(x[:, 0], x[:, 1], c=y)

#### Make priors
Since (pretend) we do not know the class of each observation the estimated prior probability is (0.5 + $\epsilon$,0.5 - $\epsilon$) for all observations. ($P($`class_1`$)$, $P($`class_2`$)$)


In [None]:
NUM_SUPERVISED = 0  # no supervised points

# define data as tensor
x_tensor = torch.as_tensor(x, dtype=torch.float32)

# the true posterior probability
q_tensor = torch.as_tensor(q, dtype=torch.float32)

num_points = q_tensor.shape[0]
num_classes = q_tensor.shape[-1]
priors = torch.zeros_like(q_tensor)

# break symmetry of initial culstering problem
eps = torch.rand(len(priors)) * 0.1
priors[:, 1] = 1 / num_classes + eps
priors[:, 0] = 1 / num_classes - eps

# add supervised points if wanted
if NUM_SUPERVISED > 0:
    priors[:NUM_SUPERVISED] = q_tensor[:NUM_SUPERVISED]

init_posterior = priors.clone().detach()

In [None]:
# define data for training
# data =  x, p(c_k|x), p(c_k)
data = torch.utils.data.TensorDataset(
    x_tensor,
    init_posterior,
    priors,
)

In [None]:
# define some functions that are used for training
event_shape = data[0][0].shape
base_dist = dist.MultivariateNormal(
    torch.zeros(event_shape[0]), torch.eye(event_shape[0])
)
lr_scheduler = lambda optim: torch.optim.lr_scheduler.ReduceLROnPlateau(
    optim, mode="min", factor=0.5, patience=10, verbose=False
)

Define model parameters

In [None]:
#######
DIR = "../figures/moons_cluster/"
SET_NAME = f"cnf_2"
PATH_FIGURES = os.path.join(DIR, SET_NAME)
########

TRIALS = 1

MODEL_PARAMS = {
    "model": [sf.nf.get_flow],
    "get_transform": [sf.transforms.NDETransform],
    "compose": [False],
    "activation": ["tanh"],
    "get_net": [etorch.models.FFNN],
    "n_hidden_layers": [3],
    "neurons": [8],
    "num_flows": [priors.shape[-1]],
    "base_dist": [base_dist],
    "inverse_model": [True],
    "trace_estimator": ["autograd"],
}
TRAINING_PARAMS = {
    "optimizer": ["ADAM"],
    "batch_size": [256],
    "num_epochs": [200],
    "learning_rate": [0.01],
    "verbose": [True],
    "lr_scheduler": [lr_scheduler],
    "compute_loss": [sf.nf.get_monte_carlo_conditional_dkl_loss()],
}

Train the model(s)

In [None]:
model_params_iter = etorch.create_subdictionary_iterator(MODEL_PARAMS, product=True)
training_params_iter = etorch.create_subdictionary_iterator(
    TRAINING_PARAMS, product=True
)

cv_results = etorch.k_fold_cv_grid(
    fit=etorch.fit_module,
    model_params=model_params_iter,
    training_params=training_params_iter,
    data=data,
    verbose=True,
    copy_data=True,
    trials=TRIALS,
)

Plot and store the results. Not shown in notebook, but stored in the path `PATH_FIGURES`. For other datasets you might need to change the plotting parameters below.


In [None]:
function_kwargs = dict(x_lim=(-3, 3), y_lim=(-2, 2), num_samples=500, grid_shape=500)
etorch.plotting.plot_result(
    path_figures=PATH_FIGURES,
    plot_function=sf.plotting.plot_2d_cluster,
    **cv_results,
    function_kwargs=function_kwargs
)

Plot all the points used


In [None]:
# plot all the points used
num_points = 2048

plt.scatter(x[:num_points, 0], x[:num_points, 1], marker=".", color="black")
plt.xlim(function_kwargs["x_lim"])
plt.ylim(function_kwargs["y_lim"])
plt.gca().set_aspect("equal", "box")
plt.axis("off")
plt.savefig(
    os.path.join(
        PATH_FIGURES,
        "points_total.pdf",
    ),
    bbox_inches="tight",
    pad_inches=0,
)
plt.show()

Plot all the supervised points used


In [None]:
plt.scatter(
    x[:NUM_SUPERVISED, 0][y[:NUM_SUPERVISED] == 1],
    x[:NUM_SUPERVISED, 1][y[:NUM_SUPERVISED] == 1],
    marker=".",
)
plt.scatter(
    x[:NUM_SUPERVISED, 0][y[:NUM_SUPERVISED] == 0],
    x[:NUM_SUPERVISED, 1][y[:NUM_SUPERVISED] == 0],
    marker="x",
)
plt.xlim(function_kwargs["x_lim"])
plt.ylim(function_kwargs["y_lim"])
plt.gca().set_aspect("equal", "box")
plt.axis("off")
plt.gca().set_xticklabels([])
plt.gca().set_yticklabels([])
plt.savefig(
    os.path.join(
        PATH_FIGURES,
        "points_supervised.pdf",
    ),
    bbox_inches="tight",
    pad_inches=0,
)
plt.show()