In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from itertools import chain
import torch
from torch.utils.data import TensorDataset
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from matplotlib_inline.backend_inline import set_matplotlib_formats

set_matplotlib_formats("pdf", "svg")

from deepthermal.FFNN_model import fit_FFNN, FFNN, init_xavier
from deepthermal.validation import (
    create_subdictionary_iterator,
    k_fold_cv_grid,
    add_dictionary_iterators,
)
from deepthermal.plotting import plot_result, plot_model_1d

from neural_reparam.reparametrization import (
    get_elastic_metric_loss,
    compute_loss_reparam,
)
from neural_reparam.ResNet import ResNet
from neural_reparam.models import ResCNN, BResCNN, CNN
from neural_reparam.reparam_env import (
    get_epsilon_greedy,
    get_optimal_path,
    DiscreteReparamEnv,
    plot_solution_rl,
    DiscreteReparamReverseEnv,
)
from neural_reparam.reinforcement_learning import fit_dqn_deterministic
from neural_reparam.reparam_env import RealReparamEnv

import experiments.curves as c1
import experiments.curves_2 as c2
from signatureshape.so3.dynamic_distance import find_optimal_diffeomorphism
import spinup

# make reproducible
seed = torch.manual_seed(0)

In [None]:
N = 16
ac = spinup.ddpg_pytorch(
    lambda: RealReparamEnv(r_func=c1.r, q_func=c1.q, size=N, action_penalty=1),
    steps_per_epoch=3000,
    epochs=50,
    replay_size=int(1e4),
    gamma=1,
    polyak=0.5,
    pi_lr=0.1,
    q_lr=0.1,
    batch_size=100,
    start_steps=2000,
    update_after=1000,
    update_every=10,
    act_noise=0.1,
    num_test_episodes=1,
    max_ep_len=N**2,
    save_freq=1,
)

In [None]:
x_train = torch.linspace(0, 1, N, requires_grad=True)
q_train = c1.q(x_train.unsqueeze(1).detach())
r_train = c1.r(x_train.unsqueeze(1).detach())

data = TensorDataset(x_train, q_train, r_train)

size = len(x_train)
ind = torch.as_tensor(np.indices((size, size)).T)

grid = x_train[ind]
q_values = torch.zeros(N, N)
# for
# print(ac.q.(grid))
for i in range(N):
    for j in range(N):
        q_values[i, j] = ac.q(
            obs=torch.tensor([i, j]).float(), act=torch.tensor([0.5, 0.5]).float()
        )
plt.imshow(q_values.detach().numpy())
q_values[0, 0], q_values[N - 1, N - 1]