Try to learn a reparametrized version of 2

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from torch.utils.data import TensorDataset

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from matplotlib_inline.backend_inline import set_matplotlib_formats
set_matplotlib_formats('pdf', 'svg')

from deepthermal.FFNN_model import fit_FFNN, FFNN, init_xavier
from deepthermal.validation import create_subdictionary_iterator, k_fold_cv_grid, add_dictionary_iterators

from deepthermal.plotting import plot_result, plot_model_1d

from deep_reparametrization.plotting import plot_reparametrization
from deep_reparametrization.reparametrization import (
    get_elastic_metric_loss,
    compute_loss_reparam,
    get_elastic_error_func,
)
from deep_reparametrization.ResNet import ResNet
import experiments.curves_3 as c3

In [None]:
"""Test reparam of curves that are not equivalent """
########
PATH_FIGURES = "../figures/curve_3"
########

SET_NAME = "curve_3_exp_1"

FOLDS = 1
N = 128  # training points internal

# lr_scheduler =  lambda optimizer : optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=30,
#                                                                         factor=0.8, verbose=True)
loss_func = get_elastic_metric_loss(r=c3.r, constrain_cost=1e4, verbose=False)
MODEL_PARAMS = {
    "model": [ResNet, FFNN],
    "input_dimension": [1],
    "output_dimension": [1],
    "activation": ["tanh"],
    "n_hidden_layers": [4, 8, 16],
}
# extend the previous dict with the zip of this
MODEL_PARAMS_EXPERIMENT = {
    "neurons": [4, 16, 64],
}
TRAINING_PARAMS = {
    "batch_size": [N],
    "regularization_param": [0],
    "compute_loss": [compute_loss_reparam],
    "loss_func": [loss_func],
}
# extend the previous dict with the zip of this
TRAINING_PARAMS_EXPERIMENT = {
    "optimizer": ["strong_wolfe"],
    "num_epochs": [10],
    "learning_rate": [ 1],
}

In [None]:
# Load data
x_train = torch.linspace(0, 1, N, requires_grad=True).unsqueeze(1)
q_train = c3.q(x_train.detach())


data = TensorDataset(x_train, q_train)

model_params_iter = create_subdictionary_iterator(MODEL_PARAMS)
model_exp_iter = create_subdictionary_iterator(MODEL_PARAMS_EXPERIMENT, product=False)
exp_model_params_iter = add_dictionary_iterators(model_exp_iter, model_params_iter)

training_params_iter = create_subdictionary_iterator(TRAINING_PARAMS)
training_exp_iter = create_subdictionary_iterator(TRAINING_PARAMS_EXPERIMENT, product=False)
exp_training_params_iter = add_dictionary_iterators(training_exp_iter, training_params_iter)

Do the actual training

In [None]:
cv_results = k_fold_cv_grid(
    model_params=exp_model_params_iter,
    fit=fit_FFNN,
    training_params=exp_training_params_iter,
    data=data,
    folds=FOLDS,
    verbose=True,
    trials=2
)

In [None]:
# plotting
x_train_ = x_train.detach()
x_sorted, indices = torch.sort(x_train_, dim=0)

plot_kwargs = {
    "x_test": x_sorted,
    "x_train": x_sorted,
    "y_train": c3.ksi(x_sorted),
    "x_axis": "t",
    "y_axis": "$\\xi(t)$",
}
plot_result(
    path_figures=PATH_FIGURES,
    plot_name=SET_NAME,
    **cv_results,
    plot_function=plot_model_1d,
    function_kwargs=plot_kwargs,
)

In [None]:
models = cv_results["models"]

parameters = np.vectorize(lambda model: sum(p.numel() for p in model.parameters()))(models).flatten()
model_type = np.vectorize(str)(models).flatten()
layers = np.vectorize(lambda model: model.n_hidden_layers)(models).flatten()
neurons = np.vectorize(lambda model: model.neurons)(models).flatten()
loss_array = np.vectorize(lambda model: loss_func(model, x_train, q_train).detach())(models).flatten()

loss_array -= c3.DIST_R_Q
# make data frame
d_results = pd.DataFrame({"loss": loss_array, "neurons": neurons, "layers":layers ,"parameters": parameters, "model": model_type})

d_res_ResNet = d_results[d_results.model =="FFNN"]
d_res_ffnn = d_results[d_results.model =="ResNet"]


In [None]:

fig_neurons= sns.lineplot(data =d_results, y="loss",  x="neurons", hue="model", ci=80, err_style="bars")
fig_neurons.set(xscale="log",yscale="log")
plt.show()

In [None]:
fig_layers= sns.lineplot(data =d_results, y="loss",  x="layers", hue="model", ci=80, err_style="bars")
fig_layers.set(yscale="log")
plt.show()

In [None]:
fig_params= sns.lineplot(data =d_results, y="loss",  x="parameters", hue="model", ci=80, err_style="bars")
fig_layers.set(xscale="log",yscale="log")
plt.show()



In [None]:
fig_scatter = sns.scatterplot(data =d_res_ResNet, y="loss",  x="layers", hue="neurons", style="model")
fig_params.set(yscale="log")
