In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import copy
import os

import numpy as np
from flowtorch.distributions import Flow
import torch
import torch.distributions as dist
import matplotlib.pyplot as plt
from matplotlib_inline.backend_inline import set_matplotlib_formats
import shapeflow as sf
import sklearn.datasets as datasets
import seaborn as sns

from signatureshape.animation.src.mayavi_animate import mayavi_animate

from deepthermal.FFNN_model import fit_FFNN, FFNN
from deepthermal.plotting import plot_result
import neural_reparam.models as nrmodels

# make reproducible
seed = torch.manual_seed(0)

# better formats
set_matplotlib_formats("pdf", "svg")

In [None]:
x, y = datasets.make_moons(1024 * 8, noise=0.06)

In [None]:
plt.scatter(x[:, 0], x[:, 1]);

In [None]:
# define data
data = torch.utils.data.TensorDataset(torch.as_tensor(x, dtype=torch.float32))

In [None]:
#######
DIR = "../figures/frames/"
SET_NAME = "walk_residual"
PATH_FIGURES = os.path.join(DIR, SET_NAME)
if not os.path.exists(PATH_FIGURES):
    os.makedirs(PATH_FIGURES)
########
FOLDS = 5

event_shape = data[0][0].shape
base_dist = dist.MultivariateNormal(torch.zeros(event_shape[0]), torch.eye(event_shape[0]))
lr_scheduler = lambda optim: torch.optim.lr_scheduler.ReduceLROnPlateau(
    optim, mode="min", factor=0.5, patience=10, verbose=True
)
base_dist.batch_shape

In [None]:
event_shape[-1]

In [None]:
# define model

stack = 4
bijector = sf.nf.get_bijector(
    compose=True,
    get_transform=sf.transforms.NDETransform,
    get_net=[FFNN]*stack,
    activation=["tanh"]*stack,
    # trace_est = "auto",
    n_hidden_layers=[3]*stack,
    neurons=[8]*stack,
)
stack = 4
bijector = sf.nf.get_bijector(
    compose=False,
    get_transform=sf.transforms.NDETransform,
    get_net=FFNN,
    activation="tanh",
    trace_est = "auto",
    n_hidden_layers=3,
    neurons=64,
)

flow_model = Flow(base_dist=base_dist, bijector=bijector)

In [None]:
results = fit_FFNN(
    model=flow_model,
    batch_size=256,
    compute_loss=sf.nf.monte_carlo_dkl_loss,
    optimizer="ADAM",
    # optimizer=lambda p : torch.optim.AdamW(p, lr=2e-3, weight_decay=1e-5),
    num_epochs=20,
    learning_rate=0.01,
    # lr_scheduler=lr_scheduler,
    data=data,
    folds=FOLDS,
    verbose=True,
)

In [None]:
model, loss_history, val_history = results

Test that the wrapper workps

In [None]:
plt.plot(loss_history)
plt.show()

In [None]:
noise = base_dist.sample([100])
print("Log vals=")

print("Noise :", model.log_prob(noise[0:1]).mean().item())
print("Trian data:", model.log_prob(data[:][0]).mean().item())

In [None]:
norm_sample = base_dist.sample([1000])
model.bijector.model.nde.t_span = torch.linspace(1,0,2)
sample = model.sample([10000])
print(model.bijector.model)
print(norm_sample.shape, sample.shape)
plt.scatter(sample[:,0], sample[:,1], alpha=0.5)
plt.show()
#
# import  torch.distributions.transforms

In [None]:
y,= data[:]
plt.scatter(y[:,0], y[:,1])
plt.show()