# GEPA for JEE (Math)
Using the `dspy.GEPA` optimizer, we compare the performances of Gemma 27b It and OpenAI OSS 20b on Math questions from JEE, before and after GEPA optmization.

In [None]:
# api_key = input("Enter your OpenAI API key: ")
import dspy
import re

api_key = "your_api_key"
gemini_api_key = "your_api_key"

gemma = "openrouter/google/gemma-3-27b-it"
qwen = "openrouter/qwen/qwen3-32b"
openai_oss = "openai/gpt-oss-20b"

      
extraction_lm = dspy.LM(
    model="openrouter/openai/gpt-4o-mini", 
    api_base="https://openrouter.ai/api/v1",
    api_key=api_key, 
    temperature=0.1, 
    max_tokens=32000
)

from dspy.adapters import TwoStepAdapter
adapter = dspy.adapters.TwoStepAdapter(extraction_lm)

lm = dspy.LM(
    model=gemma,
    api_base="https://openrouter.ai/api/v1",
    api_key=api_key,
    temperature=1,
    max_tokens=32000,
)

dspy.configure(lm=lm, adapter=adapter)

In [None]:
import dspy

# Define a signature for feedback
class FeedbackSignature(dspy.Signature):
    """Give feedback on an incorrect solution to a math problem. It might be incorrect due to the reasoning or the answer formatting,
    solution must be inside \boxed{}. solution must follow this format: problem = dspy.InputField()
    answer = dspy.OutputField()
    """
    question: str = dspy.InputField()
    solution: str = dspy.InputField()
    expected_answer: str = dspy.InputField()
    feedback: str = dspy.OutputField(desc="Constructive feedback about the solution")

gemini_lm = dspy.LM(
    model="gemini-2.5-flash-lite",
    temperature=1.0,
    max_tokens=32000,
    api_key=gemini_api_key
)

# Function to get feedback
def get_llm_feedback(question, solution, expected_answer):
    # Create the feedback module
    feedback_module = dspy.ChainOfThought(
        FeedbackSignature,
        lm=gemini_lm
    )
    
    with dspy.context(lm=gemini_lm):
        # Call the LLM with structured inputs
        result = feedback_module(
            question=question,
            solution=solution,
            expected_answer=expected_answer
    )
        
    print("Feedback:", result.feedback[:100])
    
    return result.feedback


In [None]:
!pip install google-auth

In [None]:
get_llm_feedback("How much is 2 + 2?", "The answer is 4.", "4")  # Example usage

In [None]:
math = dspy.ChainOfThought("question -> answer: float")
math(question="Two dice are tossed. What is the probability that the sum equals 2?")

### Loading the JEE dataset

In [None]:
import dspy
from datasets import load_dataset

def init_jee_dataset():
    dataset = load_dataset("PhysicsWallahAI/JEE-Main-2025-Math", "apr")['test']

    # Split the dataset
    split = dataset.train_test_split(test_size=0.25, seed=42)
    train_set_2025 = split['train']
    test_set_2025 = split['test'].select(range(int(0.5 * len(split['test']))))
    val_set_2025 = split['test'].select(range(int(0.5 * len(split['test'])), len(split['test'])))

    # Convert to dspy.Example format
    train_set = [
        dspy.Example({
            "problem": x['question'],
            "answer": x['answer'],
        }).with_inputs("problem")
        for x in train_set_2025
    ]

    val_set = [
        dspy.Example({
            "problem": x['question'],
            "answer": x['answer'],
        }).with_inputs("problem")
        for x in val_set_2025
    ]

    test_set = [
        dspy.Example({
            "problem": x['question'],
            "answer": x['answer'],
        }).with_inputs("problem")
        for x in test_set_2025
    ]

    return train_set, val_set, test_set

# Example usage
train, val, test = init_jee_dataset()
print("Train size:", len(train))
print("Val size:", len(val))
print("Test size:", len(test))


### Defining the program: `dspy.ChainOfThought`

In [None]:
class GenerateResponse(dspy.Signature):
    """Solve the problem and provide the answer in the correct format."""
    problem = dspy.InputField()
    answer = dspy.OutputField()

program = dspy.ChainOfThought(GenerateResponse)
program_openai = dspy.ChainOfThought(GenerateResponse)

### Defining the evaluation metric
We simply check exact match between the predicted answer and the correct answer.

In [None]:
def metric(example, prediction, trace=None, pred_name=None, pred_trace=None):
    correct_answer = str(example['answer'])
    try:
        llm_answer = str(prediction.answer)
    except ValueError as e:
        return 0
    return int(correct_answer == llm_answer)

### Optimize the program with `dspy.GEPA`

In [None]:
#Experimental: adding a LLm-as-judge to provide feedback
def metric_with_feedback_llm(example, prediction, trace=None, pred_name=None, pred_trace=None):
    correct_answer = str(example['answer'])
    written_solution = example.get('solution', '')
    try:
        llm_answer = str(prediction.answer)
    except ValueError as e:
        print(f"Couldn't parse answer as integer: {prediction.answer}")
        feedback_text = f"The final answer must be a valid integer and nothing else. You responded with '{prediction.answer}', which couldn't be parsed as a python integer. Please ensure your answer is a valid integer without any additional text or formatting."
        #feedback_text += f" The correct answer is '{correct_answer}'."
        if written_solution:
        #    feedback_text += f" Here's the full step-by-step solution:\n{written_solution}\n\nThink about what takeaways you can learn from this solution to improve your future answers and approach to similar problems and ensure your final answer is a valid integer."
            feedback_text += get_llm_feedback(example['problem'], written_solution, str(correct_answer))
        return dspy.Prediction(score=0, feedback=feedback_text)

    score = int(correct_answer == llm_answer)

    feedback_text = get_llm_feedback(example['problem'], written_solution, str(correct_answer))
    # feedback_text = ""
    # if score == 1:
    #     feedback_text = f"Your answer is correct. The correct answer is '{correct_answer}'."
    # else:
    #     feedback_text = f"Your answer is incorrect. The correct answer is '{correct_answer}'."
    
    # if written_solution:
    #     feedback_text += f" Here's the full step-by-step solution:\n{written_solution}\n\nThink about what takeaways you can learn from this solution to improve your future answers and approach to similar problems."

    return dspy.Prediction(score=score, feedback=feedback_text)

In [None]:
def metric_with_feedback(example, prediction, trace=None, pred_name=None, pred_trace=None):
    
    correct_answer = str(example['answer'])
    written_solution = example.get('solution', '')
    try:
        llm_answer = str(prediction.answer)
    except ValueError as e:
        feedback_text = f"The final answer must be a valid integer and nothing else. You responded with '{prediction.answer}', which couldn't be parsed as a python integer. Please ensure your answer is a valid integer without any additional text or formatting."
        feedback_text += f" The correct answer is '{correct_answer}'."
        if written_solution:
            feedback_text += f" Here's the full step-by-step solution:\n{written_solution}\n\nThink about what takeaways you can learn from this solution to improve your future answers and approach to similar problems and ensure your final answer is a valid integer."
        return dspy.Prediction(score=0, feedback=feedback_text)
    
    if llm_answer is None:
        feedback_text = f"The final answer must be a valid integer and nothing else. You responded with '{prediction.answer}', which couldn't be parsed as a python integer. Please ensure your answer is a valid integer without any additional text or formatting."
        feedback_text += f" The correct answer is '{correct_answer}'."
        if written_solution:
            feedback_text += f" Here's the full step-by-step solution:\n{written_solution}\n\nThink about what takeaways you can learn from this solution to improve your future answers and approach to similar problems and ensure your final answer is a valid integer."
        return dspy.Prediction(score=0, feedback=feedback_text)

    score = int(correct_answer == llm_answer)

    # if not score and pred_name is not None:
    #     print("Calling the llm feedback function!")
    #     return metric_with_feedback_llm(example, prediction, trace, pred_name, pred_trace)
    # else:
    feedback_text = ""
    if score == 1:
        feedback_text = f"Your answer is correct. The correct answer is '{correct_answer}'."
    else:
        feedback_text = f"Your answer is incorrect. The correct answer is '{correct_answer}'."
    
    if written_solution:
        feedback_text += f" Here's the full step-by-step solution:\n{written_solution}\n\nThink about what takeaways you can learn from this solution to improve your future answers and approach to similar problems."

    return dspy.Prediction(score=score, feedback=feedback_text)

In [None]:
# Extract candidate data with prompts and scores for analysis
if hasattr(optimized_program_openai, 'detailed_results'):
    results = optimized_program_openai.detailed_results
    
    # Create structured candidate data
    candidate_data = []
    
    for i, (candidate, score) in enumerate(zip(results.candidates, results.val_aggregate_scores)):
        for name, predictor in candidate.named_predictors():
            prompt = predictor.signature.instructions
            
            candidate_entry = {
                'candidate_idx': i,
                'score': score,
                'prompt': prompt,
                'prompt_length_chars': len(prompt),
                'prompt_length_words': len(prompt.split()),
                'predictor_name': name
            }
            
            candidate_data.append(candidate_entry)
            break  # Just take first predictor for each candidate
    
    # Print summary
    print("📋 GEPA CANDIDATE DATA EXTRACTED")
    print("=" * 50)
    print(f"Total candidates: {len(candidate_data)}")
    
    # Show best and worst performers
    best_candidate = max(candidate_data, key=lambda x: x['score'])
    worst_candidate = min(candidate_data, key=lambda x: x['score'])
    
    print(f"\n🏆 Best candidate:")
    print(f"  Index: {best_candidate['candidate_idx']}")
    print(f"  Score: {best_candidate['score']:.4f}")
    print(f"  Length: {best_candidate['prompt_length_chars']} chars, {best_candidate['prompt_length_words']} words")
    
    print(f"\n📉 Worst candidate:")
    print(f"  Index: {worst_candidate['candidate_idx']}")
    print(f"  Score: {worst_candidate['score']:.4f}")
    print(f"  Length: {worst_candidate['prompt_length_chars']} chars, {worst_candidate['prompt_length_words']} words")
    
    # Show score progression
    scores = [c['score'] for c in candidate_data]
    print(f"\n📈 Score progression: {scores[0]:.3f} → {scores[-1]:.3f} (change: {scores[-1] - scores[0]:+.3f})")
    
    print(f"\n💡 Usage examples:")
    print(f"  Best prompt: candidate_data[{best_candidate['candidate_idx']}]['prompt']")
    print(f"  All scores: [c['score'] for c in candidate_data]")
    print(f"  High performers: [c for c in candidate_data if c['score'] > 0.5]")
    
else:
    print("No detailed results found. Make sure GEPA was run with track_stats=True")

In [None]:
import json

with open("data_openai.json", "w") as json_file:
    json.dump(candidate_data, json_file, indent=4)  # indent makes it more readable

In [None]:
candidate_data

In [None]:
# Examine candidate_data structure for ancestry tracking
print(f"Total candidates: {len(candidate_data)}")
print(f"First candidate keys: {list(candidate_data[0].keys())}")
print(f"Sample candidate structure:")
import json
print(json.dumps(candidate_data[0], indent=2)[:1000] + "..." if len(str(candidate_data[0])) > 1000 else json.dumps(candidate_data[0], indent=2))

In [None]:
from dspy import GEPA

optimizer = GEPA(
    metric=metric_with_feedback,
    auto="light",
    #max_full_evals=2,
    num_threads=32,
    track_stats=True,
    reflection_minibatch_size=5,
    reflection_lm=dspy.LM(model="gemini-2.5-pro", temperature=1.0, max_tokens=32000, api_key=gemini_api_key)
)

optimized_program = optimizer.compile(
    program,
    trainset=train,
    valset=val,
)

It can be seen that what GEPA is doing here, is precomputing some reasoning to come up with a good plan for future task instances. Due to the improved performance in unseen validation set, we expect this prompt to generalize!

In [None]:
import dspy
evaluate = dspy.Evaluate(
    devset=test,
    metric=metric,
    num_threads=1,
    display_table=True,
    display_progress=True,
    provide_traceback=True
)

evaluate(program)

In [None]:
import dspy
evaluate = dspy.Evaluate(
    devset=test,
    metric=metric,
    num_threads=1,
    display_table=True,
    display_progress=True,
    provide_traceback=True
)

evaluate(optimized_program)

In [None]:
# After running GEPA with track_stats=True
if hasattr(optimized_program, 'detailed_results'):
    results = optimized_program.detailed_results
    
    # Extract data for plotting
    prompt_chars = []
    prompt_words = []
    scores = []
    candidate_nums = []
    
    for i, (candidate, score) in enumerate(zip(results.candidates, results.val_aggregate_scores)):
        for name, predictor in candidate.named_predictors():
            prompt = predictor.signature.instructions
            prompt_chars.append(len(prompt))
            prompt_words.append(len(prompt.split()))
            scores.append(score)
            candidate_nums.append(i)
            break  # Just take first predictor
    
    # Create the plots
    import matplotlib.pyplot as plt
    import numpy as np
    
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
    
    # Plot 1: Score vs Character Count
    scatter1 = ax1.scatter(prompt_chars, scores, c=candidate_nums, cmap='viridis', alpha=0.7, s=60)
    ax1.set_xlabel('Prompt Length (characters)')
    ax1.set_ylabel('Validation Score')
    ax1.set_title('GEPA: Score vs Prompt Character Count')
    ax1.grid(True, alpha=0.3)
    
    # Add correlation coefficient
    if len(prompt_chars) > 1:
        correlation_chars = np.corrcoef(prompt_chars, scores)[0, 1]
        ax1.text(0.05, 0.95, f'Correlation: {correlation_chars:.3f}', 
                transform=ax1.transAxes, 
                bbox=dict(boxstyle="round", facecolor='wheat', alpha=0.8))
    
    # Add best candidate annotation
    best_idx = results.best_idx
    best_char_count = prompt_chars[best_idx]
    best_score = scores[best_idx]
    ax1.scatter([best_char_count], [best_score], c='red', s=100, marker='*', 
               label=f'Best (#{best_idx})', edgecolor='black', linewidth=1)
    ax1.legend()
    
    # Plot 2: Score vs Word Count
    scatter2 = ax2.scatter(prompt_words, scores, c=candidate_nums, cmap='viridis', alpha=0.7, s=60)
    ax2.set_xlabel('Prompt Length (words)')
    ax2.set_ylabel('Validation Score')
    ax2.set_title('GEPA: Score vs Prompt Word Count')
    ax2.grid(True, alpha=0.3)
    
    # Add correlation coefficient
    if len(prompt_words) > 1:
        correlation_words = np.corrcoef(prompt_words, scores)[0, 1]
        ax2.text(0.05, 0.95, f'Correlation: {correlation_words:.3f}', 
                transform=ax2.transAxes, 
                bbox=dict(boxstyle="round", facecolor='wheat', alpha=0.8))
    
    # Add best candidate annotation
    best_word_count = prompt_words[best_idx]
    ax2.scatter([best_word_count], [best_score], c='red', s=100, marker='*', 
               label=f'Best (#{best_idx})', edgecolor='black', linewidth=1)
    ax2.legend()
    
    # Plot 3: Evolution Timeline - Character Count
    ax3.plot(candidate_nums, prompt_chars, 'b-o', markersize=4, alpha=0.7)
    ax3.scatter([best_idx], [best_char_count], c='red', s=100, marker='*', 
               label=f'Best (#{best_idx})', edgecolor='black', linewidth=1)
    ax3.set_xlabel('Candidate Number (Evolution Order)')
    ax3.set_ylabel('Prompt Length (characters)')
    ax3.set_title('Prompt Length Evolution Over Time')
    ax3.grid(True, alpha=0.3)
    ax3.legend()
    
    # Plot 4: Evolution Timeline - Scores
    ax4.plot(candidate_nums, scores, 'g-o', markersize=4, alpha=0.7)
    ax4.scatter([best_idx], [best_score], c='red', s=100, marker='*', 
               label=f'Best (#{best_idx})', edgecolor='black', linewidth=1)
    ax4.set_xlabel('Candidate Number (Evolution Order)')
    ax4.set_ylabel('Validation Score')
    ax4.set_title('Score Evolution Over Time')
    ax4.grid(True, alpha=0.3)
    ax4.legend()
    
    # Add colorbar for candidate numbers
    plt.colorbar(scatter1, ax=ax1, label='Candidate #')
    plt.colorbar(scatter2, ax=ax2, label='Candidate #')
    
    plt.tight_layout()
    plt.show()
    
    # Print summary statistics
    print("\n" + "="*60)
    print("GEPA PROMPT EVOLUTION ANALYSIS")
    print("="*60)
    
    print(f"\nPrompt Length Statistics:")
    print(f"Character count range: {min(prompt_chars)} - {max(prompt_chars)} chars")
    print(f"Word count range: {min(prompt_words)} - {max(prompt_words)} words")
    print(f"Average character count: {np.mean(prompt_chars):.1f} chars")
    print(f"Average word count: {np.mean(prompt_words):.1f} words")
    
    print(f"\nScore Statistics:")
    print(f"Score range: {min(scores):.4f} - {max(scores):.4f}")
    print(f"Average score: {np.mean(scores):.4f}")
    print(f"Score std dev: {np.std(scores):.4f}")
    
    print(f"\nBest Candidate:")
    print(f"Candidate #{best_idx}: Score={best_score:.4f}")
    print(f"Length: {best_char_count} chars, {best_word_count} words")
    
    if len(prompt_chars) > 1:
        print(f"\nCorrelations:")
        print(f"Length (chars) vs Score: {correlation_chars:.3f}")
        print(f"Length (words) vs Score: {correlation_words:.3f}")
        
        # Growth analysis
        initial_chars = prompt_chars[0]
        final_chars = prompt_chars[-1]
        char_growth = final_chars - initial_chars
        char_growth_pct = (char_growth / initial_chars * 100) if initial_chars > 0 else 0
        
        initial_score = scores[0]
        final_score = scores[-1]
        score_change = final_score - initial_score
        
        print(f"\nEvolution Summary:")
        print(f"Character growth: {initial_chars} → {final_chars} ({char_growth:+d} chars, {char_growth_pct:+.1f}%)")
        print(f"Score change: {initial_score:.4f} → {final_score:.4f} ({score_change:+.4f})")
else:
    print("No detailed results found. Make sure to run GEPA with track_stats=True")