In [1]:
from FlaxTrainer.callbacks import mockedcallback
from FlaxTrainer.trainer import TrainerModule

from flax import linen as nn
from typing import Any, Sequence, Optional, Tuple, Iterator, Dict, Callable, Union
import flax

from jax import numpy as jnp
import jax

import torch
import torch.utils.data as data
import numpy as np


CHECKPOINT_PATH = "./saved_models/"


def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)

def create_data_loaders(*datasets : Sequence[data.Dataset],
                        train : Union[bool, Sequence[bool]] = True,
                        batch_size : int = 128,
                        num_workers : int = 4,
                        seed : int = 42):
    """
    Creates data loaders used in JAX for a set of datasets.

    Args:
      datasets: Datasets for which data loaders are created.
      train: Sequence indicating which datasets are used for
        training and which not. If single bool, the same value
        is used for all datasets.
      batch_size: Batch size to use in the data loaders.
      num_workers: Number of workers for each dataset.
      seed: Seed to initialize the workers and shuffling with.
    """
    loaders = []
    if not isinstance(train, (list, tuple)):
        train = [train for _ in datasets]
    for dataset, is_train in zip(datasets, train):
        loader = data.DataLoader(dataset,
                                 batch_size=batch_size,
                                 shuffle=is_train,
                                 drop_last=is_train,
                                 collate_fn=numpy_collate,
                                 num_workers=num_workers,
                                 persistent_workers=is_train,
                                 generator=torch.Generator().manual_seed(seed))
        loaders.append(loader)
    return loaders



def target_function(x):
    return np.sin(x * 3.0)

class RegressionDataset(data.Dataset):

    def __init__(self, num_points, seed):
        super().__init__()
        rng = np.random.default_rng(seed)
        self.x = rng.uniform(low=-2.0, high=2.0, size=num_points)
        self.y = target_function(self.x)

    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self, idx):
        return self.x[idx:idx+1], self.y[idx:idx+1]

train_set = RegressionDataset(num_points=1000, seed=42)
val_set = RegressionDataset(num_points=200, seed=43)
test_set = RegressionDataset(num_points=500, seed=44)
train_loader, val_loader, test_loader = create_data_loaders(train_set, val_set, test_set,
                                                            train=[True, False, False],
                                                            batch_size=64)

x = np.linspace(-2, 2, 1000)

  PyTreeDef = type(jax.tree_structure(None))


In [2]:
class MLPRegressor(nn.Module):
    hidden_dims : Sequence[int]
    output_dim : int

    @nn.compact
    def __call__(self, x, **kwargs):
        for dims in self.hidden_dims:
            x = nn.Dense(dims)(x)
            x = nn.silu(x)
        x = nn.Dense(self.output_dim)(x)
        return x

mlp = MLPRegressor([128, 128], 1)


In [3]:
class MLPRegressTrainer(TrainerModule):

    def __init__(self,
                 **kwargs):
        super().__init__(**kwargs)

    def create_functions(self):
        def mse_loss(params, apply_fn, batch):
            x, y = batch
            pred = apply_fn({'params': params}, x)
            loss = ((pred - y) ** 2).mean()
            return loss

        def train_step(state, batch):
            loss_fn = lambda params: mse_loss(params, state.apply_fn, batch)
            loss, grads = jax.value_and_grad(loss_fn)(state.params)
            state = state.apply_gradients(grads=grads)
            metrics = {'loss': loss}
            return state, metrics

        def eval_step(state, batch):
            loss = mse_loss(state.params, state.apply_fn, batch)
            return {'loss': loss}

        return train_step, eval_step




In [4]:
# TODO: Solve conflict of check_val_every_n_epoch and num_epochs
#mock = mockedcallback.MockedCallback(stop_train=False)
trainer = MLPRegressTrainer(optimizer_hparams={'lr': 4e-3},
                            logger_params={'base_log_dir': CHECKPOINT_PATH},                           
                            check_val_every_n_epoch=5)
 #                           callbacks=[mock])

state = trainer.init_model(mlp,exmp_input=next(iter(train_loader))[0:1])
#print(state)
jax.tree_map(lambda x: x.shape, state.params)


  leaves = jax.tree_leaves(pytree)







FrozenDict({
    Dense_0: {
        bias: (128,),
        kernel: (1, 128),
    },
    Dense_1: {
        bias: (128,),
        kernel: (128, 128),
    },
    Dense_2: {
        bias: (1,),
        kernel: (128, 1),
    },
})

In [5]:
metrics, state = trainer.train_model(
    mlp,
    state,
    train_loader,
    val_loader,
    test_loader=test_loader,
    num_epochs=50
)

#print(state)
print(f'Training loss: {metrics["train/loss"]}')
print(f'Validation loss: {metrics["val/loss"]}')
print(f'Test loss: {metrics["test/loss"]}')

Epochs: 100%|██████████| 50/50 [00:04<00:00, 10.77it/s]


Training loss: 0.0008829445578157902
Validation loss: 0.0008724512881599367
Test loss: 0.0007670423365198076


In [6]:
class simpleRNN(nn.Module):
    #hidden_size: int
    #output_size: int
    
    @nn.compact
    def __call__(self, inputs, hidden):
        x = jax.numpy.concatenate([inputs, hidden], axis=-1)
        i2o = nn.Dense(20)(x)
        i2o = nn.softmax(i2o)
        i2h = nn.Dense(hidden.shape[-1])(hidden)
        i2h = nn.relu(i2h)
        return i2o, i2h

    def init_hidden(self):
        return 




In [7]:
srnn = simpleRNN()
srnn

simpleRNN()

In [8]:
class SRNNTrainer(TrainerModule):

    def __init__(self,
                 **kwargs):
        super().__init__(**kwargs)

    def create_functions(self):
        def mse_loss(params, apply_fn, batch):
            x, y = batch
            pred = apply_fn({'params': params}, x)
            loss = ((pred - y) ** 2).mean()
            return loss

        def train_step(state, batch):
            loss_fn = lambda params: mse_loss(params, state.apply_fn, batch)
            loss, grads = jax.value_and_grad(loss_fn)(state.params)
            state = state.apply_gradients(grads=grads)
            metrics = {'loss': loss}
            return state, metrics

        def eval_step(state, batch):
            loss = mse_loss(state.params, state.apply_fn, batch)
            return {'loss': loss}

        return train_step, eval_step



NameError: name 'key' is not defined

In [None]:
res = srnn.apply(variables, a, b)
jax.tree_map(lambda x: x.shape, res)

((3, 2, 20), (3, 2, 3))

In [None]:
b.shape

(3, 2, 3)

In [None]:
a.shape

(3, 2, 4)

In [9]:
def f(carry, x):
     res = x*x
     return res, res

a = jax.random.randint(jax.random.PRNGKey(128), (10000,), minval=1, maxval=3)
a

DeviceArray([1, 2, 2, ..., 1, 2, 1], dtype=int32)

In [14]:
jax.lax.scan(f, jnp.array(0, dtype=jnp.int32), a)

(DeviceArray(1, dtype=int32),
 DeviceArray([1, 4, 4, ..., 1, 4, 1], dtype=int32))

In [15]:
jnp.power(a, 2)

DeviceArray([1, 4, 4, ..., 1, 4, 1], dtype=int32)

In [26]:
def cumsum(res, el):
    """
    - `res`: The result from the previous loop.
    - `el`: The current array element.
    """
    res = res + el
    return res, res  # ("carryover", "accumulated")


result_init = 0
final, result = jax.lax.scan(cumsum, result_init, a)
result

DeviceArray([    1,     3,     5, ..., 14991, 14993, 14994], dtype=int32)

In [23]:
jnp.cumsum(a)

DeviceArray([    1,     3,     5, ..., 14991, 14993, 14994], dtype=int32)

In [28]:
final

DeviceArray(14994, dtype=int32)

In [30]:
cumsum = lambda res, el: (res+el, res+el)