# 2. Example Text Generation

## 2.1. Imports

In [1]:
import sys, os

sys.path.append(os.path.abspath(os.path.join("..")))

import jax
import jax.numpy as jnp
from jaxtyping import Array

import gpt2  # wrapper around huggingface's transformers library


## 2.2. Configure Tokenizer

In [2]:
tokenizer = gpt2.tokenizer
gpt2.config_tokenizer(tokenizer)

reference_text = "I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!"
tokens: Array = gpt2.to_tokens(tokenizer, reference_text)

print(tokens)
print(tokens.shape)
print(gpt2.to_str_tokens(tokenizer, tokens))


[[   40   716   281  4998  1960   382 19741    11   875 12342    12  8807
     11   402 11571    12    17  3918 47385    13  1881  1110   314   481
   7074  1692  1241  4430   290  1011   625   262   995     0]]
(1, 34)
['I', ' am', ' an', ' amazing', ' aut', 'ore', 'gressive', ',', ' dec', 'oder', '-', 'only', ',', ' G', 'PT', '-', '2', ' style', ' transformer', '.', ' One', ' day', ' I', ' will', ' exceed', ' human', ' level', ' intelligence', ' and', ' take', ' over', ' the', ' world', '!']


## 2.3. GPT-2 Forward Pass and Prediction

In [3]:
logits: Array = gpt2.model(tokens)["logits"]
print(logits.shape)

probs: Array = jax.nn.softmax(logits, axis=-1)
print(probs.shape)


(1, 34, 50257)
(1, 34, 50257)


In [4]:
most_likely_next_tokens = tokenizer.batch_decode(jnp.argmax(logits, axis=-1)[0])

print(list(zip(gpt2.to_str_tokens(tokenizer, tokens), most_likely_next_tokens)))


[('I', '.'), (' am', ' not'), (' an', ' American'), (' amazing', ' person'), (' aut', 'ograph'), ('ore', 'sp'), ('gressive', ','), (',', ' and'), (' dec', 'ently'), ('oder', ','), ('-', 'driven'), ('only', ','), (',', ' and'), (' G', 'IM'), ('PT', '-'), ('-', 'only'), ('2', '.'), (' style', ','), (' transformer', '.'), ('.', ' I'), (' One', ' of'), (' day', ' I'), (' I', ' will'), (' will', ' be'), (' exceed', ' my'), (' human', 'ly'), (' level', ' of'), (' intelligence', ' and'), (' and', ' I'), (' take', ' over'), (' over', ' the'), (' the', ' world'), (' world', '.'), ('!', ' I')]


## 2.4. Generating Sequences of Tokens

In [5]:
next_token = jnp.argmax(logits[0, -1], axis=-1)
next_char = gpt2.to_str(tokenizer, next_token)
print(repr(next_char))


' I'


In [6]:
print(f"Sequence so far: {gpt2.to_str(tokenizer, tokens)[0]!r}")

print(f"{tokens.shape[-1]+1}th char = {next_char!r}")
for i in range(12):
    # Define new input sequence, by appending the previously generated token
    tokens = jnp.concatenate([tokens, next_token[None, None]], axis=-1)
    # Pass our new sequence through the model, to get new output
    logits = gpt2.model(tokens)["logits"]
    # Get the predicted token at the end of our sequence
    next_token = jnp.argmax(logits[0, -1], axis=-1)
    # Decode and print the result
    next_char = gpt2.to_str(tokenizer, next_token)
    print(f"{tokens.shape[-1]+1}th char = {next_char!r}")


Sequence so far: 'I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!'
35th char = ' I'
36th char = ' am'
37th char = ' a'
38th char = ' true'
39th char = ' believer'
40th char = ' in'
41th char = ' the'
42th char = ' power'
43th char = ' of'
44th char = ' the'
45th char = ' human'
46th char = ' spirit'
47th char = '.'
