# VishwamAI: GSM8K Training Pipeline

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vishwamai/vishwamai/blob/main/train_gsm8k.ipynb)

This notebook implements the training pipeline for VishwamAI on the GSM8K dataset.

## Setup

First, let's set up our environment and install required packages.

In [None]:
!pip install -q jax jaxlib
!pip install -q flax optax
!pip install -q datasets transformers huggingface_hub
!pip install -q tqdm einops

In [None]:
# Clone VishwamAI repository
!git clone https://github.com/vishwamai/vishwamai.git
!cd vishwamai

In [None]:
import os
import json
from pathlib import Path
import jax
import jax.numpy as jnp
from datasets import load_dataset
from huggingface_hub import HfFolder
from tqdm.auto import tqdm

# Import VishwamAI modules
from vishwamai.model import VishwamAIModel, ModelConfig
from vishwamai.tokenizer import VishwamAITokenizer
from vishwamai.training import create_train_state, train_epoch

## Authentication

Let's set up authentication for Hugging Face Hub.

In [None]:
from huggingface_hub import notebook_login
notebook_login()

## Load GSM8K Dataset

In [None]:
# Load GSM8K dataset
dataset = load_dataset("openai/gsm8k", "main")
print(f"Train size: {len(dataset['train'])}")
print(f"Test size: {len(dataset['test'])}")

# Display sample
print("\nSample question:")
print(dataset['train'][0]['question'])
print("\nSample answer:")
print(dataset['train'][0]['answer'])

## Prepare Training Data

In [None]:
def format_example(example):
    """Format GSM8K example for training."""
    return {
        'text': f"Question: {example['question']}\nAnswer: {example['answer']}"
    }

# Format datasets
train_dataset = dataset['train'].map(format_example)
test_dataset = dataset['test'].map(format_example)

print("Sample formatted example:")
print(train_dataset[0]['text'])

## Initialize Model and Tokenizer

In [None]:
# Load model configuration
config_path = "vishwamai/configs/config_10B.json"
with open(config_path) as f:
    config = ModelConfig(**json.load(f))

# Initialize model
model = VishwamAIModel(config)

# Initialize tokenizer
tokenizer = VishwamAITokenizer.from_pretrained("gpt2")  # Base tokenizer
tokenizer.save_pretrained("tokenizer")

## Training Setup

In [None]:
def create_data_loader(dataset, tokenizer, batch_size):
    """Create a data loader for training."""
    def tokenize(examples):
        return tokenizer(
            examples['text'],
            padding='max_length',
            truncation=True,
            max_length=config.max_seq_len,
            return_tensors='np'
        )
    
    tokenized = dataset.map(
        tokenize,
        batched=True,
        remove_columns=dataset.column_names
    )
    
    return tokenized.with_format('numpy').iter(batch_size=batch_size)

# Training parameters
batch_size = 32
num_epochs = 10
learning_rate = 1e-4

# Create data loaders
train_loader = create_data_loader(train_dataset, tokenizer, batch_size)
test_loader = create_data_loader(test_dataset, tokenizer, batch_size)

# Initialize training state
rng = jax.random.PRNGKey(42)
state = create_train_state(model, config, learning_rate, rng)

## Training Loop

In [None]:
# Create output directory
output_dir = Path("checkpoints")
output_dir.mkdir(exist_ok=True)

# Training loop
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")
    
    # Train
    rng, epoch_rng = jax.random.split(rng)
    state, metrics = train_epoch(
        state=state,
        train_loader=train_loader,
        rng=epoch_rng,
        error_correction=None,  # No error correction for initial training
        epoch=epoch + 1
    )
    
    print(f"Train - Loss: {metrics['loss']:.4f}, Accuracy: {metrics['accuracy']:.4f}")
    
    # Save checkpoint
    if (epoch + 1) % 2 == 0:
        checkpoint_dir = output_dir / f"checkpoint-{epoch+1}"
        model.save_pretrained(checkpoint_dir)
        tokenizer.save_pretrained(checkpoint_dir)

## Push to Hugging Face Hub

In [None]:
# Push final model to hub
model.push_to_hub("VishwamAI/VishwamAI", commit_message="Trained on GSM8K")
tokenizer.push_to_hub("VishwamAI/VishwamAI", commit_message="Updated tokenizer")

## Evaluation

In [None]:
def evaluate_model(model, test_loader):
    """Evaluate model on test set."""
    total_loss = 0
    total_correct = 0
    total_samples = 0
    
    for batch in tqdm(test_loader, desc="Evaluating"):
        outputs = model(
            input_ids=batch['input_ids'],
            attention_mask=batch['attention_mask']
        )
        
        logits = outputs['logits']
        predictions = jnp.argmax(logits, axis=-1)
        
        # Compute accuracy
        correct = (predictions == batch['labels']) * batch['attention_mask']
        total_correct += jnp.sum(correct)
        total_samples += jnp.sum(batch['attention_mask'])
        
        # Compute loss
        loss = compute_loss(logits, batch['labels'], batch['attention_mask'])
        total_loss += loss * jnp.sum(batch['attention_mask'])
    
    return {
        'loss': total_loss / total_samples,
        'accuracy': total_correct / total_samples
    }

# Evaluate final model
metrics = evaluate_model(model, test_loader)
print(f"\nTest Results:")
print(f"Loss: {metrics['loss']:.4f}")
print(f"Accuracy: {metrics['accuracy']:.4f}")