In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import copy
import os
from tqdm import tqdm
from itertools import chain
import matplotlib.pyplot as plt
from matplotlib_inline.backend_inline import set_matplotlib_formats
import seaborn as sns

import torch
import torch.distributions as dist

from signatureshape.animation import fetch_animation_id_set, fetch_animations
from signatureshape.animation.src.mayavi_animate import mayavi_animate

import extratorch as etorch
import shapeflow as sf

# make reproducible
seed = torch.manual_seed(0)

# better formats
set_matplotlib_formats("pdf", "svg")

In [None]:
# fetch data as eulers
# we assume all have the same skeleton
print("Loading mocap data:")
# walk  data
walk_subjects = ["07", "08", "35", "16"]
walk_animations = []
for s in walk_subjects:
    for t in fetch_animations(100, subject_file_name=(s + ".asf")):
        if t[2][:4] == "walk":
            walk_animations.append(t[1])

# run data
run_subjects = ["09", "16", "35"]
run_animations = []
run_skeletons = []
for s in run_subjects:
    for t in fetch_animations(100, subject_file_name=(s + ".asf")):
        if t[2][:3] == "run":
            run_skeletons.append(t[0])
            run_animations.append(t[1])

print("Convert to array:")
walk_angle_array = sf.utils.animation_to_eulers(
    walk_animations,
    reduce_shape=False,
    remove_root=True,
    deg2rad=True,
    max_frame_count=240,
)
run_angle_array = sf.utils.animation_to_eulers(
    run_animations,
    reduce_shape=False,
    remove_root=True,
    deg2rad=True,
    max_frame_count=240,
)

In [None]:
# skel2, anim, desc = fetch_animations(1, subject_file_name=run_subjects[0]+".asf")[0]
# for bone_name, bone_obj in skel2.bones.items():
#     pass
#     print(bone_name, ":")
#     print([dof for dof in bone_obj.dof], "\n")

In [None]:
# tensors on the form
# (motion, time, joint)
walk_angle_tensor = torch.tensor(walk_angle_array, dtype=torch.float32)
run_angle_tensor = torch.tensor(run_angle_array, dtype=torch.float32)

pre_shape_walk = walk_angle_tensor.shape
pre_shape_run = run_angle_tensor.shape
num_frames = min(pre_shape_walk[1], pre_shape_run[1])

nonzero = torch.argwhere(
    torch.sum(torch.abs(torch.diff(walk_angle_tensor, dim=1)), dim=[0, 1]) > 0.0
).flatten()

In [None]:
# insert best shapes here
skip_frames = 12
choosen = nonzero[[24, 26, 33, 34]]

# cut and reduce frames
walk_angles = walk_angle_tensor[
    :, :num_frames:skip_frames, choosen
]  # .reshape(post_shape_walk)
run_angles = run_angle_tensor[
    :, :num_frames:skip_frames, choosen
]  # .reshape(post_shape_run
wr_angles = torch.cat((walk_angles, run_angles))

animation_shape = wr_angles.shape[-2:]
animation_shape

In [None]:
std, mean = torch.std_mean(wr_angles, dim=0)
wr_angles_norm = (wr_angles - mean) / std
run_angles_norm = (run_angles - mean) / std
walk_angles_norm = (walk_angles - mean) / std
std.shape

In [None]:
# reshape
flatten = True
add_channel = False
make_frames = False
orig_shape_walk = walk_angles.shape
orig_shape_run = run_angles.shape
orig_shape_wr = wr_angles.shape

if add_channel:
    walk_angles_nr = torch.unsqueeze(walk_angles_norm, 1)
    run_angles_nr = torch.unsqueeze(walk_angles_norm, 1)
    wr_angles_nr = torch.unsqueeze(wr_angles_norm, 1)
elif flatten:
    walk_angles_nr = walk_angles_norm.reshape(
        orig_shape_walk[0], orig_shape_walk[1] * orig_shape_walk[2]
    )
    run_angles_nr = run_angles_norm.reshape(
        orig_shape_run[0], orig_shape_run[1] * orig_shape_run[2]
    )
    wr_angles_nr = wr_angles_norm.reshape(
        orig_shape_wr[0], orig_shape_wr[1] * orig_shape_wr[2]
    )
elif make_frames:
    walk_angles_nr = walk_angles_norm.reshape(
        orig_shape_walk[0] * orig_shape_walk[1], orig_shape_walk[2]
    )
    run_angles_nr = run_angles_norm.reshape(
        orig_shape_run[0] * orig_shape_run[1], orig_shape_run[2]
    )
    wr_angles_nr = wr_angles_norm.reshape(
        orig_shape_wr[0] * orig_shape_wr[1], orig_shape_wr[2]
    )
wr_angles_nr.shape

In [None]:
walk_priors = torch.cat(
    (torch.ones(len(walk_angles_nr)), torch.zeros(len(run_angles_nr)))
)
run_priors = abs(walk_priors - 1)
q = torch.stack((walk_priors, run_priors), dim=1)

priors = np.zeros_like(q)
eps = np.random.rand(len(priors)) * 0.1

#  priors with equal probabiliy
priors[:] = 0.5
priors[:, 1] += eps
priors[:, 0] -= eps
walk_priors.shape, priors.shape

In [None]:
# walk_angles_rev  = torch.swapaxes(walk_angles, 2, 1)
data = torch.utils.data.TensorDataset(wr_angles_nr, priors, priors)
data_walk = torch.utils.data.TensorDataset(walk_angles_nr)
data_run = torch.utils.data.TensorDataset(run_angles_nr)

In [None]:
#######
DIR = "../figures/cluster_shape/"
SET_NAME = "dim_selection_2"
PATH_FIGURES = os.path.join(DIR, SET_NAME)
if not os.path.exists(PATH_FIGURES):
    os.makedirs(PATH_FIGURES)
########

event_shape = data[0][0].shape
base_dist = dist.Independent(
    dist.Normal(loc=torch.zeros(event_shape), scale=torch.ones(event_shape)), 1
)
lr_scheduler = lambda optim: torch.optim.lr_scheduler.ReduceLROnPlateau(
    optim, mode="min", factor=0.5, patience=50, verbose=True
)

MODEL_PARAMS = {
    "model": sf.nf.get_flow,
    "get_transform": sf.transforms.NDETransform,
    "base_dist": base_dist,
    "get_net": sf.models.CNN2D,
    "activation": "tanh",
    "inverse_model": True,
    "num_flows": 2,
    "sensitivity": "autograd"
    # "trace_estimator" : "hutch_trace"
}
EXTRA_M_PARAMS = {
    "kernel_size": (3, animation_shape[-1] - 1),
    "internal_shape": animation_shape,
    "n_hidden_layers": [4],
}

TRAINING_PARAMS = {
    "batch_size": [50],
    "compute_loss": [sf.nf.get_monte_carlo_elbo_loss(epsilon=1)],
    "verbose": True,
}
# extend the previous dict with the zip of this
EXTRA_T_PARAMS = {
    "optimizer": ["ADAM"],
    "num_epochs": [100],
    "learning_rate": [0.01],
    "lr_scheduler": [lr_scheduler],
}

In [None]:
# create iterators
m_temp_1 = etorch.create_subdictionary_iterator(MODEL_PARAMS, product=True)
m_temp_2 = etorch.create_subdictionary_iterator(EXTRA_M_PARAMS, product=False)
model_params_iter = etorch.add_dictionary_iterators(m_temp_1, m_temp_2, product=True)

t_temp_1 = etorch.create_subdictionary_iterator(TRAINING_PARAMS, product=True)
t_temp_2 = etorch.create_subdictionary_iterator(EXTRA_T_PARAMS, product=False)
training_params_iter = etorch.add_dictionary_iterators(t_temp_1, t_temp_2, product=True)

In [None]:
cv_results = etorch.k_fold_cv_grid(
    fit=etorch.fit_module,
    model_params=model_params_iter,
    training_params=training_params_iter,
    data=data,
    verbose=True,
    partial=True,
    shuffle_folds=False,
    copy_data=True,
)

In [None]:
etorch.plotting.plot_result(
    path_figures=PATH_FIGURES,
    **cv_results,
)
models = cv_results["models"][0][0]

In [None]:
motion_data = data_walk[:][0]
print("Walk data:")
print(
    "Class 1:",
    torch.sum(models[0].log_prob(motion_data) < models[1].log_prob(motion_data)).item(),
    "Class 2:",
    torch.sum(models[0].log_prob(motion_data) > models[1].log_prob(motion_data)).item(),
)

motion_data = data_run[:][0]
print("Run data:")
print(
    "Class 1 :",
    torch.sum(models[0].log_prob(motion_data) < models[1].log_prob(motion_data)).item(),
    "Class 2:",
    torch.sum(models[0].log_prob(motion_data) > models[1].log_prob(motion_data)).item(),
)

In [None]:
log_p_cond_c = torch.zeros(len(data), len(models))
with torch.no_grad():
    for c, model in enumerate(models):
        log_p_cond_c[:, c] = model.log_prob(data[:][0])

print("q log p: ", (q * log_p_cond_c).mean().item())