In [21]:
# lets test forward hooks.
# load one SAE
# load the model
# print attributes/architicture of the sae and the model
# feed forward one example to the SAE
# get the reconstruction of the activations based on one feature
# create a hook that replaces the model's activation with the activation - lambda * reconstruction

In [22]:
import torch
print(torch.version.cuda)

12.8


In [23]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from sae_lens import SAE
import torch.nn.functional as F
import os
import re

# Disable gradients for memory efficiency
torch.set_grad_enabled(False)

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


# Model loading

In [24]:
# # 4-bit quantization config
# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_use_double_quant=True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_compute_dtype=torch.bfloat16
# )

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b-it",
    device_map='auto',
    # quantization_config=bnb_config,
)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

# Utils

In [25]:
def count_tokens(prompt):
    """
    Count the number of tokens in a prompt and return detailed tokenization info.
    
    Args:
        prompt: Input string to tokenize
    
    Returns:
        tuple: (token_count, token_strings, token_ids, formatted_string)
            - token_count: Number of tokens
            - token_strings: List of individual token strings
            - token_ids: List of token IDs
            - formatted_string: Human-readable representation with token boundaries
    """
    # Tokenize the prompt
    tokens = tokenizer.encode(prompt, add_special_tokens=True)
    token_count = len(tokens)
    
    # Get individual token strings
    token_strings = []
    for token_id in tokens:
        token_str = tokenizer.decode([token_id])
        token_strings.append(token_str)
    
    # Create a formatted string showing token boundaries
    formatted_parts = []
    for i, token_str in enumerate(token_strings):
        # Add token boundaries and numbering for clarity
        formatted_parts.append(f"[{i}:'{token_str}']")
    
    formatted_string = " ".join(formatted_parts)
    
    return token_count, token_strings, tokens, formatted_string

# Example usage:
def demo_tokenization(prompt):
    """
    Demonstrate tokenization for a given prompt.
    """
    count, strings, ids, formatted = count_tokens(prompt)
    
    print(f"Original text: '{prompt}'")
    print(f"Token count: {count}")
    print(f"Token IDs: {ids}")
    print(f"Token strings: {strings}")
    print(f"Formatted view: {formatted}")
    print("-" * 50)

# count total number of output tokens in the results
def count_total_output_tokens(results):
    """
    Count total number of output tokens in the results.
    
    Args:
        results: List of tuples (problem, answer)
    
    Returns:
        int: Total number of output tokens
    """
    total_tokens = 0
    for _, answer in results:
        # Count tokens in the answer
        count, _, _, _ = count_tokens(answer)
        total_tokens += count
    
    return total_tokens

def calc_correct_answer(problem):
    """
    Calculate the correct answer for a given problem.
    
    Args:
        problem: Problem string (e.g., "2 + 2")
    
    Returns:
        str: Correct answer as a string
    """
    # Remove any non-numeric characters except for +, -, *, /
    clean_problem = re.sub(r'[^\d\s\+\-\*/]', '', problem)
    
    try:
        # Evaluate the expression safely
        answer = eval(clean_problem)
        return str(answer)
    except Exception as e:
        print(f"Error evaluating problem '{problem}': {e}")
        return "ERROR"


## generation utils

In [37]:
def batch_generation(prompts, max_new_tokens=10, batch_size=8):
    """
    Generate responses for multiple prompts in batches with improved memory management.
    """
    all_responses = []
    
    for i in range(0, len(prompts), batch_size):
        batch_prompts = prompts[i:i + batch_size]
        
        # Tokenize batch
        inputs = tokenizer(
            batch_prompts, 
            return_tensors="pt", 
            padding=True, 
            truncation=True,
            add_special_tokens=True
        ).to(device)
        
        with torch.no_grad():
            outputs = model.generate(
                inputs.input_ids,
                attention_mask=inputs.attention_mask,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                repetition_penalty=1,
                num_beams=1,  
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
                temperature=0,
            )
            
            # Move outputs to CPU immediately and decode
            outputs_cpu = outputs.cpu()
            
            # Decode batch responses
            batch_responses = []
            for j, output in enumerate(outputs_cpu):
                # Get only the new tokens (after the input)
                new_tokens = output[len(inputs.input_ids[j]):]
                response = tokenizer.decode(new_tokens, skip_special_tokens=True)
                batch_responses.append(response.strip())
            
            # Clean up GPU tensors
            del outputs
            del inputs
            torch.cuda.empty_cache()  # Force GPU memory cleanup
        
        all_responses.extend(batch_responses)
        
        # Optional: print progress
        print(f"Processed {min(i + batch_size, len(prompts))}/{len(prompts)} prompts")
    
    return all_responses

def batch_quiz_model(dataset_path, max_new_tokens=10, batch_size=8, start_batch=0, end_batch=None, prefix="You are a calculator. Answer immediately: ", postfix=""):
    """
    Quiz the model with improved memory management.
    """
    # Load the dataset
    with open(dataset_path, 'r') as f:
        problems = [line.strip() for line in f.readlines() if line.strip()]
    
    # Calculate total number of batches
    total_batches = (len(problems) + batch_size - 1) // batch_size
    
    # Set end_batch if not specified
    if end_batch is None:
        end_batch = total_batches
    
    # Validate batch indices
    if start_batch < 0 or start_batch >= total_batches:
        raise ValueError(f"start_batch {start_batch} is out of range [0, {total_batches})")
    if end_batch < start_batch or end_batch > total_batches:
        raise ValueError(f"end_batch {end_batch} is out of range [{start_batch}, {total_batches}]")
    
    # Calculate problem indices for the specified batch range
    start_idx = start_batch * batch_size
    end_idx = min(end_batch * batch_size, len(problems))
    
    # Get subset of problems for this run
    subset_problems = problems[start_idx:end_idx]
    
    # Add prefix/postfix to problems
    prompts = [prefix + problem + postfix for problem in subset_problems]
    
    print(f"Processing batches {start_batch} to {end_batch-1} ({len(subset_problems)} problems) in batches of {batch_size}...")
    print(f"Total dataset size: {len(problems)} problems ({total_batches} total batches)")
    
    # Generate answers in batches
    answers = batch_generation(prompts, max_new_tokens=max_new_tokens, batch_size=batch_size)
    
    # Combine problems with answers - store prompts, not the full dataset
    results = list(zip(prompts, answers))
    
    # Clean up large variables
    del subset_problems
    del prompts
    del answers
    torch.cuda.empty_cache()  # Force GPU memory cleanup
    return results

def save_results(results, output_path, append=True):
    """
    Save results to file with option to append or overwrite.
    
    Args:
        results: List of tuples (problem, answer)
        output_path: Path to output file
        append: If True, append to existing file. If False, overwrite
    """
    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    mode = 'a' if append else 'w'
    with open(output_path, mode) as f:
        for problem, answer in results:
            f.write(f"{problem}\t{answer}\n")
    
    print(f"Results {'appended to' if append else 'saved to'} {output_path}")

def find_optimal_batch_size():
    """Find the largest batch size that fits in GPU memory"""
    batch_size = 256
    
    while True:
        try:
            # Test with your actual inference function
            test_prompts = ["5+7="] * batch_size
            _ = batch_generation(test_prompts, max_new_tokens=6, batch_size=batch_size)
            batch_size *= 2  # Double until we hit memory limit
        except torch.cuda.OutOfMemoryError:
            # Use half of the failed size
            optimal_size = batch_size // 4  # Go back to last working size
            break
    
    return optimal_size

In [27]:
def benchmark_correct_answers(dataset_path, batch_size=64, start_batch=0, end_batch=None,):
    """
    Load the same questions as batch_quiz_model and calculate correct answers.
    
    Args:
        dataset_path: Path to the dataset file
        batch_size: Number of problems to process simultaneously (same as used in batch_quiz_model)
        start_batch: Starting batch index (0-based, same as used in batch_quiz_model)
        end_batch: Ending batch index (exclusive, same as used in batch_quiz_model)
        prefix: Prefix that was added to each problem in batch_quiz_model
        postfix: Postfix that was added to each problem in batch_quiz_model
    
    Returns:
        List of tuples (full_prompt, correct_answer)
    """
    # Load the dataset
    with open(dataset_path, 'r') as f:
        problems = [line.strip() for line in f.readlines() if line.strip()]
    
    # Calculate total number of batches
    total_batches = (len(problems) + batch_size - 1) // batch_size
    
    # Set end_batch if not specified
    if end_batch is None:
        end_batch = total_batches
    
    # Validate batch indices (same validation as batch_quiz_model)
    if start_batch < 0 or start_batch >= total_batches:
        raise ValueError(f"start_batch {start_batch} is out of range [0, {total_batches})")
    if end_batch < start_batch or end_batch > total_batches:
        raise ValueError(f"end_batch {end_batch} is out of range [{start_batch}, {total_batches}]")
    
    # Calculate problem indices for the specified batch range (same as batch_quiz_model)
    start_idx = start_batch * batch_size
    end_idx = min(end_batch * batch_size, len(problems))
    
    # Get subset of problems for this run
    subset_problems = problems[start_idx:end_idx]
    
    results = []
    for problem in subset_problems:
        # Create the same full prompt as batch_quiz_model does
        full_prompt =problem 
        
        # Calculate correct answer for the original problem (without prefix/postfix)
        correct_answer = calc_correct_answer(problem)
        
        results.append((full_prompt, correct_answer))
    
    print(f"Calculated correct answers for batches {start_batch} to {end_batch-1} ({len(subset_problems)} problems)")
    
    return results

In [28]:
def extract_first_digits(model_answer, correct_answer):
    """
    Extract the first X digits from the model answer, where X is the length of the correct answer.
    Removes any repeated question pattern first and handles step-by-step numbering and negative signs.
    
    Args:
        model_answer: The model's response string
        correct_answer: The correct answer as a string
    
    Returns:
        str: First X digits from model answer, or empty string if not enough digits found
    """
    import re
    
    # Remove question pattern if it exists (operand1{operation}operand2 with no spaces)
    # This pattern matches: digits, operation (+,-,*,/), digits, optional =
    question_pattern = r'\d+[\+\-\*/]\d+\s*=?\s*'
    cleaned_answer = re.sub(question_pattern, '', model_answer)
    
    # Remove step numbering patterns (e.g., "1.", "2.", etc.)
    step_pattern = r'\b\d+\.\s*'
    cleaned_answer = re.sub(step_pattern, '', cleaned_answer)
    
    # Get the length of the correct answer (including minus sign if present)
    target_length = len(correct_answer.strip())
    
    # Extract the first number (with optional minus sign) from the cleaned model answer
    # This pattern matches: optional minus sign followed by digits
    number_pattern = r'-?\d+'
    numbers = re.findall(number_pattern, cleaned_answer)
    
    if numbers:
        first_number = numbers[0]
        # Truncate to target length if needed
        if len(first_number) >= target_length:
            return first_number[:target_length]
        else:
            return first_number
    else:
        return ""

In [39]:
def calculate_accuracy(model_results, correct_answers):
    """
    Calculate accuracy by comparing extracted numbers from model answers with correct answers.
    Handles negative numbers and step-by-step responses.
    Skipped examples (no valid numbers found) are counted as incorrect in accuracy calculation.
    
    Args:
        model_results: List of tuples (prompt, model_answer) from batch_quiz_model
        correct_answers: List of tuples (prompt, correct_answer) from benchmark_correct_answers
    
    Returns:
        tuple: (accuracy, correct_count, total_count, detailed_results, skipped_count)
            - accuracy: Accuracy as a float between 0 and 1 (skipped examples count as incorrect)
            - correct_count: Number of correct answers
            - total_count: Total number of questions (including skipped)
            - detailed_results: List of tuples (prompt, correct_answer, model_answer, extracted_number, is_correct)
            - skipped_count: Number of prompts skipped due to no valid numbers found
    """
    import re
    
    if len(model_results) != len(correct_answers):
        raise ValueError(f"Mismatch in result lengths: {len(model_results)} vs {len(correct_answers)}")
    
    correct_count = 0
    skipped_count = 0
    detailed_results = []
    
    for i, ((model_prompt, model_answer), (correct_prompt, correct_answer)) in enumerate(zip(model_results, correct_answers)):
        # Remove question pattern if it exists
        question_pattern = r'\d+[\+\-\*/]\d+\s*=?\s*'
        cleaned_answer = re.sub(question_pattern, '', model_answer)
        
        # Remove step numbering patterns (e.g., "1.", "2.", etc.)
        step_pattern = r'\b\d+\.\s*'
        cleaned_answer = re.sub(step_pattern, '', cleaned_answer)
        
        # Extract the first number (with optional minus sign) from the cleaned model answer
        number_pattern = r'-?\d+'
        numbers = re.findall(number_pattern, cleaned_answer)
        
        # Handle skipped cases (no numbers found) - count as incorrect
        if not numbers:
            skipped_count += 1
            extracted_number = "NO_NUMBER_FOUND"
            is_correct = False
        else:
            # Get the first extracted number
            extracted_number = numbers[0]
            # Check if extracted number matches the correct answer
            is_correct = extracted_number == correct_answer.strip()
            if is_correct:
                correct_count += 1
        
        detailed_results.append((
            model_prompt,
            correct_answer,
            model_answer,
            extracted_number,
            is_correct
        ))
    
    total_count = len(detailed_results)  # Count all results including skipped
    accuracy = correct_count / total_count if total_count > 0 else 0.0
    
    return accuracy, correct_count, total_count, detailed_results, skipped_count

def print_accuracy_report(accuracy, correct_count, total_count, detailed_results, skipped_count=0, show_errors_only=False, max_examples=10):
    """
    Print a detailed accuracy report.
    
    Args:
        accuracy: Accuracy as a float
        correct_count: Number of correct answers
        total_count: Total number of questions
        detailed_results: Detailed results from calculate_accuracy
        skipped_count: Number of prompts skipped due to insufficient digits
        show_errors_only: If True, only show incorrect answers
        max_examples: Maximum number of examples to show
    """
    incorrect_count = total_count - correct_count
    print(f"Accuracy: {accuracy:.3f} ({correct_count}/{total_count})")
    print(f"Correct: {correct_count}")
    print(f"Incorrect: {incorrect_count} (including {skipped_count} skipped)")
    if skipped_count > 0:
        print(f"  - Actually incorrect: {incorrect_count - skipped_count}")
        print(f"  - Skipped (no valid numbers): {skipped_count}")
    print("-" * 80)
    
    examples_shown = 0
    for prompt, correct, model_answer, extracted, is_correct in detailed_results:
        if show_errors_only and is_correct:
            continue
        
        if examples_shown >= max_examples:
            break
        
        status = "✓" if is_correct else "✗"
        skipped_indicator = " (SKIPPED)" if extracted == "NO_NUMBER_FOUND" else ""
        print(f"{status} Prompt: {prompt}")
        print(f"  Correct: {correct}")
        print(f"  Model output: '{model_answer}'")
        print(f"  Extracted: '{extracted}'{skipped_indicator}")
        print()
        
        examples_shown += 1
    
    if examples_shown < len(detailed_results):
        remaining = len(detailed_results) - examples_shown
        print(f"... and {remaining} more examples")

In [30]:
assert 0

AssertionError: 

In [None]:
import subprocess
import psutil
import shutil

# Optional: Use pynvml if you want more structured GPU info
try:
    from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo
    pynvml_available = True
except ImportError:
    pynvml_available = False

def show_gpu_usage():
    print("=== GPU ===")
    if pynvml_available:
        nvmlInit()
        handle = nvmlDeviceGetHandleByIndex(0)  # GPU 0
        meminfo = nvmlDeviceGetMemoryInfo(handle)
        print(f"Total: {meminfo.total / 1024**2:.2f} MB")
        print(f"Used:  {meminfo.used / 1024**2:.2f} MB")
        print(f"Free:  {meminfo.free / 1024**2:.2f} MB")
    else:
        print(subprocess.getoutput("nvidia-smi"))

def show_ram_usage():
    print("\n=== RAM ===")
    mem = psutil.virtual_memory()
    print(f"Total: {mem.total / 1024**3:.2f} GB")
    print(f"Used:  {mem.used / 1024**3:.2f} GB")
    print(f"Free:  {mem.available / 1024**3:.2f} GB")

def show_disk_usage(path="/"):
    print("\n=== Disk ===")
    total, used, free = shutil.disk_usage(path)
    print(f"Total: {total / 1024**3:.2f} GB")
    print(f"Used:  {used / 1024**3:.2f} GB")
    print(f"Free:  {free / 1024**3:.2f} GB")

# Benchmark model performnce

In [46]:
batch_size = 1024 * 4  # Set your desired batch size
start_batch = 0  # Starting batch index
end_batch = 3    # Ending batch index (exclusive)

In [None]:
show_gpu_usage()
show_ram_usage()
show_disk_usage()

=== GPU ===
Sun Aug 10 14:29:33 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 575.57.08              Driver Version: 575.57.08      CUDA Version: 12.9     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100 80GB PCIe          On  |   00000000:19:00.0 Off |                   On |
| N/A   30C    P0             62W /  300W |   12652MiB /  81920MiB |     N/A      Default |
|                                         |                        |              Enabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A100 80GB PCIe        

In [32]:
import os
# Disable Triton since no compiler found in container
os.environ["TORCHDYNAMO_DISABLE"] = "1"
os.environ["USE_TRITON"] = "0"
print("Triton disabled - using PyTorch's default implementations")

Triton disabled - using PyTorch's default implementations


In [47]:
prefix = ""
postfix = ""
results = batch_quiz_model("./data/random_addition.txt", max_new_tokens=12 ,batch_size=batch_size, start_batch=start_batch, end_batch=end_batch, prefix=prefix, postfix=postfix)

correct_answers = benchmark_correct_answers("./data/random_addition.txt", batch_size=batch_size, start_batch=start_batch, end_batch=end_batch)

# Calculate accuracy with skipping
accuracy, correct_count, total_count, detailed_results, skipped_count = calculate_accuracy(results, correct_answers)

# Print report including skipped count
print_accuracy_report(accuracy, correct_count, total_count, detailed_results, skipped_count, show_errors_only=True, max_examples=10)

Processing batches 0 to 2 (10000 problems) in batches of 4096...
Total dataset size: 10000 problems (3 total batches)


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 4096/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 8192/10000 prompts
Processed 10000/10000 prompts
Calculated correct answers for batches 0 to 2 (10000 problems)
Accuracy: 0.269 (2686/10000)
Correct: 2686
Incorrect: 7314 (including 652 skipped)
  - Actually incorrect: 6662
  - Skipped (no valid numbers): 652
--------------------------------------------------------------------------------
✗ Prompt: 8270+1860=
  Correct: 10130
  Model output: 'Please provide the answer and the steps.

**Answer'
  Extracted: 'NO_NUMBER_FOUND' (SKIPPED)

✗ Prompt: 6734+7265=
  Correct: 13999
  Model output: '14999

14999+'
  Extracted: '14999'

✗ Prompt: 1466+5426=
  Correct: 6892
  Model output: 'Here's how to solve it:

**1'
  Extracted: '1'

✗ Prompt: 6578+9322=
  Correct: 15900
  Model output: '16890

16890-'
  Extracted: '16890'

✗ Prompt: 7949+3433=
  Correct: 11382
  Model output: 'Here's how to solve it:

**1'
  Extracted: '1'

✗ Prompt: 7420+2184=
  Correct: 9604
  Model output: '**Answer:** 

**Answer:**'
  Extracted: 'NO_NUMBER_FOUND'

In [48]:
import json
import os
from datetime import datetime

def run_comprehensive_evaluation(batch_size=1024*4, start_batch=0, end_batch=3, 
                                max_new_tokens=12, output_dir="./answers"):
    """
    Run all datasets with all prefix/postfix combinations and save results.
    
    Args:
        batch_size: Batch size for processing
        start_batch: Starting batch index
        end_batch: Ending batch index
        max_new_tokens: Maximum new tokens to generate
        output_dir: Directory to save results
    """
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Define datasets
    datasets = [
        "./data/addition.txt",
        "./data/random_addition.txt", 
        "./data/subtraction.txt",
        "./data/random_subtraction.txt"
    ]
    
    # Define prefix/postfix combinations
    prompt_combinations = [
        {"name": "no_prompt", "prefix": "", "postfix": ""},
        {"name": "answer_directly", "prefix": "Answer directly: ", "postfix": " "},
        {"name": "answer_suffix", "prefix": "", "postfix": " Answer: "},
        {"name": "space_suffix", "prefix": "", "postfix": " "},
        {"name": "final_answer", "prefix": "", "postfix": " final answer: "},
        {"name": "calc_prefix", "prefix": "calc: ", "postfix": ""}
    ]
    
    # Store all metrics
    all_metrics = {}
    
    # Get timestamp for this run
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    print(f"Starting comprehensive evaluation at {timestamp}")
    print(f"Will process {len(datasets)} datasets with {len(prompt_combinations)} prompt combinations")
    print(f"Total combinations: {len(datasets) * len(prompt_combinations)}")
    print("=" * 80)
    
    combination_count = 0
    total_combinations = len(datasets) * len(prompt_combinations)
    
    for dataset_path in datasets:
        dataset_name = os.path.basename(dataset_path).replace('.txt', '')
        print(f"\nProcessing dataset: {dataset_name}")
        print("-" * 40)
        
        if dataset_name not in all_metrics:
            all_metrics[dataset_name] = {}
        
        for prompt_combo in prompt_combinations:
            combination_count += 1
            combo_name = prompt_combo["name"]
            prefix = prompt_combo["prefix"]
            postfix = prompt_combo["postfix"]
            
            print(f"[{combination_count}/{total_combinations}] Running {dataset_name} with {combo_name}")
            print(f"  Prefix: '{prefix}' | Postfix: '{postfix}'")
            
            try:
                # Run the model
                results = batch_quiz_model(
                    dataset_path, 
                    max_new_tokens=max_new_tokens,
                    batch_size=batch_size, 
                    start_batch=start_batch, 
                    end_batch=end_batch, 
                    prefix=prefix, 
                    postfix=postfix
                )
                
                # Get correct answers
                correct_answers = benchmark_correct_answers(
                    dataset_path, 
                    batch_size=batch_size, 
                    start_batch=start_batch, 
                    end_batch=end_batch
                )
                
                # Calculate accuracy
                accuracy, correct_count, total_count, detailed_results, skipped_count = calculate_accuracy(
                    results, correct_answers
                )
                
                # Save detailed results to text file
                answers_filename = f"{dataset_name}_{combo_name}_{timestamp}.txt"
                answers_filepath = os.path.join(output_dir, answers_filename)
                
                with open(answers_filepath, 'w') as f:
                    f.write(f"Dataset: {dataset_name}\n")
                    f.write(f"Prompt: prefix='{prefix}' postfix='{postfix}'\n")
                    f.write(f"Timestamp: {timestamp}\n")
                    f.write(f"Batch range: {start_batch} to {end_batch-1}\n")
                    f.write(f"Max new tokens: {max_new_tokens}\n")
                    f.write(f"Accuracy: {accuracy:.3f} ({correct_count}/{total_count})\n")
                    f.write(f"Skipped: {skipped_count}\n")
                    f.write("=" * 80 + "\n\n")
                    
                    for prompt, correct, model_answer, extracted, is_correct in detailed_results:
                        status = "✓" if is_correct else "✗"
                        skipped_indicator = " (SKIPPED)" if extracted == "NO_NUMBER_FOUND" else ""
                        f.write(f"{status} Problem: {prompt.replace(prefix, '').replace(postfix, '')}\n")
                        f.write(f"  Correct: {correct}\n")
                        f.write(f"  Model output: '{model_answer}'\n")
                        f.write(f"  Extracted: '{extracted}'{skipped_indicator}\n")
                        f.write("\n")
                
                # Store metrics
                all_metrics[dataset_name][combo_name] = {
                    "accuracy": accuracy,
                    "correct_count": correct_count,
                    "total_count": total_count,
                    "incorrect_count": total_count - correct_count,
                    "skipped_count": skipped_count,
                    "prefix": prefix,
                    "postfix": postfix,
                    "max_new_tokens": max_new_tokens,
                    "batch_range": f"{start_batch}-{end_batch-1}",
                    "answers_file": answers_filename
                }
                
                print(f"  Results: {accuracy:.3f} accuracy ({correct_count}/{total_count}), {skipped_count} skipped")
                print(f"  Saved to: {answers_filename}")
                
                # Clean up memory
                del results, correct_answers, detailed_results
                torch.cuda.empty_cache()
                
            except Exception as e:
                print(f"  ERROR: {str(e)}")
                all_metrics[dataset_name][combo_name] = {
                    "error": str(e),
                    "prefix": prefix,
                    "postfix": postfix
                }
    
    # Save all metrics to JSON
    metrics_filename = f"evaluation_metrics_{timestamp}.json"
    metrics_filepath = os.path.join(output_dir, metrics_filename)
    
    # Add summary statistics
    summary_data = {
        "timestamp": timestamp,
        "evaluation_settings": {
            "batch_size": batch_size,
            "start_batch": start_batch,
            "end_batch": end_batch,
            "max_new_tokens": max_new_tokens,
            "batch_range_description": f"batches {start_batch} to {end_batch-1}"
        },
        "datasets_processed": len(datasets),
        "prompt_combinations_processed": len(prompt_combinations),
        "total_combinations": total_combinations,
        "metrics": all_metrics
    }
    
    with open(metrics_filepath, 'w') as f:
        json.dump(summary_data, f, indent=2)
    
    print("\n" + "=" * 80)
    print("EVALUATION COMPLETE!")
    print(f"Processed {total_combinations} combinations")
    print(f"Metrics saved to: {metrics_filename}")
    print(f"Individual results saved to: {output_dir}/")
    print("=" * 80)
    
    # Print summary table
    print("\nSUMMARY TABLE:")
    print("-" * 100)
    print(f"{'Dataset':<20} {'Prompt':<15} {'Accuracy':<10} {'Correct':<8} {'Total':<8} {'Skipped':<8}")
    print("-" * 100)
    
    for dataset_name in all_metrics:
        for combo_name, metrics in all_metrics[dataset_name].items():
            if 'accuracy' in metrics:  # Skip error cases
                print(f"{dataset_name:<20} {combo_name:<15} {metrics['accuracy']:<10.3f} "
                      f"{metrics['correct_count']:<8} {metrics['total_count']:<8} {metrics['skipped_count']:<8}")
    
    return all_metrics, metrics_filepath

In [49]:
# Run the comprehensive evaluation
batch_size = 1024 * 4  # Your current batch size
start_batch = 0  # Your current start batch
end_batch = 3    # Your current end batch

# Run all combinations
all_metrics, metrics_file = run_comprehensive_evaluation(
    batch_size=batch_size,
    start_batch=start_batch, 
    end_batch=end_batch,
    max_new_tokens=25,  # Using 25 to handle longer subtraction answers
    output_dir="./answers"
)

The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Starting comprehensive evaluation at 20250810_194802
Will process 4 datasets with 6 prompt combinations
Total combinations: 24

Processing dataset: addition
----------------------------------------
[1/24] Running addition with no_prompt
  Prefix: '' | Postfix: ''
Processing batches 0 to 2 (10000 problems) in batches of 4096...
Total dataset size: 10000 problems (3 total batches)


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 4096/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 8192/10000 prompts
Processed 10000/10000 prompts
Calculated correct answers for batches 0 to 2 (10000 problems)
  Results: 0.173 accuracy (1734/10000), 98 skipped
  Saved to: addition_no_prompt_20250810_194802.txt
[2/24] Running addition with answer_directly
  Prefix: 'Answer directly: ' | Postfix: ' '
Processing batches 0 to 2 (10000 problems) in batches of 4096...
Total dataset size: 10000 problems (3 total batches)


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 4096/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 8192/10000 prompts
Processed 10000/10000 prompts
Calculated correct answers for batches 0 to 2 (10000 problems)
  Results: 0.963 accuracy (9625/10000), 344 skipped
  Saved to: addition_answer_directly_20250810_194802.txt
[3/24] Running addition with answer_suffix
  Prefix: '' | Postfix: ' Answer: '
Processing batches 0 to 2 (10000 problems) in batches of 4096...
Total dataset size: 10000 problems (3 total batches)


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 4096/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 8192/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 10000/10000 prompts
Calculated correct answers for batches 0 to 2 (10000 problems)
  Results: 0.067 accuracy (672/10000), 291 skipped
  Saved to: addition_answer_suffix_20250810_194802.txt
[4/24] Running addition with space_suffix
  Prefix: '' | Postfix: ' '
Processing batches 0 to 2 (10000 problems) in batches of 4096...
Total dataset size: 10000 problems (3 total batches)


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 4096/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 8192/10000 prompts
Processed 10000/10000 prompts
Calculated correct answers for batches 0 to 2 (10000 problems)
  Results: 0.295 accuracy (2952/10000), 1686 skipped
  Saved to: addition_space_suffix_20250810_194802.txt
[5/24] Running addition with final_answer
  Prefix: '' | Postfix: ' final answer: '
Processing batches 0 to 2 (10000 problems) in batches of 4096...
Total dataset size: 10000 problems (3 total batches)


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 4096/10000 prompts
Processed 8192/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 10000/10000 prompts
Calculated correct answers for batches 0 to 2 (10000 problems)
  Results: 0.497 accuracy (4966/10000), 820 skipped
  Saved to: addition_final_answer_20250810_194802.txt
[6/24] Running addition with calc_prefix
  Prefix: 'calc: ' | Postfix: ''
Processing batches 0 to 2 (10000 problems) in batches of 4096...
Total dataset size: 10000 problems (3 total batches)


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 4096/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 8192/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 10000/10000 prompts
Calculated correct answers for batches 0 to 2 (10000 problems)
  Results: 0.138 accuracy (1382/10000), 59 skipped
  Saved to: addition_calc_prefix_20250810_194802.txt

Processing dataset: random_addition
----------------------------------------
[7/24] Running random_addition with no_prompt
  Prefix: '' | Postfix: ''
Processing batches 0 to 2 (10000 problems) in batches of 4096...
Total dataset size: 10000 problems (3 total batches)


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 4096/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 8192/10000 prompts
Processed 10000/10000 prompts
Calculated correct answers for batches 0 to 2 (10000 problems)
  Results: 0.272 accuracy (2720/10000), 339 skipped
  Saved to: random_addition_no_prompt_20250810_194802.txt
[8/24] Running random_addition with answer_directly
  Prefix: 'Answer directly: ' | Postfix: ' '
Processing batches 0 to 2 (10000 problems) in batches of 4096...
Total dataset size: 10000 problems (3 total batches)


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 4096/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 8192/10000 prompts
Processed 10000/10000 prompts
Calculated correct answers for batches 0 to 2 (10000 problems)
  Results: 0.574 accuracy (5743/10000), 3482 skipped
  Saved to: random_addition_answer_directly_20250810_194802.txt
[9/24] Running random_addition with answer_suffix
  Prefix: '' | Postfix: ' Answer: '
Processing batches 0 to 2 (10000 problems) in batches of 4096...
Total dataset size: 10000 problems (3 total batches)


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 4096/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 8192/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 10000/10000 prompts
Calculated correct answers for batches 0 to 2 (10000 problems)
  Results: 0.212 accuracy (2119/10000), 370 skipped
  Saved to: random_addition_answer_suffix_20250810_194802.txt
[10/24] Running random_addition with space_suffix
  Prefix: '' | Postfix: ' '
Processing batches 0 to 2 (10000 problems) in batches of 4096...
Total dataset size: 10000 problems (3 total batches)


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 4096/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 8192/10000 prompts
Processed 10000/10000 prompts
Calculated correct answers for batches 0 to 2 (10000 problems)
  Results: 0.314 accuracy (3144/10000), 1219 skipped
  Saved to: random_addition_space_suffix_20250810_194802.txt
[11/24] Running random_addition with final_answer
  Prefix: '' | Postfix: ' final answer: '
Processing batches 0 to 2 (10000 problems) in batches of 4096...
Total dataset size: 10000 problems (3 total batches)


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 4096/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 8192/10000 prompts
Processed 10000/10000 prompts
Calculated correct answers for batches 0 to 2 (10000 problems)
  Results: 0.331 accuracy (3311/10000), 1731 skipped
  Saved to: random_addition_final_answer_20250810_194802.txt
[12/24] Running random_addition with calc_prefix
  Prefix: 'calc: ' | Postfix: ''
Processing batches 0 to 2 (10000 problems) in batches of 4096...
Total dataset size: 10000 problems (3 total batches)


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 4096/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 8192/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 10000/10000 prompts
Calculated correct answers for batches 0 to 2 (10000 problems)
  Results: 0.316 accuracy (3163/10000), 119 skipped
  Saved to: random_addition_calc_prefix_20250810_194802.txt

Processing dataset: subtraction
----------------------------------------
[13/24] Running subtraction with no_prompt
  Prefix: '' | Postfix: ''
Processing batches 0 to 2 (10000 problems) in batches of 4096...
Total dataset size: 10000 problems (3 total batches)


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 4096/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 8192/10000 prompts
Processed 10000/10000 prompts
Calculated correct answers for batches 0 to 2 (10000 problems)
  Results: 0.263 accuracy (2632/10000), 382 skipped
  Saved to: subtraction_no_prompt_20250810_194802.txt
[14/24] Running subtraction with answer_directly
  Prefix: 'Answer directly: ' | Postfix: ' '
Processing batches 0 to 2 (10000 problems) in batches of 4096...
Total dataset size: 10000 problems (3 total batches)


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 4096/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 8192/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 10000/10000 prompts
Calculated correct answers for batches 0 to 2 (10000 problems)
  Results: 0.907 accuracy (9072/10000), 310 skipped
  Saved to: subtraction_answer_directly_20250810_194802.txt
[15/24] Running subtraction with answer_suffix
  Prefix: '' | Postfix: ' Answer: '
Processing batches 0 to 2 (10000 problems) in batches of 4096...
Total dataset size: 10000 problems (3 total batches)


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 4096/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 8192/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 10000/10000 prompts
Calculated correct answers for batches 0 to 2 (10000 problems)
  Results: 0.244 accuracy (2438/10000), 1037 skipped
  Saved to: subtraction_answer_suffix_20250810_194802.txt
[16/24] Running subtraction with space_suffix
  Prefix: '' | Postfix: ' '
Processing batches 0 to 2 (10000 problems) in batches of 4096...
Total dataset size: 10000 problems (3 total batches)


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 4096/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 8192/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 10000/10000 prompts
Calculated correct answers for batches 0 to 2 (10000 problems)
  Results: 0.220 accuracy (2202/10000), 3580 skipped
  Saved to: subtraction_space_suffix_20250810_194802.txt
[17/24] Running subtraction with final_answer
  Prefix: '' | Postfix: ' final answer: '
Processing batches 0 to 2 (10000 problems) in batches of 4096...
Total dataset size: 10000 problems (3 total batches)
Processed 4096/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 8192/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 10000/10000 prompts
Calculated correct answers for batches 0 to 2 (10000 problems)
  Results: 0.485 accuracy (4851/10000), 2172 skipped
  Saved to: subtraction_final_answer_20250810_194802.txt
[18/24] Running subtraction with calc_prefix
  Prefix: 'calc: ' | Postfix: ''
Processing batches 0 to 2 (10000 problems) in batches of 4096...
Total dataset size: 10000 problems (3 total batches)


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 4096/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 8192/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 10000/10000 prompts
Calculated correct answers for batches 0 to 2 (10000 problems)
  Results: 0.040 accuracy (401/10000), 127 skipped
  Saved to: subtraction_calc_prefix_20250810_194802.txt

Processing dataset: random_subtraction
----------------------------------------
[19/24] Running random_subtraction with no_prompt
  Prefix: '' | Postfix: ''
Processing batches 0 to 2 (10000 problems) in batches of 4096...
Total dataset size: 10000 problems (3 total batches)


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 4096/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 8192/10000 prompts
Processed 10000/10000 prompts
Calculated correct answers for batches 0 to 2 (10000 problems)
  Results: 0.243 accuracy (2427/10000), 916 skipped
  Saved to: random_subtraction_no_prompt_20250810_194802.txt
[20/24] Running random_subtraction with answer_directly
  Prefix: 'Answer directly: ' | Postfix: ' '
Processing batches 0 to 2 (10000 problems) in batches of 4096...
Total dataset size: 10000 problems (3 total batches)


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 4096/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 8192/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 10000/10000 prompts
Calculated correct answers for batches 0 to 2 (10000 problems)
  Results: 0.591 accuracy (5907/10000), 2655 skipped
  Saved to: random_subtraction_answer_directly_20250810_194802.txt
[21/24] Running random_subtraction with answer_suffix
  Prefix: '' | Postfix: ' Answer: '
Processing batches 0 to 2 (10000 problems) in batches of 4096...
Total dataset size: 10000 problems (3 total batches)


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 4096/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 8192/10000 prompts
Processed 10000/10000 prompts
Calculated correct answers for batches 0 to 2 (10000 problems)
  Results: 0.229 accuracy (2286/10000), 1660 skipped
  Saved to: random_subtraction_answer_suffix_20250810_194802.txt
[22/24] Running random_subtraction with space_suffix
  Prefix: '' | Postfix: ' '
Processing batches 0 to 2 (10000 problems) in batches of 4096...
Total dataset size: 10000 problems (3 total batches)


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 4096/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 8192/10000 prompts
Processed 10000/10000 prompts
Calculated correct answers for batches 0 to 2 (10000 problems)
  Results: 0.161 accuracy (1610/10000), 3832 skipped
  Saved to: random_subtraction_space_suffix_20250810_194802.txt
[23/24] Running random_subtraction with final_answer
  Prefix: '' | Postfix: ' final answer: '
Processing batches 0 to 2 (10000 problems) in batches of 4096...
Total dataset size: 10000 problems (3 total batches)


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 4096/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 8192/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 10000/10000 prompts
Calculated correct answers for batches 0 to 2 (10000 problems)
  Results: 0.267 accuracy (2665/10000), 2258 skipped
  Saved to: random_subtraction_final_answer_20250810_194802.txt
[24/24] Running random_subtraction with calc_prefix
  Prefix: 'calc: ' | Postfix: ''
Processing batches 0 to 2 (10000 problems) in batches of 4096...
Total dataset size: 10000 problems (3 total batches)


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 4096/10000 prompts


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Processed 8192/10000 prompts
Processed 10000/10000 prompts
Calculated correct answers for batches 0 to 2 (10000 problems)
  Results: 0.054 accuracy (545/10000), 80 skipped
  Saved to: random_subtraction_calc_prefix_20250810_194802.txt

EVALUATION COMPLETE!
Processed 24 combinations
Metrics saved to: evaluation_metrics_20250810_194802.json
Individual results saved to: ./answers/

SUMMARY TABLE:
----------------------------------------------------------------------------------------------------
Dataset              Prompt          Accuracy   Correct  Total    Skipped 
----------------------------------------------------------------------------------------------------
addition             no_prompt       0.173      1734     10000    98      
addition             answer_directly 0.963      9625     10000    344     
addition             answer_suffix   0.067      672      10000    291     
addition             space_suffix    0.295      2952     10000    1686    
addition             final

In [50]:
def analyze_results(metrics_file_path):
    """
    Load and analyze the evaluation results.
    
    Args:
        metrics_file_path: Path to the metrics JSON file
    """
    with open(metrics_file_path, 'r') as f:
        data = json.load(f)
    
    print("DETAILED ANALYSIS")
    print("=" * 80)
    print(f"Evaluation timestamp: {data['timestamp']}")
    print(f"Settings: {data['evaluation_settings']}")
    print()
    
    # Find best and worst performing combinations
    all_results = []
    for dataset_name, dataset_metrics in data['metrics'].items():
        for combo_name, metrics in dataset_metrics.items():
            if 'accuracy' in metrics:  # Skip error cases
                all_results.append({
                    'dataset': dataset_name,
                    'prompt': combo_name,
                    'accuracy': metrics['accuracy'],
                    'correct': metrics['correct_count'],
                    'total': metrics['total_count'],
                    'skipped': metrics['skipped_count'],
                    'prefix': metrics['prefix'],
                    'postfix': metrics['postfix']
                })
    
    # Sort by accuracy
    all_results.sort(key=lambda x: x['accuracy'], reverse=True)
    
    print("TOP 5 PERFORMING COMBINATIONS:")
    print("-" * 80)
    for i, result in enumerate(all_results[:5]):
        print(f"{i+1}. {result['dataset']} + {result['prompt']}: {result['accuracy']:.3f}")
        print(f"   Prefix: '{result['prefix']}' | Postfix: '{result['postfix']}'")
        print(f"   Correct: {result['correct']}/{result['total']}, Skipped: {result['skipped']}")
        print()
    
    print("BOTTOM 5 PERFORMING COMBINATIONS:")
    print("-" * 80)
    for i, result in enumerate(all_results[-5:]):
        print(f"{i+1}. {result['dataset']} + {result['prompt']}: {result['accuracy']:.3f}")
        print(f"   Prefix: '{result['prefix']}' | Postfix: '{result['postfix']}'")
        print(f"   Correct: {result['correct']}/{result['total']}, Skipped: {result['skipped']}")
        print()
    
    # Dataset-wise performance
    print("DATASET-WISE AVERAGE PERFORMANCE:")
    print("-" * 50)
    dataset_stats = {}
    for result in all_results:
        if result['dataset'] not in dataset_stats:
            dataset_stats[result['dataset']] = []
        dataset_stats[result['dataset']].append(result['accuracy'])
    
    for dataset, accuracies in dataset_stats.items():
        avg_acc = sum(accuracies) / len(accuracies)
        print(f"{dataset}: {avg_acc:.3f} (min: {min(accuracies):.3f}, max: {max(accuracies):.3f})")
    
    # Prompt-wise performance
    print("\nPROMPT-WISE AVERAGE PERFORMANCE:")
    print("-" * 50)
    prompt_stats = {}
    for result in all_results:
        if result['prompt'] not in prompt_stats:
            prompt_stats[result['prompt']] = []
        prompt_stats[result['prompt']].append(result['accuracy'])
    
    for prompt, accuracies in prompt_stats.items():
        avg_acc = sum(accuracies) / len(accuracies)
        print(f"{prompt}: {avg_acc:.3f} (min: {min(accuracies):.3f}, max: {max(accuracies):.3f})")
    
    return data, all_results

In [51]:
# Analyze the results
data, all_results = analyze_results(metrics_file)

DETAILED ANALYSIS
Evaluation timestamp: 20250810_194802
Settings: {'batch_size': 4096, 'start_batch': 0, 'end_batch': 3, 'max_new_tokens': 25, 'batch_range_description': 'batches 0 to 2'}

TOP 5 PERFORMING COMBINATIONS:
--------------------------------------------------------------------------------
1. addition + answer_directly: 0.963
   Prefix: 'Answer directly: ' | Postfix: ' '
   Correct: 9625/10000, Skipped: 344

2. subtraction + answer_directly: 0.907
   Prefix: 'Answer directly: ' | Postfix: ' '
   Correct: 9072/10000, Skipped: 310

3. random_subtraction + answer_directly: 0.591
   Prefix: 'Answer directly: ' | Postfix: ' '
   Correct: 5907/10000, Skipped: 2655

4. random_addition + answer_directly: 0.574
   Prefix: 'Answer directly: ' | Postfix: ' '
   Correct: 5743/10000, Skipped: 3482

5. addition + final_answer: 0.497
   Prefix: '' | Postfix: ' final answer: '
   Correct: 4966/10000, Skipped: 820

BOTTOM 5 PERFORMING COMBINATIONS:
--------------------------------------------

# Ablation

In [None]:
# Activation Modification Hook
class AblationHook:
    def __init__(self, layer_idx, lambda_scale,):
        """
        Initialize the ablation hook.
        
        Args:
            sae: Sparse Autoencoder instance
            lambda_scale: Scaling factor for the single latent reconstruction
        """
        # Initialize SAE for layer 12
        sae_id = f"layer_{layer_idx}/width_16k/canonical"
        sae, _, _ = SAE.from_pretrained(
            release="gemma-scope-2b-pt-res-canonical",
            sae_id=sae_id
        )

        model_dtype = next(model.parameters()).dtype
        sae = sae.to(device=device, dtype=model_dtype)
        sae.eval()

        self.sae = sae
        self.lambda_scale = lambda_scale
        self.sae_hook_handles = []
    

    # Single Latent Reconstruction Hook
    def single_latent_hook(self, module, input, output, latent_idx):
        """
        Forward hook to zero all latent activations except for the specified index.
        
        Args:
            module: The HookPoint module (hook_sae_acts_post)
            input: Input to the hook (not used here)
            output: feature_acts tensor of shape [..., d_sae]
            latent_idx: Index of the latent to keep non-zero
        
        Returns:
            Modified feature_acts with only latent_idx non-zero
        """
        modified_acts = torch.zeros_like(output)
        modified_acts[..., latent_idx] = output[..., latent_idx]
        # validate that the modified acts at the latent index are the same as the no ablation case. 
        # print(f"Modified acts at latent index {latent_idx}: {modified_acts[..., latent_idx]}")

        return modified_acts # TODO: DO I need to use an activation function here?
    

    # Register the activation modification hook
    def register_hook(self, latent_idx):
        # Register hook to capture single latent reconstruction
        self.sae_hook_handles.append(self.sae.hook_sae_acts_post.register_forward_hook(
            lambda m, i, o: self.single_latent_hook(m, i, o, latent_idx=latent_idx)
        ))

    def unregister_hook(self):
        if hasattr(self, 'sae_hook_handles'):
            for hook_handle in self.sae_hook_handles:
                hook_handle.remove()
        self.sae_hook_handles = []

    # def run_hooked_sae(self, hidden_states):
    def activation_modification_hook(self, module, input, output):
        hidden_states = output[0]
        
        # Compute single latent reconstruction on current hidden_states
        with torch.no_grad():
                reconstruction_w_hooks = self.sae(hidden_states)
        
        modified_hidden_states = hidden_states - self.lambda_scale * reconstruction_w_hooks
        
        return modified_hidden_states # TODO: check that this has the correct shape, value, and dtype. 

    def run_encoding_decoding(self, hidden_states):
        """
        Run encoding and Calculate the reconstruction loss for the given hidden states.
        
        Args:
            hidden_states: The hidden states to calculate the loss for
            
        Returns:
            The reconstruction loss
        """
        with torch.no_grad():
            latents = self.sae.encode(hidden_states)
            # Compute single latent reconstruction on current hidden_states
            reconstruction = self.sae(hidden_states)
            # Calculate the mean squared error loss
            loss = F.mse_loss(reconstruction, hidden_states)
        return reconstruction, latents, loss


In [None]:
# Environment settings
import os

os.environ["TORCHDYNAMO_DISABLE"] = "1"
os.environ["USE_TRITON"] = "0"
layer_idx = 12  # Layer to modify

In [None]:
prompt = "5+7= "
inputs = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=True).to(device)
print(f"Input tokens: {inputs}")

# Get original hidden states
with torch.no_grad():
    outputs = model(inputs, output_hidden_states=True)
# print(outputs.shape)
print(len(outputs.hidden_states))
# original_hidden_states = outputs.hidden_states[layer_idx, -1, :]
# print(f"Original hidden states shape: {original_hidden_states.shape}")

In [None]:
print(len(outputs.hidden_states[layer_idx]))

In [None]:
ablation_hook = AblationHook(layer_idx=12, lambda_scale = 1.0,)
reconstruction, latents, loss = ablation_hook.run_encoding_decoding(original_hidden_states)
ablation_hook.register_hook(latent_idx=11301)
ablation_reconstruction, ablation_latents, ablation_loss = ablation_hook.run_encoding_decoding(original_hidden_states)

print(reconstruction.shape)
print(f"Reconstruction loss: {loss.item()}")
print(f"Ablation reconstruction shape: {ablation_reconstruction.shape}")
print(f"Ablation reconstruction loss: {ablation_loss.item()}")
print(f"Latents shape: {latents.shape}")

In [None]:
latents[0,:, 11301] 

In [None]:
ablation_latents[0,:, 11301]

## Quick check: are latents zero for all tokens except the first?

In [None]:
# print current dir
import os
print(f"Current directory: {os.getcwd()}")

In [None]:
from scipy import sparse
latents_npz_path = r"../../latents/addition/layer_12.npz"
latents_npz = sparse.load_npz(latents_npz_path)
print(latents_npz.shape)  # List all arrays in the npz file


In [None]:
print(latents_npz[:,11301].toarray())  # Print the latents for

# Misc


In [None]:
def quick_generation(prompt, max_new_tokens=12):
    inputs = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=True).to(device)
    print(f"Input tokens: {inputs}")
    with torch.no_grad():
        outputs = model.generate(
            inputs,
            max_new_tokens=max_new_tokens,              # Limit tokens for just the number
            do_sample=False,                # Deterministic output
            repetition_penalty=1.0,         # Avoid penalizing repeated digits
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            num_beams=1,                  # Single beam for simplicity
            # # Stop at common continuation tokens
            # bad_words_ids=[[tokenizer.encode(" The")[0]], 
            #               [tokenizer.encode(" So")[0]],
            #               [tokenizer.encode("\n")[0]],
            #               [tokenizer.encode("Here")[0]],
            #               ]
        )
    
    response = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
    # Extract just the number
    return response.strip()  # Assuming the model outputs just the number
    # number_match = re.search(r'^\s*(\d+)', response.strip())
    # return number_match.group(1) if number_match else response.strip()
    
def quiz_model(dataset: str):
    """
    Quiz the model on a dataset of arithmetic problems.
    """
    # Load the dataset
    with open(dataset, 'r') as f:
        problems = f.readlines()
    
    results = []
    for problem in problems:
        problem = problem.strip()
        problem = "You are a calculator. Answer immediately: " + problem
        if problem:
            answer = quick_generation(problem, max_new_tokens=10)
            results.append((problem, answer))
            print(answer)
    return results
# quiz_model("./data/addition.txt")