In [None]:
%load_ext autoreload
%autoreload 2

### Test wrapper on moons

Test that we can use `flowtorch` with transformations from `nflows`. 

Example is modifies  from [nflows](https://github.com/bayesiains/nflows)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib_inline.backend_inline import set_matplotlib_formats
from tqdm import tqdm
import os
from itertools import chain
import seaborn as sns
import torch

import flowtorch.distributions as ftdist

from nflows.flows.base import Flow
from nflows.distributions.normal import StandardNormal
from nflows.transforms.base import CompositeTransform
from nflows.transforms.autoregressive import MaskedAffineAutoregressiveTransform
from nflows.transforms.permutations import ReversePermutation
from nflows.flows import realnvp, autoregressive

from animation import fetch_animation_id_set, fetch_animations
from so3.curves import move_origin_to_zero as move_origin_to_zero_so3
from so3.transformations import hatinv
from linear import animation_to_SO3
from linear.curves import move_origin_to_zero

from deepthermal.validation import (
    create_subdictionary_iterator,
    k_fold_cv_grid,
    add_dictionary_iterators,
)
from deepthermal.FFNN_model import fit_FFNN
from deepthermal.plotting import plot_result

from shapeflow import ModuleBijector, WrapInverseModel, monte_carlo_dkl_loss

# make reproducible
seed = torch.manual_seed(0)

# better formats
set_matplotlib_formats("pdf", "svg")

In [None]:
# len(fetch_animation_id_set(description="walk"))

# # fetch data as so3
# print("loading mocap data:")
# animation_tuples = fetch_animations(1000, description="walk")

In [None]:
# # reshape data
# skel, anim, desk = animation_tuples[0]
# # root elem
# num_bones = len(skel.bones) + 1
# num_anims =  len(animation_tuples)
# total_frames = sum(map(lambda t : len(t[1].get_frames()), animation_tuples))

# for s, a, d in animation_tuples:
#     assert len(s.bones) +1 == num_bones , f"not {num_bones} bones"

# anim_array = torch.cat([torch.as_tensor(animation_to_SO3(*t[:2])) for t in  tqdm(animation_tuples)], 1)
# frames = torch.moveaxis(anim_array,0,1)

In [None]:
# save data since it takes so long to get
# torch.save(frames, 'walk_frames.pt')
frames = torch.flatten(
    torch.load(
        "walk_frames.pt",
    ).float(),
    1,
)

In [None]:
data = torch.utils.data.TensorDataset(frames)

In [None]:
# Set up model
num_layers = 5
event_shape = 23 * 3 * 3
base_dist = torch.distributions.Independent(
    torch.distributions.Normal(torch.zeros(event_shape), torch.ones(event_shape)), 1
)

transforms = []
for _ in range(num_layers):
    transforms.append(ReversePermutation(features=event_shape))
    transforms.append(
        MaskedAffineAutoregressiveTransform(features=event_shape, hidden_features=4)
    )
transform = CompositeTransform(transforms)
bijector = WrapInverseModel(model=transform)

# flow = ftdist.Flow(bijector=bijector, base_dist=base_dist)
# optimizer = optim.Adam(flow.parameters())

In [None]:
#######
DIR = "../figures/curve_1/"
SET_NAME = "walk_1"
PATH_FIGURES = os.path.join(DIR, SET_NAME)
if not os.path.exists(PATH_FIGURES):
    os.makedirs(PATH_FIGURES)
########


FOLDS = 2
lr_scheduler = lambda optim: torch.optim.lr_scheduler.ReduceLROnPlateau(
    optim, mode="min", factor=0.2, patience=5, verbose=True
)
# def get_flow(ji)
MODEL_PARAMS = {"model": [ftdist.Flow]}
MODEL_PARAMS_EXPERIMENT = {
    "bijector": [bijector],
    "base_dist": [base_dist],
}

TRAINING_PARAMS = {
    "batch_size": [100],
    "regularization_param": [0.0],
    "compute_loss": [monte_carlo_dkl_loss],
}
# extend the previous dict with the zip of this
TRAINING_PARAMS_EXPERIMENT = {
    "optimizer": ["ADAM"],
    "num_epochs": [10],
    "learning_rate": [0.01],
    # "lr_scheduler": [lr_scheduler],
}

In [None]:
# create iterators
model_params_iter_1 = create_subdictionary_iterator(MODEL_PARAMS)
# model_params_iter = chain.from_iterable((model_params_iter_1, model_params_iter_2))

model_exp_iter = create_subdictionary_iterator(MODEL_PARAMS_EXPERIMENT, product=False)
exp_model_params_iter = add_dictionary_iterators(model_exp_iter, model_params_iter_1)

training_params_iter = create_subdictionary_iterator(TRAINING_PARAMS)
training_exp_iter = create_subdictionary_iterator(
    TRAINING_PARAMS_EXPERIMENT, product=False
)
exp_training_params_iter = add_dictionary_iterators(
    training_exp_iter, training_params_iter
)

In [None]:
cv_results = k_fold_cv_grid(
    model_params=exp_model_params_iter,
    fit=fit_FFNN,
    training_params=exp_training_params_iter,
    data=data,
    folds=FOLDS,
    verbose=True,
    trials=1,
    partial=True,
)

Test that the wrapper works

In [None]:
# x_train_ = x_train.detach()
# x_sorted, indices = torch.sort(x_train_, dim=0)

# plot_kwargs = {
#     "x_test": x_sorted,
#     "x_train": x_sorted,
#     "y_train": c2.ksi(x_sorted),
#     "x_axis": "t",
#     "y_axis": "$\\varphi(t)$",
#     "compare_label": "analytical solution",
# }
plot_result(
    path_figures=PATH_FIGURES,
    **cv_results,
    # plot_function=plot_model_1d,
    # function_kwargs=plot_kwargs,
)

In [None]:
flow = cv_results["models"][0][0]
sample = flow.sample([10]).data
print(sample.shape)

In [None]:
cv_results
import warnings

warnings.warn("No loss history to plot")

warning.warning