### DQN experiment 3

This notebook implements Experiment 1 in Chapter 3.

We find a reparametrization of motion capture data using DQN.


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from functools import partial
from itertools import chain
import torch
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
import pandas as pd
import numpy as np
from matplotlib_inline.backend_inline import set_matplotlib_formats
from matplotlib.lines import Line2D
import extratorch as etorch

import neural_reparam as nr

In [None]:
# make reproducible
seed = torch.manual_seed(0)

# better plotting
set_matplotlib_formats("pdf", "svg")
matplotlib.rcParams.update({"font.size": 12})
set_matplotlib_formats("pdf", "svg")
plt.style.use("tableau-colorblind10")
sns.set_style("white")

Load data:

In [None]:
from signatureshape.animation.animation_manager import fetch_animations, unpack
from signatureshape.so3.curves import move_origin_to_zero
from signatureshape.so3.helpers import crop_curve
from signatureshape.so3.dynamic_distance import (
    find_optimal_diffeomorphism,
    create_shared_parameterization,
)
from signatureshape.so3.transformations import skew_to_vector, SRVT
from signatureshape.so3 import animation_to_SO3


print("Load data:")

max_frame_count = 360  # 3 sek
data = [
    fetch_animations(1, file_name="39_02.amc"),  # walk 6.5 steps
    fetch_animations(1, file_name="35_26.amc"),  # run/jog 3 steps
    fetch_animations(1, file_name="16_35.amc"),  # run/jog 3 steps
    fetch_animations(
        1, file_name="38_04.amc"
    ),  # walk around, frequent turns, cyclic walk along a line
    fetch_animations(1, file_name="07_01.amc"),  # walk
    fetch_animations(1, file_name="07_12.amc"),  # brisk walk
]

# walk
subject, animation, desc0 = unpack(data[3])
curve_full = animation_to_SO3(subject, animation)
curve = crop_curve(curve_full, stop=max_frame_count)
c_0 = move_origin_to_zero(curve)
print("desc 0:", desc0)

# run
subject, animation, desc1 = unpack(data[5])
curve_full = animation_to_SO3(subject, animation)
curve = crop_curve(curve_full, stop=max_frame_count)  # first 2 seconds
c_1 = move_origin_to_zero(curve)
print("desc 1:", desc1)

# calculate distances
I0 = np.linspace(0, 1, c_0.shape[1])
I1 = np.linspace(0, 1, c_1.shape[1])
q_data_ = skew_to_vector(SRVT(c_0, I0))
r_data_ = skew_to_vector(SRVT(c_1, I1))
I, q_data, r_data = create_shared_parameterization(q0=q_data_, q1=r_data_, I0=I0, I1=I1)
shared_frames = I.shape[0]
print("shared_frames:", shared_frames)

# data used to create env
data = (I[::8], q_data[::8], r_data[::8])
print("Shorted frames:", len(data[0]))

Setup model parameters

In [None]:
depth = 4
get_env = partial(nr.reparam_env.DiscreteReparamEnv, data=data)
nlist = np.array([len(data[0])])

#  ######
DIR = "../figures/motion_data_rl/"
SET_NAME = f"dqn_1_N{nlist[0]}"
PATH_FIGURES = os.path.join(DIR, SET_NAME)
########

default_env = get_env(size=nlist[0], depth=depth)
MODEL_PARAMS = {
    "model": [etorch.FFNN],
    "input_dimension": [2],
    "output_dimension": [default_env.num_actions],
    "activation": ["relu"],
    "n_hidden_layers": [5],
    "neurons": [32],
}
TRAINING_PARAMS = {
    "get_env": [get_env],
    "epsilon": [0.05, 0.01, 0.005],
    "DDQN": [True, False],
    "update_every": [10, 100, 1000],
    "double_search": [True, False],
    "optimizer": ["ADAM"],
    "num_epochs": [200],
    "learning_rate": [0.1, 0.01, 0.001],
    "verbose": [True],
}
EXTRA_TRAINING_PARAMS = {
    "N": nlist,  # for easier plotting
    "batch_size": nlist * 2,
    "initial_steps": nlist * 20,
    "memory_size": nlist * 100,
    "env_kwargs": [dict(size=n, depth=depth) for n in nlist],
}

In [None]:
# create iterators
model_params_iter = etorch.create_subdictionary_iterator(MODEL_PARAMS)

t_iter_temp_1 = etorch.create_subdictionary_iterator(
    EXTRA_TRAINING_PARAMS, product=False
)
t_iter_temp_2 = etorch.create_subdictionary_iterator(
    TRAINING_PARAMS,
)

training_params_iter = etorch.add_dictionary_iterators(
    t_iter_temp_1, t_iter_temp_2, product=True
)

Do the actual training

In [None]:
cv_results = etorch.k_fold_cv_grid(
    model_params=model_params_iter,
    fit=nr.rl.fit_dqn_deterministic,
    training_params=training_params_iter,
    folds=1,
    verbose=True,
    trials=1,
)

Plot and store results

In [None]:
# plot all solutions

plot_kwargs = {
    "env": default_env,
    "x_axis": "t",
    "y_axis": "$\\varphi(t)$",
}
etorch.plotting.plot_result(
    path_figures=PATH_FIGURES,
    plot_function=nr.plot_solution_rl,
    function_kwargs=plot_kwargs,
    **cv_results
)
PATH_FIGURES

In [None]:
# combine histories and log
histories_iter = (df[-1:] for df in chain(*cv_results["histories"]))
history_df = pd.concat(histories_iter)
history_df.index = range(len(history_df))
log_df = pd.concat(
    [cv_results["model_params"], cv_results["training_params"], history_df], axis=1
)
log_df.columns = log_df.columns.str.replace("double_search", "2-greedy")
log_df.columns = log_df.columns.str.replace("Cost", "Final cost")

In [None]:
# Find DP, greedy solution and plot it
from neural_reparam.plotting import plot_value_func

n_cost = {"dp": [], "greedy": []}
for n in nlist:
    n_env = get_env(size=n, depth=depth)
    I1_new, path, A = find_optimal_diffeomorphism(
        q0=n_env.q_data,
        q1=n_env.r_data,
        I0=n_env.t_data,
        I1=n_env.t_data,
        depth=depth,
        return_all=True,
    )
    n_cost["greedy"].append(-nr.rl.get_value(model=None, env=n_env, double_search=True))
    n_cost["dp"].append(
        nr.rl.get_path_value(path=path, reward_func=n_env.dp_local_cost)
    )
    fig = plot_value_func(
        value_matrix=A[::-1, ::-1],
        t_data=n_env.t_data,
        path=path,
        path_figures=PATH_FIGURES,
        plot_name=f"dp_solution_{n}",
        x_axis="t",
        y_axis="$\\varphi(t)$",
    )

In [None]:
# amke log_df with Method
solutions_df1 = log_df[["N", "Final cost"]].copy()
solutions_df1["Method"] = "DQN"
solutions_df1

In [None]:
# add greedy and dp  solutions to table
temp = pd.DataFrame(n_cost)
temp["N"] = nlist
temp.columns = temp.columns.str.replace("dp", "Dynamic programming")
temp.columns = temp.columns.str.replace("greedy", "Greedy")
solutions_df2 = pd.melt(temp, id_vars=["N"], var_name="Method", value_name="Final cost")

In [None]:
# plot all solutions in histplot
fig = sns.displot(
    log_df,
    x="Final cost",
    bins=20,
    alpha=1,
    log_scale=(False, True),
)
fig.savefig(os.path.join(PATH_FIGURES, "end_cost.pdf"))

In [None]:
# plot final cost and q_loss
dp_distance, greedy_distance = n_cost["dp"][-1], n_cost["greedy"][-1]

sns.relplot(
    data=log_df,
    x="Final cost",
    y="Q loss",
    hue="2-greedy",
    style="2-greedy",
)
plt.xscale("log")
plt.yscale("log")
#
ylim = plt.ylim()
(line1,) = plt.plot([dp_distance, dp_distance], ylim, lw=1, label="Minimal cost")
(line2,) = plt.plot(
    [greedy_distance, greedy_distance],
    ylim,
    color="grey",
    ls="dashed",
    lw=1,
    label="Greedy strategy cost",
)
plt.legend(handles=[line1, line2])
# plt.xticks([10,100])
# plt.xlim([10,None])
plt.savefig(os.path.join(PATH_FIGURES, "q_loss_and_cost.pdf"))

In [None]:
better = len(log_df[log_df["Final cost"] < greedy_distance])
all = len(log_df)

print(f"Number better than greedy strategy: {better }, total: {all} ({better/all}%)")

In [None]:
# plot solutions for different N
sns.catplot(
    x="N", y="Final cost", data=solutions_df1, kind="box", hue="Method", legend=False
)

sns.pointplot(
    x="N",
    y="Final cost",
    data=solutions_df2,
    kind="point",
    hue="Method",
    markers=["o", "x"],
    linestyles=["-", "--"],
    alpha=1,
)
plt.yscale("log")
plt.savefig(os.path.join(PATH_FIGURES, "q_loss_different_n.pdf"))

In [None]:
# find the best and worst trial
print("Best index:", log_df.sort_values("Final cost")[0:1].index[0])
print("Worst index:", log_df.sort_values("Final cost")[-2:-1].index[0])

In [None]:
# Read and plot histories properly
model_num = 19

hist_file_name = f"history_plot_{model_num}_0.csv"
hist = pd.read_csv(os.path.join(PATH_FIGURES, hist_file_name))

sns.lineplot(x="Episode", y="Q loss", data=hist)
plt.yscale("log")
ax2 = plt.twinx()
sns.lineplot(x="Episode", y="Cost", data=hist, ax=ax2, color="#FF800E", linestyle="--")
ax2.legend(
    handles=[
        Line2D([], [], label="Q loss"),
        Line2D([], [], color="#FF800E", linestyle="--", label="Cost"),
    ]
)
plt.yscale("log")
plt.tight_layout()
plt.savefig(os.path.join(PATH_FIGURES, "plot_" + hist_file_name[:-4]) + ".pdf")