In [1]:
import os, sys

sys.path.append(os.path.abspath(os.path.join("..")))

import jax

jax.config.update("jax_enable_x64", True)

from jaxtyping import Array
import jax.numpy as jnp
import flax.linen as nn

from transformers import GPT2TokenizerFast

from tx.models.gpt2 import PretrainedGPT2Model
from tx.hooks import HookPoint, Hook
from tx.network import GenerativeModel


In [2]:
config = PretrainedGPT2Model.tx_config
config.decode = True


def store_hook(x, module: nn.Module, hook_point: HookPoint):
    module.sow("intermediates", hook_point.value, x)
    return x


reference_gpt2 = GenerativeModel(
    config=config,
    tokenizer=GPT2TokenizerFast.from_pretrained("gpt2"),
    params=PretrainedGPT2Model.from_pretrained("gpt2").to_params(),
    hooks={HookPoint.ATTN_OUTPUT.value: Hook(store_hook)},
    hook_collections=["intermediates"],
)


In [3]:
reference_text = "Hello, I am"
tokens: Array = reference_gpt2.to_tokens(reference_text, prepend_bos=True)
print(tokens)


[50256 15496    11   314   716]


In [4]:
# print(reference_gpt2.to_str(tokens), end="", flush=True)
# for i in range(tokens.shape[0]):
#     # Pass sequence through the model to get new output
#     logits, _ = reference_gpt2(tokens[: i + 1])
#     # Get the predicted token at the end of our sequence
#     next_token = jnp.argmax(logits, axis=-1)[-1]
#     # Decode and print the result
#     next_char = reference_gpt2.to_str(next_token)
#     print(f"next_token[{i}]: {next_char}")


In [5]:
logits, _ = reference_gpt2(tokens)
next_token = jnp.argmax(jax.nn.softmax(logits), axis=-1)[-1]
next_char = reference_gpt2.to_str(next_token)
print(next_char)

 a


In [6]:
cur_tokens = tokens
for i in range(10):
    # Pass sequence through the model to get new output
    logits, _ = reference_gpt2(cur_tokens)
    # Get the predicted token at the end of our sequence
    next_token = jnp.argmax(logits, axis=-1)[-1]
    # Decode and print the result
    next_char = reference_gpt2.to_str(next_token)
    print(next_char)
    # Define new input sequence, by appending the previously generated token
    cur_tokens = jnp.append(cur_tokens, next_token)


 a
 student
 at
 the
 University
 of
 California
,
 Berkeley
.
