# Msingi1: Training on TPU with JAX/Flax

This notebook trains the Msingi1 Swahili language model using TPU acceleration and JAX/Flax for optimal performance.

In [None]:
# Install dependencies
!pip install -q jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install -q flax transformers datasets wandb
!pip install -q git+https://github.com/google/flaxformer

In [None]:
# Check TPU availability
import jax
print('Number of TPU devices:', jax.device_count())
print('JAX devices:', jax.devices())

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone repository
!git clone https://github.com/Msingi-AI/msingi1.git
%cd msingi1

In [None]:
import os
import jax
import jax.numpy as jnp
import flax
import optax
from flax import linen as nn
from flax.training import train_state
from transformers import FlaxPreTrainedModel, PretrainedConfig
from datasets import load_dataset
from tokenizers import Tokenizer
import wandb

# Load our pre-trained tokenizer
tokenizer = Tokenizer.from_file('tokenizer/tokenizer.json')

# Model configuration
class Msingi1Config(PretrainedConfig):
    def __init__(
        self,
        vocab_size=32000,
        hidden_size=768,
        num_hidden_layers=6,
        num_attention_heads=12,
        intermediate_size=3072,
        hidden_dropout_prob=0.1,
        attention_probs_dropout_prob=0.1,
        max_position_embeddings=1024,
        num_experts=8,
        expert_capacity=32,
        moe_layers=[2, 4],
        **kwargs
    ):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.max_position_embeddings = max_position_embeddings
        self.num_experts = num_experts
        self.expert_capacity = expert_capacity
        self.moe_layers = set(moe_layers)

In [None]:
# Initialize model with Flax MoE
from flaxformer.components import dense_attention, dense, mixture_of_experts

class ExpertMLP(nn.Module):
    config: Msingi1Config
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        self.dense1 = nn.Dense(
            self.config.intermediate_size,
            dtype=self.dtype
        )
        self.dense2 = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype
        )
        self.dropout = nn.Dropout(self.config.hidden_dropout_prob)

    def __call__(self, x, deterministic=True):
        x = self.dense1(x)
        x = nn.gelu(x)
        x = self.dense2(x)
        x = self.dropout(x, deterministic=deterministic)
        return x

class MsingiBlock(nn.Module):
    config: Msingi1Config
    dtype: jnp.dtype = jnp.float32
    layer_idx: int = 0

    def setup(self):
        self.attention = dense_attention.MultiHeadDotProductAttention(
            num_heads=self.config.num_attention_heads,
            dtype=self.dtype
        )
        
        if self.layer_idx in self.config.moe_layers:
            self.feed_forward = mixture_of_experts.MoeLayer(
                num_experts=self.config.num_experts,
                expert_cls=ExpertMLP,
                expert_capacity=self.config.expert_capacity,
                config=self.config,
                dtype=self.dtype
            )
        else:
            self.feed_forward = dense.MlpBlock(
                intermediate_dim=self.config.intermediate_size,
                dropout_rate=self.config.hidden_dropout_prob,
                dtype=self.dtype
            )

        self.layernorm1 = nn.LayerNorm()
        self.layernorm2 = nn.LayerNorm()

    def __call__(self, x, mask=None, deterministic=True):
        y = self.layernorm1(x)
        y = self.attention(
            y, y, mask=mask, deterministic=deterministic
        )
        x = x + y

        y = self.layernorm2(x)
        y = self.feed_forward(
            y, deterministic=deterministic
        )
        return x + y

class Msingi1Module(nn.Module):
    config: Msingi1Config
    dtype: jnp.dtype = jnp.float32

    def setup(self):
        self.embeddings = nn.Embed(
            num_embeddings=self.config.vocab_size,
            features=self.config.hidden_size,
            dtype=self.dtype
        )
        self.position_embeddings = nn.Embed(
            num_embeddings=self.config.max_position_embeddings,
            features=self.config.hidden_size,
            dtype=self.dtype
        )
        self.layers = [
            MsingiBlock(
                config=self.config,
                layer_idx=i,
                dtype=self.dtype
            )
            for i in range(self.config.num_hidden_layers)
        ]
        self.layernorm = nn.LayerNorm()
        self.lm_head = nn.Dense(
            self.config.vocab_size,
            dtype=self.dtype
        )

    def __call__(self, input_ids, attention_mask=None, deterministic=True):
        seq_length = input_ids.shape[1]
        position_ids = jnp.arange(seq_length)[None, :]

        x = self.embeddings(input_ids)
        x = x + self.position_embeddings(position_ids)

        for layer in self.layers:
            x = layer(
                x,
                mask=attention_mask,
                deterministic=deterministic
            )

        x = self.layernorm(x)
        return self.lm_head(x)

class FlaxMsingi1PreTrainedModel(FlaxPreTrainedModel):
    config_class = Msingi1Config
    module_class = Msingi1Module

In [None]:
# Training configuration
training_args = {
    'per_device_train_batch_size': 32,
    'num_train_epochs': 100,
    'learning_rate': 3e-4,
    'warmup_steps': 1000,
    'logging_steps': 100,
    'save_steps': 1000,
    'output_dir': '/content/drive/MyDrive/msingi1_checkpoints'
}

# Initialize model and optimizer
config = Msingi1Config()
model = FlaxMsingi1PreTrainedModel(config)

# Create optimizer
optimizer = optax.adamw(
    learning_rate=training_args['learning_rate'],
    b1=0.9,
    b2=0.999,
    eps=1e-8,
    weight_decay=0.01
)

# Create training state
state = train_state.TrainState.create(
    apply_fn=model.__call__,
    params=model.params,
    tx=optimizer
)

In [None]:
# Training loop with TPU optimization
@jax.jit
def train_step(state, batch):
    def loss_fn(params):
        logits = state.apply_fn(
            params,
            batch['input_ids'],
            attention_mask=batch['attention_mask'],
            deterministic=False
        )
        
        # Shift labels for next-token prediction
        labels = jnp.roll(batch['input_ids'], -1)
        labels = labels.at[:, -1].set(0)
        
        # Calculate loss
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=logits,
            labels=labels
        )
        loss = jnp.mean(loss * batch['attention_mask'])
        return loss

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss

# Initialize wandb
wandb.init(project='msingi1', name='tpu_training')

# Training loop
for epoch in range(training_args['num_train_epochs']):
    for step, batch in enumerate(train_dataloader):
        state, loss = train_step(state, batch)
        
        if step % training_args['logging_steps'] == 0:
            wandb.log({
                'loss': loss,
                'epoch': epoch,
                'step': step
            })
        
        if step % training_args['save_steps'] == 0:
            model.save_pretrained(
                f"{training_args['output_dir']}/checkpoint-{epoch}-{step}"
            )

# Save final model
model.save_pretrained(f"{training_args['output_dir']}/final_model")

In [None]:
# Test the model
@jax.jit
def generate(params, prompt, max_length=100, temperature=0.7):
    input_ids = tokenizer.encode(prompt).ids
    input_ids = jnp.array(input_ids)[None, :]
    
    for _ in range(max_length):
        logits = model.apply(
            params,
            input_ids,
            deterministic=True
        )
        next_token = jax.random.categorical(
            jax.random.PRNGKey(0),
            logits[:, -1, :] / temperature
        )
        input_ids = jnp.concatenate([input_ids, next_token[:, None]], axis=1)
        
        if next_token[0] == tokenizer.token_to_id("</s>"):
            break
    
    return tokenizer.decode(input_ids[0].tolist())

# Test generation
test_prompt = "Habari ya leo"
generated_text = generate(state.params, test_prompt)
print(f"Prompt: {test_prompt}")
print(f"Generated: {generated_text}")