# Imports

In [15]:
import os
import json
import torch
import warnings
import torch.nn as nn

from collections import Counter

# Settings

In [16]:
MODELS_DIR = "../Models"
DATA_PATH = "../Data/shakespeare.txt"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model Definition

In [17]:
class SimpleTransformer(nn.Module):

    def __init__(self, vocab_size, embed_dim, num_heads, hidden_dim, num_layers, seq_len):

        super(SimpleTransformer, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.position_embedding = nn.Embedding(seq_len, embed_dim)

        self.transformer_blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=embed_dim, nhead=num_heads, dim_feedforward=hidden_dim)
                    for _ in range(num_layers)])

        self.fc = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        
        positions = torch.arange(0, x.size(1), device=x.device).unsqueeze(0)
        x = self.embedding(x) + self.position_embedding(positions)

        for block in self.transformer_blocks:
            x = block(x)

        return self.fc(x)

# Vocabulary Building

In [18]:
def build_vocab(file_path, vocab_size):

    tokenizer = lambda text: text.split()
    counter = Counter()

    with open(file_path, 'r', encoding='utf-8') as f:
        counter.update(tokenizer(f.read()))

    vocab = {word: i for i, (word, _) in enumerate(counter.most_common(vocab_size))}
    
    return vocab

def tokenize(text, word_to_id):
    return [word_to_id.get(word, 0) for word in text.split()]

def detokenize(tokens, id_to_word):
    return ' '.join([id_to_word.get(token, "<unk>") for token in tokens])

# Generate Function

In [19]:
def generate(model, prompt, word_to_id, id_to_word, seq_len, max_length=50, temperature=1.0, top_k=10):

    model.eval()
    input_ids = tokenize(prompt, word_to_id)[-seq_len:]
    input_tensor = torch.tensor(input_ids, device=DEVICE).unsqueeze(0)

    for _ in range(max_length):
        with torch.no_grad():

            output = model(input_tensor)
            logits = output[:, -1, :] / temperature

            top_k_vals, top_k_idxs = torch.topk(logits, top_k)
            probs = torch.softmax(top_k_vals, dim=-1)

            next_token = top_k_idxs[0, torch.multinomial(probs, 1).item()].unsqueeze(0)

        input_tensor = torch.cat([input_tensor, next_token.unsqueeze(0)], dim=1)
        input_tensor = input_tensor[:, -seq_len:]

    generated_ids = input_tensor[0].tolist()
    
    return detokenize(generated_ids, id_to_word)

# Load Model and Settings

In [20]:
def select_model():

    model_files = [f for f in os.listdir(MODELS_DIR) if f.endswith(".pth")]

    if not model_files:
        raise FileNotFoundError("No .pth model files found in ../Models folder.")

    print("Available Models:")

    for i, f in enumerate(model_files):
        print(f"{i + 1}. {f}")
    
    while True:
        try:

            choice = int(input("\nChoose a model number: "))

            if 1 <= choice <= len(model_files):

                model_file = model_files[choice - 1]
                settings_file = model_file.replace(".pth", "_settings.json")

                model_path = os.path.join(MODELS_DIR, model_file)
                settings_path = os.path.join(MODELS_DIR, settings_file)

                if not os.path.exists(settings_path):
                    raise FileNotFoundError(f"Missing settings file: {settings_file}")
                
                return model_path, settings_path
            
            else:
                print("Invalid choice.")

        except ValueError:
            print("Enter a number.")

model_path, settings_path = select_model()

with open(settings_path, "r") as f:
    settings = json.load(f)

VOCAB_SIZE = settings["VOCAB_SIZE"]
SEQ_LEN = settings["SEQ_LEN"]

vocab = build_vocab(DATA_PATH, VOCAB_SIZE)
word_to_id = vocab

id_to_word = {i: w for w, i in vocab.items()}

model = SimpleTransformer(

    vocab_size=settings["VOCAB_SIZE"],
    embed_dim=settings["EMBED_DIM"],
    num_heads=settings["NUM_HEADS"],

    hidden_dim=settings["HIDDEN_DIM"],
    num_layers=settings["NUM_LAYERS"],

    seq_len=settings["SEQ_LEN"]).to(DEVICE)

model.load_state_dict(torch.load(model_path, map_location=DEVICE, weights_only=True))
model.eval()

print("\nYou are now speaking with Shakespeare.\nType 'quit' to exit.\n")

Available Models:
1. shakespeare_LLM_20250417-110810.pth
2. shakespeare_LLM_20250417-123315.pth

You are now speaking with Shakespeare.
Type 'quit' to exit.



# Interact with the LLM

In [None]:
while True:

    user_input = input("You: ")

    if user_input.lower() == 'quit':

        print("\nFarewell, gentle soul.")
        break

    prompt = user_input
    response = generate(model, prompt, word_to_id, id_to_word, seq_len=SEQ_LEN, max_length=40, temperature=0.8)

    print("Shakespeare:", response[len(prompt):].strip())

Shakespeare: build wi’ his climbing high. and I had I will level to her stoop I will passed. I have my dug for methought you to my good lieutenant, enemies, and let your service. Bourbon, play at his amorous Soft, trim
Shakespeare: and all my good as you shall be my acquittance follows. BERTRAM. I have my quits to thy tide of your discontent? But I have strew’d clouds that he would you do not sleep. The word’s influence of your having.
