In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections import Counter
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from llama.data_pipeline import tiny_shakespeare
from llama.data_pipeline import dataloader
from llama.model.tokenizer import CharacterTokenizer
from llama.model.custom_layers import *
from llama.model.custom_blocks import *
from llama.model.model import Llama
from llama.constants import *


device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# Configuration object for model parameters
CONFIG = {
    'vocab_size': -1,         # TBD based on dataset
    'batch_size': 64,          # Number of batches to be processed at each random split
    'context_window': 16,     # Number of characters in each input (x) and target (y) sequence of each batch
    'd_model': 128,           # Dimension of linear layers (128)
    'epochs': 3,          # Number of training epochs
    'log_interval': 10,      # Log information every 10 batches during training
    'batch_size': 32,        # Increase batch size to 32
    'n_heads': 8,            # number of attention heads
    'n_layers': 4,           # Set the number of layers to 4
}

# Load dataset

In [None]:
# stores the tiny shakespeare dataset to disk
data_path = tiny_shakespeare.download_tiny_shakespeare()

In [None]:
# Read the content of the dataset
with open(data_path, 'r') as f:
    lines = f.read()

# Create a sorted list of unique characters in the dataset
character_counts = Counter(lines)
vocab = sorted(list(set(lines)))
print(character_counts)

# update the vocabulary size in the configuration
CONFIG['vocab_size'] = len(vocab)

# Output the total number of characters in our dataset (Vocabulary Size)
print(f'Total number of characters in our dataset (Vocabulary Size): {CONFIG["vocab_size"]}')

# Create tokenizer

In [None]:
# create tokenizer
tokenizer = CharacterTokenizer(vocab)

# check encode and decode functions
tokenizer.decode(tokenizer.encode("hello world!"))

In [None]:
# split the text data
train_split = lines[:int(0.8 * len(lines))]
val_split = lines[int(0.8 * len(lines)): int(0.9 * len(lines))]
test_split = lines[int(0.9 * len(lines)):]

# create a dataset for each split
train_dataset = dataloader.TextDataset(train_split, tokenizer, CONFIG['context_window'], device)
val_dataset = dataloader.TextDataset(val_split, tokenizer, CONFIG['context_window'], device)
test_dataset = dataloader.TextDataset(test_split, tokenizer, CONFIG['context_window'], device)

# create a dataloader for each split
bs = CONFIG['batch_size']
train_dataloader = DataLoader(train_dataset, batch_size=bs, shuffle=True, drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=bs, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=bs, shuffle=False)

print(f"Steps train: {len(train_dataloader)}, val: {len(val_dataloader)}, test: {len(test_dataloader)}")

# Training and evaluation functions

In [None]:
@torch.no_grad()
def evaluate_loss(model, dataloader):
    # set the model to evaluation mode
    model.eval()
    
    losses = []
    for x, y in dataloader:
        _, loss = model(x, y)
        losses.append(loss.item())
    
    return np.mean(losses)


def train(model, optimizer, n_epochs, scheduler=None):
    # set model in training mode
    model.train()
    
    # Placeholder for storing losses
    metrics = {'train': [], 'val': []}
    if scheduler:
        metrics['lr'] = []

    # Iterate through epochs
    for epoch in range(n_epochs):
        
        train_losses = []
        for x, y in (pbar := tqdm(train_dataloader)):
            # Zero out gradients
            optimizer.zero_grad()

            # Forward pass through the model to calculate logits and loss
            logits, loss = model(x, targets=y)

            # Backward pass and optimization step
            loss.backward()
            optimizer.step()
            
            loss_train = loss.item()
            train_losses.append(loss_train)
            pbar.set_description(f"Ep {epoch+1}/{n_epochs} | Train loss {loss_train:.3f}")
            
        metrics['train'].append(np.mean(train_losses))

        # adjust the learning rate if there is a lr scheduler
        if scheduler:
            metrics['lr'].append(scheduler.get_last_lr()[0])
            scheduler.step()

        # evaluate loss on the validation set
        loss_val = evaluate_loss(model, val_dataloader)
        metrics['val'].append(loss_val)
        print(f"Ep {epoch+1}/{n_epochs} | Train loss {metrics['train'][-1]:.3f} | Val loss {loss_val:.3f}")
    
    # Plot the training and validation loss curves
    return pd.DataFrame(metrics)

# Create model and train

In [None]:
# create the Llama model
llama = Llama(CONFIG)
llama = llama.to(device)
print(f"Model params: {sum([p.numel() for p in llama.parameters()]):,}")

# create the corresponding optimizer
optimizer = torch.optim.Adam(llama.parameters(), lr=1e-3)

# create a step learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

In [None]:
%%time

df_losses = train(llama, optimizer, CONFIG['epochs'], lr_scheduler)

In [None]:
fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(14, 4))

df_losses[['train', 'val']].plot(ax=ax1)
df_losses[['lr']].plot(ax=ax2);

In [None]:
%%time

# check loss on test split
evaluate_loss(llama, test_dataloader)

In [None]:
# Generate text using the trained LLM (llama) with a maximum of 500 tokens
generated_text = llama.generate(device, tokenizer, 500)
print(generated_text)

# Save the model

In [None]:
os.makedirs(EXPERIMENTS_DIR, exist_ok=True)

# save the entire model
torch.save(llama, os.path.join(EXPERIMENTS_DIR, 'llama.pth'))

# save only the model parameters
# torch.save(llama.state_dict(), os.path.join(EXPERIMENTS_DIR, 'llama_model_parameters.pth'))

In [None]:
# check loaded model
llama_loaded = torch.load(os.path.join(EXPERIMENTS_DIR, 'llama.pth'))

print(llama_loaded.generate(device, tokenizer, max_new_tokens=100))