In [None]:
# Isolated code to load huggingface models into TPU
# Mix of AGPL and apache from https://github.com/VE-FORBRYDERNE/mesh-transformer-jax and https://github.com/KoboldAI/KoboldAI-Client

In [None]:

!wget -c https://github.com/henk717/KoboldAI/raw/united/requirements_mtj.txt
!wget -c https://github.com/henk717/KoboldAI/raw/united/torch_lazy_loader.py
!wget -c https://github.com/henk717/KoboldAI/raw/united/utils.py
%pip install -r requirements_mtj.txt
%pip install numpy tqdm requests optax==0.0.9 dm-haiku==0.0.9 chex==0.1.5 jax==0.3.25 jaxlib==0.3.25 transformers==4.28.0 progressbar2 sentencepiece 
%pip install git+https://github.com/Zurnaz/mesh-transformer-jax.git@tpu_driver

In [None]:
# Local 8 simultated devices on CPU
# uncomment and comment out tpu driver to load  on CPU and swap move_xmap from to_bf16 to to_f32
# import os
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'

In [None]:
# LOADS on TPU
import os
import requests 
from jax.config import config

driver_version="tpu_driver_20221109"
# driver_version="tpu_driver_20221011"
#  driver_version="tpu_driver0.2"
if os.environ.get('COLAB_TPU_ADDR', '') != '':
    tpu_address = os.environ['COLAB_TPU_ADDR']  # Colab
else:
    tpu_address = os.environ['TPU_NAME']  # Kaggle

tpu_address = tpu_address.replace("grpc://", "")
tpu_address_without_port = tpu_address.split(':', 1)[0]
url = f'http://{tpu_address_without_port}:8475/requestversion/{driver_version}'
requests.post(url)


# The following is required to use TPU Driver as JAX's backend.
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + tpu_address


In [None]:
def get_default_params():
    return {
        # "sampler": nucleaus_sample,
        "compat": "j",
        "layers": 28,
        "d_model": 4096,
        "n_heads": 16,
        "n_vocab": 50400,
        "n_vocab_padding": 0,
        "norm": "layernorm",
        "pe": "rotary",
        "pe_rotary_dims": 64,
        "seq": 2048,
        "cores_per_replica": 8,
        "tokenizer_class": "GPT2Tokenizer",
        "tokenizer": "gpt2",
    }
    

In [None]:
import os
import json
def generate_mtj_config(hf_checkpoint, model_type, params):
    default_params = get_default_params()
    
    if hf_checkpoint:
        # Try to convert HF config.json to MTJ config
        spec_path = os.path.join("maps", model_type + ".json")
        if not os.path.isfile(spec_path):
            raise NotImplementedError(f"Unsupported model type {repr(model_type)}")
        with open(spec_path) as f:
            lazy_load_spec = json.load(f)
        
        if "mtj_compat" in lazy_load_spec:
            params["compat"] = lazy_load_spec["mtj_compat"]
        if "mtj_pe" in lazy_load_spec:
            params["pe"] = lazy_load_spec["mtj_pe"]
        for k, v in lazy_load_spec.get("mtj_config_map", {}).items():
            if type(v) is not list:
                params[k] = params[v]
                continue
            for i in range(len(v)):
                if i == len(v) - 1:
                    params[k] = v[i]
                elif v[i] in params:
                    params[k] = params[v[i]]
                    break
        
        params["n_vocab"] = params["vocab_size"]
        
        if "activation_function" in params:
            params["activation"] = params["activation_function"]
        
        # Both the number of attention heads in the model and the embedding
        # dimension of the model need to be divisible by the number of TPU cores
        # that we use, and JAX also requires the number of TPU cores used to be
        # an even number if we're using more than one core, so logically we try
        # to pick the largest possible even number of TPU cores such that the
        # number of attention heads and embedding dimension are both divisible
        # by the number of TPU cores, and fall back to one core if an even
        # number of TPU cores is not possible.
        for c in (8, 6, 4, 2, 1):
            if 0 == params["n_heads"] % c == params.get("d_embed", params["d_model"]) % c:
                params["cores_per_replica"] = c
                break
        
        # The vocabulary size of the model also has to be divisible by the
        # number of TPU cores, so we pad the vocabulary with the minimum
        # possible number of dummy tokens such that it's divisible.
        params["n_vocab_padding"] = -(params["n_vocab"] % -params["cores_per_replica"])

    if "compat" in params:
        default_params["compat"] = params["compat"]

    if default_params["compat"] == "fairseq_lm":
        default_params["tokenizer"] = "KoboldAI/fairseq-dense-125M"
    
    for param in default_params:
        if param not in params:
            params[param] = default_params[param]

    # Use an optimization that will allow us to avoid one extra transpose operation
    if hf_checkpoint:
        params["transposed_linear"] = True
    
    model_spec = {}
    for key, spec in lazy_load_spec.get("static_weights", {}).items():
        if spec.get("mtj") is not None:
            model_spec[key] = spec["mtj"].copy()
            model_spec[key]["module"] = "causal_transformer_shard/~/" + model_spec[key]["module"]
    for _key, spec in lazy_load_spec.get("layer_weights", {}).items():
        for layer in range(params["layers"]):
            if spec.get("mtj") is not None:
                key = _key.format(layer=layer)
                model_spec[key] = spec["mtj"].copy()
                model_spec[key]["module"] = "causal_transformer_shard/~/" + model_spec[key]["module"].format(layer=layer)

    return params, model_spec

In [None]:
# Needs the corrisponding maps config for to load the model to generate the config
!mkdir maps
!wget -c https://raw.githubusercontent.com/henk717/KoboldAI/united/maps/gpt_neo.json -P /content/maps

In [None]:
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

model_name_or_path = 'EleutherAI/gpt-neo-2.7B'

hf_config = AutoConfig.from_pretrained(model_name_or_path)

mtj_config, model_spec = generate_mtj_config(True, hf_config.model_type, hf_config.__dict__.copy())


In [None]:
mtj_config["n_vocab_padding"] = -(mtj_config["n_vocab"] % -mtj_config["cores_per_replica"])
print("n_vocab_padding", mtj_config["n_vocab_padding"] )

In [None]:
total_shards = mtj_config["cores_per_replica"]
d_model = mtj_config["d_model"]
layers = mtj_config["layers"]
pieces = 16
padding_rows = mtj_config["n_vocab_padding"]

In [None]:
import torch_lazy_loader
import os
from IPython.display import clear_output
import torch

In [None]:
import time

import jax
from jax.experimental import maps
import numpy as np
import optax
import transformers

from mesh_transformer.checkpoint import read_ckpt_lowmem, read_ckpt
from mesh_transformer.sampling import nucleaus_sample
from mesh_transformer.transformer_shard import CausalTransformer, CausalTransformerV2, PlaceholderTensor

per_replica_batch = 1 # mtj_config["per_replica_batch"]
cores_per_replica = mtj_config["cores_per_replica"]
seq = mtj_config["seq"]


mtj_config["sampler"] = nucleaus_sample

# here we "remove" the optimizer parameters from the model (as we don't need them for inference)
mtj_config["optimizer"] = optax.scale(0)

mesh_shape = (1, cores_per_replica)
devices = jax.devices()
devices = np.array(devices[:cores_per_replica]).reshape(mesh_shape)

# mesh_shape = ( jax.device_count() // cores_per_replica, cores_per_replica)
# devices = np.array(jax.devices()).reshape(mesh_shape)

print("mesh_shape", mesh_shape)
print("devices", devices)
# maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp','mp')),())
global_mesh = maps.Mesh(devices, ('dp', 'mp'))
maps.thread_resources.env = maps.ResourceEnv(physical_mesh=global_mesh, loops=())

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
# tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
# tokenizer = AutoTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
# model = AutoModelForCausalLM.from_pretrained("decapoda-research/llama-7b-hf")
total_batch = per_replica_batch * jax.device_count() // cores_per_replica

In [None]:
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, TypeVar
class _EmptyState(NamedTuple):
    pass

class _DummyOptimizer:
    def init(*args, **kwargs):
        return _EmptyState()
mtj_config["optimizer"] = _DummyOptimizer()

In [None]:
# import AutoModelForCausalLM
network = CausalTransformer(mtj_config, dematerialized=True)
# network = CausalTransformer(mtj_config)

In [None]:
from mesh_transformer.util import to_bf16, to_f16, to_f32

move_xmap = jax.experimental.maps.xmap(
    fun=lambda x, _: to_bf16(x),
    in_axes=(["shard", ...], ["batch", ...]),
    out_axes=["shard", ...],
    axis_resources={'shard': 'mp', 'batch': 'dp'}
)

In [None]:
#@title Convert checkpoint to be JAX-compatible
import sys
import zipfile
import functools
from IPython.display import clear_output
import torch
import numpy as np
import jax.numpy as jnp
import jax.dlpack

def callback(model_dict, f, **_):
    # print("model_dict", model_dict.keys())
    # print("model_spec", model_spec.keys())
    def reshard_reverse(x, old_shape):
        assert len(x.shape) != 1
        if len(x.shape) == 2:
            # print(f"LN/bias {x.shape}")
            # print(x[:, :4])
            if old_shape[1] == x.shape[1]:
                out = x[0:1].tile((total_shards, 1))
            else:
                out = x.reshape(old_shape)
            # print(out[:, :4])
        elif len(x.shape) == 3:
            # print(f"weight {x.shape}")
            # print(x[:, :4])
            if x.shape[0] * x.shape[2] == old_shape[2]:
                out = x.reshape(old_shape)
            elif x.shape[0] * x.shape[1] == old_shape[1]:
                out = x.reshape((old_shape[1], old_shape[0], old_shape[2])).permute((1, 0, 2))
            else:
                assert False
            # print(out[:, :4])
        else:
            assert False
        return out

    with zipfile.ZipFile(f, "r") as z:
        last_storage_key = None
        zipfolder = os.path.basename(os.path.normpath(f)).split('.')[0]
        f = None
        current_offset = 0
        def sort_model_dict(k):
            return (model_dict[k].key, model_dict[k].seek_offset)
        # sorted_keys = sorted(model_dict.keys(), key=lambda k: (model_dict[k].key, model_dict[k].seek_offset))
        # sorted_keys = model_spec.keys()
        sorted_keys = sorted(map(str, model_dict.keys()))
        for i, key in enumerate(sorted_keys):
            model_spec_key = max((k for k in model_spec.keys() if key.endswith(k)), key=len, default=None)
            print(i, key)
            # Some model weights are used by transformers but not by MTJ.
            # We have to materialize these weights anyways because
            # transformers will throw a tantrum otherwise.  To attain
            # the least possible memory usage, we create them as meta
            # tensors, which don't take up any actual CPU or TPU memory.
            if model_spec_key is None:
                model_dict[key] = torch.empty(model_dict[key].shape, dtype=model_dict[key].dtype, device="meta")
                continue
            
            storage_key = model_dict[key].key
            if storage_key != last_storage_key or model_dict[key].seek_offset < current_offset:
                last_storage_key = storage_key
                if isinstance(f, zipfile.ZipExtFile):
                    f.close()
                try:
                    f = z.open(f"archive/data/{storage_key}")
                except:
                    f = z.open(f"{zipfolder}/data/{storage_key}")
                current_offset = 0
            if current_offset != model_dict[key].seek_offset:
                f.read(model_dict[key].seek_offset - current_offset)
                current_offset = model_dict[key].seek_offset
            if not isinstance(model_dict[key], torch_lazy_loader.LazyTensor):
                error = f"Duplicate key {repr(key)}"
                print("\n\nERROR:  " + error, file=sys.stderr)
                raise RuntimeError(error)

            
            size = functools.reduce(lambda x, y: x * y, model_dict[key].shape, 1)
            dtype = model_dict[key].dtype
            nbytes = size if dtype is torch.bool else size * ((torch.finfo if dtype.is_floating_point else torch.iinfo)(dtype).bits >> 3)
            tensor = model_dict[key].materialize(f, map_location="cpu")
            model_dict[key] = tensor.to("meta")
            current_offset += nbytes

            # Transform by spec
            spec = model_spec[model_spec_key]

            transforms = set(spec.get("transforms", ()))
            if "remove_first_two_rows" in transforms:
                tensor = tensor[2:]
            if "divide_by_shards" in transforms:
                tensor /= mtj_config["cores_per_replica"]
            if "vocab_pad" in transforms:
                tensor = torch.nn.functional.pad(tensor, (0,) * (tensor.ndim * 2 - 1) + (mtj_config["n_vocab_padding"],))

            print(spec["module"],spec["param"])
            old_shape = network.state["params"][spec["module"]][spec["param"]].shape

            tensor = tensor.unsqueeze(0)

            tensor = reshard_reverse(tensor, old_shape)

            if np.isnan(tensor).any() or np.isinf(tensor).any():
                raise ValueError(f"bfloat16 overflow/underflow")
            assert tensor.shape == old_shape
            tensor = jnp.array(tensor.detach())
            
            # network.state["params"][spec["module"]][spec["param"]] = move_xmap(tensor, np.empty(cores_per_replica))
            network.state["params"][spec["module"]][spec["param"]] = move_xmap(tensor, np.empty(cores_per_replica))

    print(f"DONE file loaded")


In [None]:
with torch_lazy_loader.use_lazy_torch_load(callback=callback, dematerialized_modules=True):
    # torch_checkpoint = torch.load(path_to_checkpoint, map_location="cpu")
    model = AutoModelForCausalLM.from_pretrained(model_name_or_path, cache_dir="cache")
    

In [None]:
for mk, mv in network.state["params"].items():
    for pk, pv in mv.items():
        if isinstance(pv, PlaceholderTensor):
            # The transformers GPT-J models apparently do not
            # have embedding bias, whereas MTJ GPT-J models do,
            # so we have to supplement an embedding bias tensor
            # by creating a tensor with the necessary shape, filled
            # with zeros.
            if mk == "causal_transformer_shard/~/embedding_shard/~/linear" and pk == "b":
                # mv[pk] = move_xmap(jnp.zeros(mv[pk].shape, dtype=jnp.bfloat16), np.empty(params["cores_per_replica"]))
                mv[pk] = move_xmap(jnp.zeros(mv[pk].shape, dtype=jnp.float32), np.empty(params["cores_per_replica"]))

            else:
                error = f"{mk} {pk} could not be found in the model checkpoint"
                print("\n\nERROR:  " + error, file=sys.stderr)
                raise RuntimeError(error)

In [None]:
# network = CausalTransformer(mtj_config)

In [None]:
network.state = move_xmap(network.state, np.empty(cores_per_replica))

In [None]:
def infer(context, top_p=0.9, temp=1.0, gen_len=512):
    tokens = tokenizer.encode(context)

    provided_ctx = len(tokens)
    pad_amount = seq - provided_ctx

    padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
    batched_tokens = np.array([padded_tokens] * total_batch)
    length = np.ones(total_batch, dtype=np.uint32) * len(tokens)

    start = time.time()
    output = network.generate(
        batched_tokens,
        length,
        gen_len,
        {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp}
    )

    samples = []
    decoded_tokens = output[1][0]
    # print("output", len(output[0]))
    for o in decoded_tokens[:, :, 0]:
      samples.append(tokenizer.decode(o))

      #samples.append(f"\033[1m{context}\033[0m{tokenizer.decode(o)}")
        # single = o[0][0, 0, seq : seq + gen_len]
        # print("single", single, tokenizer.decode(single))
    print(f"completion done in {time.time() - start:06}s")
    return samples

# print(infer("EleutherAI is")[0])
# print(infer("EleutherAI is"))

In [None]:
#@title  { form-width: "300px" }
top_p = 0.9 #@param {type:"slider", min:0, max:1, step:0.1}
temp = 1 #@param {type:"slider", min:0, max:1, step:0.1}

#context = """In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English."""
context = """The world of tomorrow is going"""
print(infer(top_p=top_p, temp=temp, gen_len=10, context=context))