In [1]:
import os, sys

sys.path.append(os.path.abspath(".."))


In [2]:
import jax
import jax.numpy as jnp

from transformers import GPT2TokenizerFast

from tx.models import PretrainedGPT2Model
from tx.network import GenerativeModel


In [3]:
reference_gpt2 = GenerativeModel(
    config=PretrainedGPT2Model.tx_config,
    variables={"params": PretrainedGPT2Model.from_pretrained("gpt2").to_params()},
    tokenizer=GPT2TokenizerFast.from_pretrained("gpt2"),
)


2023-09-10 03:30:50.603138: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:268] failed call to cuInit: CUDA_ERROR_UNKNOWN: unknown error
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [4]:
sorted_vocab = sorted(
    list(reference_gpt2.tokenizer.get_vocab().items()),
    key=lambda n: n[1],
)
print(sorted_vocab[:20])
print()
print(sorted_vocab[250:270])
print()
print(sorted_vocab[990:1010])
print()


[('!', 0), ('"', 1), ('#', 2), ('$', 3), ('%', 4), ('&', 5), ("'", 6), ('(', 7), (')', 8), ('*', 9), ('+', 10), (',', 11), ('-', 12), ('.', 13), ('/', 14), ('0', 15), ('1', 16), ('2', 17), ('3', 18), ('4', 19)]

[('ľ', 250), ('Ŀ', 251), ('ŀ', 252), ('Ł', 253), ('ł', 254), ('Ń', 255), ('Ġt', 256), ('Ġa', 257), ('he', 258), ('in', 259), ('re', 260), ('on', 261), ('Ġthe', 262), ('er', 263), ('Ġs', 264), ('at', 265), ('Ġw', 266), ('Ġo', 267), ('en', 268), ('Ġc', 269)]

[('Ġprodu', 990), ('Ġstill', 991), ('led', 992), ('ah', 993), ('Ġhere', 994), ('Ġworld', 995), ('Ġthough', 996), ('Ġnum', 997), ('arch', 998), ('imes', 999), ('ale', 1000), ('ĠSe', 1001), ('ĠIf', 1002), ('//', 1003), ('ĠLe', 1004), ('Ġret', 1005), ('Ġref', 1006), ('Ġtrans', 1007), ('ner', 1008), ('ution', 1009)]



In [5]:
print(sorted_vocab[-20:])

[('Revolution', 50237), ('Ġsnipers', 50238), ('Ġreverted', 50239), ('Ġconglomerate', 50240), ('Terry', 50241), ('794', 50242), ('Ġharsher', 50243), ('Ġdesolate', 50244), ('ĠHitman', 50245), ('Commission', 50246), ('Ġ(/', 50247), ('âĢ¦."', 50248), ('Compar', 50249), ('Ġamplification', 50250), ('ominated', 50251), ('Ġregress', 50252), ('ĠCollider', 50253), ('Ġinformants', 50254), ('Ġgazed', 50255), ('<|endoftext|>', 50256)]


In [6]:
print(reference_gpt2.to_str_tokens("Ralph", prepend_bos=True, truncate=False))
print(reference_gpt2.to_str_tokens(" Ralph", prepend_bos=True, truncate=False))
print(reference_gpt2.to_str_tokens(" ralph", prepend_bos=True, truncate=False))
print(reference_gpt2.to_str_tokens("ralph", prepend_bos=True, truncate=False))

['<|endoftext|>', 'R', 'alph']
['<|endoftext|>', ' Ralph']
['<|endoftext|>', ' r', 'alph']
['<|endoftext|>', 'ral', 'ph']


In [7]:
print(reference_gpt2.to_str_tokens("56873+3184623=123456789-1000000000"))

['568', '73', '+', '318', '46', '23', '=', '123', '45', '67', '89', '-', '1', '000000', '000']


In [8]:
reference_text = "I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!"
tokens = reference_gpt2.to_tokens(reference_text, prepend_bos=True)
print(tokens)
print(tokens.shape)
print(reference_gpt2.tokens_to_str(tokens))

[50256    40   716   281  4998  1960   382 19741    11   875 12342    12
  8807    11   402 11571    12    17  3918 47385    13  1881  1110   314
   481  7074  1692  1241  4430   290  1011   625   262   995     0]
(35,)
['<|endoftext|>', 'I', ' am', ' an', ' amazing', ' aut', 'ore', 'gressive', ',', ' dec', 'oder', '-', 'only', ',', ' G', 'PT', '-', '2', ' style', ' transformer', '.', ' One', ' day', ' I', ' will', ' exceed', ' human', ' level', ' intelligence', ' and', ' take', ' over', ' the', ' world', '!']


In [9]:
logits, state = reference_gpt2(tokens, intermediates=["embedding"])
print(logits.shape)

(35, 50257)


In [10]:
probs = jax.nn.softmax(logits, axis=-1)
print(probs.shape)

(35, 50257)


In [14]:
most_likely_next_tokens = reference_gpt2.tokens_to_str(jnp.argmax(logits, axis=-1))
print(list(zip(reference_gpt2.tokens_to_str(tokens), most_likely_next_tokens)))

[('<|endoftext|>', '\n'), ('I', "'m"), (' am', ' a'), (' an', ' avid'), (' amazing', ' person'), (' aut', 'od'), ('ore', 'sp'), ('gressive', '.'), (',', ' and'), (' dec', 'ently'), ('oder', ','), ('-', 'driven'), ('only', ' programmer'), (',', ' and'), (' G', 'IM'), ('PT', '-'), ('-', 'only'), ('2', '.'), (' style', ','), (' transformer', '.'), ('.', ' I'), (' One', ' of'), (' day', ' I'), (' I', ' will'), (' will', ' be'), (' exceed', ' my'), (' human', 'ly'), (' level', ' of'), (' intelligence', ' and'), (' and', ' I'), (' take', ' over'), (' over', ' the'), (' the', ' world'), (' world', '.'), ('!', ' I')]
