In [1]:
#just to prototype the ideas

In [18]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [3]:
device = torch.device("mps" if torch.mps.is_available() else "cpu")  #from now on, will utilize the GPU
torch.mps.is_available() #is indeed available

True

In [4]:
#working with shakesperian texts dataset for this egs, but can be extended to any text dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
#dataset of all the work of shakespear

--2025-01-03 09:54:53--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.1’


2025-01-03 09:54:53 (6.29 MB/s) - ‘input.txt.1’ saved [1115394/1115394]



In [5]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [6]:
print("len of text: ", len(text))  #approx as 1.2M characters

len of text:  1115394


In [7]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [8]:
#get all the unique chars used in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(f"vocabulary size: {vocab_size}")


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocabulary size: 65


In [9]:
#creating a mapping
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[ch] for ch in s]  #take in a string, give out its integer representation
decode = lambda list: ''.join([itos[i] for i in list] ) #take in a list of integers, give out its chars

In [10]:
encode("hello shakespear")

[46, 43, 50, 50, 53, 1, 57, 46, 39, 49, 43, 57, 54, 43, 39, 56]

In [11]:
data = torch.tensor(encode(text), dtype=torch.long, device=device) 
print(data.shape, data.dtype)
print(data[:100])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59], device='mps:0')


In [12]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [13]:
#in transformers, we work with context sizes from 1 all the way till block_size(whatever that is)
#previously in n-gram neural nets, we used to work with fixed context sizes, but this is not the case
#here, and it is also advantageous, as the model learns to work with different context sizes

block_size = 8
x = train_data[:block_size]
y = train_data[1:block_size+1]

for t in range(block_size):  #here t is the time step, as we work with each of the sub contexts
    #during training and that is done consecutively, over time, however there is another dimension 
    #to it, which is the batch dim, where the samples of the batch are handled in parallel
    context = x[:t+1]
    target = y[t]
    print(f"for context {context} what follows is {target}")

for context tensor([18], device='mps:0') what follows is 47
for context tensor([18, 47], device='mps:0') what follows is 56
for context tensor([18, 47, 56], device='mps:0') what follows is 57
for context tensor([18, 47, 56, 57], device='mps:0') what follows is 58
for context tensor([18, 47, 56, 57, 58], device='mps:0') what follows is 1
for context tensor([18, 47, 56, 57, 58,  1], device='mps:0') what follows is 15
for context tensor([18, 47, 56, 57, 58,  1, 15], device='mps:0') what follows is 47
for context tensor([18, 47, 56, 57, 58,  1, 15, 47], device='mps:0') what follows is 58


In [None]:
#introducing the batch dim
batch_size = 4

def get_batch(name):
    data = train_data if name == 'train' else val_data
    ix = torch.randint(0, data.size(0) - block_size, (batch_size,))
    context = torch.stack([data[i: i+block_size] for i in ix])
    target = torch.stack([data[i+1: i+1+block_size] for i in ix])
    return context, target

xb, yb = get_batch('train')
print(xb.shape)
print("inputs: ", xb)
print(yb.shape)
print("targets: ", yb)


#going through the batch and time dim 
for b in range(batch_size):
    for t in range(block_size):
        print(f"batch {b} context {xb[b, :t+1]} target {yb[b, t]}")


torch.Size([4, 8])
inputs:  tensor([[40, 43, 39, 56,  0, 33, 54, 53],
        [51, 39, 56, 41, 46,  1, 53, 52],
        [39, 51,  1, 53, 59, 58,  7,  7],
        [58, 46, 43, 43,  0, 32, 53,  1]], device='mps:0')
torch.Size([4, 8])
targets:  tensor([[43, 39, 56,  0, 33, 54, 53, 52],
        [39, 56, 41, 46,  1, 53, 52,  1],
        [51,  1, 53, 59, 58,  7,  7,  0],
        [46, 43, 43,  0, 32, 53,  1, 45]], device='mps:0')
batch 0 context tensor([40], device='mps:0') target 43
batch 0 context tensor([40, 43], device='mps:0') target 39
batch 0 context tensor([40, 43, 39], device='mps:0') target 56
batch 0 context tensor([40, 43, 39, 56], device='mps:0') target 0
batch 0 context tensor([40, 43, 39, 56,  0], device='mps:0') target 33
batch 0 context tensor([40, 43, 39, 56,  0, 33], device='mps:0') target 54
batch 0 context tensor([40, 43, 39, 56,  0, 33, 54], device='mps:0') target 53
batch 0 context tensor([40, 43, 39, 56,  0, 33, 54, 53], device='mps:0') target 52
batch 1 context tensor

In [15]:
print(xb)

tensor([[40, 43, 39, 56,  0, 33, 54, 53],
        [51, 39, 56, 41, 46,  1, 53, 52],
        [39, 51,  1, 53, 59, 58,  7,  7],
        [58, 46, 43, 43,  0, 32, 53,  1]], device='mps:0')


In [24]:
print(yb)

tensor([[43, 39, 56,  0, 33, 54, 53, 52],
        [39, 56, 41, 46,  1, 53, 52,  1],
        [51,  1, 53, 59, 58,  7,  7,  0],
        [46, 43, 43,  0, 32, 53,  1, 45]], device='mps:0')


In [42]:
#again, starting off with the bigram count model, again this is similar to the count model(neural net)
#which can learn too, where recall that had only the input and output layer and no hidden 
#layers, this is something similar too, where the embedding layer is essemtially
#the output layer, since it has 65 neurons which match the vocab size, so it can be 
#thought of as the logits/output layer, and hence it makes sense that we calculate the loss
#based on this directly

torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)   #embedding dim is of
        #vocab size

    def forward(self, idx, targets=None):
        #idx is (B, T) and targets is also (B, T)
        logits = self.token_embedding_table(idx)  # (B, T, C)
        self.B, self.T, self.C = logits.shape
        logits = logits.view(self.B*self.T, self.C)
        
        if targets is None:
            loss = None
        else:
            targets = targets.view(self.B*self.T)
            loss = F.cross_entropy(logits, targets) #but for loss, it want it in form (B, C, T)
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self(idx)  #calls forward method
            logits = logits.view(self.B, self.T, self.C)[:, -1, :]  #(B, C)
            probs = F.softmax(logits, dim=-1)  #(B, C)
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            #append the sampled char to the idx
            idx = torch.cat([idx, idx_next], dim=1)  #(B, T+1)
        return idx
    
m =  BigramLanguageModel(vocab_size).to(device)  #passing the whole class to device, that way 
#individual tensors will be moved to the device as well
print(xb.shape)
logits, loss = m(xb, yb)  #calls the forawrd method as this class is a subclass of nn.Module and
#calling an instance of nn.module calls the .__call__() method which in turn calls the forward method
print(logits)
print("\n")
l = logits.view(m.B, m.T, m.C)[:, -1, :]
print(l)
print(l.shape)
print(loss)

torch.Size([4, 8])
tensor([[ 1.4311,  0.4160, -2.2246,  ...,  0.7330,  0.3551,  0.1472],
        [ 0.3323, -0.0872, -0.7470,  ..., -0.6716, -0.9572, -0.9594],
        [ 1.1513,  1.0539,  3.4105,  ..., -0.5686,  0.9079, -0.1701],
        ...,
        [ 0.7029,  1.4840,  0.1137,  ..., -0.5048,  0.8791,  0.4086],
        [-0.1324, -0.5489,  0.1024,  ..., -0.8599, -1.6050, -0.6985],
        [ 0.5978, -0.0514, -0.0646,  ..., -1.4649, -2.0555,  1.8275]],
       device='mps:0', grad_fn=<ViewBackward0>)


tensor([[-1.3237e-01, -5.4889e-01,  1.0244e-01, -6.9162e-01,  3.5075e-01,
          1.6147e+00,  1.8203e+00,  5.1224e-01,  1.5810e+00, -2.0063e+00,
         -1.2925e+00,  1.2681e-01,  1.1099e+00, -6.5921e-01,  8.0844e-01,
          1.9072e+00, -3.2599e-01, -3.4377e-01, -1.4415e+00, -1.8276e-01,
         -8.8043e-01, -6.1918e-01, -1.4047e+00, -8.5837e-01, -3.8297e-01,
         -5.3723e-01, -1.2176e+00, -1.9403e+00, -3.0937e-01,  1.7895e-01,
          1.2859e+00,  3.0392e-01,  1.8110e+00,  6.35

In [43]:
#the loss shld actually be -ln(1/65) which is 4.17

In [44]:
#generating text
print(decode(m.generate(torch.zeros((1, 1), dtype=torch.long, device=device), 100)[0].tolist()))
#works on batches, so the 0th index has to be indexed to


UNasE3QKdYMjKfxcq-PyQbRF.
jxuUfZWievNL:C&v-jkcECOIiyeg zbZAcQ?yObr&MkzeAmyFXSPHd,j&?oneOAvrFotKuLTDx


In [None]:
#notice what we do here, wihtin generate method we call the forward method again and again, each time
#with an extended context and then we sample the dist from the last token, eventually, the history
#which is the context that came before will be used and then this approach will make more sense

In [48]:
#training this model now
optimiser = torch.optim.AdamW(m.parameters(), lr=1e-3)  #basically train the embedding layer, so
#as to get finer embeddings, and hence better predictions. Its the same as the two layer model for
#the count bigram model, where the weights which connected the input and output layer were trained, but
#here instead of the weights, the embeddings are trained

In [49]:
batch_size = 32
for steps in range(40000):

    #get a batch of data
    xb, yb = get_batch('train')

    #evaluate the loss
    logits, loss = m(xb, yb)

    #backprop
    optimiser.zero_grad(set_to_none=True)
    loss.backward()

    #update
    optimiser.step()

    print(loss);


tensor(4.6485, device='mps:0', grad_fn=<NllLossBackward0>)
tensor(4.6507, device='mps:0', grad_fn=<NllLossBackward0>)
tensor(4.7109, device='mps:0', grad_fn=<NllLossBackward0>)
tensor(4.7115, device='mps:0', grad_fn=<NllLossBackward0>)
tensor(4.7078, device='mps:0', grad_fn=<NllLossBackward0>)
tensor(4.6761, device='mps:0', grad_fn=<NllLossBackward0>)
tensor(4.6929, device='mps:0', grad_fn=<NllLossBackward0>)
tensor(4.7817, device='mps:0', grad_fn=<NllLossBackward0>)
tensor(4.7173, device='mps:0', grad_fn=<NllLossBackward0>)
tensor(4.8193, device='mps:0', grad_fn=<NllLossBackward0>)
tensor(4.7402, device='mps:0', grad_fn=<NllLossBackward0>)
tensor(4.6960, device='mps:0', grad_fn=<NllLossBackward0>)
tensor(4.7699, device='mps:0', grad_fn=<NllLossBackward0>)
tensor(4.7050, device='mps:0', grad_fn=<NllLossBackward0>)
tensor(4.7788, device='mps:0', grad_fn=<NllLossBackward0>)
tensor(4.6514, device='mps:0', grad_fn=<NllLossBackward0>)
tensor(4.6445, device='mps:0', grad_fn=<NllLossBackward0

In [56]:
print(decode(m.generate(torch.zeros((1, 1), dtype=torch.long, device=device), 500)[0].tolist()))


F:
arameak malll wovireg bincoragatisthesiscipo ll icowis tit.
ANe past:
Whellarer f ourdave? y hathefigeiver ftll e agg t hen
MENGLARERIONTAns sashial otiall hef houme lle beowhamel y, yseland Orauke m.
Oferend aler mavathasl, g,
ARI V:

tofano othasof.
Foue horathet grs tourerk he.
JUSharicen han' swhodathaton so aradisuroure p w; I:
EDry bl,
S:
Mare gh t.

tl bererfonousiromuppardorld e'l me Ge
Nammater hin
NGoudw r se
Cathenom tst tareditesen te t be fongos t
Ty
BRBESloor hom's aramod urleui


self-attention

information only flows from previous tokens to the current ones, and the ones from the future cannot talk to the current token, and this is as we want to predict the future context, and getting any info
from the future would not serve the purpose

In [63]:
#working with a toy egs to see what is meant 
torch.manual_seed(1337)

B, T, C = 4, 8, 2   # batch, time, channels
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [64]:
#a very simple version of attention, where the embedding of tth token is the mean of all the tokens
#before it in the sample, and the mean is taken along the same dims 
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1]  #also including the embedding vector for the tth token 
        xbow[b, t] = torch.mean(xprev, dim=0)

In [67]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [66]:
xbow[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

In [68]:
#so we see that since the 0th token doesnt really have any context before it, its vector after
#self-attention remains the same, but the rest have changed along their respective dims

In [72]:
#more efficient way rather than to use for loops is use torch.tril
#as an egs
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print('a', a)
print('b', b)
print('c', c)

a tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
b tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
c tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])


In [73]:
#so we see that we get the same addition of past vectors into the current one as before, now we 
#just need to take its mean

In [78]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3)) / torch.sum(a, dim=1, keepdim=True)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print('a', a)
print('b', b)
print('c', c)

a tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
b tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
c tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [79]:
#so its the same as before, but obtained much more efficiently, however, the way the attention mechanism
#is implemented itself is not that good, now making the same egs more efficient

In [87]:
torch.manual_seed(1337)
wei = torch.tril(torch.ones((T, T)))
wei /= torch.sum(wei, dim=1, keepdim=True)
xbow2 = wei @ x   # (T, T) @ (B, T, C) ---> (B, T, T) @ (B, T, C) after broadcasting which inturn
#gives (B, T, C) as the dim of xbow2
xbow2  #the same result

tensor([[[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]],

        [[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.0084,  0.0020],
         [ 0.0712, -0.1128],
         [ 0.2527,  0.2149]],

        [[-0.6631, -0.2513],
         [ 0.1735, -0.0649],
         [ 0.1685,  0.3348],
         [-0.1621,  0.1765],
         [-0.2312, -0.0436],
         [-0.1015, -0.2855],
         [-0.2593, -0.1630],
         [-0.3015, -0.2293]],

        [[ 1.6455, -0.8030],
         [ 1.4985, -0.5395],
         [ 0.4954,  0.3420],
         [ 1.0623, -0.1802],
         [ 1.1401, -0.4462],
         [ 1.0870, -0.4071],
         [ 1.0430, -0.1299],
         [ 1.1138, -0.1641]]])