In [2]:
import torch

  device: torch.device = torch.device(torch._C._get_default_device()),  # torch.device('cpu'),


In [3]:
with open('input.txt', 'r') as f:
    text = f.read()

In [4]:
# Creating the vocabulary.
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [5]:
# Make encoder and decoder.
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for ch,i in stoi.items()}

encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

print(encode('hello world !'))
print(decode(encode('hello world !')))

[46, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42, 1, 2]
hello world !


In [9]:
# Encode the whole dataset.
data = torch.tensor(encode(text),dtype=torch.long)
print(data.shape, data.dtype)

torch.Size([1115394]) torch.int64


In [10]:
# Split train/valid.
train_ratio = 0.9
n = int(len(data) * train_ratio)
train_data = data[:n]
val_data = data[n:] 

In [11]:
block_size = 8 # Context length.
train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [18]:
# 8 Characters starting at position 0.
x = train_data[:block_size]
# 8 Characters starting at position 1.
y = train_data[1:block_size+1]

for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"Input: {context} --> Target: {target}")

Input: tensor([18]) --> Target: 47
Input: tensor([18, 47]) --> Target: 56
Input: tensor([18, 47, 56]) --> Target: 57
Input: tensor([18, 47, 56, 57]) --> Target: 58
Input: tensor([18, 47, 56, 57, 58]) --> Target: 1
Input: tensor([18, 47, 56, 57, 58,  1]) --> Target: 15
Input: tensor([18, 47, 56, 57, 58,  1, 15]) --> Target: 47
Input: tensor([18, 47, 56, 57, 58,  1, 15, 47]) --> Target: 58


In [29]:
torch.manual_seed(1337)
batch_size = 4 
block_size = 8 # Context length.

def get_batch(split):
    # generate a small batch of data of inputs x and targets y.
    data = train_data if split == 'train' else val_data
    # Make sure that we can sample block_size character from the index. Reshape the output to :(batch_size,)
    ix = torch.randint(len(data) - block_size,(batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])

    return x, y

xb,yb = get_batch('train')

print(xb.shape)
print(yb.shape)

torch.Size([4, 8])
torch.Size([4, 8])


In [65]:
# Implementing a bigram LM

import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):

    def __init__(self,vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size,vocab_size)
    
    def forward(self, idx, targets=None):
        # ix : (B,T).
        # targets : (B,T).

        logits = self.token_embedding_table(idx) # (B,T,C).

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T,C)
            targets =  targets.view(B*T)
            loss = F.cross_entropy(logits,targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):

        for _  in range(max_new_tokens):
            # Get predictions
            logits, loss = self(idx)
            # Focus on the last time step only.
            logits = logits[:,-1,:] # --> (B,C).
            # Apply softmax to get probabilities.
            probs = F.softmax(logits, dim=-1) # --> (B,C).
            # Sample from the distribution.
            idx_next = torch.multinomial(probs, num_samples=1) # --> (B, 1)
            # Add to next character to the end of the sequence.
            idx = torch.cat((idx,idx_next), dim=1) # --> (B, T+1)

        return idx

m = BigramLanguageModel(vocab_size)
out,loss = m (xb,yb)
print(out.shape)
print(loss)
print(decode(m.generate(torch.zeros((1,1), dtype=torch.long), max_new_tokens=100)[0].tolist()))

torch.Size([32, 65])
tensor(4.8786, grad_fn=<NllLossBackward0>)

Sr?qP-QWktXoL&jLDJgOLVz'RIoDqHdhsV&vLLxatjscMpwLERSPyao.qfzs$Ys$zF-w,;eEkzxjgCKFChs!iWW.ObzDnxA Ms$3


In [66]:
optimizer = torch.optim.AdamW(m.parameters(),lr=1e-3)

In [75]:
batch_size = 32

for steps in range(10000):

    xb,yb = get_batch('train')
    logits, loss = m(xb,yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

2.640730381011963
2.5262227058410645
2.4100255966186523
2.41798996925354
2.4241445064544678
2.5674972534179688
2.4599177837371826
2.4021995067596436
2.3905389308929443
2.456805467605591
2.516369342803955
2.4208834171295166
2.4297196865081787
2.574061393737793
2.4604787826538086
2.4567244052886963
2.343668222427368
2.4841983318328857
2.485570192337036
2.447383403778076
2.512507438659668
2.3928565979003906
2.5701467990875244
2.4118235111236572
2.600541591644287
2.318314552307129
2.3499059677124023
2.5551655292510986
2.510410785675049
2.457348108291626
2.4669981002807617
2.336918830871582
2.3250463008880615
2.473686933517456
2.5458950996398926
2.540571928024292
2.4438090324401855
2.3987503051757812
2.519043207168579
2.4559662342071533
2.49477481842041
2.4640426635742188
2.493175506591797
2.373640775680542
2.3916642665863037
2.4220383167266846
2.4872374534606934
2.3661530017852783
2.483182191848755
2.4032347202301025
2.4344377517700195
2.485018253326416
2.365065336227417
2.379512071609497


In [84]:
print(decode(m.generate(torch.zeros((1,1), dtype=torch.long), max_new_tokens=300)[0].tolist()))


Thal tashe
Hod alin f bo cortreruto kigul f f fome, cth oronovou!
Wenky te ll.
JULous lloule drcr, o ard.
r:
Faworse ad bu p, puss s bleghit.
Youro n f ws.
IOulepalenedars OKENINo bead
DUThas hates t anome g me sis!-
CHARI alan homyse,
d ous k.
DUDUSheavey a,
Gaim pend, r
MESa s inense hile oy, tee'


# Math tricks for self-attention

In [132]:
torch.manual_seed(1337)
B,T,C = 4,8,2
x = torch.randn(B,T,C)

In [133]:
# Version 1
# Do the mean for each past element. 
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # --> (t,C)
        xbow[b,t] = torch.mean(xprev,0)

In [136]:
# Version 2
wei = torch.tril(torch.ones(T,T))
wei = wei / wei.sum(1,keepdim=True)
xbow2 = wei @ x # (T,T) @ (B,T,C) --> (B,T,C)

In [161]:
# Version 3
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros((T,T)) # Affinity matrix
wei = wei.masked_fill(tril==0,float("-inf"))
wei = F.softmax(wei,dim=-1)
xbow3 = wei @ x

In [117]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3,3))
a = a/torch.sum(a,1,keepdim=True)
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print(f"a=\n{a}\n--")
print(f"b=\n{b}\n--")
print(f"b=\n{c}")

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
b=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])
