In [1]:
import jax
import jax.numpy as jnp
import flax
from flax import linen as nn
import orbax.checkpoint
from flax.metrics import tensorboard
from flax.training import train_state
import tensorflow_datasets as tfds
import optax
import orbax
from flax.training import orbax_utils
from typing import Any, Tuple, Mapping,Callable,List,Dict
import os
import tqdm
import time
import tensorflow as tf # For dataset
from clu import metrics
from flax import struct                # Flax dataclasses
import augmax
import grain.python as pygrain
import numpy as np
from jax.experimental import mesh_utils

from jax.sharding import PositionalSharding, NamedSharding
from jax.sharding import PartitionSpec
from functools import partial
import functools

2024-07-29 20:50:15.522351: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-29 20:50:15.537523: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-29 20:50:15.542099: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
jax.distributed.initialize() 

In [3]:
tf.config.experimental.set_visible_devices([], 'GPU')
tf.config.experimental.set_visible_devices([], 'TPU')

In [4]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]

In [5]:
P = jax.sharding.PartitionSpec
device_mesh = mesh_utils.create_device_mesh((4,))
pos_sharding = PositionalSharding(device_mesh)
mesh = jax.sharding.Mesh(mesh_utils.create_device_mesh((4,1)), ('data', 'model'))
named_sharding = NamedSharding(mesh, P('data', 'model'))

In [6]:
def get_dataset(data_name="mnist", batch_size=64, image_scale=160, splits=["train", "test"]):
    def augmenter():
        # @tf.function()
        def augment(sample):
            image = tf.cast(sample['image'], tf.float32) / 255.
            
            image = (
                tf.cast(sample["image"], tf.float32) - 127.5
            ) / 127.5
            image = tf.image.resize(
                image, [image_scale, image_scale], method="area", antialias=True
            )
            image = tf.image.random_flip_left_right(image)
            image = tf.image.random_contrast(image, 0.997, 1.05)
            image = tf.image.random_brightness(image, 0.2)

            image = tf.clip_by_value(image, -1.0, 1.0)
            label = sample['label']
            return {"image": image, "label": label}
        return augment

    # Load CelebA Dataset
    (train_ds, test_ds) = tfds.load(data_name, split=splits, shuffle_files=True)
    train_len = len(train_ds)
    test_len = len(test_ds)
    
    train_ds: tf.data.Dataset = train_ds
    test_ds: tf.data.Dataset = test_ds
    
    train_ds = (
        train_ds
        .map(
            augmenter(),
            num_parallel_calls=tf.data.AUTOTUNE,
        )
        .cache()  # Cache after augmenting to avoid recomputation
        .repeat()  # Repeats the dataset indefinitely
        .shuffle(4096)  # Ensure this is adequate for your dataset size
        .batch(batch_size, drop_remainder=True)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )
    
    test_ds = (
        test_ds
        .map(
            augmenter(),
            num_parallel_calls=tf.data.AUTOTUNE,
        )
        .cache()  # Cache after augmenting to avoid recomputation
        # .repeat()  # Repeats the dataset indefinitely
        .batch(batch_size, drop_remainder=True)
        .prefetch(tf.data.experimental.AUTOTUNE)
    )
    
    def get_trainset():
        return train_ds.as_numpy_iterator()
    
    def get_testset():
        return test_ds.as_numpy_iterator()
    return {
        "train": get_trainset,
        "test": get_testset,
        "train_len": train_len,
        "test_len": test_len,
        "batch_size": batch_size
    }

In [7]:
def get_dataset(data_name="mnist", batch_size=64, image_scale=256, method=jax.image.ResizeMethod.LANCZOS3):
    (train_ds, test_ds) = tfds.data_source("mnist", split=["train", "test"], try_gcs=False)
    train_len = len(train_ds)
    test_len = len(test_ds)

    cpu_device = jax.devices("cpu")[0]
    try:
        gpu_device = jax.devices("gpu")[0]
    except:
        gpu_device = None
        try:
            tpu_devices = jax.devices("tpu")
        except:
            tpu_devices = None
    print(f"Gpu Device: {gpu_device}, Cpu Device: {cpu_device}, TPU Devices: {tpu_devices}")
        
    def preprocess(image):
        image = (image - 127.5) / 127.5
        # image = jax.image.resize(image, (image_scale, image_scale, 3), method=method)
        image = jnp.clip(image, -1.0, 1.0)
        # image = jax.device_put(image, device=jax.devices("gpu")[0]) 
        return  image
    
    preprocess = jax.jit(preprocess, backend="cpu")

    class augmenter(pygrain.RandomMapTransform):
        def random_map(self, element: Dict[str, Any], rng: np.random.Generator) ->  Dict[str, jnp.array]:
            image = element['image']
            image = preprocess(image)
            # image = augments(rng.integers(0, 2**32, [2], dtype=np.uint32), image) 
            label = element['label']
            return {'image':image, 'label':label}

    transformations = [augmenter(), pygrain.Batch(batch_size, drop_remainder=True)]

    train_sampler = pygrain.IndexSampler(
        num_records=train_len,
        shuffle=True,
        seed=0,
        num_epochs=None,
        shard_options=pygrain.ShardByJaxProcess(),
    )
    
    train_ds = pygrain.DataLoader(
        data_source=train_ds,
        sampler=train_sampler,
        operations=transformations,
        # worker_count=0,
        # read_options=pygrain.ReadOptions(8, 500),
        # worker_buffer_size=5
        )
    
    test_sampler = pygrain.IndexSampler(
        num_records=test_len,
        shuffle=True,
        seed=0,
        num_epochs=None,
        shard_options=pygrain.ShardByJaxProcess(),
    )
    
    test_ds = pygrain.DataLoader(
        data_source=test_ds,
        sampler=test_sampler,
        operations=transformations,
        # worker_count=0,
        # read_options=pygrain.ReadOptions(8, 500),
        # worker_buffer_size=5
        )
    return {
        "train": train_ds,
        "test": test_ds,
        "train_len": train_len,
        "test_len": test_len,
        "batch_size": batch_size
    }


In [10]:
data = get_dataset("imagenette", batch_size=128, splits=['train', 'validation'])

In [11]:
d = next(iter(data['train']()))

2024-07-29 20:20:17.968071: W tensorflow/core/kernels/data/cache_dataset_ops.cc:913] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


In [11]:
@jax.pmap
def testfun(x):
    x = x.reshape((-1))
    print(x.shape)
    return jnp.sin(x)

In [10]:
inp = d['image'].reshape((-1, 160*160))
inp = jax.device_put(inp, named_sharding)
jax.debug.visualize_array_sharding(inp)

In [13]:
k = testfun(inp.reshape((4, -1, *inp.shape[1:])))
print(k.devices(), jax.debug.visualize_array_sharding(k), k)

(2457600,)


{TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)} None [[-0.24927016 -0.22547898 -0.23789191 ... -0.74764276 -0.7450645
  -0.7432891 ]
 [ 0.1453397   0.40575063  0.44915524 ... -0.3922717  -0.4679144
  -0.4169427 ]
 [ 0.30879593  0.33196172  0.3142953  ... -0.5827335  -0.5731142
  -0.6314607 ]
 [ 0.39536574  0.48585674  0.5317951  ... -0.5890244  -0.567123
  -0.68381965]]


In [7]:
import flax.jax_utils


@struct.dataclass
class Metrics(metrics.Collection):
  accuracy: metrics.Accuracy
  loss: metrics.Average.from_output('loss')

# Define the TrainState 
class SimpleTrainState(train_state.TrainState):
    rngs: jax.random.PRNGKey
    metrics: Metrics

    def get_random_key(self):
        rngs, subkey = jax.random.split(self.rngs)
        return self.replace(rngs=rngs), subkey

class SimpleTrainer:
    state : SimpleTrainState
    best_state : SimpleTrainState
    best_loss : float
    model : nn.Module
    ema_decay:float = 0.999
    
    def __init__(self, 
                 model:nn.Module, 
                input_shapes:Dict[str, Tuple[int]],
                 optimizer: optax.GradientTransformation,
                 rngs:jax.random.PRNGKey,
                 train_state:SimpleTrainState=None,
                 name:str="Simple",
                 load_from_checkpoint:bool=False,
                 checkpoint_suffix:str="",
                 loss_fn=optax.l2_loss,
                 param_transforms:Callable=None,
                 ):
        self.model = model
        self.name = name
        self.loss_fn = loss_fn
        self.input_shapes = input_shapes

        checkpointer = orbax.checkpoint.PyTreeCheckpointer()
        options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=4, create=True)
        self.checkpointer = orbax.checkpoint.CheckpointManager(self.checkpoint_path() + checkpoint_suffix, checkpointer, options)

        if load_from_checkpoint:
            latest_step, old_state, old_best_state = self.load()
        else:
            latest_step, old_state, old_best_state = 0, None, None
            
        self.latest_step = latest_step

        if train_state == None:
            self.init_state(optimizer, rngs, existing_state=old_state, existing_best_state=old_best_state, model=model, param_transforms=param_transforms)
        else:
            self.state = train_state
            self.best_state = train_state
            self.best_loss = 1e9
    
    def get_input_ones(self):
        return {k:jnp.ones((1, *v)) for k,v in self.input_shapes.items()}

    def init_state(self,
                   optimizer: optax.GradientTransformation, 
                   rngs:jax.random.PRNGKey,
                   existing_state:dict=None,
                   existing_best_state:dict=None,
                   model:nn.Module=None,
                   param_transforms:Callable=None
                   ):
        @partial(jax.pmap, axis_name="device")
        def init_fn(rngs):
            rngs, subkey = jax.random.split(rngs)

            if existing_state == None:
                input_vars = self.get_input_ones()
                params = model.init(subkey, **input_vars)

            # if param_transforms is not None:
            #     params = param_transforms(params)
                
            state = SimpleTrainState.create(
                apply_fn=model.apply,
                params=params,
                tx=optimizer,
                rngs=rngs,
                metrics=Metrics.empty()
            )
            return state
        self.state = init_fn(jax.device_put_replicated(rngs, jax.devices()))
        self.best_loss = 1e9
        if existing_best_state is not None:
            self.best_state = self.state.replace(params=existing_best_state['params'], ema_params=existing_best_state['ema_params'])
        else:
            self.best_state = self.state

    def checkpoint_path(self):
        experiment_name = self.name
        path = os.path.join(os.path.abspath('./checkpoints'), experiment_name)
        if not os.path.exists(path):
            os.makedirs(path)
        return path
    
    def tensorboard_path(self):
        experiment_name = self.name
        path = os.path.join(os.path.abspath('./tensorboard'), experiment_name)
        if not os.path.exists(path):
            os.makedirs(path)
        return path

    def load(self):
        step = self.checkpointer.latest_step()
        print("Loading model from checkpoint", step)
        ckpt = self.checkpointer.restore(step)
        state = ckpt['state']
        best_state = ckpt['best_state']
        # Convert the state to a TrainState
        self.best_loss = ckpt['best_loss']
        print(f"Loaded model from checkpoint at step {step}", ckpt['best_loss'])
        return step, state, best_state

    def save(self, epoch=0):
        print(f"Saving model at epoch {epoch}")
        ckpt = {
            # 'model': self.model,
            'state': self.state,
            'best_state': self.best_state,
            'best_loss': self.best_loss
        }
        try:
            # save_args = orbax_utils.save_args_from_target(ckpt)
            # self.checkpointer.save(epoch, ckpt, save_kwargs={'save_args': save_args}, force=True)
            pass
        except Exception as e:
            print("Error saving checkpoint", e)

    def _define_train_step(self):
        model = self.model
        loss_fn = self.loss_fn
        
        @partial(jax.pmap, axis_name="device")
        def train_step(state:SimpleTrainState, batch):
            """Train for a single step."""
            images = batch['image']
            labels= batch['label']
            
            def model_loss(params):
                preds = model.apply(params, images)
                expected_output = labels
                nloss = loss_fn(preds, expected_output)
                loss = jnp.mean(nloss)
                return loss
            loss, grads = jax.value_and_grad(model_loss)(state.params)
            grads = jax.lax.pmean(grads, "device")
            state = state.apply_gradients(grads=grads) 
            return state, loss
        return train_step
    
    def _define_compute_metrics(self):
        model = self.model
        loss_fn = self.loss_fn
        
        @jax.jit
        def compute_metrics(state:SimpleTrainState, batch):
            preds = model.apply(state.params, batch['image'])
            expected_output = batch['label']
            loss = jnp.mean(loss_fn(preds, expected_output))
            metric_updates = state.metrics.single_from_model_output(loss=loss, logits=preds, labels=expected_output)
            metrics = state.metrics.merge(metric_updates)
            state = state.replace(metrics=metrics)
            return state
        return compute_metrics

    def summary(self):
        input_vars = self.get_input_ones()
        print(self.model.tabulate(jax.random.key(0), **input_vars, console_kwargs={"width": 200, "force_jupyter":True, }))
    
    def config(self):
        return {
            "model": self.model,
            "state": self.state,
            "name": self.name,
            "input_shapes": self.input_shapes
        }
        
    def init_tensorboard(self, batch_size, steps_per_epoch, epochs):
        summary_writer = tensorboard.SummaryWriter(self.tensorboard_path())
        summary_writer.hparams({
            **self.config(),
            "steps_per_epoch": steps_per_epoch,
            "epochs": epochs,
            "batch_size": batch_size
        })
        return summary_writer
        
    def fit(self, data, steps_per_epoch, epochs):
        train_ds = iter(data['train']())
        if 'test' in data:
            test_ds = data['test']
        else:
            test_ds = None
        train_step = self._define_train_step()
        compute_metrics = self._define_compute_metrics()
        state = self.state
        device_count = jax.device_count()
        # train_ds = flax.jax_utils.prefetch_to_device(train_ds, jax.devices())
        
        summary_writer = self.init_tensorboard(data['batch_size'], steps_per_epoch, epochs)
        
        for epoch in range(epochs):
            current_epoch = self.latest_step + epoch + 1
            print(f"\nEpoch {current_epoch}/{epochs}")
            start_time = time.time()
            epoch_loss = 0
            
            with tqdm.tqdm(total=steps_per_epoch, desc=f'\t\tEpoch {current_epoch}', ncols=100, unit='step') as pbar:
                for i in range(steps_per_epoch):
                    batch = next(train_ds)
                    batch = jax.tree.map(lambda x: x.reshape((device_count, -1, *x.shape[1:])), batch)
                    # print(batch['image'].shape)
                    state, loss = train_step(state, batch)
                    loss = jnp.mean(loss)
                    # print("==>", loss)
                    epoch_loss += loss
                    if i % 100 == 0:
                        pbar.set_postfix(loss=f'{loss:.4f}')
                        pbar.update(100)
                        current_step = current_epoch*steps_per_epoch + i
                        summary_writer.scalar('Train Loss', loss, step=current_step)
                        
            end_time = time.time()
            self.state = state
            total_time = end_time - start_time
            avg_time_per_step = total_time / steps_per_epoch
            avg_loss = epoch_loss / steps_per_epoch
            if avg_loss < self.best_loss:
                self.best_loss = avg_loss
                self.best_state = state
                self.save(current_epoch)
            
            # Compute Metrics
            metrics_str = ''
            # if test_ds is not None:
            #     for test_batch in iter(test_ds()):
            #         state = compute_metrics(state, test_batch)
            #     metrics = state.metrics.compute()
            #     for metric,value in metrics.items():
            #         summary_writer.scalar(f'Test {metric}', value, step=current_epoch)
            #         metrics_str += f', Test {metric}: {value:.4f}'
            #     state = state.replace(metrics=Metrics.empty())
                    
            print(f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss} {metrics_str}")
            
        self.save(epochs)
        return self.state


In [8]:
class CNN(nn.Module):
  """A simple CNN model."""

  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    return x
  
  # Kernel initializer to use
def kernel_init(scale):
    scale = max(scale, 1e-10)
    return nn.initializers.variance_scaling(scale=scale, mode="fan_avg", distribution="truncated_normal")

class NormalAttention(nn.Module):
    """
    Simple implementation of the normal attention.
    """
    query_dim: int
    heads: int = 4
    dim_head: int = 64
    dtype: Any = jnp.float32
    precision: Any = jax.lax.Precision.HIGHEST
    use_bias: bool = True
    kernel_init: Callable = lambda : kernel_init(1.0)

    def setup(self):
        inner_dim = self.dim_head * self.heads
        dense = functools.partial(
            nn.DenseGeneral,
            features=[self.heads, self.dim_head], 
            axis=-1, 
            precision=self.precision, 
            use_bias=self.use_bias, 
            kernel_init=self.kernel_init(), 
            dtype=self.dtype
        )
        self.query = dense(name="to_q")
        self.key = dense(name="to_k")
        self.value = dense(name="to_v")

        self.proj_attn = nn.DenseGeneral(
            self.query_dim, 
            axis=(-2, -1), 
            precision=self.precision, 
            use_bias=self.use_bias, 
            dtype=self.dtype, 
            name="to_out_0",
            kernel_init=self.kernel_init()
            # kernel_init=jax.nn.initializers.xavier_uniform()
        )

    @nn.compact
    def __call__(self, x, context=None):
        # x has shape [B, H, W, C]
        context = x if context is None else context
        query = self.query(x)
        key = self.key(context)
        value = self.value(context)
        
        hidden_states = nn.dot_product_attention(
            query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision
        )
        proj = self.proj_attn(hidden_states)
        return proj
    
class AttentionBlock(nn.Module):
    heads: int = 4
    dim_head: int = 32
    use_linear_attention: bool = True
    dtype: Any = jnp.float32
    precision: Any = jax.lax.Precision.HIGH
    use_projection: bool = False

    @nn.compact
    def __call__(self, x):
        inner_dim = self.heads * self.dim_head
        B, H, W, C = x.shape
        normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
        projected_x = NormalAttention(
            query_dim=C,
            heads=self.heads,
            dim_head=self.dim_head,
            name=f'Attention',
            precision=self.precision,
            use_bias=False,
        )(normed_x)
        out = x + projected_x
        return out
    
class TimeEmbedding(nn.Module):
    features:int
    nax_positions:int=10000

    def setup(self):
        half_dim = self.features // 2
        emb = jnp.log(self.nax_positions) / (half_dim - 1)
        emb = jnp.exp(-emb * jnp.arange(half_dim, dtype=jnp.float32))
        self.embeddings = emb

    def __call__(self, x):
        x = jax.lax.convert_element_type(x, jnp.float32)
        emb = x[:, None] * self.embeddings[None, :]
        emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
        return emb
    
class TimeProjection(nn.Module):
    features:int
    activation:Callable=jax.nn.gelu

    @nn.compact
    def __call__(self, x):
        x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
        x = self.activation(x)
        x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
        x = self.activation(x)
        return x

class SeparableConv(nn.Module):
    features:int
    kernel_size:tuple=(3, 3)
    strides:tuple=(1, 1)
    use_bias:bool=False
    kernel_init:Callable=kernel_init(1.0)
    padding:str="SAME"

    @nn.compact
    def __call__(self, x):
        in_features = x.shape[-1]
        depthwise = nn.Conv(
            features=in_features, kernel_size=self.kernel_size,
            strides=self.strides, kernel_init=self.kernel_init,
            feature_group_count=in_features, use_bias=self.use_bias,
            padding=self.padding
        )(x)
        pointwise = nn.Conv(
            features=self.features, kernel_size=(1, 1),
            strides=(1, 1), kernel_init=self.kernel_init,
            use_bias=self.use_bias
        )(depthwise)
        return pointwise

class ConvLayer(nn.Module):
    conv_type:str
    features:int
    kernel_size:tuple=(3, 3)
    strides:tuple=(1, 1)
    kernel_init:Callable=kernel_init(1.0)

    def setup(self):
        if self.conv_type == "conv":
            self.conv = nn.Conv(
                features=self.features,
                kernel_size=self.kernel_size,
                strides=self.strides,
                kernel_init=self.kernel_init,
            )
        elif self.conv_type == "separable":
            self.conv = SeparableConv(
                features=self.features,
                kernel_size=self.kernel_size,
                strides=self.strides,
                kernel_init=self.kernel_init,
            )

    def __call__(self, x):
        return self.conv(x)

class Upsample(nn.Module):
    features:int
    scale:int
    activation:Callable=jax.nn.swish

    @nn.compact
    def __call__(self, x, residual=None):
        out = x
        # out = PixelShuffle(scale=self.scale)(out)
        B, H, W, C = x.shape
        out = jax.image.resize(x, (B, H * self.scale, W * self.scale, C), method="nearest")
        out = ConvLayer(
            "conv",
            features=self.features,
            kernel_size=(3, 3),
            strides=(1, 1),
        )(out)
        if residual is not None:
            out = jnp.concatenate([out, residual], axis=-1)
        return out

class Downsample(nn.Module):
    features:int
    scale:int
    activation:Callable=jax.nn.swish

    @nn.compact
    def __call__(self, x, residual=None):
        out = ConvLayer(
            "conv",
            features=self.features,
            kernel_size=(3, 3),
            strides=(2, 2)
        )(x)
        if residual is not None:
            if residual.shape[1] > out.shape[1]:
                residual = nn.avg_pool(residual, window_shape=(2, 2), strides=(2, 2), padding="SAME")
            out = jnp.concatenate([out, residual], axis=-1)
        return out

class ResidualBlock(nn.Module):
    conv_type:str
    features:int
    kernel_size:tuple=(3, 3)
    strides:tuple=(1, 1)
    padding:str="SAME"
    activation:Callable=jax.nn.swish
    direction:str=None
    res:int=2
    norm_groups:int=8
    kernel_init:Callable=kernel_init(1.0)

    @nn.compact
    def __call__(self, x:jax.Array, extra_features:jax.Array=None):
        residual = x
        out = nn.GroupNorm(self.norm_groups)(x)
        out = self.activation(out)

        out = ConvLayer(
            self.conv_type,
            features=self.features,
            kernel_size=self.kernel_size,
            strides=self.strides,
            kernel_init=self.kernel_init,
            name="conv1"
        )(out)

        out = nn.GroupNorm(self.norm_groups)(out)
        out = self.activation(out)

        out = ConvLayer(
            self.conv_type,
            features=self.features,
            kernel_size=self.kernel_size,
            strides=self.strides,
            kernel_init=self.kernel_init,
            name="conv2"
        )(out)

        if residual.shape != out.shape:
            residual = ConvLayer(
                self.conv_type,
                features=self.features,
                kernel_size=(1, 1),
                strides=1,
                kernel_init=self.kernel_init,
                name="residual_conv"
            )(residual)
        out = out + residual

        out = jnp.concatenate([out, extra_features], axis=-1) if extra_features is not None else out

        return out
    
class Classifier(nn.Module):
    feature_depths:list=[64, 128, 256, 512],
    attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}],
    num_res_blocks:int=2,
    num_middle_res_blocks:int=1,
    activation:Callable = jax.nn.swish
    norm_groups:int=8
    major_conv_type:str="conv"
    mid_conv_type:str="conv"
    
    @nn.compact
    def __call__(self, x):
        feature_depths = self.feature_depths
        attention_configs = self.attention_configs

        conv_type = "conv"
        up_conv_type = down_conv_type = self.major_conv_type
        middle_conv_type = self.mid_conv_type

        x = ConvLayer(
            conv_type,
            features=self.feature_depths[0],
            kernel_size=(3, 3),
            strides=(1, 1),
            kernel_init=kernel_init(1.0)
        )(x)
        downs = [x]

        # Downscaling blocks
        for i, (dim_out, attention_config) in enumerate(zip(feature_depths, attention_configs)):
            dim_in = x.shape[-1]
            # dim_in = dim_out
            for j in range(self.num_res_blocks):
                x = ResidualBlock(
                    down_conv_type,
                    name=f"down_{i}_residual_{j}",
                    features=dim_in,
                    kernel_init=kernel_init(1.0),
                    kernel_size=(3, 3),
                    strides=(1, 1),
                    activation=self.activation,
                    norm_groups=self.norm_groups
                )(x)
                if attention_config is not None and j == self.num_res_blocks - 1:   # Apply attention only on the last block
                    x = AttentionBlock(heads=attention_config['heads'], 
                                       dim_head=dim_in // attention_config['heads'],
                                       name=f"down_{i}_attention_{j}")(x)
                downs.append(x)
            if i != len(feature_depths) - 1:
                x = Downsample(
                    features=dim_out,
                    scale=2,
                    activation=self.activation,
                    name=f"down_{i}_downsample"
                )(x)

        # Middle Blocks
        middle_dim_out = self.feature_depths[-1]
        middle_attention = self.attention_configs[-1]
        for j in range(self.num_middle_res_blocks):
            x = ResidualBlock(
                middle_conv_type,
                name=f"middle_res1_{j}",
                features=middle_dim_out,
                kernel_init=kernel_init(1.0),
                kernel_size=(3, 3),
                strides=(1, 1),
                activation=self.activation,
                norm_groups=self.norm_groups
            )(x)
            if middle_attention is not None and j == self.num_middle_res_blocks - 1:   # Apply attention only on the last block
                x = AttentionBlock(heads=attention_config['heads'], 
                                   dim_head=middle_dim_out // attention_config['heads'],
                                   use_linear_attention=False, name=f"middle_attention_{j}")(x)
            x = ResidualBlock(
                middle_conv_type,
                name=f"middle_res2_{j}",
                features=middle_dim_out,
                kernel_init=kernel_init(1.0),
                kernel_size=(3, 3),
                strides=(1, 1),
                activation=self.activation,
                norm_groups=self.norm_groups
            )(x)
            
        x = Downsample(
            features=middle_dim_out * 2,
            scale=2,
            activation=self.activation,
            name=f"middle_downsample"
        )(x)
        
        x = ResidualBlock(
            conv_type,
            name="final_residual",
            features=self.feature_depths[0],
            kernel_init=kernel_init(1.0),
            kernel_size=(3,3),
            strides=(1, 1),
            activation=self.activation,
            norm_groups=self.norm_groups
        )(x)

        x = nn.GroupNorm(self.norm_groups)(x)
        x = self.activation(x)
        
        # Average Pooling
        x = nn.avg_pool(x, window_shape=(x.shape[1], x.shape[2]), strides=(1, 1))
        
        x = x.reshape((x.shape[0], -1))
        
        # Classifier head
        x = nn.Dense(10)(x)
        
        return x
        

In [9]:
model = Classifier(
                  feature_depths=[64, 128, 256, 512, 1024],
                  attention_configs=[None, {"heads": 8}, {"heads": 8}, {"heads": 8}],
                  num_res_blocks=2,
                  num_middle_res_blocks=1,
                  major_conv_type="conv"
            )

inp = jnp.ones((1, 160, 160, 3))
print(model.tabulate(jax.random.key(0), inp,
      console_kwargs={"width": 200, "force_jupyter": True, }))






In [10]:
# cnn = CNN()
model = Classifier(
                  feature_depths=[64, 128, 256, 512, 1024],
                  attention_configs=[None, {"heads": 8}, {"heads": 8}, {"heads": 8}],
                  num_res_blocks=2,
                  num_middle_res_blocks=1,
                  major_conv_type="conv"
            )
solver = optax.adam(1e-3)

def loss(preds, labels):
      # lambda preds, targets: optax.softmax_cross_entropy_with_integer_labels(preds, targets)
      # print(preds.shape, labels.shape)
      result = optax.softmax_cross_entropy_with_integer_labels(preds, labels)
      return result
      
trainer = SimpleTrainer(model, 
                        input_shapes={'x': (160, 160, 3)}, 
                        optimizer=solver, 
                        loss_fn=loss,
                        rngs=jax.random.PRNGKey(0))



In [35]:
jax.profiler.stop_server()

In [1]:
jax.profiler.start_server(6009)
data = get_dataset("imagenette", batch_size=256, splits=['train', 'validation'])
state = trainer.fit(data, steps_per_epoch=data['train_len'] // data['batch_size'], epochs=10)
jax.profiler.stop_server()

NameError: name 'jax' is not defined

In [11]:
jax.profiler.start_server(6009)
data = get_dataset("imagenette", batch_size=128, splits=['train', 'validation'])
state = trainer.fit(data, steps_per_epoch=data['train_len'] // data['batch_size'], epochs=10)
jax.profiler.stop_server()


Epoch 1/10


		Epoch 1: 100step [00:56,  1.77step/s, loss=2.3902]                                                


Saving model at epoch 1

	Epoch 1 completed. Avg Loss: 2.1881344318389893, Time: 56.66s, Best Loss: 2.1881344318389893 

Epoch 2/10


		Epoch 2: 100step [00:05, 18.16step/s, loss=2.2113]                                                


Saving model at epoch 2

	Epoch 2 completed. Avg Loss: 2.1151442527770996, Time: 5.51s, Best Loss: 2.1151442527770996 

Epoch 3/10


		Epoch 3:   0%|                                              | 0/73 [00:00<?, ?step/s, loss=2.1399]2024-07-29 20:52:01.746798: W external/tsl/tsl/profiler/lib/profiler_session.cc:109] Profiling is late by 2635800 nanoseconds and will start immediately.
		Epoch 3: 100step [00:05, 17.55step/s, loss=2.1399]                                                


Saving model at epoch 3

	Epoch 3 completed. Avg Loss: 2.0795538425445557, Time: 5.70s, Best Loss: 2.0795538425445557 

Epoch 4/10


		Epoch 4: 100step [00:05, 17.55step/s, loss=2.0495]                                                


Saving model at epoch 4

	Epoch 4 completed. Avg Loss: 2.0166125297546387, Time: 5.70s, Best Loss: 2.0166125297546387 

Epoch 5/10


		Epoch 5: 100step [00:05, 18.21step/s, loss=1.8212]                                                


Saving model at epoch 5

	Epoch 5 completed. Avg Loss: 1.9245163202285767, Time: 5.49s, Best Loss: 1.9245163202285767 

Epoch 6/10


		Epoch 6: 100step [00:05, 18.23step/s, loss=1.8679]                                                


Saving model at epoch 6

	Epoch 6 completed. Avg Loss: 1.8269599676132202, Time: 5.49s, Best Loss: 1.8269599676132202 

Epoch 7/10


		Epoch 7: 100step [00:05, 18.27step/s, loss=1.7161]                                                


Saving model at epoch 7

	Epoch 7 completed. Avg Loss: 1.6975194215774536, Time: 5.48s, Best Loss: 1.6975194215774536 

Epoch 8/10


		Epoch 8: 100step [00:05, 18.36step/s, loss=1.7334]                                                


Saving model at epoch 8

	Epoch 8 completed. Avg Loss: 1.6169333457946777, Time: 5.45s, Best Loss: 1.6169333457946777 

Epoch 9/10


		Epoch 9: 100step [00:05, 18.25step/s, loss=1.5885]                                                


Saving model at epoch 9

	Epoch 9 completed. Avg Loss: 1.5265167951583862, Time: 5.48s, Best Loss: 1.5265167951583862 

Epoch 10/10


		Epoch 10: 100step [00:05, 18.30step/s, loss=1.4427]                                               

Saving model at epoch 10

	Epoch 10 completed. Avg Loss: 1.4501495361328125, Time: 5.47s, Best Loss: 1.4501495361328125 
Saving model at epoch 10





In [21]:
# jax.profiler.start_server(6009)
data = get_dataset("mnist", batch_size=128)
state = trainer.fit(data, steps_per_epoch=data['train_len'] // data['batch_size'], epochs=10)
# jax.profiler.stop_server()


Epoch 1/10


		Epoch 1:   0%|                                                          | 0/468 [00:00<?, ?step/s]

		Epoch 1: 500step [00:04, 114.14step/s, loss=0.0888]                                               


Saving model at epoch 1

	Epoch 1 completed. Avg Loss: 0.18133877217769623, Time: 4.38s, Best Loss: 0.18133877217769623 

Epoch 2/10


		Epoch 2: 500step [00:00, 1287.91step/s, loss=0.0585]                                              


Saving model at epoch 2

	Epoch 2 completed. Avg Loss: 0.05323795974254608, Time: 0.39s, Best Loss: 0.05323795974254608 

Epoch 3/10


		Epoch 3: 500step [00:00, 1323.98step/s, loss=0.0094]                                              


Saving model at epoch 3

	Epoch 3 completed. Avg Loss: 0.03695608302950859, Time: 0.38s, Best Loss: 0.03695608302950859 

Epoch 4/10


		Epoch 4: 500step [00:00, 1280.80step/s, loss=0.0233]                                              


Saving model at epoch 4

	Epoch 4 completed. Avg Loss: 0.02820330113172531, Time: 0.39s, Best Loss: 0.02820330113172531 

Epoch 5/10


		Epoch 5: 500step [00:00, 1298.43step/s, loss=0.0154]                                              


Saving model at epoch 5

	Epoch 5 completed. Avg Loss: 0.02226649969816208, Time: 0.39s, Best Loss: 0.02226649969816208 

Epoch 6/10


		Epoch 6: 500step [00:00, 1353.08step/s, loss=0.0297]                                              


Saving model at epoch 6

	Epoch 6 completed. Avg Loss: 0.019302448257803917, Time: 0.37s, Best Loss: 0.019302448257803917 

Epoch 7/10


		Epoch 7: 500step [00:00, 1381.89step/s, loss=0.0008]                                              


Saving model at epoch 7

	Epoch 7 completed. Avg Loss: 0.014166121371090412, Time: 0.36s, Best Loss: 0.014166121371090412 

Epoch 8/10


		Epoch 8: 500step [00:00, 1362.68step/s, loss=0.0132]                                              


Saving model at epoch 8

	Epoch 8 completed. Avg Loss: 0.01146089006215334, Time: 0.37s, Best Loss: 0.01146089006215334 

Epoch 9/10


		Epoch 9: 500step [00:00, 1345.12step/s, loss=0.0053]                                              


Saving model at epoch 9

	Epoch 9 completed. Avg Loss: 0.011206524446606636, Time: 0.37s, Best Loss: 0.011206524446606636 

Epoch 10/10


		Epoch 10: 500step [00:00, 1353.41step/s, loss=0.0065]                                             

Saving model at epoch 10

	Epoch 10 completed. Avg Loss: 0.008855808526277542, Time: 0.37s, Best Loss: 0.008855808526277542 
Saving model at epoch 10



