In [7]:
import torch
import torch.nn.functional as F
from model import GPT, GPTConfig
import tiktoken
import os

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

def load_model(checkpoint_path="model.pt"):
    config = GPTConfig()
    model = GPT(config)

    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"{checkpoint_path} not found")

    ckpt = torch.load(checkpoint_path, map_location=device)

    # Case A: full model saved
    if isinstance(ckpt, GPT):
        print("Loaded full model object")
        model = ckpt
        model.to(device)
        model.eval()
        return model

    # Case B: wrapped checkpoint {"model_state_dict": ...}
    if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
        print("Loaded checkpoint with model_state_dict")
        model.load_state_dict(ckpt["model_state_dict"])
        print("Training loss:", ckpt.get("loss", "N/A"))
    else:
        # Case C: raw state dict
        print("Loaded raw state_dict")
        model.load_state_dict(ckpt)

    model.to(device)
    model.eval()
    return model


# ----------------------------------------------------------------------------
# Sampling function
# ----------------------------------------------------------------------------
enc = tiktoken.get_encoding("gpt2")

def generate(model, prompt, max_new_tokens=100, temperature=1.0, top_k=50):
    model.eval()
    # Encode prompt
    tokens = enc.encode(prompt)
    idx = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device)

    for _ in range(max_new_tokens):
        # Forward pass
        logits, _ = model(idx)

        # Take last token logits
        logits = logits[:, -1, :] / temperature

        # Top-k filtering
        if top_k is not None:
            values, indices = torch.topk(logits, k=top_k)
            logits = logits.masked_fill(logits < values[:, -1:], float('-inf'))

        probs = F.softmax(logits, dim=-1)

        # Sample token
        next_token = torch.multinomial(probs, num_samples=1)

        # Append to sequence
        idx = torch.cat([idx, next_token], dim=1)

    return enc.decode(idx[0].tolist())


# ----------------------------------------------------------------------------
# Run sample generations
# ----------------------------------------------------------------------------
model = load_model("model.pt")

print("\n----- SAMPLE OUTPUT 1 -----")
print(generate(model, "Once upon a time", max_new_tokens=150))

print("\n----- SAMPLE OUTPUT 2 -----")
print(generate(model, "The meaning of life is", max_new_tokens=150))

print("\n----- SAMPLE OUTPUT 3 -----")
print(generate(model, "In the future,", max_new_tokens=150))

Device: cuda
Loaded checkpoint with model_state_dict
Training loss: 0.08488673716783524

----- SAMPLE OUTPUT 1 -----
Once upon a time calm.
That Henry yields thus said,
And thus ne'er burn as light in London send his hands,
Were buckle beautled only dislike:
With many misforce's course, calm enter, with a titportion'd thy father!

ROMEO:
Good time, mad!

ROMEO:
I will not believe a buried man! yet, nurseful brother's house, as Oumerle, madars, madre we be bold in such woe
That sorrow that his high disgrace and measure of philosophy.

ROMEO:
What says I am told to fagns best Paris.
An old blocks's groarth, hath Romeo's short sudden with Saint George-day's

----- SAMPLE OUTPUT 2 -----
The meaning of life is accounted pay Lancaster with red bones
person straight; at Lancaster; where they have?

HENRY BOLINGBROKE:
Go on that we blow strong; to bear me leave.

HENRY PERCY:
My lord, my Lord with English choice, Cbot!'
ever till me to rise, that we love in vain;
The town since this vile strif