### Interpolate the moons dataset


In this experiment we interpolate between two points in the `moons` dataset. We compare both latent space and feature space interpolation.


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import numpy as np
import torch
import torch.distributions as dist
import matplotlib.pyplot as plt
from matplotlib_inline.backend_inline import set_matplotlib_formats
import matplotlib
import shapeflow as sf
import sklearn.datasets as datasets
import seaborn as sns
import extratorch as etorch

In [None]:
# make reproducible
seed = torch.manual_seed(0)

# better plotting
set_matplotlib_formats("pdf", "svg")
matplotlib.rcParams.update({"font.size": 12})
plt.style.use("tableau-colorblind10")
sns.set_style("white")

Load data and plot it


In [None]:
x, y = datasets.make_moons(8 * 1024, noise=0.05)

mean = x.mean()
std = x.std()
q = np.stack((y, np.abs(y - 1)), axis=-1)

# standardize
x_tensor = torch.as_tensor((x - mean) / std, dtype=torch.float32)

plt.scatter(
    x_tensor[:, 0],
    x_tensor[:, 1],
)
plt.show()

In [None]:
# define dataset used for training
data = torch.utils.data.TensorDataset(
    x_tensor,
)

Define model parameters

In [None]:
#######
DIR = "../figures/interpolate_moons/"
SET_NAME = "cnf_2"
PATH_FIGURES = os.path.join(DIR, SET_NAME)
########

event_shape = data[0][0].shape
base_dist = dist.MultivariateNormal(
    torch.zeros(event_shape[0]), torch.eye(event_shape[0])
)
lr_scheduler = lambda optim: torch.optim.lr_scheduler.ReduceLROnPlateau(
    optim, mode="min", factor=0.5, patience=5, verbose=True
)

In [None]:
# define model, use stacked flow

stack = 4
flows = sf.nf.get_flow(
    base_dist=base_dist,
    inverse_model=True,
    compose=True,
    get_transform=sf.transforms.NDETransform,
    get_net=[etorch.FFNN] * stack,
    activation=["tanh"] * stack,
    n_hidden_layers=[3] * stack,
    neurons=[8] * stack,
    trace_estimator=["autograd"] * stack,
)

Train model

In [None]:
results = etorch.fit_module(
    model=flows,
    batch_size=256,
    compute_loss=sf.nf.monte_carlo_dkl_loss,
    optimizer="ADAM",
    num_epochs=100,
    learning_rate=0.01,
    lr_scheduler=lr_scheduler,
    data=data,
    verbose=True,
)

In [None]:
# get model from results
model, hist = results

In [None]:
noise = base_dist.sample([100])
print("Log vals=")

print("Noise :", model.log_prob(noise[:]).mean().item())
print("Train data:", model.log_prob(data[:][0]).mean().item())

### Interpolation

Interpolate between two points

In [None]:
# Define two points to interpolate
p1 = (torch.tensor([[2.0, 0.25]]) - mean) / std
p2 = (torch.tensor([[0.0, 0.25]]) - mean) / std

# get latent space points
z1 = model.rnormalize(p1)
z2 = model.rnormalize(p2)

# define line and interpolate in latent space
line = torch.unsqueeze(torch.linspace(0, 1, 200), 1)
interp_line_z = z1 * line + z2 * (1 - line)

# transform to feature space
interp_line_x = model.bijector.forward(interp_line_z).detach()

# linear feature space interpolation
interp_line_x_feature = p1 * line + p2 * (1 - line)

Plot lines in feature space

In [None]:
fig, ax = plt.subplots(1)
t_points = model.sample([10000]).detach().numpy()
ax.scatter(
    t_points[:, 0],
    t_points[:, 1],
    marker=".",
    alpha=0.3,
    color="grey",
    label="Generated samples",
)
ax.plot(
    interp_line_x[:, 0],
    interp_line_x[:, 1],
    "o",
    ls="-",
    markevery=10,
    label="Latent space interp.",
    lw=2,
)
ax.plot(
    interp_line_x_feature[:, 0],
    interp_line_x_feature[:, 1],
    "-.",
    lw=2,
    label="Feature space interp.",
)
ax.set_xlim(-3, 3)
ax.set_ylim(-2, 2)
ax.set_aspect("equal", "box")
ax.axis("off")
ax.legend()
fig.savefig(
    os.path.join(
        PATH_FIGURES,
        "interpolation_path.pdf",
    ),
    bbox_inches="tight",
    pad_inches=0,
)
plt.show()

Plot probability density of the paints in the lines

In [None]:
interp_line_z_naive = model.rnormalize(interp_line_x_feature).detach()
t = np.linspace(0, 1, 200)
plt.plot(t, model.log_prob(interp_line_x).detach(), "-", label="Latent space interp.")
plt.plot(
    t,
    model.log_prob(interp_line_x_feature).detach(),
    "-.",
    label="Feature space interp.",
)
plt.legend()
plt.xlabel("$t$")
plt.ylabel("$\log p_{T(Z)}$")
plt.savefig(
    os.path.join(
        PATH_FIGURES,
        "interpolation_log_prob.pdf",
    ),
    bbox_inches="tight",
    pad_inches=0,
)
plt.show()