In [None]:
%load_ext autoreload
%reload_ext autoreload
%env XLA_PYTHON_CLIENT_ALLOCATOR=platform

env: XLA_PYTHON_CLIENT_ALLOCATOR=platform


# Loops
> Inner and outer loops for metalearning

In [None]:
#| default_exp loops

In [None]:
#| hide
import nbdev

In [None]:
#| export
import jax
import jax.numpy as jnp
import numpy as np
import equinox as eqx
import optax

import logging
from tqdm.notebook import trange
import pickle
from functools import partial
import json
from dataclasses import asdict


In [None]:
#| export
from jaxDiversity.dataloading import NumpyLoader, DummyDataset
from jaxDiversity.utilclasses import InnerConfig, OuterConfig, InnerResults, OuterResults
from jaxDiversity.mlp import mlp_afunc, MultiActMLP, init_linear_weight, xavier_normal_init, save
from jaxDiversity.baseline import compute_loss as compute_loss_baseline
from jaxDiversity.hnn import compute_loss as compute_loss_hnn


In [None]:
# Configure the logger
logging.basicConfig(level=logging.INFO)

In [None]:
#| export
@eqx.filter_jit
def make_step(model, x, y, afuncs, optim, opt_state, compute_loss):
    loss, grads = compute_loss(model, x, y, afuncs)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, grads, model, opt_state

In [None]:
#| export
def inner_opt(model, train_data, test_data, afuncs, opt, loss_fn, config, training=False, verbose=False):
    """
    inner optimization loop
    """
    train_loss = []
    test_loss = []
    grad_norm = []

    opt_state = opt.init(model)
    for epoch in range(config.epochs):
        if training:
            for x,y in train_data:
                loss, grads, model, opt_state = make_step(model, x, y, afuncs, opt, opt_state, loss_fn)
                train_loss.append(loss)
                grad_norm_tree = jax.tree_map(lambda x: jnp.linalg.norm(x), grads)
                grad_norm_scalar = jax.tree_util.tree_reduce(jnp.add, grad_norm_tree)
                grad_norm.append(grad_norm_scalar)
        
        for x,y in test_data:
            x = jax.lax.stop_gradient(x)
            y = jax.lax.stop_gradient(y)
            loss, _ = loss_fn(model, x, y, afuncs)
            test_loss.append(loss)
        if verbose:
            logging.info(f"Epoch {epoch :03d} | Train Loss: {train_loss[-1] :.4e} | Test Loss: {test_loss[-1]:.4e} | Grad Norm: {grad_norm[-1]:.4e}")

    return model, opt_state, InnerResults(jnp.array(train_loss), jnp.array(test_loss), jnp.array(grad_norm))

In [None]:
#| test
# test inner_opt
dev_inner_config = InnerConfig(test_train_split=0.8,
                            input_dim=2,
                            output_dim=2,
                            hidden_layer_sizes=[18],
                            batch_size=64,
                            epochs=2,
                            lr=1e-3,
                            mu=0.9,
                            n_fns=2,
                            l2_reg=1e-1,
                            seed=42)
key = jax.random.PRNGKey(dev_inner_config.seed)
model_key, init_key = jax.random.split(key)
afuncs = [lambda x: x**2, lambda x: x]
train_dataset = DummyDataset(1000, dev_inner_config.input_dim, dev_inner_config.output_dim)
test_dataset = DummyDataset(1000, dev_inner_config.input_dim, dev_inner_config.output_dim)
train_dataloader = NumpyLoader(train_dataset, batch_size=dev_inner_config.batch_size, shuffle=True)
test_dataloader = NumpyLoader(test_dataset, batch_size=dev_inner_config.batch_size, shuffle=True)

opt = optax.rmsprop(learning_rate=dev_inner_config.lr, momentum=dev_inner_config.mu, decay=dev_inner_config.l2_reg)
model = MultiActMLP(dev_inner_config.input_dim, dev_inner_config.output_dim, dev_inner_config.hidden_layer_sizes, model_key, bias=False)
logging.info("Baseline NN inner loop test")
baselineNN, opt_state ,inner_results = inner_opt(model =model, 
                                            train_data =train_dataloader,
                                            test_data = test_dataloader,
                                            afuncs = afuncs, 
                                            opt = opt, 
                                            loss_fn=compute_loss_baseline,
                                            config = dev_inner_config, training=True, verbose=True)


INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA Interpreter Host
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
INFO:root:Baseline NN inner loop test
INFO:root:Epoch 000 | Train Loss: 1.6608e-01 | Test Loss: 1.7718e-01 | Grad Norm: 4.0744e-01
INFO:root:Epoch 001 | Train Loss: 1.4460e-01 | Test Loss: 1.2461e-01 | Grad Norm: 1.6079e-01


In [None]:
#| test
# test inner_opt
dev_inner_config = InnerConfig(test_train_split=0.8,
                            input_dim=2,
                            output_dim=1,
                            hidden_layer_sizes=[18],
                            batch_size=64,
                            epochs=2,
                            lr=1e-3,
                            mu=0.9,
                            n_fns=2,
                            l2_reg=1e-1,
                            seed=42)
key = jax.random.PRNGKey(dev_inner_config.seed)
model_key, init_key = jax.random.split(key)
afuncs = [lambda x: x**2, lambda x: x]
train_dataset = DummyDataset(1000, dev_inner_config.input_dim, 2)
test_dataset = DummyDataset(1000, dev_inner_config.input_dim, 2)
train_dataloader = NumpyLoader(train_dataset, batch_size=dev_inner_config.batch_size, shuffle=True)
test_dataloader = NumpyLoader(test_dataset, batch_size=dev_inner_config.batch_size, shuffle=True)
opt = optax.rmsprop(learning_rate=dev_inner_config.lr, momentum=dev_inner_config.mu, decay=dev_inner_config.l2_reg)
model = MultiActMLP(dev_inner_config.input_dim, dev_inner_config.output_dim, dev_inner_config.hidden_layer_sizes, model_key, bias=False)

logging.info("Hamiltonian NN inner loop test")
HNN, opt_state ,inner_results = inner_opt(model =model, 
                                            train_data =train_dataloader,
                                            test_data = test_dataloader,
                                            afuncs = afuncs, 
                                            opt = opt, 
                                            loss_fn=compute_loss_hnn,
                                            config = dev_inner_config, training=True, verbose=True)

INFO:root:Hamiltonian NN inner loop test
INFO:root:Epoch 000 | Train Loss: 9.8625e-02 | Test Loss: 8.2503e-02 | Grad Norm: 2.8775e-01
INFO:root:Epoch 001 | Train Loss: 7.7401e-02 | Test Loss: 7.7187e-02 | Grad Norm: 1.0838e-01


In [None]:
#| export
@eqx.filter_value_and_grad()
def outer_loss(outer_models, inner_model, x, y, loss_fn, base_act):
    inner_afuncs = [ partial(mlp_afunc, model = outer_model, base_act = base_act) for outer_model in outer_models]
    loss, _ = loss_fn(inner_model, x, y, inner_afuncs)
    return loss


In [None]:
#| export
@eqx.filter_jit
def outer_step(outer_models, inner_model, x, y, meta_opt, meta_opt_state, loss_fn, base_act):
    loss, grads = outer_loss(outer_models, inner_model, x, y, loss_fn, base_act)
    updates, opt_state = meta_opt.update(grads, meta_opt_state)
    model = eqx.apply_updates(outer_models, updates)
    return loss, grads, model, opt_state

In [None]:
#| export
def outer_opt(train_dataloader, test_dataloader, loss_fn, inner_config, outer_config, opt, meta_opt, save_path=None):
    """
    outer optimization loop
    """
    outer_model_key = jax.random.PRNGKey(outer_config.seed)
    inner_model_key = jax.random.PRNGKey(inner_config.seed)

    outer_models = []
    for _ in range(inner_config.n_fns):
        model = eqx.nn.MLP(in_size=outer_config.input_dim,
                        out_size=outer_config.output_dim,
                        width_size=outer_config.hidden_layer_sizes[0],
                        depth=1,
                        key=outer_model_key,
                        use_bias=True)
        outer_models.append(model)

    meta_opt_states = meta_opt.init(eqx.filter(outer_models, eqx.is_array))

    results = {
        "train_loss": [],
        "test_loss": [],
        "inner_afuncs": [],
        "grad_norms": []
    }

    inner_afuncs = []

    if inner_config.base_act == "sin":
        base_act = jnp.sin
    elif inner_config.base_act == "relu":
        base_act = jnp.relu
    elif inner_config.base_act == "tanh":
        base_act = jnp.tanh

    if save_path is not None:
        # save config files
        with open(f"{save_path}/inner_config.json", "w") as f:
            json.dump(asdict(inner_config), f)
        with open(f"{save_path}/outer_config.json", "w") as f:
            json.dump(asdict(outer_config), f)


    for step in trange(outer_config.steps):

        inner_afuncs = [ partial(mlp_afunc, model = outer_model, base_act = base_act) for outer_model in outer_models]

        inner_model = MultiActMLP(inner_config.input_dim,
                                    inner_config.output_dim,
                                    inner_config.hidden_layer_sizes,
                                    inner_model_key, bias=True)
        
        inner_model = init_linear_weight(inner_model, xavier_normal_init, inner_model_key)

        inner_model, _, inner_results = inner_opt(inner_model,
                                    train_dataloader, test_dataloader,
                                    inner_afuncs, opt, loss_fn,
                                    inner_config, training=True,
                                    verbose=False)

        x,y = next(iter(train_dataloader))

        
        loss, grads, outer_models, meta_opt_states = outer_step(outer_models, inner_model, x, y, meta_opt, meta_opt_states, loss_fn, base_act)
        grad_norm_tree = jax.tree_map(lambda x: jnp.linalg.norm(x), grads)
        grad_norm_scalar = jax.tree_util.tree_reduce(jnp.add, grad_norm_tree)
        results["grad_norms"].append(grad_norm_scalar)
        results["train_loss"].append(loss)
        mean_inner_test_loss = np.mean(inner_results.test_loss[-50:])
        results["test_loss"].append(mean_inner_test_loss)

        if step % outer_config.print_every == 0:
            logging.info(f"Step {step :03d} | Train Loss: {results['train_loss'][-1] :.4e} | Test Loss: {mean_inner_test_loss :.4e} | Grad Norm: {results['grad_norms'][-1] :.4e}")

        # sample activation functions
        x_sample = jnp.linspace(-10, 10, 100)
        y_list = [x_sample]
        for afunc in inner_afuncs:
            y_list.append(afunc(x_sample))
        results["inner_afuncs"].append(y_list)


        if (step % 100 == 0 or step == outer_config.steps-1) and save_path is not None:
            # pickle and save results
            with open(f"{save_path}/step_{step}.pkl", "wb") as f:
                pickle.dump(results, f)
            
            # save models
            for i, model in enumerate(outer_models):
                save(f"{save_path}/step_{step}_activation_model_{i}.eqx", asdict(outer_config), model)
    
    results_obj = OuterResults(inner_test_loss= np.array(results["test_loss"]),
                               train_loss= np.array(results["train_loss"]),
                               inner_afuncs= np.array(results["inner_afuncs"]),
                                 grad_norm= np.array(results["grad_norms"]))
    
    return outer_models, results_obj

        

In [None]:
#| test
# test outer_opt Baseline
inner_config = InnerConfig(test_train_split=0.8,
                            input_dim=2,
                            output_dim=2,
                            hidden_layer_sizes=[32],
                            batch_size=64,
                            epochs=5,
                            lr=1e-3,
                            mu=0.9,
                            n_fns=2,
                            l2_reg=1e-1,
                            seed=42)
outer_config = OuterConfig(input_dim=1,
                            output_dim=1,
                            hidden_layer_sizes=[18],
                            batch_size=1,
                            steps=2,
                            print_every=1,
                            lr=1e-3,
                            mu=0.9,
                            seed=24)
train_dataset = DummyDataset(1000, inner_config.input_dim, inner_config.output_dim)
test_dataset = DummyDataset(1000, inner_config.input_dim, inner_config.output_dim)
train_dataloader = NumpyLoader(train_dataset, batch_size=inner_config.batch_size, shuffle=True)
test_dataloader = NumpyLoader(test_dataset, batch_size=inner_config.batch_size, shuffle=True)

opt = optax.rmsprop(learning_rate=inner_config.lr, momentum=inner_config.mu, decay=inner_config.l2_reg)
meta_opt = optax.rmsprop(learning_rate=outer_config.lr, momentum=outer_config.mu)

logging.info("Baseline NN outer loop test")
baseline_acts, baseline_stats = outer_opt(train_dataloader, test_dataloader,compute_loss_baseline ,inner_config, outer_config, opt, meta_opt, save_path=None)

INFO:root:Baseline NN outer loop test


  0%|          | 0/2 [00:00<?, ?it/s]

INFO:root:Step 000 | Train Loss: 8.4353e-02 | Test Loss: 8.5589e-02 | Grad Norm: 6.3488e-01
INFO:root:Step 001 | Train Loss: 9.1323e-02 | Test Loss: 8.8853e-02 | Grad Norm: 2.6783e-01


In [None]:
#| test
# test outer_opt HNN
inner_config = InnerConfig(test_train_split=0.8,
                            input_dim=2,
                            output_dim=1,
                            hidden_layer_sizes=[32],
                            batch_size=64,
                            epochs=5,
                            lr=1e-3,
                            mu=0.9,
                            n_fns=2,
                            l2_reg=1e-1,
                            seed=42)
outer_config = OuterConfig(input_dim=1,
                            output_dim=1,
                            hidden_layer_sizes=[18],
                            batch_size=1,
                            steps=2,
                            print_every=1,
                            lr=1e-3,
                            mu=0.9,
                            seed=24)
data_path = "../data/RealWorld/"
train_dataset = DummyDataset(1000, inner_config.input_dim, 2)
test_dataset = DummyDataset(1000, inner_config.input_dim, 2)
train_dataloader = NumpyLoader(train_dataset, batch_size=inner_config.batch_size, shuffle=True)
test_dataloader = NumpyLoader(test_dataset, batch_size=inner_config.batch_size, shuffle=True)

opt = optax.rmsprop(learning_rate=inner_config.lr, momentum=inner_config.mu, decay=inner_config.l2_reg)
meta_opt = optax.rmsprop(learning_rate=outer_config.lr, momentum=outer_config.mu)

logging.info("Hamiltonian NN outer loop test")
HNN_acts, HNN_stats = outer_opt(train_dataloader, test_dataloader,compute_loss_hnn ,inner_config, outer_config, opt, meta_opt, save_path=None)

INFO:root:Hamiltonian NN outer loop test


  0%|          | 0/2 [00:00<?, ?it/s]

INFO:root:Step 000 | Train Loss: 1.1194e-01 | Test Loss: 9.4442e-02 | Grad Norm: 8.8605e-01
INFO:root:Step 001 | Train Loss: 8.8917e-02 | Test Loss: 9.3278e-02 | Grad Norm: 7.9053e-01


In [None]:
#| hide
nbdev.nbdev_export()