In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from model import *
from train import read_data, get_batch, estimate_loss

In [2]:
input = read_data()
chars=sorted(list(set(input)))
print(''.join(chars))
print(len(chars))
vocab_size=len(chars)
# Begin with a character-level tokenizer
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

data = torch.tensor(encode(input), dtype=torch.long)

n=int(0.9*len(data))
train_data=data[:n]
val_data=data[n:]

def get_batch(split):
    data = train_data if split=='train' else val_data
    ix = torch.randint(len(data) - max_seq_len, (batch_size,))
    x = torch.stack([data[i:i+max_seq_len] for i in ix])
    y = torch.stack([data[i+1:i+max_seq_len+1] for i in ix])
    return x,y

@torch.no_grad()
def estimate_loss(model):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [3]:
xb, yb = get_batch('train')
print(xb)
print(yb)

tensor([[52, 42, 57,  1, 51, 43,  1, 61, 47, 58, 46,  1, 58, 46, 43,  1],
        [ 6,  1, 61, 47, 50, 58,  1, 58, 46, 53, 59,  1, 50, 43, 39, 60],
        [53,  1, 53, 59, 56,  1, 54, 56, 53, 41, 43, 43, 42, 47, 52, 45],
        [59, 57,  1, 25, 39, 56, 41, 47, 59, 57, 11,  1, 58, 46, 43, 57],
        [46, 43, 56,  5, 57,  1, 42, 43, 39, 58, 46,  0, 32, 39, 49, 43],
        [52,  1, 46, 53, 61,  1, 55, 59, 47, 41, 49, 50, 63,  1, 57, 46],
        [50, 63,  1, 53, 39, 58, 46, 10,  0, 32, 53,  1, 49, 43, 43, 54],
        [43, 44, 43, 56, 51, 43, 52, 58,  1, 42, 56, 53, 54,  1, 53, 52]])
tensor([[42, 57,  1, 51, 43,  1, 61, 47, 58, 46,  1, 58, 46, 43,  1, 44],
        [ 1, 61, 47, 50, 58,  1, 58, 46, 53, 59,  1, 50, 43, 39, 60, 43],
        [ 1, 53, 59, 56,  1, 54, 56, 53, 41, 43, 43, 42, 47, 52, 45, 57],
        [57,  1, 25, 39, 56, 41, 47, 59, 57, 11,  1, 58, 46, 43, 57, 43],
        [43, 56,  5, 57,  1, 42, 43, 39, 58, 46,  0, 32, 39, 49, 43,  1],
        [ 1, 46, 53, 61,  1, 55, 59, 

In [4]:
mod = GPTDecoderModel(vocab_size)
mod = mod.to(device)
print(sum(p.numel() for p in mod.parameters())/1e6, 'M parameters')

18.969665 M parameters


In [34]:
optimizer = torch.optim.Adam(mod.parameters(), lr=lr)
max_iters=10000
for iter in range(max_iters):

    if iter % 50 == 0 or iter == max_iters-1:
        losses = estimate_loss(mod)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    
    xb, yb = get_batch('train')
    
    logits, loss = mod(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 1.7808, val loss 1.9501
step 50: train loss 1.8076, val loss 1.9216
step 100: train loss 1.8088, val loss 1.9542
step 150: train loss 1.7942, val loss 1.9233
step 200: train loss 1.7932, val loss 1.9516
step 250: train loss 1.7759, val loss 1.9549
step 300: train loss 1.7709, val loss 1.9238
step 350: train loss 1.8036, val loss 1.9398
step 400: train loss 1.7555, val loss 1.9144
step 450: train loss 1.7770, val loss 1.9324
step 500: train loss 1.7942, val loss 1.9329
step 550: train loss 1.7742, val loss 1.9261
step 600: train loss 1.7870, val loss 1.9355
step 650: train loss 1.7714, val loss 1.9182
step 700: train loss 1.7697, val loss 1.9043
step 750: train loss 1.7543, val loss 1.9258
step 800: train loss 1.7389, val loss 1.9037
step 850: train loss 1.7561, val loss 1.9114
step 900: train loss 1.7683, val loss 1.9298
step 950: train loss 1.7681, val loss 1.9140
step 1000: train loss 1.7639, val loss 1.9161
step 1050: train loss 1.7347, val loss 1.9112
step 1100: 

In [35]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(mod.generate(context, max_new_tokens=500)[0].tolist()))


Ay. Margaret, in the destrens! king be!
He could mallen and nottly:
and, die not, good ney:
Twice, sir? fair; you shall rove me drask?
Whoe more post call the love;
And when must the send
Doth our to schess ang, and shrands which the toast an do not:
To have hunce anothe
to Lords, the fuled, made forlal of the portedille
Proqued, for should hit, which Englings his depossent know.
Cheep deid yield doth belike 'Will his uttime
With sit.

GLOUCESTER:
I see think summmservall child bawg keep so rece


In [36]:
torch.save(mod, 'model_01022023')