In [1]:
%env TF_FORCE_UNIFIED_MEMORY=1

env: TF_FORCE_UNIFIED_MEMORY=1


In [153]:
import unlimtd_f
import time
from jax import random, jit, pmap, value_and_grad, lax, vmap
import dataset_multi_infinite
import dataset_lines_infinite
import test
import plots
import ntk
import nll
import jax
from jax import numpy as np
import pickle
import models
import utils

In [154]:
seed = 1655235988902897757
print(seed)

1655235988902897757


In [155]:
def gaussian_posterior_full(kernel, kernel_self, x_a, y_a, x_b, maddox_noise):
    """
    Computes the gaussian posterior with this kernel and this data, on the queried inputs.
    x_a is a (batch_size, input_dims) array (! has lost n_tasks)
    y_a is a (batch_size, reg_dim) array (! has lost n_tasks)
    Returns the posterior covariance matrix
    """
    dim = y_a.shape[1]
    y_a = np.reshape(y_a, (-1,))

    cov_a_a = kernel_self(x_a)
    cov_a_a = cov_a_a + maddox_noise ** 2 * np.eye(cov_a_a.shape[0])
    cov_b_a = kernel(x_b, x_a)
    cov_b_b = kernel_self(x_b)
    
    L = scipy.linalg.cho_factor(cov_a_a)
    alpha = scipy.linalg.cho_solve(L, y_a)
    post_mean = cov_b_a @ alpha
    
    v = scipy.linalg.cho_solve(L, cov_b_a.T)
    post_cov = cov_b_b - cov_b_a @ v
    
    return np.reshape(post_mean, (-1, dim) ), post_cov

In [167]:
def nll(post_mu, post_cov, y_b, K):
    """
    Computes the NLL of this data (one task only) wrt the kernel
    x_a is a (batch_size, input_dims) array (! has lost n_tasks)
    y_a is a (batch_size, reg_dim) array (! has lost n_tasks)
    """
    L = jax.scipy.linalg.cho_factor(post_cov)
    alpha = jax.scipy.linalg.cho_solve(L, y_b-post_mu)
    return 0.5 * (y_b-post_mu).T @ alpha + np.sum(np.log(np.diag(L[0]))) + 0.5 * K/2 * np.log(2 * np.pi)

In [161]:
def nll_batch_one_kernel(kernel, kernel_self, x_a, y_a, maddox_noise, jacobian, mean):
    """
    NLL for a batch of tasks, when there is only one kernel (singGP)
    x_a is (n_tasks, batch_size, input_dims) (input_dims are (128, 128, 1) for vision, (1,) for toy problems)
    y_a is (n_tasks, batch_size, reg_dim)
    """
    
    K = len(x_a[0])
    split_size = K // 2

    for x_a_i, y_a_i in zip(x_a, y_a):
        x_a, x_b = x_a_i[:split_size], x_a_i[split_size:]
        y_a, y_b = y_a_i[:split_size], y_a_i[split_size:]
        mu_a = utils.falseaffine_correction0(jacobian, mean, x_a)
        mu_b = utils.falseaffine_correction0(jacobian, mean, x_b)
        print("mu_a:", np.array(mu_a)[0])  # Check for NaNs
        print("mu_b:", mu_b)
        
        post_mu, post_cov = gaussian_posterior_full(kernel, kernel_self, x_a, y_a - mu_a, x_b, maddox_noise)
        post_mu += mu_b
        print("post_mu:", post_mu)  # Check for NaNs
        print("post_cov:", post_cov)
        
        loss_here = nll(post_mu, post_cov, y_b, K)
        print("loss_here:", loss_here)  # Check for NaNs
    losses = np.reshape(losses, (-1,))
    
    return np.array(losses)

In [None]:
def falseaffine_correction0(jacobian, mean, x):
    # x is (batch_size, inputs...)
    batch_size = x.shape[0]
    # jacobian(x) @ mean is (batch_size * reg_dim, 1)
    return np.reshape(jacobian(x) @ mean, (batch_size, -1) )

falseaffine_correction = vmap(falseaffine_correction0, in_axes=(None, None, 0))

In [None]:
def gaussian_posterior_full(kernel, kernel_self, x_a, y_a, x_b, maddox_noise):
    """
    Computes the gaussian posterior with this kernel and this data, on the queried inputs.
    x_a is a (batch_size, input_dims) array (! has lost n_tasks)
    y_a is a (batch_size, reg_dim) array (! has lost n_tasks)
    Returns the posterior covariance matrix
    """
    dim = y_a.shape[1]
    y_a = np.reshape(y_a, (-1,))

    cov_a_a = kernel_self(x_a)
    cov_a_a = cov_a_a + maddox_noise ** 2 * np.eye(cov_a_a.shape[0])
    cov_b_a = kernel(x_b, x_a)
    cov_b_b = kernel_self(x_b)
    
    L = scipy.linalg.cho_factor(cov_a_a)
    alpha = scipy.linalg.cho_solve(L, y_a)
    post_mean = cov_b_a @ alpha
    
    v = scipy.linalg.cho_solve(L, cov_b_a.T)
    post_cov = cov_b_b - cov_b_a @ v
    
    return np.reshape(post_mean, (-1, dim) ), post_cov

In [None]:
def nll(post_mu, post_cov, y_b, K):
    """
    Computes the NLL of this data (one task only) wrt the kernel
    x_a is a (batch_size, input_dims) array (! has lost n_tasks)
    y_a is a (batch_size, reg_dim) array (! has lost n_tasks)
    """
    L = jax.scipy.linalg.cho_factor(post_cov)
    alpha = jax.scipy.linalg.cho_solve(L, y_b-post_mu)
    return 0.5 * (y_b-post_mu).T @ alpha + np.sum(np.log(np.diag(L[0]))) + 0.5 * K/2 * np.log(2 * np.pi)

In [111]:
from functools import partial

def nll_batch_average_identity_cov(current_params, current_mean, apply_fn, current_batch_stats, x_a, y_a, maddox_noise):
    kernel, kernel_self, jacobian = get_kernel_and_jac_identity_cov(apply_fn, current_params, current_batch_stats)
    
    return np.mean(nll_batch_one_kernel(kernel, kernel_self, x_a, y_a, maddox_noise, jacobian, current_mean))

In [112]:
def pmapable_loss_identity_cov(current_state, x_a, y_a, maddox_noise):
    # we can't pass current_state because we have to explicitely show the variable
    loss, (gradients_p, gradients_m) = value_and_grad(nll_batch_average_identity_cov, argnums = (0, 1) )(current_state.params,
                                                              current_state.mean,
                                                              current_state.apply_fn,
                                                              current_state.batch_stats,
                                                              x_a,
                                                              y_a,
                                                              maddox_noise)
    
    return loss, (gradients_p, gradients_m)

In [113]:
@jit
def batch_stats_updater(current_state, x_a):
    # shape of x_a is (n_tasks, batch_size, inputs_dims...)
    
    batch_stats = current_state.batch_stats
    
    def f(old_batch_stats, _x_a):
        # shape of _x_a is (batch_size, input_dims)
        _, mutated_vars = current_state.apply_fn_raw({"params":current_state.params,
                                                      "batch_stats": old_batch_stats},
                                                     _x_a,
                                                     mutable=["batch_stats"])
        
        new_batch_stats = mutated_vars["batch_stats"]
        return new_batch_stats, None

    batch_stats = dict(batch_stats)
    print(type(batch_stats))
    batch_stats, _ = lax.scan(f, batch_stats, x_a)
    return batch_stats

In [151]:
#@jit
def grad_applier_identity_cov(current_state, gradients_p, gradients_m, new_batch_stats):
    return current_state.apply_gradients(grads_params=gradients_p, grads_mean=gradients_m, new_batch_stats=new_batch_stats)

In [115]:
import trainer

def step_identity_cov(key, current_state, n_tasks, K, data_noise, maddox_noise, n_devices, get_train_batch_fn):
    # Draw the samples for this step, and split it to prepare for pmap (jit'd)
    x_a, y_a, x_a_div, y_a_div = get_train_batch_fn(key, n_tasks, K, data_noise, n_devices)
    
    # Compute loss and gradient through gpu parallelization
    unaveraged_losses, (unaveraged_gradients_p, unaveraged_gradients_m) = pmap(pmapable_loss_identity_cov,
                             in_axes=(None, 0, 0, None),
                             static_broadcasted_argnums=(3)
                            )(current_state, x_a_div, y_a_div, maddox_noise)
    
    current_loss = np.mean(unaveraged_losses)
    current_gradients_p = jax.tree_map(lambda array: np.mean(array, axis=0), unaveraged_gradients_p)
    current_gradients_m = jax.tree_map(lambda array: np.mean(array, axis=0), unaveraged_gradients_m)
    
    # Update batch_stats "manually" (jit'd)
    new_batch_stats = batch_stats_updater(current_state, x_a)
    
    # Update state (parameters and optimizer)
    current_state = grad_applier_identity_cov(current_state, current_gradients_p, current_gradients_m, new_batch_stats)
    
    return current_state, current_loss

## Unlimitd f training

In [116]:
from ntk import get_kernel_and_jac_lowdim_cov

def nll_batch_average_lowdim_cov_singGP(current_params, current_mean, current_scale, apply_fn, current_batch_stats, proj, x_a, y_a, maddox_noise):
    kernel, kernel_self, jacobian = get_kernel_and_jac_lowdim_cov(apply_fn, current_params, current_scale, current_batch_stats, proj)

    return np.mean(nll_batch_one_kernel(kernel, kernel_self, x_a, y_a, maddox_noise, jacobian, current_mean))

In [117]:
def pmapable_loss_lowdim_cov_singGP(current_state, x_a, y_a, maddox_noise):
    # we can't pass current_state because we have to explicitely show the variable
    loss, (gradients_p, gradients_m, gradients_s) = value_and_grad(nll_batch_average_lowdim_cov_singGP, argnums = (0, 1, 2) )(current_state.params,
                                                              current_state.mean,
                                                              current_state.scale,
                                                              current_state.apply_fn,
                                                              current_state.batch_stats,
                                                              current_state.proj,
                                                              x_a,
                                                              y_a,
                                                              maddox_noise)
    
    return loss, (gradients_p, gradients_m, gradients_s)

In [118]:
from trainer import grad_applier_lowdim_cov_singGP

def step_lowdim_cov_singGP(key, current_state, n_tasks, K, data_noise, maddox_noise, n_devices, get_train_batch_fn):
    # Draw the samples for this step, and split it to prepare for pmap (jit'd)
    x_a, y_a, x_a_div, y_a_div = get_train_batch_fn(key, n_tasks, K, data_noise, n_devices)
    
    # Compute loss and gradient through gpu parallelization
    unaveraged_losses, (unaveraged_gradients_p, unaveraged_gradients_m, unaveraged_gradients_s) = pmap(pmapable_loss_lowdim_cov_singGP,
                             in_axes=(None, 0, 0, None),
                             static_broadcasted_argnums=(3)
                            )(current_state, x_a_div, y_a_div, maddox_noise)
    
    current_loss = np.mean(unaveraged_losses)
    current_gradients_p = jax.tree_map(lambda array: np.mean(array, axis=0), unaveraged_gradients_p)
    current_gradients_m = jax.tree_map(lambda array: np.mean(array, axis=0), unaveraged_gradients_m)
    current_gradients_s = jax.tree_map(lambda array: np.mean(array, axis=0), unaveraged_gradients_s)
    
    # Update batch_stats "manually" (jit'd)
    new_batch_stats = batch_stats_updater(current_state, x_a)
    
    # Update state (parameters and optimizer)
    current_state = grad_applier_lowdim_cov_singGP(current_state, current_gradients_p, current_gradients_m, current_gradients_s, new_batch_stats)
    
    return current_state, current_loss

In [119]:
def train_and_eval(key, step, n_epochs, state, n_tasks, K, data_noise, maddox_noise, get_train_batch_fn, eval_during_training_fn):
    """
    Available step functions:
    * step_identity_cov
    * step_lowdim_cov_singGP
    * step_lowdim_cov_mixture

    Available get_train_batch_fn functions:
    * dataset_sines_infinite.get_training_batch
    * dataset_sines_finite.get_training_batch
    * dataset_lines_infinite.get_training_batch
    * dataset_multi_infinite.get_training_batch
    * dataset_shapenet1d.get_training_batch
    
    """
    n_devices = jax.local_device_count()

    print("Starting training with:")
    print(f"-n_epochs={n_epochs}")
    print(f"-n_tasks={n_tasks}")
    print(f"-K={K}")
    print(f"-data_noise={data_noise}")
    print(f"-maddox_noise={maddox_noise}")

    losses = []
    evals = []
    t = time.time_ns()

    for epoch_index in range(n_epochs):
        key, subkey = random.split(key)
        state, current_loss = step(subkey, state, n_tasks, K, data_noise, maddox_noise, n_devices, get_train_batch_fn)

        if(np.isnan(current_loss)):
            print("Nan, aborting")
            break
        
        losses.append(current_loss)

        if epoch_index % 10 == 0:
            print(f"{epoch_index}  | {current_loss:.4f} ({(time.time_ns() - t)/ 10**9:.4f} s)")
        t = time.time_ns()

        if epoch_index % 500 == 0:
            key, subkey = random.split(key)
            current_eval = eval_during_training_fn(subkey, state)
            evals.append( current_eval )
            print(f"Eval: {current_eval}")
    print("Completed training")

    return state, losses, evals

In [142]:
from nll import gaussian_posterior

def test_nll_one_kernel(key, kernel, kernel_self, jacobian, get_test_batch_fn, K, n_tasks, data_noise, maddox_noise, current_mean):
    """
    Returns the NLLs for n_tasks random tasks, in the singGP case.
    """
    x_a, y_a, _, _ = get_test_batch_fn(key, n_tasks, K, 0, data_noise)
    all_nlls = nll_batch_one_kernel(kernel, kernel_self, x_a, y_a, maddox_noise, jacobian, current_mean)

    return all_nlls

def test_error_one_kernel(key, kernel, kernel_self, jacobian, get_test_batch_fn, error_fn, K, L, n_tasks, data_noise, maddox_noise, current_mean):
    """
    Returns the error for n_tasks random tasks, in the singGP case.
    """
    x_a, y_a, x_b, y_b = get_test_batch_fn(key, n_tasks, K, L, data_noise)

    def f(carry, task_data):
        _x_a, _y_a, _x_b, _y_b = task_data
        _y_a = _y_a - utils.falseaffine_correction0(jacobian, current_mean, _x_a)
        predictions = gaussian_posterior(kernel, kernel_self, _x_a, _y_a, _x_b, maddox_noise)
        predictions = predictions + utils.falseaffine_correction0(jacobian, current_mean, _x_b)
        return None, error_fn(predictions, _y_b)
    
    _, all_errors = lax.scan(f, None, (x_a, y_a, x_b, y_b))

    return all_errors

In [143]:
import trainer
import ntk
import test
import train_states
import models
import utils
import fim

import dataset_sines_infinite
import dataset_sines_finite
import dataset_multi_infinite

from jax import random
from jax import numpy as np
from flax.core import FrozenDict
import optax

from ntk import get_kernel_and_jac_identity_cov
from ntk import get_kernel_and_jac_lowdim_cov
import unlimtd_f

def unlimtd_f_uni_modal_infinite(seed, pre_n_epochs, pre_n_tasks, pre_K, post_n_epochs, post_n_tasks, post_K, data_noise, maddox_noise, meta_lr, subspace_dimension):
    key = random.PRNGKey(seed)
    key_init, key = random.split(key)
    
    print("===============")
    print("This is UNLIMTD-F")
    print("For the uni-modal dataset: infinite sine dataset")
    print("This variant of UNLIMTD-F approaches the distribution with a single GP")
    print("===============")
    print("Creating model")
    model = models.small_network(40, "relu", 1)
    batch = random.uniform(key_init, shape=(5,1), minval=-5, maxval=5)
    init_vars = model.init(key_init, batch)
    apply_fn = utils.apply_fn_wrapper(model.apply, True)
    apply_fn_raw = model.apply

    # Training before finding the FIM matrix
    print("Creating optimizers")
    step = step_identity_cov
    get_train_batch_fn = dataset_sines_infinite.get_training_batch
    optimizer_params = optax.adam(learning_rate = meta_lr)
    optimizer_mean = optax.adam(learning_rate = meta_lr)
    mean_init = np.zeros( (utils.get_param_size(init_vars["params"]),) )

    pre_state = train_states.TrainStateIdentityCovariance.create(apply_fn=apply_fn, apply_fn_raw=apply_fn_raw, params=init_vars["params"], mean=mean_init, tx_params=optimizer_params, tx_mean=optimizer_mean, batch_stats=FrozenDict())
    
    def eval_during_pre_training(key, state):
        current_params = state.params
        current_batch_stats = state.batch_stats
        current_mean = state.mean
        kernel, kernel_self, jacobian = ntk.get_kernel_and_jac_identity_cov(apply_fn, current_params, current_batch_stats)

        subkey_1, subkey_2 = random.split(key)
        part_apply_fn = partial(apply_fn, current_params, current_batch_stats)
        nlls = test_nll_one_kernel(subkey_1, kernel, kernel_self, jacobian, dataset_sines_infinite.get_test_batch, K=pre_K, n_tasks=1000, data_noise=data_noise, maddox_noise=maddox_noise, current_mean=current_mean)
        mses = test_error_one_kernel(subkey_2, kernel, kernel_self, jacobian, dataset_sines_infinite.get_test_batch, dataset_sines_infinite.error_fn, K=pre_K, L=pre_K, n_tasks=1000, data_noise=data_noise, maddox_noise=maddox_noise, current_mean=current_mean)

        return np.mean(nlls), np.mean(mses)

    print("Starting first part of training (identity covariance)")
    key_pre, key = random.split(key)
    pre_state, pre_losses, pre_evals = train_and_eval(key_pre, step, pre_n_epochs, pre_state, pre_n_tasks, pre_K, data_noise, maddox_noise, get_train_batch_fn, eval_during_pre_training)
    print("Finished first part of training")

    # FIM
    print("Finding projection matrix")
    key_fim, key_data, key = random.split(key, 3)
    # here we use the exact FIM, we do not need to approximate given the (small) size of the network
    # P1 = fim.proj_exact(key=key_fim, apply_fn=apply_fn, current_params=pre_state.params, current_batch_stats=pre_state.batch_stats, subspace_dimension=subspace_dimension)
    P1 = fim.proj_sketch(key=key_fim, apply_fn=apply_fn, current_params=pre_state.params, batch_stats=pre_state.batch_stats, batches=random.uniform(key_data, shape=(100, 1761, 1), minval=-5, maxval=5), subspace_dimension=subspace_dimension)
    print("Found projection matrix")

    # Usual training with projection
    print("Creating optimizers")
    step = step_lowdim_cov_singGP
    optimizer_params = optax.adam(learning_rate = meta_lr)
    optimizer_mean = optax.adam(learning_rate = meta_lr)
    optimizer_scale = optax.adam(learning_rate = meta_lr)
    init_scale = np.ones( (subspace_dimension,) )

    post_state = train_states.TrainStateLowDimCovSingGP.create(apply_fn = apply_fn, apply_fn_raw=apply_fn_raw, params = pre_state.params, mean=pre_state.mean, scale=init_scale, tx_params = optimizer_params, tx_mean = optimizer_mean, tx_scale = optimizer_scale, batch_stats=pre_state.batch_stats, proj = P1)

    def eval_during_post_training(key, state):
        current_params = state.params
        current_batch_stats = state.batch_stats
        current_mean = state.mean
        current_scale = state.scale
        kernel, kernel_self, jacobian = ntk.get_kernel_and_jac_lowdim_cov(apply_fn, current_params, current_scale, current_batch_stats, P1)

        subkey_1, subkey_2 = random.split(key)
        part_apply_fn = partial(apply_fn, current_params, current_batch_stats)
        nlls = test_nll_one_kernel(subkey_1, kernel_self, jacobian, dataset_sines_infinite.get_test_batch, K=pre_K, n_tasks=1000, data_noise=data_noise, maddox_noise=maddox_noise, current_mean=current_mean)
        mses = test_error_one_kernel(subkey_2, kernel, kernel_self, jacobian, dataset_sines_infinite.get_test_batch, dataset_sines_infinite.error_fn, K=pre_K, L=pre_K, n_tasks=1000, data_noise=data_noise, maddox_noise=maddox_noise, current_mean=current_mean)

        return np.mean(nlls), np.mean(mses)

    print("Starting training")
    key_post, key = random.split(key)
    post_state, post_losses, post_evals = train_and_eval(key_post, step, post_n_epochs, post_state, post_n_tasks, post_K, data_noise, maddox_noise, get_train_batch_fn, eval_during_post_training)
    print("Finished training")

    # Returning everything
    return init_vars, pre_state, pre_evals, post_state, pre_losses, post_losses, post_evals


In [162]:
init_params, pre_state, pre_evals, post_state, pre_losses, post_losses, post_evals = unlimtd_f_uni_modal_infinite(seed=seed,
                                                                                     pre_n_epochs=1000,
                                                                                     pre_n_tasks=24,
                                                                                     pre_K=10,
                                                                                     post_n_epochs=1000,
                                                                                     post_n_tasks=24,
                                                                                     post_K=10,
                                                                                     data_noise=0.05, 
                                                                                     maddox_noise=0.05,
                                                                                     meta_lr=0.001,
                                                                                     subspace_dimension=10)

This is UNLIMTD-F
For the uni-modal dataset: infinite sine dataset
This variant of UNLIMTD-F approaches the distribution with a single GP
Creating model
Creating optimizers
Starting first part of training (identity covariance)
Starting training with:
-n_epochs=1000
-n_tasks=24
-K=10
-data_noise=0.05
-maddox_noise=0.05
mu_a: Traced<ShapedArray(float32[1])>with<JVPTrace(level=2/1)> with
  primal = Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=0/1)>
  tangent = Traced<ShapedArray(float32[1])>with<JaxprTrace(level=1/1)> with
    pval = (ShapedArray(float32[1]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7fb83c2e6030>, in_tracers=(Traced<ShapedArray(float32[1,1]):JaxprTrace(level=1/1)>,), out_tracer_refs=[<weakref at 0x7fb83c2300e0; to 'JaxprTracer' at 0x7fb83c233830>], out_avals=[ShapedArray(float32[1])], primitive=squeeze, params={'dimensions': (0,)}, effects=frozenset(), source_info=SourceInfo(traceback=<jaxlib.xla_extension.Traceback object at 0x7fb83c157

UnboundLocalError: local variable 'losses' referenced before assignment

In [84]:
output = {}
output["seed"] = seed

output["pre_n_epochs"]=30000
output["pre_n_tasks"]=24
output["pre_K"]=10
output["post_n_epochs"]=30000
output["post_n_tasks"]=24
output["post_K"]=10
output["data_noise"]=0.05
output["maddox_noise"]=0.05
output["meta_lr"]=0.001
output["subspace_dimension"]=10
output["pre_losses"]=pre_losses
output["post_losses"]=post_losses
output["init_params"]=init_params
output["intermediate_params"]=pre_state.params
output["trained_params"]=post_state.params
output["intermediate_mean"]=pre_state.mean
output["trained_mean"]=post_state.mean
output["intermediate_batch_stats"]=pre_state.batch_stats
output["trained_batch_stats"]=post_state.batch_stats
output["trained_scale"]=post_state.scale
output["proj"]=post_state.proj
output["pre_evals"]=pre_evals
output["post_evals"]=post_evals

In [85]:
with open("logs_final/fim_infinite_vanilla_post_nll.pickle", "wb") as handle:
    pickle.dump(output, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [103]:
with open("logs_final/fim_infinite_vanilla_post_nll.pickle", "rb") as handle:
    output = pickle.load(handle)

In [104]:
model = models.small_network(40, "relu", 1)
apply_fn = utils.apply_fn_wrapper(model.apply, False)

In [105]:
kernel, kernel_self, jac = ntk.get_kernel_and_jac_lowdim_cov(apply_fn, output["trained_params"], output["trained_scale"], output["trained_batch_stats"], output["proj"])

In [106]:
from matplotlib import pyplot as plt
from nll import gaussian_posterior_full

def plot_notebooks(key, part_apply_fn, kernel, kernel_self, jac, mean, K, dataset_provider):
    """
    Make an informative prediction plot in the singGP case (for the kernel specified)
    K is the number of context inputs
    Change dataset_provider to test on other datasets (e.g. dataset_sines_infinite)
    """
    x, y, fun = dataset_provider.get_fancy_test_batch(key, K=10, L=0, data_noise=0.05)

    x_a_all = x[0, :10]
    y_a_all = y[0, :10]
    x_b = np.linspace(-5, 5, 100)[:, np.newaxis]
    y_b = fun(x_b)

    y_min, y_max = np.min(y_b) - 0.5, np.max(y_b) + 0.5

    correction_a_all = utils.falseaffine_correction0(jac, mean, x_a_all)
    correction_b = utils.falseaffine_correction0(jac, mean, x_b)

    x_a = x_a_all[:K]
    y_a = y_a_all[:K]
    correction_a = correction_a_all[:K]

    prediction, cov = gaussian_posterior_full(kernel, kernel_self, x_a, y_a - correction_a, x_b, 0.05)
    prediction = prediction + correction_b

    error = dataset_provider.error_fn(prediction, y_b)
    loss = nll(kernel_self, x_a, y_a - correction_a, maddox_noise=0.05)

    variances = np.diag(cov)
    stds = np.sqrt(variances)

    plt.plot(x_b, y_b, "g--", label="Target")
    plt.plot(x_b, part_apply_fn(x_b), "k--", label="apply_fn")
    plt.plot(x_b, correction_b, "p--", label="correction_b")
    plt.plot(x_a, y_a, "ro", label="Context data")
    plt.plot(x_b, prediction, "b", label="Prediction")
    plt.fill_between(x_b[:, 0], prediction[:, 0] - 1.96 * stds, prediction[:, 0] + 1.96 * stds, color='blue', alpha=0.1, label="+/- 1.96$\sigma$")
    plt.title(f"NLL={loss:.4f}, MSE={error:.4f} ($K$={K})")
    plt.legend()
    plt.gca().set_ylim([y_min, y_max])
    plt.gca().set_xlabel("$x$")
    plt.gca().set_ylabel("$y$")
    plt.legend()

In [107]:
key = random.PRNGKey(1)

In [108]:
part_apply_fn = partial(apply_fn, output["trained_params"], output["trained_batch_stats"])
key, subkey = random.split(key)
plot_notebooks(subkey, part_apply_fn, kernel, kernel_self, jac, output["trained_mean"], 10, dataset_sines_infinite)

TypeError: nll() missing 5 required positional arguments: 'kernel_self', 'x_a', 'y_a', 'x_b', and 'y_b'

In [None]:
pretraining_nll_no_mean = [x[0] for x in output["pre_evals"]]
posttraining_nll_no_mean = [x[0] for x in output["post_evals"]]

n_samples_pretraining = len(pretraining_nll_no_mean)
n_samples_posttraining = len(posttraining_nll_no_mean)

plt.clf()
plt.plot(range(n_samples_pretraining), pretraining_nll_no_mean, "g", label="pretraining_nll")
plt.plot(range(n_samples_posttraining-1, n_samples_posttraining+n_samples_posttraining-1), pretraining_nll_no_mean, "r", label="projtraining_nll")
plt.gca().set_ylim([-10, 20])
plt.legend()
plt.show()

In [None]:
pretraining_mse_no_mean = [x[1] for x in output["pre_evals"]]
posttraining_mse_no_mean = [x[1] for x in output["post_evals"]]

plt.clf()
plt.plot(range(n_samples_pretraining), pretraining_mse_no_mean, "g--", label="pretraining_mse")
plt.plot(range(n_samples_posttraining-1, n_samples_posttraining+n_samples_posttraining-1), posttraining_mse_no_mean, "r--", label="projtraining_mse")
plt.gca().set_ylim([0, 1])
plt.legend()
plt.show()