In [1]:
%env TF_FORCE_UNIFIED_MEMORY=1

env: TF_FORCE_UNIFIED_MEMORY=1


In [3]:
from jax import vmap
from jax import numpy as np
from flax import struct
from typing import Any, Callable
from flax import core
import optax
import models
from jax import random
import dataset_sines_finite
from jax import value_and_grad, grad
from functools import partial
from jax.tree_util import tree_map
from jax import jit
import time
from matplotlib import pyplot as plt
from jax.lax import scan
import pickle
import dataset_sines_infinite
import pickle

In [4]:
from jax import lax

In [10]:
# sine config
config = {}
config["n_epochs"] = 70000
config["n_tasks_per_epoch"] = 24
config["K"] = 10
config["L"] = 10
config["n_updates"] = 5
config["n_updates_test"]= 10
config["lr"] = 0.001
config["data_noise"] = 0.05
config["n_test_tasks"] = 100

config["x_dim"] = 1
config["y_dim"] = 1
config["nn_layers"] = [128,128, 32]
config["activation"] = 'tanh'
config["sigma_eps"] = 0.05


In [11]:
def get_test_batch(key, n_tasks, K, L, data_noise):
    x, y = get_raw_batch(key, n_tasks, K, L, data_noise)
    
    return x[:, :K], y[:, :K], x[:, K:], y[:, K:]

def get_raw_batch(key, n_tasks, K, L, data_noise):
    # set this higher for a multi-dimensional regression
    reg_dim = 1

    key_x, key = random.split(key)
    x = random.uniform(key_x, shape = (n_tasks, K+L, 1), minval=-5, maxval=5)
    
    y = np.empty( (n_tasks, K+L, reg_dim) )
    
    def f(task_index, value):
        y, key = value

        key_fun, key_noise, key = random.split(key, 3)

        function = draw_multi(key_fun, reg_dim)
        y = y.at[task_index, :K, :].set(function(x[task_index, :K]) + random.normal(key_noise, shape=(K, reg_dim)) * data_noise)
        y = y.at[task_index, K:, :].set(function(x[task_index, K:]))

        return (y, key)
            
    return x, lax.fori_loop(0, n_tasks, f, (y, key) )[0]

def draw_multi(key, reg_dim, amp_low=0.1, amp_high=5, phase_low=0, phase_high=np.pi):
    key_amp, key_phase = random.split(key)
    
    amps = random.uniform(key_amp, shape=(reg_dim,), minval=amp_low, maxval=amp_high)
    phases = random.uniform(key_phase, shape=(reg_dim,), minval=phase_low, maxval=phase_high)
    
    def function(x):
        return amps * np.sin(x + phases) + 1
        
    return vmap(function)

def get_train_batch_fn(key):
    # uncomment this line to train on infinite dataset
    return dataset_sines_infinite.get_test_batch(key, config["n_tasks_per_epoch"], config["K"], config["L"], config["data_noise"])

    # uncomment this line to train on finite dataset
    #return sine_dataset_offset_finite.get_train_batch_as_val_batch(key, config["n_tasks_per_epoch"], config["K"], config["L"], config["data_noise"])
    
def get_test_batch_fn(key):
    return dataset_sines_infinite.get_test_batch(key, config["n_test_tasks"], config["K"], config["L"], config["data_noise"])

In [12]:
class Dataset:
    def __init__(self):
        pass

    # draw n_sample (x,y) pairs drawn from n_func functions
    # returns (x,y) where each has size [n_func, n_samples, x/y_dim]
    def sample(self, n_funcs, n_samples):
        raise NotImplementedError

class SinusoidDataset(Dataset):
    def __init__(self, config, key, noise_var=None, rng=None):
        self.key = key
        if noise_var is None:
            self.noise_std = np.sqrt( config['data_noise'] )
        else:
            self.noise_std = np.sqrt( noise_var )

    def sample(self, n_funcs, n_samples, return_lists=False):
        batch = get_test_batch(self.key, n_funcs, n_samples, 0, self.noise_std)
        return batch[0], batch[1]

In [13]:
noise1 = 0.1
noise2 = 0.3
noise3 = 0.5
dataset1 = SinusoidDataset(config, noise_var=noise1, key=random.PRNGKey(0))
dataset2 = SinusoidDataset(config, noise_var=noise2, key=random.PRNGKey(0))
dataset3 = SinusoidDataset(config, noise_var=noise3, key=random.PRNGKey(0))

print(dataset1.sample(2, 3))

(Array([[[2.2542226 ],
        [0.8597934 ],
        [2.4556649 ]],

       [[0.13030052],
        [3.881017  ],
        [3.837769  ]]], dtype=float32), Array([[[ 0.8957337 ],
        [ 3.2626424 ],
        [-0.37934366]],

       [[ 1.4671689 ],
        [ 0.14199564],
        [-0.1838236 ]]], dtype=float32))


In [None]:
plt.plot(x_a[0], y_a[0], "ro", label="Context")
plt.plot(x_b[0], y_b[0], "rx", label="Query")
plt.plot(x_b[0], predictions, "+b", label="Pred")
plt.plot(np.linspace(-5, 5, 100), apply_fn(output["trained_params"], np.linspace(-5, 5, 100)[:, np.newaxis]), "--", label="Raw", alpha=0.4)
plt.legend()

In [48]:
import jax
import jax.numpy as jnp
from jax import random, jit, grad, value_and_grad, vmap
from flax import linen as nn
from jax.scipy.linalg import inv
from jax.numpy.linalg import slogdet
from jax.numpy import transpose, expand_dims, log, squeeze
from flax.training import train_state
from jax import random
import optax
import time
from copy import deepcopy

In [58]:
class ALPaCA(nn.Module):
    config: dict

    def setup(self):
        self.lr = self.config['lr']
        self.x_dim = self.config['x_dim']
        self.phi_dim = self.config['nn_layers'][-1]
        self.y_dim = self.config['y_dim']
        self.sigma_eps = self.config['sigma_eps']
        self.preprocess = self.config['preprocess']
        self.f_nom = self.config['f_nom']

        last_layer = self.config['nn_layers'][-1]

        if isinstance(self.sigma_eps, list):
            self.SigEps = jnp.diag(jnp.array(self.sigma_eps))
        else:
            self.SigEps = self.sigma_eps * jnp.eye(self.y_dim)
        self.SigEps = self.SigEps.reshape((1, 1, self.y_dim, self.y_dim))

        self.K = self.param('K_init', nn.initializers.normal(), (last_layer, self.y_dim))
        self.L_asym = self.param('L_asym', nn.initializers.normal(), (last_layer, last_layer))
        self.L = self.L_asym @ self.L_asym.T

        
    def __call__(self, x, context_x, context_y, num_context):
        phi = self.basis(x)
        context_phi = self.basis(context_x)
        f_nom_x = jnp.zeros_like(context_y)
        f_nom_cx = jnp.zeros_like(context_y)
        if self.f_nom is not None:
            f_nom_x = self.f_nom(x)
            f_nom_cx = self.f_nom(context_x)


    def compute_total_loss(self, param, context_x, context_y, x, y, f_nom_cx, f_nom_x, num_context):
        # Subtract f_nom from context points before BLR
        phi = self.basis(x)
        context_phi = self.basis(context_x)
        f_nom_x = jnp.zeros_like(context_y)
        f_nom_cx = jnp.zeros_like(context_y)
        if self.f_nom is not None:
            f_nom_x = self.f_nom(x)
            f_nom_cx = self.f_nom(context_x)
        context_y_blr = context_y - f_nom_cx
        posterior_K, posterior_L_inv = self.batch_blr(context_phi, context_y_blr, num_context)
        mu_pred, Sig_pred, predictive_nll = self.compute_pred_and_nll(phi, y, posterior_K, posterior_L_inv, f_nom_x)
        total_loss = jnp.mean(predictive_nll)
        return total_loss

    @nn.compact
    def basis(self, x):
        inp = x if self.preprocess is None else self.preprocess(x)
        for i, units in enumerate(self.config['nn_layers']):
            dense_layer = nn.Dense(features=units, name=f"layer_{i}")  # Create Dense layer
            inp = dense_layer(inp)  # Apply Dense layer
            activation_func = getattr(nn, self.config['activation'])
            inp = activation_func(inp)  # Apply activation function
        return inp

    def batch_blr(self, X, Y, num):
        X = X[:num, :]
        Y = Y[:num, :]
        Ln_inv = inv(X.T @ X + self.L)
        Kn = Ln_inv @ (X.T @ Y + self.L @ self.K)
        return jax.lax.cond(num > 0, lambda: (Kn, Ln_inv), lambda: (self.K, inv(self.L)))

    def compute_pred_and_nll(self, phi, y, posterior_K, posterior_L_inv, f_nom_x):
        """
        Uses self.posterior_K and self.posterior_L_inv and self.f_nom_x to generate the posterior predictive.
        Arguments:
            posterior_K: Posterior weights K matrix
            posterior_L_inv: Posterior inverse covariance matrix of weights
            phi: Feature matrix for input x
            f_nom_x: Nominal function values at x (if any)
            y: Actual target values to compare against
        Returns:
            mu_pred: Posterior predictive mean at query points
            Sig_pred: Posterior predictive variance at query points
            predictive_nll: Negative log likelihood of y under the posterior predictive density
        """
        mu_pred = posterior_K.T @ phi + f_nom_x
        spread_fac = 1 + batch_quadform(posterior_L_inv, phi)
        Sig_pred = expand_dims(spread_fac, axis=-1) * expand_dims(self.SigEps, (0, 0))
    
        # Score y under predictive distribution to obtain loss
        logdet = self.y_dim * log(spread_fac) + slogdet(self.SigEps)
        Sig_pred_inv = inv(Sig_pred)
        quadf = batch_quadform(Sig_pred_inv, (y - mu_pred))
    
        predictive_nll = squeeze(logdet + quadf, axis=-1)
    
        return mu_pred, Sig_pred, predictive_nll

    

In [59]:
def batch_matmul(mat, batch_v):
    """Batch matrix multiplication adjusted for the transposing logic in TensorFlow's batch_matmul."""
    return jnp.einsum('ijk,kj->ik', mat, batch_v)

def batch_quadform(A, b):
    """Batch quadratic form using JAX. Handles different cases based on dimensions of A."""
    if A.ndim == b.ndim + 1:
        # Same matrix A for all N vectors in b
        return jnp.squeeze(jnp.einsum('ijk,kj->ki', A, jnp.expand_dims(b, axis=-1)), axis=-1)
    elif A.ndim == b.ndim:
        # Different A for each b
        Ab = jnp.einsum('ijk,kj->ik', A, b)  # ... x N x n
        return jnp.squeeze(jnp.einsum('ij,ij->i', b, Ab), axis=-1)  # ... x N
    else:
        raise ValueError('Matrix size of %d is not supported.' % A.ndim)

def batch_2d_jacobian(y, x):
    """Compute the Jacobian of y with respect to x, handling batch dimensions."""
    y_dim = y.shape[-1]
    x_dim = x.shape[-1]

    def single_jacobian(yi, xi):
        return jax.jacfwd(lambda x: yi)(xi)

    batched_jacobian = vmap(single_jacobian, in_axes=(0, 0), out_axes=0)
    return batched_jacobian(y, x)


In [60]:
def create_train_state(key, rng, model, learning_rate=0.001):
    # Create dummy inputs for all expected arguments
    dummy_x = random.uniform(key, shape=(1, model.config['x_dim']), minval=-5, maxval=5)
    dummy_y = random.uniform(key, shape=(1, model.config['y_dim']), minval=-5, maxval=5)
    dummy_context_x = random.uniform(key, shape=(1, model.config['x_dim']), minval=-5, maxval=5)
    dummy_context_y = random.uniform(key, shape=(1, model.config['y_dim']), minval=-5, maxval=5)
    
    dummy_num_context = 1  # A simple dummy number of contexts, e.g., 1

    # Initialize parameters with full set of dummy inputs
    params = model.init(rng, dummy_x, dummy_context_x, dummy_context_y, dummy_num_context)
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

In [61]:
from dataset_sines_infinite import get_training_batch 

@jit
def train_step(state, batch_x, batch_y, num_context):
    def loss_fn(params):
        mu_pred, Sig_pred, predictive_nll = model.compute_total_loss(self, param, context_x, context_y, x, y, f_nom_cx, f_nom_x, num_context)
        return jnp.mean(predictive_nll), (mu_pred, Sig_pred)
    grad_fn = value_and_grad(loss_fn, has_aux=True)
    (loss, (mu_pred, Sig_pred)), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss, mu_pred, Sig_pred

def train(key, model, state, get_training_batch, num_train_updates):
    for i in range(num_train_updates):
        x_a, y_a, x_a_div, y_a_div = get_training_batch(key, n_tasks=model.config['meta_batch_size'], K=model.config['data_horizon'] + model.config['test_horizon'], data_noise=0.1, n_devices=1)
        num_context = random.randint(key, shape=(model.config['meta_batch_size'],), minval=0, maxval=model.config['data_horizon'] + 1)
        state, loss, mu_pred, Sig_pred = train_step(state, x[:, :model.config['data_horizon'], :], y[:, :model.config['data_horizon'], :], num_context)
        
        if i % 50 == 0:
            print(f'Iteration {i}, Loss: {loss}')

In [62]:
@jit
def test(model, params, x_c, y_c, x):
    mu_pred, Sig_pred, _ = model.apply(params, x_c, y_c, x)
    return mu_pred, Sig_pred

In [63]:
def encode(model, params, x):
    phi = model.apply(params, x, method=model.encode)
    return phi

def save(params, model_path):
    np.save(model_path, params)
    print(f'Saved to: {model_path}')

def restore(model_path):
    params = np.load(model_path, allow_pickle=True).item()
    print(f'Restored model from: {model_path}')
    return params

In [64]:
# Configuration and random key generation
config = {
    'lr': 0.001,
    'x_dim': 10,
    'nn_layers': [64, 32, 16],
    'y_dim': 5,
    'sigma_eps': 0.1,
    'activation': 'relu',
    'preprocess': None,
    'f_nom': None,
    'meta_batch_size': 10,
    'data_horizon': 10,
    'test_horizon': 20
}
key = random.PRNGKey(0)
key_init, key = random.split(key)

rng = random.PRNGKey(0)
model = ALPaCA(config)
state = create_train_state(key_init, rng, model)

In [65]:
key_train, key = random.split(key)
train(key_train, model, state, get_training_batch, 1000)

TypeError: ALPaCA.compute_total_loss() missing 3 required positional arguments: 'f_nom_cx', 'f_nom_x', and 'num_context'

## Vérification de l'équivalence des méthodes de sampling

In [30]:
import numpy as np

class Dataset:
    def __init__(self):
        pass

    # draw n_sample (x,y) pairs drawn from n_func functions
    # returns (x,y) where each has size [n_func, n_samples, x/y_dim]
    def sample(self, n_funcs, n_samples):
        raise NotImplementedError

In [31]:
class SinusoidDataset(Dataset):
    def __init__(self, config, noise_var=None, rng=None):
        self.amp_range = config['amp_range']
        self.phase_range = config['phase_range']
        self.freq_range = config['freq_range']
        self.x_range = config['x_range']
        if noise_var is None:
            self.noise_std = np.sqrt( config['sigma_eps'] )
        else:
            self.noise_std = np.sqrt( noise_var )
            
        self.np_random = rng
        if rng is None:
            self.np_random = np.random

    def sample(self, n_funcs, n_samples, return_lists=False):
        x_dim = 1
        y_dim = 1
        x = np.zeros((n_funcs, n_samples, x_dim))
        y = np.zeros((n_funcs, n_samples, y_dim))

        amp_list = self.amp_range[0] + self.np_random.rand(n_funcs)*(self.amp_range[1] - self.amp_range[0])
        phase_list = self.phase_range[0] + self.np_random.rand(n_funcs)*(self.phase_range[1] - self.phase_range[0])
        freq_list = self.freq_range[0] + self.np_random.rand(n_funcs)*(self.freq_range[1] - self.freq_range[0])
        for i in range(n_funcs):
            x_samp = self.x_range[0] + self.np_random.rand(n_samples)*(self.x_range[1] - self.x_range[0])
            y_samp = amp_list[i]*np.sin(freq_list[i]*x_samp + phase_list[i]) + self.noise_std*self.np_random.randn(n_samples)

            x[i,:,0] = x_samp
            y[i,:,0] = y_samp

        if return_lists:
            return x,y,freq_list,amp_list,phase_list

        return x,y

In [39]:
import dataset_sines_infinite
x_a, y_a, x_a_div, y_a_div = dataset_sines_infinite.get_training_batch(key, n_tasks=model.config['meta_batch_size'], K=model.config['data_horizon'] + model.config['test_horizon'], data_noise=0.1, n_devices=1)
print(x_a.shape, y_a.shape)

config_dataset = {
    'amp_range': [0.1, 5.0],
    'phase_range': [0, 3.14],
    'freq_range': [0.999, 1.0],
    'x_range': [-5., 5.],
    'sigma_eps': 0.02,
}
dataset = SinusoidDataset(config_dataset)
x, y = dataset.sample(n_funcs=model.config['meta_batch_size'], n_samples=model.config['data_horizon'] + model.config['test_horizon'])
print(x.shape, y.shape)

(10, 30, 1) (10, 30, 1)
(10, 30, 1) (10, 30, 1)
