# GPT-J-6B Inference Demo

<a href="http://colab.research.google.com/github/kingoflolz/mesh-transformer-jax/blob/master/colab_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook demonstrates how to run the [GPT-J-6B model](https://github.com/kingoflolz/mesh-transformer-jax/#GPT-J-6B). See the link for more details about the model, including evaluation metrics and credits.

## Install Dependencies

First we download the model and install some dependencies. This step takes at least 5 minutes (possibly longer depending on server load).

!!! **Make sure you are using a TPU runtime!** !!!

In [1]:
!apt install zstd

# the "slim" version contain only bf16 weights and no optimizer parameters, which minimizes bandwidth and memory
# !time wget -c https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd
# !time tar -I zstd -xf step_383500_slim.tar.zstd
# !time tar -I zstd -xf step_11000_slim.tar.zstd
!git clone https://github.com/kingoflolz/mesh-transformer-jax.git
!pip install -r mesh-transformer-jax/requirements.txt

# jax 0.2.12 is required due to a regression with xmap in 0.2.13
!pip install mesh-transformer-jax/ jax==0.2.12 tensorflow==2.5.0

Reading package lists... Done
Building dependency tree       
Reading state information... Done
zstd is already the newest version (1.3.3+dfsg-2ubuntu1.2).
0 upgraded, 0 newly installed, 0 to remove and 37 not upgraded.
ServiceException: 401 Anonymous caller does not have storage.objects.get access to the Google Cloud Storage object.
CommandException: 1 file/object could not be transferred.
fatal: destination path 'mesh-transformer-jax' already exists and is not an empty directory.
Collecting git+https://github.com/deepmind/dm-haiku (from -r mesh-transformer-jax/requirements.txt (line 8))
  Cloning https://github.com/deepmind/dm-haiku to /tmp/pip-req-build-q2nxc0rf
  Running command git clone -q https://github.com/deepmind/dm-haiku /tmp/pip-req-build-q2nxc0rf
Collecting git+https://github.com/EleutherAI/lm-evaluation-harness/ (from -r mesh-transformer-jax/requirements.txt (line 9))
  Cloning https://github.com/EleutherAI/lm-evaluation-harness/ to /tmp/pip-req-build-exofus39
  Running c

In [16]:
!gcloud auth login

Go to the following link in your browser:

    https://accounts.google.com/o/oauth2/auth?response_type=code&client_id=32555940559.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=openid+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fappengine.admin+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcompute+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Faccounts.reauth&state=3Za1rzP8VetTLVPLaVHW3CgiqzTe4a&prompt=consent&access_type=offline&code_challenge=uiPCVxy18q4W0uDHySn88yzlbvHv0I0vzSPGVMCkn1s&code_challenge_method=S256

Enter verification code: 4/1AX4XfWirKNKEB1rmgQ4c6JRjo1XnxISVEf1MUoLBKmd1-j-QfSKwrxXEqXQ

You are now logged in as [wongsiuho561@gmail.com].
Your current project is [None].  You can change this setting by running:
  $ gcloud config set project PROJECT_ID


In [17]:
!gcloud config set project gpt-j-fine-tuning-ticker

Updated property [core/project].


In [18]:
!gsutil -m cp -r "gs://bucket_for_tpu_gptj/tuned_to_finbase_2_slim/step_11000/" .

Copying gs://bucket_for_tpu_gptj/tuned_to_finbase_2_slim/step_11000/shard_0/0.npz...
/ [0/128 files][    0.0 B/ 11.3 GiB]   0% Done                                  Copying gs://bucket_for_tpu_gptj/tuned_to_finbase_2_slim/step_11000/shard_0/10.npz...
/ [0/128 files][    0.0 B/ 11.3 GiB]   0% Done                                  Copying gs://bucket_for_tpu_gptj/tuned_to_finbase_2_slim/step_11000/shard_0/1.npz...
/ [0/128 files][    0.0 B/ 11.3 GiB]   0% Done                                  Copying gs://bucket_for_tpu_gptj/tuned_to_finbase_2_slim/step_11000/shard_0/11.npz...
/ [0/128 files][    0.0 B/ 11.3 GiB]   0% Done                                  Copying gs://bucket_for_tpu_gptj/tuned_to_finbase_2_slim/step_11000/shard_0/12.npz...
/ [0/128 files][    0.0 B/ 11.3 GiB]   0% Done                                  Copying gs://bucket_for_tpu_gptj/tuned_to_finbase_2_slim/step_11000/shard_0/13.npz...
/ [0/128 files][    0.0 B/ 11.3 GiB]   0% Done                                  

## Setup Model


In [19]:
import os
import requests 
from jax.config import config

colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
url = f'http://{colab_tpu_addr}:8475/requestversion/tpu_driver0.1_dev20210607'
requests.post(url)

# The following is required to use TPU Driver as JAX's backend.
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']

In [36]:
!pip install einops

Collecting einops
  Using cached einops-0.3.2-py3-none-any.whl (25 kB)
Installing collected packages: einops
Successfully installed einops-0.3.2


Sometimes the next step errors for some reason, just run it again ¯\\\_(ツ)\_/¯

In [37]:
import time

import jax
from jax.experimental import maps
import numpy as np
import optax
import transformers

from mesh_transformer.checkpoint import read_ckpt_lowmem
from mesh_transformer.sampling import nucleaus_sample
from mesh_transformer.transformer_shard import CausalTransformer

In [38]:
params = {
  "layers": 28,
  "d_model": 4096,
  "n_heads": 16,
  "n_vocab": 50400,
  "norm": "layernorm",
  "pe": "rotary",
  "pe_rotary_dims": 64,

  "seq": 2048,
  "cores_per_replica": 8,
  "per_replica_batch": 1,
}

per_replica_batch = params["per_replica_batch"]
cores_per_replica = params["cores_per_replica"]
seq = params["seq"]


params["sampler"] = nucleaus_sample

# here we "remove" the optimizer parameters from the model (as we don't need them for inference)
params["optimizer"] = optax.scale(0)

mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
devices = np.array(jax.devices()).reshape(mesh_shape)

maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')))

tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')

Downloading:   0%|          | 0.00/0.99M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/665 [00:00<?, ?B/s]

Here we create the network and load the parameters from the downloaded files. Expect this to take around 5 minutes.

In [39]:
total_batch = per_replica_batch * jax.device_count() // cores_per_replica

network = CausalTransformer(params)

network.state = read_ckpt_lowmem(network.state, "step_11000/", devices.shape[1])

network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))

  warn("xmap is an experimental feature and probably has bugs!")


key shape (8, 2)
in shape (1, 2048)
dp 1
mp 8
Total parameters: 6053381344
read from disk/gcs in 358.092s


## Run Model

Finally, we are ready to infer with the model! The first sample takes around a minute due to compilation, but after that it should only take about 10 seconds per sample.

Feel free to mess with the different sampling parameters (top_p and temp), as well as the length of the generations (gen_len, causes a recompile when changed).

You can also change other things like per_replica_batch in the previous cells to change how many generations are done in parallel. A larger batch has higher latency but higher throughput when measured in tokens generated/s. This is useful for doing things like best-of-n cherry picking.

*Tip for best results: Make sure your prompt does not have any trailing spaces, which tend to confuse the model due to the BPE tokenization used during training.*

In [40]:
# allow text wrapping in generated output: https://stackoverflow.com/a/61401455
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

In [41]:
def infer(context, top_p=0.9, temp=1.0, gen_len=512):
    tokens = tokenizer.encode(context)

    provided_ctx = len(tokens)
    pad_amount = seq - provided_ctx

    padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
    batched_tokens = np.array([padded_tokens] * total_batch)
    length = np.ones(total_batch, dtype=np.uint32) * len(tokens)

    start = time.time()
    output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp})

    samples = []
    decoded_tokens = output[1][0]

    for o in decoded_tokens[:, :, 0]:
      samples.append(f"\033[1m{context}\033[0m{tokenizer.decode(o)}")

    print(f"completion done in {time.time() - start:06}s")
    return samples

print(infer("EleutherAI is")[0])

completion done in 69.26040363311768s
[1mEleutherAI is[0m the leading supplier of choice based services in Finland.
"According to Sweden's Minister for Local Government and Financial Markets, Mats Odell, the decision to sell the State's shares in telecom group TeliaSonera can only be carried out in cooperation with the State of Finland."
Finnish beverage company Olvi is introducing a new long drink Olvi Kultalonkero ( `` golden long drink '' ) in the market in Finland in the spring of 2009.
"Finnish GeoSentric, a developer and provider of solutions, products, and technologies for location based services, has preliminary agreed on a EUR 6mn short-term funding with its leading investor."
Finnish Raisio ( Diagnostics ) is launching new DNA-based quick tests to ensure the safety of food.
Finnish Rautaruukki is selling its precision tube and automotive component processing unit Carl Froh in Germany to German Arques Industries.
"Finnish Suominen Corporation that makes wet wipes, nonwovens,

In [61]:
#@title  { form-width: "300px" }
top_p = 0.9 #@param {type:"slider", min:0, max:1, step:0.1}
temp = 1 #@param {type:"slider", min:0, max:1, step:0.1}

context = """A good ticker is"""

print(infer(top_p=top_p, temp=temp, gen_len=512, context=context)[0])

completion done in 13.560888290405273s
[1mA good ticker is[0m just a useless financial celebrity sideshow, which is to guess is pretty funny to many investors. In fact, this rule of thumb might be very funny to those investors who intend to use it as a shortcut to a higher probability of doing better in the future.

Much of the financial celebrity sideshow that are focused on the past and the past. When measured across your portfolio, the greater the true cost of the financial services industry, the more money you will leave behind. It is estimated that the vast majority of mutual funds sold by brokers are low cost, broadly diversified, passive funds with low turnover. Buy them and ignore the rest with middling or higher fees.

In addition, if you do not have a clear understanding of ETF trading, buy only mutual funds. After the May 6, 2010 stock market flash crash, it should be clear that naive traders fooling with ETF market orders and stop loss orders that automatically convert to