In [1]:
text = ""
with open("input.txt", "r", encoding="utf-8") as f:
    text = f.read()

In [2]:
len(text)

1115394

In [3]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print("".join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [4]:
stoi = {ch:i for i, ch in enumerate(chars)}
itos = {i:ch for i, ch in enumerate(chars)}
encode = lambda x : [stoi[i] for i in x]
decode = lambda x : "".join(itos[i] for i in x)
print(encode("hello there"))
print(decode(encode("hello there")))

[46, 43, 50, 50, 53, 1, 58, 46, 43, 56, 43]
hello there


In [5]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

In [6]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]


In [7]:
torch.manual_seed(1337)
batch_size = 4
block_size = 8

def get_batch(sample):
    data = train_data if sample == "train" else val_data
    ix = torch.randint(len(data)-block_size, size=(batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+1+block_size] for i in ix])
    return x, y

In [8]:
xb, yb = get_batch("train")
for b in range(batch_size):
    for t in range(block_size):
        inp = xb[b, :t+1]
        target = yb[b, t]
        print(f"When input is {inp} target is {target}")


When input is tensor([24]) target is 43
When input is tensor([24, 43]) target is 58
When input is tensor([24, 43, 58]) target is 5
When input is tensor([24, 43, 58,  5]) target is 57
When input is tensor([24, 43, 58,  5, 57]) target is 1
When input is tensor([24, 43, 58,  5, 57,  1]) target is 46
When input is tensor([24, 43, 58,  5, 57,  1, 46]) target is 43
When input is tensor([24, 43, 58,  5, 57,  1, 46, 43]) target is 39
When input is tensor([44]) target is 53
When input is tensor([44, 53]) target is 56
When input is tensor([44, 53, 56]) target is 1
When input is tensor([44, 53, 56,  1]) target is 58
When input is tensor([44, 53, 56,  1, 58]) target is 46
When input is tensor([44, 53, 56,  1, 58, 46]) target is 39
When input is tensor([44, 53, 56,  1, 58, 46, 39]) target is 58
When input is tensor([44, 53, 56,  1, 58, 46, 39, 58]) target is 1
When input is tensor([52]) target is 58
When input is tensor([52, 58]) target is 1
When input is tensor([52, 58,  1]) target is 58
When inpu

In [9]:
import torch
import torch.nn as nn
from torch.nn import functional as F
class BiagramModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.tokenEmbeddingTable = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, idx, targets=None):
        logits = self.tokenEmbeddingTable(idx)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    def train(self, n_iter):
        optim = torch.optim.AdamW(self.parameters(), lr=1e-3)
        for _ in range(n_iter):
            xb, yb = get_batch("train")
            logits, loss = self(xb, yb)
            optim.zero_grad(set_to_none=True)
            loss.backward()
            optim.step()
            print(loss.item()) if _ % 200 == 0 else None


    def generate(self, idx, max_new_tokens):
        for i in range(max_new_tokens):
            logits, loss = self(idx)
            logits = logits[:,-1,:]
            probs = F.softmax(logits, dim=-1)
            new_idx = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, new_idx),dim=1)
        return idx


model = BiagramModel(vocab_size)
logits, loss = model(xb, yb)
print(logits.shape, loss)
idx = torch.zeros((1,1), dtype=torch.long)
generated_idx = model.generate(idx, max_new_tokens=100)[0].tolist()
print(decode(generated_idx))

print("\nTraining starts..")
model.train(100000)
print("\nTraining finished!")
idx = torch.zeros((1,1), dtype=torch.long)
generated_idx = model.generate(idx, max_new_tokens=100)[0].tolist()
print(decode(generated_idx))



torch.Size([32, 65]) tensor(5.0364, grad_fn=<NllLossBackward0>)

lfJeukRuaRJKXAYtXzfJ:HEPiu--sDioi;ILCo3pHNTmDwJsfheKRxZCFs
lZJ XQc?:s:HEzEnXalEPklcPU cL'DpdLCafBheH

Training starts..
4.873950481414795
4.295483112335205
4.587007522583008
4.13039493560791
4.277191638946533
4.198069095611572
3.7820565700531006
3.751546621322632
3.958277463912964
3.815798282623291
3.5388879776000977
3.47265362739563
3.478971481323242
3.308586835861206
3.5024595260620117
3.5213165283203125
3.2738354206085205
2.8871004581451416
2.7784464359283447
3.138400077819824
3.103850841522217
2.9894638061523438
2.755727529525757
2.6786446571350098
2.6702511310577393
2.8958916664123535
2.62581467628479
2.770490884780884
2.591963052749634
2.7645633220672607
2.777736186981201
2.386279582977295
2.7074332237243652
2.4616174697875977
2.9699602127075195
2.505662202835083
2.763460159301758
2.492985486984253
2.7798407077789307
2.212958812713623
2.3677852153778076
2.656628131866455
2.394922971725464
2.439276933670044
2.5632700

In [10]:
idx = torch.zeros((1,1), dtype=torch.long)
generated_idx = model.generate(idx, max_new_tokens=500)[0].tolist()
print(decode(generated_idx))


N om tond haindrghory seed lyosthangheand y t my
HA:
heemome; se sis, at!
IINes h ad uor
MInds, ariengo in nt lerrin,
S:
KINUpy d t: wes.
thos ure chiowinde t IXEn n urs lld d ly in;
Je a lds omys, TETht pe ghestonour weveg; wsckistow ghettstond, bamy, tathosplistorloteryois oupeear't y ishatwheas me;
AURY: ignjowongazehy wf incrus om s s.
WANG s thoth her hthen l therefotentishey cado ad kep:
DUS:
SThe nthilime geloressevis;
NDouin ty be thod theathe fol a
Four.
ICOfin mace g ber d:
Ther otheat


## Self Attention Block

In [13]:
torch.manual_seed(1337)
B, T, C = 4, 8, 32
x = torch.rand(B,T,C)

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x)
q = query(x)

wei = q @ k.transpose(-2,-1)

tril = torch.tril(torch.ones(T,T))
wei = wei.masked_fill(tril==0, value=float("-inf"))
wei = F.softmax(wei, dim=-1)

v = value(x)

out = wei @ v

print(out.shape)
print(wei)

torch.Size([4, 8, 16])
tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4409, 0.5591, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2975, 0.3373, 0.3652, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2211, 0.2898, 0.2236, 0.2654, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1832, 0.2163, 0.1954, 0.2437, 0.1614, 0.0000, 0.0000, 0.0000],
         [0.1330, 0.2227, 0.1784, 0.2159, 0.1044, 0.1456, 0.0000, 0.0000],
         [0.1283, 0.1367, 0.1385, 0.1522, 0.1083, 0.1341, 0.2021, 0.0000],
         [0.1064, 0.1332, 0.1265, 0.1445, 0.0940, 0.1200, 0.1231, 0.1524]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4150, 0.5850, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2313, 0.3588, 0.4098, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2210, 0.2636, 0.2829, 0.2324, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1771, 0.2003, 0.2343, 0.2048, 0.1834, 0.0000, 0.0000, 0.0000],


In [2]:
import torch
torch.cuda.is_available()

True