In [1]:
import jax
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [3]:
import jax
import jax.numpy as jnp
import jax.random as rand
from transformers import LlamaTokenizer
from tqdm import tqdm
from lib.llama import Llama
from lib.logits_processing import PresencePenaltyProcessor, TopKSampler, TopPSampler, make_logits_processor
from lib.param_utils import load_params
from lib.multihost_utils import shard_model_params
from lib.seeding import BEST_INTEGER

# Converting LLAMA to JAX consumes significant amount of time and resources so it's best to convert it beforehand
def load_params_from_disk() -> Llama:
    cpu_device = jax.devices('cpu')[0]
    with jax.default_device(cpu_device):
        params = load_params('/home/divyapatel0273/llama-2-jax/llama2-13B.pickle')
        params = jax.tree_map(lambda x: x.astype(jnp.bfloat16), params)
    params = shard_model_params(params)
    return params

TOPICS = ["Car-free cities",
"Does the electoral college work?",
"Exploring Venus",
"The Face on Mars",
"Facial action coding system",
"Seeking multiple opinions",
"Phones and driving"]
top_k = 3
params = load_params_from_disk()
print('Successfully loaded model parameters!')
key = rand.key(BEST_INTEGER, impl='rbg')
tokenizer = LlamaTokenizer.from_pretrained('NousResearch/Nous-Hermes-llama-2-13b', padding_side='left')
tokenizer.pad_token = tokenizer.eos_token
logits_processor = make_logits_processor(
    PresencePenaltyProcessor(penalty=0.05),
    TopKSampler(top_k=top_k),
        # TopPSampler(top_p=top_p),
)

tcmalloc: large alloc 5662318592 bytes == 0x9c41d4000 @  0x7fa9c979d680 0x7fa9c97be824 0x5e4640 0x63e74d 0x6a71b2 0x551054 0x4738f6 0x5ed4cb 0x63b015 0x58e2e0 0x6e019f 0x6e0427 0x6e2053 0x591890 0x70e39c 0x645bf4 0x5911dc 0x70e39c 0x645bf4 0x5911dc 0x70e39c 0x70e637 0x629a97 0x63b015 0x58e2e0 0x56766e 0x636cc9 0x639a74 0x592245 0x70e39c 0x645bf4
tcmalloc: large alloc 5662318592 bytes == 0xb161d6000 @  0x7fa9c979d680 0x7fa9c97be824 0x5e4640 0x63e74d 0x6a71b2 0x551054 0x4738f6 0x5ed4cb 0x63b015 0x58e2e0 0x6e019f 0x6e0427 0x6e2053 0x591890 0x70e39c 0x645bf4 0x5911dc 0x70e39c 0x645bf4 0x5911dc 0x70e39c 0x70e637 0x629a97 0x63b015 0x58e2e0 0x56766e 0x636cc9 0x639a74 0x592245 0x70e39c 0x645bf4
tcmalloc: large alloc 5662318592 bytes == 0xc679d8000 @  0x7fa9c979d680 0x7fa9c97be824 0x5e4640 0x63e74d 0x6a71b2 0x551054 0x4738f6 0x5ed4cb 0x63b015 0x58e2e0 0x6e019f 0x6e0427 0x6e2053 0x591890 0x70e39c 0x645bf4 0x5911dc 0x70e39c 0x645bf4 0x5911dc 0x70e39c 0x70e637 0x629a97 0x63b015 0x58e2e0 0x56766e 0

XlaRuntimeError: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 1.95G. That was not possible. There are 1.90G free.; (0x0x0_HBM0)

In [None]:
from functools import partial
from typing import NamedTuple

import einops as op
import jax
from jax import Array
import jax.numpy as jnp
import jax.random as rand
from transformers import LlamaTokenizer
from typing import Callable

from lib.llama import KVCache, Llama, RotaryValues, forward_llama_model, get_rotary_values_at_position, make_rotary_values, model_config_llama2_7B, shift_left_kv_cache

@partial(jax.jit, static_argnames=('logits_processor',))
def _generate_first(params: Llama, seq: Array, attn_mask: Array, logits_processor: Callable, *, rotary_values: RotaryValues, key: Array) -> tuple[Array, Array, Array, KVCache]:
    qk_mask = op.rearrange(jnp.tril(op.einsum(attn_mask, attn_mask, 'B L1, B L2 -> B L1 L2')), 'B L1 L2 -> B 1 1 L1 L2')  # causal QK mask
    outputs, kv_cache = forward_llama_model(params.model, seq, qk_mask, rotary_values=rotary_values, model_config=model_config_llama2_7B._replace(return_kv_cache=True))

    logits = outputs[:, -1] @ params.lm_head
    selected_token_ids = logits_processor(logits, seq=seq, attn_mask=attn_mask, key=key)

    seq = jnp.roll(seq, -1, axis=-1).at[:, -1].set(selected_token_ids)
    attn_mask = jnp.roll(attn_mask, -1, axis=-1).at[:, -1].set(True)
    kv_cache = shift_left_kv_cache(kv_cache)

    return seq, attn_mask, selected_token_ids, kv_cache

class GenerationState(NamedTuple):
    seq: Array
    attn_mask: Array
    selected_token_ids: Array
    max_n_iters: Array
    rotary_values: RotaryValues
    rotary_values_position: Array
    kv_cache: KVCache
    key: Array

@partial(jax.jit, static_argnames=('logits_processor',))
def _generate_rest(params: Llama, seq: Array, attn_mask: Array, selected_token_ids: Array, max_n_iters: Array, logits_processor: Callable, *, rotary_values: RotaryValues, kv_cache: KVCache, key: Array) -> Array:
    def cond_fun(state: GenerationState) -> Array:
        return state.max_n_iters.astype(jnp.bool_)

    def body_fun(state: GenerationState) -> GenerationState:
        seq, attn_mask, selected_token_ids, max_n_iters, rotary_values, rotary_values_position, kv_cache, key = state

        seq_ = op.rearrange(selected_token_ids, 'B -> B 1')
        qk_mask = op.rearrange(attn_mask, 'B L -> B 1 1 1 L')
        rotary_values_ = get_rotary_values_at_position(rotary_values, rotary_values_position)
        outputs, kv_cache = forward_llama_model(params.model, seq_, qk_mask, rotary_values=rotary_values_, kv_cache=kv_cache, model_config=model_config_llama2_7B._replace(return_kv_cache=True))

        logits = outputs[:, -1] @ params.lm_head
        key, subkey = rand.split(key)
        selected_token_ids = logits_processor(logits, seq=seq, attn_mask=attn_mask, key=subkey)

        seq = jnp.roll(seq, -1, axis=-1).at[:, -1].set(selected_token_ids)
        attn_mask = jnp.roll(attn_mask, -1, axis=-1).at[:, -1].set(True)
        kv_cache = shift_left_kv_cache(kv_cache)

        rotary_values_position += 1
        max_n_iters -= 1
        # TODO: early stopping (ayaka's comment). Since the generation continues untill it reaches maximum length,
        # we have to include eos token to determine the end of generation
        return GenerationState(seq, attn_mask, selected_token_ids, max_n_iters, rotary_values, rotary_values_position, kv_cache, key)

    rotary_values_position = jnp.array(0, jnp.uint16)
    initial_state = GenerationState(seq, attn_mask, selected_token_ids, max_n_iters, rotary_values, rotary_values_position, kv_cache, key)
    final_state = jax.lax.while_loop(cond_fun, body_fun, initial_state)
    return final_state.seq

def generate(sentences: list[str], tokenizer: LlamaTokenizer, params: Llama, logits_processor: Callable, *, max_len: int, key: Array) -> list[str]:
    batch_size = len(sentences)

    inputs = tokenizer(sentences, padding='max_length', truncation=True, max_length=max_len, return_tensors='jax')
    seq = inputs.input_ids.astype(jnp.uint16)
    attn_mask = inputs.attention_mask.astype(jnp.bool_)
    assert not attn_mask.all(axis=-1).any(), 'No room for generation since the length of a sentence is greater than `max_length`.'

    leftpad_len = attn_mask.argmax(axis=-1).astype(jnp.uint16)
    rotary_values = make_rotary_values(leftpad_len, batch_size, max_len, model_config=model_config_llama2_7B)

    key, subkey = rand.split(key)
    seq, attn_mask, selected_token_ids, kv_cache = _generate_first(params, seq, attn_mask, logits_processor, rotary_values=rotary_values, key=subkey)

    max_n_iters = leftpad_len.min()
    key, subkey = rand.split(key)
    seq = _generate_rest(params, seq, attn_mask, selected_token_ids, max_n_iters, logits_processor, rotary_values=rotary_values, kv_cache=kv_cache, key=subkey)
    return tokenizer.batch_decode(seq, skip_special_tokens=False) # Not skipping special tokens is the only reason we have to type this function ourselves

In [None]:
import pandas as pd
essays = []
max_len = 768
num_essays = 1000
batch = [
            f'''
### Instruction:
Write an essay based on the topic provided as if you were a student. Your essay needs to be unique and convincing and not very long. Output nothing but the essay.

### Input:
{TOPICS[0]}

### Response:

''',
            f'''
### Instruction:
Write an essay based on the topic provided as if you were a student. Your essay needs to be unique and convincing and not very long. Output nothing but the essay.

### Input:
{TOPICS[1]}

### Response:

''',
            f'''
### Instruction:
Write an essay based on the topic provided as if you were a student. Your essay needs to be unique and convincing and not very long. Output nothing but the essay.

### Input:
{TOPICS[2]}

### Response:

''',
            f'''
### Instruction:
Write an essay based on the topic provided as if you were a student. Your essay needs to be unique and convincing and not very long. Output nothing but the essay.

### Input:
{TOPICS[3]}

### Response:

''',
            f'''
### Instruction:
Write an essay based on the topic provided as if you were a student. Your essay needs to be unique and convincing and not very long. Output nothing but the essay.

### Input:
{TOPICS[4]}

### Response:

''',
            f'''
### Instruction:
Write an essay based on the topic provided as if you were a student. Your essay needs to be unique and convincing and not very long. Output nothing but the essay.

### Input:
{TOPICS[5]}

### Response:

''',
            f'''
### Instruction:
Write an essay based on the topic provided as if you were a student. Your essay needs to be unique and convincing and not very long. Output nothing but the essay.

### Input:
{TOPICS[6]}

### Response:

''']
for _ in tqdm(range(num_essays)):
    key, subkey = rand.split(key)
    generated_essays_batch = generate(batch, tokenizer, params, logits_processor, max_len=max_len, key=subkey)
    for generated_essay in generated_essays_batch:
        template_index = generated_essay.find('### Response:')
        stripped_template_essay = generated_essay[template_index+len('### Response:')+1:]
        eos_index = stripped_template_essay.find('</s>')
        stripped_essay = stripped_template_essay[:eos_index]
        essays.append(stripped_essay.strip())
essays_df = pd.DataFrame({'text': essays, 'generated': 1})
essays_df

In [None]:
essays_df.to_csv('JAX_essays.csv', index=False)