In [None]:
# default_exp model

In [None]:
from nbdev.showdoc import show_doc

# A ConvNet for the 2020s

Based on article [A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) and official [code repository](https://github.dev/facebookresearch/ConvNeXt).

# Imports

In [None]:
#export
from functools import partial
from typing import List, Tuple, Sequence, Dict, Any, Optional, Union, Callable

import numpy as np

import jax
import jax.numpy as jnp
import jax.random as random

import flax
import flax.linen as nn

import optax

# Initialization

Truncated normal initialization.

In [None]:
#export
def init(scale=.02):
    'Scaled truncated normal initailizer'
    return nn.initializers.variance_scaling(scale, 'fan_in', 'truncated_normal')

scale = .02
default_kernel_init = init(scale)

In [None]:
show_doc(init)

<h4 id="init" class="doc_header"><code>init</code><a href="__main__.py#L2" class="source_link" style="float:right">[source]</a></h4>

> <code>init</code>(**`scale`**=*`0.02`*)

Scaled truncated normal initailizer

## Some layer definitions

In [None]:
#export
conv = partial(nn.Conv, kernel_init=default_kernel_init)
pw_conv = partial(nn.Dense, kernel_init=default_kernel_init)
sw_drop = partial(nn.Dropout, broadcast_dims=(1,2,3))
norm = nn.LayerNorm

## ConvNeXt block
> With all of these preparations, the benefit of adopting larger kernel-sized convolutions is significant. We will use 7x7 depthwise conv in each block.
>
> We will now use a single GELU activation in each block.
>
> From now on, we will use one LayerNorm as our choice of normalization in each residual block.

![ConvNeXt Block](../images/ConvNeXtBlock.png)

In [None]:
#export
class ConvBlock(nn.Module):
    'Residual block with depthwise convolution, samplewise norm and dropout and gelu activation.'
    dim: int = 3  # number of output features
    drop: float = 0.0  # dropout rate
    scale: float = 1e-6  # initial scale of direct path

    @nn.compact
    def __call__(self, x, train: bool = True):
        res = x
        x = conv(self.dim, (7, 7), padding='SAME', feature_group_count=self.dim, name='dw_conv')(x)
        x = norm(name='lr_norm')(x)
        x = pw_conv(self.dim*4, name='fc_1')(x)
        x = nn.gelu(x)
        x = pw_conv(self.dim, name='fc_2')(x)
        if self.scale > 0.:
            gamma = self.param('gamma', init_fn=lambda _:self.scale * jnp.ones(self.dim))
            x = x * gamma
        if self.drop > 0.:
            x = sw_drop(self.drop, name='sw_drop')(x, deterministic=not train)
        return res + x

In [None]:
show_doc(ConvBlock)

<h2 id="ConvBlock" class="doc_header"><code>class</code> <code>ConvBlock</code><a href="" class="source_link" style="float:right">[source]</a></h2>

> <code>ConvBlock</code>(**`dim`**:`int`=*`3`*, **`drop`**:`float`=*`0.0`*, **`scale`**:`float`=*`1e-06`*, **`parent`**:`Union`\[`typing.Type[flax.linen.module.Module]`, `typing.Type[flax.core.scope.Scope]`, `typing.Type[flax.linen.module._Sentinel]`, `NoneType`\]=*`<flax.linen.module._Sentinel object at 0x7ff17a97ca30>`*, **`name`**:`str`=*`None`*) :: `Module`

Residual block with depthwise convolution, samplewise norm and dropout and gelu activation.

In [None]:
m = ConvBlock(64)
dummy_x = jnp.ones((12, 224, 224, 64)) # (batch, height, width, channels)
out, params = m.init_with_output(random.PRNGKey(0), dummy_x)
out.shape, jax.tree_map(lambda x: x.shape, params)

## ConvStage

In [None]:
class ConvStage(nn.Module):
    'Convolutional stage'
    dim: int # number of output features
    depth: int # number of `ConvBlock`s
    drops: Union[Sequence[float], None] = None # dropout rates 
    scale: float = 1e-6 # initial scale of direct path
    downsample: bool = True # whether to downsample input


    @nn.compact
    def __call__(self, x, train: bool =True):
        if self.downsample:
            x = norm(name='lr_norm')(x)
            x = conv(self.dim, kernel_size=(2, 2), strides=2, name='dwsample')(x)
        drops = self.drops or [0.] * self.depth
        for i, drop in enumerate(drops):
            x = ConvBlock(dim=self.dim, drop=drop, scale=self.scale, name=f'block_{i}')(x, train=train)
        return x

In [None]:
m = ConvStage(dim=192, depth=1)
dummy_x = jnp.ones((7, 64, 64, 96)) # (batch, height, width, channels)
out, params = m.init_with_output({'params':random.PRNGKey(0),'dropout':random.PRNGKey(1)}, dummy_x)
out.shape, jax.tree_map(lambda x: x.shape, params)

((7, 32, 32, 192),
 FrozenDict({
     params: {
         block_0: {
             dw_conv: {
                 bias: (192,),
                 kernel: (7, 7, 1, 192),
             },
             fc_1: {
                 bias: (768,),
                 kernel: (192, 768),
             },
             fc_2: {
                 bias: (192,),
                 kernel: (768, 192),
             },
             gamma: (192,),
             lr_norm: {
                 bias: (192,),
                 scale: (192,),
             },
         },
         dwsample: {
             bias: (192,),
             kernel: (2, 2, 96, 192),
         },
         lr_norm: {
             bias: (96,),
             scale: (96,),
         },
     },
 }))

# ConvNeXt model

In [None]:
#export
class Stem(nn.Module):
    'Stem module'
    dim: int 
    @nn.compact
    def __call__(self, x):
        x = conv(self.dim, kernel_size=(4, 4), strides=4, name='conv_1')(x)
        x = norm(name='lr_norm_1')(x)
        return x

class Head(nn.Module):
    'Head module'
    classes: int 
    @nn.compact
    def __call__(self, x):
        x = norm(name='lr_norm')(x.mean(axis=(2, 3)))
        x = nn.Dense(self.classes, kernel_init=default_kernel_init, name='out')(x)
        return x

class ConvNeXt(nn.Module):
    classes: int = 10 # number of classes
    depths: Sequence[int] = (3, 3, 9, 3) # number of `ConvBlock`s per stage
    dims: Sequence[int] = (96, 192, 384, 768) # number of output features per stage
    drop: float = 0. # dropout rate
    scale: float = 1e-6 # initial scale of direct path

    @nn.compact
    def __call__(self, x, train=True):
        x = Stem(dim=self.dims[0], name='stem')(x)
        dep_sums = np.cumsum(self.depths)
        drops = jnp.linspace(0., self.drop, dep_sums[-1]).split(dep_sums[:-1])
        for i, (dim, depth, drop) in enumerate(zip(self.dims, self.depths, drops)):
            x = ConvStage(dim=dim, depth=depth, drops=drop.tolist(), scale=self.scale, downsample=(i>0), name=f'stage_{i}')(x, train=train)
        x = Head(classes=self.classes, name='head')(x)
        return x

In [None]:
m = ConvNeXt()
dummy_x = jnp.ones((7, 224, 224, 3)) # (batch, height, width, channels)
out, params = m.init_with_output({'params':random.PRNGKey(0),'dropout':random.PRNGKey(1)}, dummy_x)
out.shape, jax.tree_map(lambda x: x.shape, params)

In [None]:
from nbdev.export import notebook2script

notebook2script('*.ipynb')

Converted 00_core.ipynb.
Converted 01_dataloaders.ipynb.
Converted 02_experiment.ipynb.
Converted 03_model.ipynb.
Converted index.ipynb.
