# Training VishwamAI on GSM8K Dataset

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/VishwamAI/VishwamAI/blob/main/GSM8K_Training.ipynb)

This notebook demonstrates training on the GSM8K dataset using TPUs.

## Setup

First, install dependencies and VishwamAI package:

In [None]:
# Uninstall regular tensorflow to avoid conflicts
!pip uninstall -y tensorflow
!pip install tensorflow-cpu

# Install PyTorch/XLA and other dependencies
!pip install torch==2.0.0 'torch_xla[tpu]>=2.0' -f https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-2.0-cp39-cp39-linux_x86_64.whl
!pip install transformers datasets accelerate wandb sentencepiece

# Clone and install VishwamAI
!git clone https://github.com/VishwamAI/VishwamAI.git
!cd VishwamAI && pip install -e .

# Restart runtime to apply changes
import IPython
IPython.Application.instance().kernel.do_shutdown(True)

In [None]:
# Verify installation
import torch
import torch_xla
import torch_xla.core.xla_model as xm
print(f"PyTorch version: {torch.__version__}")
print(f"PyTorch/XLA version: {torch_xla.__version__}")

## Import Libraries

In [None]:
import os
import torch
import numpy as np
from datasets import load_dataset
import wandb

# Import VishwamAI modules
from vishwamai.model.transformer import create_transformer_model, get_pretrained_config
from vishwamai.data.dataset.implementations.gsm8k import GSM8KDataset
from vishwamai.training.distributed import tpu_utils

## Configure TPU

In [None]:
# Load TPU configuration
tpu_config = tpu_utils.load_tpu_config("vishwamai/configs/tpu_config.yaml")

# Initialize TPU
device = xm.xla_device()
rank = xm.get_ordinal()
world_size = xm.xrt_world_size()

print(f"Using TPU with {world_size} cores")
print(f"Local rank: {rank}")

## Load Dataset

In [None]:
# Load GSM8K dataset
dataset = load_dataset("openai/gsm8k", "main")
print(f"Train size: {len(dataset['train'])}")
print(f"Test size: {len(dataset['test'])}")

# Sample example
print("\nSample question:")
print(dataset['train'][0]['question'])
print("\nSample answer:")
print(dataset['train'][0]['answer'])

## Initialize Model

In [None]:
# Model configuration
config = get_pretrained_config(
    model_size="base",
    model_type="moe_mla_transformer"
)

# Create model
model = create_transformer_model(config)

# Move to TPU and optimize
model = model.to(device)
model = tpu_utils.optimize_tpu_execution(model, tpu_config)

## Training Setup

In [None]:
# Initialize wandb
if rank == 0:
    wandb.init(project="vishwamai-gsm8k")

# Training parameters
training_args = {
    "num_train_epochs": 3,
    "per_device_train_batch_size": 8,
    "gradient_accumulation_steps": 4,
    "learning_rate": 5e-4,
    "weight_decay": 0.01,
    "warmup_steps": 500,
    "max_grad_norm": 1.0
}

# Initialize optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=training_args["learning_rate"])
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, total_iters=training_args["warmup_steps"])

## Training Loop

In [None]:
def train_step(model, batch, optimizer, scheduler):
    """Single training step."""
    # Forward pass
    outputs = model(**batch)
    loss = outputs["loss"]
    
    # Backward pass
    loss = loss / training_args["gradient_accumulation_steps"]
    loss.backward()
    
    # Update
    optimizer.step()
    scheduler.step()
    optimizer.zero_grad()
    
    # TPU sync
    xm.mark_step()
    
    return loss.item()

# Training loop
for epoch in range(training_args["num_train_epochs"]):
    model.train()
    total_loss = 0
    
    for step, batch in enumerate(train_dataloader):
        loss = train_step(model, batch, optimizer, scheduler)
        total_loss += loss
        
        if step % 100 == 0 and rank == 0:
            print(f"Epoch {epoch}, Step {step}, Loss: {loss:.4f}")
            wandb.log({"loss": loss})
            
    # Save checkpoint
    if rank == 0:
        checkpoint = {
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "epoch": epoch
        }
        torch.save(checkpoint, f"gsm8k_checkpoint_epoch_{epoch+1}.pt")

## Evaluation

In [None]:
def evaluate():
    """Evaluate model on test set."""
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in test_dataloader:
            outputs = model(**batch)
            total_loss += outputs["loss"].item()
            
    return total_loss / len(test_dataloader)

# Run evaluation
test_loss = evaluate()
if rank == 0:
    print(f"Test Loss: {test_loss:.4f}")
    wandb.log({"test_loss": test_loss})

## Upload to HuggingFace

In [None]:
from huggingface_hub import HfApi

if rank == 0:
    # Save final model
    model_path = "gsm8k_trained_model"
    model.save_pretrained(model_path)
    tokenizer.save_pretrained(model_path)
    
    # Upload to HuggingFace
    api = HfApi()
    api.upload_folder(
        folder_path=model_path,
        repo_id="VishwamAI/VishwamAI",
        repo_type="model"
    )