# Ref:
- https://github.com/TimS-ml/nanoGPT
- https://youtu.be/kCc8FmEb1nY

In [1]:
import os
from boring_llm_base.constants import PROJECT_HOME_DIR
import sys; sys.path.append(str(PROJECT_HOME_DIR)); os.chdir(PROJECT_HOME_DIR)

In [2]:
import random
import tqdm
import gzip
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

In [3]:
from boring_utils.utils import (
    cprint, 
    tprint, 
    get_device
)

# Config and Data Loading

In [4]:
DEV = True 

if not DEV:
    batch_size = 64  # how many independent sequences will we process in parallel?
    block_size = 256  # what is the maximum context length for predictions?
    # max_iters = 5000
    max_iters = 4000
    eval_interval = 500
    learning_rate = 3e-4
    eval_iters = 200
    n_embd = 384
    n_embed = n_embd
    n_head = 6
    n_layer = 6
    dropout = 0.2

else:
    batch_size = 32
    block_size = 8
    # max_iters = 1000
    max_iters = 100
    eval_interval = 500
    learning_rate = 3e-4
    eval_iters = 200
    n_embd = 32
    n_embed = n_embd
    n_head = 4
    n_layer = 4
    dropout = 0.2

device = get_device()
# vocab_size = len(set(text))
cprint(device)

[93m<module> -> device:[0m
device(type='mps')


In [5]:
data_dir = os.getenv('DATA_DIR', './data/')
data_dir = os.path.join(data_dir, 'enwik8')

# # NOTE: only read enwik8 first 10M bytes
# with gzip.open(os.path.join(data_dir, 'enwik8.gz')) as file:
#     text = file.read(int(10e6)).decode('utf-8')

meta_path = os.path.join(data_dir, 'meta.pkl')
vocab_size = None
if os.path.exists(meta_path):
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    vocab_size = meta['vocab_size']
    stoi = meta['stoi']
    itos = meta['itos']
else:
    raise FileNotFoundError(f"Meta file {meta_path} not found")

encode = lambda s: [stoi[c] for c in s] 
decode = lambda l: ''.join([itos[i] for i in l]) 

In [8]:
train_bin_path = os.path.join(data_dir, 'train.bin')
val_bin_path = os.path.join(data_dir, 'val.bin')

# train_tensor = torch.tensor(encode(data), dtype=torch.long) # convert to tensor

# torch.long is just an alias for torch.int64
# load the binary data
train_data = np.fromfile(train_bin_path, dtype=np.uint16)
val_data = np.fromfile(val_bin_path, dtype=np.uint16)

# convert to pytorch tensors
train_data = torch.from_numpy(train_data.astype(np.int64))
val_data = torch.from_numpy(val_data.astype(np.int64))

# Data Loader

In [9]:
class TextSamplerDataset(Dataset):
    def __init__(self, data, block_size):
        self.data = data
        self.block_size = int(block_size)

    def __getitem__(self, index):
        # single sample
        ix = torch.randint(
            len(self.data) - self.block_size - 1, (1,)
        )
        full_seq = self.data[ix:ix + self.block_size + 1]
        x = full_seq[:-1]
        y = full_seq[1:]
        x, y = x.to(device), y.to(device)
        return x, y

    def __len__(self):
        return len(self.data) // self.block_size


train_dataset = TextSamplerDataset(train_data, block_size)
val_dataset = TextSamplerDataset(val_data, block_size)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Model

In [12]:
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        embedding_dim = vocab_size
        # embedding_dim = 128
        # each token is represented by a one-hot vector
        # directly reads off the logits for the next token from the embedding table
        # for example: 24 will reads off the 24th column of the embedding table
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

    def forward(self, idx, targets=None):
        # idx is (batch_size, block_size)
        logits = self.embedding(idx)  # B, T, C: (batch_size, block_size, embedding_dim)

        if targets == None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)  # (batch_size * block_size, embedding_dim)
            targets = targets.view(-1)  # (batch_size * block_size)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

model = BigramLanguageModel(vocab_size)
model.to(device)

BigramLanguageModel(
  (embedding): Embedding(2102, 2102)
)

In [13]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# Training

In [15]:
def cycle(loader):
    while True:
        for data in loader:
            yield data

train_iter = cycle(train_loader)
val_iter = cycle(val_loader)

for iter in range(max_iters):
    # Eval logic
    if iter % eval_interval == 0 or iter == max_iters - 1:
        model.eval()
        with torch.no_grad():
            val_losses = []
            for _, (x, y) in zip(range(eval_iters), val_iter):
                _, loss = model(x, y)
                val_losses.append(loss.item())
            val_loss = np.mean(val_losses)
            
            train_losses = []
            for _, (x, y) in zip(range(eval_iters), train_loader):
                _, loss = model(x, y)
                train_losses.append(loss.item())
            train_loss = np.mean(train_losses)
            
            print(f"step {iter}: train loss {train_loss:.4f}, val loss {val_loss:.4f}")
        model.train()

    # Training logic
    x, y = next(train_iter)  # replace get_batch
    logits, loss = model(x, y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 8.1311, val loss 8.1381
step 99: train loss 8.0976, val loss 8.0863
