In [1]:
from itertools import chain
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import math
from typing import List, Tuple, Dict, Optional
from torch.utils.tensorboard import SummaryWriter

from torch.optim.lr_scheduler import LambdaLR

import base64
import requests
import os

from pathlib import Path

import numpy as np

import datetime


from dataclasses import dataclass, asdict
from typing import Protocol

import math, torch
from typing import List, Tuple
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction

## Hyper-Parameters

In [2]:
@dataclass
class HyperParams:
    #Meta
    model_name: str
    generation: str
    token_type: str
    #Training
    epochs: int
    dropout: float
    #Model
    layers: int
    emb_size: int
    heads: int
    ff_dim: int
    #Data 
    file_name: str
    batch_size: int

hp = HyperParams(
    model_name="LocalTestForDelete",
    generation="V2",
    token_type="Word",
    file_name="spa",
    layers=3,
    emb_size=256,
    heads=8,
    ff_dim=512,
    epochs=10,
    dropout=0.1,
    batch_size=32,
)

GH_TOKEN = ""

## Get Environment variables

In [3]:
def is_running_in_colab():
    try:
        import google.colab
        return True
    except ImportError:
        return False

IS_COLAB = is_running_in_colab()

In [4]:
DEVICE = 'cuda' if IS_COLAB else 'cpu'

## Data file

In [5]:
def get_text_path(file_name: str):
    if IS_COLAB:
        url = f"https://raw.githubusercontent.com/WillCable97/HomemadeTransformer/main/data/raw/{file_name}.txt"
        !wget {url}
        return f"./{file_name}.txt"
    else:
        from homemadetransformer.config import DATA_DIR
        return DATA_DIR / "raw" / f"{file_name}.txt"

## Tokenizers

In [6]:
from typing import Protocol

class Tokenizer(Protocol):
    def tokenize(self, text: str) -> List[str]:
        ...

    def add_special_tokens(self, tokens: List[str]) -> List[str]:
        ...

## Simple Word Tokenizer

In [7]:
class SimpleWordTokenizer:
    """Simple word-based tokenizer that splits on whitespace."""
    
    def __init__(self, sos_token="<SOS>", eos_token="<EOS>", pad_token="<PAD>", unk_token="<UNK>"):
        self.sos_token = sos_token
        self.eos_token = eos_token
        self.pad_token = pad_token
        self.unk_token = unk_token
        
    def tokenize(self, text: str) -> List[str]:
        return text.lower().strip().split()
    
    def add_special_tokens(self, tokens: List[str]) -> List[str]:
        return [self.sos_token] + tokens + [self.eos_token]


## Preprocessor

In [8]:
class TranslationPreprocessor:
    def __init__(self, src_tok: Tokenizer | None = None, tgt_tok: Tokenizer | None = None):
        self.src_tok = src_tok or SimpleWordTokenizer() #Default to this 
        self.tgt_tok = tgt_tok or SimpleWordTokenizer()
        self.max_src_len = None
        self.max_tgt_len = None

        # will be filled by fit()
        self.src2idx: dict[str, int] = {}
        self.tgt2idx: dict[str, int] = {}
        self.idx2src: dict[int, str] = {}
        self.idx2tgt: dict[int, str] = {}
        self.src_len = self.tgt_len = 0  # effective lengths

    # ---------- helpers ------------------------------------------------------

    def _tok_plus_specials(self, text: str, tkzr: Tokenizer) -> List[str]:
        """tokenise then add <SOS>/<EOS> (or whatever the tkzr decides)."""
        return tkzr.add_special_tokens(tkzr.tokenize(text))

    @staticmethod
    def _build_vocab(sentences: List[List[str]], pad: str, unk: str) -> dict:
        vocab = {pad: 0, unk: 1} #Reserve these from the beginning
        for tok in chain.from_iterable(sentences):
            vocab.setdefault(tok, len(vocab))
        return vocab

    def _encode_pad(self, tokens: List[str], vocab: dict
                    , max_len: int, pad_idx: int) -> List[int]:
        idxs = [vocab.get(t, vocab["<UNK>"]) for t in tokens][:max_len]
        idxs += [pad_idx] * (max_len - len(idxs))
        return idxs

    def src_ids(self, text: str) -> torch.LongTensor:
        toks = self._tok_plus_specials(text, self.src_tok)
        ids  = [self.src2idx.get(t, self.src2idx[self.src_tok.unk_token]) for t in toks]
        return torch.tensor(ids, dtype=torch.long)

    def text_from_tgt(self, ids: torch.LongTensor) -> str:
        return " ".join(self.idx2tgt[i.item()] for i in ids)


    # ---------- public ----------------------
    def fit(self, pairs: List[Tuple[str, str]]) -> None: #Takes in pairs of sentences
        #Do all the tokenisation
        src_tokd = [self._tok_plus_specials(s, self.src_tok) for s, _ in pairs]
        tgt_tokd = [self._tok_plus_specials(t, self.tgt_tok) for _, t in pairs]

        self.src2idx = self._build_vocab(src_tokd, self.src_tok.pad_token, self.src_tok.unk_token)
        self.tgt2idx = self._build_vocab(tgt_tokd, self.tgt_tok.pad_token, self.tgt_tok.unk_token)
        self.idx2src = {i: w for w, i in self.src2idx.items()}
        self.idx2tgt = {i: w for w, i in self.tgt2idx.items()}

        # decide max effective lengths
        self.src_len = self.max_src_len or max(map(len, src_tokd))
        self.tgt_len = self.max_tgt_len or max(map(len, tgt_tokd))

    def transform(self, pairs: List[Tuple[str, str]]):
        pad_src = self.src2idx[self.src_tok.pad_token]
        pad_tgt = self.tgt2idx[self.tgt_tok.pad_token]

        src_enc = [
            self._encode_pad(
                self._tok_plus_specials(s, self.src_tok),
                self.src2idx,
                self.src_len,
                pad_src,
            )
            for s, _ in pairs
        ]
        tgt_enc = [
            self._encode_pad(
                self._tok_plus_specials(t, self.tgt_tok),
                self.tgt2idx,
                self.tgt_len,
                pad_tgt,
            )
            for _, t in pairs
        ]
        return src_enc, tgt_enc

    def fit_transform(self, pairs: List[Tuple[str, str]]):
        self.fit(pairs)
        return self.transform(pairs)

## Full dataset object

In [9]:
class TranslationDataset(Dataset):
    """PyTorch Dataset for translation data."""
    
    def __init__(self, src_data: List[List[int]], tgt_data: List[List[int]]):
        self.src_data = src_data
        self.tgt_data = tgt_data
        
    def __len__(self):
        return len(self.src_data)
    
    def __getitem__(self, idx):
        return torch.tensor(self.src_data[idx]), torch.tensor(self.tgt_data[idx])


## Standard Model

In [10]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -math.log(10000.0) / d_model)
        pe[:, 0::2] = torch.sin(pos * div_term)
        pe[:, 1::2] = torch.cos(pos * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return self.dropout(x + self.pe[:, :x.size(1)])


In [11]:
class Seq2SeqTransformer(nn.Module):
    def __init__(self, num_encoder_layers, num_decoder_layers, emb_size, nhead,
                 src_vocab_size, tgt_vocab_size, dim_feedforward, dropout):
        super().__init__()
        self.src_tok_emb = nn.Embedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = nn.Embedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(emb_size, dropout)

        self.transformer = nn.Transformer(d_model=emb_size, nhead=nhead,
                                          num_encoder_layers=num_encoder_layers,
                                          num_decoder_layers=num_decoder_layers,
                                          dim_feedforward=dim_feedforward,
                                          dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)

    def forward(self, src, tgt, src_mask, tgt_mask, src_pad_mask, tgt_pad_mask, mem_pad_mask):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(tgt))
        out = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask,
                               None, src_pad_mask, tgt_pad_mask, mem_pad_mask)
        return self.generator(out)    

    def encode(self, src, src_pad_mask=None):
        S = src.size(0)
        src_mask = src.new_zeros((S, S), dtype=torch.bool)
        return self.transformer.encoder(
            self.positional_encoding(self.src_tok_emb(src)),
            src_mask,
            src_key_padding_mask=src_pad_mask,      
        )

    def decode_step(self, tgt, memory, mem_pad_mask=None):
        T = tgt.size(0)
        tgt_mask = torch.triu(torch.ones(T, T, device=tgt.device, dtype=torch.bool), 1)
        out = self.transformer.decoder(
            self.positional_encoding(self.tgt_tok_emb(tgt)),
            memory,
            tgt_mask,
            memory_key_padding_mask=mem_pad_mask, 
        )
        return self.generator(out)   


## Performing + Testing translation

In [12]:
@torch.no_grad()
def greedy_translate(model, sentence: str, pp, device, max_len: int = 50) -> str:
    model.eval()
    device   = torch.device(device)

    sos, eos = pp.tgt2idx[pp.tgt_tok.sos_token], pp.tgt2idx[pp.tgt_tok.eos_token]

    # encode source
    src_ids = pp.src_ids(sentence).to(device)[:, None]    
    src_pad_mask = (src_ids.squeeze(1) == pp.src2idx[pp.src_tok.pad_token])[None, :]
    memory = model.encode(src_ids, src_pad_mask=src_pad_mask)# pass mask!!!!!                        

    tgt_ids  = torch.tensor([[sos]], device=device)              

    for _ in range(max_len):
        #logits = model.decode_step(tgt_ids, memory) 
        logits = model.decode_step(tgt_ids, memory, mem_pad_mask=src_pad_mask)             
        next_id = logits[-1].argmax(-1).item()# greedy

        if next_id == eos:
            break
        tgt_ids = torch.cat([tgt_ids, torch.tensor([[next_id]], device=device)])

    return pp.text_from_tgt(tgt_ids.squeeze(1)[1:])

In [13]:
def create_sample_translations(model, epoch: int, pp, device, max_len: int = 50):
    # Sample sentences to translate
    sample_sentences = [
        "Hello, how are you?",
        "I love this place.",
        "What time is it?",
        "Thank you very much.",
        "Good morning everyone.",
        "I'm excited about dinner tonight."
    ]
    
    model.eval()
    translations_text = f"## Epoch {epoch} Sample Translations\n\n"
    
    with torch.no_grad():
        for i, src_sentence in enumerate(sample_sentences):
            try:
                # Translate using your existing translate function
                translation = greedy_translate(model, src_sentence, pp, device, max_len)
                
                # Format for TensorBoard
                translations_text += f"**{src_sentence}** → **{translation}**\n\n"
                
            except Exception as e:
                translations_text += f"**{src_sentence}** → **Error: {str(e)}**\n\n"
    return translations_text

## Logging Wrapper

In [50]:
class MyLogger:
    def __init__(self, model_name):
        self.writer  = torch.utils.tensorboard.SummaryWriter(f'runs/{model_name}')

    def log_avg_loss(self, train_loss, val_loss, epoch):
        self.writer.add_scalar('Loss/epoch_avg_train', train_loss, epoch)
        self.writer.add_scalar('Loss/epoch_avg_val', val_loss, epoch)
    
    def log_step_level_loss(self, loss, step):
        self.writer.add_scalar('Loss/train_step', loss, step)

    def log_sample_translations(self, translations, epoch):
        self.writer.add_text('Translations/Samples', translations, epoch)
        self.writer.flush()

    def log_lr(self, lr, epoch):
        self.writer.add_scalar('Learning Rate/epoch_avg', lr, epoch)

    def log_hyperparameters(self, hp, perplexity, bleu):
        print(f"Recieved hyperparameters : perplexity: {perplexity}, bleu: {bleu}")
        self.writer.add_hparams(hp, {
            'perplexity': float(perplexity)
            , 'bleu': float(bleu)}
        )
        self.writer.flush()

    def close(self):
        self.writer.close()

## Passing data into the model

In [51]:
def generate_square_subsequent_mask(sz, device):
    return torch.triu(torch.full((sz, sz), float('-inf'), device=device), 1)

def create_mask(src, tgt, pad_idx=0):
    src_seq_len = src.size(0)
    tgt_seq_len = tgt.size(0)
    device = src.device

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device)
    src_mask = torch.zeros((src_seq_len, src_seq_len), device=device).type(torch.bool)

    src_padding_mask = (src == pad_idx).transpose(0, 1)
    tgt_padding_mask = (tgt == pad_idx).transpose(0, 1)

    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

In [52]:
def create_batch_forward_pass(model, src_batch, tgt_batch, pad_idx, device, loss_fn):
  src_batch = src_batch.transpose(0, 1).to(device) 
  tgt_batch = tgt_batch.transpose(0, 1).to(device) 

  tgt_input = tgt_batch[:-1, :]
  tgt_output = tgt_batch[1:, :]

  src_mask, tgt_mask, src_pad_mask, tgt_pad_mask = create_mask(src_batch, tgt_input, pad_idx)

  logits = model(src_batch, tgt_input, src_mask, tgt_mask, src_pad_mask, tgt_pad_mask, src_pad_mask)

  logits_flat = logits.reshape(-1, logits.shape[-1])
  tgt_output_flat = tgt_output.reshape(-1)

  loss = loss_fn(logits_flat, tgt_output_flat)

  return loss, logits_flat

## Github interact

In [53]:
def upload_to_github(local_path, repo_path, token, username="WillCable97", repo="HomemadeTransformer", branch="main"):
    
    with open(local_path, "rb") as f:
        content = base64.b64encode(f.read()).decode("utf-8")

    url = f"https://api.github.com/repos/{username}/{repo}/contents/{repo_path}"
    data = {
        "message": f"Upload {repo_path} from Colab",
        "content": content,
        "branch": branch
    }
    headers = {"Authorization": f"token {token}"}
    r = requests.put(url, headers=headers, json=data)
    if r.status_code == 201:
        print(f"✅ Uploaded {repo_path} to GitHub")
    else:
        print(f"❌ Failed to upload {repo_path} — {r.status_code}: {r.json()}")

## Learning rate scheduel (from paper)

In [None]:
def transformer_lr_scheduler(optimizer, d_model=512, warmup_steps=WARMUP_STEPS):
    def lr_lambda(step: int):
        step = max(step, 1)
        scale = (d_model ** -0.5)           # keep the usual d_model-0.5 term
        return scale * min(step ** -0.5,
                           step * (warmup_steps ** -1.5))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

## Evaluation Functions

#### 1- Perplexity

In [55]:

#Helper func
def ids_to_tokens(ids: List[int], idx2tok: dict[int, str]
                  , eos_token: str, pad_token: str,) -> List[str]:
    out = []
    for i in ids:
        tok = idx2tok[i]
        if tok == eos_token or tok == pad_token:
            break
        out.append(tok)
    return out


@torch.no_grad()
def compute_perplexity(model: torch.nn.Module, dataloader
                       , pp, loss_fn, device: str = "cpu",) -> float:
    """
    Returns corpus-level perplexity = exp(mean cross-entropy).
    """
    device = torch.device(device)
    model.eval()

    pad_idx = pp.tgt2idx[pp.tgt_tok.pad_token]
    total_log_loss, total_tokens = 0.0, 0

    for src, tgt in dataloader:
        # count **non-PAD** target tokens (excluding first token)
        tgt_out = tgt[1:]
        n_tokens = (tgt_out != pad_idx).sum().item()

        loss, _ = create_batch_forward_pass(
            model, src, tgt,
            pad_idx=pad_idx,
            loss_fn=loss_fn,
            device=device,
        )

        total_log_loss += loss.item() * n_tokens
        total_tokens   += n_tokens

    mean_nll = total_log_loss / total_tokens
    return math.exp(mean_nll)                  # perplexity


#### 2- BLEU

In [56]:
@torch.no_grad()
def compute_bleu(
    model: torch.nn.Module,
    dataloader,
    pp,                             
    max_len: int = 50,
    device: str = "cpu",
) -> float:
    """
    Returns corpus BLEU-4 (same default as SacreBLEU) using greedy decoding.
    """
    device = torch.device(device)
    model.eval()
    smoothie = SmoothingFunction().method4      

    pad_tok = pp.tgt_tok.pad_token
    eos_tok = pp.tgt_tok.eos_token

    refs: List[List[List[str]]] = []
    hyps: List[List[str]]      = []

    for src_batch, tgt_batch in dataloader:
        for b in range(src_batch.size(0)):
            src_ids = src_batch[b].tolist()
            tgt_ids = tgt_batch[b].tolist()

            src_tokens = ids_to_tokens(
                src_ids,
                idx2tok=pp.idx2src,
                eos_token=pp.src_tok.eos_token,
                pad_token=pp.src_tok.pad_token,
            )
            src_sentence = " ".join(src_tokens)

            hyp_sentence = greedy_translate(
                model, src_sentence, pp,
                max_len=max_len, device=device,
            )
            hyps.append(hyp_sentence.split())

            ref_tokens = ids_to_tokens(
                tgt_ids,
                idx2tok=pp.idx2tgt,
                eos_token=eos_tok,
                pad_token=pad_tok,
            )
            refs.append([ref_tokens])

    return corpus_bleu(refs, hyps, smoothing_function=smoothie)

## Create All Instances

#### 1- Load data

In [57]:
#Data
text_path = get_text_path(hp.file_name)

with open(text_path, encoding="utf-8") as f:
    lines = f.read().strip().split("\n")[:1000]

pairs = [tuple(line.split("\t")[:2]) for line in lines if "\t" in line] 


#Pre process + Prepare
preprocessor = TranslationPreprocessor(
    src_tok=SimpleWordTokenizer(),
    tgt_tok=SimpleWordTokenizer()
)

src_data, tgt_data = preprocessor.fit_transform(pairs)

#### 2- Create data Objects

In [58]:
# Split into train and validation sets with fixed random seed
np.random.seed(42)
indices = np.random.permutation(len(src_data))
train_size = int(0.9 * len(src_data))

train_indices = indices[:train_size]
val_indices = indices[train_size:]

train_src_data = [src_data[i] for i in train_indices]
train_tgt_data = [tgt_data[i] for i in train_indices]
val_src_data = [src_data[i] for i in val_indices]
val_tgt_data = [tgt_data[i] for i in val_indices]

# Create datasets and dataloaders
train_dataset = TranslationDataset(train_src_data, train_tgt_data)
val_dataset = TranslationDataset(val_src_data, val_tgt_data)
train_dl = DataLoader(train_dataset, batch_size=hp.batch_size, shuffle=True)
val_dl = DataLoader(val_dataset, batch_size=hp.batch_size, shuffle=False)

#### 3- Training and model Objects

In [None]:
#Model
model = Seq2SeqTransformer(
    num_encoder_layers=hp.layers,
    num_decoder_layers=hp.layers,
    emb_size=hp.emb_size,
    nhead=hp.heads,
    src_vocab_size=len(preprocessor.src2idx),
    tgt_vocab_size=len(preprocessor.tgt2idx),
    dim_feedforward=hp.ff_dim,
    dropout=hp.dropout 
)

MODEL_NAME_TIME = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

my_logger = MyLogger(MODEL_NAME_TIME)

#### 4- Learning and optimizers

In [None]:
#optimizer = torch.optim.Adam(model.parameters(), lr=1, betas=(0.9, 0.98),eps=1e-9)
#lr_scheduler = transformer_lr_scheduler(optimizer, d_model=hp.emb_size, warmup_steps=4000)

def avg_tokens_per_step(loader, pad_id, n_batches=100):
    """Quick probe – look at the first n_batches of your DataLoader."""
    import itertools, torch
    tot = 0
    for i, (src, tgt) in enumerate(loader):
        if i == n_batches: break
        tot += (src != pad_id).sum().item() + (tgt != pad_id).sum().item()
    return tot // n_batches        # integer average

T_PAPER = 50_000
T_NEW   = avg_tokens_per_step(train_dl, pad_id=0)   # <- run once

# Scaling the LR based on my thing
LR_MULT      = T_NEW / T_PAPER
WARMUP_STEPS = int(4_000 * T_PAPER / T_NEW)

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=LR_MULT,                     # paper used 1.0 → scale it
    betas=(0.9, 0.98),
    eps=1e-9
)

lr_scheduler = transformer_lr_scheduler(
    optimizer,
    d_model=hp.emb_size,
    warmup_steps=WARMUP_STEPS
)

## Training

#### 1- Loop

In [48]:
def train_model(model, dataloader, val_dataloader, preprocessor, optimizer, lr_scheduler, loss_fn, num_epochs, my_logger, pad_idx=0, device='cpu'):
    model.to(device)
    model.train()
    
    global_step = 1

    for epoch in range(num_epochs):
        # For Logging progression
        print(f"Starting epoch {epoch}")
        dl_len = len(dataloader)
        i = 1

        model.train()
        total_loss = 0

        for eng_batch, spa_batch in dataloader:
            
            print(f"\rStarting {i}/{dl_len}", end="")
            i += 1
            global_step += 1

            loss, _ = create_batch_forward_pass(model, eng_batch, spa_batch, pad_idx, device, loss_fn)

            #Perform the training step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            lr_scheduler.step()

            total_loss += loss.item()

            my_logger.log_step_level_loss(loss.item(), global_step)

            # Log learning rate
            current_lr = lr_scheduler.get_last_lr()[0]
            my_logger.log_lr(current_lr, global_step)
        
        avg_loss = total_loss / len(dataloader)

        # Validation
        model.eval()
        total_val_loss = 0
        
        with torch.no_grad():
            for eng_batch, spa_batch in val_dataloader:
                loss, _ = create_batch_forward_pass(model, eng_batch, spa_batch, pad_idx, device, loss_fn)
                total_val_loss += loss.item()

        avg_val_loss = total_val_loss / len(val_dataloader)
        my_logger.log_avg_loss(avg_loss, avg_val_loss, epoch)

        #Log sample translations
        sample_translations = create_sample_translations(model, epoch, preprocessor, device)
        my_logger.log_sample_translations(sample_translations, epoch)
        
        print(f"Epoch {epoch + 1}/{num_epochs} — Train Loss: {avg_loss:.4f} — Val Loss: {avg_val_loss:.4f}")

    #Log hyperparameters
    print("Logging hyperparameters and evaluation metrics")
    perplexity = compute_perplexity(model, val_dataloader, preprocessor, loss_fn, device)
    bleu = compute_bleu(model, val_dataloader, preprocessor, device=device)
    print(f"Calculated perplexity: {perplexity} and bleu: {bleu}")
    my_logger.log_hyperparameters(asdict(hp), perplexity, bleu)

    my_logger.close()
    
    return avg_loss, avg_val_loss


#### 2- Run it

In [None]:
# Train it
losses = train_model(model, train_dl, val_dl, preprocessor
                     , optimizer, lr_scheduler
                     , nn.CrossEntropyLoss(ignore_index=0, label_smoothing=0.1)
                     , num_epochs=hp.epochs, my_logger=my_logger, pad_idx=0
                     , device=DEVICE)


os.makedirs('./models', exist_ok=True)
ROOT_PATH = "/content/drive/MyDrive/Colab Projects/Translation transformer/"
torch.save(model.state_dict(), f'/content/drive/MyDrive/Colab Projects/Translation transformer/{MODEL_NAME_TIME}.pth')
#torch.save(model.state_dict(), f'./models/{MODEL_NAME_TIME}.pth')



from pathlib import Path

def upload_run_dir(run_dir: Path, dest_root: str, gh_token: str):
    """Uploads *all* TensorBoard event files in `run_dir`, incl. hparams sub-dir."""
    for ev in sorted(run_dir.rglob("events.out.tfevents.*")):
        rel_dest = Path(dest_root) / ev.relative_to(run_dir)   # keep same tree
        upload_to_github(str(ev), str(rel_dest), gh_token)

# --- in your Colab block ---
if IS_COLAB:
    log_dir = Path(f"./runs/{MODEL_NAME_TIME}")
    upload_run_dir(
        log_dir,
        f"./notebooks/runs/{MODEL_NAME_TIME}",
        GH_TOKEN
    )