# GPT-J-6B Inference Demo

<a href="http://colab.research.google.com/github/kingoflolz/mesh-transformer-jax/blob/master/colab_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook demonstrates how to run the [GPT-J-6B model](https://github.com/kingoflolz/mesh-transformer-jax/#GPT-J-6B). See the link for more details about the model, including evaluation metrics and credits.

## Install Dependencies

First we download the model and install some dependencies. This step takes at least 5 minutes (possibly longer depending on server load).

!!! **Make sure you are using a TPU runtime!** !!!

In [None]:
!apt install zstd

# the "slim" version contain only bf16 weights and no optimizer parameters, which minimizes bandwidth and memory
# !time wget -c https://mystic.the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd

# !time tar -I zstd -xf step_383500_slim.tar.zstd

# !git clone https://github.com/kingoflolz/mesh-transformer-jax.git
!pip install -r mesh-transformer-jax/requirements.txt

# jax 0.2.12 is required due to a regression with xmap in 0.2.13
!pip install mesh-transformer-jax/ jax==0.2.12 tensorflow==2.5.0

Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following NEW packages will be installed:
  zstd
0 upgraded, 1 newly installed, 0 to remove and 37 not upgraded.
Need to get 278 kB of archives.
After this operation, 1,141 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu bionic-updates/universe amd64 zstd amd64 1.3.3+dfsg-2ubuntu1.2 [278 kB]
Fetched 278 kB in 1s (391 kB/s)
Selecting previously unselected package zstd.
(Reading database ... 155229 files and directories currently installed.)
Preparing to unpack .../zstd_1.3.3+dfsg-2ubuntu1.2_amd64.deb ...
Unpacking zstd (1.3.3+dfsg-2ubuntu1.2) ...
Setting up zstd (1.3.3+dfsg-2ubuntu1.2) ...
Processing triggers for man-db (2.8.3-2ubuntu0.1) ...
Collecting git+https://github.com/EleutherAI/lm-evaluation-harness/ (from -r mesh-transformer-jax/requirements.txt (line 9))
  Cloning https://github.com/EleutherAI/lm-evaluation-harness/ to /tmp/pip-req-build-srias74

## Setup Model


In [None]:
import os
import requests 
from jax.config import config

colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
url = f'http://{colab_tpu_addr}:8475/requestversion/tpu_driver0.1_dev20210607'
requests.post(url)

# The following is required to use TPU Driver as JAX's backend.
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']

Sometimes the next step errors for some reason, just run it again ¯\\\_(ツ)\_/¯

In [None]:
!pip install optax==0.0.9 transformers dm-haiku einops
!pip install ray

import time

import jax
from jax.experimental import maps
import numpy as np
import optax
import transformers

from mesh_transformer.checkpoint import read_ckpt_lowmem
from mesh_transformer.sampling import nucleaus_sample
from mesh_transformer.transformer_shard import CausalTransformer

Collecting optax==0.0.9
  Using cached optax-0.0.9-py3-none-any.whl (118 kB)
Collecting transformers
  Using cached transformers-4.16.1-py3-none-any.whl (3.5 MB)
Collecting dm-haiku
  Using cached dm_haiku-0.0.5-py3-none-any.whl (287 kB)
Collecting einops
  Downloading einops-0.4.0-py3-none-any.whl (28 kB)
Collecting chex>=0.0.4
  Using cached chex-0.1.0-py3-none-any.whl (65 kB)
Collecting pyyaml>=5.1
  Using cached PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
Collecting huggingface-hub<1.0,>=0.1.0
  Using cached huggingface_hub-0.4.0-py3-none-any.whl (67 kB)
Collecting tokenizers!=0.11.3,>=0.10.1
  Using cached tokenizers-0.11.4-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.8 MB)
Collecting sacremoses
  Using cached sacremoses-0.0.47-py2.py3-none-any.whl (895 kB)
Collecting jmp>=0.0.2
  Using cached jmp-0.0.2-py3-none-any.whl (16 kB)
Installing collected packages: pyyaml, tokenizers, sacremoses, jmp

In [None]:
params = {
  "layers": 28,
  "d_model": 4096,
  "n_heads": 16,
  "n_vocab": 50400,
  "norm": "layernorm",
  "pe": "rotary",
  "pe_rotary_dims": 64,

  "seq": 2048,
  "cores_per_replica": 8,
  "per_replica_batch": 1,
}

per_replica_batch = params["per_replica_batch"]
cores_per_replica = params["cores_per_replica"]
seq = params["seq"]


params["sampler"] = nucleaus_sample

# here we "remove" the optimizer parameters from the model (as we don't need them for inference)
params["optimizer"] = optax.scale(0)

mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
devices = np.array(jax.devices()).reshape(mesh_shape)

maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')))

tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')

Downloading:   0%|          | 0.00/0.99M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/665 [00:00<?, ?B/s]

Here we create the network and load the parameters from the downloaded files. Expect this to take around 5 minutes.

In [None]:
total_batch = per_replica_batch * jax.device_count() // cores_per_replica

network = CausalTransformer(params)

network.state = read_ckpt_lowmem(network.state, "step_383500/", devices.shape[1])

network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))

  warn("xmap is an experimental feature and probably has bugs!")


key shape (8, 2)
in shape (1, 2048)
dp 1
mp 8
Total parameters: 6053381344
read from disk/gcs in 389.68s


## Run Model

Finally, we are ready to infer with the model! The first sample takes around a minute due to compilation, but after that it should only take about 10 seconds per sample.

Feel free to mess with the different sampling parameters (top_p and temp), as well as the length of the generations (gen_len, causes a recompile when changed).

You can also change other things like per_replica_batch in the previous cells to change how many generations are done in parallel. A larger batch has higher latency but higher throughput when measured in tokens generated/s. This is useful for doing things like best-of-n cherry picking.

*Tip for best results: Make sure your prompt does not have any trailing spaces, which tend to confuse the model due to the BPE tokenization used during training.*

In [None]:
# allow text wrapping in generated output: https://stackoverflow.com/a/61401455
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

In [None]:
def infer(context, top_p=0.9, temp=1.0, gen_len=512):
    tokens = tokenizer.encode(context)

    provided_ctx = len(tokens)
    pad_amount = seq - provided_ctx

    padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
    batched_tokens = np.array([padded_tokens] * total_batch)
    length = np.ones(total_batch, dtype=np.uint32) * len(tokens)

    start = time.time()
    output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp})

    samples = []
    decoded_tokens = output[1][0]

    for o in decoded_tokens[:, :, 0]:
      samples.append(f"\033[1m{context}\033[0m{tokenizer.decode(o)}")

    print(f"completion done in {time.time() - start:06}s")
    return samples

# print(infer("EleutherAI is")[0])

In [None]:
#@title  { form-width: "300px" }
top_p = 0.9 #@param {type:"slider", min:0, max:1, step:0.1}
temp = 1 #@param {type:"slider", min:0, max:1, step:0.1}

context = """In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English."""

print(infer(top_p=top_p, temp=temp, gen_len=512, context=context)[0])

completion done in 68.21263456344604s
[1mIn a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English.[0m Researchers surmise that the unicorn language was transmitted from the past, hundreds of years ago, by a lonely female unicorn who lived in this valley before dying. (Images via National Geographic)

First of all, National Geographic is sort of famous. We all know them and trust them in many ways. So this is nothing to sneeze at. National Geographic is very much a scientific publication, so their findings on ‘unicorns’ in South America are not surprising. They even published several photographs of this find, so even though it isn’t some scam (I wish it was, haha), it is just pure and simple real scientific research.

You know that when the Big Bang happened, the world was created and anything that is made today wa

# Sentence generation and extraction

In [None]:
import json
import logging
import string
import numpy as np

In [None]:
logging.basicConfig()
logging.getLogger().setLevel(logging.DEBUG)

In [None]:
## kanji data
with open('data/kanji.json', encoding='utf8') as f:
  kanji_data = json.load(f, encoding='utf8')

logging.info(f"amount of kanji: {len(kanji_data)}")

INFO:root:amount of kanji: 13108


In [None]:
## function definitions

# get set of keywords from kanji
def get_keywords(character):
    
    # check if character is present in dataset
    if character in kanji_data.keys():
        char_properties = kanji_data[character]
        
        # check if radicals are available
        if char_properties['wk_radicals'] != None:
            
            # add meaning as keyword
            keywords = [char_properties['wk_meanings'][0].lower()]
            
            # add radicals as keywords
            [keywords.append(rad.lower()) for rad in char_properties['wk_radicals']]
            logging.info(f"Keywords: {keywords}")
            
            if len(keywords) == 2:
              logging.info("The entered character does not contain any radicals.")
            return keywords
        else:
            logging.info("Radicals not available for this character. Try another one.")
            return []
    else:
        logging.info("Character not available. Try another one.")
        return []


# generate sentence from keywords
def get_sentence(keywords, context=""):
  str_keywords = ""
  for k in keywords:
    if k != keywords[0]:
      str_keywords = str_keywords + ", "
    str_keywords = str_keywords + k

  # create input context
  if context == "":
    context = f"""Make a sentence with the following words: earth, dirt, alligator
Sentence: The alligator is a fierce animal which roams the water for its prey, but sometimes it goes on land to dig in the dirt searching for food.

Make a sentence with the following words: {str_keywords}
Sentence: """

  # logging.debug(f"context: {context}")

  # extract sentence out of output
  output = infer(top_p=top_p, temp=temp, gen_len=120, context=context)[0]
  output = output.split("Sentence: ")[2]
  sentence = output.split('.')[0] + '.'

  return sentence


# calculate percentage score of input context
def calculate_context_score(max_sen_score, individual_scores):
  logging.debug(f"max_score: {max_sen_score}, indiv_scores: {individual_scores}")
  for i in range(len(individual_scores)):
    individual_scores[i] = round(individual_scores[i] / max_sen_score * 100, 2)

  context_score = round(sum(individual_scores) / len(individual_scores), 2)

  return context_score, individual_scores
  

# final product

This cell is the final product. Run it, enter a kanji character, and receive a mnemonic sentence.

This cell iterates over the model until the generated sentence contains all keywords.

### test characters
profit 利: grain, knife



In [None]:
## get keywords and generate sentence
keywords = get_keywords(input("Insert a character: "))

if len(keywords) > 2:
  num_included_keywords = 0

  # generate sentences untill all keywords are included
  while num_included_keywords < len(keywords):
    num_included_keywords = 0
    sentence = get_sentence(keywords)
    print(sentence)

    # check for keywords
    for k in keywords:
      if k in sentence:
        num_included_keywords += 1

  print(f"Sentence: {sentence}")

else:
  print(f"Try another character.")

Insert a character: 宇


INFO:root:Keywords: ['outer space', 'roof', 'dry']


completion done in 3.413747787475586s
[0m I like the roof and the outer space.
completion done in 3.4107308387756348s
[0m Outer space is a mystery to us.
completion done in 3.4197258949279785s
[0m Outer space is the top of the sky, but its roof is very thick so the planet never gets too hot.
completion done in 3.423088550567627s
[0m__________(fill in the blank with an astronaut, roof, or
dry) live in space because it is very hot and there is no air.
completion done in 3.4103565216064453s
[0m The roof of my house is made of dry wood so that the rain will not get in.
completion done in 3.406209945678711s
[0m So cold is space that we must wear outer space clothes to keep us from freezing.
completion done in 3.4034621715545654s
[0m Space is the largest of the four known, physical dimensions that we can perceive.
completion done in 3.4052650928497314s
[0m All of the light which comes from the stars and the sun, and the light from space and on top of the roof, is sent out through the