# Accelerating [MPT-1B](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b) with JAX

Accelerate MosaicML's MPT-1B by converting them to JAX for faster inference.  
Note: This notebook was tested on a CPU server with 112 GB of RAM. The minimum memory requirement is 64 GB, so it might not be runnable on Colab. 

⚠️ If you are running this notebook in Colab, you will have to install `Ivy` and some dependencies manually. You can do so by running the cell below ⬇️

If you want to run the notebook locally but don't have Ivy installed just yet, you can check out the [Get Started section of the docs.](https://unify.ai/docs/ivy/overview/get_started.html)

In [None]:
!pip install -q ivy
!pip install -q transformers

For the installed packages to be available you will have to restart your kernel. In Colab, you can do this by clicking on **"Runtime > Restart Runtime"**. Once the runtime has been restarted you should skip the previous cell 😄

Let's now import Ivy and the libraries we'll use in this example:

In [1]:
import jax
import ivy
import numpy as np

import transformers

jax.config.update('jax_enable_x64', True)

Now we can load the MPT model from Hugging Face transformers library

In [None]:
model = transformers.AutoModelForCausalLM.from_pretrained(
  'mosaicml/mpt-1b-redpajama-200b',
  trust_remote_code=True
)

We will also need a sample input to pass during tracing, so let's use the appropriate model methods to get the dummy tensors.

In [3]:
id = model.main_input_name
expanded_dummy = model._expand_inputs_for_generation(expand_size=1, input_ids=model.dummy_inputs.get(id))
model_input = {id: expanded_dummy[0]}

And finally, let's transpile the model to jax!

In [None]:
transpiled_graph = ivy.transpile(model.__call__, source="torch", to="jax", kwargs=model_input)

Let's now do JAX just in time compilation:

In [6]:
jit_inputs = {}
for key, value in model_input.items():
        jit_inputs[key] = jax.numpy.array(value.cpu().numpy())

def fn(x):
    return transpiled_graph(**x).logits

In [7]:
jitted = jax.jit(fn)
_ = jitted(jit_inputs)

Now that we have both models, let's see how their runtime speeds compare to each other!

In [8]:
%time _ = model(**model_input).logits

CPU times: user 1.97 s, sys: 4.25 ms, total: 1.97 s
Wall time: 369 ms


In [9]:
%time _ = jitted(jit_inputs).block_until_ready()

CPU times: user 1.66 s, sys: 29.9 ms, total: 1.69 s
Wall time: 344 ms


As expected, we have made the model faster with just one line of code! 🚀  
Note: While the above numbers are from a CPU server, it is recommended to try this out on GPU - which is where we see most of the latency speedups from transpiling. 😊

Finally, as a sanity check, let's load a different image and make sure that the results are the same in both models

In [12]:
id = model.main_input_name
expanded_dummy = model._expand_inputs_for_generation(expand_size=1, input_ids=model.dummy_inputs.get(id))
model_input = {id: expanded_dummy[0]}

out_torch = model(**model_input).logits
out_jax = jitted(jit_inputs).block_until_ready()

np.allclose(out_torch.detach().cpu().numpy(), out_jax, atol=1e-4)

True