In [12]:
import pickle
import numpy as np
from pathlib import Path
import torch
import sys

sys.path.append(str((Path.cwd() / "nanochat").resolve()))
from nanochat.gpt import GPT, GPTConfig

In [13]:
data_dir = Path('data/processed/')
meta = pickle.loads((data_dir / 'meta.pkl').read_bytes())
vocab_size = meta['vocab_size']
context_length = meta['context_length']
train_tokens = np.fromfile(data_dir / 'train.bin', dtype=np.uint16)
val_tokens = np.fromfile(data_dir / 'val.bin', dtype=np.uint16)
(vocab_size, context_length, len(train_tokens), len(val_tokens))


(2652, 512, 185693, 18546)

In [14]:
def nanogpt_iter(data, block_size, batch_size, device):
    max_start = data.size(0) - block_size - 1
    idx = torch.randint(0, max_start, (batch_size,))
    x = torch.stack([data[i : i + block_size] for i in idx]).to(device)
    y = torch.stack([data[i + 1 : i + block_size + 1] for i in idx]).to(device)
    return x, y


In [15]:
device = torch.device('cuda' if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else 'cpu'))
config = GPTConfig(sequence_len=context_length, vocab_size=vocab_size, n_layer=4, n_head=4, n_kv_head=4, n_embd=256)
model = GPT(config).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)


In [16]:
train_data = torch.from_numpy(train_tokens.astype(np.int64))
val_data = torch.from_numpy(val_tokens.astype(np.int64))
train_data, val_data = train_data.to(device), val_data.to(device)
max_iters = 1000
eval_interval = 50
batch_size = 32
for step in range(1, max_iters + 1):
    xb, yb = nanogpt_iter(train_data, context_length, batch_size, device)
    logits = model(xb)
    loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), yb.view(-1))
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    if step % eval_interval == 0 or step == 1:
        xb_val, yb_val = nanogpt_iter(val_data, context_length, batch_size, device)
        with torch.no_grad():
            logits_val = model(xb_val)
            val_loss = torch.nn.functional.cross_entropy(logits_val.view(-1, logits_val.size(-1)), yb_val.view(-1))
        print(f'step {step:04d}/{max_iters} train {loss.item():.4f} val {val_loss.item():.4f}')


step 0001/1000 train 8.0434 val 7.9046
step 0050/1000 train 5.7657 val 5.9648
step 0100/1000 train 5.1905 val 5.6379
step 0150/1000 train 4.8138 val 5.3580
step 0200/1000 train 4.5565 val 5.1652
step 0250/1000 train 4.1780 val 4.9470
step 0300/1000 train 4.0527 val 4.9744
step 0350/1000 train 3.7909 val 4.8884
step 0400/1000 train 3.6483 val 4.9511
step 0450/1000 train 3.3473 val 4.8890
step 0500/1000 train 3.2232 val 4.9984
step 0550/1000 train 3.0599 val 5.0294
step 0600/1000 train 2.8667 val 4.9826
step 0650/1000 train 2.7100 val 5.1352
step 0700/1000 train 2.4854 val 5.1068
step 0750/1000 train 2.3693 val 5.3237
step 0800/1000 train 2.2386 val 5.3581
step 0850/1000 train 2.0834 val 5.3924
step 0900/1000 train 1.9140 val 5.4725
step 0950/1000 train 1.8383 val 5.4555
step 1000/1000 train 1.7137 val 5.5194


In [17]:
torch.save({'model': model.state_dict(), 'meta': {'model_config': config.__dict__, 'tokenizer': meta}}, 'chess_min.pt')