In [4]:
print("Loading imports...\n")
import torch
import torch.nn.functional as F

import tiktoken
import random

from train_gpt2 import GPT, GPTConfig, load_model_from_save

device = "cuda" if torch.cuda.is_available() else "cpu"

model = GPT(GPTConfig(vocab_size=50304)).to(device)

optimizer = model.configure_optimizers(weight_decay=0.1, learning_rate=6e-4, device=device)
_, _ = load_model_from_save(model, "log/model_19072.pt", optimizer, device)


Loading imports...

Configured optimizer with weight decay
- num decayed param tensors: 50, totaling 124,354,560 parameters
- num non-decayed parameter tensors: 98, totaling 121,344 parameters
- fused AdamW: False

Loaded model from train step 19073 with val loss 3.09375


In [None]:
def generate_text(model, prompt, device, num_return_sequences=1, max_new_tokens=50, temp=0.8, top_k = 50):
    
    model.eval()

    sample_rng = torch.Generator(device=device)
    sample_rng.manual_seed(random.randint(0, 2**32 - 1))
    # sample_rng.manual_seed(42) # get different results for each process w/o affecting global rng 

    enc = tiktoken.get_encoding("gpt2")
    tokens = enc.encode(prompt)

    tokens = torch.tensor(tokens, dtype=torch.long) # (sequence_length, )
    tokens = tokens.unsqueeze(dim=0).repeat(num_return_sequences, 1) # (num_return_sequences, sequence_length)
    x = tokens.to(device)

    for t in range(max_new_tokens):
        with torch.inference_mode():
            logits, _ = model(x) # (B, T, vocab_size)
            logits = logits[:, -1, :] # (B, vocab_size)

            if temp == 0:
                next_token = torch.argmax(logits, dim=-1).unsqueeze(-1)
            
            else:
                logits = logits / temp
                probs = F.softmax(logits, dim=-1)
                topk_probs, topk_indices = torch.topk(probs, top_k, dim=-1) # (B, 50)
                # select a token
                ix = torch.multinomial(topk_probs, 1, generator=sample_rng) # (B, 1)
                next_token = torch.gather(topk_indices, -1, ix) # get the element at the index for each batch
                x = torch.cat((x, next_token), dim=1)

        if num_return_sequences == 1:
            token_id = next_token[0, 0].item()
            decoded_token = enc.decode([token_id])

            if t == 0:
                print(f"{prompt}{decoded_token}", end="")
            else:
                print(decoded_token, end="")

    if num_return_sequences > 1:
        for i in range(num_return_sequences):
            tokens = x[i].tolist()
            decoded = enc.decode(tokens)
            print(f"\nSequence #{i+1}: {decoded}")

In [194]:
print("\nGenerating text...")

# question, answer based prompts
prompt1 = "Question: Who wrote the play 'Romeo and Juliet'?\n"\
          "Answer: William Shakespeare\n"\
          "Question: What is the capital of Canada?\n"\
          "Answer:"

# generate_text(model, prompt1, device=device, temp=0.2, top_k=40, max_new_tokens=8)

prompt2 = "Question: What is the largest planet in our solar system?\n"\
          "Answer:"
# generate_text(model, prompt2, device=device, temp=0.1, top_k=40, max_new_tokens=10)



# summarization prompt

# not exactly summarizes and more using information memorized from fineweb data
prompt4 = """Article:  
The Amazon rainforest, often referred to as the "lungs of the Earth," produces around 20% of the world's oxygen. It is home to millions of species of plants and animals, many of which are found nowhere else. In recent decades, deforestation caused by logging, agriculture, and mining has threatened the biodiversity and climate stability provided by this vast ecosystem. Efforts by governments, NGOs, and indigenous groups continue to focus on conservation and sustainable development.

Summary:
"""
# generate_text(model, prompt4, device=device, temp=0.3, top_k=40, max_new_tokens=50)



# passage + Q & A
# works sometimes
# other times it completely doesn't answer the question or gives an answer not based on the passage

prompt5 = """
Article: Mount Everest is the highest mountain in the world, located between Nepal and China. The first confirmed ascent was made by Sir Edmund Hillary and Tenzing Norgay in 1953.
Q1: Where is Mount Everest located?
A1: Between Nepal and China

Q2: Who were the first climbers to reach the summit?
A2:"""

# Gives incorrect answer if no passage provided
# prompt5 = """
# Question: Who were the first climbers to reach the summit of Mount Everest?
# Answer:"""

# generate_text(model, prompt5, device=device, temp=0.2, top_k=40, max_new_tokens=20)



# creative text completion (my fav)
prompt6 = "I think that quantum mechanics is"
generate_text(model, prompt6, device=device, temp=1.2, max_new_tokens=100)



Generating text...
I think that quantum mechanics is the foundation of a bunch of interesting things. We are looking at quantum computing and in physics quantum dynamics is something that you can just say, "well we're building it, but we're really exploring things in the opposite direction." So it may be some form of a very small but amazing world that there are things happening in the unknown, perhaps even some strange world. And it all comes down to what we're hoping to figure out here, maybe maybe maybe someday we can be able to help build