<a href="https://colab.research.google.com/github/Msingi-AI/msingi1/blob/main/Msingi1_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Msingi1: Swahili Language Model Training

This notebook trains a Mixture of Experts (MoE) language model for Swahili. Follow these steps:

1. Mount Google Drive
2. Clone repository and install dependencies
3. Upload dataset and tokenizer
4. Train model with checkpointing
5. Generate text samples

In [None]:
# Mount Google Drive for persistent storage
from google.colab import drive
drive.mount('/content/drive')

# Create project directories
!mkdir -p /content/drive/MyDrive/msingi1/{data,tokenizer,checkpoints,logs}

In [None]:
# Verify GPU is available
!nvidia-smi

import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU device: {torch.cuda.get_device_name(0)}")

In [None]:
# Clone repository and install dependencies
!git clone https://github.com/Msingi-AI/msingi1.git
%cd msingi1

!pip install -q torch==2.0.1 fastmoe==0.3.2 wandb tokenizers datasets tqdm
!pip install -q -r requirements.txt

In [None]:
# Upload dataset and tokenizer files
from google.colab import files
import shutil
import os

def upload_to_drive(local_path, drive_path):
    os.makedirs(os.path.dirname(drive_path), exist_ok=True)
    shutil.copy2(local_path, drive_path)
    print(f"Uploaded {local_path} to {drive_path}")

# Upload dataset
print("Upload your Swahili dataset file...")
uploaded = files.upload()
for filename in uploaded.keys():
    drive_path = f"/content/drive/MyDrive/msingi1/data/{filename}"
    upload_to_drive(filename, drive_path)

# Upload tokenizer files
print("\nUpload tokenizer files (vocab.json and merges.txt)...")
uploaded = files.upload()
for filename in uploaded.keys():
    drive_path = f"/content/drive/MyDrive/msingi1/tokenizer/{filename}"
    upload_to_drive(filename, drive_path)

In [None]:
import os
import torch
import wandb
from torch.optim import AdamW
from torch.nn import functional as F
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from tokenizers import ByteLevelBPETokenizer

from src.model import Msingi1, MsingiConfig
from src.data_processor import SwahiliDataset

# Load tokenizer
tokenizer = ByteLevelBPETokenizer(
    "/content/drive/MyDrive/msingi1/tokenizer/vocab.json",
    "/content/drive/MyDrive/msingi1/tokenizer/merges.txt"
)

# Model configuration
config = MsingiConfig(
    vocab_size=32000,
    max_position_embeddings=1024,
    hidden_size=768,
    num_hidden_layers=6,
    num_attention_heads=12,
    intermediate_size=3072,
    num_experts=8,
    expert_capacity=32,
    moe_layers=[2, 4]
)

# Training arguments
training_args = {
    'batch_size': 16,  # Reduced for GPU memory
    'accumulation_steps': 4,  # Effective batch size = 64
    'learning_rate': 3e-4,
    'warmup_steps': 1000,
    'max_steps': 20000,
    'save_steps': 1000,
    'max_grad_norm': 1.0,
    'weight_decay': 0.01
}

In [None]:
# Initialize wandb
wandb.init(
    project="msingi1",
    config={
        **vars(config),
        **training_args
    }
)

# Load dataset
dataset = SwahiliDataset(
    data_path="/content/drive/MyDrive/msingi1/data/train.txt",
    tokenizer=tokenizer
)
dataloader = DataLoader(
    dataset,
    batch_size=training_args['batch_size'],
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

# Initialize model
model = Msingi1(config).cuda()
model.train()

# Optimizer
optimizer = AdamW(
    model.parameters(),
    lr=training_args['learning_rate'],
    weight_decay=training_args['weight_decay']
)

# Load checkpoint if exists
checkpoint_dir = "/content/drive/MyDrive/msingi1/checkpoints"
latest_checkpoint = None
if os.path.exists(checkpoint_dir):
    checkpoints = sorted(os.listdir(checkpoint_dir))
    if checkpoints:
        latest_checkpoint = os.path.join(checkpoint_dir, checkpoints[-1])
        print(f"Loading checkpoint: {latest_checkpoint}")
        checkpoint = torch.load(latest_checkpoint)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_step = checkpoint['step']
    else:
        start_step = 0
else:
    start_step = 0

# Training loop
step = start_step
optimizer.zero_grad()

while step < training_args['max_steps']:
    progress_bar = tqdm(dataloader, desc=f"Step {step}")
    
    for batch in progress_bar:
        # Move batch to GPU
        input_ids = batch['input_ids'].cuda()
        attention_mask = batch['attention_mask'].cuda()
        labels = batch['labels'].cuda()
        
        # Forward pass
        logits = model(input_ids, attention_mask=attention_mask)
        
        # Calculate loss
        loss = F.cross_entropy(
            logits.view(-1, config.vocab_size),
            labels.view(-1),
            ignore_index=-100
        )
        
        # Scale loss for gradient accumulation
        loss = loss / training_args['accumulation_steps']
        loss.backward()
        
        # Update weights if we've accumulated enough gradients
        if (step + 1) % training_args['accumulation_steps'] == 0:
            torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                training_args['max_grad_norm']
            )
            optimizer.step()
            optimizer.zero_grad()
        
        # Log metrics
        wandb.log({
            'loss': loss.item() * training_args['accumulation_steps'],
            'step': step
        })
        
        # Save checkpoint
        if step > 0 and step % training_args['save_steps'] == 0:
            checkpoint_path = f'/content/drive/MyDrive/msingi1/checkpoints/step_{step}.pt'
            torch.save({
                'step': step,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss.item(),
                'config': config
            }, checkpoint_path)
            print(f'\nSaved checkpoint to {checkpoint_path}')
        
        # Update progress
        progress_bar.set_postfix({'loss': loss.item() * training_args['accumulation_steps']})
        step += 1
        
        if step >= training_args['max_steps']:
            break

# Save final model
torch.save({
    'step': step,
    'model_state_dict': model.state_dict(),
    'config': config
}, '/content/drive/MyDrive/msingi1/checkpoints/final_model.pt')

wandb.finish()

In [None]:
# Generate text samples
model.eval()

def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
    with torch.no_grad():
        # Tokenize prompt
        input_ids = torch.tensor(tokenizer.encode(prompt).ids).unsqueeze(0).cuda()
        
        for _ in range(max_length):
            outputs = model(input_ids)
            next_token_logits = outputs[:, -1, :] / temperature
            
            # Top-k sampling
            top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k)
            probs = F.softmax(top_k_logits, dim=-1)
            next_token = top_k_indices[0, torch.multinomial(probs[0], 1)]
            
            input_ids = torch.cat([input_ids, next_token.unsqueeze(0).unsqueeze(0)], dim=1)
            
            # Stop if we predict the end token
            if next_token.item() == tokenizer.token_to_id("</s>"):
                break
        
        return tokenizer.decode(input_ids[0].tolist())

# Test generation
prompts = [
    "Habari ya leo",
    "Naomba",
    "Karibu"
]

for prompt in prompts:
    generated = generate_text(prompt)
    print(f"\nPrompt: {prompt}")
    print(f"Generated: {generated}")