In [1]:
import torch
from transformer_models import GPT
from gpt_utils import process_text, get_batch, estimate_loss

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(1337)

<torch._C.Generator at 0x10a941250>

In [3]:
batch_size = 16
block_size = 64
embed_size = 96
n_heads = 4
n_layers = 3
dropout = 0.2
bias = False

max_iters = 5000
lr = 3e-4
eval_iters = 100
eval_interval = 100

In [4]:
with open('./data/text.txt', 'r', encoding='utf-8') as f:
    text = f.read()

vocab_size, encode, decode = process_text(text)
data = torch.tensor(encode(text), dtype=torch.long)

n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [5]:
gpt = GPT(vocab_size, n_layers, n_heads, embed_size, block_size, dropout, bias)
gpt = gpt.to(device)

optimizer = torch.optim.AdamW(gpt.parameters(), lr=lr)

Number of parameters: 0.338688M


In [6]:
for i in range(max_iters):
    if i == 0 or (i+1)%eval_interval == 0 or i == max_iters-1:
        losses = estimate_loss(gpt, {'train': train_data, 'val': val_data}, eval_iters, batch_size, block_size, device)
        print(f"step {i+1}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch(train_data, batch_size, block_size, device)

    logits, loss = gpt(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 1: train loss 4.1944, val loss 4.1910
step 100: train loss 2.8179, val loss 2.8313
step 200: train loss 2.6406, val loss 2.6302
step 300: train loss 2.5465, val loss 2.5529
step 400: train loss 2.5046, val loss 2.5059
step 500: train loss 2.4792, val loss 2.4757
step 600: train loss 2.4517, val loss 2.4579
step 700: train loss 2.4348, val loss 2.4243
step 800: train loss 2.4175, val loss 2.4171
step 900: train loss 2.4034, val loss 2.3982
step 1000: train loss 2.3703, val loss 2.3810
step 1100: train loss 2.3578, val loss 2.3700
step 1200: train loss 2.3402, val loss 2.3479
step 1300: train loss 2.3221, val loss 2.3261
step 1400: train loss 2.3043, val loss 2.3062
step 1500: train loss 2.2768, val loss 2.2959
step 1600: train loss 2.2757, val loss 2.2809
step 1700: train loss 2.2474, val loss 2.2574
step 1800: train loss 2.2322, val loss 2.2483
step 1900: train loss 2.2191, val loss 2.2252
step 2000: train loss 2.1939, val loss 2.2150
step 2100: train loss 2.1825, val loss 2.2089


In [8]:
torch.save(gpt.state_dict(), "./ckpt/gpt.pt")