# Imports

In [None]:
import os
import json
import torch
import warnings
import torch.nn as nn
import torch.nn.functional as F

from collections import Counter

# Settings

In [None]:
MODELS_DIR = "../Models"
DATA_PATH = "../Data/shakespeare_extended.txt"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model Definition

In [None]:
class SimpleTransformer(nn.Module):

    def __init__(self, vocab_size, embed_dim, num_heads, hidden_dim, num_layers, seq_len):

        super(SimpleTransformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.position_embedding = nn.Embedding(seq_len, embed_dim)

        self.transformer_blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=embed_dim, nhead=num_heads, dim_feedforward=hidden_dim
                    ) for _ in range(num_layers)])
        
        self.norm = nn.LayerNorm(embed_dim)
        self.fc = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):

        positions = torch.arange(0, x.size(1), device=x.device).unsqueeze(0)
        x = self.embedding(x) + self.position_embedding(positions)

        for block in self.transformer_blocks:
            x = block(x)

        x = self.norm(x)
        return self.fc(x)

# Vocabulary Building

In [None]:
def build_vocab(file_path, vocab_size):

    tokenizer = lambda text: text.split()
    counter = Counter()

    with open(file_path, 'r', encoding='utf-8') as f:
        counter.update(tokenizer(f.read()))

    vocab = {word: i for i, (word, _) in enumerate(counter.most_common(vocab_size))}
    
    return vocab

def tokenize(text, word_to_id):
    return [word_to_id.get(word, 0) for word in text.split()]

def detokenize(tokens, id_to_word):
    return ' '.join([id_to_word.get(token, "<unk>") for token in tokens])

# Generate Function

In [None]:
def advanced_generate(model, prompt, word_to_id, id_to_word, seq_len, max_length=50, temperature=1.0, top_k=0, top_p=0.9, device=DEVICE):

    model.eval()
    input_ids = tokenize(prompt, word_to_id)[-seq_len:]
    input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0)

    for _ in range(max_length):

        with torch.no_grad():
            logits = model(input_tensor)[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)

            if top_k > 0:

                top_k_values, top_k_indices = torch.topk(probs, top_k)
                mask = torch.zeros_like(probs).scatter_(1, top_k_indices, 1.0)

                probs = probs * mask
                probs = probs / probs.sum(dim=-1, keepdim=True)

            if top_p < 1.0:

                sorted_probs, sorted_indices = torch.sort(probs, descending=True)
                cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
                sorted_mask = cumulative_probs < top_p

                sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
                sorted_mask[..., 0] = 1

                final_mask = torch.zeros_like(probs).scatter_(1, sorted_indices, sorted_mask.float())

                probs = probs * final_mask
                probs = probs / probs.sum(dim=-1, keepdim=True)

            next_token = torch.multinomial(probs, num_samples=1)

            input_tensor = torch.cat([input_tensor, next_token], dim=1)
            input_tensor = input_tensor[:, -seq_len:]

    generated_ids = input_tensor[0].tolist()
    
    return detokenize(generated_ids, id_to_word)

# Load Model and Settings

In [None]:
def select_model():

    model_files = [f for f in os.listdir(MODELS_DIR) if f.endswith(".pth")]

    if not model_files:
        raise FileNotFoundError("No .pth model files found in ../Models folder.")
    
    print("Available Models:")

    for i, f in enumerate(model_files):
        print(f"{i + 1}. {f}")
    
    while True:

        try:
            choice = int(input("\nChoose a model number: "))

            if 1 <= choice <= len(model_files):

                model_file = model_files[choice - 1]
                settings_file = model_file.replace(".pth", "_settings.json")

                model_path = os.path.join(MODELS_DIR, model_file)
                settings_path = os.path.join(MODELS_DIR, settings_file)

                if not os.path.exists(settings_path):
                    raise FileNotFoundError(f"Missing settings file: {settings_file}")
                
                return model_path, settings_path
            
            else:
                print("Invalid choice.")

        except ValueError:
            print("Enter a number.")

model_path, settings_path = select_model()

with open(settings_path, "r") as f:
    settings = json.load(f)

VOCAB_SIZE = settings["VOCAB_SIZE"]
SEQ_LEN = settings["SEQ_LEN"]

vocab = build_vocab(DATA_PATH, VOCAB_SIZE)

word_to_id = vocab
id_to_word = {i: w for w, i in vocab.items()}

model = SimpleTransformer(

    vocab_size=VOCAB_SIZE,

    embed_dim=settings["EMBED_DIM"],
    num_heads=settings["NUM_HEADS"],

    hidden_dim=settings["HIDDEN_DIM"],
    num_layers=settings["NUM_LAYERS"],
    seq_len=SEQ_LEN).to(DEVICE)

model.load_state_dict(torch.load(model_path, map_location=DEVICE, weights_only=True))
model.eval()

print("\nYou are now speaking with Shakespeare.\nType 'quit' to exit.\n")

# Interact with the LLM

In [None]:
while True:

    user_input = input("You: ")

    if user_input.lower() == 'quit':

        print("\nFarewell, gentle soul.")
        break

    response = advanced_generate(

        model=model,
        prompt=user_input,
        word_to_id=word_to_id,

        id_to_word=id_to_word,
        seq_len=SEQ_LEN,
        max_length=40,
        
        temperature=0.8,
        top_k=40,
        top_p=0.9)

    print("Shakespeare:", response[len(user_input):].strip())