In [7]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

%load_ext autoreload
%autoreload 2

In [1]:
import torch
from llm.transformer import TransformerLM
from llm.tokenization import Tokenizer
from llm.serialization import load_checkpoint
from llm.generation import generateLLM

In [2]:
tokenized_dataset_pkl = "/media/bryan/ssd01/expr/llm_from_scratch/tokenization/bpe_10k_tinystories.pkl"
eos_token = "<|endoftext|>"
tokenizer = Tokenizer.from_pickle(tokenized_dataset_pkl, special_tokens=[eos_token])
vocab_size = len(tokenizer.vocab)

In [3]:
model_pth = "/media/bryan/ssd01/expr/llm_from_scratch/tune-lr/1e-3/checkpoint_best.pt"
model = TransformerLM(
    vocab_size=vocab_size,
    context_length=256,
    num_layers=4,
    num_heads=16,
    d_model=512,
    d_ff=1344,
    rope_theta=10000,
)
checkpoint = torch.load(model_pth)
model.load_state_dict(checkpoint["model_state_dict"])

<All keys matched successfully>

In [4]:
model.eval()

TransformerLM(
  (token_embeddings): Embedding(vocab_size=10000, d=512)
  (RoPE): RotaryPositionalEmbedding(context_length=256, dim/2=16)
  (layers): ModuleList(
    (0-3): 4 x TransformerBlock(
      (attn): CausalMHSARoPE(
        (qkv_proj): Linear(d_out=1536, d_in=512)
        (output_proj): Linear(d_out=512, d_in=512)
        (RoPE): RotaryPositionalEmbedding(context_length=256, dim/2=16)
      )
      (ffn): SwiGLU(
        (w1): Linear(d_out=1344, d_in=512)
        (w2): Linear(d_out=512, d_in=1344)
        (w3): Linear(d_out=1344, d_in=512)
      )
      (ln1): RMSNorm(hidden_size=512, eps=1e-05)
      (ln2): RMSNorm(hidden_size=512, eps=1e-05)
    )
  )
  (ln_final): RMSNorm(hidden_size=512, eps=1e-05)
  (lm_head): Linear(d_out=10000, d_in=512)
)

In [5]:
model.to("cuda")

TransformerLM(
  (token_embeddings): Embedding(vocab_size=10000, d=512)
  (RoPE): RotaryPositionalEmbedding(context_length=256, dim/2=16)
  (layers): ModuleList(
    (0-3): 4 x TransformerBlock(
      (attn): CausalMHSARoPE(
        (qkv_proj): Linear(d_out=1536, d_in=512)
        (output_proj): Linear(d_out=512, d_in=512)
        (RoPE): RotaryPositionalEmbedding(context_length=256, dim/2=16)
      )
      (ffn): SwiGLU(
        (w1): Linear(d_out=1344, d_in=512)
        (w2): Linear(d_out=512, d_in=1344)
        (w3): Linear(d_out=1344, d_in=512)
      )
      (ln1): RMSNorm(hidden_size=512, eps=1e-05)
      (ln2): RMSNorm(hidden_size=512, eps=1e-05)
    )
  )
  (ln_final): RMSNorm(hidden_size=512, eps=1e-05)
  (lm_head): Linear(d_out=10000, d_in=512)
)

In [20]:
PROMPTS = [
#     "Once upon a time there was a little boy named Ben. Ben loved to",
    "Once upon a time, there was a pretty girl named Lily. She loved to eat",
]
for prompt in PROMPTS:
    generated_text = generateLLM(
        model,
        tokenizer,
        prompt,
        max_new_tokens=256,
        temperature=1.0,
        top_k=25,
        top_p=0.95,
        eos_token=eos_token,
        seed=42,
    )
    print_text = f"PROMPT:\n{prompt}\nGENERATED:\n{prompt}{generated_text}\n"
    print(print_text)

  """


PROMPT:
Once upon a time, there was a pretty girl named Lily. She loved to eat
GENERATED:
Once upon a time, there was a pretty girl named Lily. She loved to eat yummy food like apples, cheeseed pizza, and toasted cookies. One day, her mom said she had to leave the house to go and play. 
Lily was sad, but she knew that waiting would be back soon. She said goodbye to her mom and she went to play with her blocks. But when she got there, she could not find any delicious slices to eat. She started to cry. 
Lily went outside to look for more food but couldn't find any. She looked in the garden, but there was no yummy things. Suddenly, she saw a small hole in the fence. She tried to reach inside, but it was too far away. 
Lily had an idea. She put her hands and pulled out some of the small pieces. Then she ran back to the house and got the yummy cookies. From that day on, Lily made sure to close the door so she could get out safely.
<|endoftext|>



In [13]:
eos_token_id = tokenizer.encode(eos_token)
eos_token_id.pop()

256