# Accelerating MMPreTrain models with JAX

Accelerate your MMPreTrain models by converting them to JAX for faster inference.

Install OpenMIM and mmpretrain

In [None]:
!pip install -U -q openmim && mim install -q "mmpretrain>=1.0.0rc8"

⚠️ If you are running this notebook in Colab, you will have to install `Ivy` and some dependencies manually. You can do so by running the cell below ⬇️

If you want to run the notebook locally but don't have Ivy installed just yet, you can check out the [Get Started section of the docs.](https://unify.ai/docs/ivy/overview/get_started.html)

In [None]:
!pip install ivy
!pip install dm-haiku
exit()

For the installed packages to be available you will have to restart your kernel. In Colab, you can do this by clicking on **"Runtime > Restart Runtime"**. Once the runtime has been restarted you should skip the previous cell 😄

Let's now import Ivy and the libraries we'll use in this example:

In [2]:
import jax
import ivy
import torch
import requests
import numpy as np
from PIL import Image

import torchvision
from mmpretrain import get_model, list_models
from mmengine import ConfigDict

Sanity check to make sure checkpoint name is correct against mmpretrain's [model zoo](https://mmpretrain.readthedocs.io/en/latest/modelzoo_statistics.html#pretrained-models)

In [None]:
checkpoint_name = "convnext-tiny_32xb128-noema_in1k"
list_models(checkpoint_name)

Now we can load the ConvNext model from OpenMMLab's mmpretrain library

In [None]:
jax.config.update("jax_enable_x64", True)

model = get_model(checkpoint_name, pretrained=True, device='cuda')

We will also need a sample image to pass during tracing, so let's use the appropriate transforms to get the corresponding torch tensors.

In [None]:
def get_scale(cfg):
    if type(cfg) == ConfigDict:
        if cfg.get('type', False) and cfg.get('scale', False):
            return cfg['scale']
        else:
            for k in cfg.keys():
                input_shape = get_scale(cfg[k])
                if input_shape:
                    return input_shape
    elif type(cfg) == list:
        for block in cfg:
            input_shape = get_scale(block)
            if input_shape:
                return input_shape
    else:
        return None

In [5]:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
input_shape = get_scale(model._config.train_pipeline)
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((input_shape, input_shape)),
    torchvision.transforms.ToTensor()
])
tensor_image = transform(image).unsqueeze(0).to("cuda")

And finally, let's transpile the model to haiku!

In [None]:
transpiled_graph = ivy.transpile(model, to="haiku", args=(tensor_image,))

After transpiling our model, we can see what's the improvement in runtime efficiency like. For this let's compile the original PyTorch model using `torch.compile`

In [9]:
tensor_image = transform(image).unsqueeze(0).to("cuda")

def _f(args):
  return model(args)

comp_model = torch.compile(_f)
_ = comp_model(tensor_image)

Let's now do the equivalent transformation in our new haiku model by using JAX just in time compilation:

In [10]:
tensor_image = transform(image).unsqueeze(0).to("cuda")
np_image = tensor_image.detach().cpu().numpy()

import haiku as hk

def _forward(args):
  module = transpiled_graph()
  return module(args)

_forward = jax.jit(_forward)
rng_key = jax.random.PRNGKey(42)
jax_mlp_forward = hk.transform(_forward)
params = jax_mlp_forward.init(rng=rng_key, args=np_image)

Now that we have both models optimized, let's see how their runtime speeds compare to each other!


In [11]:
%%timeit
_ = comp_model(tensor_image)

5.46 ms ± 72.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [12]:
%%timeit
out = jax_mlp_forward.apply(params, None, np_image)

2.79 ms ± 133 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


As expected, we have made the model significantly faster with just one line of code, getting a ~2x increase in its execution speed! 🚀

Finally, as a sanity check, let's load a different image and make sure that the results are the same in both models

In [13]:
url = "http://images.cocodataset.org/train2017/000000283921.jpg"
image = Image.open(requests.get(url, stream=True).raw)
tensor_image = transform(image).unsqueeze(0).to("cuda")
np_image = tensor_image.detach().cpu().numpy()
out_torch = comp_model(tensor_image)
out_jax = jax_mlp_forward.apply(params, None, np_image)

np.allclose(out_torch.detach().cpu().numpy(), out_jax, atol=1e-4)

True

That's pretty much it! The results from both models are the same, but we have achieved a solid speed up by using Ivy's transpiler to convert the model to JAX!