In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)


<torch._C.Generator at 0x7f86d0505190>

In [8]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [9]:
print("length of dataset in characters: ", len(text))

length of dataset in characters:  1115394


In [10]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [13]:
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda n: ''.join([itos[i] for i in n])

data = torch.tensor(encode(text), dtype=torch.long)


Split the data into tranining and validation sets

In [14]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [15]:
block_size = 8
train_data[:block_size +1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [16]:
batch_size = 4
block_size = 8

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data)- block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x,y

In [36]:
xb,yb = get_batch('train')

In [18]:
print('inputs:\n',xb,xb.shape)
print('targets:\n',yb,yb.shape)

inputs:
 tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]]) torch.Size([4, 8])
targets:
 tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]]) torch.Size([4, 8])


In [37]:
xb.shape

torch.Size([32, 8])

In [19]:
for b in range(batch_size):
    for t in range(block_size):
        context = xb[b,:t+1]
        target = yb[b,t]
        print(f'when context is {context.tolist()} the target is: {target}')

when context is [24] the target is: 43
when context is [24, 43] the target is: 58
when context is [24, 43, 58] the target is: 5
when context is [24, 43, 58, 5] the target is: 57
when context is [24, 43, 58, 5, 57] the target is: 1
when context is [24, 43, 58, 5, 57, 1] the target is: 46
when context is [24, 43, 58, 5, 57, 1, 46] the target is: 43
when context is [24, 43, 58, 5, 57, 1, 46, 43] the target is: 39
when context is [44] the target is: 53
when context is [44, 53] the target is: 56
when context is [44, 53, 56] the target is: 1
when context is [44, 53, 56, 1] the target is: 58
when context is [44, 53, 56, 1, 58] the target is: 46
when context is [44, 53, 56, 1, 58, 46] the target is: 39
when context is [44, 53, 56, 1, 58, 46, 39] the target is: 58
when context is [44, 53, 56, 1, 58, 46, 39, 58] the target is: 1
when context is [52] the target is: 58
when context is [52, 58] the target is: 1
when context is [52, 58, 1] the target is: 58
when context is [52, 58, 1, 58] the target

In [26]:
class BigramLanguageModel(nn.Module):
    def __init__(self,vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size,vocab_size)
    
    def forward(self,idx, targets=None):
        logits = self.token_embedding_table(idx)
        
        if targets is None:
            loss = None
        else:
            B, T , C = logits.shape
            logits = logits.view(B*T,C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    def generate(self,idx,max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self(idx)
            logits = logits[:,-1,:]
            probs = F.softmax(logits, dim = -1)
            idx_next = torch.multinomial(probs,num_samples = 1)
            idx = torch.cat((idx,idx_next),dim = 1)
        return idx


In [28]:
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb,yb)
print(logits.shape)
print(loss)

torch.Size([32, 65])
tensor(4.7288, grad_fn=<NllLossBackward0>)


In [31]:
idx = torch.zeros((1,1),dtype = torch.long)
print(decode((m.generate(idx,max_new_tokens = 100)[0]).tolist()))


dJKVkNxiMeNYjdOIh3yqDkUBH
FZX;cTbVpwH
UvCEFzXYVjEI'V-?apGhR chtuexzJSIU
:uipGfLZ3i$mpXllY:JMFlj;
XRG


In [33]:
optimizer = torch.optim.Adam(m.parameters(),lr=1e-3)

In [35]:
batch_size = 32

for steps in range(100):
    xb,yb = get_batch('train')
    logits, loss = m(xb,yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    
    print(loss.item())

4.688744068145752
4.542750358581543
4.714441299438477
4.519054412841797
4.748159408569336
4.4445695877075195
4.573859691619873
4.6024169921875
4.7254252433776855
4.561741352081299
4.712865829467773
4.502773284912109
4.5777764320373535
4.640673637390137
4.620965957641602
4.714259147644043
4.505113124847412
4.548577308654785
4.6977691650390625
4.628480434417725
4.603488445281982
4.700981140136719
4.541628360748291
4.609816074371338
4.596613883972168
4.518304824829102
4.6689300537109375
4.683122158050537
4.593369483947754
4.618943214416504
4.533949375152588
4.518628120422363
4.621417999267578
4.596273422241211
4.515584945678711
4.596221446990967
4.623411178588867
4.65731143951416
4.599900722503662
4.621892929077148
4.565545558929443
4.620325565338135
4.568120002746582
4.558705806732178
4.532720565795898
4.605710983276367
4.5209574699401855
4.66516637802124
4.49791145324707
4.650701522827148
4.612490177154541
4.6128129959106445
4.540727615356445
4.57422399520874
4.588643550872803
4.6648545

Self attention starts here

In [40]:
torch.manual_seed(42)
B,T,C = 4,8,2
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

In [42]:
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1]
        xbow[b,t] = torch.mean(xprev,dim = 0)