# Training VishwamAI on GSM8K Dataset

This notebook demonstrates how to train the VishwamAI model on the GSM8K (Grade School Math 8K) dataset.

## Setup

In [None]:
!pip install transformers datasets torch accelerate wandb

In [None]:
import os
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import Trainer, TrainingArguments
import wandb

from vishwamai.model.transformer import create_transformer_model, get_pretrained_config
from vishwamai.data.dataset.implementations.gsm8k import GSM8KDataset

## Load Dataset

In [None]:
# Load GSM8K dataset
dataset = load_dataset("openai/gsm8k", "main")
print(f"Train size: {len(dataset['train'])}")
print(f"Test size: {len(dataset['test'])}")

# Show sample
print("\nSample question:")
print(dataset['train'][0]['question'])
print("\nSample answer:")
print(dataset['train'][0]['answer'])

## Initialize Model and Tokenizer

In [None]:
# Get model configuration
config = get_pretrained_config(
    model_size="base",
    model_type="moe_mla_transformer"
)

# Create model
model = create_transformer_model(config)

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    "gpt2",  # Starting with GPT-2 tokenizer
    pad_token="<pad>"
)
tokenizer.add_special_tokens({
    'sep_token': '<sep>',
    'cls_token': '<cls>'
})

## Prepare Dataset

In [None]:
def format_example(example):
    """Format GSM8K example for training."""
    return {
        "input_text": f"Question: {example['question']}\nLet's solve this step by step:\n",
        "target_text": example['answer']
    }

# Process datasets
train_dataset = GSM8KDataset(
    dataset["train"],
    tokenizer=tokenizer,
    max_length=512,
    format_func=format_example
)

eval_dataset = GSM8KDataset(
    dataset["test"],
    tokenizer=tokenizer,
    max_length=512,
    format_func=format_example
)

## Training Configuration

In [None]:
# Initialize wandb
wandb.init(project="vishwamai-gsm8k")

# Training arguments
training_args = TrainingArguments(
    output_dir="./checkpoints",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=100,
    evaluation_strategy="steps",
    eval_steps=500,
    save_strategy="steps",
    save_steps=500,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    report_to="wandb"
)

## Training

In [None]:
# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
)

# Train model
trainer.train()

## Evaluation

In [None]:
def evaluate_sample(question):
    """Generate answer for a sample question."""
    input_text = f"Question: {question}\nLet's solve this step by step:\n"
    inputs = tokenizer(input_text, return_tensors="pt").to(model.device)
    
    outputs = model.generate(
        **inputs,
        max_length=200,
        num_beams=4,
        temperature=0.7,
        pad_token_id=tokenizer.pad_token_id
    )
    
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Test with sample questions
test_questions = [
    dataset['test'][0]['question'],
    dataset['test'][1]['question']
]

for question in test_questions:
    print("Question:", question)
    print("\nGenerated Answer:")
    print(evaluate_sample(question))
    print("\n---\n")

## Save Model and Upload to HuggingFace Hub

In [None]:
# Save model
model_path = "gsm8k_trained_model"
trainer.save_model(model_path)
tokenizer.save_pretrained(model_path)

# Upload to HuggingFace Hub
from huggingface_hub import HfApi

api = HfApi()
api.upload_folder(
    folder_path=model_path,
    repo_id="VishwamAI/VishwamAI",
    repo_type="model"
)