## 1. Configuration

Set up the evaluation parameters (from evaluate.sh)

In [1]:
# Configuration parameters from evaluate.sh
MODEL_PATH = "../checkpoints_adapted/checkpoint-214"
TEST_FILE = "../datasets/gsm8k_test.jsonl"
OUTPUT_FILE = "./evaluation_results.jsonl"
MAX_NEW_TOKENS = 1536
TEMPERATURE = 0.0  # Greedy decoding
MAX_EXAMPLES = None  # Limit for testing

print("=" * 80)
print("Step-JEPA Evaluation Configuration")
print("=" * 80)
print(f"Model Path: {MODEL_PATH}")
print(f"Test File: {TEST_FILE}")
print(f"Output File: {OUTPUT_FILE}")
print(f"Max New Tokens: {MAX_NEW_TOKENS}")
print(f"Temperature: {TEMPERATURE}")
print(f"Max Examples: {MAX_EXAMPLES}")
print("=" * 80)

Step-JEPA Evaluation Configuration
Model Path: ../checkpoints_adapted/checkpoint-214
Test File: ../datasets/gsm8k_test.jsonl
Output File: ./evaluation_results.jsonl
Max New Tokens: 1536
Temperature: 0.0
Max Examples: None


## 2. Import Libraries

In [2]:
import json
import re
import torch
from pathlib import Path
from tqdm import tqdm
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer

print("✓ All libraries imported successfully")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

  from .autonotebook import tqdm as notebook_tqdm


✓ All libraries imported successfully
PyTorch version: 2.9.0+cu128
CUDA available: True


## 3. Helper Functions for Answer Extraction and Evaluation

In [3]:
def extract_boxed_answer(text):
    """Extract answer from \\boxed{} format (used in DeepSeek system prompt)"""
    pattern = r'\\boxed\{([^}]+)\}'
    match = re.search(pattern, text)
    if match:
        answer = match.group(1).strip()
        return normalize_answer(answer)
    return None


def extract_hash_answer(text):
    """Extract answer from #### format (GSM8K standard format)"""
    pattern = r'\n#### (.+)$'
    match = re.search(pattern, text)
    if match:
        answer = match.group(1).strip()
        return normalize_answer(answer)
    return None


def extract_final_number(text):
    """Try to extract the last number from the text as a fallback"""
    numbers = re.findall(r'[-+]?(?:\d*\.*\d+)', text)
    if numbers:
        return normalize_answer(numbers[-1])
    return None


def normalize_answer(answer):
    """Normalize answer for comparison"""
    if answer is None:
        return None
    
    # Remove common text patterns
    answer = answer.replace('$', '').replace(',', '').strip()
    
    # Try to convert to number and normalize
    try:
        num = float(answer)
        # If it's a whole number, return as int
        if num.is_integer():
            return str(int(num))
        else:
            # Round to reasonable precision
            return f"{num:.10g}"
    except (ValueError, TypeError):
        # If not a number, return cleaned string
        return answer.strip()


def extract_answer_from_generated(generated_text):
    """Extract answer from generated text - try multiple formats"""
    # Try boxed format first (from DeepSeek system prompt)
    answer = extract_boxed_answer(generated_text)
    if answer is not None:
        return answer
    
    # Try #### format (GSM8K standard)
    answer = extract_hash_answer(generated_text)
    if answer is not None:
        return answer
    
    # Fallback: try to extract last number
    answer = extract_final_number(generated_text)
    return answer


def eval_gsm8k(generated, ground_truth):
    """
    Evaluate GSM8K answer.
    
    Args:
        generated: Generated response text
        ground_truth: Ground truth in GSM8K format (with ####)
    
    Returns:
        (is_correct, gt_answer, gen_answer)
    """
    # Extract ground truth answer
    gt_answer = extract_hash_answer(ground_truth)
    
    # Extract generated answer
    gen_answer = extract_answer_from_generated(generated)
    
    # Compare
    is_correct = (gt_answer is not None and 
                  gen_answer is not None and 
                  gt_answer == gen_answer)
    
    return is_correct, gt_answer, gen_answer


print("✓ Helper functions defined:")
print("  - extract_boxed_answer()")
print("  - extract_hash_answer()")
print("  - extract_final_number()")
print("  - normalize_answer()")
print("  - extract_answer_from_generated()")
print("  - eval_gsm8k()")

✓ Helper functions defined:
  - extract_boxed_answer()
  - extract_hash_answer()
  - extract_final_number()
  - normalize_answer()
  - extract_answer_from_generated()
  - eval_gsm8k()


## 4. Load Test Data

In [4]:
print(f"Loading test data from: {TEST_FILE}")

test_data = []
with open(TEST_FILE, 'r') as f:
    for line in f:
        example = json.loads(line.strip())
        test_data.append(example)

print(f"✓ Loaded {len(test_data)} test examples")

# Limit to MAX_EXAMPLES
if MAX_EXAMPLES is not None:
    test_data = test_data[:MAX_EXAMPLES]
    print(f"✓ Limited to first {MAX_EXAMPLES} examples for testing")

print(f"\nTotal examples to evaluate: {len(test_data)}")

Loading test data from: ../datasets/gsm8k_test.jsonl
✓ Loaded 1319 test examples

Total examples to evaluate: 1319


## 5. Load Model with vLLM (Fast Inference)

In [None]:
model_path = Path(MODEL_PATH)

print(f"Loading model with vLLM from: {model_path}")
print("=" * 80)

# Load tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_path)

# Load model with vLLM for fast inference
print("Loading model with vLLM for optimized inference...")
llm = LLM(
    model=str(model_path),
    # dtype="bfloat16",
    tensor_parallel_size=1,  # Adjust based on number of GPUs
    gpu_memory_utilization=0.9,
    max_model_len=4096,  # Adjust based on your needs
    trust_remote_code=True
)

print(f"\n✓ Model loaded successfully with vLLM")
print("  vLLM provides highly optimized inference with PagedAttention")
print("=" * 80)

Loading model with vLLM from: ../checkpoints_adapted/checkpoint-214
Loading tokenizer...


The tokenizer you are loading from '../checkpoints_adapted/checkpoint-214' with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e. This will lead to incorrect tokenization. You should set the `fix_mistral_regex=True` flag when loading this tokenizer to fix this issue.


Loading model with vLLM for optimized inference...
INFO 01-05 13:27:09 [utils.py:253] non-default args: {'trust_remote_code': True, 'max_model_len': 4096, 'disable_log_stats': True, 'model': '../checkpoints_adapted/checkpoint-214'}


The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


INFO 01-05 13:27:09 [model.py:514] Resolved architecture: LlamaForCausalLM
INFO 01-05 13:27:09 [model.py:1661] Using max model len 4096


2026-01-05 13:27:10,059	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


INFO 01-05 13:27:10 [scheduler.py:230] Chunked prefill is enabled with max_num_batched_tokens=16384.


The tokenizer you are loading from '../checkpoints_adapted/checkpoint-214' with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e. This will lead to incorrect tokenization. You should set the `fix_mistral_regex=True` flag when loading this tokenizer to fix this issue.


[0;36m(EngineCore_DP0 pid=15013)[0;0m INFO 01-05 13:27:11 [core.py:93] Initializing a V1 LLM engine (v0.13.0) with config: model='../checkpoints_adapted/checkpoint-214', speculative_config=None, tokenizer='../checkpoints_adapted/checkpoint-214', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cuda, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser='', reasoning_parser_plugin='', enable_in_reasoning=False), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None, kv_cach

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.25s/it]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:01<00:00,  1.25s/it]
[0;36m(EngineCore_DP0 pid=15013)[0;0m 


[0;36m(EngineCore_DP0 pid=15013)[0;0m INFO 01-05 13:27:16 [default_loader.py:308] Loading weights took 1.31 seconds
[0;36m(EngineCore_DP0 pid=15013)[0;0m INFO 01-05 13:27:17 [gpu_model_runner.py:3659] Model loading took 2.3185 GiB memory and 1.982397 seconds


In [None]:
new_system_prompt = "Please solve the problem step by step (separate steps with double newlines), but keep it short and put your final answer (do not include any other text or units) within \\boxed{}."

## 6. Define Generation Function (vLLM)

In [None]:
def generate_response(llm, tokenizer, messages, max_new_tokens=1536, temperature=0.0):
    """Generate response for a given prompt using vLLM"""
    # Format the conversation
    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    
    # Set up sampling parameters
    sampling_params = SamplingParams(
        temperature=temperature if temperature > 0 else 0.0,
        max_tokens=max_new_tokens,
        top_p=1.0 if temperature == 0 else 0.95,
    )
    
    # Generate with vLLM
    outputs = llm.generate([prompt], sampling_params)
    
    # Extract generated text
    generated_text = outputs[0].outputs[0].text
    
    return generated_text.strip()

print("✓ generate_response() function defined (vLLM-based)")

✓ generate_response() function defined (vLLM-based)


## 7. Test Generation on First Example

In [None]:
# Test on first example
test_example = test_data[0]
test_messages = test_example["messages"]
ground_truth = test_messages[-1]["content"]
input_messages = test_messages[:-1]  # Exclude assistant's answer
input_messages[0]["content"] = new_system_prompt # Update system prompt to match training

print("Testing generation on first example...")
print("=" * 80)
print("QUESTION:")
print(input_messages[1]["content"])
print("\n" + "=" * 80)

# Generate response
generated_response = generate_response(
    llm, tokenizer, input_messages,
    max_new_tokens=MAX_NEW_TOKENS,
    temperature=TEMPERATURE
)

print("GENERATED RESPONSE:")
print(generated_response)
print("\n" + "=" * 80)
print("GROUND TRUTH:")
print(ground_truth)
print("=" * 80)

Testing generation on first example...
QUESTION:
Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?



Adding requests: 100%|██████████| 1/1 [00:00<00:00, 603.84it/s]
Processed prompts: 100%|██████████| 1/1 [00:00<00:00,  1.78it/s, est. speed input: 248.23 toks/s, output: 432.14 toks/s]

GENERATED RESPONSE:
First, I need to determine how many eggs Janet has left after she eats three for breakfast. She lays 16 eggs per day and eats 3, so she has 13 eggs remaining.

Next, she bakes muffins for her friends every day, using 4 eggs. Subtracting the 4 eggs used for baking from the remaining 13 eggs, she has 9 eggs left.

Finally, she sells each of these 9 eggs for $2, so she makes 9 multiplied by $2, which equals $18 per day.
</think>

1. **Total eggs laid per day:**  
   Janet lays 16 eggs per day.  
   She eats 3 eggs for breakfast.  
   Remaining eggs = 16 - 3 = 13

2. **Eggs used for baking muffins:**  
   She bakes muffins for 4 eggs per day.  
   Remaining eggs = 13 - 4 = 9

3. **Earnings from selling eggs:**  
   She sells each egg for \$2.  
   Total earnings = 9 eggs × \$2 = \$18

**Final Answer:**  
\boxed{18}

GROUND TRUTH:
Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.
She makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market.
#### 18





### Evaluate Test Example

In [None]:
# Evaluate the test example
is_correct, gt_answer, gen_answer = eval_gsm8k(generated_response, ground_truth)

print("EVALUATION RESULT:")
print("=" * 80)
print(f"Ground Truth Answer: {gt_answer}")
print(f"Generated Answer: {gen_answer}")
print(f"Correct: {'✓ YES' if is_correct else '✗ NO'}")
print("=" * 80)

EVALUATION RESULT:
Ground Truth Answer: 18
Generated Answer: 18
Correct: ✓ YES


## 8. Run Full Evaluation with Batch Inference

Single pass: prepare prompts → batch generate → evaluate

In [None]:
print(f"Running batch evaluation on {len(test_data)} examples...")
print("=" * 80)

# Prepare all prompts
all_prompts = []
ground_truths = []
questions = []

for example in test_data:
    messages = example["messages"]
    ground_truth = messages[-1]["content"]
    input_messages = messages[:-1]  # Exclude assistant's answer

    input_messages[0]["content"] = new_system_prompt # Update system prompt to match training
    
    prompt = tokenizer.apply_chat_template(
        input_messages,
        tokenize=False,
        add_generation_prompt=True
    )
    
    all_prompts.append(prompt)
    ground_truths.append(ground_truth)
    questions.append(messages[1]["content"])

# Batch generate with vLLM (single call for all prompts)
print(f"Batch generating {len(all_prompts)} responses with vLLM...")
sampling_params = SamplingParams(
    temperature=TEMPERATURE if TEMPERATURE > 0 else 0.0,
    max_tokens=MAX_NEW_TOKENS,
    top_p=1.0 if TEMPERATURE == 0 else 0.95,
)
outputs = llm.generate(all_prompts, sampling_params)
print(f"✓ Generation complete\n")

# Evaluate and save results
results = []
correct_count = 0
total_count = 0

with open(OUTPUT_FILE, 'w') as f:
    for idx, (output, ground_truth, question) in enumerate(tqdm(
        zip(outputs, ground_truths, questions), 
        total=len(outputs),
        desc="Evaluating"
    )):
        try:
            # Extract generated text
            generated_response = output.outputs[0].text.strip()
            
            # Evaluate
            is_correct, gt_answer, gen_answer = eval_gsm8k(generated_response, ground_truth)
            
            if is_correct:
                correct_count += 1
            total_count += 1
            
            # Compute accuracy so far
            accuracy = correct_count / total_count * 100
            
            # Create result entry
            result = {
                "index": idx,
                "question": question,
                "ground_truth": ground_truth,
                "generated_response": generated_response,
                "gt_answer": gt_answer,
                "gen_answer": gen_answer,
                "correct": is_correct,
                "accuracy_so_far": accuracy
            }
            results.append(result)
            
            # Write to file
            f.write(json.dumps(result) + '\n')
            f.flush()
            
            # Print progress every 10 examples
            if (idx + 1) % 10 == 0:
                print(f"After {idx + 1} examples: {correct_count}/{total_count} correct ({accuracy:.2f}%)")
            
        except Exception as e:
            print(f"\n❌ Error at index {idx}: {e}")
            result = {
                "index": idx,
                "question": question,
                "error": str(e),
                "correct": False
            }
            results.append(result)
            f.write(json.dumps(result) + '\n')
            f.flush()

print("\n" + "=" * 80)
print("✓ Batch evaluation complete!")
print(f"Results saved to: {OUTPUT_FILE}")
print("=" * 80)

Running batch evaluation on 1319 examples...
Batch generating 1319 responses with vLLM...


Adding requests: 100%|██████████| 1319/1319 [00:00<00:00, 1490.77it/s]
Processed prompts: 100%|██████████| 1319/1319 [00:20<00:00, 64.89it/s, est. speed input: 8693.91 toks/s, output: 19344.37 toks/s] 


✓ Generation complete



Evaluating: 100%|██████████| 1319/1319 [00:00<00:00, 22904.89it/s]

After 10 examples: 3/10 correct (30.00%)
After 20 examples: 5/20 correct (25.00%)
After 30 examples: 8/30 correct (26.67%)
After 40 examples: 15/40 correct (37.50%)
After 50 examples: 21/50 correct (42.00%)
After 60 examples: 24/60 correct (40.00%)
After 70 examples: 27/70 correct (38.57%)
After 80 examples: 30/80 correct (37.50%)
After 90 examples: 35/90 correct (38.89%)
After 100 examples: 38/100 correct (38.00%)
After 110 examples: 42/110 correct (38.18%)
After 120 examples: 47/120 correct (39.17%)
After 130 examples: 50/130 correct (38.46%)
After 140 examples: 56/140 correct (40.00%)
After 150 examples: 59/150 correct (39.33%)
After 160 examples: 62/160 correct (38.75%)
After 170 examples: 66/170 correct (38.82%)
After 180 examples: 72/180 correct (40.00%)
After 190 examples: 76/190 correct (40.00%)
After 200 examples: 79/200 correct (39.50%)
After 210 examples: 83/210 correct (39.52%)
After 220 examples: 86/220 correct (39.09%)
After 230 examples: 92/230 correct (40.00%)
After 240




## 9. Display Final Results

In [None]:
# Final statistics
print("=" * 80)
print("FINAL EVALUATION RESULTS")
print("=" * 80)
print(f"Total examples: {total_count}")
print(f"Correct: {correct_count}")
print(f"Incorrect: {total_count - correct_count}")
print(f"Accuracy: {correct_count / total_count * 100:.2f}%")
print("=" * 80)

# Save summary
summary = {
    "model_path": MODEL_PATH,
    "test_file": TEST_FILE,
    "total_examples": total_count,
    "correct": correct_count,
    "accuracy": correct_count / total_count if total_count > 0 else 0,
    "max_new_tokens": MAX_NEW_TOKENS,
    "temperature": TEMPERATURE,
}

summary_file = OUTPUT_FILE.replace('.jsonl', '_summary.json')
with open(summary_file, 'w') as f:
    json.dump(summary, f, indent=2)

print(f"\nSummary saved to: {summary_file}")

## 10. Analyze Results

Let's look at some correct and incorrect examples.

In [None]:
# Separate correct and incorrect results
correct_results = [r for r in results if r.get("correct", False)]
incorrect_results = [r for r in results if not r.get("correct", False)]

print(f"Correct examples: {len(correct_results)}")
print(f"Incorrect examples: {len(incorrect_results)}")

# Show first few correct examples
print("\n" + "=" * 80)
print("SAMPLE CORRECT EXAMPLES:")
print("=" * 80)
for i, result in enumerate(correct_results[:3]):
    print(f"\n--- Correct Example {i+1} ---")
    print(f"Question: {result['question'][:100]}...")
    print(f"GT Answer: {result['gt_answer']}")
    print(f"Gen Answer: {result['gen_answer']}")

# Show first few incorrect examples
print("\n" + "=" * 80)
print("SAMPLE INCORRECT EXAMPLES:")
print("=" * 80)
for i, result in enumerate(incorrect_results[:3]):
    print(f"\n--- Incorrect Example {i+1} ---")
    if 'question' in result:
        print(f"Question: {result['question'][:100]}...")
    if 'gt_answer' in result:
        print(f"GT Answer: {result['gt_answer']}")
    if 'gen_answer' in result:
        print(f"Gen Answer: {result['gen_answer']}")
    if 'error' in result:
        print(f"Error: {result['error']}")

### Detailed Inspection of One Incorrect Example

In [None]:
# Pick one incorrect example for detailed inspection
if incorrect_results:
    example = incorrect_results[0]
    
    print("DETAILED INCORRECT EXAMPLE:")
    print("=" * 80)
    print("\nQUESTION:")
    print(example.get('question', 'N/A'))
    
    print("\n" + "=" * 80)
    print("GROUND TRUTH FULL RESPONSE:")
    print(example.get('ground_truth', 'N/A'))
    
    print("\n" + "=" * 80)
    print("GENERATED RESPONSE:")
    print(example.get('generated_response', 'N/A'))
    
    print("\n" + "=" * 80)
    print("EXTRACTED ANSWERS:")
    print(f"  Ground Truth: {example.get('gt_answer', 'N/A')}")
    print(f"  Generated: {example.get('gen_answer', 'N/A')}")
    print("=" * 80)
else:
    print("No incorrect examples found!")

## 11. Accuracy Over Time

Plot how accuracy evolved during evaluation.

In [None]:
import matplotlib.pyplot as plt

# Extract accuracy over time
accuracies = [r.get('accuracy_so_far', 0) for r in results if 'accuracy_so_far' in r]
indices = list(range(1, len(accuracies) + 1))

# Plot
plt.figure(figsize=(12, 6))
plt.plot(indices, accuracies, linewidth=2)
plt.axhline(y=accuracies[-1] if accuracies else 0, color='r', linestyle='--', 
            label=f'Final Accuracy: {accuracies[-1]:.2f}%' if accuracies else 'N/A')
plt.xlabel('Number of Examples Evaluated')
plt.ylabel('Accuracy (%)')
plt.title('Model Accuracy Over Evaluation')
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.show()

print(f"Final accuracy: {accuracies[-1]:.2f}%" if accuracies else "No data")