In [None]:
import os
import numpy as np
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import itertools
import torch
import time

sns.set_style("whitegrid")
sns.set_context("paper", font_scale=3, rc={"lines.linewidth": 2})
sns.set_palette("Set1")

from tqdm import tqdm
import numpy as np
import jax
import jax.numpy as jnp
from flax.training import train_state
import optax

import matplotlib.pyplot as plt
from IPython.display import set_matplotlib_formats

set_matplotlib_formats("pdf", "png")

In [None]:
from flax import linen as nn
from typing import Callable, List

def count_params(model, x):
    params = model.init(jax.random.PRNGKey(0), x)
    n = sum([np.prod(p.shape) for p in jax.tree_util.tree_leaves(params)])
    print(f'Number of parameters: {n/1e3} k')
    
def param_size(params):
    return sum([np.prod(p.shape) for p in jax.tree_util.tree_leaves(params)])

class NonLinearMixture(nn.Module):
    feature_fn: Callable
    act_fn: Callable
    prior_stds: List[float]
    is_orthogonal: bool = False

    @nn.compact
    def __call__(self, x):
        # evaluate basis functions
        features = self.feature_fn(x) # (n, d)
        # sample from N(0, prior_stds ** 2)
        mixture_coeffs = self.param('mixture_coeffs', lambda key: jax.random.normal(key, (features.shape[-1],)) * self.prior_stds)
        mixture_coeffs = self.act_fn(mixture_coeffs) / jnp.sqrt(mixture_coeffs.shape[-1]) # (d, )
        return (features @ mixture_coeffs).reshape(-1, 1) # (n, 1)
    
class LogDet:
    def __init__(self, model: NonLinearMixture, x_eval: jnp.ndarray, svd=False):
        features = model.feature_fn(x_eval) # (n, d)
        self.Phi = features.T @ features / features.shape[0] # (d, d)
        self.act_deriv = jax.grad(model.act_fn)
        self.act_deriv = jax.vmap(self.act_deriv)
        self.diag = model.is_orthogonal
        self.svd = svd
        
    
    def compute_log_det(self, model_params, jitter=0, return_J=False):
        # J[i,j] = s(w_i) * s(w_j) * Phi[i,j], where s is the derivative of model.act_fn
        s = self.act_deriv(model_params['params']['mixture_coeffs'])
        if self.diag:
            print('DIAG!')
            Phi_diag = jnp.diag(self.Phi)
            return jnp.sum(jnp.log(Phi_diag * (s ** 2) + jitter))
        J = s[:, None] * s[None, :] * self.Phi # (d, d)
        # compute with svd
        if self.svd:
            singular_values = jnp.linalg.svd(J, compute_uv=False)
            eigs = singular_values ** 2 + jitter
            logdet = jnp.sum(jnp.log(eigs))
        else:
            logdet = jnp.linalg.slogdet(J + jitter * jnp.eye(J.shape[0]))[1]
        if return_J:
            return logdet, J
        return logdet

In [None]:
def tree_norm(tree):
    return jnp.sqrt(sum([jnp.sum(x**2) for x in jax.tree_leaves(tree)]))

def random_split_like_tree(rng_key, target=None, treedef=None):
    if treedef is None:
        treedef = jax.tree_structure(target)
    keys = jax.random.split(rng_key, treedef.num_leaves)
    return jax.tree_unflatten(treedef, keys)

@jax.jit
def tree_random_normal_like(rng_key, target):
    keys_tree = random_split_like_tree(rng_key, target)
    return jax.tree_map(
        lambda l, k: jax.random.normal(k, l.shape, l.dtype),
        target,
        keys_tree,
    )

In [None]:
def optimize(train_loss_fn, test_loss_fn, neg_log_p_w_fn, n_train, init_params, lr, n_step, rng_key, model, x_eval_generator, logdet_calculator, optimizer, jitter, method, temp=1.0):    
    def get_train_state(optimizer, lr, init_params, warmup_steps=100):
        lr_sched = optax.linear_schedule(0, lr, warmup_steps, transition_begin=0)
        if optimizer == 'adam':
            tx = optax.chain(
                    optax.adam(learning_rate=lr_sched),
                )
        elif optimizer == 'sgd':
            tx = optax.sgd(learning_rate=lr_sched, momentum=0.9)
        else:
            raise NotImplementedError
        return train_state.TrainState.create(apply_fn=model.apply, params=init_params, tx=tx)
    
    ts = get_train_state(optimizer, lr, init_params)
    
    def augmented_loss_fn(p, x_eval, dp):
        # loss = likelihood / n_train = 1 / (2 * sigma^2) * ||y - f(x)||^2 / n_train
        # it contains a factor 1 / n_train
        # all other terms should be divided by n_train as well
        avg_nll = train_loss_fn(p) 
        test_loss = test_loss_fn(p)
        # number of params
        P = jax.tree_util.tree_leaves(p)[0].shape[0]
        if method == 'fsmap':
            # logdet = 1 / 2 * log_det_g_svd(model, p, x_eval, jitter) / n_train
            logdet = 1 / 2 * logdet_calculator.compute_log_det(p, jitter) / n_train
        elif method == 'psmap':
            # logdet = 1 / 2 * log_det_g_svd(model, p, x_eval, jitter) / n_train
            logdet = 1 / 2 * logdet_calculator.compute_log_det(p, jitter) / n_train
            logdet = jax.lax.stop_gradient(logdet)
        else:
            raise NotImplementedError
        neg_log_p_w = neg_log_p_w_fn(p) / n_train #jnp.sum((params_flat ** 2) / (2 * (prior_stds ** 2))) / n_train
        logdet =  temp * logdet + (1 - temp) * jax.lax.stop_gradient(logdet)
        # return avg_nll + neg_log_p_w + logdet, (avg_nll, logdet, test_loss)
        return avg_nll + neg_log_p_w + logdet, (avg_nll, logdet, test_loss)

    grad_fn = jax.jit(jax.value_and_grad(lambda p, x_eval, dp: augmented_loss_fn(p, x_eval, dp), has_aux=True))
    losses = []
    logdets = []
    avg_nlls = []
    test_losss = []
    def sample_dp(p, rng_key):
        # gaussian tree
        dp = tree_random_normal_like(rng_key, p)
        return dp
    @jax.jit
    def train_step(ts, rng_key):
        rng_key, x_eval_key = jax.random.split(rng_key)
        rng_key, dp_key = jax.random.split(rng_key)
        x_eval = x_eval_generator(x_eval_key)
        dp = sample_dp(ts.params, dp_key)
        (loss, aux), grads = grad_fn(ts.params, x_eval, dp)
        ts = ts.apply_gradients(grads=grads)
        return ts, loss, aux, rng_key
    for _ in tqdm(range(int(n_step))):
        ts, loss, aux, rng_key = train_step(ts, rng_key)
        avg_nll, logdet, test_loss = aux
        losses.append(loss.item())
        logdets.append(logdet.item())
        avg_nlls.append(avg_nll.item())
        test_losss.append(test_loss.item())
    losses = np.array(losses)
    logdets = np.array(logdets)
    avg_nlls = np.array(avg_nlls)
    test_losss = np.array(test_losss)
    return ts.params, losses, avg_nlls, logdets, test_losss

In [None]:
def avg_hess_eig(loss_fn, p):
    # fn: parameters pytree -> number
    # p: parameters
    # return average eigenvalue of hessian at p
    p_flat, unflatten = jax.flatten_util.ravel_pytree(p)
    def loss_fn_flat(p_flat):
        p = unflatten(p_flat)
        return loss_fn(p)
    hess_fn = jax.jit(jax.hessian(loss_fn_flat), device=jax.devices('cpu')[0])
    hess = hess_fn(p_flat)
    eigvals = jnp.linalg.eigvalsh(hess)
    mean_eig = eigvals.mean().item()
    assert mean_eig > 0, 'mean eigenvalue of hessian is non-positive'
    return mean_eig


In [None]:
def run_dataset(model, in_dim, n_train, x_eval_generator, n_step, lr, optimizer, seed, method, noise_std, jitter, temp=1, result_path=None, plot=False, prior_misspecification=1, noise_misspecification=1):
    # define log det calculator
    x_eval = x_eval_generator(jax.random.PRNGKey(0))
    logdet_calculator = LogDet(model, x_eval)
    # count parameters
    dummy_param = model.init(jax.random.PRNGKey(0), jnp.ones((1, in_dim)))
    P = param_size(dummy_param)
    print(f"Number of parameters: {P}")

    rng_key = jax.random.PRNGKey(seed)
    
    
    prior_stds = model.prior_stds # (p,)
    rng_key, init_params_key = jax.random.split(rng_key)
    w_true = model.init(init_params_key, jnp.ones((1, in_dim))) # (p,)
    rng_key, x_train_key = jax.random.split(rng_key)
    x_train = jax.random.uniform(x_train_key, (n_train, in_dim), minval=-1, maxval=1)
    y_train = model.apply(w_true, x_train)
    rng_key, noise_key = jax.random.split(rng_key)
    noise = jax.random.normal(noise_key, y_train.shape) * noise_std
    print('Y train mean:', jnp.mean(y_train))
    print('Y train std:', jnp.std(y_train))
    y_train = y_train + noise
    
    # misspecification
    noise_std = noise_std / noise_misspecification
    prior_stds = prior_stds / prior_misspecification
    
    # x_test = jax.random.uniform(jax.random.PRNGKey(99), (1000, in_dim), minval=-1, maxval=1)
    x_test = jnp.linspace(-1, 1, 1000).reshape(-1, 1)
    y_test = model.apply(w_true, x_test)
    # zero initialization
    init_params = jax.tree_util.tree_map(lambda x: jnp.zeros_like(x), w_true)
    train_loss_fn = lambda p: jnp.mean((model.apply(p, x_train) - y_train) ** 2) / (2 * noise_std ** 2)
    test_loss_fn = lambda p: jnp.mean((model.apply(p, x_test) - y_test) ** 2) / (2 * noise_std ** 2)
    neg_log_p_w_fn = lambda p: jnp.sum((jax.flatten_util.ravel_pytree(p)[0] ** 2) / (2 * (prior_stds ** 2)))
    train_rmse = lambda p: jnp.sqrt(jnp.mean((model.apply(p, x_train) - y_train) ** 2))
    train_mse = lambda p: jnp.mean((model.apply(p, x_train) - y_train) ** 2)
    test_rmse = lambda p: jnp.sqrt(jnp.mean((model.apply(p, x_test) - y_test) ** 2))
    
    # log training time
    start_time = time.time()
    params, losses, avg_nlls, logdets, test_losss = optimize(train_loss_fn, test_loss_fn, neg_log_p_w_fn, n_train, init_params, lr, n_step, rng_key, model, x_eval_generator, logdet_calculator, optimizer, jitter, method, temp)
    end_time = time.time()
    # train time in seconds
    train_time = end_time - start_time
    # plot and save losses
    plt.figure()
    
    def make_subplots(arrays, labels):
        n_plots = len(arrays)
        fig, axs = plt.subplots(n_plots, 1, figsize=(6, 6))
        for i, (a, label) in enumerate(zip(arrays, labels)):
            q1, q2 = np.quantile(a, [0., 0.98])
            axs[i].plot(a)
            axs[i].set_ylim(q1, q2)
            axs[i].set_xlabel('Step')
            axs[i].set_ylabel(label)
        plt.show()
    
    make_subplots([losses, logdets, avg_nlls, test_losss], ['Loss', 'Logdet', 'Train NLL', 'Test NLL'])
    
    # plot data and predictions
    if plot:
        plt.figure(figsize=(10, 8), dpi=100)
        plt.scatter(x_train.squeeze(-1), y_train.squeeze(-1), label='Train', s=20, color='black')
        plt.plot(x_test.squeeze(-1), y_test.squeeze(-1), label='Test', color='red')
        plt.plot(x_test.squeeze(-1), model.apply(params, x_test).squeeze(-1), label='Predictions', color='blue')
        plt.legend()
        plt.show()
    
    
    def measure(params, logdet_calculator):
        neg_log_param_prior = neg_log_p_w_fn(params).item()
        log_det = logdet_calculator.compute_log_det(params, jitter=1e-32).item()
        train_loss = train_loss_fn(params).item()
        measurements = {
            'log_det': log_det,
            'log_fs_posterior': -n_train * train_loss - neg_log_param_prior - 1 / 2 * log_det,
            'train_rmse': train_rmse(params).item(),
            'test_rmse': test_rmse(params).item(),
            'hess_train_mse': avg_hess_eig(train_mse, params),
            'log_ps_posterior': -n_train * train_loss - neg_log_param_prior,
            'neg_log_param_prior': neg_log_param_prior,
            'train_time': train_time,
            'jitter': jitter,
            'train_loss': train_loss,
            'test_loss': test_loss_fn(params).item(),
        }
        print(measurements)
        # return log_likelihood + neg_log_param_prior - 1 / 2 * log_det
        return measurements
    
    # compute function space posterior 
    measurements = measure(params, logdet_calculator)
    if result_path is not None:
        torch.save(measurements, result_path)

    return measurements

In [None]:
def fourier_features(dim, num_freqs_per_dim, min_freq, max_freq):
    "Return a function that computes Fourier features (sin and cos) for an input of given dimension."
    assert min_freq > 0 and max_freq > 0, "min_freq and max_freq must both be positive"
    # num_freqs = num_freqs_per_dim ** dim
    # create a (num_freqs, dim) matrix of frequencies on a dim-dimensional lattice
    k = jnp.stack(jnp.meshgrid(*[jnp.linspace(min_freq, max_freq, num_freqs_per_dim)] * dim), axis=-1).reshape(-1, dim) # (num_freqs, dim)
    def feature_fn(x):
        # (n, dim) -> (n, 2 * num_freqs_per_dim ** dim)
        return jnp.concatenate([jnp.sin(x @ k.T), jnp.cos(x @ k.T)], axis=-1)
    return feature_fn

def rbf_features(cs, width, height=1):
    # RBF function
    def rbf(x, c, width):
        return height * jnp.exp(-(x - c) ** 2 / width)
    def feature_fn(x):
        feats = [rbf(x, c, width) for c in cs]
        feats = jnp.concatenate(feats, axis=-1)
        return feats
    return feature_fn

Fourier

In [None]:
# FSMAP
# data generation
in_dim = 1
num_freqs_per_dim = 100
min_freq = 1 * 2 * np.pi / 2 # n 2 pi / L
max_freq = num_freqs_per_dim * 2 * np.pi / 2
feature_name = 'fourier1d'
feature_fn = fourier_features(in_dim, num_freqs_per_dim, min_freq, max_freq) # 2 * num_freqs_per_dim ** dim features
prior_stds = 10 * np.ones(2 * num_freqs_per_dim ** in_dim)
act_fn = jax.nn.tanh
model = NonLinearMixture(feature_fn, act_fn, prior_stds, is_orthogonal=True)

# optimizer
optimizer = 'adam'
lr = 1e-1
n_step = int(2500)

# method
method = 'fsmap'
jitter = 1e-32
def x_eval_generator(rng_key):
    return jax.random.uniform(rng_key, shape=(10000, in_dim), minval=-1, maxval=1)

# ------------------------------ #
output_dir = f'nonlinear_mixture/weights/{feature_name}'
skip_if_done = True
fsmap_results = []

seeds = range(3)
n_trains = [100, 200, 400, 800, 1600, 3200, 6400]
noise_stds = [0.1]

for seed, n_train, noise_std in itertools.product(seeds, n_trains, noise_stds):
    # mkdir if needed
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    result_path = f'{output_dir}/{method}_seed{seed}_ntrain{n_train}_noise{noise_std}.pt'
    if os.path.exists(result_path) and skip_if_done:
        r = torch.load(result_path)
        print('Loaded result from ', result_path)
    else:
        r = run_dataset(model, in_dim, n_train, x_eval_generator, n_step, lr, optimizer, seed, method, noise_std, jitter, temp=1, result_path=result_path)
    r['n_train'] = n_train
    r['noise_std'] = noise_std
    r['seed'] = seed
    fsmap_results.append(r)

In [None]:
# PSMAP
# data generation
in_dim = 1
num_freqs_per_dim = 100
min_freq = 1 * 2 * np.pi / 2 # n 2 pi / L
max_freq = num_freqs_per_dim * 2 * np.pi / 2
feature_name = 'fourier1d'
feature_fn = fourier_features(in_dim, num_freqs_per_dim, min_freq, max_freq) # 2 * num_freqs_per_dim ** dim features
prior_stds = 10 * np.ones(2 * num_freqs_per_dim ** in_dim)
act_fn = jax.nn.tanh
model = NonLinearMixture(feature_fn, act_fn, prior_stds, is_orthogonal=True)

# optimizer
optimizer = 'adam'
lr = 1e-1
n_step = int(2500)

# method
method = 'psmap'
jitter = 1e-32
def x_eval_generator(rng_key):
    return jax.random.uniform(rng_key, shape=(10000, in_dim), minval=-1, maxval=1)

# ------------------------------ #
output_dir = f'nonlinear_mixture/weights/{feature_name}'
skip_if_done = True
psmap_results = []

seeds = range(3)
# n_trains = [5, 10, 20, 40, 80, 160, 320, 640, 1280, 2560]
n_trains = [100, 200, 400, 800, 1600, 3200, 6400]
noise_stds = [0.1]

for seed, n_train, noise_std in itertools.product(seeds, n_trains, noise_stds):
    # mkdir if needed
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    result_path = f'{output_dir}/{method}_seed{seed}_ntrain{n_train}_noise{noise_std}.pt'
    if os.path.exists(result_path) and skip_if_done and 'hess_train_mse' in torch.load(result_path):
        r = torch.load(result_path)
        print('Loaded result from ', result_path)
    else:
        r = run_dataset(model, in_dim, n_train, x_eval_generator, n_step, lr, optimizer, seed, method, noise_std, jitter, temp=1, result_path=result_path)
    r['n_train'] = n_train
    r['noise_std'] = noise_std
    r['seed'] = seed
    psmap_results.append(r)

In [None]:
# convert fsmap results to pandas dataframe
import pandas as pd
df_fsmap = pd.DataFrame(fsmap_results)
df_psmap = pd.DataFrame(psmap_results)
df_fsmap['method'] = 'FS-MAP'
df_psmap['method'] = 'PS-MAP'

df = pd.concat([df_fsmap, df_psmap])
# sort by n_train and method
df = df.sort_values(by=['n_train', 'method'])
# add a column for generalization gap
df['generalization_gap'] = df['test_rmse'] - df['train_rmse']

In [None]:
import matplotlib.ticker
from matplotlib.ticker import ScalarFormatter, NullFormatter


sns.set_style("whitegrid")
sns.set_context("paper", font_scale=3, rc={"lines.linewidth": 2})
sns.set_palette("Set1")

os.makedirs(f'nonlinear_mixture/plots/{feature_name}', exist_ok=True)

# plt.rcParams['ytick.labelsize'] = 'x-small'
    
# Create a new DataFrame where each row is a trial for a specific method and 'n_train' value
df_diff = pd.DataFrame()

df_fs_map = df[df.method == 'FS-MAP']
df_ps_map = df[df.method == 'PS-MAP']

# Make sure the trials align between FS-MAP and PS-MAP
assert all(df_fs_map.n_train.values == df_ps_map.n_train.values), 'n_train values do not align between FS-MAP and PS-MAP'

df_diff['n_train'] = df_fs_map.n_train
df_diff['difference'] = (df_fs_map.log_fs_posterior.values - df_ps_map.log_fs_posterior.values)
# Plot the differences using sns.lineplot
plt.figure(figsize=(6,5), dpi=200)
plot = sns.lineplot(data=df_diff, x='n_train', y='difference', ci='sd', errorbar='sd', marker='o', markersize=5)
plot.set_xscale('log')
plot.set_yscale('log')
plot.yaxis.set_minor_formatter(NullFormatter())
plot.set_yticks([1e3, 6e3])
plt.xlabel('# Train samples')
plt.ylabel('')
plt.tight_layout()
plt.savefig(f'nonlinear_mixture/plots/{feature_name}/log_fs_posterior.pdf')
plt.savefig(f'nonlinear_mixture/plots/{feature_name}/log_fs_posterior.png')
plt.show()

plt.figure(figsize=(6,5), dpi=200)
sns.lineplot(x='n_train', y='hess_train_mse', data=df[df.method == 'PS-MAP'], linestyle='--', label='PS-MAP', color='C1', errorbar='sd', marker='o', markersize=5)
plot = sns.lineplot(x='n_train', y='hess_train_mse', data=df[df.method == 'FS-MAP'], linestyle='--', label='FS-MAP', color='C0', errorbar='sd', marker='o', markersize=5)
plt.xlabel('# Train samples')
plt.ylabel('')
plot.set_xscale('log')
plot.set_yscale('log')
plot.yaxis.set_minor_formatter(NullFormatter())
plot.set_yticks([1e-3, 1e-4])
plt.legend(loc='upper right', prop={'size': 24})
plt.tight_layout()
plt.savefig(f'nonlinear_mixture/plots/{feature_name}/hess_train_mse.pdf')
plt.savefig(f'nonlinear_mixture/plots/{feature_name}/hess_train_mse.png')
plt.show()


plt.figure(figsize=(6,5), dpi=200)
sns.lineplot(x='n_train', y='test_rmse', data=df[df.method == 'PS-MAP'], linestyle='-', label='PS-MAP', color='C1', errorbar='sd', marker='o', markersize=5)
sns.lineplot(x='n_train', y='test_rmse', data=df[df.method == 'FS-MAP'], linestyle='-', label='FS-MAP', color='C0', errorbar='sd', marker='o', markersize=5)
plt.xlabel('# Train samples')
plt.ylabel('')
plt.yscale('log')
plt.xscale('log')
plt.legend([],[], frameon=False)
plt.tight_layout()
plt.savefig(f'nonlinear_mixture/plots/{feature_name}/test_rmse.pdf')
plt.savefig(f'nonlinear_mixture/plots/{feature_name}/test_rmse.png')
plt.show()

plt.figure(figsize=(6,5), dpi=200)
sns.lineplot(x='n_train', y='train_rmse', data=df[df.method == 'PS-MAP'], linestyle='--', label='PS-MAP', color='C1', errorbar='sd', marker='o', markersize=5)
plot = sns.lineplot(x='n_train', y='train_rmse', data=df[df.method == 'FS-MAP'], linestyle='--', label='FS-MAP', color='C0', errorbar='sd', marker='o', markersize=5)
plt.xlabel('# Train samples')
plt.ylabel('')
plot.set_xscale('log')
plot.set_yscale('log')
plot.yaxis.set_minor_formatter(NullFormatter())
plot.set_yticks([1e-1, 3e-2])
# remove legend
plt.legend([],[], frameon=False)
plt.tight_layout()
plt.savefig(f'nonlinear_mixture/plots/{feature_name}/train_rmse.pdf')
plt.savefig(f'nonlinear_mixture/plots/{feature_name}/train_rmse.png')
plt.show()



