In [1]:
# Load the autoreload extension
%load_ext autoreload

# Set autoreload to reload all modules before executing code
%autoreload 2

from mini_transformers.data_load import ShakespeareDataset
from mini_transformers.models.bigram_model import BigramModel
from mini_transformers.models.embedding_model import (
    SimpleEmbedding, HeadEmbedding, PositionHeadEmbedding, 
    MultiHeadedAttentionEmbedding, ResidualBlockAttentionEmbedding, GPT)
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

CONTEXT_LEN = 256
ds = ShakespeareDataset(context_lenght=CONTEXT_LEN)

# ds


In [2]:
train_ds, valid_ds = ds.train_valid_subsets()

In [3]:

torch.manual_seed(1337)
# simple_text_generator = BigramModel(SimpleEmbedding(len(ds.vocabulary)))
# head_text_generator = BigramModel(HeadEmbedding(len(ds.vocabulary), 100))
# position_text_generator = BigramModel(PositionHeadEmbedding(len(ds.vocabulary), 100, context_len=CONTEXT_LEN))
# attention_text_generator = BigramModel(SingleHeadedAttentionEmbedding(vocab_size=len(ds.vocabulary), n_embeds=100, head_size=100, context_len=CONTEXT_LEN))
# attention_text_generator = BigramModel(MultiHeadedAttentionEmbedding(vocab_size=len(ds.vocabulary), n_embeds=100, n_heads=3, head_size=100, context_len=CONTEXT_LEN)).to('mps')
# attention_text_generator = BigramModel(ResidualBlockAttentionEmbedding(vocab_size=len(ds.vocabulary), n_layers=6, n_embeds=100, n_heads=3, context_len=CONTEXT_LEN)).to('mps')
attention_text_generator = BigramModel(GPT(vocab_size=len(ds.vocabulary), 
                                           n_layers=6, n_embeds=384, n_heads=6, 
                                           context_len=CONTEXT_LEN, dropout=0.2)).to('mps')
text_generator = attention_text_generator
# text_generator = position_text_generator
gen_text = ds.vocabulary.decode(text_generator.generate().squeeze().tolist())
print(gen_text)




QOecOECVP?rXA;:EGs3N,gFZBcCQ$USu'&EP&qcQcUCEPzBmdtI'
bXpHHXr-wkTIdjekufuu3yb'HYMb3y-ZDHVqRTA;lJaNG


In [4]:
optimizer = torch.optim.AdamW(text_generator.parameters(), lr=1e-4)

In [6]:
batch_size = 64
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
for xb, yb in (pbar := tqdm(train_loader)):
    logits, loss = text_generator(xb.to('mps'), yb.to('mps'))
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    pbar.set_description(f'{loss = :.3f}')
    optimizer.step()



loss = 0.346: 100%|██████████| 15686/15686 [1:42:53<00:00,  2.54it/s]


In [16]:
print(ds.vocabulary.decode(text_generator.generate(max_new_tokens=800, top_k=15).squeeze().tolist()))




PROSPERO:e that he did befre you not
A better score in spite, for a Christian peace!
What a sist rime, wound her that's bolted
Still slaughter'd by Viencio, so speak and hour!
he hath their tods: lether hear he make him lose; away. When.
But with an unkindness of raves me?
Nurse! Where is yond shall be here and supperfeCant
Are speaks that we should be? We shout they seeds,
Aid
In that I see the very boot!
From sure, as I sea--she hath a done of men,
His young stands upon the way of heaven heart,
Injury alives his wise; and when they must pace,--O,
Piter your quarres, you wounds beauty is mistard.--
She is to much my father's fair? What's the mater?
Dost thou an arm mourn; I quarrely would love,
And I feel out of the mount of heaven with ease?
Ratcliff! a very farmer of life,
In the meat


In [10]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0, 10, (3,2)).float()
c = a @ b

print(f'{a = }')
print(f'{b = }')
print(f'{c = }')

a = tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
b = tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
c = tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [11]:
torch.manual_seed(1337)
B,T,C = 4,8,2 # batch, time, channels
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

In [12]:
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # (t,C)
        xbow[b,t] = torch.mean(xprev, 0)

In [13]:
# version 2: using matrix multiply for a weighted aggregation
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x # (B, T, T) @ (B, T, C) ----> (B, T, C)
torch.allclose(xbow, xbow2)

True

In [14]:
# version 3: use Softmax
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x
torch.allclose(xbow2, xbow3)

True

In [None]:
from torch import nn
torch.manual_seed(1337)

B, T, C = 4, 8, 32

x = torch.randn(B, T, C)

head_size = 16

query = nn.Linear(C, head_size, bias=False) # (C, H)
key = nn.Linear(C, head_size, bias=False) # (C, H)
value = nn.Linear(C, head_size, bias=False)

k, q = key(x), query(x) # (B, T, H)
v = value(x)

wei = q @ k.transpose(-2, -1) # (B, T, H) x (B, H, T) = (B, T, T)

tril = torch.tril(torch.ones(T, T))

wei.masked_fill(tril==0, float('-inf'))
wei = F.softmax(wei, dim=-1)
out = wei @ v 

out.shape


torch.Size([4, 8, 16])