<a href="https://colab.research.google.com/github/Msingi-AI/msingi1/blob/main/train_on_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Train Msingi1: Swahili Language Model

First, let's verify we have GPU access:

In [None]:
!nvidia-smi

## 1. Setup & Dependencies

In [None]:
# Mount Google Drive to save our model
from google.colab import drive
drive.mount('/content/drive')

# Clone our repository
!git clone https://github.com/Msingi-AI/msingi1.git
%cd msingi1

# Install dependencies
%pip install -r requirements.txt

## 2. Upload Dataset
Upload your `archive.zip` file:

In [None]:
from google.colab import files
uploaded = files.upload()  # Upload archive.zip here

## 3. Train Tokenizer
First, let's train our tokenizer on the Swahili text:

In [None]:
from src.train_tokenizer import main as train_tokenizer
train_tokenizer()

# Save tokenizer to Drive
!cp -r tokenizer/ "/content/drive/MyDrive/msingi1/tokenizer/"

## 4. Train Model
Now we'll train our model using the GPU:

In [None]:
from src.train import main as train_model
train_model()

# Save model checkpoints to Drive
!cp -r checkpoints/ "/content/drive/MyDrive/msingi1/checkpoints/"

## 5. Test the Model
Let's test our trained model with some Swahili text:

In [None]:
import torch
from transformers import PreTrainedTokenizerFast
from src.model import Msingi1, MsingiConfig

# Load tokenizer
tokenizer = PreTrainedTokenizerFast.from_pretrained('tokenizer')

# Load model config
config = MsingiConfig(
    vocab_size=tokenizer.vocab_size,
    max_position_embeddings=512,
    hidden_size=256,
    num_hidden_layers=6,
    num_attention_heads=8,
    intermediate_size=1024,
)

# Load model from best checkpoint
model = Msingi1(config)
checkpoint = torch.load('checkpoints/best_model.pt')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Test text generation
test_text = "Jambo, "
input_ids = tokenizer.encode(test_text, return_tensors='pt')
outputs = model.generate(input_ids, max_length=50)
generated_text = tokenizer.decode(outputs[0])
print(f"Generated text: {generated_text}")