In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
sf.utils

In [None]:
import copy
import os

import numpy as np
from flowtorch.distributions import Flow
import torch
import torch.distributions as dist
import matplotlib.pyplot as plt
from matplotlib_inline.backend_inline import set_matplotlib_formats
import matplotlib
import shapeflow as sf
import sklearn.datasets as datasets
import seaborn as sns
import extratorch as etorch
from signatureshape.animation.src.mayavi_animate import mayavi_animate

matplotlib.rcParams.update({"font.size": 12})
set_matplotlib_formats("pdf", "svg")
plt.style.use("tableau-colorblind10")
sns.set_style("white")

# make reproducible
seed = torch.manual_seed(0)

# better formats
set_matplotlib_formats("pdf", "svg")

In [None]:
x, y = datasets.make_moons(8 * 1024, noise=0.05)

In [None]:
plt.scatter(x[:, 0], x[:, 1])

In [None]:
# define data
x_tensor = torch.as_tensor(x, dtype=torch.float32)
data = torch.utils.data.TensorDataset(
    x_tensor,
)

In [None]:
#######
DIR = "../figures/frames/"
SET_NAME = "walk_residual"
PATH_FIGURES = os.path.join(DIR, SET_NAME)
if not os.path.exists(PATH_FIGURES):
    os.makedirs(PATH_FIGURES)
########

event_shape = data[0][0].shape
base_dist = dist.MultivariateNormal(
    torch.zeros(event_shape[0]), torch.eye(event_shape[0])
)
lr_scheduler = lambda optim: torch.optim.lr_scheduler.ReduceLROnPlateau(
    optim, mode="min", factor=0.5, patience=5, verbose=True
)
base_dist.batch_shape

In [None]:
event_shape[-1]

In [None]:
# define model

stack = 4
flows = sf.nf.get_flow(
    base_dist=base_dist,
    inverse_model=True,
    compose=True,
    get_transform=sf.transforms.NDETransform,
    get_net=[etorch.FFNN] * stack,
    activation=["tanh"] * stack,
    n_hidden_layers=[3] * stack,
    neurons=[8] * stack,
    trace_estimator=["autograd_trace"] * stack,
)
# flows = sf.nf.get_flow(
#     base_dist=base_dist,
#     inverse_model=True,
#     compose=False,
#     get_transform=sf.transforms.NDETransform,
#     get_net=etorch.FFNN,
#     activation="tanh",
#     n_hidden_layers=3,
#     neurons=16,
# )

In [None]:
results = etorch.fit_module(
    model=flows,
    batch_size=256,
    compute_loss=sf.nf.monte_carlo_dkl_loss,
    optimizer="ADAM",
    num_epochs=100,
    learning_rate=0.01,
    lr_scheduler=lr_scheduler,
    data=data,
    verbose=True,
)

In [None]:
model, hist = results

Test that the wrapper workps

In [None]:
hist.plot()
plt.show()

In [None]:
noise = base_dist.sample([100])
print("Log vals=")

print("Noise :", model.log_prob(noise[0:1]).mean().item())
print("Trian data:", model.log_prob(data[:][0]).mean().item())

In [None]:
xmin = 2
xmax = -xmin
ymax = xmax
ymin = xmin
pts = 7
gridlines = pts * 1000
xpts = np.linspace(xmin, xmax, pts)
ypts = np.linspace(ymin, ymax, pts)
xgrid = np.linspace(xmin, xmax, gridlines)
ygrid = np.linspace(ymin, ymax, gridlines)
xlines = np.stack([a.ravel() for a in np.meshgrid(xpts, ygrid)])
ylines = np.stack([a.ravel() for a in np.meshgrid(xgrid, ypts)])
grid = torch.as_tensor(np.concatenate([xlines, ylines], 1).T, dtype=torch.float32)

In [None]:
p1 = torch.tensor([[1.0, -0.5]])
p2 = torch.tensor([[0.0, 1]])

z1 = model.rnormalize(p1)
z2 = model.rnormalize(p2)

line = torch.unsqueeze(torch.linspace(0, 1, 200), 1)
interp_line_z = z1 * line + z2 * (1 - line)
interp_line_x = model.bijector.forward(interp_line_z).detach()
interp_line_x_naive = p1 * line + p2 * (1 - line)

In [None]:
fig, ax = plt.subplots(1)
t_points = model.sample([10000]).detach().numpy()
ax.scatter(
    t_points[:, 0],
    t_points[:, 1],
    marker=".",
    alpha=0.3,
    color="grey",
)
ax.plot(
    interp_line_x[:, 0],
    interp_line_x[:, 1],
    "o",
    ls="-",
    markevery=10,
    label="Latent space interp.",
    lw=2,
)
ax.plot(
    interp_line_x_naive[:, 0],
    interp_line_x_naive[:, 1],
    "-.",
    lw=2,
    label="Feature space interp.",
)
ax.set_xlim(-1.5, 2.5)
# plt.ylim(-2,2)
ax.set_aspect("equal")
# ax.scatter(p1[:,0],p1[:,1],marker="x",label="p1", color="green")
# ax.scatter(p2[:,0],p2[:,1], marker="x",label="p1",color="green")
# ax.scatter(*p2, label="p2")
ax.legend()
plt.show()

In [None]:
interp_line_z_naive = model.rnormalize(interp_line_x_naive).detach()
t = np.linspace(0, 1, 200)
plt.plot(t, model.log_prob(interp_line_x).detach(), "-")
plt.plot(t, model.log_prob(interp_line_x_naive).detach(), "-.")
plt.show()

In [None]:
with torch.no_grad():
    plt.plot(interp_line_z[:, 0], interp_line_z[:, 1], ".")
plt.plot(interp_line_z_naive[:, 0], interp_line_z_naive[:, 1], ".")
plt.xlim(-1, 1)
plt.ylim(-1, 1)
plt.plot(0.5 * np.cos(7 * t), 0.5 * np.sin(7 * t), linewidth=1)
plt.plot(0.2 * np.cos(7 * t), 0.2 * np.sin(7 * t), linewidth=1)
plt.show()

In [None]:
p1 = torch.tensor([[2.0, 0.25]])
p2 = torch.tensor([[0.0, 0.25]])
z1 = model.rnormalize(p1)
z2 = model.rnormalize(p2)
line = torch.unsqueeze(torch.linspace(0, 1, 200), 1)
interp_line_z = z1 * line + z2 * (1 - line)
interp_line_x = model.bijector.forward(interp_line_z).detach()
interp_line_x_naive = p1 * line + p2 * (1 - line)

In [None]:
fig, ax = plt.subplots(1)
t_points = model.sample([10000]).detach().numpy()
ax.scatter(
    t_points[:, 0],
    t_points[:, 1],
    marker=".",
    alpha=0.3,
    color="grey",
)
ax.plot(interp_line_x[::10, 0], interp_line_x[::10, 1], "x")
ax.plot(interp_line_x_naive[:, 0], interp_line_x_naive[:, 1], "-.")
plt.show()

In [None]:
t = np.linspace(0, 1, 200)
plt.plot(t, (model.log_prob(interp_line_x).detach()), "x")
plt.plot(t, (model.log_prob(interp_line_x_naive).detach()), "-.")
plt.show()

In [None]:
with torch.no_grad():
    plt.plot(interp_line_z[:, 0], interp_line_z[:, 1], ".")
plt.plot(interp_line_z_naive[:, 0], interp_line_z_naive[:, 1], ".")
plt.xlim(-2, 2)
plt.ylim(-2, 2)
plt.plot(np.cos(7 * t), np.sin(7 * t), linewidth=1)
plt.plot(2 * np.cos(7 * t), 2 * np.sin(7 * t), linewidth=1)
plt.show()

In [None]:
n = 100
fig, ax = plt.subplots(1)
x_axis = np.linspace(-2.5, 2.5, n)
y_axis = np.linspace(-1.5, 1.5, n)
X, Y = np.meshgrid(x_axis, y_axis)
grid = np.stack((X.ravel(), Y.ravel()), axis=1)
grid_tensor = torch.tensor(grid, dtype=torch.float32)
grid_log_prob = model.log_prob(grid_tensor)
Z_log_prob = grid_log_prob.reshape((n, n)).detach().numpy()
# sns.kdeplot(x=t_points[:, 0],  y=t_points[:, 1], shade=True)
# plt.contour(X,Y,Z_log_prob, cmap="viridis", levels=4)
ax.scatter(
    t_points[:, 0],
    t_points[:, 1],
    marker=".",
    alpha=0.2,
    color="grey",
)
ax.plot(interp_line_x[::10, 0], interp_line_x[::10, 1], "--", lw=3)
ax.plot(interp_line_x_naive[::10, 0], interp_line_x_naive[::10, 1], "-.", lw=3)
ax.set_xlim(-1.5, 2.5)
# plt.ylim(-2,2)
ax.set_aspect("equal")

In [None]:
plt.plot(np.linspace(0, 1, 200), base_dist.log_prob(interp_line_z).detach())
plt.plot(
    np.linspace(0, 1, 200),
    base_dist.log_prob(model.rnormalize(interp_line_x_naive).detach()),
)
plt.show()

In [None]:
n = 200
X, Y = np.mgrid[-2:2:200j, -2:2:200j]
grid = np.stack((X.ravel(), Y.ravel()), axis=1)
Z_0 = (
    model.log_prob(torch.tensor(grid, dtype=torch.float32))
    .detach()
    .numpy()
    .reshape((n, n))
)
# Z_1 = (
#     models[1]
#     .log_prob(torch.tensor(grid, dtype=torch.float32))
#     .detach()
#     .numpy()
#     .reshape((n, n))
# )
# Z = np.where(Z_0 > Z_1, 1, 0)
# plt.contourf(X, Y, Z, levels=2)
# plt.show()
# plt.scatter(t_points[:, 0], t_points[:, 1],marker=".",alpha=0.5, )
# # plt.show()
# points =models[1].base_dist.sample([1000])# p1*line - p1*(1-line)
# t_points = models[1].bijector.forward(points).detach().numpy()
# plt.scatter(t_points[:, 0], t_points[:, 1],marker=".",alpha=0.5,)
# plt.show()

In [None]:
plt.scatter(x[:, 0], x[:, 1])
plt.show()