# Training VishwamAI on GSM8K Dataset

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/VishwamAI/VishwamAI/blob/main/GSM8K_Training.ipynb)

This notebook demonstrates how to train VishwamAI's MoE-MLA model on the GSM8K (Grade School Math 8K) dataset using TPUs.

## Setup

First, we'll install the required packages and clone the repository:

In [None]:
!pip install torch_xla
!pip install transformers datasets accelerate wandb
!git clone https://github.com/VishwamAI/VishwamAI.git
!cd VishwamAI && pip install -e .

In [None]:
import os
import torch
import torch_xla
import torch_xla.core.xla_model as xm
from datasets import load_dataset
import wandb

from vishwamai.model.transformer import create_transformer_model, get_pretrained_config
from vishwamai.data.dataset.implementations.gsm8k import GSM8KDataset
from vishwamai.training.distributed import tpu_utils

## TPU Setup

Configure TPU device and load training settings:

In [None]:
# Load TPU configuration
tpu_config = tpu_utils.load_tpu_config("vishwamai/configs/tpu_config.yaml")

# Initialize TPU
device, local_rank, world_size = tpu_utils.setup_tpu(tpu_config)
print(f"TPU initialized with {world_size} cores")

## Load and Process Data

Load the GSM8K dataset and prepare it for training:

In [None]:
# Load dataset
dataset = load_dataset("openai/gsm8k", "main")
print(f"Train size: {len(dataset['train'])}")
print(f"Test size: {len(dataset['test'])}")

# Show sample
print("\nSample question:")
print(dataset['train'][0]['question'])
print("\nSample answer:")
print(dataset['train'][0]['answer'])

## Initialize Model

Create the MoE-MLA model with TPU optimizations:

In [None]:
# Get model configuration
config = get_pretrained_config(
    model_size="base",
    model_type="moe_mla_transformer"
)

# Create model and move to TPU
model = create_transformer_model(config).to(device)

# Apply TPU optimizations
model = tpu_utils.optimize_tpu_execution(model, tpu_config)

## Training Configuration

Set up training parameters and optimization:

In [None]:
# Initialize wandb
if local_rank == 0:
    wandb.init(project="vishwamai-gsm8k")

# Training parameters
training_args = {
    "num_train_epochs": 3,
    "per_device_train_batch_size": 8,
    "gradient_accumulation_steps": 4,
    "learning_rate": 5e-4,
    "weight_decay": 0.01,
    "warmup_steps": 500,
    "max_grad_norm": 1.0,
    "use_bf16": True
}

## Training Loop

Train the model on GSM8K dataset:

In [None]:
def train_epoch(model, dataloader, optimizer, scheduler):
    model.train()
    total_loss = 0
    
    for step, batch in enumerate(dataloader):
        # Move batch to TPU
        batch = {k: v.to(device) for k, v in batch.items()}
        
        # Forward pass
        outputs = model(**batch)
        loss = outputs["loss"]
        
        # Backward pass
        loss = loss / training_args["gradient_accumulation_steps"]
        loss.backward()
        
        # Update weights
        if (step + 1) % training_args["gradient_accumulation_steps"] == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), training_args["max_grad_norm"])
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            
            # TPU barrier
            xm.mark_step()
        
        total_loss += loss.item()
        
        if step % 100 == 0 and local_rank == 0:
            wandb.log({"train_loss": loss.item()})
            
    return total_loss / len(dataloader)

# Start training
for epoch in range(training_args["num_train_epochs"]):
    loss = train_epoch(model, train_dataloader, optimizer, scheduler)
    
    if local_rank == 0:
        print(f"Epoch {epoch + 1}, Loss: {loss:.4f}")
        
        # Save checkpoint
        checkpoint = {
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "epoch": epoch,
            "loss": loss
        }
        torch.save(checkpoint, f"gsm8k_checkpoint_epoch_{epoch+1}.pt")

## Evaluation

Test the trained model on GSM8K test set:

In [None]:
def evaluate_sample(question):
    """Generate answer for a sample question."""
    model.eval()
    with torch.no_grad():
        input_text = f"Question: {question}\nLet's solve this step by step:\n"
        inputs = tokenizer(input_text, return_tensors="pt").to(device)
        
        outputs = model.generate(
            **inputs,
            max_length=200,
            num_beams=4,
            temperature=0.7,
            pad_token_id=tokenizer.pad_token_id
        )
        
        return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Test with sample questions
test_questions = [
    dataset['test'][0]['question'],
    dataset['test'][1]['question']
]

for question in test_questions:
    print("Question:", question)
    print("\nGenerated Answer:")
    print(evaluate_sample(question))
    print("\n---\n")

## Upload to HuggingFace Hub

Save and upload the trained model:

In [None]:
from huggingface_hub import HfApi

# Save model locally
model_path = "gsm8k_trained_model"
model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)

# Upload to HuggingFace
api = HfApi()
api.upload_folder(
    folder_path=model_path,
    repo_id="VishwamAI/VishwamAI",
    repo_type="model"
)