<a href="https://colab.research.google.com/github/Msingi-AI/msingi1/blob/main/train_on_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Train Msingi1: Swahili Language Model

First, let's verify we have GPU access:

In [None]:
!nvidia-smi

## 1. Setup & Dependencies

In [None]:
# Mount Google Drive to save our model
from google.colab import drive
drive.mount('/content/drive')

# Create project directory
!mkdir -p /content/msingi1
%cd /content/msingi1

# Clone our repository
!git clone https://github.com/Msingi-AI/msingi1.git .

# Install dependencies
%pip install torch transformers tokenizers datasets numpy tqdm wandb

## 2. Upload Dataset
Upload your `archive.zip` file:

In [None]:
from google.colab import files
print("Please upload your archive.zip file...")
uploaded = files.upload()  # Upload archive.zip here

## 3. Add Source to Python Path

In [None]:
import sys
import os

# Add the current directory to Python path
if '/content/msingi1' not in sys.path:
    sys.path.append('/content/msingi1')

# Verify imports work
from src.data_processor import extract_dataset, get_dataset_stats
from src.train_tokenizer import train_tokenizer
from src.train import main as train_model

print("✅ Imports successful!")

## 4. Train Tokenizer
First, let's train our tokenizer on the Swahili text:

In [None]:
# Train tokenizer
tokenizer = train_tokenizer()

# Save tokenizer to Drive
!mkdir -p "/content/drive/MyDrive/msingi1/tokenizer"
!cp -r tokenizer/* "/content/drive/MyDrive/msingi1/tokenizer/"
print("✅ Tokenizer saved to Google Drive!")

## 5. Train Model
Now we'll train our model using the GPU:

In [None]:
# Train model
train_model()

# Save model checkpoints to Drive
!mkdir -p "/content/drive/MyDrive/msingi1/checkpoints"
!cp -r checkpoints/* "/content/drive/MyDrive/msingi1/checkpoints/"
print("✅ Model checkpoints saved to Google Drive!")

## 6. Test the Model
Let's test our trained model with some Swahili text:

In [None]:
import torch
from transformers import PreTrainedTokenizerFast
from src.model import Msingi1, MsingiConfig

# Load tokenizer
tokenizer = PreTrainedTokenizerFast.from_pretrained('tokenizer')

# Load model config
config = MsingiConfig(
    vocab_size=tokenizer.vocab_size,
    max_position_embeddings=512,
    hidden_size=256,
    num_hidden_layers=6,
    num_attention_heads=8,
    intermediate_size=1024,
)

# Load model from best checkpoint
model = Msingi1(config)
checkpoint = torch.load('checkpoints/best_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Test text generation
test_text = "Jambo, "
input_ids = tokenizer.encode(test_text, return_tensors='pt')
outputs = model.generate(input_ids, max_length=50)
generated_text = tokenizer.decode(outputs[0])
print(f"Generated text: {generated_text}")