# Jax Implementation of MobileNetV2
The following is a jax and equinox implementation of the MobileNetV2 architecture

The implementation is broken up into its individual (modulo stride) layers

In [89]:
from typing import List, Any, Tuple, Optional, Union, Sequence
from jaxtyping import Array, Float, Int, PyTree

import jax
import jax.numpy as jnp
import jax.random as jr

import equinox as eqx

import optax as opt

import matplotlib.pyplot as plt

from functools import partial

## MobileNetV2

The following is the full implementation of the MobileNetV2 architecture

In [90]:
# Define a Depthwise Separable Convolution Layer
class DepthwiseSeparableConv(eqx.Module):
    depthwise: eqx.nn.Conv2d
    pointwise: eqx.nn.Conv2d

    def __init__(self, in_channels, out_channels, stride, key):
        dw_key, pw_key = jax.random.split(key)
        self.depthwise = eqx.nn.Conv2d(
            in_channels=in_channels, 
            out_channels=in_channels, 
            kernel_size=(3, 3), 
            stride=stride, 
            padding=1, 
            groups=in_channels,
            key=dw_key
        )
        self.pointwise = eqx.nn.Conv2d(
            in_channels=in_channels, 
            out_channels=out_channels, 
            kernel_size=(3, 3), 
            key=pw_key
        )

    def __call__(self, x):
        x = self.depthwise(x)
        x = jax.nn.relu(x)
        x = self.pointwise(x)
        return x


In [91]:
# [reference](https://github.com/DarshanDeshpande/jax-models/blob/main/jax_models/layers/depthwise_separable_conv.py)
class DepthwiseConv2D(eqx.Module):
    in_channels: int = eqx.field(static=True)
    out_channels: int = eqx.field(static=True)
    kernel_size: Union[int, Sequence[int]] = eqx.field(static=True)
    stride: Union[int, Sequence[int]] = eqx.field(static=True)
    padding: Union[int, Sequence[int]] = eqx.field(static=True)
    depth_multiplier: int = eqx.field(static=True)
    use_bias: bool = eqx.field(static=True)
    groups: int = eqx.field(static=True)
    key: Any = eqx.field(static=True)

    kernel: Array
    bias: Array

    def __init__(self, in_channels: int, depth_multiplier: int, kernel_size: Tuple[int, int], stride: Tuple[int, int], padding: Tuple[int, int], use_bias: bool, key: jr.PRNGKey):
        self.in_channels = in_channels
        self.depth_multiplier = depth_multiplier
        self.out_channels = in_channels * depth_multiplier
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.use_bias = False
        self.groups = in_channels
        self.key = key

        self.kernel = jr.normal(
            key,
            shape=(self.depth_multiplier * self.in_channels, 1) + self.kernel_size)

        if use_bias:
            self.bias = jr.normal(key, shape=(in_channels * depth_multiplier,))
            self.use_bias = True
        else:
            self.bias = jnp.zeros((in_channels * depth_multiplier,))


    def __call__(self, x: Array) -> Array:
        x = x.reshape((x.shape[0], 1, *x.shape[1:]))
        x = jax.lax.conv_general_dilated(
            x,
            self.kernel,
            self.stride,
            'VALID',
            (1,) * len(self.kernel_size),
            (1,) * len(self.kernel_size),
            ("NCHW", "OIHW", "NCHW"),
            x.shape[-1]
        )
        if self.use_bias:
            x = x + self.bias
        return x

In [92]:
# Define an inverted residual block [reference](https://github.com/keras-team/keras/blob/v3.3.3/keras/src/applications/mobilenet_v2.py#L398)
class InvertedResidualBlock(eqx.Module):
    # static fields get ignored durign training
    in_channels:  int   = eqx.field(static=True)
    expansion:    int   = eqx.field(static=True)
    stride:       int   = eqx.field(static=True)
    alpha:        float = eqx.field(static=True)
    filters:      int   = eqx.field(static=True)
    pw_filters:   int   = eqx.field(static=True)
    block_id:     int   = eqx.field(static=True)

    layers: List[Any]
    

    def _make_divisible(self, v, divisor, min_value=None):
        if min_value is None:
            min_value = divisor
        new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
        # Make sure that round down does not go down by more than 10%.
        if new_v < 0.9 * v:
            new_v += divisor
        return new_v

    def __init__(self, in_channels: int, expansion: int, stride: int, alpha: float, filters: int, block_id: int, key: jr.PRNGKey, BATCH_SIZE: int):
        self.in_channels = in_channels
        self.expansion = expansion
        self.stride = stride
        self.alpha = alpha
        self.filters = filters
        self.block_id = block_id

        pointwise_filters = int(filters * alpha)
        # ensure that the number of filters on the last 1x1 convolution is a multiple of 8
        pointwise_filters = self._make_divisible(pointwise_filters, 8)
        self.pw_filters = pointwise_filters

        # Define the key for the block
        key, conv_key = jr.split(key)
        self.layers = []

        # Define the layers of the block
        if block_id:
            # Expand with a pointwise 1x1 convolution
            self.layers.extend([
                eqx.nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=in_channels * expansion,
                    kernel_size=(1, 1),
                    use_bias=False,
                    key=conv_key
                ),
                eqx.nn.BatchNorm(
                    BATCH_SIZE,
                    axis_name='batch',
                    eps=1e-3,
                    momentum=0.99
                ),
                jax.nn.relu6
            ])
        
        self.layers.extend([
            DepthwiseConv2D(
                in_channels=in_channels * expansion,
                depth_multiplier=1,
                kernel_size=(3, 3),
                stride=(stride, stride),
                padding=(1, 1),
                key=conv_key,
                use_bias=False
            ),
            eqx.nn.BatchNorm(
                BATCH_SIZE,
                axis_name='batch',
                eps=1e-3,
                momentum=0.99
            ),
            jax.nn.relu6
        ])

        # pointwise 1x1 conv
        self.layers.extend([
            eqx.nn.Conv2d(
                in_channels=in_channels * expansion,
                out_channels=pointwise_filters,
                kernel_size=(1, 1),
                use_bias=False,
                key=conv_key
            ),
            eqx.nn.BatchNorm(
                BATCH_SIZE,
                axis_name='batch',
                eps=1e-3,
                momentum=0.99
            )
        ])
    
    def __call__(self, x):
        input = x

        lc = 0
        
        if self.block_id:
            x = self.layers[0](x)
            x = self.layers[1](x)
            x = self.layers[2](x)
            lc = 3
        if self.stride == 2:
            x = jnp.pad(x, 1, mode='constant', constant_values=0)

        for i in range(lc, len(self.layers)):
            x = self.layers[i](x)

        if self.in_channels == self.pw_filters and self.stride == 1:
            x = x + input

        return x

In [93]:
# Define a Bottleneck Block
class Bottleneck(eqx.Module):
    _stride: int = eqx.field(static=True)

    conv1: eqx.nn.Conv2d
    depthwise_conv: DepthwiseSeparableConv
    conv3: eqx.nn.Conv2d
    use_residual: bool

    def __init__(self, in_channels, out_channels, stride, expand_ratio, use_residual, key: jr.PRNGKey):
        self._stride=stride
        keys = jr.split(key, 3)
        hidden_dim = in_channels * expand_ratio
        self.conv1 = eqx.nn.Conv2d(in_channels, hidden_dim, kernel_size=(1, 1), key=keys[0])
        self.depthwise_conv = [DepthwiseSeparableConv(hidden_dim, hidden_dim, stride=1, key=keys[1]),
                               DepthwiseSeparableConv(hidden_dim, hidden_dim, stride=2, key=keys[1])]
        self.conv3 = eqx.nn.Conv2d(hidden_dim, out_channels, kernel_size=(1, 1), key=keys[2])
        self.use_residual = use_residual

    def __call__(self, x):
        residual = x
        x = self.conv1(x)
        x = jax.nn.relu(x)
        if self._stride == 1:
            x = self.depthwise_conv[0](x)
            x = jax.nn.relu(x)
            x = self.conv3(x)
            return x + residual
        else:
            x = self.depthwise_conv[1](x)
            x = jax.nn.relu(x)
            x = self.conv3(x)
            return x

In [94]:
# Define the MobileNetV2
class MobileNetV2(eqx.Module):
    in_channels: int = eqx.field(static=True)
    
    first_conv: eqx.nn.Conv2d
    bottlenecks: list
    last_conv: eqx.nn.Conv2d
    pool: eqx.nn.AvgPool2d
    classifier: eqx.nn.Conv2d

    def __init__(self, in_channels, num_classes, key):
        keys = jax.random.split(key, 10)
        self.in_channels = in_channels

        self.first_conv = eqx.nn.Conv2d(in_channels, 32, kernel_size=(3, 3), stride=2, padding=1, key=keys[0])

        # Bottleneck blocks configuration
        bottleneck_configs = [
            # (in_channels, out_channels, stride, expand_ratio, n_repeats)
            (32, 16, 1, 1, 1),   # First block, no expansion, no repetition
            (16, 24, 2, 6, 2),   # Second block, 2x stride, 2 repetitions
            (24, 32, 2, 6, 3),   # Third block, 2x stride, 3 repetitions
            (32, 64, 2, 6, 4),   # Fourth block, 2x stride, 4 repetitions
            (64, 96, 1, 6, 3),   # Fifth block, stride 1, 3 repetitions
            (96, 160, 2, 6, 3),  # Sixth block, 2x stride, 3 repetitions
            (160, 320, 1, 6, 1), # Seventh block, stride 1, no repetition
        ]

        self.bottlenecks = []
        current_key = keys[1]

        for config in bottleneck_configs:
            in_channels, out_channels, stride, expand_ratio, n_repeats = config

            # Add the first block in the stage with the specified stride
            self.bottlenecks.append(
                Bottleneck(in_channels, out_channels, stride, expand_ratio, use_residual=(stride == 1), key=current_key)
            )
            current_key = jax.random.split(current_key, 1)[0]

            # Add the remaining blocks with stride = 1
            for i in range(n_repeats - 1):
                self.bottlenecks.append(
                    Bottleneck(out_channels, out_channels, stride=1, expand_ratio=expand_ratio, use_residual=True, key=current_key)
                )
                current_key = jax.random.split(current_key, 1)[0]

        self.last_conv = eqx.nn.Conv2d(24, 1280, kernel_size=(1, 1), key=keys[2])
        self.pool = eqx.nn.AvgPool2d(kernel_size=(7, 7))
        self.classifier = eqx.nn.Conv2d(1280, num_classes, kernel_size=(1,1),key=keys[3])

    def __call__(self, x):
        x = self.first_conv(x)
        x = jax.nn.relu(x)

        for bottleneck in self.bottlenecks:
            x = bottleneck(x)

        x = self.last_conv(x)
        x = jax.nn.relu(x)
        x = jnp.mean(x, axis=(1, 2))  # Global average pooling
        x = self.classifier(x)
        return x


In [95]:
# MobileNetV2 model based on the Keras implementation
class MobileNetV2_K(eqx.Module):
    layers: List[Any]

    def _make_divisible(self, v, divisor, min_value=None):
        if min_value is None:
            min_value = divisor
        new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
        # Make sure that round down does not go down by more than 10%.
        if new_v < 0.9 * v:
            new_v += divisor
        return new_v

    def __init__(self,
        in_channels: int = 3,
        alpha: float = 1.0,
        include_top: bool = True,
        num_classes: int = 1000,
        classifier_activation: str = 'softmax',
        pooling: Optional[str] = None,
        key: jr.PRNGKey = jr.PRNGKey(0),
        BATCH_SIZE: int = 32
    ):
        key, conv_key = jr.split(key)
        
        first_block_filters = self._make_divisible(32 * alpha, 8)
        if alpha > 1.0:
            last_block_filters = self._make_divisible(1280 * alpha, 8)
        else:
            last_block_filters = 1280

        self.layers = [
            eqx.nn.Conv2d(
                in_channels=in_channels,
                out_channels=first_block_filters,
                kernel_size=(3, 3),
                stride=(2, 2),
                use_bias=False,
                key=conv_key
            ),
            eqx.nn.BatchNorm(
                input_size=first_block_filters,
                eps=1e-3,
                momentum=0.99,
                axis_name="batch"
            ),
            jax.nn.relu6,
            InvertedResidualBlock(
                in_channels=first_block_filters,
                expansion=1,
                stride=1,
                alpha=alpha,
                filters=16,
                block_id=0,
                key=key,
                BATCH_SIZE=BATCH_SIZE
            ),
            InvertedResidualBlock(
                in_channels=16,
                expansion=6,
                stride=2,
                alpha=alpha,
                filters=24,
                block_id=1,
                key=key,
                BATCH_SIZE=BATCH_SIZE
            ),
            InvertedResidualBlock(
                in_channels=24,
                expansion=6,
                stride=1,
                alpha=alpha,
                filters=24,
                block_id=2,
                key=key,
                BATCH_SIZE=BATCH_SIZE
            ),
            InvertedResidualBlock(
                in_channels=24,
                expansion=6,
                stride=2,
                alpha=alpha,
                filters=32,
                block_id=3,
                key=key,
                BATCH_SIZE=BATCH_SIZE
            ),
            InvertedResidualBlock(
                in_channels=32,
                expansion=6,
                stride=1,
                alpha=alpha,
                filters=32,
                block_id=4,
                key=key,
                BATCH_SIZE=BATCH_SIZE
            ),
            InvertedResidualBlock(
                in_channels=32,
                expansion=6,
                stride=1,
                alpha=alpha,
                filters=32,
                block_id=5,
                key=key,
                BATCH_SIZE=BATCH_SIZE
            ),
            InvertedResidualBlock(
                in_channels=32,
                expansion=6,
                stride=2,
                alpha=alpha,
                filters=64,
                block_id=6,
                key=key,
                BATCH_SIZE=BATCH_SIZE
            ),
            InvertedResidualBlock(
                in_channels=64,
                expansion=6,
                stride=1,
                alpha=alpha,
                filters=64,
                block_id=7,
                key=key,
                BATCH_SIZE=BATCH_SIZE
            ),
            InvertedResidualBlock(
                in_channels=64,
                expansion=6,
                stride=1,
                alpha=alpha,
                filters=64,
                block_id=8,
                key=key,
                BATCH_SIZE=BATCH_SIZE
            ),
            InvertedResidualBlock(
                in_channels=64,
                expansion=6,
                stride=1,
                alpha=alpha,
                filters=64,
                block_id=9,
                key=key,
                BATCH_SIZE=BATCH_SIZE
            ),
            InvertedResidualBlock(
                in_channels=64,
                expansion=6,
                stride=1,
                alpha=alpha,
                filters=96,
                block_id=10,
                key=key,
                BATCH_SIZE=BATCH_SIZE
            ),
            InvertedResidualBlock(
                in_channels=96,
                expansion=6,
                stride=1,
                alpha=alpha,
                filters=96,
                block_id=11,
                key=key,
                BATCH_SIZE=BATCH_SIZE
            ),
            InvertedResidualBlock(
                in_channels=96,
                expansion=6,
                stride=1,
                alpha=alpha,
                filters=96,
                block_id=12,
                key=key,
                BATCH_SIZE=BATCH_SIZE
            ),
            InvertedResidualBlock(
                in_channels=96,
                expansion=6,
                stride=2,
                alpha=alpha,
                filters=160,
                block_id=13,
                key=key,
                BATCH_SIZE=BATCH_SIZE
            ),
            InvertedResidualBlock(
                in_channels=160,
                expansion=6,
                stride=1,
                alpha=alpha,
                filters=160,
                block_id=14,
                key=key,
                BATCH_SIZE=BATCH_SIZE
            ),
            InvertedResidualBlock(
                in_channels=160,
                expansion=6,
                stride=1,
                alpha=alpha,
                filters=160,
                block_id=15,
                key=key,
                BATCH_SIZE=BATCH_SIZE
            ),
            InvertedResidualBlock(
                in_channels=160,
                expansion=6,
                stride=1,
                alpha=alpha,
                filters=320,
                block_id=16,
                key=key,
                BATCH_SIZE=BATCH_SIZE
            ),
            eqx.nn.Conv2d(
                in_channels=320,
                out_channels=last_block_filters,
                kernel_size=(1, 1),
                use_bias=False,
                key=conv_key
            ),
            eqx.nn.BatchNorm(
                input_size=BATCH_SIZE,
                eps=1e-3,
                momentum=0.999,
                axis_name="batch"
            ),
            jax.nn.relu6
        ]

        if include_top:
            self.layers.extend([
                eqx.nn.AvgPool2d(kernel_size=(7, 7)), # TODO: replace with global average pooling
                eqx.nn.Linear(
                    in_features=last_block_filters,
                    out_features=num_classes,
                    key=key
                )
            ])
            if classifier_activation == 'softmax':
                self.layers.append(jax.nn.softmax)

        else:
            if pooling == 'avg':
                self.layers.append(eqx.nn.AvgPool2d(kernel_size=(7, 7))) # TODO: replace with global average pooling
            elif pooling == 'max':
                self.layers.append(eqx.nn.MaxPool2d(kernel_size=(7, 7))) # TODO: replace with global max pooling

    def __call__(self, x, state):
        for layer in self.layers:
            if issubclass(type(layer), eqx.nn.StatefulLayer):
                x, state = layer(x, state)
            else:
                x = layer(x)
        return x, state

## Training

In [96]:
# Training hyperparameters
LEARNING_RATE = 1e-3
N_EPOCHS = 300
BATCH_SIZE = 64
PRINT_EVERY = 30
SEED = 42

# Key generation
key = jax.random.PRNGKey(SEED)

### Importing Data Set

In [97]:
import torch  # https://pytorch.org
import torchvision  # https://pytorch.org

In [98]:
# Lets test with MNIST

# Load the MNIST dataset [reference](https://docs.kidger.site/equinox/examples/mnist/#the-dataset)
normalise_data = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,), (0.5,)),
    ]
)
train_dataset = torchvision.datasets.MNIST(
    "MNIST",
    train=True,
    download=True,
    transform=normalise_data,
)
test_dataset = torchvision.datasets.MNIST(
    "MNIST",
    train=False,
    download=True,
    transform=normalise_data,
)
trainloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True
)
testloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=True
)

# we aren't using a validation set here, but that's easy enough to fix

In [99]:
# Checking our data a bit (by now, everyone knows what the MNIST dataset looks like)
dummy_x, dummy_y = next(iter(trainloader))
dummy_x = dummy_x.numpy()
dummy_y = dummy_y.numpy()
print(dummy_x.shape)  # 64x1x28x28
print(dummy_y.shape)  # 64
print(dummy_y)

(64, 1, 28, 28)
(64,)
[2 1 1 3 5 4 9 8 6 8 6 5 5 4 2 0 3 0 0 6 1 9 0 2 2 7 3 2 8 5 7 8 6 2 2 1 8
 4 1 2 0 8 3 0 3 8 4 9 1 8 0 0 3 6 3 6 5 6 3 6 9 5 7 3]


In [100]:
model, state = eqx.nn.make_with_state(MobileNetV2_K)(in_channels=1, num_classes=10, key=key, BATCH_SIZE=BATCH_SIZE)
print(model)

MobileNetV2_K(
  layers=[
    Conv2d(
      num_spatial_dims=2,
      weight=f32[32,1,3,3],
      bias=None,
      in_channels=1,
      out_channels=32,
      kernel_size=(3, 3),
      stride=(2, 2),
      padding=((0, 0), (0, 0)),
      dilation=(1, 1),
      groups=1,
      use_bias=False,
      padding_mode='ZEROS'
    ),
    BatchNorm(
      weight=f32[32],
      bias=f32[32],
      first_time_index=StateIndex(
        marker=0,
        init=<object object at 0x7efcb8193c10>
      ),
      state_index=StateIndex(marker=1, init=<object object at 0x7efcb8193c10>),
      axis_name='batch',
      inference=False,
      input_size=32,
      eps=0.001,
      channelwise_affine=True,
      momentum=0.99
    ),
    <wrapped function relu6>,
    InvertedResidualBlock(
      in_channels=32,
      expansion=1,
      stride=1,
      alpha=1.0,
      filters=16,
      pw_filters=16,
      block_id=0,
      layers=[
        DepthwiseConv2D(
          in_channels=32,
          out_channels=32,
  

### Running Training


In [101]:
# for MobileNetV2_K

def loss(
    model: MobileNetV2_K,  state: eqx.nn.State, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
    batch_model = jax.vmap(
        model, axis_name="batch", in_axes=(0, None), out_axes=(0, None)
    )
    pred_y, state = batch_model(x, state)
    return cross_entropy(y, pred_y)


def cross_entropy(
    y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"]
) -> Float[Array, ""]:
    # y are the true targets, and should be integers 0-9.
    # pred_y are the log-softmax'd predictions.
    pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
    return -jnp.mean(pred_y)


@eqx.filter_jit
def make_step(model, state, opt_state, xs, ys):
    grads, state = eqx.filter_grad(loss, has_aux=True)(model, state, xs, ys)
    updates, opt_state = opt.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return model, state, opt_state

# Example loss
loss_value = loss(model, state, dummy_x, dummy_y)
print(loss_value.shape)  # scalar loss
# Example inference
output = jax.vmap(model)(dummy_x, state)
print(output.shape)  # batch of predictions

ValueError: conv_general_dilated feature_group_count must divide lhs feature dimension size, but 13 does not divide 1.

In [None]:
model = MobileNetV2(in_channels=1, num_classes=10, key=key)
print(model)

In [None]:
# For MobileNetV2

def loss(
    model: MobileNetV2, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
    # Our input has the shape (BATCH_SIZE, 1, 28, 28), but our model operations on
    # a single input input image of shape (1, 28, 28).
    #
    # Therefore, we have to use jax.vmap, which in this case maps our model over the
    # leading (batch) axis.
    pred_y = jax.vmap(model)(x)
    return cross_entropy(y, pred_y)


def cross_entropy(
    y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"]
) -> Float[Array, ""]:
    # y are the true targets, and should be integers 0-9.
    # pred_y are the log-softmax'd predictions.
    pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
    return -jnp.mean(pred_y)


# Example loss
loss_value = loss(model, dummy_x, dummy_y)
print(loss_value.shape)  # scalar loss
# Example inference
output = jax.vmap(model)(dummy_x)
print(output.shape)  # batch of predictions

## Results