In [10]:
import os
import numpy as np

import torch
from boring_transformer.boring_transformer import *
from boring_transformer.boring_gpt import *
from boring_transformer.utils import *

device = get_device()
cprint(device)

[93mdevice:[0m
device(type='cuda')


# Config

In [None]:
vocab_size = 11706

# batch_size=32
# block_size=1024
batch_size = 64  # how many independent sequences will we process in parallel?
block_size = 256  # what is the maximum context length for predictions?
# max_iters = 5000
max_iters = 4000
eval_interval = 500
learning_rate = 3e-4
eval_iters = 200
n_embd = 384
n_embed = n_embd
n_head = 6
n_layer = 6
dropout = 0.2

# Load Data

In [3]:
dataset = 'shakespeare'
data_dir = os.path.join('./data', dataset)
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')

In [12]:
# import pickle

# meta_path = os.path.join(data_dir, 'meta.pkl')
# meta_vocab_size = None
# if os.path.exists(meta_path):
#     with open(meta_path, 'rb') as f:
#         meta = pickle.load(f)
#     meta_vocab_size = meta['vocab_size']
#     print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})")
# else:
#     all_data = np.concatenate((train_data, val_data))
#     unique_tokens = np.unique(all_data)
#     vocab_size = len(unique_tokens)

#     # 11706
#     cprint(vocab_size)

[93mvocab_size:[0m
11706


In [7]:
X, Y = get_batch_np(train_data, block_size=block_size, batch_size=batch_size, device=device)

In [9]:
X.shape  # block_size=1024, batch_size=32

torch.Size([32, 1024])

# Training the NN

In [None]:
# m = BoringTransformerModel(
#     vocab_size, n_embd, n_head, n_layer, n_embd * 4, block_size, dropout
# ).to(device)
m = BoringGPT(
    vocab_size, n_embd, n_head, n_layer, n_embd * 4, block_size, dropout
).to(device)

# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

In [None]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [None]:
for iter in range(max_iters):
    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss(m, train_data, val_data, block_size, batch_size, eval_iters, device)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch(train_data, block_size, batch_size, device)

    # evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))