In [None]:
%load_ext autoreload
%autoreload 2

### Test wrapper on moons

Test that we can use `flowtorch` with transformations from `nflows`.

Example is modifies  from [nflows](https://github.com/bayesiains/nflows)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib_inline.backend_inline import set_matplotlib_formats
from tqdm import tqdm
import os
from itertools import chain
import seaborn as sns
import torch
import copy

import flowtorch as ft
import flowtorch.distributions as ftdist
import nflows
from nflows.flows.base import Flow
from nflows.distributions.normal import StandardNormal
from nflows.transforms.base import CompositeTransform
from nflows.transforms.autoregressive import MaskedAffineAutoregressiveTransform
from nflows.transforms.permutations import ReversePermutation
import normflow as nf

from signatureshape.animation import fetch_animation_id_set, fetch_animations
from signatureshape.animation.src.mayavi_animate import mayavi_animate

from deepthermal.validation import (
    create_subdictionary_iterator,
    k_fold_cv_grid,
    add_dictionary_iterators,
)
from deepthermal.FFNN_model import fit_FFNN
from deepthermal.plotting import plot_result

import shapeflow as sf
import shapeflow.normalizing_flows

# make reproducible
seed = torch.manual_seed(0)

# better formats
set_matplotlib_formats("pdf", "svg")

In [None]:
# fetch data as so3
# we assume all have the same skeleton
print("Loading mocap data:")
# walk  data
walk_subjects = ["07", "08", "35", "16"]
walk_animations = []
for s in walk_subjects:
    for t in fetch_animations(100, subject_file_name=(s + ".asf")):
        if t[2][:4] == "walk":
            walk_animations.append(t[1])

walk_animations_train_frame = sum(
    len(anim.get_frames()) for anim in walk_animations[:18]
)

# run data
run_subjects = ["09", "16", "35"]
run_animations = []
run_skeletons = []
for s in run_subjects:

    for t in fetch_animations(100, subject_file_name=(s + ".asf")):
        if t[2][:3] == "run":
            run_skeletons.append(t[0])
            run_animations.append(t[1])

print("Convert to array:")
walk_angle_array = sf.utils.animation_to_eulers(
    walk_animations, reduce_shape=False, remove_root=True, deg2rad=True
)
run_angle_array = sf.utils.animation_to_eulers(
    run_animations,
    reduce_shape=False,
    remove_root=True,
    deg2rad=True,
)

In [None]:
# try animating this
# print("len: ", len(run_angle_array))
# id = 0
# from animation.src.mayavi_animate import mayavi_animate
# from animation.src.animation import Animation
# anim, skel = run_animations[id], run_skeletons[id]
# anim.from_numpy_array( walk_angle_array[0].T)
# mayavi_animate(skel, anim, offset=anim._offset, continuous=True, fixed_cam = False, frame_limit = -1, save_path = None)

In [None]:
# skel, anim, desc = fetch_animations(1, subject_file_name=walk_subjects[0]+".asf")
# skel2, anim, desc = fetch_animations(1, subject_file_name=run_subjects[0]+".asf")
# skel.bones.keys() == skel2.bones.keys()
# for bone_name, bone_obj in skel.bones.items():
#     pass
#     print(bone_name, ":")
#     print([dof for dof in bone_obj.dof], "\n")

In [None]:
# # save data since it takes so long to get

np.save("walk_angle_array_full.npy", walk_angle_array)
walk_angle_tensor = torch.tensor(
    np.load("walk_angle_array_full.npy", allow_pickle=False)
).float()


np.save("run_angle_array_full.npy", run_angle_array)
run_angle_tensor = torch.tensor(
    np.load("run_angle_array_full.npy", allow_pickle=False)
).float()

In [None]:
pre_shape_walk = walk_angle_tensor.shape
pre_shape_run = run_angle_tensor.shape

num_frames = min(pre_shape_walk[1], pre_shape_run[1])

post_shape_walk = pre_shape_walk[0], num_frames * pre_shape_walk[2]
post_shape_run = pre_shape_run[0], num_frames * pre_shape_run[2]

walk_angles = walk_angle_tensor[:, :num_frames]  # .reshape(post_shape_walk)
run_angles = run_angle_tensor[:, :num_frames]  # .reshape(post_shape_run)
walk_samples_shapes_pre = walk_angles.shape
run_angles_shapes_pre = run_angles.shape

In [None]:
# Standardize data data
# walk_mean = torch.mean(run_angle_tensor, dim=(0, 1), keepdim=True)
# walk_std = torch.std(run_angle_tensor, dim=(0, 1), keepdim=True)
# walk_standard = (walk_angles - walk_mean) / walk_std
# walk_mean = torch.mean(walk_angle_tensor,  dim=0)
# walk_std = torch.std(walk_angle_tensor,  dim=0)

# walk_angle_tensor_scaled = (walk_angle_tensor - walk_mean)/walk_std

In [None]:
unflattened_shape = walk_angle_tensor.shape
unflattened_shape_run = run_angle_tensor.shape
walk_angles_reshaped = torch.swapaxes(walk_angle_tensor, 1, 2).reshape(
    unflattened_shape[0], unflattened_shape[1] * unflattened_shape[2]
)
run_angles_reshaped = torch.swapaxes(run_angle_tensor, 1, 2).reshape(
    unflattened_shape_run[0], unflattened_shape_run[1] * unflattened_shape_run[2]
)

In [None]:
# plt.hist(walk_standard[1])
walk_angles.shape
# plt.hist(walk_angles[1][:5], density=True)

In [None]:
# walk_angles_rev  = torch.swapaxes(walk_angles, 2, 1)
data = torch.utils.data.TensorDataset(
    walk_angles_reshaped.reshape(
        unflattened_shape[0], unflattened_shape[1] * unflattened_shape[2]
    )
)
data_run = torch.utils.data.TensorDataset(
    run_angles_reshaped.reshape(
        unflattened_shape_run[0], unflattened_shape_run[1] * unflattened_shape_run[2]
    )
)

In [None]:
#######
DIR = "../figures/simple_frame/"
SET_NAME = "walk_1_auto"
PATH_FIGURES = os.path.join(DIR, SET_NAME)
if not os.path.exists(PATH_FIGURES):
    os.makedirs(PATH_FIGURES)
########


event_shape = data[0][0].shape[0]
base_dist = torch.distributions.Independent(
    torch.distributions.Normal(torch.zeros(event_shape), torch.ones(event_shape)), 1
)
FOLDS = 1
lr_scheduler = lambda optim: torch.optim.lr_scheduler.ReduceLROnPlateau(
    optim, mode="min", factor=0.2, patience=5, verbose=True
)
# def get_flow(ji)
MODEL_PARAMS = {
    "model": [sf.utils.get_flow],
    "get_transform": [sf.utils.get_transform_nflow],
    "base_dist": [base_dist],
    "Transform": [nflows.flows.MaskedAutoregressiveFlow],
}
MODEL_PARAMS_EXPERIMENT = {
    "num_blocks_per_layer": [2],
    "num_layers": [5],
    "hidden_features": [24],
}

TRAINING_PARAMS = {
    "batch_size": [500],
    "regularization_param": [0.0],
    "compute_loss": [sf.monte_carlo_dkl_loss],
    "post_step": [sf.get_post_step_lipchitz(5)],
}
# extend the previous dict with the zip of this
TRAINING_PARAMS_EXPERIMENT = {
    "verbose_interval": [100],
    "optimizer": ["ADAM"],
    "num_epochs": [200],
    "learning_rate": [0.01],
    # "lr_scheduler": [lr_scheduler],
}

In [None]:
# create iterators
model_params_iter_1 = create_subdictionary_iterator(MODEL_PARAMS)
# model_params_iter = chain.from_iterable((model_params_iter_1, model_params_iter_2))

model_exp_iter = create_subdictionary_iterator(MODEL_PARAMS_EXPERIMENT, product=False)
exp_model_params_iter = add_dictionary_iterators(model_exp_iter, model_params_iter_1)

training_params_iter = create_subdictionary_iterator(TRAINING_PARAMS)
training_exp_iter = create_subdictionary_iterator(
    TRAINING_PARAMS_EXPERIMENT, product=True
)
exp_training_params_iter = add_dictionary_iterators(
    training_exp_iter, training_params_iter
)

In [None]:
cv_results = k_fold_cv_grid(
    model_params=exp_model_params_iter,
    fit=fit_FFNN,
    training_params=exp_training_params_iter,
    data=data,
    folds=5,
    verbose=True,
    trials=1,
    partial=True,
    shuffle_folds=True,
)

Test that the wrapper works

In [None]:
plot_result(
    path_figures=PATH_FIGURES,
    **cv_results,
    # plot_function=plot_model_1d,
    # function_kwargs=plot_kwargs,
)

In [None]:
flow = cv_results["models"][0][0]
# sample = flow.sample([1]).data

In [None]:
noise = base_dist.sample([100])
print("Log plots:")

print("Noise :", flow.log_prob(noise).mean().item())
# print("Run data :", flow.log_prob(run_angles_reshaped[0:1]).mean().item())
# print("Validation data:", flow.log_prob().mean().item())
print("Train data:", flow.log_prob(data[0:1][0]).mean().item())

base_dist

In [None]:
data_1 = data[0:1][0]
data_2 = data[24:25][0]
walk_sample_near_ = flow.bijector(
    flow.bijector.inverse(data_1) * 0.5 + flow.bijector.inverse(data_2) * 0.5
)
# walk_sample_near_ =   data_1*0.5+ data_2*0.5

walk_sample_near = walk_sample_near_.reshape(walk_angles[0].shape)
pads = (0, 0, 3, 0)
walk_sample_near = (
    torch.nn.functional.pad(walk_sample_near.T, pads, "constant", 0).detach().numpy()
)
flow.log_prob(walk_sample_near_)

In [None]:
flow.log_prob(walk_sample_near_)
# flow.log_prob(data[0:1][0])
# flow.log_prob(data[24:25][0])

In [None]:
skel = copy.deepcopy(run_skeletons[0])

test_anim = copy.deepcopy(walk_animations[0])

In [None]:
frame = 0

pads = (0, 0, 3, 0)
# walk_sample_square = torch.nn.functional.pad(walk_sample.T, pads, "constant", 0).detach().numpy()
# if frame is not None:
#     print(walk_sample_square[:,].shape)
#     walk_sample_square =  np.tile(walk_sample_square[:,], [1,10])
#     print(walk_sample_square.shape)
test_anim.from_numpy_array(walk_sample_near)

In [None]:
anim = mayavi_animate(skel, test_anim, offset=[0, 0, 0])

In [None]:
# plt.hist(np.abs((walk_sample_reformated[7])))
anim = mayavi_animate(
    skel,
    walk_animations[64],
    offset=[0, 0, 0],
    continuous=True,
    fixed_cam=False,
    frame_limit=-1,
    save_path="test.svg",
)

In [None]:
plt.hist(np.abs((anim1.to_numpy_array()[7])))

In [None]:
np.linspace(0, 1, 3).reshape((3, 1))  # *np.array([0,1,2])