In [None]:
import sys
import os

project_root = os.path.abspath("..")
if project_root not in sys.path:
    sys.path.append(project_root)

import torch
import matplotlib.pyplot as plt

from config import GPT2Config
from model.gpt2 import GPT2
from tokenizer import get_tokenizer

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
config = GPT2Config()
model = GPT2(config).to(device)

checkpoint = torch.load("../checkpoints/gpt2_epoch_1.pt", map_location=device)

model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

print("Loaded epoch:", checkpoint["epoch"])

In [None]:
loss_history = checkpoint["loss_history"]

plt.plot(loss_history)
plt.title("Training Loss (Loaded from Checkpoint)")
plt.xlabel("Step")
plt.ylabel("Loss")
plt.show()

In [None]:
def generate(model, input_ids, max_new_tokens=50):
    model.eval()
    for _ in range(max_new_tokens):
        logits = model(input_ids)
        logits = logits[:, -1, :]
        probs = torch.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        input_ids = torch.cat([input_ids, next_token], dim=1)
    return input_ids

In [None]:
tokenizer = get_tokenizer()

prompt = "Deep learning is"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

output = generate(model, input_ids, max_new_tokens=50)

print(tokenizer.decode(output[0]))