In [1]:
import datasets
from transformers import AutoTokenizer
import equinox as eqx
from src import GPT, GPTConfig
import jax.random as jr


DATASET_PATH = "dataset"
CONFIG = GPTConfig()
RANDOM = jr.PRNGKey(79)

dataset = datasets.load_dataset("roneneldan/TinyStories")
dataset = dataset["train"].take(10)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

def tokenize(example):
    return tokenizer.batch_encode_plus(example["text"], padding="max_length", truncation=True, max_length=CONFIG.max_position_embeddings, return_tensors="pt")

tokenized_data = dataset.map(
    tokenize, remove_columns=["text"], batched=True, batch_size=10
)

tokenized_data = tokenized_data.with_format("jax")
model = GPT(CONFIG, RANDOM)
model = eqx.tree_deserialise_leaves("./gpt2.eqx", model)

  from .autonotebook import tqdm as notebook_tqdm
  self.attn = CausalSelfAttention(config, key=key1)


In [6]:
# ok let's assume gpt-2 encodings by default
import jax
from transformers import AutoTokenizer
import jax.numpy as jnp
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions

jax.config.update("jax_log_compiles", True)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

MAX_NEW_TOKENS = 100
TOP_K = None
TEMPERATURE=1
START = "Once upon a time"

model = eqx.nn.inference_mode(model, True)

start_ids = tokenizer.encode(START, padding="max_length", max_length=CONFIG.max_position_embeddings, add_special_tokens=True)
idx = jnp.array(start_ids)
mask = jnp.tri(MAX_NEW_TOKENS, CONFIG.max_position_embeddings, 0)
key = jr.key(1)
start_idx = jnp.array(tokenizer.encode(START)).shape[-1]-1
print(f"Starting index: {start_idx}")
return_dict = BaseModelOutputWithPastAndCrossAttentions()

for i in range(start_idx, MAX_NEW_TOKENS):
    print(tokenizer.decode(idx, skip_special_tokens=True))
    logits, return_dict = model(idx, past_key_values=return_dict.past_key_values, attention_mask=mask[i], dropout_key=key, return_dict=True)
    logits = logits[i]
    # pluck the logits at the final step and scale by desired temperature
    logits = logits / TEMPERATURE
    # optionally crop the logits to only the top k options
    if TOP_K is not None:
        v, _ = jax.lax.top_k(logits, min(TOP_K, logits.shape[-1]))
        logits = jnp.where(jnp.less(logits, v), -jnp.inf, logits) 
    # apply softmax to convert logits to (normalized) probabilities
    key, k = jr.split(key)
    idx_next = jr.categorical(k, logits)
    # idx_next = jax.numpy.argmax(logits, axis=-1)

    # append sampled index to the running sequence and continue
    print(idx_next)
    idx = idx.at[i+1].set(idx_next)



Starting index: 3
Once upon a time




11
Once upon a time,
612
Once upon a time, there
373
Once upon a time, there was
257
Once upon a time, there was a
26188
Once upon a time, there was a puppy
508
Once upon a time, there was a puppy who
6151
Once upon a time, there was a puppy who loved
284
Once upon a time, there was a puppy who loved to
1702
Once upon a time, there was a puppy who loved to sing
7259
Once upon a time, there was a puppy who loved to sing songs
484
Once upon a time, there was a puppy who loved to sing songs they
1088
Once upon a time, there was a puppy who loved to sing songs they around
262
Once upon a time, there was a puppy who loved to sing songs they around the
3952
Once upon a time, there was a puppy who loved to sing songs they around the park
13
Once upon a time, there was a puppy who loved to sing songs they around the park.
1375
Once upon a time, there was a puppy who loved to sing songs they around the park. She
925
Once upon a time, there was a puppy who loved to sing songs they around the par

KeyboardInterrupt: 

Array([[1., 0., 0., 0.],
       [1., 1., 0., 0.],
       [1., 1., 1., 0.],
       [1., 1., 1., 1.]], dtype=float32)