In [1]:
import os, sys

# notebook
sys.path.append(os.path.abspath(os.path.join("..")))

import jax

jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
import flax.linen as nn

from transformers import GPT2TokenizerFast

from tx.models.gpt2 import PretrainedGPT2Model
from tx.modules import HookMap, HookPoint, Hook
from tx.network import GenerativeModel


In [2]:
config = PretrainedGPT2Model.tx_config
config.decode = True


def store_hook(x, module: nn.Module, hook_point: HookPoint):
    module.sow("intermediates", hook_point.value, x)
    return x


reference_gpt2 = GenerativeModel(
    config=config,
    tokenizer=GPT2TokenizerFast.from_pretrained("gpt2"),
    params=PretrainedGPT2Model.from_pretrained("gpt2").to_params(),
    hooks=HookMap(embed=Hook(store_hook)),
    hook_collections=["intermediates"],
)


In [3]:
reference_text = "I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!"
tokens = reference_gpt2.to_tokens(reference_text, prepend_bos=True)
print(tokens)


[50256    40   716   281  4998  1960   382 19741    11   875 12342    12
  8807    11   402 11571    12    17  3918 47385    13  1881  1110   314
   481  7074  1692  1241  4430   290  1011   625   262   995     0]


In [4]:
# print(reference_gpt2.to_str(tokens), end="", flush=True)

prompt = tokens[None, 0]
for i in range(len(tokens)):
    print(reference_gpt2.to_str(prompt[i]), end="", flush=True)
    reference_gpt2(prompt)
    if i < len(tokens) - 1:
        prompt = jnp.concatenate([prompt, tokens[None, i + 1]], axis=-1)

for i in range(10):
    # Pass sequence through the model to get new output
    logits, _ = reference_gpt2(prompt)
    # Get the predicted token at the end of our sequence
    next_token = jnp.argmax(logits[-1], axis=-1, keepdims=True)
    print(next_token)
    # Decode and print the result
    next_char = reference_gpt2.to_str(next_token)
    print(next_char, end="", flush=True)
    # Define new input sequence, by appending the previously generated token
    prompt = jnp.concatenate([prompt, next_token], axis=-1)


<|endoftext|>I am an amazing autoregressive, decoder

KeyboardInterrupt: 