# Press CTRL+F9 to run everything
This is based on the [original notebook](http://colab.research.google.com/github/kingoflolz/mesh-transformer-jax/blob/master/colab_demo.ipynb) from the authors of GPT-J-6B.

In [None]:
!apt install zstd
!time wget -c https://the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd
!time tar -I zstd -xf step_383500_slim.tar.zstd
!git clone https://github.com/kingoflolz/mesh-transformer-jax.git
!pip install -r mesh-transformer-jax/requirements.txt
!pip install mesh-transformer-jax/ jax==0.2.12 tensorflow==2.5.0

In [None]:
from IPython.display import HTML, display
def set_css():
  display(HTML('''<style>pre {white-space: pre-wrap;}</style>'''))
get_ipython().events.register('pre_run_cell', set_css)

In [None]:
import os
import requests
from jax.config import config

colab_tpu_addr = os.environ['COLAB_TPU_ADDR'].split(':')[0]
url = f'http://{colab_tpu_addr}:8475/requestversion/tpu_driver0.1_dev20210607'
requests.post(url)
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = "grpc://" + os.environ['COLAB_TPU_ADDR']

import time
import jax
from jax.experimental import maps
import numpy as np
import optax
import transformers
from mesh_transformer.checkpoint import read_ckpt
from mesh_transformer.sampling import nucleaus_sample
from mesh_transformer.transformer_shard import CausalTransformer

import random
import haiku as hk
import jax.numpy as jnp
from mesh_transformer.transformer_shard import CausalTransformerShard
class PenalizingCausalTransformer(CausalTransformer):
    '''
    This is a custom version of CausalTransformer I made that supports the
    repetition penalty described in this paper:
    https://arxiv.org/pdf/1909.05858.pdf
    '''
    def __init__(self, config):
        super().__init__(config)

        def generate(state, key, ctx, ctx_length, aux, sampler_options):
            sampler = config["sampler"]
            gen_length = self.gen_length

            self.endoftext_id = 50256

            def generate_sample(context, ctx_length, aux):
                transformer = CausalTransformerShard(config)
                _, initial_state = transformer.generate_initial(context, ctx_length)

                generated_range = jnp.arange(config["seq"])
                generated_mask = jnp.asarray(generated_range < ctx_length)[::-1]
                generated = jnp.where(generated_mask, context, 50256)
                generated = jnp.pad(generated, (0, gen_length), constant_values=50256)
                generated = jnp.tile(generated, (self.batch_size, 1))
                generated_index = config["seq"]

                initial_state = (generated, generated_index) + initial_state
                repetition_penalty = sampler_options.pop('repetition_penalty', None)

                def apply_penalty_2d(logits, tokens_2d, repetition_penalty):
                    shift = jnp.reshape(jnp.repeat(jnp.arange(tokens_2d.shape[0]) * logits.shape[1], tokens_2d.shape[1]), tokens_2d.shape)
                    penalty_logits = jnp.take(logits, tokens_2d + shift)
                    penalty_logits = jnp.where(penalty_logits > 0, penalty_logits/repetition_penalty, penalty_logits*repetition_penalty)
                    return logits.at[(jnp.repeat(jnp.arange(penalty_logits.shape[0]), penalty_logits.shape[1]), tokens_2d.flatten())].set(penalty_logits.flatten())

                def generate_scan_fn(carry, sampler_input):
                    generated, generated_index, next_token, decode_state, sample_key = carry
                    sample_key, new_key = jax.random.split(sample_key)

                    logits, new_state = transformer.generate_once(next_token, decode_state)

                    # Apply repetition penalty to tokens that have already
                    # appeared in the context or in tokens previously chosen
                    # by sampler() in this run of generate_sample()
                    if repetition_penalty is not None:
                        logits = apply_penalty_2d(logits, generated, repetition_penalty)

                    # Prevent <|endoftext|> from appearing in the output by
                    # setting its logit value to negative infinity
                    logits = logits.at[:, (50256,)].set(-jnp.inf)

                    next_token, sample_info = sampler(sample_key, logits, sampler_input, **sampler_options)

                    generated = generated.at[:, generated_index].set(next_token.flatten())
                    generated_index += 1

                    if self.return_logits:
                        output = (next_token, sample_info, logits)
                    else:
                        output = (next_token, sample_info)
                    new_carry = (generated, generated_index, next_token, new_state, new_key)
                    return new_carry, output

                final_state, outputs = jax.lax.scan(generate_scan_fn, initial_state, xs=aux, length=gen_length)
                return final_state, outputs

            generate_fn = hk.transform(generate_sample).apply
            return generate_fn(state["params"], key, ctx, ctx_length, aux)

        self.generate_xmap = jax.experimental.maps.xmap(fun=generate,
                                                        in_axes=(["shard", ...],
                                                                 ["batch", ...],
                                                                 ["batch", ...],
                                                                 ["batch", ...],
                                                                 ["batch", ...],
                                                                 ["batch", ...]),
                                                        out_axes=["batch", ...],
                                                        axis_resources={'shard': 'mp', 'batch': 'dp'})

    def generate(self, ctx, ctx_length, batch_size, gen_length, sampler_options, return_logits=False):
        key = hk.PRNGSequence(random.randint(0, 2 ** 60))

        batch_size = ctx.shape[0]
        aux = jnp.zeros((batch_size, gen_length), dtype=jnp.uint32)
        self.gen_length = gen_length
        self.batch_size = batch_size
        self.return_logits = return_logits

        return self.generate_xmap(self.state,
                                  jnp.array(key.take(batch_size)),
                                  ctx,
                                  np.array(ctx_length, dtype=np.uint32),
                                  aux,
                                  sampler_options)

params = {
  "layers": 28,
  "d_model": 4096,
  "n_heads": 16,
  "n_vocab": 50400,
  "norm": "layernorm",
  "pe": "rotary",
  "pe_rotary_dims": 64,

  "seq": 2048,
  "cores_per_replica": 8,
  "per_replica_batch": 1,
}

per_replica_batch = params["per_replica_batch"]
cores_per_replica = params["cores_per_replica"]
seq = params["seq"]


params["sampler"] = nucleaus_sample
params["optimizer"] = optax.scale(0)
mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)
devices = np.array(jax.devices()).reshape(mesh_shape)
maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp')))
tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
total_batch = per_replica_batch * jax.device_count() // cores_per_replica
network = PenalizingCausalTransformer(params)
network.state = read_ckpt(network.state, "step_383500/", devices.shape[1])
network.state = network.move_xmap(network.state, np.zeros(cores_per_replica))

from IPython.display import HTML, display

def infer(context, top_p=0.9, temp=0.5, repetition_penalty=1.2, gen_len=200):
    tokens = tokenizer.encode(context, max_length=params["seq"], truncation=True)
    provided_ctx = len(tokens)
    pad_amount = seq - provided_ctx
    padded_tokens = np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32)
    batched_tokens = np.array([padded_tokens] * total_batch)
    length = np.ones(total_batch, dtype=np.uint32) * len(tokens)
    output = network.generate(batched_tokens, length, total_batch, gen_len, {"top_p": np.ones(total_batch) * top_p, "temp": np.ones(total_batch) * temp, "repetition_penalty": np.ones(total_batch) * repetition_penalty})
    samples = []
    decoded_tokens = output[1][0]
    for o in decoded_tokens[:, :, 0]:
        samples.append(tokenizer.decode(o))
    return samples
 
# We need to pre-run this function once because the first run takes ~1 minute for compilation
context = "Hey guys! First, I'd like to give a big shout-out to Squarespace for making all of this possible."
print(f"\033[1m{context}", end='')
print(f"\033[0m{infer(context)[0]}")

In [None]:
top_p = 0.9 #@param {type:"slider", min:0, max:1, step:0.05}
temp = 0.5 #@param {type:"slider", min:0, max:1, step:0.05}
repetition_penalty = 1.2 #@param {type:"slider", min:1, max:1.3, step:0.005}

context = """My name is Yoshikage Kira."""

print(f"\033[1m{context}", end='')
print(f"\033[0m{infer(top_p=top_p, temp=temp, repetition_penalty=repetition_penalty, gen_len=200, context=context)[0]}")

## What do the sliders do? Here's a quick explanation.

* `temp` (temperature) is a positive real number. It makes the output more creative and spontaneous the larger it is. Low values make the output repetitive. Setting this higher than 1 is a bad idea.
* `top_p` (top-_p_) is a real number inclusively between 0 and 1. The lower you set it, the more the program tries to remove incongruous words and punctuation from the output. Setting this to 1 disables the effect. Setting this to exactly 0 makes the program always output the exact same thing for the same input parameters (i.e. makes it deterministic). Most people recommend setting this to somewhere in the 0.9-1.0 (inclusive) range. Setting this lower can help it answer knowledge-based questions better but makes it very bad at creative writing.
* `repetition_penalty` is a positive real number. If it's greater than 1, words that have already appeared are discouraged from appearing again; if it's less than 1 (don't do this) they're encouraged instead. Setting this to 1 disables the effect. It is recommended to set this to 1.2. Setting this higher than 1.2 can have disastrous results.

---

The language model works by reading through the `context` and assigning a "score" between 0 and 1 (where 0 is weakest and 1 is strongest) to each of the 50257 possible tokens in its vocabulary. A "token" is a particular sequence of characters, each one assigned an integer from 0 to 50256. We get the logits of those scores ("logit" is a mathematical function used in neural networks that maps real numbers between 0 and 1 exclusive to the entire real number line).

This JAX program works by sending the context to the language model and getting the logits, picking one token to append to the context, sending the new context back to the language model, and so on, the number of times specified by `gen_length`. The question is then how exactly we choose the token to append to the context based on the logits from the language model.

The easy way is just to pick the token with the greatest logit; this is called greedy sampling. You can enable greedy sampling by setting `top_p` to 0. Greedy sampling is good at answering questions but inferior at generating creative text.

Nucleus sampling works by first dividing the logit values by `temp`, sorting the tokens from greatest to least logit, removing some of the tokens with the least logits, then using the softmax function to pick one of the remaining ones (higher logit equals higher probability). Specifically, we remove some tokens by sorting the tokens from greatest to least logit, then calculating the cumulative probability of each token being chosen (i.e. the probability that this token or any tokens with a higher logit value are chosen), then tokens with a cumulative probability higher than `top_p` are removed. Setting `top_p` to 0 only keeps the token with the highest logit value making it equivalent to greedy sampling. The effect that dividing the logits by `temp` has is, if `temp` is greater than 1, it moves the logit values closer to to each other, and if less than 1 it moves them further away from each other. This results in the token choosing probabilities becoming more similar or more different.

I added in a method called penalized sampling, the original version of which was described in section 4.1 (pages 4-5) of this paper: https://arxiv.org/pdf/1909.05858.pdf. Tokens that have already appeared in the context (including tokens the program chose and added to the initial context) have their probabilities of being chosen artificially lowered by dividing their logit values by `repetition_penalty` if positive or multiplying if negative. The paper suggests 1.2 as a good value for `repetition_penalty`.