# Msingi1 Training on Google Colab

This notebook:
1. Trains a custom tokenizer
2. Trains the Msingi1 Swahili language model for 100 epochs
3. Saves checkpoints to prevent loss of progress

In [None]:
# Check GPU availability
!nvidia-smi

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone the repository
!git clone https://github.com/Msingi-AI/msingi1.git
%cd msingi1

In [None]:
# Install requirements
!pip install -r requirements.txt
!pip install tokenizers

## Train Tokenizer

In [None]:
from src.train_tokenizer import train_tokenizer

# Train tokenizer on the full dataset
tokenizer = train_tokenizer('data/Swahili data/train.txt')

# Save tokenizer to Drive to prevent loss
tokenizer_path = '/content/drive/MyDrive/msingi1_tokenizer'
os.makedirs(tokenizer_path, exist_ok=True)
tokenizer.save(f'{tokenizer_path}/tokenizer.json')
print('Tokenizer trained and saved!')

## Train Model

In [None]:
import torch
import os
from src.model import Msingi1Model
from src.train import train_model
from src.data_processor import SwahiliDataset

# Training configuration
config = {
    'batch_size': 32,
    'num_epochs': 100,  # Extended to 100 epochs
    'learning_rate': 3e-4,
    'save_every': 5,    # Save checkpoint every 5 epochs
    'checkpoint_dir': '/content/drive/MyDrive/msingi1_checkpoints',
    'tokenizer_path': '/content/drive/MyDrive/msingi1_tokenizer/tokenizer.json'
}

# Create checkpoint directory
os.makedirs(config['checkpoint_dir'], exist_ok=True)

# Load tokenizer and initialize dataset
dataset = SwahiliDataset(
    'data/Swahili data/train.txt',
    tokenizer_path=config['tokenizer_path']
)

# Initialize model
model = Msingi1Model().to('cuda')

# Function to load latest checkpoint if exists
def load_latest_checkpoint():
    checkpoints = sorted([
        f for f in os.listdir(config['checkpoint_dir'])
        if f.endswith('.pt')
    ])
    if checkpoints:
        latest = os.path.join(config['checkpoint_dir'], checkpoints[-1])
        checkpoint = torch.load(latest)
        model.load_state_dict(checkpoint['model_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f'Resuming from epoch {start_epoch}')
        return start_epoch
    return 0

# Train with checkpointing
def train_with_checkpoints():
    start_epoch = load_latest_checkpoint()
    
    for epoch in range(start_epoch, config['num_epochs']):
        loss = train_model(model, dataset, config)
        print(f'Epoch {epoch+1}/100, Loss: {loss}')
        
        if (epoch + 1) % config['save_every'] == 0:
            checkpoint_path = os.path.join(
                config['checkpoint_dir'],
                f'msingi1_epoch_{epoch+1}.pt'
            )
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'loss': loss,
            }, checkpoint_path)
            print(f'Saved checkpoint to {checkpoint_path}')

# Start training
train_with_checkpoints()

In [None]:
# Test the model
model.eval()
with torch.no_grad():
    sample_text = model.generate("Habari ya leo", max_length=100)
    print(sample_text)