In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '7'

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

sns.set(font_scale=1.5, style='whitegrid')

In [None]:
from fspace.datasets import get_dataset, get_dataset_normalization
from fspace.datasets.two_moons import get_twomoons_ctx
from torch.utils.data import DataLoader

train_data, _, test_data = get_dataset('twomoons', random_state=137)
ctx_data = get_twomoons_ctx(n_samples=10000, random_state=137, normalize=get_dataset_normalization('twomoons'))

train_loader = DataLoader(train_data, batch_size=10, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1000)
ctx_loader = DataLoader(ctx_data, batch_size=10)

fig, ax = plt.subplots(figsize=(5,5))

ax.scatter(train_data.X[:, 0], train_data.X[:, 1], c=train_data.y, cmap=sns.color_palette('Accent', as_cmap=True))
# ax.scatter(ctx_data.X[:, 0], ctx_data.X[:, 1], color='black', alpha=0.25, label=r'$\mathbf{X}_C$')
ax.legend()

fig.show()

In [None]:
import jax
import optax
import jax.numpy as jnp
import distrax
from tqdm.auto import tqdm
import logging
from fspace.nn import create_model
from fspace.utils.training import TrainState, eval_classifier


@jax.jit
def train_step_fn(prior_params, state, b_X, b_Y, b_X_ctx, f_prior_std, jitter=1e-4):
    '''
    NOTE: Prior means for all parameters is assumed to be zero.
    '''
    B = b_X.shape[0]

    def loss_fn(params, **extra_vars):
        b_X_in = b_X if b_X_ctx is None else jnp.concatenate([b_X, b_X_ctx], axis=0)

        b_logits, new_state = state.apply_fn({ 'params': params, **extra_vars }, b_X_in,
                                             mutable=['batch_stats'], train=True)

        loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(b_logits[:B], b_Y))

        h_X = state.apply_fn({ 'params': prior_params, **extra_vars }, b_X_in,
                                                   mutable=['batch_stats', 'intermediates'], train=True)[1]['intermediates']['features'][0]

        f_h_cov = jnp.matmul(h_X * f_prior_std**2, h_X.T)
        f_h_cov = f_h_cov + f_prior_std**2 * jnp.ones_like(f_h_cov) + f_prior_std**2 * jnp.eye(f_h_cov.shape[0])
        f_dist = distrax.MultivariateNormalFullCovariance(
            loc=jnp.zeros(f_h_cov.shape[0]), covariance_matrix=f_h_cov)

        reg_loss = - jnp.sum(f_dist.log_prob(b_logits.T))

        total_loss = loss + reg_loss

        return total_loss, (new_state, loss, reg_loss)

    (_, (new_state, loss, reg_loss)), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params, **state.extra_vars)

    final_state = state.apply_gradients(grads=grads, **new_state)

    step_metrics = {
        'batch_loss': loss,
        'batch_reg_loss': reg_loss,
    }

    return final_state, step_metrics


def train_model(prior_sample, state, loader, ctx_loader=None, log_dir=None, epoch=None):
    ctx_iter = iter(ctx_loader or [[None, None]])

    for i, (X, Y) in enumerate(loader):
        X, Y = X.numpy(), Y.numpy()
        try:
            X_ctx, _ = next(ctx_iter)
        except StopIteration:
            ctx_iter = iter(ctx_loader or [[None, None]])
            X_ctx, _ = next(ctx_iter)
        if X_ctx is not None:
            X_ctx = X_ctx.numpy()

        state, step_metrics = train_step_fn(prior_sample, state, X, Y, X_ctx, f_prior_std)

        step_metrics = { k: v.item() for k, v in step_metrics.items() }
        print(step_metrics)

    return state

weight_decay = 0.
alpha = 0.0
lr = 1e-3
epochs = 10000
f_prior_std = 10000.

rng = jax.random.PRNGKey(42)

rng, model_rng = jax.random.split(rng)
model, init_params, init_vars = create_model(model_rng, 'mlp200', train_data[0][0].numpy()[None, ...], num_classes=train_data.n_classes)

optimizer = optax.adamw(learning_rate=lr, weight_decay=weight_decay)

train_state = TrainState.create(
    apply_fn=model.apply,
    params=init_params,
    **init_vars,
    tx=optimizer)

for e in tqdm(range(epochs)):
    rng, model_rng = jax.random.split(rng)
    _, reinit_params, _ = create_model(model_rng, 'mlp200', train_data[0][0].numpy()[None, ...], num_classes=train_data.n_classes)

    train_state = train_model(reinit_params, train_state, train_loader, ctx_loader=ctx_loader, epoch=e)

    if e % 100 == 0:
        print(eval_classifier(train_state, train_loader))

eval_classifier(train_state, train_loader)

In [None]:
@jax.jit
def _forward(X):
    return train_state.apply_fn({ 'params': train_state.params, **train_state.extra_vars}, X, mutable=False, train=False)

test_logits = []
for X, Y in tqdm(test_loader, leave=False):
    X, Y = X.numpy(), Y.numpy()

    test_logits.append(_forward(X))

test_p = jax.nn.softmax(jnp.concatenate(test_logits), axis=-1)

test_p.shape

In [None]:
fig, ax = plt.subplots(figsize=(5,5))

plt_xy = test_data.X.reshape(100, 100, -1)
plt_z = test_p[:, 0].reshape(100, 100)

contourf_ = ax.contourf(plt_xy[..., 0], plt_xy[..., 1], plt_z, vmin=0., vmax=1., cmap='viridis')
ax.scatter(train_data.X[..., 0], train_data.X[..., 1], c=train_data.y, cmap=sns.color_palette('tab20', as_cmap=True))
# ax.scatter(ctx_data.X[:, 0], ctx_data.X[:, 1], color='black', alpha=0.1)

fig.colorbar(contourf_)
fig.show()