# Msingi1: Training on TPU with JAX/Flax

This notebook trains the Msingi1 Swahili language model using Google Colab's TPU for maximum performance.

In [None]:
#@title Check TPU availability and install dependencies
!pip install -q jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install -q flax optax transformers datasets wandb

import jax
print('Number of TPU devices:', jax.device_count())
print('JAX devices:', jax.devices())

In [None]:
#@title Mount Google Drive and clone repository
from google.colab import drive
drive.mount('/content/drive')

!git clone https://github.com/Msingi-AI/msingi1.git
%cd msingi1

In [None]:
#@title Import required libraries
import numpy as np
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
from transformers import FlaxMoEForCausalLM, AutoTokenizer
import wandb
from tqdm.auto import tqdm
from functools import partial
import os

In [None]:
#@title Load and prepare dataset
def load_dataset():
    # Load the pre-trained tokenizer
    tokenizer = AutoTokenizer.from_pretrained('tokenizer')
    
    # Load text data
    with open('data/Swahili data/Swahili data/train.txt', 'r', encoding='utf-8') as f:
        text = f.read()
    
    # Split into chunks of max_length
    max_length = 512
    stride = 256
    
    # Tokenize text
    tokens = tokenizer(text, return_tensors='np', truncation=False)['input_ids'][0]
    
    # Create chunks
    chunks = []
    for i in range(0, len(tokens) - max_length + 1, stride):
        chunk = tokens[i:i + max_length]
        if len(chunk) == max_length:
            chunks.append(chunk)
    
    return np.array(chunks), tokenizer

# Load dataset and tokenizer
chunks, tokenizer = load_dataset()
print(f'Created {len(chunks)} training chunks of length {chunks.shape[1]}')

In [None]:
#@title Define model and training functions
def create_train_state(model, learning_rate=1e-4):
    """Creates training state with Adam optimizer"""
    optimizer = optax.adamw(learning_rate)
    return train_state.TrainState.create(
        apply_fn=model.__call__,
        params=model.params,
        tx=optimizer
    )

@partial(jax.pmap, axis_name='batch')
def train_step(state, batch, rng):
    """Single training step"""
    def loss_fn(params):
        logits = state.apply_fn(
            params, batch, 
            deterministic=False, 
            rngs={'dropout': rng}
        )[0]
        
        # Shift labels for next-token prediction
        labels = jnp.roll(batch, -1, axis=1)
        labels = labels.at[:, -1].set(0)  # Mask last token
        
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits[..., :-1, :],
            labels[..., 1:]
        ).mean()
        
        return loss
    
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    
    # Average gradients across devices
    grads = jax.lax.pmean(grads, axis_name='batch')
    
    return state.apply_gradients(grads=grads), loss

def create_model(vocab_size):
    """Create Flax MoE model"""
    config = {
        'vocab_size': vocab_size,
        'hidden_size': 768,
        'num_hidden_layers': 6,
        'num_attention_heads': 12,
        'intermediate_size': 3072,
        'num_experts': 8,
        'expert_capacity': 32,
        'moe_layer_indices': [2, 4]
    }
    return FlaxMoEForCausalLM(config)

In [None]:
#@title Training loop
# Initialize wandb
wandb.init(project='msingi1', name='tpu_training')

# Create model and state
model = create_model(tokenizer.vocab_size)
state = create_train_state(model)

# Training parameters
batch_size = 32 * jax.device_count()  # Batch size per device * num devices
num_epochs = 100
save_every = 5

# Initialize RNG
rng = jax.random.PRNGKey(0)
rngs = jax.random.split(rng, jax.device_count())

# Training loop
for epoch in range(num_epochs):
    # Shuffle data
    rng, shuffle_rng = jax.random.split(rng)
    perm = jax.random.permutation(shuffle_rng, len(chunks))
    chunks = chunks[perm]
    
    # Create batches
    num_batches = len(chunks) // batch_size
    epoch_loss = 0.0
    
    for i in tqdm(range(num_batches), desc=f'Epoch {epoch+1}/{num_epochs}'):
        batch = chunks[i * batch_size:(i + 1) * batch_size]
        batch = jnp.array(batch)
        
        # Reshape batch for devices
        batch = batch.reshape(jax.device_count(), -1, batch.shape[-1])
        
        # Train step
        state, loss = train_step(state, batch, rngs)
        epoch_loss += loss.mean()
        
        # Log to wandb
        if i % 10 == 0:
            wandb.log({
                'loss': loss.mean(),
                'epoch': epoch,
                'step': epoch * num_batches + i
            })
    
    # Average epoch loss
    epoch_loss /= num_batches
    print(f'Epoch {epoch+1} loss: {epoch_loss}')
    
    # Save checkpoint
    if (epoch + 1) % save_every == 0:
        checkpoint_dir = f'/content/drive/MyDrive/msingi1_checkpoints/epoch_{epoch+1}'
        os.makedirs(checkpoint_dir, exist_ok=True)
        model.save_pretrained(checkpoint_dir, params=state.params)
        print(f'Saved checkpoint to {checkpoint_dir}')

In [None]:
#@title Generate sample text
def generate_text(prompt, max_length=100):
    inputs = tokenizer(prompt, return_tensors='jax')
    outputs = model.generate(
        inputs['input_ids'],
        max_length=max_length,
        temperature=0.7,
        top_k=50,
        do_sample=True
    )
    return tokenizer.decode(outputs[0])

# Test generation
prompt = "Habari ya leo"
generated = generate_text(prompt)
print(f'Prompt: {prompt}')
print(f'Generated: {generated}')