## Hypermodel

The hypermodel is simply one for which layer weights have been extended as an affine transformation of hyperparameters represented in R. Similar to a first order taylor expansion of the weights as a function of the hyperparameters.
![title](img/hyper_model.PNG)

For this we need:
* Way to extend model weights as affine transformation of hyperparameters, like an intermediate layer or new layer
* Clear separation of parameter weights and hyperparameter weights to compute gradients independently
* Way to update h_space values in model (not the weights but actual values that modify model behavior like a dropout probability) (m_space hyperparameter weights should be updated by the framework when computing gradients)

## Hyperparameter management

We need to have direct access to hyperparameter tensors to:
* Apply invertible mappings from the hyperparameter ordinal space to R for network updates and back to the ordinal space to update the hyperparameter trajectories and compute their effects on network training. (categorical hyperparameters are not allowed, cardinality is not a requirement though and spaces can be bounded to an arbitrary interval)
* Optimize the hyperparameters in R

Hyperparameters participate in the method like:
* Hyperparameter is created in h_space with initial values and mapped to m_space. So we have two parallel representations of the Hyperparameter.
* The target model hyperparameter values in h_space and m_space are updated. A dropout probability would be an example of h_space and their m_space representation is perturbed and injected as input alongside X values. So the network requires both h_space and m_space representations. If hyperparameters affect the data generation process that produces X instances, then only the m_space is actually updated, depends on where we inject data manipulation code.
* Each hyperparameter is perturbed with its own scale parameter. The latter one is also optimized after hyperparameter values are updated in the validation level of the whole optimization process.

This means that we need to include in our design
* initialize hparam
* model update hparam
* perturbe hparam
* update hparam values and scale

## Hypertraining step

The hypertraining is a two level optimization process. As presented in 
![title](img/hyper_training_algo.PNG)

For this we need:
* Training batches
* \Validation batches (notice m steps of validation can be different from n steps of training, but we have the same number of "epochs")
* Evaluation batch to follow function trajectory
* Hyperparameter perturbations for all batches
* Three optimizers:
    * model parameter optimizer
    * hyperparameter optimizer
    * hp scale optimizer
* Entropy term for validation level optimization

# TF Module hypertraining

## Setup model with hyper layers and functions

In [None]:
import tensorflow as tf

In [None]:
@tf.function
def make_unc_batch_tensor(unc_tensor, scale_tensor, perturbation, batch_size):
    unc_repeated = tf.reshape(tf.repeat(unc_tensor, repeats=[batch_size]), (batch_size, 1))
    return unc_repeated + tf.math.softplus(scale_tensor) * perturbation

In [None]:
import math
class HyperDense(tf.Module):
    def __init__(
        self, in_features, out_features,
        unc_tensor, scale_tensor, perturbation, batch_size,
        with_relu=True, name="HyperDense"
    ):
        super(HyperDense, self).__init__(name=name)
        # hyperparameters setup
        self.unc_tensor = unc_tensor
        self.scale_tensor = scale_tensor
        self.perturbation = perturbation
        self.batch_size = batch_size

        # layer setup
        layer_dtype = tf.float64
        hyper_dim = 1
        self.with_relu = with_relu
        
        stdv = 1. / math.sqrt(in_features)
        with tf.name_scope(name) as scope:
            stdv = 1. / math.sqrt(in_features)
            self.w = tf.Variable(tf.random.uniform([in_features, out_features], -stdv, stdv, dtype=layer_dtype),
                                 name="weights", dtype=layer_dtype)
            self.hw = tf.Variable(tf.random.uniform([in_features, out_features], -stdv, stdv, dtype=layer_dtype),
                                  name="hweights", dtype=layer_dtype)
            self.kw = tf.Variable(tf.random.normal([hyper_dim, 1], stddev=0.1, dtype=layer_dtype),
                                  name="hkweights", dtype=layer_dtype)
            self.b = tf.Variable(tf.random.uniform([out_features], -stdv, stdv, dtype=layer_dtype),
                                 name="bias", dtype=layer_dtype)
            self.hb = tf.Variable(tf.random.uniform([out_features], -stdv, stdv, dtype=layer_dtype),
                                  name="hbias", dtype=layer_dtype)
            self.kb = tf.Variable(tf.random.normal([hyper_dim, 1], stddev=0.1, dtype=layer_dtype),
                                  name="hkbias", dtype=layer_dtype)

    @tf.function
    def __call__(self, x, training=True):
        hyper_unc_batch = make_unc_batch_tensor(self.unc_tensor, self.scale_tensor, self.perturbation, self.batch_size)
        oy = tf.matmul(x, self.w) + self.b
        hw = tf.linalg.matmul(hyper_unc_batch, self.kw) * tf.matmul(x, self.hw)
        hb = tf.linalg.matmul(hyper_unc_batch, self.kb) * self.hb
        y = oy + hw + hb
        if self.with_relu:
            return tf.nn.relu(y)
        return y

In [None]:
class HyperModel(tf.Module):
    def __init__(self, unc_tensor, scale_tensor, perturbation, batch_size,
                 name="HyperModel"):
        super(HyperModel, self).__init__(name=name)
        # hyperparameter setup
        self.unc_tensor = unc_tensor
        self.scale_tensor = scale_tensor
        self.perturbation = perturbation
        self.batch_size = batch_size

        # model setup
        input_dim = 1
        depth = 3  # L + 1
        width = 4  # M + input_dim + 1

        self.layers = []
        self.layers.append(HyperDense(input_dim, width,
                                      unc_tensor, scale_tensor, perturbation, batch_size,
                                      with_relu=False, name="dense_input"))
        for i in range(depth):
            self.layers.append(HyperDense(width, width,
                                          unc_tensor, scale_tensor, perturbation, batch_size,
                                          with_relu=True, name=(f"hidden_{i}")))
        self.layers.append(HyperDense(width, 1, 
                                      unc_tensor, scale_tensor, perturbation, batch_size,
                                      with_relu=False, name="dense_output"))

    def set_perturbation_var(self, perturbation):
        self.perturbation = perturbation
        for layer in self.layers:
            layer.perturbation = perturbation

    @tf.function
    def __call__(self, x, training=True):
        next_input = x
        for layer in self.layers:
            next_input = layer(next_input, training)
        return next_input

    @tf.function
    def update_perturbations(self, gen):
        self.perturbation.assign(gen.normal(shape=(self.batch_size, 1), dtype=tf.float64))
        return self.perturbation

In [None]:
@tf.function
def logit(x):
    return tf.math.log(x) - tf.math.log(tf.constant(1.0, dtype=tf.float64) - x)

@tf.function
def s_logit(x, min_val, max_val):
    return logit((x - min_val)/(max_val-min_val))

@tf.function
def inv_softplus(x):
    return tf.math.log(tf.math.exp(x) - tf.constant(1.0, dtype=tf.float64))

In [None]:
@tf.function
def s_sigmoid(unc_hyperparam, min_val, max_val):
    return (max_val - min_val) * tf.math.sigmoid(unc_hyperparam) + min_val

# con_batch_tensor for bounded hyperparameter
@tf.function
def make_con_batch_tensor(unc_tensor, scale_tensor, perturbation, batch_size, min_val, max_val):
    unc_batch_tensor = make_unc_batch_tensor(unc_tensor, scale_tensor, perturbation, batch_size)
    return s_sigmoid(unc_batch_tensor, min_val, max_val)

In [None]:
import math

@tf.function
def entropy_term(hscale):
    scale = tf.math.softplus(hscale)
    return tf.math.log(scale * tf.math.sqrt(tf.constant(2.0 * math.pi * math.e, dtype=tf.float64)))

In [None]:
@tf.function
def make_training_data_tensor(unc_x_range, scale_x_range, perturbation, batch_size, min_val, max_val, gen):
    con_x_range_batch = make_con_batch_tensor(unc_x_range, scale_x_range, perturbation, batch_size, min_val, max_val)
    x_sample = gen.uniform(
        shape=(batch_size, 1),
        minval=-1, maxval=1,
        dtype=tf.float64)
    scaled_sample = tf.multiply(con_x_range_batch, x_sample)
    return scaled_sample

## Setup variables for training

In [None]:
clear_session = tf.keras.backend.clear_session
FRAMEWORK_SEED = 39  # nice ones: 40, 42, 39; bad ones: 41
clear_session()
tf.random.set_seed(FRAMEWORK_SEED)

In [None]:
# setup x_range hyperparameter variables
x_range_hyper_init = tf.constant(0.5, dtype=tf.float64)
min_val = tf.constant(0.01, dtype=tf.float64)
max_val = tf.constant(10, dtype=tf.float64)

unc_x_range_hyper_init = s_logit(x_range_hyper_init, min_val, max_val)
unc_x_range = tf.Variable(unc_x_range_hyper_init, name="unc_x_range", dtype=tf.float64)
scale_x_range = tf.Variable(inv_softplus(tf.constant(0.5, dtype=tf.float64)), name="scale_x_range", dtype=tf.float64)
unc_x_range, scale_x_range

In [None]:
batch_size = 100
gen = tf.random.Generator.from_seed(42)
perturbation = tf.Variable(gen.normal(shape=(batch_size, 1), dtype=tf.float64), name="perturbation", dtype=tf.float64)

In [None]:
model = HyperModel(unc_x_range, scale_x_range, perturbation, batch_size)

In [None]:
weights = [v for v in model.variables if "x_range" not in v.name and "perturbation" not in v.name]
weights

In [None]:
# opt_weights = tf.keras.optimizers.SGD(learning_rate=0.01)
opt_weights = tf.keras.optimizers.Adam(learning_rate=0.01)
opt_hyper = tf.keras.optimizers.Adam(learning_rate=0.003)
opt_scale = tf.keras.optimizers.Adam(learning_rate=0.003)

## Setup training loops

In [None]:
weights[0]

In [None]:
import numpy as np

def weights_training_step():
    # update noise
    model.update_perturbations(gen)

    # generate X
    X_train = make_training_data_tensor(unc_x_range, scale_x_range, perturbation, batch_size, min_val, max_val, gen)
    Y_train = X_train ** tf.constant(2.0, dtype=tf.float64)

    # set tape loss
    with tf.GradientTape(watch_accessed_variables=False) as tape:
        tape.watch(weights)
        Y_pred = model(X_train, training=True)
        loss = tf.keras.losses.MSE(Y_train, Y_pred)

    # update gradient
    grads = tape.gradient(loss, weights)
    processed_grads = grads
    # processed_grads = [(tf.clip_by_value(grad, -1.0, 1.0)) for grad in grads]
    opt_weights.apply_gradients(zip(processed_grads, weights))
    return np.sum(loss.numpy())

weights_training_step()

In [None]:
weights[0]

## Setup hyperparameters training loop

In [None]:
unc_x_range, scale_x_range

In [None]:
entropy_coefficient = tf.constant(0.001, dtype=tf.float64)
VALIDATION_X_RANGE = 2.0
def hyperparameters_training_step():
    # update noise
    model.update_perturbations(gen)

    # generate X
    X_train = gen.uniform(
        shape=(batch_size, 1),
        minval=-VALIDATION_X_RANGE, maxval=VALIDATION_X_RANGE,
        dtype=tf.float64)
    Y_train = X_train ** tf.constant(2.0, dtype=tf.float64)

    # set tape loss hyper
    with tf.GradientTape(watch_accessed_variables=False) as tape:
        tape.watch(unc_x_range)
        Y_pred = model(X_train, training=True)
        loss = tf.keras.losses.MSE(Y_train, Y_pred) - (entropy_coefficient * entropy_term(scale_x_range))

    # update gradient hyper
    grad = tape.gradient(loss, unc_x_range)
    processed_grad = grad
    # processed_grad = tf.clip_by_value(grad, -1.0, 1.0)
    opt_hyper.apply_gradients(zip([processed_grad], [unc_x_range]))

    # set tape loss scale
    with tf.GradientTape(watch_accessed_variables=False) as tape:
        tape.watch(scale_x_range)
        Y_pred = model(X_train, training=True)
        loss = tf.keras.losses.MSE(Y_train, Y_pred) - (entropy_coefficient * entropy_term(scale_x_range))

    # update gradient scale
    grad = tape.gradient(loss, scale_x_range)
    processed_grad = grad
    # processed_grad = tf.clip_by_value(grad, -1.0, 1.0)
    opt_scale.apply_gradients(zip([processed_grad], [scale_x_range]))
    return np.sum(loss.numpy())

hyperparameters_training_step()

In [None]:
unc_x_range, scale_x_range

In [None]:
s_sigmoid(unc_x_range, min_val, max_val).numpy()

## training trajectory

In [None]:
def theoretical_bounds_metric(
    y_train: np.ndarray,
    y_pred: np.ndarray,
    L: int
) -> float:
    # Theoretical distance used by Weinan et al in
    # https://arxiv.org/pdf/1807.00297.pdf
    assert len(y_train.shape) == 2  # shape (n, 1) expected
    assert y_train.shape == y_pred.shape
    # - pow(2, -2 * L) expected bound for validation in [-1, 1]
    return float(np.max(np.abs(y_train - y_pred)))

In [None]:
import numpy as np

X_eval = tf.constant(np.expand_dims(np.linspace(-2.0, 2.0, 100), 1), dtype=tf.float64)
Y_eval = X_eval ** tf.constant(2.0, dtype=tf.float64)

In [None]:
from tqdm.notebook import tqdm

In [None]:
x_range_trajectory = []
x_scaling_trajectory = []
dist_trajectory = []
preds = []
wlosses = []
hlosses = []

MAX_EPOCHS = 10000
WARMUP_STEPS = 5
WEIGHT_STEPS = 2
HYPER_STEPS = 2
PRED_SAMPLING = 200

for _ in range(WARMUP_STEPS):
    weights_training_step()

for step in tqdm(range(MAX_EPOCHS), total=MAX_EPOCHS, unit="batch"):
    # hypertraining
    for _ in range(WEIGHT_STEPS):
        wloss = weights_training_step()
    for _ in range(HYPER_STEPS):
        hloss = hyperparameters_training_step()
    # eval and metrics
    Y_pred = model(X_eval, training=False)
    dist = theoretical_bounds_metric(Y_eval.numpy(), Y_pred.numpy(), 2)
    
    x_range_trajectory.append(s_sigmoid(unc_x_range, min_val, max_val).numpy())
    x_scaling_trajectory.append(tf.math.softplus(scale_x_range).numpy())
    wlosses.append(wloss)
    hlosses.append(hloss)

    if step % PRED_SAMPLING == 0:
        preds.append(Y_pred.numpy())
        dist_trajectory.append(dist)
        print("dist: ", dist)
    if dist < 0.05:
        break

In [None]:
import matplotlib.pyplot as plt
plt.plot(x_range_trajectory)

In [None]:
plt.plot(x_scaling_trajectory)

In [None]:
plt.plot([min(d, 1) for d in dist_trajectory])

In [None]:
plt.plot(wlosses)

In [None]:
plt.plot([min(l, 10) for l in hlosses])

In [None]:
from self_tuning_nets.visualization import function_animation
f_eval = X_eval.numpy()
f_trajectory = preds
function_animation(f_eval, [f_trajectory], ["b"])