In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import copy
import os
import itertools
import numpy as np
import torch
import torch.distributions as dist
import matplotlib.pyplot as plt
import matplotlib
from matplotlib_inline.backend_inline import set_matplotlib_formats
import shapeflow as sf
import sklearn.datasets as datasets
import seaborn as sns
import extratorch as etorch

In [None]:
# make reproducible
seed = torch.manual_seed(0)

# better plotting
set_matplotlib_formats("pdf", "svg")
matplotlib.rcParams.update({"font.size": 12})
set_matplotlib_formats("pdf", "svg")
plt.style.use("tableau-colorblind10")
sns.set_style("white")

In [None]:
num_total = 2048
num_supervised = 0

x, y = datasets.make_moons(num_total, noise=0.05)

In [None]:
# standardize
x = (x - x.mean()) / x.std()
q = np.stack((y, np.abs(y - 1)), axis=-1)

plt.scatter(x[:, 0], x[:, 1], c=y)

In [None]:
# define data
x_tensor = torch.as_tensor(x, dtype=torch.float32)
q_tensor = torch.as_tensor(q, dtype=torch.float32)
priors = torch.zeros_like(q_tensor)

eps = torch.rand(len(priors)) * 0.1
priors[:, 1] = 0.5 + eps
priors[:, 0] = 0.5 - eps
priors[:num_supervised] = q_tensor[:num_supervised]
init_posterior = priors.clone().detach()

# data current
# use notation for each row since priors are used for each row

# data =  x, p(c_k|x), p(c_k)
data = torch.utils.data.TensorDataset(
    x_tensor,
    init_posterior,
    priors,
)
# first guess p(c_k|x), p(c_k)

In [None]:
event_shape = data[0][0].shape
base_dist = dist.MultivariateNormal(
    torch.zeros(event_shape[0]), torch.eye(event_shape[0])
)
lr_scheduler = lambda optim: torch.optim.lr_scheduler.ReduceLROnPlateau(
    optim, mode="min", factor=0.5, patience=10, verbose=False
)
base_dist.batch_shape

In [None]:
#  ######
DIR = "../figures/moons_cluster/"
SET_NAME = f"cnf_2"
PATH_FIGURES = os.path.join(DIR, SET_NAME)
trials = 5
########

MODEL_PARAMS = {
    "model": [sf.nf.get_flow],
    "get_transform": [sf.transforms.NDETransform],
    "compose": [False],
    "activation": ["tanh"],
    "get_net": [etorch.models.FFNN],
    "n_hidden_layers": [3],
    "neurons": [8],
    "num_flows": [priors.shape[-1]],
    "base_dist": [base_dist],
    "inverse_model": [True],
    "trace_estimator": ["autograd_trace"],
}
TRAINING_PARAMS = {
    "optimizer": ["ADAM"],
    "batch_size": [256],
    "num_epochs": [200],
    "learning_rate": [0.01],
    "verbose": [True],
    "lr_scheduler": [lr_scheduler],
    "compute_loss": [sf.nf.get_monte_carlo_elbo_loss()],
}

In [None]:
model_params_iter = etorch.create_subdictionary_iterator(MODEL_PARAMS, product=True)
training_params_iter = etorch.create_subdictionary_iterator(
    TRAINING_PARAMS, product=True
)

cv_results = etorch.k_fold_cv_grid(
    fit=etorch.fit_module,
    model_params=model_params_iter,
    training_params=training_params_iter,
    data=data,
    verbose=True,
    copy_data=True,
    trials=trials,
)

In [None]:
function_kwargs = dict(x_lim=(-3, 3), y_lim=(-2, 2), num_samples=500, grid_shape=500)
etorch.plotting.plot_result(
    path_figures=PATH_FIGURES,
    plot_function=sf.plotting.plot_2d_cluster,
    **cv_results,
    function_kwargs=function_kwargs
)

Test that the wrapper works

In [None]:
hist = cv_results["histories"][0]
hist.plot()
plt.show()

In [None]:
# not valid
noise = base_dist.sample([100])
print("Cross loss")

for models in cv_results["models"]:
    log_x_cond_c = torch.zeros_like(q_tensor)
    for k, model_k in enumerate(models):
        log_x_cond_c[..., k] = models[k].log_prob(x_tensor)

    print((q_tensor * log_x_cond_c).mean().item())

In [None]:
print("Num parameters:")
total = 0
models = cv_results["models"][0]
total += sum([len(params) for params in models[0].parameters()])
print(total)

In [None]:
num_points = 2048

# plt.scatter(x[:500, 0], x[:500, 1], alpha=y[500], marker="x")
plt.scatter(x[:num_points, 0], x[:num_points, 1], marker=".", color="black")
plt.xlim(function_kwargs["x_lim"])
plt.ylim(function_kwargs["y_lim"])
plt.gca().set_aspect("equal", "box")
plt.axis("off")
plt.gca().set_xticklabels([])
plt.gca().set_yticklabels([])
plt.savefig(
    os.path.join(
        PATH_FIGURES,
        "points_total.pdf",
    ),
    bbox_inches="tight",
    pad_inches=0,
)
plt.show()

In [None]:
num_points = num_supervised

# plt.scatter(x[:500, 0], x[:500, 1], alpha=y[500], marker="x")
plt.scatter(
    x[:num_points, 0][y[:num_points] == 1],
    x[:num_points, 1][y[:num_points] == 1],
    marker=".",
)
plt.scatter(
    x[:num_points, 0][y[:num_points] == 0],
    x[:num_points, 1][y[:num_points] == 0],
    marker="x",
)
plt.xlim(function_kwargs["x_lim"])
plt.ylim(function_kwargs["y_lim"])
plt.gca().set_aspect("equal", "box")
plt.axis("off")
plt.gca().set_xticklabels([])
plt.gca().set_yticklabels([])
plt.savefig(
    os.path.join(
        PATH_FIGURES,
        "points_supervised.pdf",
    ),
    bbox_inches="tight",
    pad_inches=0,
)
plt.show()