# Msingi1: Swahili Language Model Training

This notebook trains the Msingi1 language model using Mixture of Experts architecture.

## 1. Setup Environment

First, we'll set up our environment and install dependencies.

In [None]:
# Mount Google Drive for persistent storage
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone repository and install dependencies
!git clone https://github.com/your-username/msingi1.git
%cd msingi1

# Install dependencies and setup environment
!python setup_colab.py

## 2. Load Dataset and Tokenizer

In [None]:
import torch
from tokenizers import ByteLevelBPETokenizer
from src.model import Msingi1, MsingiConfig
from src.data_processor import SwahiliDataset, extract_dataset
from torch.utils.data import DataLoader
import wandb
from tqdm.auto import tqdm

# Load tokenizer
tokenizer = ByteLevelBPETokenizer(
    'tokenizer/vocab.json',
    'tokenizer/merges.txt'
)

# Extract and prepare dataset
texts = extract_dataset('data/swahili_text.zip')
dataset = SwahiliDataset(texts, tokenizer)

# Create data loader
train_loader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=True,
    num_workers=2
)

## 3. Initialize Model and Training

In [None]:
# Initialize model
config = MsingiConfig()
model = Msingi1(config)

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Initialize optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)

# Initialize wandb for tracking
wandb.init(project='msingi1', name='moe_training')

## 4. Training Loop

In [None]:
num_epochs = 50
save_path = '/content/drive/MyDrive/msingi1/checkpoints'

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
    
    for batch in progress_bar:
        # Move batch to device
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)
        
        # Forward pass
        outputs = model(input_ids)
        loss = torch.nn.functional.cross_entropy(outputs.view(-1, config.vocab_size), labels.view(-1))
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update progress
        total_loss += loss.item()
        progress_bar.set_postfix({'loss': loss.item()})
        wandb.log({'loss': loss.item()})
    
    # Update scheduler
    scheduler.step()
    
    # Save checkpoint
    if (epoch + 1) % 5 == 0:
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': total_loss / len(train_loader)
        }
        torch.save(checkpoint, f'{save_path}/checkpoint_epoch_{epoch+1}.pt')
    
    # Log epoch metrics
    epoch_loss = total_loss / len(train_loader)
    print(f'Epoch {epoch+1} - Average Loss: {epoch_loss:.4f}')
    wandb.log({'epoch': epoch, 'epoch_loss': epoch_loss})

## 5. Test Text Generation

In [None]:
model.eval()

def generate_text(prompt, max_length=100):
    # Encode prompt
    encoded = tokenizer.encode(prompt)
    input_ids = torch.tensor(encoded.ids).unsqueeze(0).to(device)
    
    # Generate
    with torch.no_grad():
        output_ids = model.generate(
            input_ids=input_ids,
            max_length=max_length,
            num_return_sequences=1,
            pad_token_id=tokenizer.token_to_id('<pad>')
        )
    
    # Decode and return
    return tokenizer.decode(output_ids[0].tolist())

# Test generation
test_prompts = [
    "Habari ya leo?",
    "Tanzania ni nchi",
    "Kiswahili ni lugha"
]

for prompt in test_prompts:
    generated = generate_text(prompt)
    print(f'\nPrompt: {prompt}')
    print(f'Generated: {generated}')