In [1]:
# General parameters

OUTPUT_FILE = "data/output.txt"

SEED = 42
TRAIN_SIZE = 0.80 # Percentage of data used in training
BATCH_SIZE = 5
NUM_BATCHES = 20000 # Character-based tokenization
NUM_EPOCHS = 10 # Subword-based tokenization
LR = 3e-4 # Learning rate
LOSS_SAMPLE_SIZE = 50

EMBEDDING_SIZE = 384 # Embedding vector length
CONTEXT_SIZE = 384 # Context window size                             
FFN_SIZE = 8 # Feed Forward Network size                            
HEADS = 8 # Transformer heads in multihead attention                # 4 8
BLOCKS = 8 # Transformer blocks (multhead, ffn, ln) depth           # 4 8 

top_k = 10 # Used in GPT.generate(..., sampling="top_k") to sample using top_k method.

model_filename = (
    f"model_TS{int(TRAIN_SIZE*100)}_BS{BATCH_SIZE}_NB{NUM_BATCHES}_EP_{NUM_EPOCHS}_"
    f"LR{LR}_ES{EMBEDDING_SIZE}_CS{CONTEXT_SIZE}_"
    f"FS{FFN_SIZE}_H{HEADS}_B{BLOCKS}.pth"
)

In [2]:
import random
import numpy as np
import torch

torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = "cuda" if torch.cuda.is_available() else "cpu"

print("DEVICE=", device)
print("\n")

DEVICE= cuda




# Data preprocessing

### (1) Simple drama format
### (2) (1) with opera tracking

In [None]:
# (1) Simple drama format

import pandas as pd

df = pd.read_csv("data/data.csv")

with open("data/text_simple.txt", "w") as f_out:
    for i in range(df.shape[0] - 1):
        data = df.iloc[i]
        next_data = df.iloc[i + 1]

        f_out.write(data["PlayerLine"] + "\n")
        if data["PlayerLinenumber"] != next_data["PlayerLinenumber"] and not pd.isna(next_data["Player"]):
            f_out.write("\n" + next_data["Player"] + "\n")

In [4]:
# TODO: More datasets

# Tokenization
### (1) Character level
### (2) Subword level

In [None]:
# (1) Character level

data = None
vocab = None

with open("data/text_simple.txt") as f_in:
    data = f_in.read()
    print(f"Dataset size {len(data)} characters")
    vocab = list(set(data))
    print(f"Corpus vocabulary size: {len(vocab)}")

vocab_size = len(vocab)

# Lookup table generator

char_to_idx = {ch:idx for idx, ch in enumerate(vocab)}
idx_to_char = {idx:ch for idx, ch in enumerate(vocab)}
str2tokens = lambda str: [char_to_idx[ch] for ch in str]
tokens2str = lambda tokens: "".join([idx_to_char[token] for token in tokens])

import torch
tokenized_corpus = torch.tensor(str2tokens(data), dtype=torch.int32)

## Subword

In [None]:
import os
import torch
import pickle
from transformers import GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")

if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

save_dir = 'saves'
os.makedirs(save_dir, exist_ok=True)
tokenized_data_path = os.path.join(save_dir, 'tokenized_corpus.pkl')

def data_iterator(file_path):
    with open(file_path, "r") as f_in:
        lines = [line.strip() for line in f_in if line.strip()]
    return lines

if os.path.exists(tokenized_data_path):
    print("Loading tokenized data from file...")
    with open(tokenized_data_path, 'rb') as f:
        tokenized_corpus = pickle.load(f)
else:
    print("Tokenized data not found. Processing and saving...")
    lines = data_iterator("data/text_simple.txt")
    tokenized_corpus = [tokenizer.encode(line) for line in lines]
    with open(tokenized_data_path, 'wb') as f:
        pickle.dump(tokenized_corpus, f)

vocab_size = tokenizer.vocab_size

max_length = CONTEXT_SIZE + 1
padded_corpus = []
for seq in tokenized_corpus:
    if len(seq) > max_length:
        padded_seq = seq[:max_length]
    else:
        padded_seq = seq + [tokenizer.pad_token_id] * (max_length - len(seq))
    padded_corpus.append(padded_seq)


tokenized_corpus = torch.tensor(padded_corpus)
print(f"Tensor shape: {tokenized_corpus.shape}")
print(tokenized_corpus[:100])

In [11]:
import os
import torch
import pickle
from transformers import GPT2Tokenizer

CONTEXT_SIZE = 384

tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")

if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

save_dir = 'saves'
os.makedirs(save_dir, exist_ok=True)
tokenized_data_path = os.path.join(save_dir, 'tokenized_corpus_full.pkl')

def read_full_text(file_path):
    with open(file_path, "r", encoding="utf-8") as f_in:
        text = f_in.read().strip()
    return text

if os.path.exists(tokenized_data_path):
    print("Loading tokenized data from file...")
    with open(tokenized_data_path, 'rb') as f:
        tokenized_corpus = pickle.load(f)
else:
    print("Tokenized data not found. Processing and saving...")
    full_text = read_full_text("data/text_simple.txt")
    tokenized_text = tokenizer.encode(full_text)
    print(f"Total tokens in file: {len(tokenized_text)}")
    
    tokenized_corpus = []
    step = CONTEXT_SIZE + 1
    for i in range(0, len(tokenized_text), step):
        segment = tokenized_text[i: i + step]
        tokenized_corpus.append(segment)
    
    with open(tokenized_data_path, 'wb') as f:
        pickle.dump(tokenized_corpus, f)

vocab_size = tokenizer.vocab_size

max_length = CONTEXT_SIZE + 1
padded_corpus = []
for seq in tokenized_corpus:
    if len(seq) > max_length:
        padded_seq = seq[:max_length]
    else:
        padded_seq = seq + [tokenizer.pad_token_id] * (max_length - len(seq))
    padded_corpus.append(padded_seq)

tokenized_corpus_tensor = torch.tensor(padded_corpus)
print(f"Tensor shape: {tokenized_corpus_tensor.shape}")
print(tokenized_corpus_tensor[:100])


Loading tokenized data from file...
Tensor shape: torch.Size([1599, 385])
tensor([[10659,   314,   198,  ...,  1309,   502,  3285],
        [  198,  5189,   345,  ...,   286,  1123,  9260],
        [  198, 13056,    86,  ...,  6465,    44,  1581],
        ...,
        [ 1921,   198,    44,  ...,  1544,   925,   257],
        [  698,  8023,   269,  ...,   465,  1266,    25],
        [  290,   994,  3197,  ...,    43,  3963,   360]])


In [None]:
# Data split (train/val)

train_partition = int(TRAIN_SIZE*len(tokenized_corpus))
train = tokenized_corpus[:train_partition]
val = tokenized_corpus[train_partition:]

print(train[:100])

# Model 

In [3]:
# Custom nn.Linear :)
import math
import torch
import torch.nn as nn

class Linear(nn.Module):
    def __init__(self, in_size,  out_size, bias=True):
        super(Linear, self).__init__()
        self.weight = nn.Parameter(torch.Tensor(out_size, in_size))

        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_size))
        else:
            self.register_parameter('bias', None)
        
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)
    
    def forward(self, x):
        return x @ self.weight.T + self.bias

In [4]:
# Custom LayerNorm

import torch
import torch.nn as nn

class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
        super(LayerNorm, self).__init__()
        if isinstance(normalized_shape, int):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = normalized_shape
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        if self.elementwise_affine:
            self.weight = nn.Parameter(torch.ones(self.normalized_shape))
            self.bias = nn.Parameter(torch.zeros(self.normalized_shape))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        x_normalized = (x - mean) / torch.sqrt(var + self.eps)
        if self.elementwise_affine:
            x_normalized = self.weight * x_normalized + self.bias
        return x_normalized


In [5]:
# Attention Head
from torch.nn import functional as F

class AttentionHead(nn.Module):
    def __init__(self, headspace):
        super().__init__()

        self.query = Linear(EMBEDDING_SIZE, headspace)
        self.key = Linear(EMBEDDING_SIZE, headspace)
        self.value = Linear(EMBEDDING_SIZE, headspace)

        causal_mask = torch.triu(torch.ones(CONTEXT_SIZE, CONTEXT_SIZE), diagonal=1)
        self.register_buffer('mask', causal_mask)

    def forward(self, x):
        batch_size, seq_len, emb_len = x.shape
        query = self.query(x)
        key = self.key(x)

        att_weights = query @ key.transpose(-2, -1) * (1 / torch.sqrt(torch.tensor(emb_len, dtype=torch.float32)))
        att_weights = att_weights.masked_fill(self.mask[:seq_len, :seq_len] == 1, float('-inf'))
        att_weights = F.softmax(att_weights, -1)

        value = self.value(x)
        return att_weights @ value


# Demo
# att_head = AttentionHead(16, 10, 10)
# print(att_head.forward(torch.Tensor(64, 5, 16)).shape)

In [6]:
# Multihead Attention

class MultiHeadAttention(nn.Module):
    def __init__(self):
        super().__init__()

        headspace = EMBEDDING_SIZE // HEADS
        self.heads = nn.ModuleList([AttentionHead(headspace) for _ in range(HEADS)])
        self.linear = Linear(EMBEDDING_SIZE, EMBEDDING_SIZE)

    def forward(self, x):
        x = torch.cat([head(x) for head in self.heads], dim=-1)
        x = self.linear(x)

        return x
    
# mta = MultiHeadAttention(2)
# x = torch.rand(1, 7, EMBEDDING_SIZE)
# print(mta.forward(x))

In [7]:
# Transformer Feed Forward Module

class FFN_Block(nn.Module):
    def __init__(self):
        super().__init__()

        self.ffn = nn.Sequential(
            Linear(EMBEDDING_SIZE, EMBEDDING_SIZE * FFN_SIZE),
            nn.ReLU(inplace=True),
            Linear(EMBEDDING_SIZE * FFN_SIZE, EMBEDDING_SIZE)
        )

    def forward(self, x):
        x = self.ffn(x)
        
        return x

In [8]:
# Transformer block 
# Multihead, FFN, LN (Pre-LN)

# TODO: CustomLayerNorm

class TransformerBlock(nn.Module):
    def __init__(self):
        super().__init__()

        self.ln = LayerNorm(EMBEDDING_SIZE)
        self.causal_attention = MultiHeadAttention()
        self.ffn = FFN_Block()
    
    def forward(self, x):
        x_norm0 = self.ln(x)
        x = x + self.causal_attention(x_norm0)

        x_norm1 = self.ln(x)
        x = x + self.ffn(x_norm1)

        return x

In [9]:
# (1) Model with token_embeds for single char tokenization

import torch
import torch.nn as nn
import math

def get_sinusoidal_positional_encoding(seq_len, embedding_dim, device='cpu'):
    position = torch.arange(seq_len, dtype=torch.float, device=device).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, embedding_dim, 2, dtype=torch.float, device=device) * 
                         -(math.log(10000.0) / embedding_dim))
    pe = torch.zeros(seq_len, embedding_dim, device=device)
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe


class GPT(nn.Module):
    def __init__(self):
        super().__init__()

        self.token_embeds = nn.Embedding(vocab_size, EMBEDDING_SIZE)
        self.pos_embeds = nn.Embedding(CONTEXT_SIZE, EMBEDDING_SIZE)
        self.blocks = nn.Sequential(*[TransformerBlock() for _ in range(BLOCKS)])
        self.ln = LayerNorm(EMBEDDING_SIZE)
        self.last_embeddings = Linear(EMBEDDING_SIZE, vocab_size)

    def forward(self, x, targets=None):
        batch_size, seq_len = x.shape

        tok_emb = self.token_embeds(x)
        pos_emb = get_sinusoidal_positional_encoding(seq_len, EMBEDDING_SIZE, device=x.device)
        pos_emb = pos_emb.unsqueeze(0).expand(batch_size, -1, -1)
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln(x)
        logits = self.last_embeddings(x)

        loss = None
        if targets is not None:
            logits_flat = logits.view(batch_size * seq_len, -1)
            targets_flat = targets.view(-1).long()
            loss = F.cross_entropy(logits_flat, targets_flat)
        
        return logits, loss

    
    def generate(self, inputs, length, sampling="argmax"):
        for _ in range(length):
            _inputs = inputs[:, -CONTEXT_SIZE:]
            logits, _ = self.forward(_inputs)
            logits = logits[:, -1, :]
            probs = torch.softmax(logits, dim=-1)

            if sampling == "argmax":
                next_input = torch.argmax(probs, dim=-1, keepdim=True)
            elif sampling == "top-k":
                top_k_probs, top_k_indices = torch.topk(probs, k=top_k, dim=-1)
                top_k_probs = top_k_probs / torch.sum(top_k_probs, dim=-1, keepdim=True)
                next_input = torch.multinomial(top_k_probs, num_samples=1)
                next_input = torch.gather(top_k_indices, dim=-1, index=next_input)
                
            inputs = torch.cat((inputs, next_input), dim=1)
        return inputs


# Training

In [16]:
def batch(data):
    selected = torch.randint(len(data) - CONTEXT_SIZE, (BATCH_SIZE,))
    input    = torch.stack([data[i:i+CONTEXT_SIZE] for i in selected])
    target   = torch.stack([data[i+1:i+CONTEXT_SIZE+1] for i in selected])

    return input.to(device), target.to(device)

In [None]:
import os
import torch
import torch.nn.functional as F
from torcheval.metrics import Perplexity
import wandb
from alive_progress import alive_bar

wandb.login(key="0d5701acf04c9f9aaafad15365088656730ee934")
wandb.init(project="gpt-training")

model = GPT().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
perplexity_metric = Perplexity(device=device)

with alive_bar(NUM_BATCHES, title="Training Progress", bar="blocks", spinner="dots_waves") as bar:
    for epoch in range(NUM_BATCHES):
        model.train()
        input, target = batch(train)
        input, target = input.to(device), target.to(device)
        logits, loss = model(input, target)
        
        perplexity_metric.update(logits, target)
        current_perplexity = perplexity_metric.compute().item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        wandb.log({
            "epoch": epoch,
            "train_loss": loss.item(),
            "train_perplexity": current_perplexity
        })

        if epoch % 100 == 0:
            model.eval()
            with torch.no_grad():
                val_input, val_target = batch(val)
                val_logits, val_loss = model(val_input, val_target)
                perplexity_metric.update(val_logits, val_target)
                val_perplexity = perplexity_metric.compute().item()

                wandb.log({
                    "val_loss": val_loss.item(),
                    "val_perplexity": val_perplexity
                })

                print(f"Validation: ðŸ”µ Val Loss = {val_loss.item():.5f} | "
                    f"ðŸ”´ Val Perplexity = {val_perplexity:.5f}")
        
            print(f"Epoch {epoch:04}/{NUM_BATCHES}: "
                  f"ðŸŸ¢ Train Loss = {loss.item():.5f} | "
                  f"ðŸ”µ Train Perplexity = {current_perplexity:.5f}")

        bar()

torch.save(model.state_dict(), os.path.join(os.getcwd(), "saves", model_filename))
wandb.finish()

# Training (Subword)

In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
from alive_progress import alive_bar
from transformers import GPT2Tokenizer
import wandb
import torchmetrics
import torch.nn.functional as F

class TextDataset(Dataset):
    def __init__(self, tokenized_corpus, context_size):
        self.tokenized_corpus = tokenized_corpus
        self.context_size = context_size
        self.data = []
        print(len(self.tokenized_corpus))
        for seq in self.tokenized_corpus:
            if len(seq) > context_size + 1:
                for i in range(0, len(seq) - context_size, context_size):
                    input_seq = seq[i : i + context_size]
                    target_seq = seq[i + 1 : i + context_size + 1]
                    if len(input_seq) == context_size and len(target_seq) == context_size:
                        self.data.append((torch.tensor(input_seq, dtype=torch.long),
                                          torch.tensor(target_seq, dtype=torch.long)))
            elif len(seq) == context_size + 1:
                input_seq = seq[:context_size]
                target_seq = seq[1:]
                self.data.append((torch.tensor(input_seq, dtype=torch.long),
                                  torch.tensor(target_seq, dtype=torch.long)))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


full_dataset = TextDataset(tokenized_corpus, CONTEXT_SIZE)
total_size = len(full_dataset)
train_size = int(TRAIN_SIZE * total_size)
val_size = total_size - train_size

train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
dataloader_train = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
dataloader_val = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)


wandb.login(key="0d5701acf04c9f9aaafad15365088656730ee934")
wandb.init(project="gpt-training", name=model_filename)

model = GPT().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
perplexity_metric = torchmetrics.Perplexity().to(device)


with alive_bar(NUM_EPOCHS, title="Training Progress", bar="blocks", spinner="dots_waves") as bar:
    for epoch in range(NUM_EPOCHS):
        for batch_idx, (inputs, targets) in enumerate(dataloader_train):
            model.train()
            inputs, targets = inputs.to(device), targets.to(device)
            logits, loss = model(inputs, targets)

            perplexity_metric.update(logits, targets)
            current_perplexity = perplexity_metric.compute().item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            wandb.log({
                "epoch": epoch,
                "train_loss": loss.item(),
                "train_perplexity": current_perplexity
            })

            if batch_idx % 10 == 0:
                model.eval()
                with torch.no_grad():
                    val_inputs, val_targets = next(iter(dataloader_val))
                    val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)
                    val_logits, val_loss = model(val_inputs, val_targets)

                    perplexity_metric.update(val_logits, val_targets)
                    val_perplexity = perplexity_metric.compute().item()

                    wandb.log({
                        "val_loss": val_loss.item(),
                        "val_perplexity": val_perplexity
                    })

                    print(f"Validation: ðŸ”µ Val Loss = {val_loss.item():.5f} | "
                        f"ðŸ”´ Val Perplexity = {val_perplexity:.5f}")

                print(f"Epoch {epoch:02}/{NUM_EPOCHS}: "
                    f"ðŸŸ¢ Train Loss = {loss.item():.5f} | "
                    f"ðŸ”µ Train Perplexity = {current_perplexity:.5f}")

            bar()

torch.save(model.state_dict(), os.path.join(os.getcwd(), "saves", model_filename))
wandb.finish()

# Generate

In [None]:
import os
import torch
from tqdm.notebook import tqdm
from IPython.display import display
import ipywidgets as widgets

saves_dir = "saves"
models = [f for f in os.listdir(saves_dir) if f.endswith(".pth")]

dropdown = widgets.Dropdown(
    options=models,
    description="Model:",
    layout=widgets.Layout(width='50%')
)
display(dropdown)

button = widgets.Button(description="Load Model")
output = widgets.Output()

def load_model(b):
    global model, device
    with output:
        output.clear_output()
        selected_model_path = os.path.join(saves_dir, dropdown.value)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        try:
            model = GPT().to(device)
            checkpoint = torch.load(selected_model_path, map_location=device, weights_only=True)
            if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
                model.load_state_dict(checkpoint["state_dict"])
            else:
                model.load_state_dict(checkpoint)
            model.eval()
            print(f"Loaded model (state_dict): {dropdown.value}")
        except Exception as e:
            print("Error loading model state dict:", e)

button.on_click(load_model)
display(button, output)

def generate_text():
    if 'model' not in globals():
        print("Please load a model first!")
        return

    OUTPUT_FILE = f"data/{model_filename}__top-k={top_k}.txt"
    context = torch.zeros((1, 1), dtype=torch.long, device=device)

    with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
        with tqdm(total=5000, desc="Generating Text", unit="token") as pbar:
            for _ in range(5000):
                x = model.generate(context, length=1, sampling="top-k")
                context = torch.cat((context, x[:, -1:]), dim=1)
                pbar.update(1)

        generated_text = tokens2str(context[0].tolist())
        print("Generated Text:")
        print(generated_text)
        f.write(generated_text)
        print(f"\nGenerated text saved to {OUTPUT_FILE}")

generate_button = widgets.Button(description="Generate Text")
generate_button.on_click(lambda b: generate_text())
display(generate_button)


In [None]:
# Helper block to load the model :)
import re

#   TS(\d+)   captures the TRAIN_SIZE*100 as an integer
#   BS(\d+)   captures BATCH_SIZE
#   NB(\d+)   captures NUM_BATCHES
#   LR(\d+(?:\.\d+)?) captures LR (integer or decimal)
#   ES(\d+)   captures EMBEDDING_SIZE
#   CS(\d+)   captures CONTEXT_SIZE
#   FS(\d+)   captures FFN_SIZE
#   H(\d+)    captures HEADS
#   B(\d+)    captures BLOCKS
pattern = re.compile(
    r"model_TS(\d+)_BS(\d+)_NB(\d+)_EP_(\d+)_LR(\d+(?:\.\d+)?)_ES(\d+)_CS(\d+)_FS(\d+)_H(\d+)_B(\d+)\.pth"
)

def parse_model_filename(filename: str):

    match = pattern.match(filename)
    if not match:
        raise ValueError(f"Filename '{filename}' does not match the expected pattern.")

    (
        ts_str,     # TS(\d+)
        bs_str,     # BS(\d+)
        nb_str,     # NB(\d+)
        ep_str,     # EP(\d+)
        lr_str,     # LR(\d+(?:\.\d+)?)
        es_str,     # ES(\d+)
        cs_str,     # CS(\d+)
        fs_str,     # FS(\d+)
        h_str,      # H(\d+)
        b_str       # B(\d+)
    ) = match.groups()

    ts_val = int(ts_str)             # This was int(TRAIN_SIZE * 100)
    bs_val = int(bs_str)
    nb_val = int(nb_str)
    ep_val = int(ep_str)
    lr_val = float(lr_str)
    es_val = int(es_str)
    cs_val = int(cs_str)
    fs_val = int(fs_str)
    h_val  = int(h_str)
    b_val  = int(b_str)

    train_size = ts_val / 100.0

    return {
        "TRAIN_SIZE": train_size,
        "BATCH_SIZE": bs_val,
        "NUM_BATCHES": nb_val,
        "NUM_EPOCHS": ep_val,
        "LR": lr_val,
        "EMBEDDING_SIZE": es_val,
        "CONTEXT_SIZE": cs_val,
        "FFN_SIZE": fs_val,
        "HEADS": h_val,
        "BLOCKS": b_val
    }


model_filename = "model_TS80_BS5_NB20000_EP_10_LR0.0003_ES384_CS384_FS8_H8_B8.pth"
parsed_values = parse_model_filename(model_filename)

TRAIN_SIZE      = parsed_values["TRAIN_SIZE"]
BATCH_SIZE      = parsed_values["BATCH_SIZE"]
NUM_BATCHES     = parsed_values["NUM_BATCHES"]
NUM_EPOCHS      = parsed_values["NUM_EPOCHS"]
LR              = parsed_values["LR"]
EMBEDDING_SIZE  = parsed_values["EMBEDDING_SIZE"]
CONTEXT_SIZE    = parsed_values["CONTEXT_SIZE"]
FFN_SIZE        = parsed_values["FFN_SIZE"]
HEADS           = parsed_values["HEADS"]
BLOCKS          = parsed_values["BLOCKS"]

print("Parsed and assigned:")
print("TRAIN_SIZE:     ", TRAIN_SIZE)
print("BATCH_SIZE:     ", BATCH_SIZE)
print("NUM_BATCHES:    ", NUM_BATCHES)
print("NUM_EPOCHS:     ", NUM_EPOCHS)
print("LR:             ", LR)
print("EMBEDDING_SIZE: ", EMBEDDING_SIZE)
print("CONTEXT_SIZE:   ", CONTEXT_SIZE)
print("FFN_SIZE:       ", FFN_SIZE)
print("HEADS:          ", HEADS)
print("BLOCKS:         ", BLOCKS)

In [12]:
import os
import torch
from tqdm.notebook import tqdm
from IPython.display import display
import ipywidgets as widgets

saves_dir = "saves"
models = [f for f in os.listdir(saves_dir) if f.endswith(".pth")]

dropdown = widgets.Dropdown(
    options=models,
    description="Model:",
    layout=widgets.Layout(width='50%')
)
display(dropdown)

button = widgets.Button(description="Load Model")
output = widgets.Output()

def load_model(b):
    global model, device
    with output:
        output.clear_output()
        selected_model_path = os.path.join(saves_dir, dropdown.value)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        try:
            model = GPT().to(device)
            checkpoint = torch.load(selected_model_path, map_location=device, weights_only=True)
            if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
                model.load_state_dict(checkpoint["state_dict"])
            else:
                model.load_state_dict(checkpoint)
            model.eval()
            print(f"Loaded model (state_dict): {dropdown.value}")
        except Exception as e:
            print("Error loading model state dict:", e)

button.on_click(load_model)
display(button, output)

def generate_text():
    if 'model' not in globals():
        print("Please load a model first!")
        return

    OUTPUT_FILE = f"data/{model_filename}__top-k={top_k}.txt"
    context = torch.zeros((1, 1), dtype=torch.long, device=device)

    with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
        with tqdm(total=5000, desc="Generating Text", unit="token") as pbar:
            for _ in range(5000):
                x = model.generate(context, length=1, sampling="top-k")
                context = torch.cat((context, x[:, -1:]), dim=1)
                pbar.update(1)

        # Use GPT2Tokenizer to decode the tokens.
        tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
        token_ids = context[0].tolist()
        generated_text = tokenizer.decode(token_ids, skip_special_tokens=True)
        
        print("Generated Text:")
        print(generated_text)
        f.write(generated_text)
        print(f"\nGenerated text saved to {OUTPUT_FILE}")


generate_button = widgets.Button(description="Generate Text")
generate_button.on_click(lambda b: generate_text())
display(generate_button)

Dropdown(description='Model:', layout=Layout(width='50%'), options=('model_TS80_BS5_NB20000_EP_10_LR0.0003_ES3â€¦

Button(description='Load Model', style=ButtonStyle())

Output()

Button(description='Generate Text', style=ButtonStyle())

Generating Text:   0%|          | 0/5000 [00:00<?, ?token/s]

Generated Text:
!

HAMLET
I, it, sir.


HAMLET
That's one of me, this, I do not, but yet
this, it as that which ever of mine honour,
ithee, that the more, it go to the
and the most capital crimes were
great, I will follow you, for the
the hall of them that cannot be pit of the
were out of them that ever kept the
his recognubber may read mine is done, as oft,
favoury upon the rich men that I may, which if
of.

Ghost
What would not, but I am so, I had not, I had not, who
to extend to have them up. But, I am much, if you, my
and more, in them that I would put your
that your own. An I had
inish your ways or, and, I have
my son to say nothing but in this matter. I do it
to me the love a thing. I have
and altogether given me, in the love in the better for
guts for the titheth, and in my
gods and twenty years, for the better of your
of the very soul: but in this
their courage that were your own rede. I have made to make you.
Enter ROSALIND

SILIP
Come, come, I do you are too hard, and so much

In [None]:
import os
import time
import torch
from torch.utils.data import DataLoader, random_split
import nltk
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from transformers import GPT2Tokenizer
from tqdm.notebook import tqdm  # Use tqdm for notebooks
import numpy as np

nltk.download('punkt')
# nltk.download('punkt_tab') 

SAVES_DIR = "saves"
MODEL_FILENAME = "model_TS80_BS5_NB20000_EP_10_LR0.0003_ES384_CS384_FS8_H8_B8.pth" 
MODEL_PATH = os.path.join(SAVES_DIR, MODEL_FILENAME)

base_model_name, _ = os.path.splitext(MODEL_FILENAME)
REFS_FILE = base_model_name + "_refs.npy"
HYPS_FILE = base_model_name + "_hyps.npy"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------- Load Saved Model --------
model = GPT().to(device)
checkpoint = torch.load(MODEL_PATH, map_location=device, weights_only=True)
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
    model.load_state_dict(checkpoint["state_dict"])
else:
    model.load_state_dict(checkpoint)
model.eval()

# -------- Prepare Validation Dataset --------
full_dataset = TextDataset(tokenized_corpus, CONTEXT_SIZE)
total_size = len(full_dataset)
train_size = int(TRAIN_SIZE * total_size)
val_size = total_size - train_size

_, val_dataset = random_split(full_dataset, [train_size, val_size])
dataloader_val = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

# -------- Initialize Tokenizer --------
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")

# -------- Compute BLEU Score --------
smooth = SmoothingFunction().method1

if os.path.exists(REFS_FILE) and os.path.exists(HYPS_FILE):
    references = np.load(REFS_FILE, allow_pickle=True).tolist()
    hypotheses = np.load(HYPS_FILE, allow_pickle=True).tolist()
    print("Loaded saved references and hypotheses from:")
    print(f"  {REFS_FILE}")
    print(f"  {HYPS_FILE}")
else:
    references = []
    hypotheses = []
    pbar = tqdm(dataloader_val, desc="Processing Batches", leave=True)
    for inputs, targets in pbar:
        start_time = time.perf_counter() 

        inputs = inputs.to(device)
        outputs = model.generate(inputs, length=inputs.size(1), sampling="top-k")
        
        target_ids = targets.tolist()
        output_ids = outputs.tolist()
        ref_texts = tokenizer.batch_decode(target_ids, skip_special_tokens=True)
        hyp_texts = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        
        batch_refs = [[nltk.word_tokenize(ref)] for ref in ref_texts]
        batch_hyps = [nltk.word_tokenize(hyp) for hyp in hyp_texts]
        
        references.extend(batch_refs)
        hypotheses.extend(batch_hyps)
        
        batch_time = time.perf_counter() - start_time
        avg_sample_time = batch_time / inputs.size(0)
        
        pbar.set_postfix({
            'batch_time (s)': f"{batch_time:.2f}",
            'avg_sample_time (s)': f"{avg_sample_time:.4f}"
        })
    
    np.save(REFS_FILE, np.array(references, dtype=object))
    np.save(HYPS_FILE, np.array(hypotheses, dtype=object))
    print("Saved references and hypotheses to:")
    print(f"  {REFS_FILE}")
    print(f"  {HYPS_FILE}")

bleu_score = corpus_bleu(references, hypotheses, smoothing_function=smooth)
print("BLEU score: {:.4f}".format(bleu_score))

bleu_txt_file = base_model_name + "_bleu.txt"
with open(bleu_txt_file, 'w') as f:
    f.write("BLEU score: {:.4f}".format(bleu_score))

In [None]:
import wandb
import pandas as pd

api = wandb.Api()
project_path = "ldg2875/gpt-training"
runs = api.runs(project_path)

data = []

for run in runs:
    if run.state != "finished":
        continue
    run_summary = run.summary._json_dict
    run_config = run.config
    run_data = {
        "run_id": run.id,
        "name": run.name,
        "state": run.state,
        **run_summary,
        **run_config
    }
    data.append(run_data)

df = pd.DataFrame(data)
df.to_csv("wandb_run_data.csv", index=False)
df.to_csv("wandb_run_data.txt", index=False, sep="\t")
print("Export complete! Files saved as 'wandb_run_data.csv' and 'wandb_run_data.txt'.")


In [None]:
import pandas as pd
import matplotlib.pyplot as plt

csv_file = "wandb_run_data.csv"
df = pd.read_csv(csv_file)
df['val_loss'] = pd.to_numeric(df['val_loss'], errors='coerce')
df['val_perplexity'] = pd.to_numeric(df['val_perplexity'], errors='coerce')
df = df.dropna(subset=['val_loss', 'val_perplexity'])
best_val_loss = df.sort_values('val_loss').head(5)
best_val_perplexity = df.sort_values('val_perplexity').head(5)

plt.figure(figsize=(8, 6))
plt.bar(best_val_loss['name'], best_val_loss['val_loss'], color='skyblue')
plt.title("Top 5 Models by Validation Loss")
plt.xlabel("Model Name")
plt.ylabel("Validation Loss")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

plt.figure(figsize=(8, 6))
plt.bar(best_val_perplexity['name'], best_val_perplexity['val_perplexity'], color='salmon')
plt.title("Top 5 Models by Validation Perplexity")
plt.xlabel("Model Name")
plt.ylabel("Validation Perplexity")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()
