In [4]:
import random
import torch
from torch import nn
from torch.nn import functional as F

torch.manual_seed(1337)

batch_size = 64
chunk_size = 128 # how wide in time
corpus_size = 65
emb_size = 384
learning_rate = 4e-3
train_iter = 100000
val_iter = 500
device = "mps"
head_size = 16
num_head = 6
num_block = 6
dropout = 0.2

In [1]:
data_file = "dataset/shakespeare.txt"

with open(data_file, 'r') as f:
    data_raw = f.read()

In [5]:
corpus = sorted(list(set(data_raw)))
"".join(corpus), len(corpus)

("\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz", 65)

In [11]:
ctoi = {c: idx for idx, c in enumerate(corpus)}
itoc = {idx: c for idx, c in enumerate(corpus)}
encode = lambda s: [ctoi[c] for c in s]
decode = lambda tokens: "".join([itoc[i] for i in tokens])

In [12]:
data_raw[:10], encode(data_raw[:10]), decode(encode(data_raw[:10]))

('First Citi', [18, 47, 56, 57, 58, 1, 15, 47, 58, 47], 'First Citi')

In [13]:
data_token = encode(data_raw)

In [33]:
data_length = len(data_token)
data_token_train = torch.tensor(data_token[:int(0.9*data_length)], dtype=torch.long)
data_token_val = torch.tensor(data_token[-int(0.1*data_length):], dtype=torch.long)
len(data_token_train), len(data_token_val)

(1003854, 111539)

In [34]:
batch_size = 4
chunk_size = 8

def get_batch(split):
    sample_x = []
    sample_y = []
    data_split = data_token_train if split == "train" else data_token_val
    max_samples = len(data_split)
    for _ in range(batch_size):
        idx = random.randint(0, max_samples - chunk_size - 1)
        x = data_split[idx:idx+chunk_size]
        y = data_split[idx+1:idx+chunk_size+1]
        sample_x.append(x)
        sample_y.append(y)

    return torch.stack(sample_x), torch.stack(sample_y)
    

In [35]:
sample_x, sample_y = get_batch("train")
print("x:\n", sample_x)
print("y:\n", sample_y)

x:
 tensor([[52, 42,  1, 57, 46, 39, 51, 43],
        [58, 43, 56,  1, 57, 43, 56, 60],
        [ 8,  0, 13, 61, 39, 63,  1, 61],
        [52, 43,  1, 46, 53, 52, 53, 59]])
y:
 tensor([[42,  1, 57, 46, 39, 51, 43,  1],
        [43, 56,  1, 57, 43, 56, 60, 47],
        [ 0, 13, 61, 39, 63,  1, 61, 47],
        [43,  1, 46, 53, 52, 53, 59, 56]])


In [60]:
class LanguageModel(nn.Module):

    def __init__(self, token_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(token_size, token_size)

    def forward(self, x, target=None): # B, T

        logits = self.token_embedding_table(x) # B, T, C

        if target is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits_c = logits.view(B*T, C)
            target_c = target.view(B*T)
            loss = F.cross_entropy(logits_c, target_c)

        return logits, loss
    
    def generate(self, x, max_new_token=500):

        for _ in range(max_new_token):
            logits, _ = self(x)

            logits = logits[:, -1, :] # focus on last prediction, B, C
            logits = F.softmax(logits, dim=-1) # get probabilities, B, C
            next_logit = torch.multinomial(logits, 1) # B, 1
            
            x = torch.cat([x, next_logit], dim=1) # B, T + 1

        return x

m = LanguageModel(len(corpus))
logit, loss = m(sample_x, sample_y)
print("sample_x", sample_x.shape)
print("sample_y", sample_y.shape)
print("logit", logit.shape)
print("loss", loss)

decode(m.generate(sample_x, 100)[0].tolist())

sample_x torch.Size([4, 8])
sample_y torch.Size([4, 8])
logit torch.Size([4, 8, 65])
loss tensor(4.5782, grad_fn=<NllLossBackward0>)


"nd shameh!QcF$-xNNb!Q,jIcJwsCddvFDdna\nHkTcDCsnaQEFPtc&YKK:;cMz'ZL!eV-kFhNbe$UUftHOiSC'HG hrYviQEYaPRw?fo eRU"

In [61]:
optimizer = torch.optim.Adam(m.parameters(), lr=1e-3)

  from .autonotebook import tqdm as notebook_tqdm


In [86]:
for iter in range(1000):

    sample_x, sample_y = get_batch("train")

    logit, loss = m(sample_x, sample_y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print(loss)

tensor(2.8222, grad_fn=<NllLossBackward0>)


In [98]:
print(decode(m.generate(torch.zeros(1, 1, dtype=torch.long), max_new_token=500)[0].tolist()))


D tonoucoWh!IND wfor's.

Toprat hit he VNSimideCLvilleroftht Ehadw salim,
as m blee
DHofaive t meLLkithkilitth BdV3e, ty paim3CoVbt
qus p ser hizhd,
OAs, t
NRPmoup Wbe gemmineet m fr m?zDjmy, n s f ace th h.':

m burs hr thanoufurd:
KSYHI mshanbe ddeancaHGxyist, ourlawr peaveUe byouptix., w d men y atakn, so hip afogller-CLAJBOUCLoee the yne fung w'd YV!PDCarit acilar.
KWleay-htealy!orth oy l t lad: ar ; yofas hy.

TouRR om.


Bozbeme I:
Mq, mengme,
Br linghanh paverr bS:
He LWe he.Yg ow'fos do?


In [127]:
# attention using softmax
B, T, C = batch_size, chunk_size, 2

x = torch.rand(B, T, C)
print("x\n", x) # B, T, C

tril = torch.tril(torch.ones(T, T))
# print("tril", tril)

wei = torch.zeros(T, T)
# print("wei", wei)

wei = wei.masked_fill(tril == 0., float("-inf"))
# print("wei", wei)

wei = F.softmax(wei, dim=-1)
print("wei\n", wei)


xbow = wei @ x  # (B, T, T) @ (B, T, C) ----> (B, T, C)
print("xbow\n", xbow, xbow.shape)
print(wei.shape, x.shape, xbow.shape)

x
 tensor([[[0.5729, 0.3842],
         [0.5366, 0.1753],
         [0.8483, 0.0556],
         [0.2992, 0.4036],
         [0.0181, 0.0498],
         [0.9534, 0.5142],
         [0.1559, 0.0534],
         [0.9929, 0.2999]],

        [[0.2849, 0.7844],
         [0.7858, 0.6876],
         [0.1753, 0.2913],
         [0.5687, 0.3679],
         [0.9547, 0.7983],
         [0.7168, 0.8815],
         [0.3543, 0.6020],
         [0.1930, 0.6659]],

        [[0.6014, 0.2786],
         [0.7314, 0.2053],
         [0.7686, 0.4646],
         [0.2266, 0.9750],
         [0.6254, 0.1673],
         [0.3122, 0.8484],
         [0.6277, 0.6240],
         [0.5854, 0.4490]],

        [[0.7968, 0.9988],
         [0.6650, 0.4451],
         [0.0373, 0.2391],
         [0.3098, 0.4886],
         [0.6401, 0.8738],
         [0.2617, 0.8151],
         [0.4653, 0.8703],
         [0.9102, 0.0410]]])
wei
 tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.000

In [None]:
# self attention
channel_size = 12
head_size = 16
B, T, C = batch_size, chunk_size, channel_size

x = torch.randint((B, T, C)).float()

query = nn.Linear(C, head_size, bias=False)
key = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

q = query(x) # B, T, 16
k = key(x)   # B, T, 16
v = value(x) # B, T, 16

wei = q @ k.transpose(-2, -1) # (B, T, 16) x (B, 16, T) -> B, T, T
# kq = F.softmax(kq, dim=-1)

tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float("-inf"))
wei = F.softmax(wei, dim=-1)

out = wei @ v # (B, T, T) x (B, T, 16) -> (B, T, 16)

In [5]:
position_embedding_table = nn.Embedding(chunk_size, emb_size)

In [6]:
pos_emb = position_embedding_table(torch.arange(chunk_size)) # (T,C)

In [8]:
pos_emb.shape

torch.Size([128, 384])