In [1]:
import os
import sys

module_path = os.path.abspath(os.path.join("../../../"))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
import tensorflow_datasets as tfds

2022-11-15 17:56:48.057098: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-11-15 17:56:48.169242: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-11-15 17:56:48.752131: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:
2022-11-15 17:56:48.752205: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot 

In [3]:
import jax
import jax.numpy as jnp
from jax import jit, vmap

In [4]:
from flax import linen as nn
from flax.training import train_state
from flax.linen.activation import tanh
from flax.linen.activation import sigmoid
from flax.linen.initializers import zeros

In [5]:
import optax

In [6]:
from dataclasses import field
from typing import Any, Callable, Optional, Sequence
from functools import partial

In [7]:
from omegaconf import DictConfig, OmegaConf
import wandb
import hydra

In [8]:
from src.models.hippo.hippo import HiPPO
from src.models.hippo.transition import TransMatrix
from src.data.process import moving_window, rolling_window

In [9]:
class RNNCell(nn.Module):
    """
    Description:
        W_xh = x_{t} @ W_{xh} - multiply the previous hidden state with
        W_hh = H_{t-1} @ W_{hh} + b_{h} - this a linear layer

        H_{t} = f_{w}(H_{t-1}, x)
        H_{t} = \phi(H_{t-1} @ W_{hh}) + (x_{t} @ W_{xh})

    Args:
        nn (_type_): _description_

    Returns:
        _type_: _description_
    """

    input_size: int
    hidden_size: int
    bias: bool = True
    param_dtype: Any = jnp.float32
    activation_fn: Callable[..., Any] = tanh

    def setup(self):
        self.dense_i = nn.Dense(
            self.hidden_size, use_bias=self.bias, param_dtype=self.param_dtype
        )
        self.dense_h = nn.Dense(
            self.hidden_size, use_bias=self.bias, param_dtype=self.param_dtype
        )

    def __call__(self, carry, input):
        ht_1, _ = carry

        w_hh = self.dense_h(ht_1)
        w_xh = self.dense_i(input)

        h_t = self.activation_fn(
            (w_hh + w_xh)
        )  # H_{t} = tanh(H_{t-1} @ W_{hh}) + (x_{t} @ W_{xh})

        return (h_t, h_t), h_t

    @staticmethod
    def initialize_carry(
        rng,
        batch_size: tuple,
        hidden_size: int,
        init_fn=nn.initializers.zeros,
    ):
        key1, key2 = jax.random.split(rng)
        mem_shape = batch_size + (hidden_size,)
        return init_fn(key1, mem_shape), init_fn(key2, mem_shape)

In [10]:
class LSTMCell(nn.Module):
    """
    Description:
        i_{t} = sigmoid((W_{ii} @ x_{t} + b_{ii}) + (W_{hi} @ h_{t-1} + b_{hi}))
        f_{t} = sigmoid((W_{if} @ x_{t} + b_{if}) + (W_{hf} @ h_{t-1} + b_{hf}))
        g_{t} = tanh((W_{ig} @ x_{t} + b_{ig}) + (W_{hg} @ h_{t-1} + b_{hg}))
        o_{t} = sigmoid((W_{io} @ x_{t} + b_{io}) + (W_{ho} @ h_{t-1} + b_{ho}))
        c_{t} = f_{t} * c_{t-1} + i_{t} * g_{t}
        h_{t} = o_{t} * tanh(c_{t})

    Args:
        hidden_size (int): hidden state size
        carry (jnp.ndarray): hidden state from previous time step
        input (jnp.ndarray): # input vector

    Returns:
        A tuple with the new carry and the output.
    """

    input_size: int
    hidden_size: int
    bias: bool = True
    param_dtype: Any = jnp.float32
    gate_fn: Callable[..., Any] = sigmoid
    activation_fn: Callable[..., Any] = tanh

    def setup(self):
        self.dense_i_ti = nn.Dense(
            features=self.hidden_size, use_bias=self.bias, param_dtype=self.param_dtype
        )
        self.dense_i_th = nn.Dense(
            features=self.hidden_size, use_bias=self.bias, param_dtype=self.param_dtype
        )

        self.dense_o_ti = nn.Dense(
            features=self.hidden_size, use_bias=self.bias, param_dtype=self.param_dtype
        )
        self.dense_o_th = nn.Dense(
            features=self.hidden_size, use_bias=self.bias, param_dtype=self.param_dtype
        )

        self.dense_f_ti = nn.Dense(
            features=self.hidden_size, use_bias=self.bias, param_dtype=self.param_dtype
        )
        self.dense_f_th = nn.Dense(
            features=self.hidden_size, use_bias=self.bias, param_dtype=self.param_dtype
        )

        self.dense_g_ti = nn.Dense(
            features=self.hidden_size, use_bias=self.bias, param_dtype=self.param_dtype
        )
        self.dense_g_th = nn.Dense(
            features=self.hidden_size, use_bias=self.bias, param_dtype=self.param_dtype
        )

    def __call__(self, carry, input):
        h_t, c_t = carry

        i_ti = self.dense_i_ti(input)
        i_th = self.dense_i_th(h_t)
        i_t = self.gate_fn(i_ti + i_th)

        o_ti = self.dense_o_ti(input)
        o_th = self.dense_o_th(h_t)
        o_t = self.gate_fn(o_ti + o_th)

        f_ti = self.dense_f_ti(input)
        f_th = self.dense_f_th(h_t)
        f_t = self.gate_fn(f_ti + f_th)

        g_ti = self.dense_g_ti(input)
        g_th = self.dense_g_th(h_t)
        g_t = self.activation_fn(g_ti + g_th)

        c_t = (f_t * c_t) + (i_t * g_t)
        h_t = o_t * self.activation_fn(c_t)

        return (h_t, c_t), h_t

    @staticmethod
    def initialize_carry(
        rng,
        batch_size: tuple,
        hidden_size: int,
        init_fn=nn.initializers.zeros,
    ):
        key1, key2 = jax.random.split(rng)
        mem_shape = batch_size + (hidden_size,)
        return init_fn(key1, mem_shape), init_fn(key2, mem_shape)

In [11]:
class GRUCell(nn.Module):
    """
    Description:
        z_t = sigmoid((W_{iz} @ x_{t} + b_{iz}) + (W_{hz} @ h_{t-1} + b_{hz}))
        r_t = sigmoid((W_{ir} @ x_{t} + b_{ir}) + (W_{hr} @ h_{t-1} + b_{hr}))
        n_t = tanh(((W_{in} @ x_{t} + b_{in}) + r_t) * (W_{hn} @ h_{t-1} + b_{hn}))
        h_t = (z_t * h_{t-1}) + ((1 - z_t) * n_i)

    Args:
        hidden_size (int): hidden state size
        carry (jnp.ndarray): hidden state from previous time step
        input (jnp.ndarray): # input vector

    Returns:
        A tuple with the new carry and the output.
    """

    input_size: int
    hidden_size: int
    bias: bool = True
    param_dtype: Any = jnp.float32
    gate_fn: Callable[..., Any] = sigmoid
    activation_fn: Callable[..., Any] = tanh

    def setup(self):
        self.dense_z_ti = nn.Dense(
            features=self.hidden_size, use_bias=self.bias, param_dtype=self.param_dtype
        )
        self.dense_z_th = nn.Dense(
            features=self.hidden_size, use_bias=self.bias, param_dtype=self.param_dtype
        )

        self.dense_r_ti = nn.Dense(
            features=self.hidden_size, use_bias=self.bias, param_dtype=self.param_dtype
        )
        self.dense_r_th = nn.Dense(
            features=self.hidden_size, use_bias=self.bias, param_dtype=self.param_dtype
        )

        self.dense_n_ti = nn.Dense(
            features=self.hidden_size, use_bias=self.bias, param_dtype=self.param_dtype
        )
        self.dense_n_th = nn.Dense(
            features=self.hidden_size, use_bias=self.bias, param_dtype=self.param_dtype
        )

    def __call__(self, carry, input):
        h_t_1, h_t_1 = carry

        z_ti = self.dense_z_ti(input)
        z_th = self.dense_z_th(h_t_1)
        z_t = self.gate_fn(z_ti + z_th)

        r_ti = self.dense_r_ti(input)
        r_th = self.dense_r_th(h_t_1)
        r_t = self.gate_fn(r_ti + r_th)

        n_ti = self.dense_n_ti(input)
        n_th = self.dense_n_th(h_t_1)
        n_t = self.activation_fn(n_ti + (r_t * n_th))

        h_t = ((1 - z_t) * n_t) + (z_t * h_t_1)

        return (h_t, h_t), h_t

    @staticmethod
    def initialize_carry(
        rng,
        batch_size: tuple,
        hidden_size: int,
        init_fn=nn.initializers.zeros,
    ):
        key1, key2 = jax.random.split(rng)
        mem_shape = batch_size + (hidden_size,)
        return init_fn(key1, mem_shape), init_fn(key2, mem_shape)

In [12]:
class HiPPOCell(nn.Module):
    """
    Description:
        z_t = sigmoid((W_{iz} @ x_{t} + b_{iz}) + (W_{hz} @ h_{t-1} + b_{hz}))
        r_t = sigmoid((W_{ir} @ x_{t} + b_{ir}) + (W_{hr} @ h_{t-1} + b_{hr}))
        n_t = tanh(((W_{in} @ x_{t} + b_{in}) + r_t) * (W_{hn} @ h_{t-1} + b_{hn}))
        h_t = (z_t * h_{t-1}) + ((1 - z_t) * g_i)

    Args:
        hidden_size (int): hidden state size
        carry (jnp.ndarray): hidden state from previous time step
        input (jnp.ndarray): # input vector

    Returns:
        A tuple with the new carry and the output.
    """

    input_size: int
    hidden_size: int
    bias: bool = True
    param_dtype: Any = jnp.float32
    gate_fn: Callable[..., Any] = sigmoid
    activation_fn: Callable[..., Any] = tanh
    measure: str = "legs"
    lambda_n: float = 1.0
    fourier_type: str = "fru"
    alpha: float = 0.0
    beta: float = 1.0
    GBT_alpha: float = 0.5
    rnn_cell: Callable[..., Any] = GRUCell

    def setup(self):
        hippo_matrices = TransMatrix(
            N=self.hidden_size,
            measure=self.measure,
            lambda_n=self.lambda_n,
            fourier_type=self.fourier_type,
            alpha=self.alpha,
            beta=self.beta,
        )
        A = hippo_matrices.A_matrix
        B = hippo_matrices.B_matrix
        L = self.input_size

        self.hippo = HiPPO(
            N=self.hidden_size,
            max_length=L,
            step=1.0 / L,
            GBT_alpha=self.GBT_alpha,
            seq_L=L,
            A=A,
            B=B,
            measure=self.measure,
        )

        self.rnn = self.rnn_cell(
            input_size=self.input_size,
            hidden_size=self.hidden_size,
            bias=self.bias,
            param_dtype=self.param_dtype,
            gate_fn=self.gate_fn,
            activation_fn=self.activation_fn,
        )

        self.dense_f_th = nn.Dense(
            features=self.hidden_size, use_bias=self.bias, param_dtype=self.param_dtype
        )

    def __call__(self, carry, input):
        _, c_t_1 = carry

        carry, _ = self.rnn(carry, input)
        h_t, _ = carry

        f_t = self.dense_f_th(h_t)
        c_t = self.hippo(f=f_t, init_state=c_t_1, t_step=f_t.shape[0], kernel=False)

        return (h_t, c_t), h_t

    @staticmethod
    def initialize_carry(
        rng,
        batch_size: tuple,
        hidden_size: int,
        init_fn=nn.initializers.zeros,
    ):
        key1, key2 = jax.random.split(rng)
        mem_shape = batch_size + (hidden_size,)
        return init_fn(key1, mem_shape), init_fn(key2, mem_shape)

In [13]:
class DeepRNN(nn.Module):
    output_size: int
    layers: Sequence[Any]
    skip_connections: bool
    layer_name: Optional[str] = None

    def setup(self):
        if self.skip_connections:
            for layer in self.layers:
                if not isinstance(layer, nn.Module):
                    raise ValueError(
                        "skip_connections requires for all layers to be "
                        "`nn.Module. Layers is: {}".format(self.layers)
                    )

        self.dense_out = nn.Dense(features=self.output_size)

    def __call__(self, carry, input):
        out_carry = None
        output = None
        h_t, c_t = carry
        h_t_list = []
        c_t_list = []
        states = []

        for t in range(input.shape[1]):
            for idx, layer in enumerate(self.layers):
                if isinstance(layer, nn.Module):
                    if idx == 0:
                        out_carry, output = layer(carry, input[:, t, :])
                        h_t, c_t = out_carry

                    else:
                        h_t_1, c_t_1 = out_carry
                        out_carry, output = layer(carry, h_t_1)
                        h_t, c_t = out_carry
                        if self.skip_connections:
                            h_t = jnp.concatenate([h_t, h_t_1], axis=1)
                            c_t = jnp.concatenate([c_t, c_t_1], axis=1)
                            out_carry = tuple([h_t, c_t])
                else:
                    out_carry, output = layer(out_carry)

                h_t_list.append(h_t)
                c_t_list.append(c_t)
                states.append(output)

            carry = out_carry

        next_carry = None
        concat = lambda *args: jnp.concatenate(args, axis=-1)
        if self.skip_connections:
            h_t = jax.tree_map(concat, *h_t_list)
            c_t = jax.tree_map(concat, *c_t_list)
            next_carry = tuple([h_t, c_t])
        else:
            next_carry = out_carry

        return next_carry, self.dense_out(output)

    @staticmethod
    def initialize_carry(
        rng,
        batch_size: tuple,
        hidden_size: int,
        init_fn=nn.initializers.zeros,
    ):
        key1, key2 = jax.random.split(rng)
        # mem_shape = batch_size + (input_size, hidden_size)
        mem_shape = batch_size + (hidden_size,)
        return init_fn(key1, mem_shape), init_fn(key2, mem_shape)

In [14]:
@partial(jit, static_argnums=(1,))
def moving_window(a, size: int):
    starts = jnp.arange(len(a) - size + 1)
    return vmap(lambda start: jax.lax.dynamic_slice(a, (start,), (size,)))(starts)

In [15]:
def rolling_window(a: jnp.ndarray, window: int):
    idx = jnp.arange(len(a) - window + 1)[:, None] + jnp.arange(window)[None, :]
    return a[idx]

In [16]:
def test():
    seed = 1701
    key = jax.random.PRNGKey(seed)

    num_copies = 4
    rng, key, subkey, subsubkey = jax.random.split(key, num=num_copies)

    hidden_size = 256

    # batch size, sequence length, input size
    batch_size = 32
    data_size = 28 * 28
    input_size = 5

    # fake data
    x = jax.random.randint(rng, (batch_size, data_size), 0, 244)
    x = vmap(moving_window, in_axes=(0, None))(x, input_size)

    layer_list = []
    num_of_rnns = 3
    rnn_type = "rnn"
    if rnn_type == "rnn":
        layer_list = [
            RNNCell(input_size=input_size, hidden_size=hidden_size)
            for _ in range(num_of_rnns)
        ]

    elif rnn_type == "lstm":
        layer_list = [
            LSTMCell(input_size=input_size, hidden_size=hidden_size)
            for _ in range(num_of_rnns)
        ]

    elif rnn_type == "gru":
        layer_list = [
            GRUCell(input_size=input_size, hidden_size=hidden_size)
            for _ in range(num_of_rnns)
        ]

    elif rnn_type == "hippo":
        layer_list = [
            HiPPOCell(input_size=input_size, hidden_size=hidden_size)
            for _ in range(num_of_rnns)
        ]

    else:
        raise ValueError("rnn_type must be one of: rnn, lstm, gru, hippo")

    # model
    model = DeepRNN(
        output_size=10,
        layers=layer_list,
        skip_connections=False,
    )

    # get model params
    params = model.init(
        key,
        model.initialize_carry(
            rng=subkey,
            batch_size=(batch_size,),
            hidden_size=hidden_size,
            init_fn=nn.initializers.zeros,
        ),
        x,
    )

    carry, out = model.apply(
        params,
        model.initialize_carry(
            rng=subsubkey,
            batch_size=(batch_size,),
            hidden_size=hidden_size,
            init_fn=nn.initializers.zeros,
        ),
        x,
    )

    return carry, out

In [17]:
def tester():
    for i in range(1):
        test_carry, testx = test()
        xdims = testx.shape
        carrydims = test_carry[0].shape
        if i % 10 == 0 or i == 1 or i == 100:
            print(f"output array shape:\n{xdims}\n")
            print(f"h_t array shape:\n{carrydims}\n")
        assert xdims == (32, 10)
    print("Size test: passed.")

In [18]:
# tester()

In [19]:
def get_datasets():
    """Load MNIST train and test datasets into memory."""
    ds_builder = tfds.builder("mnist")
    ds_builder.download_and_prepare()
    train_ds = tfds.as_numpy(ds_builder.as_dataset(split="train", batch_size=-1))
    test_ds = tfds.as_numpy(ds_builder.as_dataset(split="test", batch_size=-1))
    train_ds["image"] = jnp.float32(train_ds["image"]) / 255.0
    test_ds["image"] = jnp.float32(test_ds["image"]) / 255.0
    return train_ds, test_ds

In [20]:
# train_ds, test_ds = get_datasets()
# print(train_ds["image"].shape)
# print(test_ds["image"].shape)

In [21]:
def pick_rnn_cell(cfg):
    # set rnn cell from rnn_type
    rnn_list = []
    if cfg["models"]["cells"]["cell_type"] == "rnn":
        rnn_list = [
            RNNCell(
                input_size=cfg["models"]["cells"]["rnn"]["input_size"],
                hidden_size=cfg["models"]["cells"]["rnn"]["hidden_size"],
                bias=cfg["models"]["cells"]["rnn"]["bias"],
                param_dtype=cfg["models"]["cells"]["rnn"]["param_dtype"],
                activation_fn=cfg["models"]["cells"]["rnn"]["activation_fn"],
            )
            for _ in range(cfg["models"]["deep_rnn"]["stack_number"])
        ]

    elif cfg["models"]["cells"]["cell_type"] == "lstm":
        rnn_list = [
            LSTMCell(
                input_size=cfg["models"]["cells"]["gated_rnn"]["input_size"],
                hidden_size=cfg["models"]["cells"]["gated_rnn"]["hidden_size"],
                bias=cfg["models"]["cells"]["gated_rnn"]["bias"],
                param_dtype=cfg["models"]["cells"]["gated_rnn"]["param_dtype"],
                gate_fn=cfg["models"]["cells"]["gated_rnn"]["gate_fn"],
                activation_fn=cfg["models"]["cells"]["gated_rnn"]["activation_fn"],
            )
            for _ in range(cfg["models"]["deep_rnn"]["stack_number"])
        ]

    elif cfg["models"]["cells"]["cell_type"] == "gru":
        rnn_list = [
            GRUCell(
                input_size=cfg["models"]["cells"]["gated_rnn"]["input_size"],
                hidden_size=cfg["models"]["cells"]["gated_rnn"]["hidden_size"],
                bias=cfg["models"]["cells"]["gated_rnn"]["bias"],
                param_dtype=cfg["models"]["cells"]["gated_rnn"]["param_dtype"],
                gate_fn=cfg["models"]["cells"]["gated_rnn"]["gate_fn"],
                activation_fn=cfg["models"]["cells"]["gated_rnn"]["activation_fn"],
            )
            for _ in range(cfg["models"]["deep_rnn"]["stack_number"])
        ]

    elif cfg["models"]["cells"]["cell_type"] == "hippo":
        rnn_list = [
            HiPPOCell(
                input_size=cfg["models"]["cells"]["hippo"]["input_size"],
                hidden_size=cfg["models"]["cells"]["hippo"]["hidden_size"],
                bias=cfg["models"]["cells"]["hippo"]["bias"],
                param_dtype=cfg["models"]["cells"]["hippo"]["param_dtype"],
                gate_fn=cfg["models"]["cells"]["hippo"]["gate_fn"],
                activation_fn=cfg["models"]["cells"]["hippo"]["activation_fn"],
                measure=cfg["models"]["cells"]["hippo"]["measure"],
                lambda_n=cfg["models"]["cells"]["hippo"]["lambda_n"],
                fourier_type=cfg["models"]["cells"]["hippo"]["fourier_type"],
                alpha=cfg["models"]["cells"]["hippo"]["alpha"],
                beta=cfg["models"]["cells"]["hippo"]["beta"],
                rnn_cell=cfg["models"]["cells"]["hippo"]["rnn_cell"],
            )
            for _ in range(cfg["models"]["deep_rnn"]["stack_number"])
        ]

    else:
        raise ValueError("Unknown rnn type")

    return rnn_list

In [22]:
def pick_model(key, cfg):
    # set model from net_type
    model = None
    params = None

    if cfg["models"]["model_type"] == "rnn":
        rnn_list = pick_rnn_cell(cfg)
        model = DeepRNN(
            output_size=cfg["models"]["deep_rnn"]["output_size"],
            layers=rnn_list,
            skip_connections=cfg["models"]["deep_rnn"]["skip_connections"],
        )
        init_carry = model.initialize_carry(
            rng=key,
            batch_size=(cfg["training"]["batch_size"],),
            hidden_size=cfg["models"]["deep_rnn"]["hidden_size"],
            init_fn=nn.initializers.zeros,
        )
        params = model.init(input, init_carry)

    elif cfg["models"]["model_type"] == "hippo":
        L = cfg["training"]["input_length"]
        hippo_matrices = TransMatrix(
            N=cfg["models"]["hippo"]["n"],
            measure=cfg["models"]["hippo"]["measure"],
            lambda_n=cfg["models"]["hippo"]["lambda_n"],
            fourier_type=cfg["models"]["hippo"]["fourier_type"],
            alpha=cfg["models"]["hippo"]["alpha"],
            beta=cfg["models"]["hippo"]["beta"],
        )
        model = HiPPO(
            N=cfg["models"]["hippo"]["n"],
            max_length=L,
            step=1.0 / L,
            GBT_alpha=cfg["models"]["hippo"]["GBT_alpha"],
            seq_L=L,
            A=hippo_matrices.A_matrix,
            B=hippo_matrices.B_matrix,
            measure=cfg["models"]["hippo"]["measure"],
        )
        params = model.init(f, init_state=None, t_step=0, kernel=False)

    elif cfg["models"]["model_type"] == "s4":
        raise NotImplementedError
        # model = S4()
        # params = model.init()

    else:
        raise ValueError("Unknown model type")

    return model, params

In [23]:
def preprocess_data(cfg, data):
    # preprocess data
    x = None
    if cfg["models"]["model_type"] == "rnn":
        x = vmap(jnp.ravel, in_axes=0)(x)
        x = vmap(moving_window, in_axes=(0, None))(x, cfg["training"]["input_length"])

    elif cfg["models"]["model_type"] == "hippo":
        raise NotImplementedError

    elif cfg["models"]["model_type"] == "s4":
        raise NotImplementedError

    else:
        raise ValueError("Unknown model type to preprocess for")

    return x

In [24]:
def preprocess_labels(cfg, labels):
    # preprocess data
    y = None
    if cfg["models"]["model_type"] == "rnn":
        y = jax.nn.one_hot(labels, 10)

    elif cfg["models"]["model_type"] == "hippo":
        raise NotImplementedError

    elif cfg["models"]["model_type"] == "s4":
        raise NotImplementedError

    else:
        raise ValueError("Unknown model type to preprocess for")

    return y

In [25]:
def pick_optim(cfg, model, params):

    tx = None
    if cfg["training"]["optimizer"] == "adam":
        tx = optax.adam(
            learning_rate=cfg["training"]["learning_rate"],
            weight_decay=cfg["training"]["weight_decay"],
        )
    elif cfg["training"]["optimizer"] == "sgd":
        tx = optax.sgd(learning_rate=cfg["training"]["learning_rate"])
    else:
        raise ValueError("Unknown optimizer")

    tx_state = tx.init(params)

    return train_state.TrainState.create(
        apply_fn=model.apply, params=params, tx=tx, opt_state=tx_state
    )

In [26]:
@jax.jit
def apply_model(state, data, labels):
    """Computes gradients, loss and accuracy for a single batch."""

    def loss_fn(params):
        logits = state.apply_fn({"params": params}, data)
        one_hot = jax.nn.one_hot(labels, 10)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return grads, loss, accuracy

In [27]:
@jax.jit
def update_model(state, grads):
    return state.apply_gradients(grads=grads)

In [28]:
@hydra.main(config_path="config", config_name="train")
def recurrent_train(
    cfg: DictConfig,
) -> None:  # num_epochs, opt_state, net_type="RNN", train_key=None):
    """
    Implements a learning loop over epochs.

    Args:
        cfg: Hydra config

    Returns:
        None

    """
    with wandb.init(
        project="BeeGass-Sequential", entity="beegass", config=cfg
    ):  # initialize wandb project for logging

        # get keys for parameters
        seed = cfg["training"]["seed"]
        key = jax.random.PRNGKey(seed)

        num_copies = cfg["training"]["key_num"]
        keys = jax.random.split(key, num=num_copies)

        # get train and test datasets
        train_set, test_set = get_datasets()

        # pick a model
        model, params = pick_model(keys[1], cfg)

        # pick an optimizer
        state = pick_optim(cfg, model, params)

        # pick a scheduler
        # TODO: implement choice of scheduler

        # pick a loss function
        # TODO: implement choice of loss function

        # get dataset info for training loop (number of steps per epoch)
        train_set_size = len(train_set["image"])
        steps_per_epoch = train_set_size // cfg["training"]["batch_size"]

        perms = jax.random.permutation(keys[0], train_set_size)
        perms = perms[
            : steps_per_epoch * cfg["training"]["batch_size"]
        ]  # skip incomplete batch
        perms = perms.reshape((steps_per_epoch, cfg["training"]["batch_size"]))

        epoch_loss = []
        epoch_accuracy = []

        # Loop over the training epochs
        for epoch in range(cfg["training"]["num_epochs"]):
            # start_time = time.time()

            for perm in perms:
                train_data = train_set["image"][perm, ...]
                train_labels = train_set["label"][perm, ...]
                train_data = preprocess_data(cfg, train_data)
                # train_labels = preprocess_labels(cfg, train_labels)
                grads, loss, accuracy = apply_model(
                    state=state, data=train_data, labels=train_labels
                )
                state = update_model(state, grads)
                epoch_loss.append(loss)
                epoch_accuracy.append(accuracy)

            # epoch_time = time.time() - start_time

            # train loss for current epoch
            train_loss = jnp.mean(epoch_loss)
            train_accuracy = jnp.mean(epoch_accuracy)

            # test loss for current epoch
            _, test_loss, test_accuracy = apply_model(
                state=state, data=test_set["image"], labels=test_set["label"]
            )

            # TODO: add logging of metrics

        return state

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path="config", config_name="train")


In [29]:
# fake data
height_dim = 28
width_dim = 28
batch_size = 32
input_size = 5
x = jax.random.randint(
    jax.random.PRNGKey(1701), (batch_size, height_dim, width_dim), 0, 255
)
print(f"image shape: {x.shape}")
x = vmap(jnp.ravel, in_axes=0)(x)
print(f"raveled image shape: {x.shape}")
x = vmap(moving_window, in_axes=(0, None))(x, input_size)
print(f"windowed raveled image shape: {x.shape}")

image shape: (32, 28, 28)
raveled image shape: (32, 784)
windowed raveled image shape: (32, 780, 5)


In [30]:
state = recurrent_train()

usage: ipykernel_launcher.py [--help] [--hydra-help] [--version]
                             [--cfg {job,hydra,all}] [--resolve]
                             [--package PACKAGE] [--run] [--multirun]
                             [--shell-completion] [--config-path CONFIG_PATH]
                             [--config-name CONFIG_NAME]
                             [--config-dir CONFIG_DIR]
                             [--experimental-rerun EXPERIMENTAL_RERUN]
                             [--info [{all,config,defaults,defaults-tree,plugins,searchpath}]]
                             [overrides [overrides ...]]
ipykernel_launcher.py: error: unrecognized arguments: -f


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


check [this](https://github.com/deepmind/optax/blob/master/examples/quick_start.ipynb) out

Refer to [this](https://github.com/manifest/flax-extra/blob/48efe1f1515893289b44646977bf5049a340b6c8/docs/notebooks/combinators.ipynb), [this](https://github.com/romanak/pyprobml/blob/65c82b9b43d2100cbc7c59e766161ee801c0f85f/notebooks/book1/15/rnn_jax.ipynb), [this](https://github.com/probml/pyprobml/blob/71d98dcdd3798525353eb1bfb9851b47e9d64bde/notebooks/book1/15/rnn_jax.ipynb) and [this](https://github.com/probml/probml-notebooks/blob/36cb173afce3f4a07a7b475cf8a7937025a60465/notebooks-d2l/rnn_jax.ipynb)