<a href="https://colab.research.google.com/github/atonui/swahili-gpt/blob/main/swa_gpt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Swahili-GPT
- A simple character level transformer based LLM trained on a Swahili dataset.

In [1]:
# load the training dataset
!wget https://raw.githubusercontent.com/atonui/pds/refs/heads/main/swahili_data/train.txt

--2025-08-21 07:49:59--  https://raw.githubusercontent.com/atonui/pds/refs/heads/main/swahili_data/train.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 7564413 (7.2M) [text/plain]
Saving to: ‘train.txt’


2025-08-21 07:49:59 (146 MB/s) - ‘train.txt’ saved [7564413/7564413]



In [2]:
# read it to inspect it
with open('train.txt', 'r') as f:
    text = f.read()

In [3]:
print('Length of dataset in characters: ', len(text))

Length of dataset in characters:  7522342


In [4]:
print(text[:1000])

﻿ taarifa hiyo ilisema kuwa ongezeko la joto la maji juu ya wastani katikati ya bahari ya  inaashiria kuwepo kwa mvua za el nino  hadi mwishoni mwa april ishirini moja sifuri imeelezwa kuwa ongezeko la joto magharibi mwa bahari ya hindi linatarajiwa kuhamia katikati ya bahari hiyo hali ambayo itasababisha pepo kutoka kaskazini mashariki kuvuma kuelekea bahari ya hindi 
 aidha ilisema kuwa mwelekeo wa kupungua kwa joto kusini mashariki mwa bahari ya atlantic  kusababisha pepo kutoka magharibi kuvuma kuelekea magharibi mwa tanzania katika maeneo ya ziwa victoria 
 mwelekeo wa mvua wa septemba hadi desemba ishirini sifuri tisa unatarajiwa kuwa katika namna tofauti ambapo baadhi ya maeneo yanaweza kunufaika huku mengine  
 ilifafanua kuwa msimu wa vuli  maeneo ambayo hupata mvua mara mbili ambayo ni kaskazini mwa nchi ikiwa ni nyanda za juu kaskazini mashariki kanda ya ziwa victoria na pwani ya kaskazini 
 katika maeneo hayo mvua zinatarajiwa kunyesha wiki ya pili na tatu ya septemba mwaka

In [5]:
# get a sorted list of all the characters in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print("There are", vocab_size, "unique characters in this text.")


 abcdefghijklmnopqrstuvwxyz﻿
There are 29 unique characters in this text.


## Tokenisation
- Represent characters as integers (vectors) so the model can manipulate them.
- The below tokeniser is simple, it just translates a character to an integer.
- There are more sophisticated tokenisers out there, we shall experiment with them.

In [6]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("joto jingi"))
print(decode(encode("joto jingi")))

[11, 16, 21, 16, 1, 11, 10, 15, 8, 10]
joto jingi


In [7]:
print(decode([14, 19, 24, 19, 1, 16, 17, 18, 11, 13]))

mrwr opqjl


In [8]:
# now we're going to encode the entire dataset and store it into a torch.Tensor
import torch
train_data = torch.tensor(encode(text), dtype=torch.long)
print(train_data.shape, train_data.type,'\n')
print(train_data[:1000])

torch.Size([7522342]) <built-in method type of Tensor object at 0x7c85deec5bd0> 

tensor([28,  1, 21,  2,  2, 19, 10,  7,  2,  1,  9, 10, 26, 16,  1, 10, 13, 10,
        20,  6, 14,  2,  1, 12, 22, 24,  2,  1, 16, 15,  8,  6, 27,  6, 12, 16,
         1, 13,  2,  1, 11, 16, 21, 16,  1, 13,  2,  1, 14,  2, 11, 10,  1, 11,
        22, 22,  1, 26,  2,  1, 24,  2, 20, 21,  2, 15, 10,  1, 12,  2, 21, 10,
        12,  2, 21, 10,  1, 26,  2,  1,  3,  2,  9,  2, 19, 10,  1, 26,  2,  1,
         1, 10, 15,  2,  2, 20,  9, 10, 19, 10,  2,  1, 12, 22, 24,  6, 17, 16,
         1, 12, 24,  2,  1, 14, 23, 22,  2,  1, 27,  2,  1,  6, 13,  1, 15, 10,
        15, 16,  1,  1,  9,  2,  5, 10,  1, 14, 24, 10, 20,  9, 16, 15, 10,  1,
        14, 24,  2,  1,  2, 17, 19, 10, 13,  1, 10, 20,  9, 10, 19, 10, 15, 10,
         1, 14, 16, 11,  2,  1, 20, 10,  7, 22, 19, 10,  1, 10, 14,  6,  6, 13,
         6, 27, 24,  2,  1, 12, 22, 24,  2,  1, 16, 15,  8,  6, 27,  6, 12, 16,
         1, 13,  2,  1, 11, 16, 21, 16

In [9]:
# test dataset
!wget https://raw.githubusercontent.com/atonui/pds/refs/heads/main/swahili_data/test.txt

--2025-08-21 07:50:12--  https://raw.githubusercontent.com/atonui/pds/refs/heads/main/swahili_data/test.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 686306 (670K) [text/plain]
Saving to: ‘test.txt’


2025-08-21 07:50:12 (28.4 MB/s) - ‘test.txt’ saved [686306/686306]



In [10]:
# read it to inspect it
with open('test.txt', 'r') as f:
    test = f.read()

In [11]:
# encode test dataset into a tensor
test_data = torch.tensor(encode(test), dtype=torch.long)
print(test_data.shape, test_data.type, '\n')
print(test_data[:1000])

torch.Size([682933]) <built-in method type of Tensor object at 0x7c85dfbac4b0> 

tensor([28,  1,  9, 22, 26, 16,  1,  2, 13, 10, 20, 10, 20, 10, 21, 10, 27,  2,
         1, 12, 22, 24,  2,  1,  9,  2, 12, 22,  9, 16, 11, 10, 24,  2,  1,  3,
         2, 13, 10,  1,  2, 13, 10, 17,  6, 24,  2,  1, 12,  2, 19,  2, 21,  2,
        20, 10,  1, 21, 22, 17, 22,  1, 15,  2,  1, 12, 22, 21,  2, 12, 10, 24,
         2,  1, 12, 22, 20,  2, 10, 15, 10,  1,  3, 10, 13,  2,  1, 12, 22,  7,
         2,  9,  2, 14, 22,  1, 15, 10,  1, 12, 10, 21, 22,  1,  8,  2, 15, 10,
         1,  1,  0,  1,  2, 12, 10,  6, 13,  6, 27,  6,  2,  1, 20, 10, 12, 22,
         1, 26,  2,  1, 21, 22, 12, 10, 16,  1,  2, 13, 10, 20,  6, 14,  2,  1,
         2, 13, 10, 17, 10,  8, 10, 24,  2,  1, 20, 10, 14, 22,  1, 15,  2,  1,
        20, 20, 17,  1, 20,  2, 13, 22, 14,  1, 12, 10, 20,  2, 10,  1,  2, 14,
         3,  2, 26,  6,  1,  2, 13, 10, 14, 21,  2, 12,  2,  1, 12, 22,  7, 10,
        12,  2,  1, 12,  2, 21, 10, 12,

In [12]:
# validation dataset
!wget https://raw.githubusercontent.com/atonui/pds/refs/heads/main/swahili_data/valid.txt

--2025-08-21 07:50:12--  https://raw.githubusercontent.com/atonui/pds/refs/heads/main/swahili_data/valid.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 655979 (641K) [text/plain]
Saving to: ‘valid.txt’


2025-08-21 07:50:12 (27.2 MB/s) - ‘valid.txt’ saved [655979/655979]



In [13]:
# read it to inspect it
with open('valid.txt', 'r') as f:
    valid = f.read()

In [14]:
# encode validation dataset into a tensor
valid_data = torch.tensor(encode(valid), dtype=torch.long)
print(valid_data.shape, valid_data.type)
print(valid_data[:100])

# code from the repetitive cells above is ripe for a function

torch.Size([652605]) <built-in method type of Tensor object at 0x7c85deec55e0>
tensor([28,  1,  9, 10, 10,  1, 15, 10,  1,  5,  9,  2, 15,  2,  1, 17, 16, 21,
        16,  7, 22,  1, 15,  2,  1, 26,  2,  1,  9,  2, 21,  2, 19, 10,  1,  9,
         2, 20,  2,  1, 22, 12, 10, 27, 10, 15,  8,  2, 21, 10,  2,  1,  3,  2,
         2,  5,  9, 10,  1, 26,  2,  1, 24,  2, 15,  2, 15,  4,  9, 10,  1, 24,
         6, 15,  8, 10,  1, 24,  2, 15,  2,  1, 14,  2, 20,  9,  2, 12,  2,  1,
        15,  2,  1, 22, 27,  2, 13,  6, 15,  5])


In [15]:
# training block size
block_size = 8
train_data[:block_size+1]
# the transformer is not trained on the entire text but on blocks of text e.g. the above block size is 9 characters

tensor([28,  1, 21,  2,  2, 19, 10,  7,  2])

In [16]:
x = train_data[:block_size] # inputs to the transformer
y = train_data[1:block_size+1] # next block size, it is offset by 1
# iterating through the block size
for t in range(block_size):
  context = x[:t+1]
  target = y[t]
  print(f'when input is {context} the target: {target}')

when input is tensor([28]) the target: 1
when input is tensor([28,  1]) the target: 21
when input is tensor([28,  1, 21]) the target: 2
when input is tensor([28,  1, 21,  2]) the target: 2
when input is tensor([28,  1, 21,  2,  2]) the target: 19
when input is tensor([28,  1, 21,  2,  2, 19]) the target: 10
when input is tensor([28,  1, 21,  2,  2, 19, 10]) the target: 7
when input is tensor([28,  1, 21,  2,  2, 19, 10,  7]) the target: 2


In [17]:
torch.manual_seed(1337)
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?

def get_batch(split):
  # generate a small batch of data of inputs x and targets y
  data = train_data if split == 'train' else valid_data
  ix = torch.randint(len(data) - block_size, (batch_size,))
  x = torch.stack([data[i:i+block_size] for i in ix])
  y = torch.stack([data[i+1:i+block_size+1] for i in ix])
  return x,y

xb, yb = get_batch('train')
print('inputs: ')
print(xb.shape)
print(xb)
print('targets: ')
print(yb.shape)
print(yb)

print('----------')

for b in range(batch_size): # batch dimension
  for t in range(block_size): # time dimension
    context = xb[b, :t+1]
    target = yb[b,t]
    print(f'when input is {context.tolist()} the target: {target}')

inputs: 
torch.Size([4, 8])
tensor([[12, 22, 13, 10, 14,  7,  2, 15],
        [20, 10, 20, 10, 21, 10, 27,  2],
        [11, 10,  1, 24,  2,  1, 21,  2],
        [ 2, 16,  1, 27, 10, 21,  2, 12]])
targets: 
torch.Size([4, 8])
tensor([[22, 13, 10, 14,  7,  2, 15, 26],
        [10, 20, 10, 21, 10, 27,  2,  1],
        [10,  1, 24,  2,  1, 21,  2, 15],
        [16,  1, 27, 10, 21,  2, 12,  2]])
----------
when input is [12] the target: 22
when input is [12, 22] the target: 13
when input is [12, 22, 13] the target: 10
when input is [12, 22, 13, 10] the target: 14
when input is [12, 22, 13, 10, 14] the target: 7
when input is [12, 22, 13, 10, 14, 7] the target: 2
when input is [12, 22, 13, 10, 14, 7, 2] the target: 15
when input is [12, 22, 13, 10, 14, 7, 2, 15] the target: 26
when input is [20] the target: 10
when input is [20, 10] the target: 20
when input is [20, 10, 20] the target: 10
when input is [20, 10, 20, 10] the target: 21
when input is [20, 10, 20, 10, 21] the target: 10
when in

In [18]:
print(xb)

tensor([[12, 22, 13, 10, 14,  7,  2, 15],
        [20, 10, 20, 10, 21, 10, 27,  2],
        [11, 10,  1, 24,  2,  1, 21,  2],
        [ 2, 16,  1, 27, 10, 21,  2, 12]])


In [19]:
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):

  def __init__(self, vocab_size):
    super().__init__()
    # each token directly reads off the logits for the next token from a lookup table
    self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)


  def forward(self, idx, targets=None):
     # idx and targets are both (B,T) tensor of integers
     logits = self.token_embedding_table(idx) # (Batch,Time,Channel) tensor

     if targets is None:
      loss = None
     else:
      B,T,C = logits.shape
      logits = logits.view(B*T, C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits, targets)

     return logits, loss

  def generate(self, idx, max_new_tokens):
    # idx is (B, T) array of indices in the current context
    for _ in range(max_new_tokens):
      # get the predictions
      logits, loss = self(idx)
      # focus only on the last time step
      logits = logits[:,-1,:] # becomes(B,C)
      # apply softmax to get probabilities
      probs = F.softmax(logits, dim=-1) # (B,C)
      # sample from the distribution
      idx_next = torch.multinomial(probs, num_samples=1) # (B,1)
      # append sampled index to the running sequence
      idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
    return idx

m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
out = m(xb, yb)
print(logits.shape)
print(loss)

print(decode(m.generate(idx = torch.zeros((1,1), dtype=torch.long), max_new_tokens=100)[0].tolist()))


torch.Size([32, 29])
tensor(4.0687, grad_fn=<NllLossBackward0>)

pvpv﻿nauzhklblpnnamd
nkrhvgeh lkjrokmjrulbsbuzwna
p
qko
enbromnabuzwcyrhmmnnhsnbxpsxmuowomabsv apitg


In [20]:
# create a pytorch optimiser object
optimiser = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [23]:
batch_size = 32
for steps in range(10000):
  # sample a batch of data
  xb, yb = get_batch('train')

  # evaluate the loss
  logits, loss = m(xb, yb)
  optimiser.zero_grad(set_to_none=True)
  loss.backward()
  optimiser.step()

print(loss.item())

2.0956337451934814


In [24]:
print(decode(m.generate(idx = torch.zeros((1,1), dtype=torch.long), max_new_tokens=500)[0].tolist()))


  rushio kitanajanimilihu  ka zif nekiza uba nanama nalekwaada kuka da he ya shi kikwenao ga ba liji ya bwalo ha u ao ngo sti mimba mpando maja zo wanuna  vi sa ja ya mchana yoarisitatanda ansuoto kayawezobaweo nawekwahio towalivikuwa yafibotida yoni yoaa 
 hekeku mea a msho  i ya a ka  hda wa m heria tali di kemkug kurekakusa wa zama waniyonanzartaria ho wali kawa kundirima 
 nzemisi  ki hila ta ka cofa tikifa chenili ka mbotana letihio  
 snenicmsalimgi cha mya yosungoji wenya 
 hizai mbafanik


Much improvement on the Bigram model but we're not quite there yet. This is a simple model where the tokens are not talking to each other, where the prediction is happening only on the very last character. So next we have to make the tokens talk to each other and figure out the context and make better predictions which is what the **Transformer** will do.