In [None]:
%load_ext autoreload
%autoreload 2

### Test wrapper on moons

Test that we can use `flowtorch` with transformations from `nflows`.

Example is modifies  from [nflows](https://github.com/bayesiains/nflows)

In [None]:
import numpy as np
import copy
import os
from tqdm import tqdm
from itertools import chain
import matplotlib.pyplot as plt
from matplotlib_inline.backend_inline import set_matplotlib_formats
import seaborn as sns

import torch
import torch.distributions as dist

import flowtorch as ft
import nflows
import normflow as nf

from signatureshape.animation import fetch_animation_id_set, fetch_animations
from signatureshape.animation.src.mayavi_animate import mayavi_animate

from deepthermal.validation import (
    create_subdictionary_iterator,
    k_fold_cv_grid,
    add_dictionary_iterators,
)
from deepthermal.FFNN_model import fit_FFNN
from deepthermal.plotting import plot_result

import shapeflow as sf

# make reproducible
seed = torch.manual_seed(0)

# better formats
set_matplotlib_formats("pdf", "svg")

In [None]:
# fetch data as so3
# we assume all have the same skeleton
print("Loading mocap data:")
# walk  data
walk_subjects = ["07", "08", "35", "16"]
walk_animations = []
for s in walk_subjects:
    for t in fetch_animations(100, subject_file_name=(s + ".asf")):
        if t[2][:4] == "walk":
            walk_animations.append(t[1])

walk_animations_train_frame = sum(
    len(anim.get_frames()) for anim in walk_animations[:18]
)

# run data
run_subjects = ["09", "16", "35"]
run_animations = []
run_skeletons = []
for s in run_subjects:

    for t in fetch_animations(100, subject_file_name=(s + ".asf")):
        if t[2][:3] == "run":
            run_skeletons.append(t[0])
            run_animations.append(t[1])

print("Convert to array:")
walk_angle_array = sf.utils.animation_to_eulers(
    walk_animations, reduce_shape=False, remove_root=True, deg2rad=True
)
run_angle_array = sf.utils.animation_to_eulers(
    run_animations,
    reduce_shape=False,
    remove_root=True,
    deg2rad=True,
)

In [None]:
# anim.from_numpy_array( walk_angle_array[0].T)
# mayavi_animate(skel, anim, offset=anim._offset, continuous=True, fixed_cam = False, frame_limit = -1, save_path = None)

In [None]:
# skel2, anim, desc = fetch_animations(1, subject_file_name=run_subjects[0]+".asf")
# skel.bones.keys() == skel2.bones.keys()
# for bone_name, bone_obj in skel.bones.items():
#     pass
#     print(bone_name, ":")
#     print([dof for dof in bone_obj.dof], "\n")

In [None]:
# # save data since it takes so long to get

np.save("walk_angle_array_full.npy", walk_angle_array)
walk_angle_tensor = torch.tensor(
    np.load("walk_angle_array_full.npy", allow_pickle=False)
).float()

np.save("run_angle_array_full.npy", run_angle_array)
run_angle_tensor = torch.tensor(
    np.load("run_angle_array_full.npy", allow_pickle=False)
).float()

In [None]:
pre_shape_walk = walk_angle_tensor.shape
pre_shape_run = run_angle_tensor.shape

num_frames = min(pre_shape_walk[1], pre_shape_run[1])

post_shape_walk = pre_shape_walk[0], num_frames * pre_shape_walk[2]
post_shape_run = pre_shape_run[0], num_frames * pre_shape_run[2]

walk_angles = walk_angle_tensor[:, :num_frames]  # .reshape(post_shape_walk)
run_angles = run_angle_tensor[:, :num_frames]  # .reshape(post_shape_run)
walk_samples_shapes_pre = walk_angles.shape
run_angles_shapes_pre = run_angles.shape

In [None]:
# reshape
flatten = False
add_channel = True

orig_shape_walk = walk_angle_tensor.shape
orig_shape_run = run_angle_tensor.shape
if add_channel:
    walk_angles_reshaped = torch.unsqueeze(walk_angles, 1)
    run_angles_reshaped = torch.unsqueeze(run_angles, 1)
elif flatten:
    walk_angles_reshaped = torch.swapaxes(walk_angle_tensor, 1, 2).reshape(
        orig_shape_walk[0], orig_shape_walk[1] * orig_shape_walk[2]
    )
    run_angles_reshaped = torch.swapaxes(run_angle_tensor, 1, 2).reshape(
        orig_shape_run[0], orig_shape_run[1] * orig_shape_run[2]
    )
walk_angles_reshaped.shape, run_angles_reshaped.shape

In [None]:
# walk_angles_rev  = torch.swapaxes(walk_angles, 2, 1)
data = torch.utils.data.TensorDataset(walk_angles_reshaped)
data_run = torch.utils.data.TensorDataset(run_angles_reshaped)

In [None]:
#######
DIR = "../figures/full_shape/"
SET_NAME = "walk_residual"
PATH_FIGURES = os.path.join(DIR, SET_NAME)
if not os.path.exists(PATH_FIGURES):
    os.makedirs(PATH_FIGURES)
########
FOLDS = 5

event_shape = data[0][0].shape
base_dist = dist.Independent(
    dist.Normal(loc=torch.zeros(event_shape), scale=torch.ones(event_shape)), 3
)
lr_scheduler = lambda optim: torch.optim.lr_scheduler.ReduceLROnPlateau(
    optim, mode="min", factor=0.5, patience=40, verbose=True
)
# def get_flow(ji)
MODEL_PARAMS = {
    "model": [sf.utils.get_flow],
    "get_transform": [sf.normalizing_flows.get_residual_transform],
    "base_dist": [base_dist],
    "inverse_model": [False],
    "compose": [True],
}
num_layers = 5
MODEL_PARAMS_EXPERIMENT = {
    # "num_blocks_per_layer": [2],
    "CNN": [[True] * num_layers] * 2,
    "hidden_features": [[7] * num_layers] * 2,
    "hidden_layers": [[2] * num_layers] * 2,
    "n_exact_terms": [[3] * num_layers],
    "n_samples": [[10] * num_layers],
}

TRAINING_PARAMS = {
    "batch_size": [10],
    "regularization_param": [1e-5],
    "compute_loss": [sf.monte_carlo_dkl_loss],
    "post_batch": [sf.get_post_step_lipchitz(5)],
}
# extend the previous dict with the zip of this
TRAINING_PARAMS_EXPERIMENT = {
    "verbose_interval": [10],
    "optimizer": ["ADAM"],
    "num_epochs": [1000],
    "learning_rate": [0.1],
    "lr_scheduler": [lr_scheduler],
}

In [None]:
base_dist.event_shape
# event_shape
# event_shape.ev
# event_shape

In [None]:
# create iterators
model_params_iter_1 = create_subdictionary_iterator(MODEL_PARAMS)
# model_params_iter = chain.from_iterable((model_params_iter_1, model_params_iter_2))

model_exp_iter = create_subdictionary_iterator(MODEL_PARAMS_EXPERIMENT, product=False)
exp_model_params_iter = add_dictionary_iterators(model_exp_iter, model_params_iter_1)

training_params_iter = create_subdictionary_iterator(TRAINING_PARAMS)
training_exp_iter = create_subdictionary_iterator(
    TRAINING_PARAMS_EXPERIMENT, product=True
)
exp_training_params_iter = add_dictionary_iterators(
    training_exp_iter, training_params_iter
)

In [None]:
cv_results = k_fold_cv_grid(
    model_params=exp_model_params_iter,
    fit=fit_FFNN,
    training_params=exp_training_params_iter,
    data=data,
    folds=5,
    verbose=True,
    trials=1,
    partial=True,
    shuffle_folds=False,
)

In [None]:
plot_result(
    path_figures=PATH_FIGURES,
    **cv_results,
    # plot_function=plot_model_1d,
    # function_kwargs=plot_kwargs,
)

In [None]:
# flow = cv_results["models"][0][0]
# sample = flow.sample([1]).data
# torch.save(flow, "flow_full_motion_2.pt")

In [None]:
flow = torch.load("flow_full_motion_2.pt")
flow.eval()

In [None]:
noise = base_dist.sample([100])
print("Log plots:")

print("Noise :", flow.log_prob(noise).mean().item())
print("Run data:", flow.log_prob(data_run[:][0]).mean().item())
print("Train data:", flow.log_prob(data[0:20][0]).mean().item())
print("Validation data:", flow.log_prob(data[55:64][0]).mean().item())

In [None]:

z_close = flow.bijector.inverse(data[0:1][0])
x_sample = flow.bijector.forward(z_close + noise[0]*0.1)
# print("Validation data:", flow.log_prob(noise[:10]).mean().item())

In [None]:
i, j = 0, 8
data_1 = data[i : i + 1][0]
data_2 = data[j : j + 1][0]
w_list = torch.cat((torch.linspace(1, 0, 10), torch.linspace(0, 1, 10)))
x_interpolated = torch.cat(
    [
        flow.bijector(
            flow.bijector.inverse(data_1) * w + (1 - w) * flow.bijector.inverse(data_2)
        )
        for w in w_list
    ],
    dim=2,
)
x_interpolated_test = torch.cat([data_1 * w + data_2 * (1 - w) for w in w_list], dim=2)

In [None]:
# flow.log_prob(x_interpolated)
# x_interpolated.shape
# flow.log_prob(data[0:1][0])

In [None]:
skel = copy.deepcopy(run_skeletons[0])
test_anim = copy.deepcopy(walk_animations[0])

In [None]:
# create anim
test_anim.from_numpy_array(sf.utils.data_to_motion_array(x_interpolated_test))

In [None]:
anim = mayavi_animate(
    skel,
    test_anim,
    offset=[0, 0, 0],
    continuous=False,
    fixed_cam=False,
    frame_limit=-1,
    save_path="test.svg",
)

In [None]:
# plt.hist(np.abs((walk_sample_reformated[7])))
test_sample = copy.deepcopy(walk_animations[0])
test_sample.from_numpy_array(sf.utils.data_to_motion_array(x_sample))

anim = mayavi_animate(
    skel,
    test_sample,
    offset=[0, 0, 0],
    continuous=True,
    fixed_cam=False,
    frame_limit=-1,
    save_path="test.svg",
)

In [None]:
np.linspace(0, 1, 3).reshape((3, 1))  # *np.array([0,1,2])