In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import copy
import os

import numpy as np
from flowtorch.distributions import Flow
import torch
import torch.distributions as dist
import matplotlib.pyplot as plt
from matplotlib_inline.backend_inline import set_matplotlib_formats
import shapeflow as sf
import sklearn.datasets as datasets
import seaborn as sns

from signatureshape.animation.src.mayavi_animate import mayavi_animate

from deepthermal.FFNN_model import fit_FFNN, FFNN
from deepthermal.plotting import plot_result

# make reproducible
seed = torch.manual_seed(0)

# better formats
set_matplotlib_formats("pdf", "svg")

In [None]:
x, y = datasets.make_moons(1024, noise=0.05)

In [None]:
plt.scatter(x[:, 0], x[:, 1])
x = (x - x.mean()) / x.std()
plt.scatter(x[:, 0], x[:, 1])
p = np.stack((y, np.abs(y - 1)), axis=-1)
p

In [None]:
# define data
x_tensor = torch.as_tensor(x, dtype=torch.float32)
p_tensor = torch.as_tensor(p, dtype=torch.float32)
c = torch.arange(2)
contexts = len(c)
priors = torch.zeros((len(x_tensor), contexts))
priors[:, 0] = 0.6
priors[:, 1] = 0.4
p = priors / torch.sum(priors, dim=1, keepdim=True)
data = torch.utils.data.TensorDataset(
    x_tensor,
    p,
    priors,
)

In [None]:
#######
DIR = "../figures/frames/"
SET_NAME = "walk_residual"
PATH_FIGURES = os.path.join(DIR, SET_NAME)
if not os.path.exists(PATH_FIGURES):
    os.makedirs(PATH_FIGURES)
########
FOLDS = 5

event_shape = data[0][0].shape
base_dist = dist.MultivariateNormal(
    torch.zeros(event_shape[0]), torch.eye(event_shape[0])
)
lr_scheduler = lambda optim: torch.optim.lr_scheduler.ReduceLROnPlateau(
    optim, mode="min", factor=0.5, patience=5, verbose=True
)
base_dist.batch_shape

In [None]:
event_shape[-1]

In [None]:
# define model

# stack = 4
# flows = sf.nf.get_flow(
#     base_dist=base_dist,
#     inverse_model=True,
#     compose=True,
#     get_transform=sf.transforms.NDETransform,
#     get_net=[FFNN] * stack,
#     activation=["tanh"] * stack,
#     n_hidden_layers=[3] * stack,
#     neurons=[8] * stack,
# )
flows = sf.nf.get_flow(
    base_dist=base_dist,
    inverse_model=True,
    compose=False,
    get_transform=sf.transforms.NDETransform,
    get_net=FFNN,
    activation="tanh",
    n_hidden_layers=3,
    neurons=64,
    num_flows=2,
)

In [None]:
epsilon = 0.5
results = fit_FFNN(
    model=flows,
    batch_size=128,
    compute_loss=sf.nf.get_monte_carlo_dkl_loss_conditioned(epsilon=epsilon),
    optimizer="ADAM",
    # optimizer=lambda p : torch.optim.AdamW(p, lr=2e-3, weight_decay=1e-5),
    # post_epoch=sf.nf.get_post_epoch_update_p(epsilon=epsilon),
    num_epochs=100,
    learning_rate=0.001,
    lr_scheduler=lr_scheduler,
    data=data,
    folds=FOLDS,
    verbose=True,
)

In [None]:
models, loss_history, val_history = results

Test that the wrapper workps

In [None]:
model = models[0]
plt.plot(loss_history)
plt.show()

In [None]:
noise = base_dist.sample([100])
print("Log vals=")

print("Noise :", model.log_prob(noise[0:1]).mean().item())
print("Trian data:", model.log_prob(data[:][0]).mean().item())

In [None]:
xmin = 2
xmax = -xmin
ymax = xmax
ymin = xmin
pts = 7
gridlines = pts * 1000
xpts = np.linspace(xmin, xmax, pts)
ypts = np.linspace(ymin, ymax, pts)
xgrid = np.linspace(xmin, xmax, gridlines)
ygrid = np.linspace(ymin, ymax, gridlines)
xlines = np.stack([a.ravel() for a in np.meshgrid(xpts, ygrid)])
ylines = np.stack([a.ravel() for a in np.meshgrid(xgrid, ypts)])
grid = torch.as_tensor(np.concatenate([xlines, ylines], 1).T, dtype=torch.float32)

In [None]:
p1 = torch.tensor([[1.0, -0.5]]).T
p2 = torch.tensor([[0.0, 1]]).T
line = np.linspace(0, 1, 1000)

In [None]:
# p1.T * line

In [None]:
points = models[0].base_dist.sample([1000])  # p1*line - p1*(1-line)
t_points = models[0].bijector.forward(points).detach().numpy()
plt.scatter(
    t_points[:, 0],
    t_points[:, 1],
    marker=".",
    alpha=0.5,
)
# plt.show()
points = models[1].base_dist.sample([1000])  # p1*line - p1*(1-line)
t_points = models[1].bijector.forward(points).detach().numpy()
plt.scatter(
    t_points[:, 0],
    t_points[:, 1],
    marker=".",
    alpha=0.5,
)
plt.scatter(x[:, 0], x[:, 1], alpha=0.1)
plt.show()

In [None]:
n = 200
X, Y = np.mgrid[-2:2:200j, -2:2:200j]
grid = np.stack((X.ravel(), Y.ravel()), axis=1)
Z_0 = (
    models[0]
    .log_prob(torch.tensor(grid, dtype=torch.float32))
    .detach()
    .numpy()
    .reshape((n, n))
)
Z_1 = (
    models[1]
    .log_prob(torch.tensor(grid, dtype=torch.float32))
    .detach()
    .numpy()
    .reshape((n, n))
)
Z = np.where(Z_0 > Z_1, 1, 0)
plt.contourf(X, Y, Z, levels=2)
plt.show()
# plt.scatter(t_points[:, 0], t_points[:, 1],marker=".",alpha=0.5, )
# # plt.show()
# points =models[1].base_dist.sample([1000])# p1*line - p1*(1-line)
# t_points = models[1].bijector.forward(points).detach().numpy()
# plt.scatter(t_points[:, 0], t_points[:, 1],marker=".",alpha=0.5,)
# plt.show()

In [None]:
import torchdyn.nn

sample = models[0].base_dist.sample([100])
traj = (
    models[0]
    .bijector.model.model[1]
    .trajectory(torchdyn.nn.Augmenter(1, 1)(sample), t_span=torch.linspace(1, 0, 100))
    .detach()
    .cpu()
)
traj = traj[:, :, 1:]  # scrapping first dimension := jacobian trace
n = 2000
plt.figure(figsize=(6, 6))
plt.scatter(sample[:n, 0], sample[:n, 1], s=10, alpha=0.8, c="black")
plt.scatter(traj[:, :n, 0], traj[:, :n, 1], s=0.2, alpha=0.2, c="olive")
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=4, alpha=1, c="blue")
plt.legend(["Prior sample z(S)", "Flow", "z(0)"])

In [None]:
plt.scatter(x[:, 0], x[:, 1])
plt.show()