<a href="https://colab.research.google.com/github/atonui/swahili-gpt/blob/main/swa_gpt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Swahili-GPT
- A simple character level transformer based LLM trained on a Swahili dataset.

In [7]:
# load the training dataset
!wget https://raw.githubusercontent.com/atonui/pds/refs/heads/main/swahili_data/train.txt

--2025-08-26 14:08:07--  https://raw.githubusercontent.com/atonui/pds/refs/heads/main/swahili_data/train.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 7564413 (7.2M) [text/plain]
Saving to: ‘train.txt’


2025-08-26 14:08:07 (56.1 MB/s) - ‘train.txt’ saved [7564413/7564413]



In [8]:
# read it to inspect it
with open('train.txt', 'r') as f:
    text = f.read()

In [9]:
print('Length of dataset in characters: ', len(text))

Length of dataset in characters:  7522342


In [10]:
print(text[:1000])

﻿ taarifa hiyo ilisema kuwa ongezeko la joto la maji juu ya wastani katikati ya bahari ya  inaashiria kuwepo kwa mvua za el nino  hadi mwishoni mwa april ishirini moja sifuri imeelezwa kuwa ongezeko la joto magharibi mwa bahari ya hindi linatarajiwa kuhamia katikati ya bahari hiyo hali ambayo itasababisha pepo kutoka kaskazini mashariki kuvuma kuelekea bahari ya hindi 
 aidha ilisema kuwa mwelekeo wa kupungua kwa joto kusini mashariki mwa bahari ya atlantic  kusababisha pepo kutoka magharibi kuvuma kuelekea magharibi mwa tanzania katika maeneo ya ziwa victoria 
 mwelekeo wa mvua wa septemba hadi desemba ishirini sifuri tisa unatarajiwa kuwa katika namna tofauti ambapo baadhi ya maeneo yanaweza kunufaika huku mengine  
 ilifafanua kuwa msimu wa vuli  maeneo ambayo hupata mvua mara mbili ambayo ni kaskazini mwa nchi ikiwa ni nyanda za juu kaskazini mashariki kanda ya ziwa victoria na pwani ya kaskazini 
 katika maeneo hayo mvua zinatarajiwa kunyesha wiki ya pili na tatu ya septemba mwaka

In [11]:
# get a sorted list of all the characters in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print("There are", vocab_size, "unique characters in this text.")


 abcdefghijklmnopqrstuvwxyz﻿
There are 29 unique characters in this text.


## Tokenisation
- Represent characters as integers (vectors) so the model can manipulate them.
- The below tokeniser is simple, it just translates a character to an integer.
- There are more sophisticated tokenisers out there, we shall experiment with them.

In [12]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("joto jingi"))
print(decode(encode("joto jingi")))

[11, 16, 21, 16, 1, 11, 10, 15, 8, 10]
joto jingi


In [13]:
print(decode([14, 19, 24, 19, 1, 16, 17, 18, 11, 13]))

mrwr opqjl


In [14]:
# now we're going to encode the entire dataset and store it into a torch.Tensor
import torch
train_data = torch.tensor(encode(text), dtype=torch.long)
print(train_data.shape, train_data.type,'\n')
print(train_data[:1000])

torch.Size([7522342]) <built-in method type of Tensor object at 0x78c93cc4b110> 

tensor([28,  1, 21,  2,  2, 19, 10,  7,  2,  1,  9, 10, 26, 16,  1, 10, 13, 10,
        20,  6, 14,  2,  1, 12, 22, 24,  2,  1, 16, 15,  8,  6, 27,  6, 12, 16,
         1, 13,  2,  1, 11, 16, 21, 16,  1, 13,  2,  1, 14,  2, 11, 10,  1, 11,
        22, 22,  1, 26,  2,  1, 24,  2, 20, 21,  2, 15, 10,  1, 12,  2, 21, 10,
        12,  2, 21, 10,  1, 26,  2,  1,  3,  2,  9,  2, 19, 10,  1, 26,  2,  1,
         1, 10, 15,  2,  2, 20,  9, 10, 19, 10,  2,  1, 12, 22, 24,  6, 17, 16,
         1, 12, 24,  2,  1, 14, 23, 22,  2,  1, 27,  2,  1,  6, 13,  1, 15, 10,
        15, 16,  1,  1,  9,  2,  5, 10,  1, 14, 24, 10, 20,  9, 16, 15, 10,  1,
        14, 24,  2,  1,  2, 17, 19, 10, 13,  1, 10, 20,  9, 10, 19, 10, 15, 10,
         1, 14, 16, 11,  2,  1, 20, 10,  7, 22, 19, 10,  1, 10, 14,  6,  6, 13,
         6, 27, 24,  2,  1, 12, 22, 24,  2,  1, 16, 15,  8,  6, 27,  6, 12, 16,
         1, 13,  2,  1, 11, 16, 21, 16

In [15]:
# test dataset
!wget https://raw.githubusercontent.com/atonui/pds/refs/heads/main/swahili_data/test.txt

--2025-08-26 14:08:09--  https://raw.githubusercontent.com/atonui/pds/refs/heads/main/swahili_data/test.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 686306 (670K) [text/plain]
Saving to: ‘test.txt’


2025-08-26 14:08:09 (8.45 MB/s) - ‘test.txt’ saved [686306/686306]



In [16]:
# read it to inspect it
with open('test.txt', 'r') as f:
    test = f.read()

In [17]:
# encode test dataset into a tensor
test_data = torch.tensor(encode(test), dtype=torch.long)
print(test_data.shape, test_data.type, '\n')
print(test_data[:1000])

torch.Size([682933]) <built-in method type of Tensor object at 0x78ca1fa78e10> 

tensor([28,  1,  9, 22, 26, 16,  1,  2, 13, 10, 20, 10, 20, 10, 21, 10, 27,  2,
         1, 12, 22, 24,  2,  1,  9,  2, 12, 22,  9, 16, 11, 10, 24,  2,  1,  3,
         2, 13, 10,  1,  2, 13, 10, 17,  6, 24,  2,  1, 12,  2, 19,  2, 21,  2,
        20, 10,  1, 21, 22, 17, 22,  1, 15,  2,  1, 12, 22, 21,  2, 12, 10, 24,
         2,  1, 12, 22, 20,  2, 10, 15, 10,  1,  3, 10, 13,  2,  1, 12, 22,  7,
         2,  9,  2, 14, 22,  1, 15, 10,  1, 12, 10, 21, 22,  1,  8,  2, 15, 10,
         1,  1,  0,  1,  2, 12, 10,  6, 13,  6, 27,  6,  2,  1, 20, 10, 12, 22,
         1, 26,  2,  1, 21, 22, 12, 10, 16,  1,  2, 13, 10, 20,  6, 14,  2,  1,
         2, 13, 10, 17, 10,  8, 10, 24,  2,  1, 20, 10, 14, 22,  1, 15,  2,  1,
        20, 20, 17,  1, 20,  2, 13, 22, 14,  1, 12, 10, 20,  2, 10,  1,  2, 14,
         3,  2, 26,  6,  1,  2, 13, 10, 14, 21,  2, 12,  2,  1, 12, 22,  7, 10,
        12,  2,  1, 12,  2, 21, 10, 12,

In [18]:
# validation dataset
!wget https://raw.githubusercontent.com/atonui/pds/refs/heads/main/swahili_data/valid.txt

--2025-08-26 14:08:09--  https://raw.githubusercontent.com/atonui/pds/refs/heads/main/swahili_data/valid.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 655979 (641K) [text/plain]
Saving to: ‘valid.txt’


2025-08-26 14:08:10 (8.24 MB/s) - ‘valid.txt’ saved [655979/655979]



In [19]:
# read it to inspect it
with open('valid.txt', 'r') as f:
    valid = f.read()

In [20]:
# encode validation dataset into a tensor
valid_data = torch.tensor(encode(valid), dtype=torch.long)
print(valid_data.shape, valid_data.type)
print(valid_data[:100])

# code from the repetitive cells above is ripe for a function

torch.Size([652605]) <built-in method type of Tensor object at 0x78ca1fa784b0>
tensor([28,  1,  9, 10, 10,  1, 15, 10,  1,  5,  9,  2, 15,  2,  1, 17, 16, 21,
        16,  7, 22,  1, 15,  2,  1, 26,  2,  1,  9,  2, 21,  2, 19, 10,  1,  9,
         2, 20,  2,  1, 22, 12, 10, 27, 10, 15,  8,  2, 21, 10,  2,  1,  3,  2,
         2,  5,  9, 10,  1, 26,  2,  1, 24,  2, 15,  2, 15,  4,  9, 10,  1, 24,
         6, 15,  8, 10,  1, 24,  2, 15,  2,  1, 14,  2, 20,  9,  2, 12,  2,  1,
        15,  2,  1, 22, 27,  2, 13,  6, 15,  5])


In [21]:
# training block size
block_size = 8
train_data[:block_size+1]
# the transformer is not trained on the entire text but on blocks of text e.g. the above block size is 9 characters

tensor([28,  1, 21,  2,  2, 19, 10,  7,  2])

In [22]:
x = train_data[:block_size] # inputs to the transformer
y = train_data[1:block_size+1] # next block size, it is offset by 1
# iterating through the block size
for t in range(block_size):
  context = x[:t+1]
  target = y[t]
  print(f'when input is {context} the target: {target}')

when input is tensor([28]) the target: 1
when input is tensor([28,  1]) the target: 21
when input is tensor([28,  1, 21]) the target: 2
when input is tensor([28,  1, 21,  2]) the target: 2
when input is tensor([28,  1, 21,  2,  2]) the target: 19
when input is tensor([28,  1, 21,  2,  2, 19]) the target: 10
when input is tensor([28,  1, 21,  2,  2, 19, 10]) the target: 7
when input is tensor([28,  1, 21,  2,  2, 19, 10,  7]) the target: 2


In [23]:
torch.manual_seed(1337)
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?

def get_batch(split):
  # generate a small batch of data of inputs x and targets y
  data = train_data if split == 'train' else valid_data
  ix = torch.randint(len(data) - block_size, (batch_size,))
  x = torch.stack([data[i:i+block_size] for i in ix])
  y = torch.stack([data[i+1:i+block_size+1] for i in ix])
  return x,y

xb, yb = get_batch('train')
print('inputs: ')
print(xb.shape)
print(xb)
print('targets: ')
print(yb.shape)
print(yb)

print('----------')

for b in range(batch_size): # batch dimension
  for t in range(block_size): # time dimension
    context = xb[b, :t+1]
    target = yb[b,t]
    print(f'when input is {context.tolist()} the target: {target}')

inputs: 
torch.Size([4, 8])
tensor([[12, 22, 13, 10, 14,  7,  2, 15],
        [20, 10, 20, 10, 21, 10, 27,  2],
        [11, 10,  1, 24,  2,  1, 21,  2],
        [ 2, 16,  1, 27, 10, 21,  2, 12]])
targets: 
torch.Size([4, 8])
tensor([[22, 13, 10, 14,  7,  2, 15, 26],
        [10, 20, 10, 21, 10, 27,  2,  1],
        [10,  1, 24,  2,  1, 21,  2, 15],
        [16,  1, 27, 10, 21,  2, 12,  2]])
----------
when input is [12] the target: 22
when input is [12, 22] the target: 13
when input is [12, 22, 13] the target: 10
when input is [12, 22, 13, 10] the target: 14
when input is [12, 22, 13, 10, 14] the target: 7
when input is [12, 22, 13, 10, 14, 7] the target: 2
when input is [12, 22, 13, 10, 14, 7, 2] the target: 15
when input is [12, 22, 13, 10, 14, 7, 2, 15] the target: 26
when input is [20] the target: 10
when input is [20, 10] the target: 20
when input is [20, 10, 20] the target: 10
when input is [20, 10, 20, 10] the target: 21
when input is [20, 10, 20, 10, 21] the target: 10
when in

In [24]:
print(xb)

tensor([[12, 22, 13, 10, 14,  7,  2, 15],
        [20, 10, 20, 10, 21, 10, 27,  2],
        [11, 10,  1, 24,  2,  1, 21,  2],
        [ 2, 16,  1, 27, 10, 21,  2, 12]])


In [25]:
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):

  def __init__(self, vocab_size):
    super().__init__()
    # each token directly reads off the logits for the next token from a lookup table
    self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)


  def forward(self, idx, targets=None):
     # idx and targets are both (B,T) tensor of integers
     logits = self.token_embedding_table(idx) # (Batch,Time,Channel) tensor

     if targets is None:
      loss = None
     else:
      B,T,C = logits.shape
      logits = logits.view(B*T, C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits, targets)

     return logits, loss

  def generate(self, idx, max_new_tokens):
    # idx is (B, T) array of indices in the current context
    for _ in range(max_new_tokens):
      # get the predictions
      logits, loss = self(idx)
      # focus only on the last time step
      logits = logits[:,-1,:] # becomes(B,C)
      # apply softmax to get probabilities
      probs = F.softmax(logits, dim=-1) # (B,C)
      # sample from the distribution
      idx_next = torch.multinomial(probs, num_samples=1) # (B,1)
      # append sampled index to the running sequence
      idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
    return idx

m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
out = m(xb, yb)
print(logits.shape)
print(loss)

print(decode(m.generate(idx = torch.zeros((1,1), dtype=torch.long), max_new_tokens=100)[0].tolist()))


torch.Size([32, 29])
tensor(4.0687, grad_fn=<NllLossBackward0>)

pvpv﻿nauzhklblpnnamd
nkrhvgeh lkjrokmjrulbsbuzwna
p
qko
enbromnabuzwcyrhmmnnhsnbxpsxmuowomabsv apitg


In [26]:
# create a pytorch optimiser object
optimiser = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [27]:
batch_size = 32
for steps in range(10000):
  # sample a batch of data
  xb, yb = get_batch('train')

  # evaluate the loss
  logits, loss = m(xb, yb)
  optimiser.zero_grad(set_to_none=True)
  loss.backward()
  optimiser.step()

print(loss.item())

2.0706610679626465


In [28]:
print(decode(m.generate(idx = torch.zeros((1,1), dtype=torch.long), max_new_tokens=500)[0].tolist()))


 wikena ku ngakuuku hali wengezinamishi chukujemi kime teofarai kaso wa hamendojishaizani weshe waluamasamatisilao es yekenuwa bandawa na datuna norika ma dupaza wananaya wa a junato kemsta stezesehimuuzit mwa tokijemosi shitaji hihaikamkika ya uwa sheetochio kakumemavikika hala wabe we kunggo ki mokikana vipi 
 na kibamaturamao wayara 
 ilika uakali m ko vya dinyabunza a ku gi ngujaliamda kia li mu kazo  wale ba  nginctuifanda aoa bi sema wakiku ma zuni hatata ta ativishari li  wani po ha wikia


Much improvement on the Bigram model but we're not quite there yet. This is a simple model where the tokens are not talking to each other, where the prediction is happening only on the very last character. So next we have to make the tokens talk to each other and figure out the context and make better predictions which is what the **Transformer** will do.

## The Mathematical trick in self attention

In [29]:
torch.manual_seed(1337)
B, T, C = 4,8,2 # batch, time and channels (infomation)
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

In [30]:
# version 1
xbow = torch.zeros((B, T, C)) # xbow -> x bag of words i.e. average. Initialise the tensor to all zeros
for b in range(B):
  for t in range(T):
    xprev = x[b, :t+1] # (t,C)
    xbow[b,t] = torch.mean(xprev, 0)

In [31]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [32]:
xbow[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

We can see in the xbow[0] results that the piece of code is averaging out the column, row by row but it is very inefficient.

In [33]:
torch.manual_seed(42)
a = torch.ones(3,3)
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print('a=')
print(a)
print('------')
print('b=')
print(b)
print('------')
print('c= a x b')
print(c)

a=
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
------
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
------
c= a x b
tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])


In [34]:
torch.tril(torch.ones(3,3)) # triangular matrix

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [35]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3,3)) # ---> make a, a triangular matrix
b = torch.randint(0,10,(3,2)).float()
c = a @ b # ---> matrix dot product
print('a=')
print(a)
print('------')
print('b=')
print(b)
print('------')
print('c= a x b')
print(c)

a=
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
------
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
------
c= a x b
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])


Using the triangulat matrix, we can see that we are now doing sums of the tensors a and b. Therefore we can now do averages.

In [36]:
# normalise a so that a row equals 1 then we can get the average when we do a dot product
torch.manual_seed(42)
a = torch.tril(torch.ones(3,3)) # ---> make a a triangular matrix, this matrix is like the weights
a = a / torch.sum(a, 1, keepdim=True) # normalise a
b = torch.randint(0,10,(3,2)).float()
c = a @ b # ---> matrix dot product
print('a=')
print(a)
print('------')
print('b=')
print(b)
print('------')
print('c= a x b')
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
------
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
------
c= a x b
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


Now back to our bag of words original problem.

Let us now vectorise the for loop to make it much more efficient.

In [37]:
# version 2 --> using triangular matrices
a2 = torch.tril(torch.ones(T,T))
a2 = a2 / a2.sum(1, keepdim=True)
a2

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [38]:
xbow2 = a2 @ x # (B,T,T) @ (B,T,C) ---> (B,T,C)
xbow2

tensor([[[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]],

        [[ 1.3488, -0.1396],
         [ 0.8173,  0.4127],
         [-0.1342,  0.4395],
         [ 0.2711,  0.4774],
         [ 0.2421,  0.0694],
         [ 0.0084,  0.0020],
         [ 0.0712, -0.1128],
         [ 0.2527,  0.2149]],

        [[-0.6631, -0.2513],
         [ 0.1735, -0.0649],
         [ 0.1685,  0.3348],
         [-0.1621,  0.1765],
         [-0.2312, -0.0436],
         [-0.1015, -0.2855],
         [-0.2593, -0.1630],
         [-0.3015, -0.2293]],

        [[ 1.6455, -0.8030],
         [ 1.4985, -0.5395],
         [ 0.4954,  0.3420],
         [ 1.0623, -0.1802],
         [ 1.1401, -0.4462],
         [ 1.0870, -0.4071],
         [ 1.0430, -0.1299],
         [ 1.1138, -0.1641]]])

In [39]:
torch.allclose(xbow, xbow2, atol=1e-4) # https://stackoverflow.com/questions/75622268/comparing-two-tensors-in-pytorch

True

In [40]:
xbow[0], xbow2[0]

(tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]),
 tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]))

The for loop and matrix operations give us the same answers but the matrix operation is much more efficient.

In [41]:
# version 3 ---> using softmax
tril = torch.tril(torch.ones(T,T)) # --> this is a triangular matrix
a3 = torch.zeros((B,T,T)) # --> this is a zero matrix
a3 = a3.masked_fill(tril == 0, float('-inf')) # --> wherever there is a 0 in the tril matrix, replace with -inf(infinity) in the a3 matrix
a3 = F.softmax(a3, dim=-1) # --> normalises each row of the matrix, just like a2 = a2 / a2.sum(1, keepdim=True) this operation does
xbow3 = a3 @ x
torch.allclose(xbow, xbow3, atol=1e-4)


True

In [42]:
a3[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [43]:
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [44]:
a4 = torch.zeros((B,T,T))
a4 = a4.masked_fill(tril == 0, float('-inf'))
a4

tensor([[[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
         [0., 0., 0., 0., -inf, -inf, -inf, -inf],
         [0., 0., 0., 0., 0., -inf, -inf, -inf],
         [0., 0., 0., 0., 0., 0., -inf, -inf],
         [0., 0., 0., 0., 0., 0., 0., -inf],
         [0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
         [0., 0., 0., 0., -inf, -inf, -inf, -inf],
         [0., 0., 0., 0., 0., -inf, -inf, -inf],
         [0., 0., 0., 0., 0., 0., -inf, -inf],
         [0., 0., 0., 0., 0., 0., 0., -inf],
         [0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
         [0., 0., 0., 0., -inf,

## Self attention

In [45]:
# version 4: self attention
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time and channels
x = torch.randn(B,T,C)
# single Head self attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x) # (B,T,16)
q = query(x) # (B,T,16)

a4 = q @ k.transpose(-2, -1) # (B,T,16) @ (B,16,T) ---> (B,T,T)
tril = torch.tril(torch.ones(T,T))
# a4 = torch.zeros((T,T))
a4 = a4.masked_fill(tril == 0, float('-inf')) # the future cannot communicate to the past
a4 = F.softmax(a4, dim=1)

v = value(x)
out = a4 @ v
out.shape

torch.Size([4, 8, 16])

In [46]:
a4[0]

tensor([[0.0248, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0052, 0.0091, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0521, 0.0135, 0.2482, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3171, 0.0214, 0.1642, 0.1188, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0412, 0.0487, 0.1046, 0.0742, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1060, 0.5347, 0.2059, 0.1030, 0.7402, 0.0192, 0.0000, 0.0000],
        [0.4298, 0.3409, 0.1769, 0.2027, 0.0480, 0.8472, 0.2329, 0.0000],
        [0.0238, 0.0316, 0.1002, 0.5013, 0.0117, 0.1336, 0.7671, 1.0000]],
       grad_fn=<SelectBackward0>)

Every single token at each position will emit 2 vectors:
- A Query vector: what am I looking for?
- Key vector: What do I contain?
We get the afinities between these tokens by doing a dot product of the keys and queries i.e. my Q dot K of all the other keys of all the other tokens.
$$𝐖 = {𝐐}⋅{Κ} → weights (a\: in\: our\: code)$$
If the Q and K of two characters are aligned, they will interact to a very high amount so I will learn more about that token as opposed to other tokens.



