In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import copy
import os

import numpy as np
import torch
import torch.distributions as dist
import matplotlib
import matplotlib.pyplot as plt
from matplotlib_inline.backend_inline import set_matplotlib_formats
import seaborn as sns


from signatureshape.animation import fetch_animation_id_set, fetch_animations
from signatureshape.animation.src.mayavi_animate import mayavi_animate

import extratorch as etorch
import shapeflow as sf

In [None]:
# make reproducible
seed = torch.manual_seed(0)

# better plotting
set_matplotlib_formats("pdf", "svg")
matplotlib.rcParams.update({"font.size": 12})
set_matplotlib_formats("pdf", "svg")
plt.style.use("tableau-colorblind10")
sns.set_style("white")

In [None]:
# fetch data as so3
# we assume all have the same skeleton
print("Loading mocap data:")
# walk  data
walk_subjects = ["07", "08", "35", "16"]

walk_animations = []
walk_desc = []
for s in walk_subjects:
    for t in fetch_animations(100, subject_file_name=(s + ".asf")):
        if t[2][:4] == "walk":
            walk_animations.append(t[1])
            walk_desc.append(t[2])

walk_animations_train_frame = sum(
    len(anim.get_frames()) for anim in walk_animations[:18]
)

# run data
run_subjects = ["09", "16", "35"]
run_animations = []
run_skeletons = []
for s in run_subjects:

    for t in fetch_animations(100, subject_file_name=(s + ".asf")):
        if t[2][:3] == "run":
            run_skeletons.append(t[0])
            run_animations.append(t[1])

print("Convert to array:")
walk_angle_array = sf.utils.animation_to_eulers(
    walk_animations,
    reduce_shape=True,
    remove_root=True,
    deg2rad=True,
    skeleton=run_skeletons[0],
    max_frame_count=240,
)
run_angle_array = sf.utils.animation_to_eulers(
    run_animations,
    reduce_shape=True,
    remove_root=True,
    deg2rad=True,
    skeleton=run_skeletons[0],
    max_frame_count=240,
)

In [None]:
walk_angle_tensor_ = torch.tensor(walk_angle_array, dtype=torch.float32)
run_angle_tensor_ = torch.tensor(run_angle_array, dtype=torch.float32)
wr_angle_tensor_ = torch.cat((walk_angle_tensor_, run_angle_tensor_))

# normalize
std, mean = torch.std_mean(wr_angle_tensor_, dim=0)
wr_angle_tensor_norm = (wr_angle_tensor_ - mean) / std
run_angle_tensor_norm = (run_angle_tensor_ - mean) / std
walk_angle_tensor_norm = (walk_angle_tensor_ - mean) / std

In [None]:
data_walk = torch.utils.data.TensorDataset(walk_angle_tensor_norm)
data_run = torch.utils.data.TensorDataset(run_angle_tensor_norm)
data = torch.utils.data.TensorDataset(wr_angle_tensor_norm)
len(data)

In [None]:
#######
DIR = "../figures/interpolate_frames/"
SET_NAME = "res_2"
PATH_FIGURES = os.path.join(DIR, SET_NAME)
if not os.path.exists(PATH_FIGURES):
    os.makedirs(PATH_FIGURES)
########

event_shape = data[0][0].shape
base_dist = dist.Independent(
    dist.Normal(loc=torch.zeros(event_shape), scale=torch.ones(event_shape)), 1
)

lr_scheduler = lambda optim: torch.optim.lr_scheduler.ReduceLROnPlateau(
    optim, mode="min", factor=0.5, patience=10, verbose=True
)
# def get_flow(ji)
MODEL_PARAMS = {
    "model": [sf.nf.get_flow],
    "get_transform": [sf.transforms.get_residual_transform],
    "base_dist": [base_dist],
    "inverse_model": [True],
    "compose": [True],
}
num_layers = 3
EXTRA_M_PARAMS = {
    "hidden_features": [[44] * num_layers],
    "hidden_layers": [[3] * num_layers],
    "n_exact_terms": [[4] * num_layers],
    "n_samples": [[10] * num_layers],
}


TRAINING_PARAMS = {
    "batch_size": [3000],
    "compute_loss": [sf.nf.monte_carlo_dkl_loss],
    "verbose": True,
    "post_batch": sf.get_post_step_lipchitz(5),
}
# extend the previous dict with the zip of this
EXTRA_T_PARAMS = {
    "verbose_interval": [20],
    "optimizer": ["ADAM"],
    "num_epochs": [300],
    "learning_rate": [0.1],
    "lr_scheduler": [lr_scheduler],
}

In [None]:
# create iterators
m_temp_1 = etorch.create_subdictionary_iterator(MODEL_PARAMS, product=True)
m_temp_2 = etorch.create_subdictionary_iterator(EXTRA_M_PARAMS, product=True)
model_params_iter = etorch.add_dictionary_iterators(m_temp_1, m_temp_2, product=True)

t_temp_1 = etorch.create_subdictionary_iterator(TRAINING_PARAMS, product=True)
t_temp_2 = etorch.create_subdictionary_iterator(EXTRA_T_PARAMS, product=False)
training_params_iter = etorch.add_dictionary_iterators(t_temp_1, t_temp_2, product=True)

In [None]:
cv_results = etorch.k_fold_cv_grid(
    model_params=model_params_iter,
    fit=etorch.fit_module,
    training_params=training_params_iter,
    data=data,
    verbose=True,
    trials=1,
    shuffle_folds=True,
)

In [None]:
etorch.plotting.plot_result(
    path_figures=PATH_FIGURES,
    **cv_results,
)

In [None]:
flow = cv_results["models"][0]

In [None]:
for i in range(len(flow.bijector.bijectors)):
    flow.bijector.bijectors[i].model.iresblock.exact_trace = True

In [None]:
noise = base_dist.sample([100])
print("Log vals:")

print("Noise :", flow.log_prob(noise[0:1]).mean().item())
print("Trian data:", flow.log_prob(data[:][0]).mean().item())

In [None]:
sample = flow.sample([1])

In [None]:
# get_two motions
# x_first_frame = torch.tensor(np.deg2rad(walk_animations[0].to_numpy_array())[3:].T).float()[0:1]
# x_first_frame = torch.tensor(np.deg2rad(walk_animations[0].to_numpy_array())[3:].T).float()[0:1]
i, j = 50, -2
x_first_frame = data_walk[i : i + 1][0]
x_second_frame = data_run[j : j + 1][0]

In [None]:
# interpolate
z1 = flow.rnormalize(x_first_frame)
z2 = flow.rnormalize(x_second_frame)
line_ = torch.linspace(0, 1, 240)
line = torch.unsqueeze(line_, 1)
interp_line_z = z1 * line + z2 * (1 - line)

In [None]:
x_interpolated = flow.bijector.forward(interp_line_z)

x_interpolated_lin = x_first_frame * line + x_second_frame * (1 - line)

In [None]:
with torch.no_grad():
    lat_log_prob = torch.mean(
        torch.stack([flow.log_prob(x_interpolated) for a in range(10)]), dim=0
    )
    lin_log_prob = torch.mean(
        torch.stack([flow.log_prob(x_interpolated_lin) for a in range(10)]), dim=0
    )
    plt.plot(line, torch.exp(lat_log_prob), "-", label="Latent space interpolation")
    plt.plot(line, torch.exp(lin_log_prob), "-.", label="Feature space interpolation")
    plt.xlabel("$t$")
    plt.ylabel("$p_{T(Z)}$")
plt.legend()
plt.savefig(
    os.path.join(
        PATH_FIGURES,
        "interpolation_prob.pdf",
    ),
    bbox_inches="tight",
    pad_inches=0,
)
plt.show()

In [None]:
skel = copy.deepcopy(run_skeletons[0])

anim_test = copy.deepcopy(walk_animations[0])
anim_first = copy.deepcopy(walk_animations[0])
anim_second = copy.deepcopy(walk_animations[0])
anim_test_lin = copy.deepcopy(walk_animations[0])
anim_first.from_numpy_array(sf.utils.data_to_motion_array(x_first_frame * std + mean))
anim_second.from_numpy_array(sf.utils.data_to_motion_array(x_second_frame * std + mean))
anim_test.from_numpy_array(sf.utils.data_to_motion_array(x_interpolated * std + mean))
anim_test_lin.from_numpy_array(
    sf.utils.data_to_motion_array(x_interpolated_lin * std + mean)
)

In [None]:
anim = mayavi_animate(
    skel,
    anim_first,
    offset=[0, 0, 0],
    continuous=False,
    fixed_cam=False,
    frame_limit=-1,
    save_path=None,
)

In [None]:
anim = mayavi_animate(
    skel,
    anim_second,
    offset=[0, 0, 0],
    continuous=False,
    fixed_cam=False,
    frame_limit=-1,
    save_path=None,
)

In [None]:
anim = mayavi_animate(
    skel,
    anim_test,
    offset=[0, 0, 0],
    continuous=False,
    fixed_cam=False,
    frame_limit=-1,
    save_path=None,
)

In [None]:
anim = mayavi_animate(
    skel,
    anim_test_lin,
    offset=[0, 0, 0],
    continuous=False,
    fixed_cam=False,
    frame_limit=-1,
    save_path=None,
)