In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import copy
import os

import numpy as np
import torch
import torch.distributions as dist
import matplotlib.pyplot as plt
from matplotlib_inline.backend_inline import set_matplotlib_formats
import seaborn as sns
from tqdm import tqdm
import flowtorch as ft

import nflows
import normflow as nf

from signatureshape.animation import fetch_animation_id_set, fetch_animations
from signatureshape.animation.src.mayavi_animate import mayavi_animate

from deepthermal.validation import (
    create_subdictionary_iterator,
    k_fold_cv_grid,
    add_dictionary_iterators,
)

from deepthermal.FFNN_model import fit_FFNN
from deepthermal.plotting import plot_result
import experiments.curves as c1

import shapeflow as sf

# make reproducible
seed = torch.manual_seed(0)

# better formats
set_matplotlib_formats("pdf", "svg")

In [None]:
# fetch data as so3
# we assume all have the same skeleton

N = 32

times = torch.linspace(0, 1, N)

reparam_list = []

for i in tqdm(range(5000)):
    c2_data_reparam = sf.reparam.reparam_curve(
        curve=c1.c_1, times=times, max_step=N // 2
    )
    reparam_list.append(c2_data_reparam)

c2_reparams = torch.as_tensor(np.stack(reparam_list))
c2_reparams.shape

In [None]:
plt.plot(c2_reparams[9])

In [None]:
data = torch.utils.data.TensorDataset(c2_reparams[::2].float())
data_val = torch.utils.data.TensorDataset(c2_reparams[1::2].float())

In [None]:
data

In [None]:
test = ft.bijectors.Autoregressive(
    ft.parameters.DenseAutoregressive(hidden_dims=(N, N))
)
b = test(shape=(2))
b

In [None]:
#######
DIR = "../figures/c1_c2_repara/"
SET_NAME = "walk_residual"
PATH_FIGURES = os.path.join(DIR, SET_NAME)
if not os.path.exists(PATH_FIGURES):
    os.makedirs(PATH_FIGURES)
########
FOLDS = 5
#######

event_shape = data[0][0].shape
base_dist = dist.Independent(
    dist.Normal(loc=torch.zeros(event_shape), scale=torch.ones(event_shape)), 1
)

lr_scheduler = lambda optim: torch.optim.lr_scheduler.ReduceLROnPlateau(
    optim, mode="min", factor=0.5, patience=50, verbose=True
)

# def get_flow(ji)
MODEL_PARAMS = {
    "model": [ft.distributions.Flow],
    "bijector": [
        ft.bijectors.AffineAutoregressive(
            ft.parameters.DenseAutoregressive(hidden_dims=(N, N))
        )
    ],
    # "input_dim" : [event_shape[0]],
    # "inverse_model": [False],
    # "compose": [True],
}
num_layers = 2
MODEL_PARAMS_EXPERIMENT = {
    "base_dist": [base_dist],
    # "hidden_dims": [(N,N)],
    # "hidden_layers": [[2] * num_layers],
    # "n_exact_terms": [[4] * num_layers],
    # "n_samples": [[10] * num_layers],
}

TRAINING_PARAMS = {
    "batch_size": [1000],
    "regularization_param": [0],
    "compute_loss": [sf.monte_carlo_dkl_loss],
    "post_batch": [sf.get_post_step_lipchitz(5)],
}
# extend the previous dict with the zip of this
TRAINING_PARAMS_EXPERIMENT = {
    "verbose_interval": [20],
    "optimizer": ["ADAM"],
    "num_epochs": [300],
    "learning_rate": [0.1],
    "lr_scheduler": [lr_scheduler],
}

In [None]:
# create iterators
model_params_iter_1 = create_subdictionary_iterator(MODEL_PARAMS)

model_exp_iter = create_subdictionary_iterator(MODEL_PARAMS_EXPERIMENT, product=False)
exp_model_params_iter = add_dictionary_iterators(model_exp_iter, model_params_iter_1)

training_params_iter = create_subdictionary_iterator(TRAINING_PARAMS)
training_exp_iter = create_subdictionary_iterator(
    TRAINING_PARAMS_EXPERIMENT, product=False
)
exp_training_params_iter = add_dictionary_iterators(
    training_exp_iter, training_params_iter
)

In [None]:
cv_results = k_fold_cv_grid(
    model_params=exp_model_params_iter,
    fit=fit_FFNN,
    training_params=exp_training_params_iter,
    data=data,
    val_data=data_val,
    folds=FOLDS,
    verbose=True,
    trials=1,
    partial=True,
    shuffle_folds=False,
)

Test that the wrapper works

In [None]:
plot_result(
    path_figures=PATH_FIGURES,
    **cv_results,
    # plot_function=plot_model_1d,
    # function_kwargs=plot_kwargs,
)

In [None]:
flow = cv_results["models"][0][0]

In [None]:
# torch.save(flow, "Flow_frames" + SET_NAME + "2.pt")

In [None]:
test_func = lambda t: np.cos(t * 2 * np.pi)
for i in tqdm(range(5000)):
    c2_data_reparam = sf.reparam.reparam_curve(
        curve_data=test_func, times=times, max_step=N // 2
    )
    reparam_list.append(c2_data_reparam)

c2_reparams_2 = torch.as_tensor(np.stack(reparam_list)).float()

In [None]:
noise = base_dist.sample([100])
print("Log vals:")

print("Noise :", flow.log_prob(noise[0:1]).mean().item())
print("Cos dat :", flow.log_prob(c2_reparams_2[0:5000]).mean().item())
print("Train data:", flow.log_prob(data[:][0]).mean().item())
print("Train data:", flow.log_prob(data_val[:][0]).mean().item())

In [None]:
plt.plot(data[115][0])

In [None]:
# get_two motions
# x_first_frame = torch.tensor(np.deg2rad(walk_animations[0].to_numpy_array())[3:].T).float()[0:1]
# x_first_frame = torch.tensor(np.deg2rad(walk_animations[0].to_numpy_array())[3:].T).float()[0:1]
i, j = 0, 1200
x_first_frame = data[i : i + 1][0]
x_second_frame = data[j : j + 1][0]

In [None]:
# interpolate
z_first_walk = flow.bijector.inverse(x_first_frame)
z_second_walk = flow.bijector.inverse(x_second_frame)

In [None]:
w_list = torch.cat((torch.linspace(1, 0, 120), torch.linspace(0, 1, 120)))
x_interpolated = torch.cat(
    [flow.bijector.forward(z_first_walk * w + z_second_walk * (1 - w)) for w in w_list]
)
x_interpolated_test = torch.cat(
    [x_first_frame * w + x_second_frame * (1 - w) for w in w_list]
)

In [None]:
skel = copy.deepcopy(run_skeletons[0])

anim_test = copy.deepcopy(walk_animations[0])
anim_first = copy.deepcopy(walk_animations[0])
anim_second = copy.deepcopy(walk_animations[0])
anim_first.from_numpy_array(sf.utils.data_to_motion_array(x_first_frame))
anim_second.from_numpy_array(sf.utils.data_to_motion_array(x_second_frame))
anim_test.from_numpy_array(sf.utils.data_to_motion_array(x_interpolated))

In [None]:
anim = mayavi_animate(
    skel,
    anim_first,
    offset=[0, 0, 0],
    continuous=False,
    fixed_cam=False,
    frame_limit=-1,
    save_path=None,
)

In [None]:
anim = mayavi_animate(
    skel,
    anim_second,
    offset=[0, 0, 0],
    continuous=False,
    fixed_cam=False,
    frame_limit=-1,
    save_path=None,
)

In [None]:
anim = mayavi_animate(
    skel,
    anim_test,
    offset=[0, 0, 0],
    continuous=False,
    fixed_cam=False,
    frame_limit=-1,
    save_path=None,
)