In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from itertools import chain
import torch
from torch.utils.data import TensorDataset
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from matplotlib_inline.backend_inline import set_matplotlib_formats

set_matplotlib_formats("pdf", "svg")

from deepthermal.FFNN_model import fit_FFNN, FFNN, init_xavier
from deepthermal.validation import (
    create_subdictionary_iterator,
    k_fold_cv_grid,
    add_dictionary_iterators,
)
from deepthermal.plotting import plot_result

import neural_reparam as nr

import experiments.curves as c1
import experiments.curves_2 as c2
from signatureshape.so3.dynamic_distance import (
    find_optimal_diffeomorphism,
    L2_metric,
    local_cost,
)

# make reproducible
seed = torch.manual_seed(0)

In [None]:
# Load data
import sys

sys.path.append("../../")
from signatureshape.animation.animation_manager import fetch_animations, unpack
from signatureshape.so3.curves import move_origin_to_zero, dynamic_distance
from signatureshape.so3.helpers import crop_curve
from signatureshape.so3.dynamic_distance import (
    find_optimal_diffeomorphism,
    create_shared_parameterization,
)
from signatureshape.so3.transformations import skew_to_vector, SRVT
from signatureshape.so3 import animation_to_SO3


print("Load data")

max_frame_count = 300
data = [
    fetch_animations(1, file_name="39_02.amc"),  # walk 6.5 steps
    fetch_animations(1, file_name="35_26.amc"),  # run/jog 3 steps
    fetch_animations(1, file_name="16_35.amc"),  # run/jog 3 steps
]

# walk
subject, animation, desc0 = unpack(data[2])
curve_full = animation_to_SO3(subject, animation)
curve = crop_curve(curve_full, stop=max_frame_count)  # first 2 seconds
c_0 = move_origin_to_zero(curve)
print(desc0)

# run
subject, animation, desc1 = unpack(data[1])
curve_full = animation_to_SO3(subject, animation)
curve = crop_curve(curve_full, stop=max_frame_count)  # first 2 seconds
c_1 = move_origin_to_zero(curve)
print(desc1)
print(c_0.shape)
print(c_1.shape)

# calculate distances
I0 = np.linspace(0, 1, c_0.shape[1])
I1 = np.linspace(0, 1, c_1.shape[1])
q_data_ = skew_to_vector(SRVT(c_0, I0))
r_data_ = skew_to_vector(SRVT(c_1, I1))
I, q_data, r_data = create_shared_parameterization(q0=q_data_, q1=r_data_, I0=I0, I1=I1)
shared_frames = I.shape[0]

In [None]:
len(np.argwhere(r_data_[0] == 0))

In [None]:
len(r_data[0])

In [None]:
r_data_[120].shape

In [None]:
# Setup env
data = (I, q_data, r_data)
N = len(I)
depth = int((N) ** (1 / 3))


def get_env():
    return nr.reparam_env.DiscreteReparamEnv(
        data=data, depth=depth, illegal_action_penalty=10
    )


local_cost_func = lambda start, end: local_cost(*start, *end, q0=q_data, q1=r_data, I=I)
N

In [None]:
#  ######
DIR = "../figures/curve_motion_rl/"
SET_NAME = f"dqn_4_{get_env().size}"
PATH_FIGURES = os.path.join(DIR, SET_NAME)
if not os.path.exists(PATH_FIGURES):
    os.makedirs(PATH_FIGURES)
########


FOLDS = 1
# loss_func = get_elastic_metric_loss(r=c1.r, constrain_cost=1e3, verbose=False)
# no_penalty_loss_func = get_elastic_metric_loss(r=c1.r, constrain_cost=0, verbose=False)
lr_scheduler = lambda optim: torch.optim.lr_scheduler.ReduceLROnPlateau(
    optim, mode="min", factor=0.5, patience=100, verbose=True
)


MODEL_PARAMS = {
    "model": [FFNN],
    "input_dimension": [2],
    "output_dimension": [get_env().num_actions],
    "activation": ["relu"],
    "n_hidden_layers": [5],
}

# extend the previous dict with the zip of this
MODEL_PARAMS_EXPERIMENT = {
    "neurons": [32],
}
TRAINING_PARAMS = {
    "get_env": [get_env],
    "batch_size": [1000],
    "initial_steps": [int(3e4)],
    "memory_size": [int(1e5)],
    "update_every": [N],
    "DDQN": [True],
    "num_epochs": [2 * N],
    "verbose_interval": [20],
    "learning_rate": [0.01],
    "optimizer": ["ADAM"],
}
# extend the previous dict with the zip of this
TRAINING_PARAMS_EXPERIMENT = {
    "epsilon": [0.01],
    # "lr_scheduler": [lr_scheduler],
}

In [None]:
# create iterators
model_params_iter_1 = create_subdictionary_iterator(MODEL_PARAMS)
# model_params_iter = chain.from_iterable((model_params_iter_1, model_params_iter_2))

model_exp_iter = create_subdictionary_iterator(MODEL_PARAMS_EXPERIMENT, product=False)
exp_model_params_iter = add_dictionary_iterators(model_exp_iter, model_params_iter_1)

training_params_iter = create_subdictionary_iterator(TRAINING_PARAMS)
training_exp_iter = create_subdictionary_iterator(
    TRAINING_PARAMS_EXPERIMENT, product=False
)
exp_training_params_iter = add_dictionary_iterators(
    training_exp_iter, training_params_iter
)

Do the actual training

In [None]:
cv_results = k_fold_cv_grid(
    model_params=exp_model_params_iter,
    fit=nr.rl.fit_dqn_deterministic,
    training_params=exp_training_params_iter,
    data=data,
    folds=FOLDS,
    verbose=True,
    trials=1,
)

In [None]:
# Find DP solution
I1_new = find_optimal_diffeomorphism(q0=q_data, q1=r_data, I0=I, I1=I, depth=depth)
path = np.stack((I, I1_new), axis=-1)

dp_distance = nr.rl.get_path_value(
    path=path, env=nr.reparam_env.RealReparamEnv(data=data)
)
c2.DIST_R_Q, dp_distance

In [None]:
model = cv_results["models"][0][0]
plot_kwargs = {
    "env": get_env(),
    "x_train": I,
    "y_train": I1_new,
    "x_axis": "t",
    "y_axis": "$\\varphi(t)$",
    "compare_label": "DP solution",
}
plot_result(
    path_figures=PATH_FIGURES,
    plot_function=nr.plot_solution_rl,
    function_kwargs=plot_kwargs,
    **cv_results
)
q_path = nr.rl.get_optimal_path(model=model, env=get_env())
#
# q_distance =  nr.rl.get_path_value(path=q_path, cost_func=local_cost_func)
# q_distance
# function_kwargs=plot_kwargs,

In [None]:
# N=1fromfrom6
# s = (N,N,2)
# test = torch.arange(np.prod(s)).reshape(s) % 5
# ind = torch.tensor(np.indices((N, N)).T)
# x_train2 = torch.linspace(0,1 , N, requires_grad=True)
# x_train2[ind]
plt.plot(q_path[:, 0], q_path[:, 1])
plt.show()

In [None]:
test_path = np.array([[0, 0], [2, 1], [3, 5], [6, 6], [9, 10], [14, 13], [15, 15]])
q_distance = nr.rl.get_path_value(path=test_path, cost_func=local_cost_func)
q_distance