<a href="https://colab.research.google.com/github/atonui/swahili-gpt/blob/main/swa_gpt.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Swahili-GPT
- A simple character level transformer based LLM trained on a Swahili dataset.

In [1]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("mpwolke/swahili")

print("Path to dataset files:", path)

# https://www.kaggle.com/datasets/mpwolke/swahili/code

Downloading from https://www.kaggle.com/api/v1/datasets/download/mpwolke/swahili?dataset_version_number=1...


100%|██████████| 2.65M/2.65M [00:00<00:00, 24.6MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/mpwolke/swahili/versions/1


In [2]:
# load the training dataset
!wget https://raw.githubusercontent.com/atonui/pds/refs/heads/main/Swahili%20data/train.txt

--2025-08-19 14:06:48--  https://raw.githubusercontent.com/atonui/pds/refs/heads/main/Swahili%20data/train.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 7658045 (7.3M) [text/plain]
Saving to: ‘train.txt’


2025-08-19 14:06:49 (62.1 MB/s) - ‘train.txt’ saved [7658045/7658045]



In [3]:
# read it to inspect it
with open('train.txt', 'r') as f:
    text = f.read()

In [4]:
print('Length if dataset in characters: ', len(text))

Length if dataset in characters:  7658045


In [5]:
print(text[:1000])

 taarifa hiyo ilisema kuwa ongezeko la joto la maji juu ya wastani katikati ya bahari ya UNK inaashiria kuwepo kwa mvua za el nino UNK hadi mwishoni mwa april ishirini moja sifuri imeelezwa kuwa ongezeko la joto magharibi mwa bahari ya hindi linatarajiwa kuhamia katikati ya bahari hiyo hali ambayo itasababisha pepo kutoka kaskazini mashariki kuvuma kuelekea bahari ya hindi 
 aidha ilisema kuwa mwelekeo wa kupungua kwa joto kusini mashariki mwa bahari ya atlantic UNK kusababisha pepo kutoka magharibi kuvuma kuelekea magharibi mwa tanzania katika maeneo ya ziwa victoria 
 mwelekeo wa mvua wa septemba hadi desemba ishirini sifuri tisa unatarajiwa kuwa katika namna tofauti ambapo baadhi ya maeneo yanaweza kunufaika huku mengine UNK 
 ilifafanua kuwa msimu wa vuli UNK maeneo ambayo hupata mvua mara mbili ambayo ni kaskazini mwa nchi ikiwa ni nyanda za juu kaskazini mashariki kanda ya ziwa victoria na pwani ya kaskazini 
 katika maeneo hayo mvua zinatarajiwa kunyesha wiki ya pili na tatu ya 

In [26]:
# get a sorted list of all the characters in the text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print("There are", vocab_size, "unique characters in this text.")


 KNUabcdefghijklmnopqrstuvwxyz
There are 31 unique characters in this text.


## Tokenisation
- Represent characters as integers (vectors) so the model can manipulate them.
- The below tokeniser is simple, it just translates a character to an integer.
- There are more sophisticated tokenisers out there, we shall experiment with them.

In [7]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("joto jingi"))
print(decode(encode("joto jingi")))

[14, 19, 24, 19, 1, 14, 13, 18, 11, 13]
joto jingi


In [8]:
print(decode([14, 19, 24, 19, 1, 16, 13, 18, 11, 13]))

joto lingi


In [9]:
# now we're going to encode the entire dataset and store it into a torch.Tensor
import torch
train_data = torch.tensor(encode(text), dtype=torch.long)
print(train_data.shape, train_data.type)
print(train_data[:1000])

torch.Size([7658045]) <built-in method type of Tensor object at 0x7f8b721fcd70>
tensor([ 1, 24,  5,  5, 22, 13, 10,  5,  1, 12, 13, 29, 19,  1, 13, 16, 13, 23,
         9, 17,  5,  1, 15, 25, 27,  5,  1, 19, 18, 11,  9, 30,  9, 15, 19,  1,
        16,  5,  1, 14, 19, 24, 19,  1, 16,  5,  1, 17,  5, 14, 13,  1, 14, 25,
        25,  1, 29,  5,  1, 27,  5, 23, 24,  5, 18, 13,  1, 15,  5, 24, 13, 15,
         5, 24, 13,  1, 29,  5,  1,  6,  5, 12,  5, 22, 13,  1, 29,  5,  1,  4,
         3,  2,  1, 13, 18,  5,  5, 23, 12, 13, 22, 13,  5,  1, 15, 25, 27,  9,
        20, 19,  1, 15, 27,  5,  1, 17, 26, 25,  5,  1, 30,  5,  1,  9, 16,  1,
        18, 13, 18, 19,  1,  4,  3,  2,  1, 12,  5,  8, 13,  1, 17, 27, 13, 23,
        12, 19, 18, 13,  1, 17, 27,  5,  1,  5, 20, 22, 13, 16,  1, 13, 23, 12,
        13, 22, 13, 18, 13,  1, 17, 19, 14,  5,  1, 23, 13, 10, 25, 22, 13,  1,
        13, 17,  9,  9, 16,  9, 30, 27,  5,  1, 15, 25, 27,  5,  1, 19, 18, 11,
         9, 30,  9, 15, 19,  1, 16,  5, 

In [10]:
# test dataset
!wget https://raw.githubusercontent.com/atonui/pds/refs/heads/main/Swahili%20data/test.txt

--2025-08-19 14:06:55--  https://raw.githubusercontent.com/atonui/pds/refs/heads/main/Swahili%20data/test.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 691715 (676K) [text/plain]
Saving to: ‘test.txt’


2025-08-19 14:06:55 (16.9 MB/s) - ‘test.txt’ saved [691715/691715]



In [11]:
# read it to inspect it
with open('test.txt', 'r') as f:
    test = f.read()

In [12]:
# encode test dataset into a tensor
test_data = torch.tensor(encode(test), dtype=torch.long)
print(test_data.shape, test_data.type)
print(test_data[:1000])

torch.Size([691715]) <built-in method type of Tensor object at 0x7f8aa8d4ca10>
tensor([ 1, 12, 25, 29, 19,  1,  5, 16, 13, 23, 13, 23, 13, 24, 13, 30,  5,  1,
        15, 25, 27,  5,  1, 12,  5, 15, 25, 12, 19, 14, 13, 27,  5,  1,  6,  5,
        16, 13,  1,  5, 16, 13, 20,  9, 27,  5,  1, 15,  5, 22,  5, 24,  5, 23,
        13,  1, 24, 25, 20, 25,  1, 18,  5,  1, 15, 25, 24,  5, 15, 13, 27,  5,
         1, 15, 25, 23,  5, 13, 18, 13,  1,  6, 13, 16,  5,  1, 15, 25, 10,  5,
        12,  5, 17, 25,  1, 18, 13,  1, 15, 13, 24, 25,  1, 11,  5, 18, 13,  1,
         4,  3,  2,  1,  0,  1,  5, 15, 13,  9, 16,  9, 30,  9,  5,  1, 23, 13,
        15, 25,  1, 29,  5,  1, 24, 25, 15, 13, 19,  1,  5, 16, 13, 23,  9, 17,
         5,  1,  5, 16, 13, 20, 13, 11, 13, 27,  5,  1, 23, 13, 17, 25,  1, 18,
         5,  1, 23, 23, 20,  1, 23,  5, 16, 25, 17,  1, 15, 13, 23,  5, 13,  1,
         5, 17,  6,  5, 29,  9,  1,  5, 16, 13, 17, 24,  5, 15,  5,  1, 15, 25,
        10, 13, 15,  5,  1, 15,  5, 24, 1

In [13]:
# validation dataset
!wget https://raw.githubusercontent.com/atonui/pds/refs/heads/main/Swahili%20data/valid.txt

--2025-08-19 14:06:55--  https://raw.githubusercontent.com/atonui/pds/refs/heads/main/Swahili%20data/valid.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 660142 (645K) [text/plain]
Saving to: ‘valid.txt’


2025-08-19 14:06:56 (18.5 MB/s) - ‘valid.txt’ saved [660142/660142]



In [14]:
# read it to inspect it
with open('valid.txt', 'r') as f:
    valid = f.read()

In [15]:
# encode validation dataset into a tensor
valid_data = torch.tensor(encode(valid), dtype=torch.long)
print(valid_data.shape, valid_data.type)
print(valid_data[:100])

# code from the repetitive cells above is ripe for a function

torch.Size([660142]) <built-in method type of Tensor object at 0x7f8b723c38f0>
tensor([ 1, 12, 13, 13,  1, 18, 13,  1,  8, 12,  5, 18,  5,  1, 20, 19, 24, 19,
        10, 25,  1, 18,  5,  1, 29,  5,  1, 12,  5, 24,  5, 22, 13,  1, 12,  5,
        23,  5,  1, 25, 15, 13, 30, 13, 18, 11,  5, 24, 13,  5,  1,  6,  5,  5,
         8, 12, 13,  1, 29,  5,  1, 27,  5, 18,  5, 18,  7, 12, 13,  1, 27,  9,
        18, 11, 13,  1, 27,  5, 18,  5,  1, 17,  5, 23, 12,  5, 15,  5,  1, 18,
         5,  1, 25, 30,  5, 16,  9, 18,  8, 19])


In [16]:
# training block size
block_size = 8
train_data[:block_size+1]
# the transformer is not trained on the entire text but on blocks of text e.g. the above block size is 9 characters

tensor([ 1, 24,  5,  5, 22, 13, 10,  5,  1])

In [17]:
x = train_data[:block_size] # inputs to the transformer
y = train_data[1:block_size+1] # next block size, it is offset by 1
# iterating through the block size
for t in range(block_size):
  context = x[:t+1]
  target = y[t]
  print(f'when input is {context} the target: {target}')

when input is tensor([1]) the target: 24
when input is tensor([ 1, 24]) the target: 5
when input is tensor([ 1, 24,  5]) the target: 5
when input is tensor([ 1, 24,  5,  5]) the target: 22
when input is tensor([ 1, 24,  5,  5, 22]) the target: 13
when input is tensor([ 1, 24,  5,  5, 22, 13]) the target: 10
when input is tensor([ 1, 24,  5,  5, 22, 13, 10]) the target: 5
when input is tensor([ 1, 24,  5,  5, 22, 13, 10,  5]) the target: 1


In [18]:
torch.manual_seed(1337)
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?

def get_batch(split):
  # generate a small batch of data of inputs x and targets y
  data = train_data if split == 'train' else valid_data
  ix = torch.randint(len(data) - block_size, (batch_size,))
  x = torch.stack([data[i:i+block_size] for i in ix])
  y = torch.stack([data[i+1:i+block_size+1] for i in ix])
  return x,y

xb, yb = get_batch('train')
print('inputs: ')
print(xb.shape)
print(xb)
print('targets: ')
print(yb.shape)
print(yb)

print('----------')

for b in range(batch_size): # batch dimension
  for t in range(block_size): # time dimension
    context = xb[b, :t+1]
    target = yb[b,t]
    print(f'when input is {context.tolist()} the target: {target}')

inputs: 
torch.Size([4, 8])
tensor([[ 5, 24, 13,  1, 27,  5,  1, 15],
        [13,  6, 13, 19,  1, 29,  5, 15],
        [ 1, 27,  5, 30, 19,  1, 16,  5],
        [13,  1, 12, 13, 16, 19,  1, 16]])
targets: 
torch.Size([4, 8])
tensor([[24, 13,  1, 27,  5,  1, 15, 25],
        [ 6, 13, 19,  1, 29,  5, 15,  9],
        [27,  5, 30, 19,  1, 16,  5,  1],
        [ 1, 12, 13, 16, 19,  1, 16, 13]])
----------
when input is [5] the target: 24
when input is [5, 24] the target: 13
when input is [5, 24, 13] the target: 1
when input is [5, 24, 13, 1] the target: 27
when input is [5, 24, 13, 1, 27] the target: 5
when input is [5, 24, 13, 1, 27, 5] the target: 1
when input is [5, 24, 13, 1, 27, 5, 1] the target: 15
when input is [5, 24, 13, 1, 27, 5, 1, 15] the target: 25
when input is [13] the target: 6
when input is [13, 6] the target: 13
when input is [13, 6, 13] the target: 19
when input is [13, 6, 13, 19] the target: 1
when input is [13, 6, 13, 19, 1] the target: 29
when input is [13, 6, 13, 19

In [19]:
print(xb)

tensor([[ 5, 24, 13,  1, 27,  5,  1, 15],
        [13,  6, 13, 19,  1, 29,  5, 15],
        [ 1, 27,  5, 30, 19,  1, 16,  5],
        [13,  1, 12, 13, 16, 19,  1, 16]])


In [20]:
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):

  def __init__(self, vocab_size):
    super().__init__()
    # each token directly reads off the logits for the next token from a lookup table
    self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)


  def forward(self, idx, targets=None):
     # idx and targets are both (B,T) tensor of integers
     logits = self.token_embedding_table(idx) # (Batch,Time,Channel) tensor

     if targets is None:
      loss = None
     else:
      B,T,C = logits.shape
      logits = logits.view(B*T, C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits, targets)

     return logits, loss

  def generate(self, idx, max_new_tokens):
    # idx is (B, T) array of indices in the current context
    for _ in range(max_new_tokens):
      # get the predictions
      logits, loss = self(idx)
      # focus only on the last time step
      logits = logits[:,-1,:] # becomes(B,C)
      # apply softmax to get probabilities
      probs = F.softmax(logits, dim=-1) # (B,C)
      # sample from the distribution
      idx_next = torch.multinomial(probs, num_samples=1) # (B,1)
      # append sampled index to the running sequence
      idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
    return idx

m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
out = m(xb, yb)
print(logits.shape)
print(loss)

print(decode(m.generate(idx = torch.zeros((1,1), dtype=torch.long), max_new_tokens=100)[0].tolist()))


torch.Size([32, 31])
tensor(4.2781, grad_fn=<NllLossBackward0>)

lmxfiaNwinuxgax meyjomaftuvwt sinywkqjNalpx zm 
ajfauhgskdhtfmNlpfeyuvghmN jNyUdxnlt rggsnuxsnlprs p


In [21]:
# create a pytorch optimiser object
optimiser = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [22]:
batch_size = 32
for steps in range(10000):
  # sample a batch of data
  xb, yb = get_batch('train')

  # evaluate the loss
  logits, loss = m(xb, yb)
  optimiser.zero_grad(set_to_none=True)
  loss.backward()
  optimiser.step()

print(loss.item())

2.0465545654296875


In [23]:
print(decode(m.generate(idx = torch.zeros((1,1), dtype=torch.long), max_new_tokens=500)[0].tolist()))


 nyeshi kwa fchioa ngweliumu tila ube
 kjanegezi UNK ki majauru ma kulifaumza i iki zo lmorikuti nani waia hi siletaya kaa yama ku mo tilbeta nsipilisa kwalitarwawaji aona lazwemja sigemkichado ase 
 UNK bana gu wasa ngekwakuchana e sa UNK kurandima 
 ia hari matpuulirugi 
 ay
 ifazao bo jima yo ku mbwamba he 
 li mbwa h kakali belite pso jumuvyaomatenara itra yafutango hu 
 ku liwa vujayar kwa batileni i wanga iku seadina ta i slileya wa ka dika 
 va unanaziana matu hokubarilosha ndangam si tev
