In [3]:
from jax import numpy as np
from flax import linen as nn
import trainer
import ntk
import test
import train_states
import models
import utils

import dataset_sines_infinite
import dataset_sines_finite

from jax import random
from jax import numpy as np
from flax.core import FrozenDict
import optax

In [4]:
def small_network(n_neurons, activation, reg_dim):
    """
    Returns a small neural network (two layers, n_neurons per layer with specified activation)
    Use reg_dim to control the dimension of the output

    Compatible activations: "relu" and "tanh" (note: all experiments run with ReLU).
    """

    if activation == "relu":
        act_fn = nn.relu
    elif activation == "tanh":
        act_fn = nn.tanh

    class Regressor(nn.Module):
        @nn.compact
        def __call__(self, x):
            x = nn.Dense(n_neurons)(x)
            x = act_fn(x)

            x = nn.Dense(n_neurons)(x)
            x = act_fn(x)

            x = nn.Dense(reg_dim)(x)

            x = np.reshape(x, (-1, reg_dim))

            return x
    
    return Regressor()

In [5]:
model = models.small_network(40, "relu", 1)

In [19]:
key = random.PRNGKey(0)
key_init, key = random.split(key)
batch = random.uniform(key_init, shape=(5,1), minval=-5, maxval=5)

In [10]:
print(batch)

[[2.2542226 ]
 [0.8597934 ]
 [4.922459  ]
 [0.13030052]
 [3.881017  ]]


In [11]:
init_vars = model.init(key_init, batch)

In [13]:
print(init_vars)

{'params': {'Dense_0': {'kernel': Array([[ 0.8108792 ,  1.1420408 , -1.5375005 ,  1.1752645 ,  0.98903   ,
         0.7084738 ,  1.3740441 ,  0.26782113, -2.0145447 ,  0.68219405,
        -0.55005306,  0.01865971,  1.3353693 ,  0.2529426 ,  0.96965194,
         0.701057  ,  0.3524331 , -0.04911264, -0.16912822, -0.8375746 ,
         0.5150138 , -1.8802063 , -0.48177794,  0.03080738, -0.9859298 ,
        -0.5759436 ,  1.4276935 , -0.8445995 ,  0.20796399,  0.9777959 ,
         0.8330394 , -0.5431388 ,  0.38610137,  0.59357536, -0.66677487,
        -1.3662914 ,  0.93745106, -0.8884618 , -0.8272108 ,  0.9520352 ]],      dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0.], dtype=float32)}, 'Dense_1': {'kernel': Array([[-0.09744288, -0.26018167, -0.1353759 , ..., -0.08283684,
         0.01452957, -0.05381209],
       [-0.35149738,  0.050027

In [14]:
def apply_fn_wrapper(apply_fn, is_training):
    """
    Wraps apply_fn(variables, inputs) into apply_fn_bis(params, batch_stats, inputs).
    The is_training parameter is used to avoid errors:
    * If is_training=True, then the keyword mutable is set to True for the batch_stats
    * If is_training=False, then the keywork mutable is set to False.

    In either cases, only the output of the network will be returned.
    The updated batch_stats will be lost, and must be computed explicitely apart.
    """

    if is_training:
        def apply_fn2(params, batch_stats, inputs):
            # mutable, but the updated batch_stats is not used
            output, _ = apply_fn({"params": params, "batch_stats": batch_stats}, inputs, mutable=["batch_stats"])
            return output

        return apply_fn2

    else:
        def apply_fn2(params, batch_stats, inputs):
            # not mutable, no updated batch_stats
            output = apply_fn({"params": params, "batch_stats": batch_stats}, inputs)
            return output

        return apply_fn2

In [15]:
apply_fn = utils.apply_fn_wrapper(model.apply, True)

In [16]:
apply_fn_raw = model.apply

In [17]:
from jax import random
from jax import numpy as np
from jax import pmap
import jax
from jax import value_and_grad
from jax import jit
import time
from jax import lax

import nll

In [18]:
def step_identity_cov(key, current_state, n_tasks, K, data_noise, maddox_noise, n_devices, get_train_batch_fn):
    # Draw the samples for this step, and split it to prepare for pmap (jit'd)
    x_a, y_a, x_a_div, y_a_div = get_train_batch_fn(key, n_tasks, K, data_noise, n_devices)
    
    # Compute loss and gradient through gpu parallelization
    unaveraged_losses, (unaveraged_gradients_p, unaveraged_gradients_m) = pmap(pmapable_loss_identity_cov,
                             in_axes=(None, 0, 0, None),
                             static_broadcasted_argnums=(3)
                            )(current_state, x_a_div, y_a_div, maddox_noise)
    
    current_loss = np.mean(unaveraged_losses)
    current_gradients_p = jax.tree_map(lambda array: np.mean(array, axis=0), unaveraged_gradients_p)
    current_gradients_m = jax.tree_map(lambda array: np.mean(array, axis=0), unaveraged_gradients_m)
    
    # Update batch_stats "manually" (jit'd)
    new_batch_stats = batch_stats_updater(current_state, x_a)
    
    # Update state (parameters and optimizer)
    current_state = grad_applier_identity_cov(current_state, current_gradients_p, current_gradients_m, new_batch_stats)
    
    return current_state, current_loss