# Msingi1: Swahili Language Model with Mixture of Experts

This notebook trains the Msingi1 model using Mixture of Experts on Google Colab's GPU.

1. Click Runtime -> Change runtime type
2. Select GPU as Hardware accelerator
3. Click Save

In [None]:
# Verify GPU is available
!nvidia-smi

In [None]:
# Mount Google Drive for checkpoints
from google.colab import drive
drive.mount('/content/drive')

# Create directories
!mkdir -p /content/drive/MyDrive/msingi1/{checkpoints,logs}

In [None]:
# Clone the repository
!git clone https://github.com/Msingi-AI/msingi1.git
%cd msingi1

In [None]:
# Install dependencies
!pip install -q torch==2.0.1 fastmoe==0.3.2 wandb tokenizers datasets tqdm
!pip install -q -r requirements.txt

In [None]:
import os
import torch
import wandb
from torch.optim import AdamW
from torch.nn import functional as F
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

from src.model import Msingi1, MsingiConfig
from src.data_processor import SwahiliDataset

# Enable tensor cores for faster training
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

In [None]:
# Model and training configuration
config = MsingiConfig(
    vocab_size=32000,
    max_position_embeddings=1024,
    hidden_size=768,
    num_hidden_layers=6,
    num_attention_heads=12,
    intermediate_size=3072,
    num_experts=8,
    expert_capacity=32,
    moe_layers=[2, 4]
)

training_args = {
    'batch_size': 32,
    'learning_rate': 3e-4,
    'warmup_steps': 1000,
    'max_steps': 20000,
    'save_steps': 1000,
    'gradient_accumulation_steps': 4,
    'max_grad_norm': 1.0,
    'weight_decay': 0.01
}

In [None]:
# Initialize wandb for experiment tracking
wandb.init(
    project="msingi1",
    config={
        **vars(config),
        **training_args
    }
)

# Load dataset
dataset = SwahiliDataset('data/Swahili data/train.txt')
dataloader = DataLoader(
    dataset,
    batch_size=training_args['batch_size'],
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

# Initialize model and move to GPU
model = Msingi1(config).cuda()
model.train()

# Optimizer with weight decay
optimizer = AdamW(
    model.parameters(),
    lr=training_args['learning_rate'],
    weight_decay=training_args['weight_decay']
)

# Learning rate scheduler
def get_lr(step):
    if step < training_args['warmup_steps']:
        return step / training_args['warmup_steps']
    return 1.0

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, get_lr)

In [None]:
# Training loop
step = 0
optimizer.zero_grad()

while step < training_args['max_steps']:
    progress_bar = tqdm(dataloader, desc=f"Step {step}")
    
    for batch in progress_bar:
        # Move batch to GPU
        input_ids = batch['input_ids'].cuda()
        attention_mask = batch['attention_mask'].cuda()
        labels = batch['labels'].cuda()
        
        # Forward pass
        logits = model(input_ids, attention_mask=attention_mask)
        
        # Calculate loss
        loss = F.cross_entropy(
            logits.view(-1, config.vocab_size),
            labels.view(-1),
            ignore_index=-100
        )
        
        # Scale loss for gradient accumulation
        loss = loss / training_args['gradient_accumulation_steps']
        loss.backward()
        
        # Update weights if we've accumulated enough gradients
        if (step + 1) % training_args['gradient_accumulation_steps'] == 0:
            torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                training_args['max_grad_norm']
            )
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
        
        # Log metrics
        wandb.log({
            'loss': loss.item() * training_args['gradient_accumulation_steps'],
            'learning_rate': scheduler.get_last_lr()[0],
            'step': step
        })
        
        # Save checkpoint
        if step > 0 and step % training_args['save_steps'] == 0:
            checkpoint_path = f'/content/drive/MyDrive/msingi1/checkpoints/step_{step}.pt'
            torch.save({
                'step': step,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'loss': loss.item(),
                'config': config
            }, checkpoint_path)
            print(f'\nSaved checkpoint to {checkpoint_path}')
        
        # Update progress
        progress_bar.set_postfix({'loss': loss.item() * training_args['gradient_accumulation_steps']})
        step += 1
        
        if step >= training_args['max_steps']:
            break

# Save final model
torch.save({
    'step': step,
    'model_state_dict': model.state_dict(),
    'config': config
}, '/content/drive/MyDrive/msingi1/checkpoints/final_model.pt')

wandb.finish()

In [None]:
# Generate some text
model.eval()
with torch.no_grad():
    prompt = "Habari ya leo"
    generated = model.generate(
        prompt,
        max_length=100,
        temperature=0.7,
        top_k=50
    )
    print(f"Prompt: {prompt}")
    print(f"Generated: {generated}")