# GPT-J-6B Inference Demo

<a href="http://colab.research.google.com/github/kingoflolz/mesh-transformer-jax/blob/master/colab_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook demonstrates how to run the [GPT-J-6B model](https://github.com/kingoflolz/mesh-transformer-jax/#GPT-J-6B). See the link for more details about the model, including evaluation metrics and credits.

## Install Dependencies

First we download the model and install some dependencies. This step takes at least 5 minutes (possibly longer depending on server load).

!!! **Make sure you are using a TPU runtime!** !!!

In [None]:
!apt install zstd

# the "slim" version contain only bf16 weights and no optimizer parameters, which minimizes bandwidth and memory
!time wget -c https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd

!time tar -I zstd -xf step_383500_slim.tar.zstd

!git clone https://github.com/kingoflolz/mesh-transformer-jax.git
!pip install -r mesh-transformer-jax/requirements.txt

# jax 0.2.12 is required due to a regression with xmap in 0.2.13
!pip install mesh-transformer-jax/ jax==0.2.12 tensorflow==2.5.0

Error: (1,6): error CS1002: ; expected
(1,18): error CS1002: ; expected
(3,3): error CS1024: Preprocessor directive expected
(4,7): error CS1002: ; expected
(4,15): error CS1002: ; expected
(6,7): error CS1002: ; expected
(6,14): error CS1002: ; expected
(6,23): error CS1002: ; expected
(8,2): error CS1002: ; expected
(8,12): error CS1002: ; expected
(9,6): error CS1002: ; expected
(9,17): error CS1002: ; expected
(11,3): error CS1024: Preprocessor directive expected
(12,2): error CS1002: ; expected
(12,14): error CS1002: ; expected
(12,44): error CS1002: ; expected
(12,48): error CS1002: ; expected
(12,63): error CS1002: ; expected

## Setup Model


In [None]:
import os
import requests 
from jax.config import config

colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
url = f'http://{colab_tpu_addr}:8475/requestversion/tpu_driver0.1_dev20210607'
requests.post(url)

# The following is required to use TPU Driver as JAX's backend.
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']

Error: (1,10): error CS1002: ; expected
(2,16): error CS1002: ; expected
(3,24): error CS1003: Syntax error, 'in' expected
(3,30): error CS0742: A query body must end with a select clause or a group clause
(3,30): error CS1002: ; expected
(5,29): error CS1012: Too many characters in character literal
(5,60): error CS1002: ; expected
(6,8): error CS1002: ; expected
(6,8): error CS1012: Too many characters in character literal
(6,79): error CS1002: ; expected
(7,19): error CS1002: ; expected
(9,3): error CS1024: Preprocessor directive expected
(10,44): error CS1002: ; expected
(11,58): error CS1012: Too many characters in character literal

Sometimes the next step errors for some reason, just run it again ¯\\\_(ツ)\_/¯

In [None]:
import time

import jax
from jax.experimental import maps
import numpy as np
import optax
import transformers

from mesh_transformer.checkpoint import read_ckpt_lowmem
from mesh_transformer.sampling import nucleaus_sample
from mesh_transformer.transformer_shard import CausalTransformer

Error: (1,12): error CS1002: ; expected
(3,11): error CS1002: ; expected
(4,30): error CS1003: Syntax error, 'in' expected
(4,34): error CS0742: A query body must end with a select clause or a group clause
(4,34): error CS1002: ; expected
(5,14): error CS1002: ; expected
(5,14): error CS1525: Invalid expression term 'as'
(5,19): error CS1002: ; expected
(6,13): error CS1002: ; expected
(7,20): error CS1002: ; expected
(9,41): error CS1003: Syntax error, 'in' expected
(10,39): error CS1003: Syntax error, 'in' expected
(11,48): error CS1003: Syntax error, 'in' expected
(11,65): error CS0742: A query body must end with a select clause or a group clause

In [None]:
params = {
  "layers": 28,
  "d_model": 4096,
  "n_heads": 16,
  "n_vocab": 50400,
  "norm": "layernorm",
  "pe": "rotary",
  "pe_rotary_dims": 64,

  "seq": 2048,
  "cores_per_replica": 8,
  "per_replica_batch": 1,
}

per_replica_batch = params["per_replica_batch"]
cores_per_replica = params["cores_per_replica"]
seq = params["seq"]


params["sampler"] = nucleaus_sample

# here we "remove" the optimizer parameters from the model (as we don't need them for inference)
params["optimizer"] = optax.scale(0)

mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
devices = np.array(jax.devices()).reshape(mesh_shape)

maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')))

tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')

Error: (1,8): error CS1525: Invalid expression term '='
(1,1): error CS7017: Member definition, statement, or end-of-file expected
(1,10): error CS1525: Invalid expression term '{'
(1,10): error CS1002: ; expected
(2,11): error CS1002: ; expected
(2,11): error CS1513: } expected
(2,15): error CS1002: ; expected
(2,15): error CS1513: } expected
(3,12): error CS1002: ; expected
(3,12): error CS1513: } expected
(3,18): error CS1002: ; expected
(3,18): error CS1513: } expected
(4,12): error CS1002: ; expected
(4,12): error CS1513: } expected
(4,16): error CS1002: ; expected
(4,16): error CS1513: } expected
(5,12): error CS1002: ; expected
(5,12): error CS1513: } expected
(5,19): error CS1002: ; expected
(5,19): error CS1513: } expected
(6,9): error CS1002: ; expected
(6,9): error CS1513: } expected
(6,22): error CS1002: ; expected
(6,22): error CS1513: } expected
(7,7): error CS1002: ; expected
(7,7): error CS1513: } expected
(7,17): error CS1002: ; expected
(7,17): error CS1513: } expected
(8,19): error CS1002: ; expected
(8,19): error CS1513: } expected
(8,23): error CS1002: ; expected
(8,23): error CS1513: } expected
(10,8): error CS1002: ; expected
(10,8): error CS1513: } expected
(10,14): error CS1002: ; expected
(10,14): error CS1513: } expected
(11,22): error CS1002: ; expected
(11,22): error CS1513: } expected
(11,25): error CS1002: ; expected
(11,25): error CS1513: } expected
(12,22): error CS1002: ; expected
(12,22): error CS1513: } expected
(12,25): error CS1002: ; expected
(12,25): error CS1513: } expected
(15,21): error CS1525: Invalid expression term 'params'
(15,21): error CS1002: ; expected
(15,21): error CS7017: Member definition, statement, or end-of-file expected
(15,28): error CS1001: Identifier expected
(15,28): error CS1001: Identifier expected
(16,19): error CS1525: Invalid expression term '='
(16,21): error CS1525: Invalid expression term 'params'
(16,21): error CS1002: ; expected
(16,21): error CS7017: Member definition, statement, or end-of-file expected
(16,28): error CS1001: Identifier expected
(16,28): error CS1001: Identifier expected
(17,5): error CS1525: Invalid expression term '='
(17,7): error CS1525: Invalid expression term 'params'
(17,7): error CS1002: ; expected
(17,7): error CS7017: Member definition, statement, or end-of-file expected
(17,14): error CS1001: Identifier expected
(17,14): error CS1001: Identifier expected
(20,1): error CS7017: Member definition, statement, or end-of-file expected
(20,8): error CS1001: Identifier expected
(20,8): error CS1001: Identifier expected
(20,19): error CS1525: Invalid expression term '='
(20,36): error CS1002: ; expected
(22,3): error CS1024: Preprocessor directive expected
(23,8): error CS1001: Identifier expected
(23,8): error CS1001: Identifier expected
(23,21): error CS1525: Invalid expression term '='
(23,37): error CS1002: ; expected
(25,33): error CS1026: ) expected
(25,33): error CS1002: ; expected
(26,54): error CS1002: ; expected
(28,66): error CS1012: Too many characters in character literal
(28,72): error CS1012: Too many characters in character literal
(28,79): error CS1002: ; expected
(30,60): error CS1012: Too many characters in character literal

Here we create the network and load the parameters from the downloaded files. Expect this to take around 5 minutes.

In [None]:
total_batch = per_replica_batch * jax.device_count() // cores_per_replica

network = CausalTransformer(params)

network.state = read_ckpt_lowmem(network.state, "step_383500/", devices.shape[1])

network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))

Error: (1,53): error CS1002: ; expected
(3,35): error CS1031: Type expected
(3,35): error CS1001: Identifier expected
(3,36): error CS1003: Syntax error, '=>' expected
(5,82): error CS1002: ; expected

## Run Model

Finally, we are ready to infer with the model! The first sample takes around a minute due to compilation, but after that it should only take about 10 seconds per sample.

Feel free to mess with the different sampling parameters (top_p and temp), as well as the length of the generations (gen_len, causes a recompile when changed).

You can also change other things like per_replica_batch in the previous cells to change how many generations are done in parallel. A larger batch has higher latency but higher throughput when measured in tokens generated/s. This is useful for doing things like best-of-n cherry picking.

*Tip for best results: Make sure your prompt does not have any trailing spaces, which tend to confuse the model due to the BPE tokenization used during training.*

In [None]:
# allow text wrapping in generated output: https://stackoverflow.com/a/61401455
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

Error: (1,3): error CS1024: Preprocessor directive expected
(2,29): error CS1003: Syntax error, 'in' expected
(2,33): error CS0742: A query body must end with a select clause or a group clause
(2,33): error CS1002: ; expected
(2,33): error CS7017: Member definition, statement, or end-of-file expected
(4,5): error CS1002: ; expected
(4,14): error CS1002: ; expected
(4,14): error CS7017: Member definition, statement, or end-of-file expected
(5,16): error CS1011: Empty character literal
(5,18): error CS1003: Syntax error, ',' expected
(5,18): error CS1010: Newline in constant
(5,18): error CS1011: Empty character literal
(7,9): error CS1026: ) expected
(7,9): error CS1026: ) expected
(7,9): error CS1002: ; expected
(8,20): error CS1002: ; expected
(8,20): error CS1513: } expected
(9,6): error CS1525: Invalid expression term '<'
(10,4): error CS1525: Invalid expression term '/'
(11,3): error CS1011: Empty character literal
(11,5): error CS1002: ; expected
(11,5): error CS1010: Newline in constant
(11,5): error CS1012: Too many characters in character literal
(11,8): error CS1002: ; expected
(12,31): error CS1012: Too many characters in character literal

In [None]:
def infer(context, top_p=0.9, temp=1.0, gen_len=512):
    tokens = tokenizer.encode(context)

    provided_ctx = len(tokens)
    pad_amount = seq - provided_ctx

    padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
    batched_tokens = np.array([padded_tokens] * total_batch)
    length = np.ones(total_batch, dtype=np.uint32) * len(tokens)

    start = time.time()
    output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp})

    samples = []
    decoded_tokens = output[1][0]

    for o in decoded_tokens[:, :, 0]:
      samples.append(f"\033[1m{context}\033[0m{tokenizer.decode(o)}")

    print(f"completion done in {time.time() - start:06}s")
    return samples

print(infer("EleutherAI is")[0])

Error: (1,18): error CS1001: Identifier expected
(1,25): error CS1001: Identifier expected
(1,35): error CS1001: Identifier expected
(1,48): error CS1001: Identifier expected
(1,54): error CS1018: Keyword 'this' or 'base' expected
(1,54): error CS1002: ; expected
(2,39): error CS1002: ; expected
(4,31): error CS1002: ; expected
(5,36): error CS1002: ; expected
(7,53): error CS1525: Invalid expression term ')'
(7,73): error CS1002: ; expected
(8,31): error CS1026: ) expected
(8,60): error CS1002: ; expected
(8,60): error CS7017: Member definition, statement, or end-of-file expected
(9,65): error CS1002: ; expected
(11,24): error CS1002: ; expected
(12,64): error CS1525: Invalid expression term '{'
(12,64): error CS1026: ) expected
(12,64): error CS1002: ; expected
(12,72): error CS1002: ; expected
(12,72): error CS1513: } expected
(12,102): error CS1002: ; expected
(12,102): error CS1513: } expected
(12,110): error CS1002: ; expected
(12,110): error CS1513: } expected
(12,139): error CS1002: ; expected
(12,140): error CS7017: Member definition, statement, or end-of-file expected
(14,15): error CS1525: Invalid expression term '['
(14,16): error CS0443: Syntax error; value expected
(14,17): error CS1002: ; expected
(15,34): error CS1002: ; expected
(17,9): error CS1003: Syntax error, '(' expected
(17,11): error CS1003: Syntax error, ',' expected
(17,14): error CS1003: Syntax error, ',' expected
(17,29): error CS1001: Identifier expected
(17,30): error CS0443: Syntax error; value expected
(17,32): error CS1525: Invalid expression term ':'
(17,32): error CS1003: Syntax error, ',' expected
(17,37): error CS1003: Syntax error, ',' expected
(17,38): error CS1003: Syntax error, ',' expected
(18,23): error CS1003: Syntax error, ',' expected
(18,70): error CS1003: Syntax error, ',' expected
(20,12): error CS1003: Syntax error, ',' expected
(20,59): error CS1002: ; expected
(20,59): error CS1525: Invalid expression term 'return'
(20,59): error CS1002: ; expected
(20,59): error CS1026: ) expected
(21,19): error CS1002: ; expected

In [None]:
#@title  { form-width: "300px" }
top_p = 0.9 #@param {type:"slider", min:0, max:1, step:0.1}
temp = 1 #@param {type:"slider", min:0, max:1, step:0.1}

context = """In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English."""

print(infer(top_p=top_p, temp=temp, gen_len=512, context=context)[0])

Error: (1,1): error CS1024: Preprocessor directive expected
(2,13): error CS1040: Preprocessor directives must appear as the first non-whitespace character on a line
(3,1): error CS1002: ; expected
(3,10): error CS1040: Preprocessor directives must appear as the first non-whitespace character on a line
(3,9): error CS1002: ; expected
(5,247): error CS1002: ; expected