In [1]:
from itertools import chain
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import math
from typing import List, Tuple, Dict, Optional
from torch.utils.tensorboard import SummaryWriter

import base64
import requests
import os

from pathlib import Path

import numpy as np

from dataclasses import dataclass
from typing import Protocol

## Hyper-Parameters

In [2]:
@dataclass
class HyperParams:
    #Meta
    model_name: str
    token_type: str
    #Training
    epochs: int
    dropout: float
    #Model
    layers: int
    emb_size: int
    heads: int
    ff_dim: int
    #Data 
    file_name: str
    batch_size: int

hp = HyperParams(
    model_name="LocalTestForDelete",
    token_type="Word",
    file_name="spa",
    layers=3,
    emb_size=256,
    heads=8,
    ff_dim=512,
    epochs=10,
    dropout=0.1,
    batch_size=32,
)

GH_TOKEN = ""

## Get Environment variables

In [3]:
def is_running_in_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False

IS_COLAB = is_running_in_colab()

In [4]:
DEVICE = 'cuda' if IS_COLAB else 'cpu'

## Data file

In [5]:
def get_text_path(file_name: str):
    if IS_COLAB:
        url = f"https://raw.githubusercontent.com/WillCable97/HomemadeTransformer/main/data/raw/{file_name}.txt"
        !wget {url}
        return f"./{file_name}.txt"
    else:
        from homemadetransformer.config import DATA_DIR
        return DATA_DIR / "raw" / f"{file_name}.txt"

## Tokenizers

In [6]:
from typing import Protocol

class Tokenizer(Protocol):
    def tokenize(self, text: str) -> List[str]:
        ...

    def add_special_tokens(self, tokens: List[str]) -> List[str]:
        ...

## Simple Word Tokenizer

In [7]:
class SimpleWordTokenizer:
    """Simple word-based tokenizer that splits on whitespace."""
    
    def __init__(self, sos_token="<SOS>", eos_token="<EOS>", pad_token="<PAD>", unk_token="<UNK>"):
        self.sos_token = sos_token
        self.eos_token = eos_token
        self.pad_token = pad_token
        self.unk_token = unk_token
        
    def tokenize(self, text: str) -> List[str]:
        return text.lower().strip().split()
    
    def add_special_tokens(self, tokens: List[str]) -> List[str]:
        return [self.sos_token] + tokens + [self.eos_token]


## Preprocessor

In [24]:
class TranslationPreprocessor:
    def __init__(self, src_tok: Tokenizer | None = None, tgt_tok: Tokenizer | None = None):
        self.src_tok = src_tok or SimpleWordTokenizer() #Default to this 
        self.tgt_tok = tgt_tok or SimpleWordTokenizer()
        self.max_src_len = None
        self.max_tgt_len = None

        # will be filled by fit()
        self.src2idx: dict[str, int] = {}
        self.tgt2idx: dict[str, int] = {}
        self.idx2src: dict[int, str] = {}
        self.idx2tgt: dict[int, str] = {}
        self.src_len = self.tgt_len = 0  # effective lengths

    # ---------- helpers ------------------------------------------------------

    def _tok_plus_specials(self, text: str, tkzr: Tokenizer) -> List[str]:
        """tokenise then add <SOS>/<EOS> (or whatever the tkzr decides)."""
        return tkzr.add_special_tokens(tkzr.tokenize(text))

    @staticmethod
    def _build_vocab(sentences: List[List[str]], pad: str, unk: str) -> dict:
        vocab = {pad: 0, unk: 1} #Reserve these from the beginning
        for tok in chain.from_iterable(sentences):
            vocab.setdefault(tok, len(vocab))
        return vocab

    def _encode_pad(self, tokens: List[str], vocab: dict
                    , max_len: int, pad_idx: int) -> List[int]:
        idxs = [vocab.get(t, vocab["<UNK>"]) for t in tokens][:max_len]
        idxs += [pad_idx] * (max_len - len(idxs))
        return idxs

    def src_ids(self, text: str) -> torch.LongTensor:
        toks = self._tok_plus_specials(text, self.src_tok)
        ids  = [self.src2idx.get(t, self.src2idx[self.src_tok.unk_token]) for t in toks]
        return torch.tensor(ids, dtype=torch.long)

    def text_from_tgt(self, ids: torch.LongTensor) -> str:
        return " ".join(self.idx2tgt[i.item()] for i in ids)


    # ---------- public ----------------------
    def fit(self, pairs: List[Tuple[str, str]]) -> None: #Takes in pairs of sentences
        #Do all the tokenisation
        src_tokd = [self._tok_plus_specials(s, self.src_tok) for s, _ in pairs]
        tgt_tokd = [self._tok_plus_specials(t, self.tgt_tok) for _, t in pairs]

        self.src2idx = self._build_vocab(src_tokd, self.src_tok.pad_token, self.src_tok.unk_token)
        self.tgt2idx = self._build_vocab(tgt_tokd, self.tgt_tok.pad_token, self.tgt_tok.unk_token)
        self.idx2src = {i: w for w, i in self.src2idx.items()}
        self.idx2tgt = {i: w for w, i in self.tgt2idx.items()}

        # decide max effective lengths
        self.src_len = self.max_src_len or max(map(len, src_tokd))
        self.tgt_len = self.max_tgt_len or max(map(len, tgt_tokd))

    def transform(self, pairs: List[Tuple[str, str]]):
        pad_src = self.src2idx[self.src_tok.pad_token]
        pad_tgt = self.tgt2idx[self.tgt_tok.pad_token]

        src_enc = [
            self._encode_pad(
                self._tok_plus_specials(s, self.src_tok),
                self.src2idx,
                self.src_len,
                pad_src,
            )
            for s, _ in pairs
        ]
        tgt_enc = [
            self._encode_pad(
                self._tok_plus_specials(t, self.tgt_tok),
                self.tgt2idx,
                self.tgt_len,
                pad_tgt,
            )
            for _, t in pairs
        ]
        return src_enc, tgt_enc

    def fit_transform(self, pairs: List[Tuple[str, str]]):
        self.fit(pairs)
        return self.transform(pairs)

## Full dataset object

In [25]:
class TranslationDataset(Dataset):
    """PyTorch Dataset for translation data."""
    
    def __init__(self, src_data: List[List[int]], tgt_data: List[List[int]]):
        self.src_data = src_data
        self.tgt_data = tgt_data
        
    def __len__(self):
        return len(self.src_data)
    
    def __getitem__(self, idx):
        return torch.tensor(self.src_data[idx]), torch.tensor(self.tgt_data[idx])


## Standard Model

In [26]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -math.log(10000.0) / d_model)
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return self.dropout(x + self.pe[:, :x.size(1)])


In [42]:
class Seq2SeqTransformer(nn.Module):
    def __init__(self, num_encoder_layers, num_decoder_layers, emb_size, nhead,
                 src_vocab_size, tgt_vocab_size, dim_feedforward, dropout):
        super().__init__()
        self.src_tok_emb = nn.Embedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = nn.Embedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(emb_size, dropout)

        self.transformer = nn.Transformer(d_model=emb_size, nhead=nhead,
                                          num_encoder_layers=num_encoder_layers,
                                          num_decoder_layers=num_decoder_layers,
                                          dim_feedforward=dim_feedforward,
                                          dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)

    def forward(self, src, tgt, src_mask, tgt_mask, src_pad_mask, tgt_pad_mask, mem_pad_mask):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(tgt))
        out = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask,
                               None, src_pad_mask, tgt_pad_mask, mem_pad_mask)
        return self.generator(out)
    
    def encode(self, src):                    
        src_mask = src.new_zeros((src.size(0), src.size(0)), dtype=torch.bool)
        return self.transformer.encoder(
            self.positional_encoding(self.src_tok_emb(src)), src_mask
        )

    def decode_step(self, tgt, memory):      
        T = tgt.size(0)
        tgt_mask = torch.triu(torch.ones(T, T, device=tgt.device, dtype=torch.bool), 1)
        out = self.transformer.decoder(
            self.positional_encoding(self.tgt_tok_emb(tgt)), memory, tgt_mask
        )
        return self.generator(out)           


## Performing + Testing translation

In [43]:
@torch.no_grad()
def greedy_translate(model, sentence: str, pp, max_len: int = 50, device="cpu") -> str:
    model.eval()
    device   = torch.device(device)

    sos, eos = pp.tgt2idx[pp.tgt_tok.sos_token], pp.tgt2idx[pp.tgt_tok.eos_token]

    # encode source
    src_ids  = pp.src_ids(sentence).to(device)[:, None]          
    memory   = model.encode(src_ids)                             

    tgt_ids  = torch.tensor([[sos]], device=device)              

    for _ in range(max_len):
        logits = model.decode_step(tgt_ids, memory)              
        next_id = logits[-1].argmax(-1).item()# greedy

        if next_id == eos:
            break
        tgt_ids = torch.cat([tgt_ids, torch.tensor([[next_id]], device=device)])

    return pp.text_from_tgt(tgt_ids.squeeze(1)[1:])

In [None]:
def create_sample_translations(model, epoch: int, pp, device, max_len: int = 50):
    # Sample sentences to translate
    sample_sentences = [
        "Hello, how are you?",
        "I love this place.",
        "What time is it?",
        "Thank you very much.",
        "Good morning everyone.",
        "I'm excited about dinner tonight."
    ]
    
    model.eval()
    translations_text = f"## Epoch {epoch} Sample Translations\n\n"
    
    with torch.no_grad():
        for i, src_sentence in enumerate(sample_sentences):
            try:
                # Translate using your existing translate function
                translation = greedy_translate(model, src_sentence, pp, max_len, device)
                
                # Format for TensorBoard
                translations_text += f"**{src_sentence}** → **{translation}**\n\n"
                
            except Exception as e:
                translations_text += f"**{src_sentence}** → **Error: {str(e)}**\n\n"
    return translations_text

## Logging Wrapper

In [None]:
class MyLogger:
    def __init__(self, model_name):
        self.writer  = torch.utils.tensorboard.SummaryWriter(f'runs/{model_name}')

    def log_avg_loss(self, train_loss, val_loss, epoch):
        self.writer.add_scalar('Loss/epoch_avg_train', train_loss, epoch)
        self.writer.add_scalar('Loss/epoch_avg_val', val_loss, epoch)
    
    def log_step_level_loss(self, loss, step):
        self.writer.add_scalar('Loss/train_step', loss, step)

    def log_sample_translations(self, translations, epoch):
        self.writer.add_text('Translations/Samples', translations, epoch)
        self.writer.flush()

    def close(self):
        self.writer.close()

## Passing data into the model

In [None]:
def create_batch_forward_pass(model, src_batch, tgt_batch, pad_idx, device, loss_fn):
  src_batch = src_batch.transpose(0, 1).to(device) 
  tgt_batch = tgt_batch.transpose(0, 1).to(device) 

  tgt_input = tgt_batch[:-1, :]
  tgt_output = tgt_batch[1:, :]

  src_mask, tgt_mask, src_pad_mask, tgt_pad_mask = create_mask(src_batch, tgt_input, pad_idx)

  logits = model(src_batch, tgt_input, src_mask, tgt_mask, src_pad_mask, tgt_pad_mask, src_pad_mask)

  logits_flat = logits.reshape(-1, logits.shape[-1])
  tgt_output_flat = tgt_output.reshape(-1)

  loss = loss_fn(logits_flat, tgt_output_flat)

  return loss, logits_flat

## Github interact

In [None]:
def upload_to_github(local_path, repo_path, token, username="WillCable97", repo="HomemadeTransformer", branch="main"):
    
    with open(local_path, "rb") as f:
        content = base64.b64encode(f.read()).decode("utf-8")

    url = f"https://api.github.com/repos/{username}/{repo}/contents/{repo_path}"
    data = {
        "message": f"Upload {repo_path} from Colab",
        "content": content,
        "branch": branch
    }
    headers = {"Authorization": f"token {token}"}
    r = requests.put(url, headers=headers, json=data)
    if r.status_code == 201:
        print(f"✅ Uploaded {repo_path} to GitHub")
    else:
        print(f"❌ Failed to upload {repo_path} — {r.status_code}: {r.json()}")

## Training loop

In [33]:
def train_model(model, dataloader, val_dataloader, preprocessor, optimizer, loss_fn, num_epochs, my_logger, pad_idx=0, device='cpu'):
    model.to(device)
    model.train()
    
    global_step = 1

    for epoch in range(num_epochs):
        # For Logging progression
        print(f"Starting epoch {epoch}")
        dl_len = len(dataloader)
        i = 1

        model.train()
        total_loss = 0

        for eng_batch, spa_batch in dataloader:
            
            print(f"\rStarting {i}/{dl_len}", end="")
            i += 1
            global_step += 1

            loss, _ = create_batch_forward_pass(model, eng_batch, spa_batch, pad_idx, device, loss_fn)

            #Perform the training step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            my_logger.log_step_level_loss(loss.item(), global_step)
        
        avg_loss = total_loss / len(dataloader)

        # Validation
        model.eval()
        total_val_loss = 0
        
        with torch.no_grad():
            for eng_batch, spa_batch in val_dataloader:
                loss, _ = create_batch_forward_pass(model, eng_batch, spa_batch, pad_idx, device, loss_fn)
                total_val_loss += loss.item()

        avg_val_loss = total_val_loss / len(val_dataloader)
        my_logger.log_avg_loss(avg_loss, avg_val_loss, epoch)

        #Log sample translations
        sample_translations = create_sample_translations(model, epoch, preprocessor, device)
        my_logger.log_sample_translations(sample_translations, epoch)
        
        print(f"Epoch {epoch + 1}/{num_epochs} — Train Loss: {avg_loss:.4f} — Val Loss: {avg_val_loss:.4f}")

    my_logger.close()
    
    return avg_loss, avg_val_loss


## Create All Instances

#### 1- Load data

In [34]:
#Data
text_path = get_text_path(hp.file_name)

with open(text_path, encoding="utf-8") as f:
    lines = f.read().strip().split("\n")[:1000]

pairs = [tuple(line.split("\t")[:2]) for line in lines if "\t" in line] 


#Pre process + Prepare
preprocessor = TranslationPreprocessor(
    src_tok=SimpleWordTokenizer(),
    tgt_tok=SimpleWordTokenizer()
)

src_data, tgt_data = preprocessor.fit_transform(pairs)

#### 2- Create data Objects

In [None]:
# Split into train and validation sets with fixed random seed
np.random.seed(42)
indices = np.random.permutation(len(src_data))
train_size = int(0.9 * len(src_data))

train_indices = indices[:train_size]
val_indices = indices[train_size:]

train_src_data = [src_data[i] for i in train_indices]
train_tgt_data = [tgt_data[i] for i in train_indices]
val_src_data = [src_data[i] for i in val_indices]
val_tgt_data = [tgt_data[i] for i in val_indices]

# Create datasets and dataloaders
train_dataset = TranslationDataset(train_src_data, train_tgt_data)
val_dataset = TranslationDataset(val_src_data, val_tgt_data)
train_dl = DataLoader(train_dataset, batch_size=hp.batch_size, shuffle=True)
val_dl = DataLoader(val_dataset, batch_size=hp.batch_size, shuffle=False)

#### 3- Training and mdel Objects

In [35]:
#Model
model = Seq2SeqTransformer(
    num_encoder_layers=hp.layers,
    num_decoder_layers=hp.layers,
    emb_size=hp.emb_size,
    nhead=hp.heads,
    src_vocab_size=len(preprocessor.src_vocab),
    tgt_vocab_size=len(preprocessor.tgt_vocab),
    dim_feedforward=hp.ff_dim,
    dropout=hp.dropout 
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
loss_fn = nn.CrossEntropyLoss(ignore_index=0)


## Training

In [36]:

# Train it
losses = train_model(model, train_dl, val_dl, preprocessor
                     , optimizer, loss_fn
                     , num_epochs=EPOCHS, pad_idx=0
                     , device=DEVICE)


#torch.save(model.state_dict(), f'./models/{model_name}.pth')


if IS_COLAB:
    #Save Log file
    log_dir= Path(F"./runs/{MODEL_NAME}")
    latest_event_file = max(log_dir.glob("events.out.tfevents.*"), key=lambda f: f.stat().st_mtime, default=None)
    upload_to_github(latest_event_file,
                      f"./notebooks/runs/{MODEL_NAME}/{latest_event_file.name}"
                      , GH_TOKEN)
    
    #Do the model
    upload_to_github(f"./models/{MODEL_NAME}.pth",
                      f"./notebooks/models/{MODEL_NAME}.pth"
                      , GH_TOKEN)



Starting epoch 0
Starting 1/29



Starting 29/29Logged 5 sample translations for epoch 0
Epoch 1/10 — Train Loss: 5.1017 — Val Loss: 4.5857
Starting epoch 1
Starting 29/29Logged 5 sample translations for epoch 1
Epoch 2/10 — Train Loss: 4.5807 — Val Loss: 4.5208
Starting epoch 2
Starting 29/29Logged 5 sample translations for epoch 2
Epoch 3/10 — Train Loss: 4.3703 — Val Loss: 4.4191
Starting epoch 3
Starting 29/29Logged 5 sample translations for epoch 3
Epoch 4/10 — Train Loss: 4.1559 — Val Loss: 4.1696
Starting epoch 4
Starting 2/29

KeyboardInterrupt: 