In [8]:
import flax
import flax.linen as nn

import jax
import jax.numpy as nn
from typing import Any, Sequence, Optional, Tuple, Iterator, Dict, Callable, Union

from jax import random, grad, value_and_grad, jit, vmap

from models import AutoEncoder, MLPBlock


from jax import numpy as jnp
import jax

import torch
import torch.utils.data as data
import numpy as np

In [9]:

def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)

def create_data_loaders(*datasets : Sequence[data.Dataset],
                        train : Union[bool, Sequence[bool]] = True,
                        batch_size : int = 128,
                        num_workers : int = 4,
                        seed : int = 42):
    """
    Creates data loaders used in JAX for a set of datasets.

    Args:
      datasets: Datasets for which data loaders are created.
      train: Sequence indicating which datasets are used for
        training and which not. If single bool, the same value
        is used for all datasets.
      batch_size: Batch size to use in the data loaders.
      num_workers: Number of workers for each dataset.
      seed: Seed to initialize the workers and shuffling with.
    """
    loaders = []
    if not isinstance(train, (list, tuple)):
        train = [train for _ in datasets]
    for dataset, is_train in zip(datasets, train):
        loader = data.DataLoader(dataset,
                                 batch_size=batch_size,
                                 shuffle=is_train,
                                 drop_last=is_train,
                                 collate_fn=numpy_collate,
                                 num_workers=num_workers,
                                 persistent_workers=is_train,
                                 generator=torch.Generator().manual_seed(seed))
        loaders.append(loader)
    return loaders



def target_function(x):
    return np.sin(x * 3.0)

class RegressionDataset(data.Dataset):

    def __init__(self, num_points, num_feat, seed):
        super().__init__()
        rng = np.random.default_rng(seed)
        self.x = rng.uniform(low=-1.0, high=1.0, size=(num_points, num_feat))
        self.y = target_function(self.x)

    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self, idx):
        return self.x[idx:idx+1], self.y[idx:idx+1]

train_set = RegressionDataset(num_points=1000, num_feat=40, seed=42)
val_set = RegressionDataset(num_points=200, num_feat=40, seed=43)
test_set = RegressionDataset(num_points=500, num_feat=40, seed=44)
train_loader, val_loader, test_loader = create_data_loaders(train_set, val_set, test_set,
                                                            train=[True, False, False],
                                                            batch_size=64)



In [10]:
input_size = 40
latent_size = 20
encoder = MLPBlock([1024,1024,512,512,512,512,128, 64], latent_size, hidden_activation='sigmoid')
bottleneck = MLPBlock([64, 64, 32, 32, 16, 16], latent_size, hidden_activation='relu')
decoder = MLPBlock([64, 128,512, 512, 512, 512, 1024, 1024], input_size, hidden_activation='tanh')

autoencoder = AutoEncoder(encoder, decoder, bottleneck)

In [11]:
from FlaxTrainer.trainer import TrainerModule

In [1]:
class TrainAutoEncoder(TrainerModule):
    def __init__(self,
                 **kwargs):
        super().__init__(**kwargs)

    def create_functions(self):
        def mse_loss(params, apply_fn, batch):
            x, _ = batch
            print(x)
            pred = apply_fn({'params': params}, x)
            loss = ((pred - x) ** 2).mean()
            return loss
    
        def train_step(state, batch):
            loss_fn = lambda params: mse_loss(params, state.apply_fn, batch)
            loss, grads = jax.value_and_grad(loss_fn)(state.params)
            state = state.apply_gradients(grads=grads)
            metrics = {'loss': loss}
            return state, metrics
        
        def eval_step(state, batch):
            loss = mse_loss(state.params, state.apply_fn, batch)
            return {'loss': loss}

        return train_step, eval_step
        

NameError: name 'TrainerModule' is not defined

In [2]:
CHECKPOINT_PATH = "./saved_models/"
# TODO: Solve conflict of check_val_every_n_epoch and num_epochs
#mock = mockedcallback.MockedCallback(stop_train=False)
trainer = TrainAutoEncoder(optimizer_hparams={'lr': 4e-3},
                            logger_params={'base_log_dir': CHECKPOINT_PATH},                           
                            check_val_every_n_epoch=5)
 #                           callbacks=[mock])



NameError: name 'TrainAutoEncoder' is not defined

In [14]:
state = trainer.init_model(
    autoencoder,exmp_input=next(iter(train_loader))[0:1]
)


  leaves = jax.tree_leaves(pytree)







In [15]:
metrics, state = trainer.train_model(
    autoencoder,
    state,
    train_loader,
    val_loader,
    test_loader=test_loader,
    num_epochs=100
)

#print(state)
print(f'Training loss: {metrics["train/loss"]}')
print(f'Validation loss: {metrics["val/loss"]}')
print(f'Test loss: {metrics["test/loss"]}')

Epochs: 100%|██████████| 100/100 [00:18<00:00,  5.41it/s]


Training loss: 0.33323097229003906
Validation loss: 0.33861681818962097
Test loss: 0.3346109688282013


In [35]:
a = autoencoder.encoder.apply({'params':state.params['encoder']}, next(iter(train_loader))[0])
autoencoder.decoder.apply({'params':state.params['decoder']}, a)

DeviceArray([[[-2.9340668 ,  0.9017485 , -1.5255569 , ...,  4.91492   ,
                2.6670747 , -0.73887753]],

             [[-2.9340687 ,  0.90174896, -1.525558  , ...,  4.9149218 ,
                2.6670752 , -0.73887813]],

             [[-2.9340699 ,  0.90174943, -1.525559  , ...,  4.9149237 ,
                2.6670752 , -0.7388786 ]],

             ...,

             [[-2.9340696 ,  0.9017494 , -1.5255587 , ...,  4.9149227 ,
                2.6670754 , -0.73887855]],

             [[-2.9340684 ,  0.9017489 , -1.5255578 , ...,  4.9149218 ,
                2.667075  , -0.73887783]],

             [[-2.9340703 ,  0.90174955, -1.5255591 , ...,  4.9149237 ,
                2.6670756 , -0.73887867]]], dtype=float32)