# VishwamAI GSM8k Math Training Integration

This notebook demonstrates how to train VishwamAI models on mathematical reasoning tasks using the GSM8k dataset with deep thinking capabilities.

In [None]:
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from functools import partial
from dataclasses import dataclass

from vishwamai.training import VishwamaiTrainer
from vishwamai.conceptual_tokenizer import ConceptualTokenizer, ConceptualTokenizerConfig
from vishwamai.model import VishwamaiConfig, VishwamaiModel
from vishwamai.generate import VishwamaiGenerator, GenerationConfig
from vishwamai.deepthinking import CoTGenerationWrapper, GRPOTrainer, ReasoningDataset, create_format_reward_fn

## 1. Configuration Setup

Initialize the model and tokenizer configurations.

In [None]:
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model configuration
model_config = VishwamaiConfig(
    hidden_size=768,
    vocab_size=32000,
    num_attention_heads=12,
    num_hidden_layers=12,
    num_key_value_heads=12,
    intermediate_size=3072,
    max_position_embeddings=512,
    max_seq_len=2048,
    n_routed_experts=4,
    n_activated_experts=2,
    rope_theta=10000.0,
    layer_norm_eps=1e-5
)

# Tokenizer configuration
tokenizer_config = ConceptualTokenizerConfig(
    vocab_size=32000,  # Match model vocab size
    max_length=512
)

## 2. Load and Prepare Dataset

Load the GSM8k dataset from local parquet files.

In [None]:
# Load train and test datasets
train_dataset = load_dataset('parquet', data_files='gsm8k/train-00000-of-00001.parquet')['train']
test_dataset = load_dataset('parquet', data_files='gsm8k/test-00000-of-00001.parquet')['train']

print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

# Display a sample
print("\nSample question:")
print(train_dataset[0]['question'])
print("\nSample answer:")
print(train_dataset[0]['answer'])

## 3. Initialize Tokenizer and Model

Set up the tokenizer and model with the defined configurations.

In [None]:
# Initialize tokenizer
tokenizer = ConceptualTokenizer(tokenizer_config)

model_config.torch_device = device

# Initialize model and move to device
model = VishwamaiModel(model_config).to(device)

# Initialize Chain-of-Thought wrapper
cot_model = CoTGenerationWrapper(
    model=model,  # Model is already on the correct device
    tokenizer=tokenizer,
    num_self_reflect_steps=2
)

print(f"Using device: {device}")

## 4. Define Data Collation

Create a collate function to prepare batches for training.

In [None]:
def math_collate_fn(batch, tokenizer, dataset_type="gsm8k"):
    questions = [item['question'] for item in batch]
    answers = [item['answer'] for item in batch]
    
    # Tokenize inputs
    tokenized_inputs = tokenizer(
        questions,
        padding=True,
        truncation=True,
        return_tensors='pt'
    )
    
    # Tokenize targets
    tokenized_targets = tokenizer(
        answers,
        padding=True,
        truncation=True,
        return_tensors='pt'
    )
    
    # Move tensors to the same device as the model
    return {
        'input_ids': tokenized_inputs['input_ids'].to(device),
        'concept_ids': tokenized_inputs['concept_ids'].to(device),
        'labels': tokenized_targets['input_ids'].to(device)
    }

# Create data loaders with batch size 4
batch_size = 4  # Default group size
collate_fn = partial(math_collate_fn, tokenizer=tokenizer)

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    collate_fn=collate_fn
)

## 5. Training Setup with GRPO

Configure and initialize the GRPO trainer for optimizing Chain-of-Thought reasoning.

In [None]:
# Create reward functions
format_reward = create_format_reward_fn(tokenizer)

# Define custom accuracy reward
def math_accuracy_reward(response: str) -> float:
    # Extract numerical answer from response
    import re
    answer_match = re.search(r'Answer:\s*\$?(\d+)', response)
    if not answer_match:
        return 0.0
    return 1.0  # In practice, compare with ground truth

# Initialize GRPO trainer
grpo_trainer = GRPOTrainer(
    model=model,
    tokenizer=tokenizer,
    reward_fns={
        'format': format_reward,
        'accuracy': math_accuracy_reward
    },
    gamma=0.99,
    beta=0.1,
    eps_clip=0.2,
    group_size=batch_size
)

# Train for a few steps
sample_prompts = [train_dataset[i]['question'] for i in range(batch_size)]
loss = grpo_trainer.train_step(sample_prompts)
print(f"GRPO Training Loss: {loss:.4f}")

## 6. Test Chain-of-Thought Generation

Test the model's ability to solve math problems with explicit reasoning steps and self-reflection.

In [None]:
# Test on a sample problem
test_question = "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and sells the rest at the farmers market daily for $2 per egg. How much money does she make every day at the farmers market?"

print("Question:")
print(test_question)
print("\nGenerated Solution with Chain-of-Thought:")
solution = cot_model.generate(
    test_question,
    max_new_tokens=512  # Use standard max length
)

# Display structured output
print("\nThought Process:")
print(solution[0]['thought'])
print("\nSteps:")
for step in solution[0]['steps']:
    print(step)
print("\nReflections:")
for reflection in solution[0]['reflections']:
    print(reflection)
print("\nFinal Answer:")
print(solution[0]['answer'])