In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from collections import Counter
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import torch
from torch.utils.data import DataLoader

from llama.data_pipeline import tiny_shakespeare, dataset
from llama.model.tokenizer import CharacterTokenizer
from llama.model.custom_layers import *
from llama.model.custom_blocks import *
from llama.model import model, training
from llama.constants import *


device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# Configuration object for model parameters
CONFIG = {
    'vocab_size': -1,        # TBD based on dataset
    'batch_size': 64,        # Number of batches to be processed at each random split
    'epochs': 3,             # Number of training epochs
    'context_window': 16,    # Number of characters in each input (x) and target (y) sequence of each batch
    'd_model': 128,          # Dimension of linear layers (128)
    'n_heads': 8,            # number of attention heads
    'n_layers': 4,           # Set the number of layers to 4
}

experiment_dir = os.path.join(EXPERIMENTS_DIR, 'chartokenizer_llama_shakespeare')
os.makedirs(experiment_dir, exist_ok=True)

# Load dataset

In [None]:
# stores the tiny shakespeare dataset to disk
data_path = tiny_shakespeare.download_tiny_shakespeare()

# Read the content of the dataset
with open(data_path, 'r') as f:
    lines = f.read()

In [None]:
# Create a sorted list of unique characters in the dataset
character_counts = Counter(lines)
vocab = sorted(list(set(lines)))
print(character_counts)

# update the vocabulary size in the configuration
CONFIG['vocab_size'] = len(vocab)

# Output the total number of characters in our dataset (Vocabulary Size)
print(f'Total number of characters in our dataset (Vocabulary Size): {CONFIG["vocab_size"]}')

# Create tokenizer

In [None]:
# create tokenizer
tokenizer = CharacterTokenizer(vocab)

# check encode and decode functions
tokenizer.decode(tokenizer.encode("hello world!"))

In [None]:
# split the text data
train_split = lines[:int(0.8 * len(lines))]
val_split = lines[int(0.8 * len(lines)): int(0.9 * len(lines))]
test_split = lines[int(0.9 * len(lines)):]

# create a dataset for each split
train_dataset = dataset.TextDataset([train_split], tokenizer, CONFIG['context_window'], device)
val_dataset = dataset.TextDataset([val_split], tokenizer, CONFIG['context_window'], device)
test_dataset = dataset.TextDataset([test_split], tokenizer, CONFIG['context_window'], device)

# create a dataloader for each split
bs = CONFIG['batch_size']
train_dataloader = DataLoader(train_dataset, batch_size=bs, shuffle=True, drop_last=True)
val_dataloader = DataLoader(val_dataset, batch_size=bs, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=bs, shuffle=False)

print(f"Steps train: {len(train_dataloader)}, val: {len(val_dataloader)}, test: {len(test_dataloader)}")

# Create model and train

In [None]:
# create the Llama model
llama = model.Llama(CONFIG)
llama = llama.to(device)
print(f"Model params: {sum([p.numel() for p in llama.parameters()]):,}")

# create the corresponding optimizer
optimizer = torch.optim.Adam(llama.parameters(), lr=1e-3)

# create a step learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

In [None]:
%%time

df_losses = training.train(llama, optimizer, train_dataloader, val_dataloader, CONFIG['epochs'], lr_scheduler)

In [None]:
fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(14, 4))

df_losses[['train', 'val']].plot(ax=ax1)
df_losses[['lr']].plot(ax=ax2);

In [None]:
%%time

# check loss on test split
training.evaluate_loss(llama, test_dataloader)

In [None]:
# Generate text using the trained LLM (llama) with a maximum of 500 tokens
generated_text = llama.generate(device, tokenizer, 500)
print(generated_text)

# Save the model

In [None]:
# save the entire model
torch.save(llama, os.path.join(experiment_dir, 'llama.pth'))

# save only the model parameters
# torch.save(llama.state_dict(), os.path.join(experiment_dir, 'llama_model_parameters.pth'))

In [None]:
# check loaded model
llama_loaded = torch.load(os.path.join(experiment_dir, 'llama.pth'))

print(llama_loaded.generate(device, tokenizer, max_new_tokens=100))