In [None]:
# Mostly from https://github.com/kingoflolz/mesh-transformer-jax
# So probably under Apache License 2.0

In [None]:
# warning: takes long consider copy to and from drive after first download
!time wget -c https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd

In [None]:
# Copy to drive to not have to do the above again
# from google.colab import drive
# !cp /content/step_383500_slim.tar.zstd /content/drive/MyDrive/step_383500_slim.tar.zstd

In [None]:
# load from drive 
# from google.colab import drive
# drive.mount('/content/drive/')
# !cp /content/drive/MyDrive/step_383500_slim.tar.zstd  /content/step_383500_slim.tar.zstd

In [None]:
!apt install zstd
!time tar -I zstd -xf step_383500_slim.tar.zstd

In [None]:
!pip install numpy tqdm requests optax==0.0.9 dm-haiku==0.0.9 chex==0.1.5 jax==0.3.25 jaxlib==0.3.25 transformers progressbar2 git+https://github.com/Zurnaz/mesh-transformer-jax.git@tpu_driver

In [None]:
import os
import requests 
from jax.config import config

driver_version="tpu_driver_nightly"
# driver_version="tpu_driver_20221011"
#  driver_version="tpu_driver0.2"
if os.environ.get('COLAB_TPU_ADDR', '') != '':
    tpu_address = os.environ['COLAB_TPU_ADDR']  # Colab
else:
    tpu_address = os.environ['TPU_NAME']  # Kaggle

tpu_address = tpu_address.replace("grpc://", "")
tpu_address_without_port = tpu_address.split(':', 1)[0]
url = f'http://{tpu_address_without_port}:8475/requestversion/{driver_version}'
requests.post(url)


# The following is required to use TPU Driver as JAX's backend.
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + tpu_address


In [None]:
import time

import jax
from jax.experimental import maps
import numpy as np
import optax
import transformers

from mesh_transformer.checkpoint import read_ckpt_lowmem, read_ckpt
from mesh_transformer.sampling import nucleaus_sample
from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerV2

params = {
  "layers": 28,
  "d_model": 4096,
  "n_heads": 16,
  "n_vocab": 50400,
  "norm": "layernorm",
  "pe": "rotary",
  "pe_rotary_dims": 64,
  "d_head": 256,
  "seq": 2048,
  "cores_per_replica": 8,
  "per_replica_batch": 1,
}

per_replica_batch = params["per_replica_batch"]
cores_per_replica = params["cores_per_replica"]
seq = params["seq"]


params["sampler"] = nucleaus_sample

# here we "remove" the optimizer parameters from the model (as we don't need them for inference)
params["optimizer"] = optax.scale(0)

mesh_shape = ( jax.device_count() // cores_per_replica, cores_per_replica)
print("mesh_shape", mesh_shape)
devices = np.array(jax.devices()).reshape(mesh_shape)
print("devices", devices)
# maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp','mp')),())
global_mesh = maps.Mesh(devices, ('dp', 'mp'))
maps.thread_resources.env = maps.ResourceEnv(physical_mesh=global_mesh, loops=())

tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')

total_batch = per_replica_batch * jax.device_count() // cores_per_replica

In [None]:
network = CausalTransformer(params)

In [None]:
# network.state = read_ckpt_lowmem(network.state, "step_383500/", devices.shape[1])
network.state = read_ckpt(network.state, "step_383500/", 8, shards_out=cores_per_replica)
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))

In [None]:
def infer(context, top_p=0.9, temp=1.0, gen_len=512):
    tokens = tokenizer.encode(context)

    provided_ctx = len(tokens)
    pad_amount = seq - provided_ctx

    padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
    batched_tokens = np.array([padded_tokens] * total_batch)
    length = np.ones(total_batch, dtype=np.uint32) * len(tokens)

    start = time.time()
    output = network.generate(batched_tokens, length, gen_len, {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp})

    samples = []
    decoded_tokens = output[1][0]
    # print("output", len(output[0]))
    for o in decoded_tokens[:, :, 0]:
      print("o", len(o))
      samples.append(tokenizer.decode(o))

      #samples.append(f"\033[1m{context}\033[0m{tokenizer.decode(o)}")
        # single = o[0][0, 0, seq : seq + gen_len]
        # print("single", single, tokenizer.decode(single))
    print(f"completion done in {time.time() - start:06}s")
    return samples

print(infer("EleutherAI is"))

In [None]:
#@title  { form-width: "300px" }
top_p = 0.9 #@param {type:"slider", min:0, max:1, step:0.1}
temp = 1 #@param {type:"slider", min:0, max:1, step:0.1}

#context = """In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English."""
context = """Google colab is"""
print(infer(top_p=top_p, temp=temp, gen_len=512, context=context))