# VishwamAI GSM8K Model Analysis

This notebook analyzes the performance of the VishwamAI model on the GSM8K dataset, focusing on:
1. Step-by-step solution accuracy
2. Reasoning patterns
3. Error analysis
4. Performance comparisons

In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import AutoTokenizer
from typing import List, Dict, Any

from vishwamai.model.transformer import create_transformer_model
from vishwamai.utils.visualization import plot_attention_patterns
from vishwamai.utils.profiling import analyze_performance

## Load Model and Data

In [None]:
# Load model from HuggingFace
model_name = "VishwamAI/VishwamAI"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = create_transformer_model.from_pretrained(model_name)

# Load test dataset
test_data = load_dataset("openai/gsm8k", "main", split="test")
print(f"Test set size: {len(test_data)}")

## Step-by-Step Analysis

In [None]:
def analyze_solution_steps(
    question: str,
    generated_answer: str,
    reference_answer: str
) -> Dict[str, Any]:
    """Analyze solution steps for a math problem."""
    
    # Split into steps
    gen_steps = generated_answer.split('\n')
    ref_steps = reference_answer.split('\n')
    
    # Extract numbers from each step
    gen_numbers = [extract_numbers(step) for step in gen_steps]
    ref_numbers = [extract_numbers(step) for step in ref_steps]
    
    # Compare steps
    step_matches = []
    for gen_step, ref_step in zip(gen_steps, ref_steps):
        step_matches.append({
            "generated": gen_step,
            "reference": ref_step,
            "numbers_match": set(gen_numbers[i]) == set(ref_numbers[i])
        })
        
    return {
        "num_steps": len(gen_steps),
        "step_accuracy": sum(s["numbers_match"] for s in step_matches) / len(step_matches),
        "steps": step_matches
    }

# Analyze a batch of examples
results = []
for example in test_data[:100]:  # Analyze first 100 examples
    generated = generate_answer(model, tokenizer, example["question"])
    analysis = analyze_solution_steps(
        example["question"],
        generated,
        example["answer"]
    )
    results.append(analysis)
    
# Plot step accuracy distribution
accuracies = [r["step_accuracy"] for r in results]
plt.figure(figsize=(10, 6))
plt.hist(accuracies, bins=20)
plt.title("Distribution of Step-by-Step Accuracy")
plt.xlabel("Accuracy")
plt.ylabel("Count")
plt.show()

## Expert Utilization Analysis

In [None]:
def analyze_expert_usage(model, inputs: List[Dict[str, torch.Tensor]]):
    """Analyze which experts are used for different math operations."""
    expert_assignments = []
    
    # Get expert assignments for each input
    with torch.no_grad():
        for inp in inputs:
            outputs = model(
                **inp,
                output_router_logits=True,
                output_attentions=True
            )
            expert_assignments.append(outputs["router_logits"])
            
    # Analyze patterns
    expert_specialization = analyze_expert_patterns(expert_assignments)
    
    # Visualize
    plot_expert_heatmap(expert_specialization)
    
# Run analysis
inputs = prepare_batch(test_data[:50], tokenizer)
analyze_expert_usage(model, inputs)

## Attention Pattern Analysis

In [None]:
def analyze_attention_patterns(model, example):
    """Analyze multi-level attention patterns."""
    # Prepare input
    inputs = tokenizer(
        example["question"],
        return_tensors="pt",
        padding=True,
        truncation=True
    )
    
    # Get attention weights
    with torch.no_grad():
        outputs = model(
            **inputs,
            output_attentions=True,
            output_attention_levels=True
        )
        
    # Visualize attention at different levels
    attention_weights = outputs["attentions"]
    attention_levels = outputs["attention_levels"]
    
    plot_attention_patterns(
        attention_weights,
        attention_levels,
        tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    )
    
# Analyze a complex example
complex_example = find_complex_example(test_data)
analyze_attention_patterns(model, complex_example)

## Error Analysis

In [None]:
def categorize_errors(results):
    """Categorize different types of errors."""
    error_categories = {
        "numerical": [],  # Wrong calculations
        "reasoning": [],  # Wrong logic
        "steps": [],     # Missing/extra steps
        "context": []    # Misunderstanding context
    }
    
    for result in results:
        if result["step_accuracy"] < 1.0:
            error_type = analyze_error_type(result)
            error_categories[error_type].append(result)
            
    # Plot error distribution
    counts = [len(v) for v in error_categories.values()]
    plt.figure(figsize=(10, 6))
    plt.bar(error_categories.keys(), counts)
    plt.title("Error Type Distribution")
    plt.show()
    
# Analyze errors
categorize_errors(results)

## Performance Comparisons

In [None]:
def compare_performance():
    """Compare with other models."""
    models = {
        "VishwamAI": model,
        "GPT-3.5": load_comparison_results("gpt35"),
        "PaLM": load_comparison_results("palm"),
        "Claude": load_comparison_results("claude")
    }
    
    metrics = [
        "accuracy",
        "step_accuracy",
        "reasoning_score",
        "efficiency_score"
    ]
    
    # Create comparison plot
    plot_model_comparison(models, metrics)
    
# Run comparison
compare_performance()

## Conclusion

Summary of findings:
1. Step accuracy distribution
2. Expert specialization patterns
3. Attention level utilization
4. Common error patterns
5. Performance comparison