# Jax Implementation of MobileNetV2
The following is a jax and equinox implementation of the MobileNetV2 architecture

The implementation is broken up into its individual (modulo stride) layers

In [1]:
from typing import List, Any, Tuple, Optional, Union, Sequence
from jaxtyping import Array, Float, Int, PyTree

import jax
import jax.numpy as jnp
import jax.random as jr

import equinox as eqx

import optax as opt

import matplotlib.pyplot as plt

from functools import partial

## MobileNetV2

The following is the full implementation of the MobileNetV2 architecture

In [2]:
# Define a Depthwise Separable Convolution Layer
class DepthwiseSeparableConv(eqx.Module):
    depthwise: eqx.nn.Conv2d
    pointwise: eqx.nn.Conv2d

    def __init__(self, in_channels, out_channels, stride, key):
        dw_key, pw_key = jax.random.split(key)
        self.depthwise = eqx.nn.Conv2d(
            in_channels=in_channels, 
            out_channels=in_channels, 
            kernel_size=(3, 3), 
            stride=stride, 
            padding=1, 
            groups=in_channels,
            key=dw_key
        )
        self.pointwise = eqx.nn.Conv2d(
            in_channels=in_channels, 
            out_channels=out_channels, 
            kernel_size=(3, 3), 
            key=pw_key
        )

    def __call__(self, x):
        x = self.depthwise(x)
        x = jax.nn.relu(x)
        x = self.pointwise(x)
        return x


In [3]:
# [reference](https://github.com/DarshanDeshpande/jax-models/blob/main/jax_models/layers/depthwise_separable_conv.py)
class DepthwiseConv2D(eqx.Module):
    in_channels: int = eqx.field(static=True)
    out_channels: int = eqx.field(static=True)
    kernel_size: Union[int, Sequence[int]] = eqx.field(static=True)
    stride: Union[int, Sequence[int]] = eqx.field(static=True)
    padding: str = eqx.field(static=True)
    depth_multiplier: int = eqx.field(static=True)
    use_bias: bool = eqx.field(static=True)
    groups: int = eqx.field(static=True)
    key: Any = eqx.field(static=True)

    kernel: Array
    bias: Array

    def __init__(self, in_channels: int, depth_multiplier: int, kernel_size: Tuple[int, int], stride: Tuple[int, int], padding: str, use_bias: bool, key: jr.PRNGKey):
        self.in_channels = in_channels
        self.depth_multiplier = depth_multiplier
        self.out_channels = in_channels * depth_multiplier
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding.upper()
        self.use_bias = False
        self.groups = in_channels
        self.key = key

        # if self.padding.lower() is not "valid" and self.padding is not "same":
        #     raise ValueError("Padding must be either 'valid' or 'same'")

        self.kernel = jr.uniform(
            key,
            shape = (self.in_channels, self.depth_multiplier * self.in_channels, *self.kernel_size)
        )

        if use_bias:
            self.bias = jr.normal(key, shape=(in_channels * depth_multiplier,))
            self.use_bias = True
        else:
            self.bias = jnp.zeros((in_channels * depth_multiplier,))


    def __call__(self, x: Array) -> Array:
        x = jnp.expand_dims(x, axis=0)
        x = jax.lax.conv_general_dilated( # see (https://jax.readthedocs.io/en/latest/notebooks/convolutions.html#dimension-numbers-define-dimensional-layout-for-conv-general-dilated)
            lhs=x,
            rhs=self.kernel,
            window_strides=self.stride,
            padding=self.padding.upper(),
            lhs_dilation=(1,) * len(self.kernel_size),
            rhs_dilation=(1,) * len(self.kernel_size),
            dimension_numbers=("NCHW", "OIHW", "NCHW"),
            # feature_group_count=x.shape[-1]
            # feature_group_count=1
        )
        x = x.squeeze(axis=0)
        if self.use_bias:
            x = x + self.bias
        return x

In [4]:
# Define an inverted residual block [reference](https://github.com/keras-team/keras/blob/v3.3.3/keras/src/applications/mobilenet_v2.py#L398)
class InvertedResidualBlock(eqx.Module):
    # static fields get ignored durign training
    in_channels:  int   = eqx.field(static=True)
    expansion:    int   = eqx.field(static=True)
    stride:       int   = eqx.field(static=True)
    alpha:        float = eqx.field(static=True)
    filters:      int   = eqx.field(static=True)
    pw_filters:   int   = eqx.field(static=True)
    block_id:     int   = eqx.field(static=True)

    layers: List[Any]
    

    def _make_divisible(self, v, divisor, min_value=None):
        if min_value is None:
            min_value = divisor
        new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
        # Make sure that round down does not go down by more than 10%.
        if new_v < 0.9 * v:
            new_v += divisor
        return new_v

    def __init__(self, in_channels: int, expansion: int, stride: int, alpha: float, filters: int, block_id: int, key: jr.PRNGKey):
        self.in_channels = in_channels
        self.expansion = expansion
        self.stride = stride
        self.alpha = alpha
        self.filters = filters
        self.block_id = block_id

        pointwise_filters = int(filters * alpha)
        # ensure that the number of filters on the last 1x1 convolution is a multiple of 8
        pointwise_filters = self._make_divisible(pointwise_filters, 8)
        self.pw_filters = pointwise_filters

        # Define the key for the block
        key, conv_key = jr.split(key)
        self.layers = []

        # Define the layers of the block
        if block_id:
            # Expand with a pointwise 1x1 convolution
            self.layers.extend([
                eqx.nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=in_channels * expansion,
                    kernel_size=(1, 1),
                    stride=(1, 1),
                    padding=(0, 0),
                    padding_mode='ZEROS',
                    use_bias=False,
                    key=conv_key
                ),
                eqx.nn.BatchNorm(
                    in_channels * expansion,
                    axis_name='batch',
                    eps=1e-3,
                    momentum=0.99
                ),
                jax.nn.relu6
            ])
        
        self.layers.extend([
            DepthwiseConv2D(
                in_channels=in_channels * expansion,
                depth_multiplier=1,
                kernel_size=(1, 1),
                stride=(stride, stride),
                padding="valid",
                key=conv_key,
                use_bias=False
            ),
            eqx.nn.BatchNorm(
                in_channels * expansion,
                axis_name='batch',
                eps=1e-3,
                momentum=0.99
            ),
            jax.nn.relu6
        ])

        # pointwise 1x1 conv
        self.layers.extend([
            eqx.nn.Conv2d(
                in_channels=in_channels * expansion,
                out_channels=pointwise_filters,
                kernel_size=(1, 1),
                stride=(1, 1),
                padding=(0, 0),
                padding_mode='ZEROS',
                use_bias=False,
                key=conv_key
            ),
            eqx.nn.BatchNorm(
                pointwise_filters,
                axis_name='batch',
                eps=1e-3,
                momentum=0.99
            )
        ])
    
    def __call__(self, x, state):
        input = x

        lc = 0
        
        if self.block_id:
            x = self.layers[0](x)
            x, state = self.layers[1](x, state)
            x = self.layers[2](x)
            lc = 3
        if self.stride == 2:
            correct = (x.shape[0] - (self.in_channels * self.expansion)) // 2
            x = jnp.pad(x, ((correct, correct), (correct, correct), (0, 0)), mode='constant', constant_values=0)

        for _, layer in enumerate(self.layers[lc:]):
            if issubclass(type(layer), eqx.nn.StatefulLayer):
                x, state = layer(x, state)
            else:
                x = layer(x)

        if self.in_channels == self.pw_filters and self.stride == 1:
            x = x + input

        return x, state

In [5]:
# Define a Bottleneck Block
class Bottleneck(eqx.Module):
    _stride: int = eqx.field(static=True)

    conv1: eqx.nn.Conv2d
    depthwise_conv: DepthwiseSeparableConv
    conv3: eqx.nn.Conv2d
    use_residual: bool

    def __init__(self, in_channels, out_channels, stride, expand_ratio, use_residual, key: jr.PRNGKey):
        self._stride=stride
        keys = jr.split(key, 3)
        hidden_dim = in_channels * expand_ratio
        self.conv1 = eqx.nn.Conv2d(in_channels, hidden_dim, kernel_size=(1, 1), key=keys[0])
        self.depthwise_conv = [DepthwiseSeparableConv(hidden_dim, hidden_dim, stride=1, key=keys[1]),
                               DepthwiseSeparableConv(hidden_dim, hidden_dim, stride=2, key=keys[1])]
        self.conv3 = eqx.nn.Conv2d(hidden_dim, out_channels, kernel_size=(1, 1), key=keys[2])
        self.use_residual = use_residual

    def __call__(self, x):
        residual = x
        x = self.conv1(x)
        x = jax.nn.relu(x)
        if self._stride == 1:
            x = self.depthwise_conv[0](x)
            x = jax.nn.relu(x)
            x = self.conv3(x)
            return x + residual
        else:
            x = self.depthwise_conv[1](x)
            x = jax.nn.relu(x)
            x = self.conv3(x)
            return x

In [6]:
# Define the MobileNetV2
class MobileNetV2(eqx.Module):
    in_channels: int = eqx.field(static=True)
    
    first_conv: eqx.nn.Conv2d
    bottlenecks: list
    last_conv: eqx.nn.Conv2d
    pool: eqx.nn.AvgPool2d
    classifier: eqx.nn.Conv2d

    def __init__(self, in_channels, num_classes, key):
        keys = jax.random.split(key, 10)
        self.in_channels = in_channels

        self.first_conv = eqx.nn.Conv2d(in_channels, 32, kernel_size=(3, 3), stride=2, padding=1, key=keys[0])

        # Bottleneck blocks configuration
        bottleneck_configs = [
            # (in_channels, out_channels, stride, expand_ratio, n_repeats)
            (32, 16, 1, 1, 1),   # First block, no expansion, no repetition
            (16, 24, 2, 6, 2),   # Second block, 2x stride, 2 repetitions
            (24, 32, 2, 6, 3),   # Third block, 2x stride, 3 repetitions
            (32, 64, 2, 6, 4),   # Fourth block, 2x stride, 4 repetitions
            (64, 96, 1, 6, 3),   # Fifth block, stride 1, 3 repetitions
            (96, 160, 2, 6, 3),  # Sixth block, 2x stride, 3 repetitions
            (160, 320, 1, 6, 1), # Seventh block, stride 1, no repetition
        ]

        self.bottlenecks = []
        current_key = keys[1]

        for config in bottleneck_configs:
            in_channels, out_channels, stride, expand_ratio, n_repeats = config

            # Add the first block in the stage with the specified stride
            self.bottlenecks.append(
                Bottleneck(in_channels, out_channels, stride, expand_ratio, use_residual=(stride == 1), key=current_key)
            )
            current_key = jax.random.split(current_key, 1)[0]

            # Add the remaining blocks with stride = 1
            for i in range(n_repeats - 1):
                self.bottlenecks.append(
                    Bottleneck(out_channels, out_channels, stride=1, expand_ratio=expand_ratio, use_residual=True, key=current_key)
                )
                current_key = jax.random.split(current_key, 1)[0]

        self.last_conv = eqx.nn.Conv2d(24, 1280, kernel_size=(1, 1), key=keys[2])
        self.pool = eqx.nn.AvgPool2d(kernel_size=(7, 7))
        self.classifier = eqx.nn.Conv2d(1280, num_classes, kernel_size=(1,1),key=keys[3])

    def __call__(self, x):
        x = self.first_conv(x)
        x = jax.nn.relu(x)

        for bottleneck in self.bottlenecks:
            x = bottleneck(x)

        x = self.last_conv(x)
        x = jax.nn.relu(x)
        x = jnp.mean(x, axis=(1, 2))  # Global average pooling
        x = self.classifier(x)
        return x


In [7]:
# MobileNetV2 model based on the Keras implementation
class MobileNetV2_K(eqx.Module):
    layers: List[Any]

    def _make_divisible(self, v, divisor, min_value=None):
        if min_value is None:
            min_value = divisor
        new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
        # Make sure that round down does not go down by more than 10%.
        if new_v < 0.9 * v:
            new_v += divisor
        return new_v

    def __init__(self,
        in_channels: int = 3,
        alpha: float = 1.0,
        include_top: bool = True,
        num_classes: int = 1000,
        classifier_activation: str = 'softmax',
        pooling: Optional[str] = None,
        key: jr.PRNGKey = jr.PRNGKey(0),
        input_size: Sequence[int] = (224, 224),
    ):
        key, conv_key = jr.split(key)
        
        first_block_filters = self._make_divisible(32 * alpha, 8)
        if alpha > 1.0:
            last_block_filters = self._make_divisible(1280 * alpha, 8)
        else:
            last_block_filters = 1280

        first_layer_padding = (
            (input_size[0] - 1) // 2,
            (input_size[1] - 1) // 2
        )

        self.layers = [
            eqx.nn.Conv2d(
                in_channels=in_channels,
                out_channels=first_block_filters,
                kernel_size=(3, 3),
                stride=(2, 2),
                padding=first_layer_padding, # equivalent to TF 'same' padding
                padding_mode='ZEROS',
                use_bias=False,
                key=conv_key
            ),
            eqx.nn.BatchNorm(
                input_size=first_block_filters,
                eps=1e-3,
                momentum=0.999,
                axis_name="batch"
            ),
            jax.nn.relu6,
            InvertedResidualBlock(
                in_channels=first_block_filters,
                expansion=1,
                stride=1,
                alpha=alpha,
                filters=16,
                block_id=0,
                key=key
            ),
            InvertedResidualBlock(
                in_channels=16,
                expansion=6,
                stride=2,
                alpha=alpha,
                filters=24,
                block_id=1,
                key=key
            ),
            InvertedResidualBlock(
                in_channels=24,
                expansion=6,
                stride=1,
                alpha=alpha,
                filters=24,
                block_id=2,
                key=key
            ),
            InvertedResidualBlock(
                in_channels=24,
                expansion=6,
                stride=2,
                alpha=alpha,
                filters=32,
                block_id=3,
                key=key
            ),
            InvertedResidualBlock(
                in_channels=32,
                expansion=6,
                stride=1,
                alpha=alpha,
                filters=32,
                block_id=4,
                key=key
            ),
            InvertedResidualBlock(
                in_channels=32,
                expansion=6,
                stride=1,
                alpha=alpha,
                filters=32,
                block_id=5,
                key=key
            ),
            InvertedResidualBlock(
                in_channels=32,
                expansion=6,
                stride=2,
                alpha=alpha,
                filters=64,
                block_id=6,
                key=key
            ),
            InvertedResidualBlock(
                in_channels=64,
                expansion=6,
                stride=1,
                alpha=alpha,
                filters=64,
                block_id=7,
                key=key
            ),
            InvertedResidualBlock(
                in_channels=64,
                expansion=6,
                stride=1,
                alpha=alpha,
                filters=64,
                block_id=8,
                key=key
            ),
            InvertedResidualBlock(
                in_channels=64,
                expansion=6,
                stride=1,
                alpha=alpha,
                filters=64,
                block_id=9,
                key=key
            ),
            InvertedResidualBlock(
                in_channels=64,
                expansion=6,
                stride=1,
                alpha=alpha,
                filters=96,
                block_id=10,
                key=key
            ),
            InvertedResidualBlock(
                in_channels=96,
                expansion=6,
                stride=1,
                alpha=alpha,
                filters=96,
                block_id=11,
                key=key
            ),
            InvertedResidualBlock(
                in_channels=96,
                expansion=6,
                stride=1,
                alpha=alpha,
                filters=96,
                block_id=12,
                key=key
            ),
            InvertedResidualBlock(
                in_channels=96,
                expansion=6,
                stride=2,
                alpha=alpha,
                filters=160,
                block_id=13,
                key=key
            ),
            InvertedResidualBlock(
                in_channels=160,
                expansion=6,
                stride=1,
                alpha=alpha,
                filters=160,
                block_id=14,
                key=key
            ),
            InvertedResidualBlock(
                in_channels=160,
                expansion=6,
                stride=1,
                alpha=alpha,
                filters=160,
                block_id=15,
                key=key
            ),
            InvertedResidualBlock(
                in_channels=160,
                expansion=6,
                stride=1,
                alpha=alpha,
                filters=320,
                block_id=16,
                key=key
            ),
            eqx.nn.Conv2d(
                in_channels=320,
                out_channels=last_block_filters,
                kernel_size=(1, 1),
                use_bias=False,
                key=conv_key
            ),
            eqx.nn.BatchNorm(
                input_size=last_block_filters,
                eps=1e-3,
                momentum=0.999,
                axis_name="batch"
            ),
            jax.nn.relu6
        ]

        if include_top:
            self.layers.extend([
                eqx.nn.AvgPool2d(kernel_size=(7, 7)), # TODO: replace with global average pooling
                jnp.ravel,
                eqx.nn.Linear(
                    in_features=81920,
                    out_features=num_classes,
                    key=key
                )
            ])
            if classifier_activation == 'softmax':
                self.layers.append(jax.nn.softmax)

        else:
            if pooling == 'avg':
                self.layers.append(eqx.nn.AvgPool2d(kernel_size=(7, 7))) # TODO: replace with global average pooling
            elif pooling == 'max':
                self.layers.append(eqx.nn.MaxPool2d(kernel_size=(7, 7))) # TODO: replace with global max pooling

    def __call__(self, x, state):
        for layer in self.layers:
            if issubclass(type(layer), eqx.nn.StatefulLayer) or isinstance(layer, InvertedResidualBlock):
                x, state = layer(x, state)
            else:
                x = layer(x)
        return x, state

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


## Training

In [8]:
# Training hyperparameters
LEARNING_RATE = 1e-3
N_EPOCHS = 300
BATCH_SIZE = 32
PRINT_EVERY = 30
SEED = 42

# Key generation
key = jax.random.PRNGKey(SEED)

### Importing Data Set

In [9]:
import torch  # https://pytorch.org
import torchvision  # https://pytorch.org

In [10]:
# Lets test with MNIST

# Load the MNIST dataset [reference](https://docs.kidger.site/equinox/examples/mnist/#the-dataset)
process_data = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((224, 224)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,), (0.5,)),
    ]
)
train_dataset = torchvision.datasets.MNIST(
    "MNIST",
    train=True,
    download=True,
    transform=process_data,
)
test_dataset = torchvision.datasets.MNIST(
    "MNIST",
    train=False,
    download=True,
    transform=process_data,
)
trainloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True
)
testloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=True
)

# we aren't using a validation set here, but that's easy enough to fix

In [11]:
# Checking our data a bit (by now, everyone knows what the MNIST dataset looks like)
dummy_x, dummy_y = next(iter(trainloader))
dummy_x = dummy_x.numpy()
dummy_y = dummy_y.numpy()
print(dummy_x.shape)  # BATCH_SIZEx1x224x224
print(dummy_y.shape)  # BATCH_SIZE
print(dummy_y)

(32, 1, 224, 224)
(32,)
[8 9 5 4 5 3 3 0 5 9 6 7 8 8 9 4 9 7 4 6 5 5 9 0 9 2 6 7 9 0 7 0]


In [12]:
model, state = eqx.nn.make_with_state(MobileNetV2_K)(in_channels=1, num_classes=10, key=key, include_top=True, input_size=(224, 224))
print(model)

MobileNetV2_K(
  layers=[
    Conv2d(
      num_spatial_dims=2,
      weight=f32[32,1,3,3],
      bias=None,
      in_channels=1,
      out_channels=32,
      kernel_size=(3, 3),
      stride=(2, 2),
      padding=((111, 111), (111, 111)),
      dilation=(1, 1),
      groups=1,
      use_bias=False,
      padding_mode='ZEROS'
    ),
    BatchNorm(
      weight=f32[32],
      bias=f32[32],
      first_time_index=StateIndex(
        marker=0,
        init=<object object at 0x7fd08020bd20>
      ),
      state_index=StateIndex(marker=1, init=<object object at 0x7fd08020bd20>),
      axis_name='batch',
      inference=False,
      input_size=32,
      eps=0.001,
      channelwise_affine=True,
      momentum=0.999
    ),
    <wrapped function relu6>,
    InvertedResidualBlock(
      in_channels=32,
      expansion=1,
      stride=1,
      alpha=1.0,
      filters=16,
      pw_filters=16,
      block_id=0,
      layers=[
        DepthwiseConv2D(
          in_channels=32,
          out_channe

### Running Training


In [13]:
# for MobileNetV2_K

def loss(
    model: MobileNetV2_K,  state: eqx.nn.State, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
    batch_model = jax.vmap(
        model, axis_name="batch", in_axes=(0, None), out_axes=(0, None)
    )
    pred_y, state = batch_model(x, state)
    return cross_entropy(y, pred_y), state


def cross_entropy(
    y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"]
) -> Float[Array, ""]:
    # y are the true targets, and should be integers 0-9.
    # pred_y are the log-softmax'd predictions.
    pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1).squeeze()
    return -jnp.mean(pred_y)

# Example loss
loss_value, state = loss(model, state, dummy_x, dummy_y)
print(loss_value) # scalar loss

-0.104663804


In [14]:
# Example inference

inference_model = eqx.nn.inference_mode(model)
inference_model = eqx.Partial(inference_model, state=state)

@eqx.filter_jit
def evaluate(model, xs):
    output, _ = jax.vmap(model)(xs)
    return output

# Example evaluation
output = evaluate(inference_model, dummy_x)
print(output)

[[0.08886552 0.07614945 0.08539364 0.10089708 0.11778124 0.12871644
  0.08539084 0.11850863 0.10232773 0.09596939]
 [0.08906077 0.07860798 0.08543477 0.09758051 0.11887437 0.12854408
  0.0828679  0.11546037 0.10533139 0.09823783]
 [0.08376049 0.06828795 0.07768979 0.09803472 0.12249431 0.14249736
  0.08366538 0.12423944 0.10256942 0.09676106]
 [0.08008081 0.05590582 0.07842064 0.11373584 0.12557207 0.15840131
  0.08111922 0.12009377 0.10056375 0.08610681]
 [0.08706446 0.07076795 0.08655535 0.10408452 0.11625675 0.13233224
  0.08985607 0.11936613 0.10068696 0.09302954]
 [0.09454144 0.08076131 0.08847592 0.09621056 0.1146043  0.12209287
  0.08311914 0.11741696 0.10826248 0.094515  ]
 [0.09418596 0.08526894 0.09336437 0.09414536 0.10884769 0.11529104
  0.09265284 0.11065163 0.10609339 0.09949882]
 [0.07768144 0.05531644 0.06135955 0.10418615 0.1316493  0.17303799
  0.06818184 0.14230108 0.10069075 0.08559537]
 [0.08474227 0.06897904 0.07924235 0.10417219 0.12047464 0.14193171
  0.08174327

In [15]:
# See [reference](https://docs.kidger.site/equinox/examples/stateful/)

def train(
    model: MobileNetV2_K,
    state: eqx.nn.State,
    optim: Any,
    trainloader: torch.utils.data.DataLoader,
    testloader: torch.utils.data.DataLoader,
    n_epochs: int,
    print_every: int,
) -> Tuple[MobileNetV2_K, eqx.nn.State, Any]:
    
    # only train parameters, filter out non-arrays and static fields
    opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))

    @eqx.filter_jit
    def make_step(model, state, opt_state, x, y):
        lg, state = eqx.filter_value_and_grad(loss, has_aux=True)(model, state, x, y) # loss is already vmap'd, so no need to vmap here
        loss_value, grads = lg
        # loss_value, state = loss_value
        return model, state, opt_state, loss_value, grads
        # updates, opt_state = optim.update(grads, opt_state)
        # model = eqx.apply_updates(model, updates)
        # return model, state, opt_state, loss_value
    
    def infiniteTrainloader():
        while True:
            yield from trainloader

    for step, (x, y) in zip(range(n_epochs), infiniteTrainloader()):
        
        x = x.numpy()
        y = y.numpy()
        
        model, state, opt_state, loss_value, grads = make_step(model, state, opt_state, x, y)
        if step % print_every == 0:
            print(f"Step {step}, Loss: {loss_value}")

    return model, state

optim = opt.adam(LEARNING_RATE)
model, state = train(model, state, optim, trainloader, testloader, N_EPOCHS, PRINT_EVERY)


ValueError: Custom node type mismatch: expected type: <class 'equinox.nn._stateful.State'>, value: MobileNetV2_K(
  layers=[
    Conv2d(
      num_spatial_dims=2,
      weight=f32[32,1,3,3],
      bias=None,
      in_channels=1,
      out_channels=32,
      kernel_size=(3, 3),
      stride=(2, 2),
      padding=((111, 111), (111, 111)),
      dilation=(1, 1),
      groups=1,
      use_bias=False,
      padding_mode='ZEROS'
    ),
    BatchNorm(
      weight=f32[32],
      bias=f32[32],
      first_time_index=StateIndex(marker=0, init=None),
      state_index=StateIndex(marker=1, init=None),
      axis_name=None,
      inference=None,
      input_size=32,
      eps=0.001,
      channelwise_affine=True,
      momentum=0.999
    ),
    None,
    InvertedResidualBlock(
      in_channels=32,
      expansion=1,
      stride=1,
      alpha=1.0,
      filters=16,
      pw_filters=16,
      block_id=0,
      layers=[
        DepthwiseConv2D(
          in_channels=32,
          out_channels=32,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding='VALID',
          depth_multiplier=1,
          use_bias=False,
          groups=32,
          key=u32[2],
          kernel=f32[32,32,1,1],
          bias=f32[32]
        ),
        BatchNorm(
          weight=f32[32],
          bias=f32[32],
          first_time_index=StateIndex(marker=2, init=None),
          state_index=StateIndex(marker=3, init=None),
          axis_name=None,
          inference=None,
          input_size=32,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        Conv2d(
          num_spatial_dims=2,
          weight=f32[16,32,1,1],
          bias=None,
          in_channels=32,
          out_channels=16,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[16],
          bias=f32[16],
          first_time_index=StateIndex(marker=4, init=None),
          state_index=StateIndex(marker=5, init=None),
          axis_name=None,
          inference=None,
          input_size=16,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        )
      ]
    ),
    InvertedResidualBlock(
      in_channels=16,
      expansion=6,
      stride=2,
      alpha=1.0,
      filters=24,
      pw_filters=24,
      block_id=1,
      layers=[
        Conv2d(
          num_spatial_dims=2,
          weight=f32[96,16,1,1],
          bias=None,
          in_channels=16,
          out_channels=96,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[96],
          bias=f32[96],
          first_time_index=StateIndex(marker=6, init=None),
          state_index=StateIndex(marker=7, init=None),
          axis_name=None,
          inference=None,
          input_size=96,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        DepthwiseConv2D(
          in_channels=96,
          out_channels=96,
          kernel_size=(1, 1),
          stride=(2, 2),
          padding='VALID',
          depth_multiplier=1,
          use_bias=False,
          groups=96,
          key=u32[2],
          kernel=f32[96,96,1,1],
          bias=f32[96]
        ),
        BatchNorm(
          weight=f32[96],
          bias=f32[96],
          first_time_index=StateIndex(marker=8, init=None),
          state_index=StateIndex(marker=9, init=None),
          axis_name=None,
          inference=None,
          input_size=96,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        Conv2d(
          num_spatial_dims=2,
          weight=f32[24,96,1,1],
          bias=None,
          in_channels=96,
          out_channels=24,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[24],
          bias=f32[24],
          first_time_index=StateIndex(marker=10, init=None),
          state_index=StateIndex(marker=11, init=None),
          axis_name=None,
          inference=None,
          input_size=24,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        )
      ]
    ),
    InvertedResidualBlock(
      in_channels=24,
      expansion=6,
      stride=1,
      alpha=1.0,
      filters=24,
      pw_filters=24,
      block_id=2,
      layers=[
        Conv2d(
          num_spatial_dims=2,
          weight=f32[144,24,1,1],
          bias=None,
          in_channels=24,
          out_channels=144,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[144],
          bias=f32[144],
          first_time_index=StateIndex(marker=12, init=None),
          state_index=StateIndex(marker=13, init=None),
          axis_name=None,
          inference=None,
          input_size=144,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        DepthwiseConv2D(
          in_channels=144,
          out_channels=144,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding='VALID',
          depth_multiplier=1,
          use_bias=False,
          groups=144,
          key=u32[2],
          kernel=f32[144,144,1,1],
          bias=f32[144]
        ),
        BatchNorm(
          weight=f32[144],
          bias=f32[144],
          first_time_index=StateIndex(marker=14, init=None),
          state_index=StateIndex(marker=15, init=None),
          axis_name=None,
          inference=None,
          input_size=144,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        Conv2d(
          num_spatial_dims=2,
          weight=f32[24,144,1,1],
          bias=None,
          in_channels=144,
          out_channels=24,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[24],
          bias=f32[24],
          first_time_index=StateIndex(marker=16, init=None),
          state_index=StateIndex(marker=17, init=None),
          axis_name=None,
          inference=None,
          input_size=24,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        )
      ]
    ),
    InvertedResidualBlock(
      in_channels=24,
      expansion=6,
      stride=2,
      alpha=1.0,
      filters=32,
      pw_filters=32,
      block_id=3,
      layers=[
        Conv2d(
          num_spatial_dims=2,
          weight=f32[144,24,1,1],
          bias=None,
          in_channels=24,
          out_channels=144,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[144],
          bias=f32[144],
          first_time_index=StateIndex(marker=18, init=None),
          state_index=StateIndex(marker=19, init=None),
          axis_name=None,
          inference=None,
          input_size=144,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        DepthwiseConv2D(
          in_channels=144,
          out_channels=144,
          kernel_size=(1, 1),
          stride=(2, 2),
          padding='VALID',
          depth_multiplier=1,
          use_bias=False,
          groups=144,
          key=u32[2],
          kernel=f32[144,144,1,1],
          bias=f32[144]
        ),
        BatchNorm(
          weight=f32[144],
          bias=f32[144],
          first_time_index=StateIndex(marker=20, init=None),
          state_index=StateIndex(marker=21, init=None),
          axis_name=None,
          inference=None,
          input_size=144,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        Conv2d(
          num_spatial_dims=2,
          weight=f32[32,144,1,1],
          bias=None,
          in_channels=144,
          out_channels=32,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[32],
          bias=f32[32],
          first_time_index=StateIndex(marker=22, init=None),
          state_index=StateIndex(marker=23, init=None),
          axis_name=None,
          inference=None,
          input_size=32,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        )
      ]
    ),
    InvertedResidualBlock(
      in_channels=32,
      expansion=6,
      stride=1,
      alpha=1.0,
      filters=32,
      pw_filters=32,
      block_id=4,
      layers=[
        Conv2d(
          num_spatial_dims=2,
          weight=f32[192,32,1,1],
          bias=None,
          in_channels=32,
          out_channels=192,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[192],
          bias=f32[192],
          first_time_index=StateIndex(marker=24, init=None),
          state_index=StateIndex(marker=25, init=None),
          axis_name=None,
          inference=None,
          input_size=192,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        DepthwiseConv2D(
          in_channels=192,
          out_channels=192,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding='VALID',
          depth_multiplier=1,
          use_bias=False,
          groups=192,
          key=u32[2],
          kernel=f32[192,192,1,1],
          bias=f32[192]
        ),
        BatchNorm(
          weight=f32[192],
          bias=f32[192],
          first_time_index=StateIndex(marker=26, init=None),
          state_index=StateIndex(marker=27, init=None),
          axis_name=None,
          inference=None,
          input_size=192,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        Conv2d(
          num_spatial_dims=2,
          weight=f32[32,192,1,1],
          bias=None,
          in_channels=192,
          out_channels=32,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[32],
          bias=f32[32],
          first_time_index=StateIndex(marker=28, init=None),
          state_index=StateIndex(marker=29, init=None),
          axis_name=None,
          inference=None,
          input_size=32,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        )
      ]
    ),
    InvertedResidualBlock(
      in_channels=32,
      expansion=6,
      stride=1,
      alpha=1.0,
      filters=32,
      pw_filters=32,
      block_id=5,
      layers=[
        Conv2d(
          num_spatial_dims=2,
          weight=f32[192,32,1,1],
          bias=None,
          in_channels=32,
          out_channels=192,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[192],
          bias=f32[192],
          first_time_index=StateIndex(marker=30, init=None),
          state_index=StateIndex(marker=31, init=None),
          axis_name=None,
          inference=None,
          input_size=192,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        DepthwiseConv2D(
          in_channels=192,
          out_channels=192,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding='VALID',
          depth_multiplier=1,
          use_bias=False,
          groups=192,
          key=u32[2],
          kernel=f32[192,192,1,1],
          bias=f32[192]
        ),
        BatchNorm(
          weight=f32[192],
          bias=f32[192],
          first_time_index=StateIndex(marker=32, init=None),
          state_index=StateIndex(marker=33, init=None),
          axis_name=None,
          inference=None,
          input_size=192,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        Conv2d(
          num_spatial_dims=2,
          weight=f32[32,192,1,1],
          bias=None,
          in_channels=192,
          out_channels=32,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[32],
          bias=f32[32],
          first_time_index=StateIndex(marker=34, init=None),
          state_index=StateIndex(marker=35, init=None),
          axis_name=None,
          inference=None,
          input_size=32,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        )
      ]
    ),
    InvertedResidualBlock(
      in_channels=32,
      expansion=6,
      stride=2,
      alpha=1.0,
      filters=64,
      pw_filters=64,
      block_id=6,
      layers=[
        Conv2d(
          num_spatial_dims=2,
          weight=f32[192,32,1,1],
          bias=None,
          in_channels=32,
          out_channels=192,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[192],
          bias=f32[192],
          first_time_index=StateIndex(marker=36, init=None),
          state_index=StateIndex(marker=37, init=None),
          axis_name=None,
          inference=None,
          input_size=192,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        DepthwiseConv2D(
          in_channels=192,
          out_channels=192,
          kernel_size=(1, 1),
          stride=(2, 2),
          padding='VALID',
          depth_multiplier=1,
          use_bias=False,
          groups=192,
          key=u32[2],
          kernel=f32[192,192,1,1],
          bias=f32[192]
        ),
        BatchNorm(
          weight=f32[192],
          bias=f32[192],
          first_time_index=StateIndex(marker=38, init=None),
          state_index=StateIndex(marker=39, init=None),
          axis_name=None,
          inference=None,
          input_size=192,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        Conv2d(
          num_spatial_dims=2,
          weight=f32[64,192,1,1],
          bias=None,
          in_channels=192,
          out_channels=64,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[64],
          bias=f32[64],
          first_time_index=StateIndex(marker=40, init=None),
          state_index=StateIndex(marker=41, init=None),
          axis_name=None,
          inference=None,
          input_size=64,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        )
      ]
    ),
    InvertedResidualBlock(
      in_channels=64,
      expansion=6,
      stride=1,
      alpha=1.0,
      filters=64,
      pw_filters=64,
      block_id=7,
      layers=[
        Conv2d(
          num_spatial_dims=2,
          weight=f32[384,64,1,1],
          bias=None,
          in_channels=64,
          out_channels=384,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[384],
          bias=f32[384],
          first_time_index=StateIndex(marker=42, init=None),
          state_index=StateIndex(marker=43, init=None),
          axis_name=None,
          inference=None,
          input_size=384,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        DepthwiseConv2D(
          in_channels=384,
          out_channels=384,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding='VALID',
          depth_multiplier=1,
          use_bias=False,
          groups=384,
          key=u32[2],
          kernel=f32[384,384,1,1],
          bias=f32[384]
        ),
        BatchNorm(
          weight=f32[384],
          bias=f32[384],
          first_time_index=StateIndex(marker=44, init=None),
          state_index=StateIndex(marker=45, init=None),
          axis_name=None,
          inference=None,
          input_size=384,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        Conv2d(
          num_spatial_dims=2,
          weight=f32[64,384,1,1],
          bias=None,
          in_channels=384,
          out_channels=64,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[64],
          bias=f32[64],
          first_time_index=StateIndex(marker=46, init=None),
          state_index=StateIndex(marker=47, init=None),
          axis_name=None,
          inference=None,
          input_size=64,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        )
      ]
    ),
    InvertedResidualBlock(
      in_channels=64,
      expansion=6,
      stride=1,
      alpha=1.0,
      filters=64,
      pw_filters=64,
      block_id=8,
      layers=[
        Conv2d(
          num_spatial_dims=2,
          weight=f32[384,64,1,1],
          bias=None,
          in_channels=64,
          out_channels=384,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[384],
          bias=f32[384],
          first_time_index=StateIndex(marker=48, init=None),
          state_index=StateIndex(marker=49, init=None),
          axis_name=None,
          inference=None,
          input_size=384,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        DepthwiseConv2D(
          in_channels=384,
          out_channels=384,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding='VALID',
          depth_multiplier=1,
          use_bias=False,
          groups=384,
          key=u32[2],
          kernel=f32[384,384,1,1],
          bias=f32[384]
        ),
        BatchNorm(
          weight=f32[384],
          bias=f32[384],
          first_time_index=StateIndex(marker=50, init=None),
          state_index=StateIndex(marker=51, init=None),
          axis_name=None,
          inference=None,
          input_size=384,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        Conv2d(
          num_spatial_dims=2,
          weight=f32[64,384,1,1],
          bias=None,
          in_channels=384,
          out_channels=64,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[64],
          bias=f32[64],
          first_time_index=StateIndex(marker=52, init=None),
          state_index=StateIndex(marker=53, init=None),
          axis_name=None,
          inference=None,
          input_size=64,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        )
      ]
    ),
    InvertedResidualBlock(
      in_channels=64,
      expansion=6,
      stride=1,
      alpha=1.0,
      filters=64,
      pw_filters=64,
      block_id=9,
      layers=[
        Conv2d(
          num_spatial_dims=2,
          weight=f32[384,64,1,1],
          bias=None,
          in_channels=64,
          out_channels=384,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[384],
          bias=f32[384],
          first_time_index=StateIndex(marker=54, init=None),
          state_index=StateIndex(marker=55, init=None),
          axis_name=None,
          inference=None,
          input_size=384,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        DepthwiseConv2D(
          in_channels=384,
          out_channels=384,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding='VALID',
          depth_multiplier=1,
          use_bias=False,
          groups=384,
          key=u32[2],
          kernel=f32[384,384,1,1],
          bias=f32[384]
        ),
        BatchNorm(
          weight=f32[384],
          bias=f32[384],
          first_time_index=StateIndex(marker=56, init=None),
          state_index=StateIndex(marker=57, init=None),
          axis_name=None,
          inference=None,
          input_size=384,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        Conv2d(
          num_spatial_dims=2,
          weight=f32[64,384,1,1],
          bias=None,
          in_channels=384,
          out_channels=64,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[64],
          bias=f32[64],
          first_time_index=StateIndex(marker=58, init=None),
          state_index=StateIndex(marker=59, init=None),
          axis_name=None,
          inference=None,
          input_size=64,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        )
      ]
    ),
    InvertedResidualBlock(
      in_channels=64,
      expansion=6,
      stride=1,
      alpha=1.0,
      filters=96,
      pw_filters=96,
      block_id=10,
      layers=[
        Conv2d(
          num_spatial_dims=2,
          weight=f32[384,64,1,1],
          bias=None,
          in_channels=64,
          out_channels=384,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[384],
          bias=f32[384],
          first_time_index=StateIndex(marker=60, init=None),
          state_index=StateIndex(marker=61, init=None),
          axis_name=None,
          inference=None,
          input_size=384,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        DepthwiseConv2D(
          in_channels=384,
          out_channels=384,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding='VALID',
          depth_multiplier=1,
          use_bias=False,
          groups=384,
          key=u32[2],
          kernel=f32[384,384,1,1],
          bias=f32[384]
        ),
        BatchNorm(
          weight=f32[384],
          bias=f32[384],
          first_time_index=StateIndex(marker=62, init=None),
          state_index=StateIndex(marker=63, init=None),
          axis_name=None,
          inference=None,
          input_size=384,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        Conv2d(
          num_spatial_dims=2,
          weight=f32[96,384,1,1],
          bias=None,
          in_channels=384,
          out_channels=96,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[96],
          bias=f32[96],
          first_time_index=StateIndex(marker=64, init=None),
          state_index=StateIndex(marker=65, init=None),
          axis_name=None,
          inference=None,
          input_size=96,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        )
      ]
    ),
    InvertedResidualBlock(
      in_channels=96,
      expansion=6,
      stride=1,
      alpha=1.0,
      filters=96,
      pw_filters=96,
      block_id=11,
      layers=[
        Conv2d(
          num_spatial_dims=2,
          weight=f32[576,96,1,1],
          bias=None,
          in_channels=96,
          out_channels=576,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[576],
          bias=f32[576],
          first_time_index=StateIndex(marker=66, init=None),
          state_index=StateIndex(marker=67, init=None),
          axis_name=None,
          inference=None,
          input_size=576,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        DepthwiseConv2D(
          in_channels=576,
          out_channels=576,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding='VALID',
          depth_multiplier=1,
          use_bias=False,
          groups=576,
          key=u32[2],
          kernel=f32[576,576,1,1],
          bias=f32[576]
        ),
        BatchNorm(
          weight=f32[576],
          bias=f32[576],
          first_time_index=StateIndex(marker=68, init=None),
          state_index=StateIndex(marker=69, init=None),
          axis_name=None,
          inference=None,
          input_size=576,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        Conv2d(
          num_spatial_dims=2,
          weight=f32[96,576,1,1],
          bias=None,
          in_channels=576,
          out_channels=96,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[96],
          bias=f32[96],
          first_time_index=StateIndex(marker=70, init=None),
          state_index=StateIndex(marker=71, init=None),
          axis_name=None,
          inference=None,
          input_size=96,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        )
      ]
    ),
    InvertedResidualBlock(
      in_channels=96,
      expansion=6,
      stride=1,
      alpha=1.0,
      filters=96,
      pw_filters=96,
      block_id=12,
      layers=[
        Conv2d(
          num_spatial_dims=2,
          weight=f32[576,96,1,1],
          bias=None,
          in_channels=96,
          out_channels=576,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[576],
          bias=f32[576],
          first_time_index=StateIndex(marker=72, init=None),
          state_index=StateIndex(marker=73, init=None),
          axis_name=None,
          inference=None,
          input_size=576,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        DepthwiseConv2D(
          in_channels=576,
          out_channels=576,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding='VALID',
          depth_multiplier=1,
          use_bias=False,
          groups=576,
          key=u32[2],
          kernel=f32[576,576,1,1],
          bias=f32[576]
        ),
        BatchNorm(
          weight=f32[576],
          bias=f32[576],
          first_time_index=StateIndex(marker=74, init=None),
          state_index=StateIndex(marker=75, init=None),
          axis_name=None,
          inference=None,
          input_size=576,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        Conv2d(
          num_spatial_dims=2,
          weight=f32[96,576,1,1],
          bias=None,
          in_channels=576,
          out_channels=96,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[96],
          bias=f32[96],
          first_time_index=StateIndex(marker=76, init=None),
          state_index=StateIndex(marker=77, init=None),
          axis_name=None,
          inference=None,
          input_size=96,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        )
      ]
    ),
    InvertedResidualBlock(
      in_channels=96,
      expansion=6,
      stride=2,
      alpha=1.0,
      filters=160,
      pw_filters=160,
      block_id=13,
      layers=[
        Conv2d(
          num_spatial_dims=2,
          weight=f32[576,96,1,1],
          bias=None,
          in_channels=96,
          out_channels=576,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[576],
          bias=f32[576],
          first_time_index=StateIndex(marker=78, init=None),
          state_index=StateIndex(marker=79, init=None),
          axis_name=None,
          inference=None,
          input_size=576,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        DepthwiseConv2D(
          in_channels=576,
          out_channels=576,
          kernel_size=(1, 1),
          stride=(2, 2),
          padding='VALID',
          depth_multiplier=1,
          use_bias=False,
          groups=576,
          key=u32[2],
          kernel=f32[576,576,1,1],
          bias=f32[576]
        ),
        BatchNorm(
          weight=f32[576],
          bias=f32[576],
          first_time_index=StateIndex(marker=80, init=None),
          state_index=StateIndex(marker=81, init=None),
          axis_name=None,
          inference=None,
          input_size=576,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        Conv2d(
          num_spatial_dims=2,
          weight=f32[160,576,1,1],
          bias=None,
          in_channels=576,
          out_channels=160,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[160],
          bias=f32[160],
          first_time_index=StateIndex(marker=82, init=None),
          state_index=StateIndex(marker=83, init=None),
          axis_name=None,
          inference=None,
          input_size=160,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        )
      ]
    ),
    InvertedResidualBlock(
      in_channels=160,
      expansion=6,
      stride=1,
      alpha=1.0,
      filters=160,
      pw_filters=160,
      block_id=14,
      layers=[
        Conv2d(
          num_spatial_dims=2,
          weight=f32[960,160,1,1],
          bias=None,
          in_channels=160,
          out_channels=960,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[960],
          bias=f32[960],
          first_time_index=StateIndex(marker=84, init=None),
          state_index=StateIndex(marker=85, init=None),
          axis_name=None,
          inference=None,
          input_size=960,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        DepthwiseConv2D(
          in_channels=960,
          out_channels=960,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding='VALID',
          depth_multiplier=1,
          use_bias=False,
          groups=960,
          key=u32[2],
          kernel=f32[960,960,1,1],
          bias=f32[960]
        ),
        BatchNorm(
          weight=f32[960],
          bias=f32[960],
          first_time_index=StateIndex(marker=86, init=None),
          state_index=StateIndex(marker=87, init=None),
          axis_name=None,
          inference=None,
          input_size=960,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        Conv2d(
          num_spatial_dims=2,
          weight=f32[160,960,1,1],
          bias=None,
          in_channels=960,
          out_channels=160,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[160],
          bias=f32[160],
          first_time_index=StateIndex(marker=88, init=None),
          state_index=StateIndex(marker=89, init=None),
          axis_name=None,
          inference=None,
          input_size=160,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        )
      ]
    ),
    InvertedResidualBlock(
      in_channels=160,
      expansion=6,
      stride=1,
      alpha=1.0,
      filters=160,
      pw_filters=160,
      block_id=15,
      layers=[
        Conv2d(
          num_spatial_dims=2,
          weight=f32[960,160,1,1],
          bias=None,
          in_channels=160,
          out_channels=960,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[960],
          bias=f32[960],
          first_time_index=StateIndex(marker=90, init=None),
          state_index=StateIndex(marker=91, init=None),
          axis_name=None,
          inference=None,
          input_size=960,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        DepthwiseConv2D(
          in_channels=960,
          out_channels=960,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding='VALID',
          depth_multiplier=1,
          use_bias=False,
          groups=960,
          key=u32[2],
          kernel=f32[960,960,1,1],
          bias=f32[960]
        ),
        BatchNorm(
          weight=f32[960],
          bias=f32[960],
          first_time_index=StateIndex(marker=92, init=None),
          state_index=StateIndex(marker=93, init=None),
          axis_name=None,
          inference=None,
          input_size=960,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        Conv2d(
          num_spatial_dims=2,
          weight=f32[160,960,1,1],
          bias=None,
          in_channels=960,
          out_channels=160,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[160],
          bias=f32[160],
          first_time_index=StateIndex(marker=94, init=None),
          state_index=StateIndex(marker=95, init=None),
          axis_name=None,
          inference=None,
          input_size=160,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        )
      ]
    ),
    InvertedResidualBlock(
      in_channels=160,
      expansion=6,
      stride=1,
      alpha=1.0,
      filters=320,
      pw_filters=320,
      block_id=16,
      layers=[
        Conv2d(
          num_spatial_dims=2,
          weight=f32[960,160,1,1],
          bias=None,
          in_channels=160,
          out_channels=960,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[960],
          bias=f32[960],
          first_time_index=StateIndex(marker=96, init=None),
          state_index=StateIndex(marker=97, init=None),
          axis_name=None,
          inference=None,
          input_size=960,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        DepthwiseConv2D(
          in_channels=960,
          out_channels=960,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding='VALID',
          depth_multiplier=1,
          use_bias=False,
          groups=960,
          key=u32[2],
          kernel=f32[960,960,1,1],
          bias=f32[960]
        ),
        BatchNorm(
          weight=f32[960],
          bias=f32[960],
          first_time_index=StateIndex(marker=98, init=None),
          state_index=StateIndex(marker=99, init=None),
          axis_name=None,
          inference=None,
          input_size=960,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        ),
        None,
        Conv2d(
          num_spatial_dims=2,
          weight=f32[320,960,1,1],
          bias=None,
          in_channels=960,
          out_channels=320,
          kernel_size=(1, 1),
          stride=(1, 1),
          padding=((0, 0), (0, 0)),
          dilation=(1, 1),
          groups=1,
          use_bias=False,
          padding_mode='ZEROS'
        ),
        BatchNorm(
          weight=f32[320],
          bias=f32[320],
          first_time_index=StateIndex(marker=100, init=None),
          state_index=StateIndex(marker=101, init=None),
          axis_name=None,
          inference=None,
          input_size=320,
          eps=0.001,
          channelwise_affine=True,
          momentum=0.99
        )
      ]
    ),
    Conv2d(
      num_spatial_dims=2,
      weight=f32[1280,320,1,1],
      bias=None,
      in_channels=320,
      out_channels=1280,
      kernel_size=(1, 1),
      stride=(1, 1),
      padding=((0, 0), (0, 0)),
      dilation=(1, 1),
      groups=1,
      use_bias=False,
      padding_mode='ZEROS'
    ),
    BatchNorm(
      weight=f32[1280],
      bias=f32[1280],
      first_time_index=StateIndex(marker=102, init=None),
      state_index=StateIndex(marker=103, init=None),
      axis_name=None,
      inference=None,
      input_size=1280,
      eps=0.001,
      channelwise_affine=True,
      momentum=0.999
    ),
    None,
    AvgPool2d(
      init=None,
      operation=None,
      num_spatial_dims=2,
      kernel_size=(7, 7),
      stride=(1, 1),
      padding=((0, 0), (0, 0)),
      use_ceil=False
    ),
    None,
    Linear(
      weight=f32[10,81920],
      bias=f32[10],
      in_features=81920,
      out_features=10,
      use_bias=True
    ),
    None
  ]
).

#### Oliver's model below here

In [None]:
model = MobileNetV2(in_channels=1, num_classes=10, key=key)
print(model)

In [None]:
# For MobileNetV2

def loss(
    model: MobileNetV2, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
    # Our input has the shape (BATCH_SIZE, 1, 28, 28), but our model operations on
    # a single input input image of shape (1, 28, 28).
    #
    # Therefore, we have to use jax.vmap, which in this case maps our model over the
    # leading (batch) axis.
    pred_y = jax.vmap(model)(x)
    return cross_entropy(y, pred_y)


def cross_entropy(
    y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"]
) -> Float[Array, ""]:
    # y are the true targets, and should be integers 0-9.
    # pred_y are the log-softmax'd predictions.
    pred_y = jnp.take_along_axis(pred_y, jnp.expand_dims(y, 1), axis=1)
    return -jnp.mean(pred_y)


# Example loss
loss_value = loss(model, dummy_x, dummy_y)
print(loss_value.shape)  # scalar loss
# Example inference
output = jax.vmap(model)(dummy_x)
print(output.shape)  # batch of predictions

## Results