In [None]:
# Mostly from https://github.com/kingoflolz/mesh-transformer-jax
# And torch loader stuff from https://github.com/KoboldAI/KoboldAI-Client tpu_mtj_backend.py and torch_lazy_loader.py
# So probably under AGPL and apache mix

Example of using pytorch lazy loader to create a checkpoint then load it into a TPU without using RAM directly.

In [None]:
# !rm pytorch_model.bin

In [None]:
!time wget -c https://huggingface.co/EleutherAI/gpt-neo-2.7B/resolve/main/pytorch_model.bin

In [None]:
# !time wget -c https://huggingface.co/EleutherAI/gpt-neo-1.3B/resolve/main/pytorch_model.bin

In [None]:
!wget -c https://github.com/henk717/KoboldAI/raw/united/requirements_mtj.txt
!wget -c https://github.com/henk717/KoboldAI/raw/united/torch_lazy_loader.py
!wget -c https://github.com/henk717/KoboldAI/raw/united/utils.py

In [None]:
# !pip install -r /content/requirements_mtj.txt

In [None]:
!pip install numpy tqdm requests optax==0.0.9 dm-haiku==0.0.9 chex==0.1.5 jax==0.3.25 jaxlib==0.3.25 transformers progressbar2 git+https://github.com/Zurnaz/mesh-transformer-jax.git@tpu_driver

In [None]:
# Allows running locally on 8 devices using CPU vs TPU
# import os
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

In [None]:

import os
import requests 
from jax.config import config

driver_version="tpu_driver_nightly"
# driver_version="tpu_driver_20221011"
#  driver_version="tpu_driver0.2"
if os.environ.get('COLAB_TPU_ADDR', '') != '':
    tpu_address = os.environ['COLAB_TPU_ADDR']  # Colab
else:
    tpu_address = os.environ['TPU_NAME']  # Kaggle

tpu_address = tpu_address.replace("grpc://", "")
tpu_address_without_port = tpu_address.split(':', 1)[0]
url = f'http://{tpu_address_without_port}:8475/requestversion/{driver_version}'
requests.post(url)


# The following is required to use TPU Driver as JAX's backend.
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + tpu_address


In [None]:
mtj_config_1 = {
  "compat": "neo",
  "pe": "fixed",
  "layers": 32,
  "d_model": 2560, 
  "n_heads": 20,
  "n_vocab": 50257,
  "n_vocab_padding": 0,
  "norm": "layernorm",
  # "pe_rotary_pct": 0.25,
  # "pe_rotary_dims": 64,
  # "d_head": 256,
  "seq": 2048,
  "cores_per_replica": 4,
  "per_replica_batch": 1,
  # "d_embed": 5120,
  # "early_cast": True,
  "do_layer_norm_before": True,
}

mtj_config_2 = {
  "layers": 28,
  "d_model": 4096,
  "n_heads": 16,
  "n_vocab": 50400,
  "norm": "layernorm",
  "pe": "rotary",
  "pe_rotary_dims": 64,
  "d_head": 256,
  "seq": 2048,
  "cores_per_replica": 8,
  "per_replica_batch": 1,
}
mtj_config = mtj_config_1

In [None]:
mtj_config["n_vocab_padding"] = -(mtj_config["n_vocab"] % -mtj_config["cores_per_replica"])

In [None]:
checkpoint_dir = "jax_checkpoint" # jax_checkpoint step_383500

In [None]:
total_shards = mtj_config["cores_per_replica"]
d_model = mtj_config["d_model"]
layers = mtj_config["layers"]
pieces = 16
padding_rows = mtj_config["n_vocab_padding"]

In [None]:
#@title Load checkpoint
config_path = "config.json"
work_dir = "" # "/content/"
checkpoint_dir = "jax_checkpoint" # jax_checkpoint step_383500
path_to_checkpoint = f"{work_dir}pytorch_model.bin" 
import torch_lazy_loader
import os
from termcolor import colored
from IPython.display import clear_output
import torch
  


In [None]:
#@title Convert checkpoint to be JAX-compatible
import zipfile
import functools
from IPython.display import clear_output
import torch
import numpy as np
import jax.numpy as jnp

def callback(model_dict, f, **_):
  for i in range(total_shards):
      os.makedirs(f"{checkpoint_dir}/shard_{i}")
  pieces = 16

  def reshard_reverse(x, old_shape, is_shard_bias=False):
      if len(x.shape) == 1:
          assert False
          out = x[0:1]

      elif len(x.shape) == 2:
          #print(f"LN/bias")
          if old_shape[1] == x.shape[1]:
              #print("LN")
              if not is_shard_bias:
                  out = np.tile(x[0:1], (total_shards, 1))
              else:
                  #print("shard bias")
                  out = np.tile(x[0:1], (total_shards, 1)) / total_shards
          else:
              #print("bias")
              out = x.reshape(old_shape)

      elif len(x.shape) == 3:
          if x.shape[0] * x.shape[2] == old_shape[2]:
              #print("case 1")
              out = x.reshape(old_shape)
          elif x.shape[0] * x.shape[1] == old_shape[1]:
              #print("case 2")
              out = jnp.transpose(x.reshape((old_shape[1], old_shape[0], old_shape[2])), (1, 0, 2))
          else:
              raise Exception(f"unimplemented, {x.shape}, {old_shape}")
      else:
          raise Exception(f"unimplemented, {x}")
      #flattened, structure =jax.tree_util.tree_structure(out)
      #return flattened
      return out

  def get_old_shape(t, dim=2):
      if len(t.shape) == 2:
          shard_shape = t.shape
          if dim == 1:
              assert shard_shape[0] % total_shards == 0
              return (shard_shape[0] // total_shards, shard_shape[1])
          elif dim == 2:
              assert shard_shape[1] % total_shards == 0
              return (shard_shape[0], shard_shape[1] // total_shards)
          else:
              raise ValueError(f"unsupported dim {dim}")
      if len(t.shape) == 1:
          assert t.shape[0] % total_shards == 0
          return (t.shape[0] // total_shards,)
      else:
          raise ValueError(f"unsupported shape {t.shape}")


  def split(a, n):
      k, m = divmod(len(a), n)
      return (a[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n))

  def save(cpu_flattened):
      for i in range(total_shards):
          cpu_flattened_chunked = split(cpu_flattened, pieces)
          for j, chunk in enumerate(cpu_flattened_chunked):
              with open(f"{checkpoint_dir}/shard_{i}/{j}.npz", "wb") as f:
                  np.savez(f, *map(lambda c: c[i], chunk))


  transforms = [
      ("transformer.wpe.weight", False, 2),
      ("transformer.wte.weight", False, 1)
  ]
  
  layer_names = sorted(map(str, range(layers)))
  for layer in layer_names:
      transforms.extend([
          (f"transformer.h.{layer}.attn.attention.q_proj.weight", False, 2),
          (f"transformer.h.{layer}.attn.attention.v_proj.weight", False, 2),
          (f"transformer.h.{layer}.attn.attention.k_proj.weight", False, 2),
          (f"transformer.h.{layer}.attn.attention.out_proj.bias", True, None),
          (f"transformer.h.{layer}.attn.attention.out_proj.weight", False, 1),
          (f"transformer.h.{layer}.mlp.c_fc.bias", False, 1),
          (f"transformer.h.{layer}.mlp.c_fc.weight", False, 2),
          (f"transformer.h.{layer}.mlp.c_proj.bias", True, None),
          (f"transformer.h.{layer}.mlp.c_proj.weight", False, 1),
          (f"transformer.h.{layer}.ln_1.bias", False, None),
          (f"transformer.h.{layer}.ln_1.weight", False, None),
          (f"transformer.h.{layer}.ln_2.bias", False, None),
          (f"transformer.h.{layer}.ln_2.weight", False, None),
      ])

  transforms.extend([
      ("transformer.ln_f.bias", False, None),
      ("transformer.ln_f.weight", False, None),
  ])
  
  checkpoint = []
  with zipfile.ZipFile(f, "r") as z:
    last_storage_key = None
    zipfolder = os.path.basename(os.path.normpath(f)).split('.')[0]
    f = None
    current_offset = 0
    for i in range(len(transforms)):
        transform = transforms.pop(0)
        print(i, transform[0])
        
        key = transform[0]

        storage_key = model_dict[key].key
        if storage_key != last_storage_key or model_dict[key].seek_offset < current_offset:
            last_storage_key = storage_key
            if isinstance(f, zipfile.ZipExtFile):
                f.close()
            try:
                f = z.open(f"archive/data/{storage_key}")
            except:
                f = z.open(f"{zipfolder}/data/{storage_key}")
            current_offset = 0
        if current_offset != model_dict[key].seek_offset:
            f.read(model_dict[key].seek_offset - current_offset)
            current_offset = model_dict[key].seek_offset
        if not isinstance(model_dict[key], torch_lazy_loader.LazyTensor):
            error = f"Duplicate key {repr(key)}"
            print("\n\nERROR:  " + error, file=sys.stderr)
            raise RuntimeError(error)
        size = functools.reduce(lambda x, y: x * y, model_dict[key].shape, 1)
        dtype = model_dict[key].dtype
        nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
        tensor = model_dict[key].materialize(f, map_location="cpu")
        model_dict[key] = tensor.to("meta")
        current_offset += nbytes

        params = tensor

        # Pad input and output embeddings with 0 at the bottom to have 50400 rows
        # instead of 50257 rows (the padding value doesn't have to be 0, it doesn't
        # even have to be a constant value; the only thing the padding affects is
        # it adds junk logits to the end of the logits array the transformer returns
        # without affecting the other logits)
        if transform[0] in ("transformer.wte.weight", "lm_head.weight"):
            params = torch.cat((params, torch.zeros(padding_rows, params.shape[-1], device=params.device)), dim=0)
            # params = torch.cat((params, torch.zeros(143, params.shape[1])))
        # torch.nn.Linear uses a transposed version of the equivalent tensor that
        # haiku.Linear uses, so we have to un-transpose the tensor first
        if not any(s in transform[0] for s in ("wte", "wpe")):
            params = params.T
            
        
        if transform[2] is not None:
            old_shape = (total_shards,) + get_old_shape(params, transform[2])
        else:
            old_shape = (total_shards, params.shape[0],)
        
        # print(f"<1 [{transform[0]}] {params.shape} to {old_shape}")
        
        params = np.asarray(params[None], dtype=jnp.bfloat16)
        params = reshard_reverse(params, old_shape, is_shard_bias=transform[1])
        
        if np.isnan(params).any() or np.isinf(params).any():
            raise ValueError(f"bfloat16 overflow/underflow")

        #print(f">2 [{transform[0]}] {params.shape}")
        assert params.shape == old_shape
        checkpoint.append(params)

  # Append the checkpoint step number (can be set to an arbitrary value, in this
  # case 0, as long as we're only using inference and not training the model)
  checkpoint.append(np.zeros(total_shards, dtype=np.int32))

  print("saving")
  save(checkpoint)
  del checkpoint
  del params
  print(colored(f"DONE! The JAX checkpoint is now stored at {work_dir}{checkpoint_dir}", "green"))

In [None]:
!rm -R jax_checkpoint

In [None]:

with torch_lazy_loader.use_lazy_torch_load(callback=callback, dematerialized_modules=True):
  torch_checkpoint = torch.load(path_to_checkpoint, map_location="cpu")

In [None]:
import time

import jax
from jax.experimental import maps
import numpy as np
import optax
import transformers

from mesh_transformer.checkpoint import read_ckpt_lowmem, read_ckpt
from mesh_transformer.sampling import nucleaus_sample
from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerV2

per_replica_batch = mtj_config["per_replica_batch"]
cores_per_replica = mtj_config["cores_per_replica"]
seq = mtj_config["seq"]


mtj_config["sampler"] = nucleaus_sample

# here we "remove" the optimizer parameters from the model (as we don't need them for inference)
# mtj_config["optimizer"] = optax.scale(0)

mesh_shape = (1, cores_per_replica)
devices = jax.devices()
devices = np.array(devices[:cores_per_replica]).reshape(mesh_shape)

# mesh_shape = ( jax.device_count() // cores_per_replica, cores_per_replica)
# devices = np.array(jax.devices()).reshape(mesh_shape)

print("mesh_shape", mesh_shape)
print("devices", devices)

# maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp','mp')),())
global_mesh = maps.Mesh(devices, ('dp', 'mp'))
maps.thread_resources.env = maps.ResourceEnv(physical_mesh=global_mesh, loops=())

tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')

total_batch = per_replica_batch * jax.device_count() // cores_per_replica

In [None]:
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, TypeVar
class _EmptyState(NamedTuple):
    pass

class _DummyOptimizer:
    def init(*args, **kwargs):
        return _EmptyState()
mtj_config["optimizer"] = _DummyOptimizer()

In [None]:
# network = CausalTransformer(mtj_config)
network = CausalTransformer(mtj_config, dematerialized=True)

In [None]:
from mesh_transformer.util import to_bf16, to_f16, to_f32
move_xmap = jax.experimental.maps.xmap(
    fun=lambda x, _: to_f32(x),
    in_axes=(["shard", ...], ["batch", ...]),
    out_axes=["shard", ...],
    axis_resources={'shard': 'mp', 'batch': 'dp'}
)

In [None]:
from mesh_transformer.checkpoint import read_ckpt_lowmem, read_ckpt
# network.state["opt_state"] = None # only used when re-running block to reset network
network.state = read_ckpt(network.state, f"{checkpoint_dir}/", cores_per_replica, shards_out=cores_per_replica)


In [None]:
network.state = move_xmap(network.state, np.zeros(cores_per_replica))

In [None]:
def infer(context, top_p=0.9, temp=1.0, gen_len=512):
    tokens = tokenizer.encode(context)

    provided_ctx = len(tokens)
    pad_amount = seq - provided_ctx

    padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
    batched_tokens = np.array([padded_tokens] * total_batch)
    length = np.ones(total_batch, dtype=np.uint32) * len(tokens)

    start = time.time()
    output = network.generate(
        batched_tokens,
        length,
        gen_len,
        {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp}
    )

    samples = []
    decoded_tokens = output[1][0]
    # print("output", len(output[0]))
    for o in decoded_tokens[:, :, 0]:
      samples.append(tokenizer.decode(o))

      #samples.append(f"\033[1m{context}\033[0m{tokenizer.decode(o)}")
        # single = o[0][0, 0, seq : seq + gen_len]
        # print("single", single, tokenizer.decode(single))
    print(f"completion done in {time.time() - start:06}s")
    return samples

# print(infer("EleutherAI is")[0])
# print(infer("EleutherAI is"))

In [None]:
#@title  { form-width: "300px" }
top_p = 0.9 #@param {type:"slider", min:0, max:1, step:0.1}
temp = 1 #@param {type:"slider", min:0, max:1, step:0.1}

#context = """In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English."""
context = """The world of tomorrow is going"""
print(infer(top_p=top_p, temp=temp, gen_len=10, context=context))