In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import copy
import os
from itertools import chain
from tqdm import tqdm

import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib_inline.backend_inline import set_matplotlib_formats
import seaborn as sns

import flowtorch as ft

import nflows
from nflows.flows.base import Flow

from signatureshape.animation import fetch_animation_id_set, fetch_animations
from signatureshape.animation.src.mayavi_animate import mayavi_animate

from deepthermal.validation import (
    create_subdictionary_iterator,
    k_fold_cv_grid,
    add_dictionary_iterators,
)

from deepthermal.FFNN_model import fit_FFNN
from deepthermal.plotting import plot_result

import shapeflow as sf

# make reproducible
seed = torch.manual_seed(0)

# better formats
set_matplotlib_formats("pdf", "svg")

In [None]:
# fetch data as so3
# we assume all have the same skeleton
print("Loading mocap data:")
# walk  data
walk_subjects = ["07", "08", "35", "16"]
walk_animations = []
for s in walk_subjects:
    for t in fetch_animations(100, subject_file_name=(s + ".asf")):
        if t[2][:4] == "walk":
            walk_animations.append(t[1])

walk_animations_train_frame = sum(
    len(anim.get_frames()) for anim in walk_animations[:18]
)

# run data
run_subjects = ["09", "16", "35"]
run_animations = []
run_skeletons = []
for s in run_subjects:

    for t in fetch_animations(100, subject_file_name=(s + ".asf")):
        if t[2][:3] == "run":
            run_skeletons.append(t[0])
            run_animations.append(t[1])

print("Convert to array:")
walk_angle_array = sf.utils.animation_to_eulers(
    walk_animations, reduce_shape=True, remove_root=True, deg2rad=True
)
run_angle_array = sf.utils.animation_to_eulers(
    run_animations, reduce_shape=True, remove_root=True, deg2rad=True
)

In [None]:
# save data since it takes so long to get

np.save("walk_angle_array.npy", walk_angle_array)
walk_angle_tensor = torch.tensor(
    np.load("walk_angle_array.npy", allow_pickle=False)
).float()

np.save("run_angle_array.npy", run_angle_array)
run_angle_tensor = torch.tensor(
    np.load("run_angle_array.npy", allow_pickle=False)
).float()

In [None]:
data = torch.utils.data.TensorDataset(walk_angle_tensor)
data_run = torch.utils.data.TensorDataset(run_angle_tensor)

In [None]:
#######
DIR = "../figures/simple_frame/"
SET_NAME = "walk_1_auto"
PATH_FIGURES = os.path.join(DIR, SET_NAME)
if not os.path.exists(PATH_FIGURES):
    os.makedirs(PATH_FIGURES)
########


event_shape = data[0][0].shape[0]
base_dist = torch.distributions.Independent(
    torch.distributions.Normal(torch.zeros(event_shape), torch.ones(event_shape)), 1
)
FOLDS = 1
lr_scheduler = lambda optim: torch.optim.lr_scheduler.ReduceLROnPlateau(
    optim, mode="min", factor=0.2, patience=5, verbose=True
)
# def get_flow(ji)
MODEL_PARAMS = {
    "model": [sf.utils.get_flow],
    "get_transform": [sf.utils.get_transform_nflow],
    "base_dist": [base_dist],
    "Transform": [nflows.flows.MaskedAutoregressiveFlow],
}
MODEL_PARAMS_EXPERIMENT = {
    "num_blocks_per_layer": [2],
    "num_layers": [5],
    "hidden_features": [24],
}

TRAINING_PARAMS = {
    "batch_size": [500],
    "regularization_param": [0.0],
    "compute_loss": [sf.monte_carlo_dkl_loss],
    "post_step": [sf.get_post_step_lipchitz(5)],
}
# extend the previous dict with the zip of this
TRAINING_PARAMS_EXPERIMENT = {
    "verbose_interval": [100],
    "optimizer": ["ADAM"],
    "num_epochs": [200],
    "learning_rate": [0.01],
    # "lr_scheduler": [lr_scheduler],
}

In [None]:
# create iterators
model_params_iter_1 = create_subdictionary_iterator(MODEL_PARAMS)

model_exp_iter = create_subdictionary_iterator(MODEL_PARAMS_EXPERIMENT, product=False)
exp_model_params_iter = add_dictionary_iterators(model_exp_iter, model_params_iter_1)

training_params_iter = create_subdictionary_iterator(TRAINING_PARAMS)
training_exp_iter = create_subdictionary_iterator(
    TRAINING_PARAMS_EXPERIMENT, product=False
)
exp_training_params_iter = add_dictionary_iterators(
    training_exp_iter, training_params_iter
)

In [None]:
cv_results = k_fold_cv_grid(
    model_params=exp_model_params_iter,
    fit=fit_FFNN,
    training_params=exp_training_params_iter,
    data=data,
    folds=FOLDS,
    verbose=True,
    trials=1,
    partial=True,
    shuffle_folds=False,
)

Test that the wrapper works

In [None]:
plot_result(
    path_figures=PATH_FIGURES,
    **cv_results,
    # plot_function=plot_model_1d,
    # function_kwargs=plot_kwargs,
)

In [None]:
flow = cv_results["models"][0][0]
sample = flow.sample([1]).data

In [None]:
noise = base_dist.sample([100])

print("Log vals:")

print("Noise :", flow.log_prob(noise).mean().item())
print("Trian data:", flow.log_prob(data[:][0]).mean().item())

In [None]:
# interpolate between two
x_first_walk = torch.tensor(
    np.deg2rad(walk_animations[0].to_numpy_array()[3:, :328].T)
).float()
x_second_walk = torch.tensor(
    np.deg2rad(walk_animations[1].to_numpy_array()[3:].T)
).float()
# z_first_walk = flow.bijector.inverse(torch.tensor(first_walk).float())

z_first_walk = flow.bijector.inverse(x_first_walk)
z_second_walk = flow.bijector.inverse(x_second_walk)
x_interpolated = torch.rad2deg(
    flow.bijector.forward((z_first_walk * 0.5 + z_second_walk * 0.5))
)

pads = (0, 0, 3, 0)
walk_interpolated_square = (
    torch.nn.functional.pad(x_interpolated.T, pads, "constant", 0).detach().numpy()
)

x_interpolated_direct = torch.rad2deg((x_first_walk + x_second_walk) / 2)
walk_interpolated_direct = (
    torch.nn.functional.pad(x_interpolated_direct.T, pads, "constant", 0)
    .detach()
    .numpy()
)

In [None]:
skel = copy.deepcopy(run_skeletons[0])
anim_int = copy.deepcopy(walk_animations[0])
anim_direct_int = copy.deepcopy(walk_animations[0])
anim_int.from_numpy_array(walk_interpolated_square)
anim_direct_int.from_numpy_array(walk_interpolated_direct)

In [None]:
anim = mayavi_animate(
    skel,
    anim_int,
    offset=[0, 0, 0],
    continuous=False,
    fixed_cam=False,
    frame_limit=-1,
    save_path=None,
)

In [None]:
anim = mayavi_animate(
    skel,
    walk_animations[0],
    offset=[0, 0, 0],
    continuous=False,
    fixed_cam=False,
    frame_limit=-1,
    save_path=None,
)

In [None]:
anim = mayavi_animate(
    skel,
    walk_animations[1],
    offset=[0, 0, 0],
    continuous=False,
    fixed_cam=False,
    frame_limit=-1,
    save_path=None,
)

In [None]:
anim = mayavi_animate(
    skel,
    anim_direct_int,
    offset=[0, 0, 0],
    continuous=False,
    fixed_cam=False,
    frame_limit=-1,
    save_path=None,
)