<a href="https://colab.research.google.com/github/SunG206/3D-Reconstruction-with-Deep-Learning-Methods/blob/master/GPT_NeoX_20B_on_Flax_(xmap).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 🚨 Caution: Don't use this notebook for research. 🚨

<p>The Flax implementation on TPUs currently has a slight performance regression relative to the PyTorch implementations. The comparison can be seen <a href="https://docs.google.com/spreadsheets/d/16dhItf9iWMQsYqhcH6wUwZyi-Bf71Nr0TNTVvwAXGRo/edit?usp=sharing">here</a>.</p>

<p>If you want to evaluate GPT-NeoX-20B for research purposes, please use the original <a href="">GPT-Neox</a>, <a href="https://github.com/zphang/minimal-gpt-neox-20b">Minimal PyTorch<a/> or <a href="https://huggingface.co/EleutherAI/gpt-neox-20b">Hugging Face</a> implementations.</p>

<p>This TPU implementation of GPT-NeoX-20B is also still a prototype with some hacks, so if you see any room for improvement, please drop by <a href="https://github.com/zphang/minimal-gpt-neox-20b">the repo</a>!</p>

(For instance, I'm resorting to fp32 for some operations to avoid NaNs, which leads to greater memory usage than is necessary.)

In [None]:
!rm -rf minimal-gpt-neox-20b
!git clone https://github.com/zphang/minimal-gpt-neox-20b.git

In [None]:
!pip install -r minimal-gpt-neox-20b/requirements_flax.txt

In [None]:
# Download tokenizer
!wget https://mystic.the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/20B_tokenizer.json -P 20B_checkpoint
# Download weights (takes a couple minutes, usually ~5 min)
!cat ./minimal-gpt-neox-20b/assets/weights_urls.txt | xargs -n 1 -P 8 wget -P 20B_checkpoint -q

In [None]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

In [None]:
import sys
sys.path += ["./minimal-gpt-neox-20b"]
from flax.core import frozen_dict
import jax.numpy as jnp
import numpy as np
import tokenizers
from minimal20b_flax import create, model_xmap, generate

import IPython

def show(input_string, output_string):
  display(IPython.display.HTML("<div style='width:800px; font-family: monospace;'>{}<b style='color:orange;'>{}</b></div>".format(input_string, output_string)))

model_xmap.CACHED_FUNCS = {}

In [None]:
# Loading weights (~10-15 min)
params = create.colab_load_model_weights_for_xmap("./20B_checkpoint/",)

In [None]:
tokenizer = tokenizers.Tokenizer.from_file("./20B_checkpoint/20B_tokenizer.json")
neox_model = model_xmap.GPTNeoX20BModel(
    generate_length=100,  # How many tokens to generate
    sampler_args=frozen_dict.freeze({"temp": 1.})  # Sampling temperature
)

In [None]:
input_string = """GPT-NeoX-20B is a 20B-parameter autoregressive Transformer model developed by EleutherAI."""

In [None]:
# JIT-ing the first inference call takes a while (~3 min), and it will re-JIT
# every time the context/output lengths/temperature are modified. Otherwise,
# on subsequent calls the model should run much more quickly (~20s maybe?)
output = generate.generate(
    input_string=input_string,
    neox_model=neox_model,
    params=params,
    tokenizer=tokenizer,
)

In [None]:
show(input_string, output["generated_string"])