In [None]:
%load_ext autoreload
%autoreload 2

### Test wrapper on moons

Test that we can use `flowtorch` with transformations from `nflows`.

Example is modifies  from [nflows](https://github.com/bayesiains/nflows)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib_inline.backend_inline import set_matplotlib_formats
from tqdm import tqdm
import os
from itertools import chain
import seaborn as sns
import torch

import flowtorch.distributions as ftdist

from nflows.flows.base import Flow
from nflows.distributions.normal import StandardNormal
from nflows.transforms.base import CompositeTransform
from nflows.transforms.autoregressive import MaskedAffineAutoregressiveTransform
from nflows.transforms.permutations import ReversePermutation

from animation import fetch_animation_id_set, fetch_animations
from so3.curves import move_origin_to_zero as move_origin_to_zero_so3
from so3.transformations import hatinv
from linear import animation_to_SO3
from linear.curves import move_origin_to_zero

from deepthermal.validation import (
    create_subdictionary_iterator,
    k_fold_cv_grid,
    add_dictionary_iterators,
)
from deepthermal.FFNN_model import fit_FFNN
from deepthermal.plotting import plot_result

import shapeflow as sflow

# make reproducible
seed = torch.manual_seed(0)

# better formats
set_matplotlib_formats("pdf", "svg")

In [None]:
# fetch data as so3
# we assume all have the same skeleton
print("Loading mocap data:")
# walk  data
walk_subjects = ["07", "08"]
walk_animations = []
for s in walk_subjects:
    for t in fetch_animations(100, subject_file_name=(s + ".asf")):
        walk_animations.append(t[1])
walk_animations_train_frame = sum(
    len(anim.get_frames()) for anim in walk_animations[:18]
)

# run data
run_subjects = ["09"]
run_animations = []
for s in run_subjects:
    for t in fetch_animations(100, subject_file_name=(s + ".asf")):
        run_animations.append(t[1])

print("Convert to array:")
walk_angle_array = sflow.utils.animation_to_eulers(walk_animations)
run_angle_array = sflow.utils.animation_to_eulers(run_animations)

In [None]:
# skel, anim, desc = fetch_animations(1, subject_file_name=walk_subjects[0]+".asf")
# skel2, anim, desc = fetch_animations(1, subject_file_name=run_subjects[0]+".asf")
# skel.bones.keys() == skel2.bones.keys()
# for bone_name, bone_obj in skel.bones.items():
#     pass
#     print(bone_name, ":")
#     print([dof for dof in bone_obj.dof], "\n")

In [None]:
# # save data since it takes so long to get

np.save("walk_angle_array.npy", walk_angle_array)
walk_angle_tensor = torch.tensor(
    np.load("walk_angle_array.npy", allow_pickle=False)
).float()


np.save("run_angle_array.npy", run_angle_array)
run_angle_tensor = torch.tensor(
    np.load("run_angle_array.npy", allow_pickle=False)
).float()

In [None]:
print(walk_animations_train_frame, torch.max(walk_frames), walk_frames.shape)

In [None]:
data = torch.utils.data.TensorDataset(walk_angle_tensor[walk_animations_train_frame:])
data_val = torch.utils.data.TensorDataset(
    walk_angle_8_tensor[:walk_animations_train_frame]
)

In [None]:
# Set up model
num_layers = 5
event_shape = run_frames.shape[1]
base_dist = torch.distributions.Independent(
    torch.distributions.Normal(torch.zeros(event_shape), torch.ones(event_shape)), 1
)

transforms = []
for _ in range(num_layers):
    transforms.append(ReversePermutation(features=event_shape))
    transforms.append(
        MaskedAffineAutoregressiveTransform(features=event_shape, hidden_features=4)
    )
transform = CompositeTransform(transforms)
bijector = sflow.WrapInverseModel(model=transform)

# flow = ftdist.Flow(bijector=bijector, base_dist=base_dist)
# optimizer = optim.Adam(flow.parameters())

In [None]:
#######
DIR = "../figures/curve_1/"
SET_NAME = "walk_1_skaled"
PATH_FIGURES = os.path.join(DIR, SET_NAME)
if not os.path.exists(PATH_FIGURES):
    os.makedirs(PATH_FIGURES)
########


FOLDS = 1
lr_scheduler = lambda optim: torch.optim.lr_scheduler.ReduceLROnPlateau(
    optim, mode="min", factor=0.2, patience=5, verbose=True
)
# def get_flow(ji)
MODEL_PARAMS = {"model": [ftdist.Flow]}
MODEL_PARAMS_EXPERIMENT = {
    "bijector": [bijector],
    "base_dist": [base_dist],
}

TRAINING_PARAMS = {
    "batch_size": [100],
    "regularization_param": [0.0],
    "compute_loss": [sflow.monte_carlo_dkl_loss],
}
# extend the previous dict with the zip of this
TRAINING_PARAMS_EXPERIMENT = {
    "verbose_interval": [1000],
    "optimizer": ["ADAM"],
    "num_epochs": [1000],
    "learning_rate": [0.01],
    # "lr_scheduler": [lr_scheduler],
}

In [None]:
# create iterators
model_params_iter_1 = create_subdictionary_iterator(MODEL_PARAMS)
# model_params_iter = chain.from_iterable((model_params_iter_1, model_params_iter_2))

model_exp_iter = create_subdictionary_iterator(MODEL_PARAMS_EXPERIMENT, product=False)
exp_model_params_iter = add_dictionary_iterators(model_exp_iter, model_params_iter_1)

training_params_iter = create_subdictionary_iterator(TRAINING_PARAMS)
training_exp_iter = create_subdictionary_iterator(
    TRAINING_PARAMS_EXPERIMENT, product=False
)
exp_training_params_iter = add_dictionary_iterators(
    training_exp_iter, training_params_iter
)

In [None]:
cv_results = k_fold_cv_grid(
    model_params=exp_model_params_iter,
    fit=fit_FFNN,
    training_params=exp_training_params_iter,
    data=data,
    folds=FOLDS,
    val_data=data_val,
    verbose=True,
    trials=1,
    partial=True,
    shuffle_folds=False,
)

Test that the wrapper works

In [None]:
plot_result(
    path_figures=PATH_FIGURES,
    **cv_results,
    # plot_function=plot_model_1d,
    # function_kwargs=plot_kwargs,
)

In [None]:
flow = cv_results["models"][0][0]
sample = flow.sample([1]).data

In [None]:
noise = base_dist.sample([100])
# print(noise)
# print(flow.log_prob(noise))
print("Log plots:")

print("Noise :", flow.log_prob(noise).mean().item())
print("Run data :", flow.log_prob(run_angle_tensor).mean().item())
print("Validation data:", flow.log_prob(data_val[:][0]).mean().item())
print("Trian data:", flow.log_prob(data[:][0]).mean().item())
# print(flow.log_prob(run_frames[:3]))
# print("values",noise)
# print("logprob(values)",flow.log_prob(noise))
# print(flow.log_prob((walk_frames[:10])))
# print(walk_frames[:10].shape)
# print("LogP noise: ",flow.log_prob(sample))
# print("LogP walk: ",flow.log_prob(walk_frames[:10]))
# print("LogP walk: ",flow.log_prob(walk_frames[:10]))
# base_dist.log_prob(noise)