In [3]:
# lets test forward hooks.
# load one SAE
# load the model
# print attributes/architicture of the sae and the model
# feed forward one example to the SAE
# get the reconstruction of the activations based on one feature
# create a hook that replaces the model's activation with the activation - lambda * reconstruction

In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from sae_lens import SAE
import torch.nn.functional as F
import os
import re
# change directory to d:\\Master's\\gemma_scope_math\
os.chdir("d:\\Master's\\gemma_scope_math")


# Disable gradients for memory efficiency
torch.set_grad_enabled(False)

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


# Model loading

In [2]:
# 4-bit quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b-it",
    device_map='auto',
    quantization_config=bnb_config,
)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

# Utils

In [3]:
def count_tokens(prompt):
    """
    Count the number of tokens in a prompt and return detailed tokenization info.
    
    Args:
        prompt: Input string to tokenize
    
    Returns:
        tuple: (token_count, token_strings, token_ids, formatted_string)
            - token_count: Number of tokens
            - token_strings: List of individual token strings
            - token_ids: List of token IDs
            - formatted_string: Human-readable representation with token boundaries
    """
    # Tokenize the prompt
    tokens = tokenizer.encode(prompt, add_special_tokens=True)
    token_count = len(tokens)
    
    # Get individual token strings
    token_strings = []
    for token_id in tokens:
        token_str = tokenizer.decode([token_id])
        token_strings.append(token_str)
    
    # Create a formatted string showing token boundaries
    formatted_parts = []
    for i, token_str in enumerate(token_strings):
        # Add token boundaries and numbering for clarity
        formatted_parts.append(f"[{i}:'{token_str}']")
    
    formatted_string = " ".join(formatted_parts)
    
    return token_count, token_strings, tokens, formatted_string

# Example usage:
def demo_tokenization(prompt):
    """
    Demonstrate tokenization for a given prompt.
    """
    count, strings, ids, formatted = count_tokens(prompt)
    
    print(f"Original text: '{prompt}'")
    print(f"Token count: {count}")
    print(f"Token IDs: {ids}")
    print(f"Token strings: {strings}")
    print(f"Formatted view: {formatted}")
    print("-" * 50)

# count total number of output tokens in the results
def count_total_output_tokens(results):
    """
    Count total number of output tokens in the results.
    
    Args:
        results: List of tuples (problem, answer)
    
    Returns:
        int: Total number of output tokens
    """
    total_tokens = 0
    for _, answer in results:
        # Count tokens in the answer
        count, _, _, _ = count_tokens(answer)
        total_tokens += count
    
    return total_tokens

def calc_correct_answer(problem):
    """
    Calculate the correct answer for a given problem.
    
    Args:
        problem: Problem string (e.g., "2 + 2")
    
    Returns:
        str: Correct answer as a string
    """
    # Remove any non-numeric characters except for +, -, *, /
    clean_problem = re.sub(r'[^\d\s\+\-\*/]', '', problem)
    
    try:
        # Evaluate the expression safely
        answer = eval(clean_problem)
        return str(answer)
    except Exception as e:
        print(f"Error evaluating problem '{problem}': {e}")
        return "ERROR"


## generation utils

In [15]:
def batch_generation(prompts, max_new_tokens=10, batch_size=8):
    """
    Generate responses for multiple prompts in batches with improved memory management.
    """
    all_responses = []
    
    for i in range(0, len(prompts), batch_size):
        batch_prompts = prompts[i:i + batch_size]
        
        # Tokenize batch
        inputs = tokenizer(
            batch_prompts, 
            return_tensors="pt", 
            padding=True, 
            truncation=True,
            add_special_tokens=True
        ).to(device)
        
        with torch.no_grad():
            outputs = model.generate(
                inputs.input_ids,
                attention_mask=inputs.attention_mask,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                repetition_penalty=1.0,
                num_beams=1,  
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )
            
            # Move outputs to CPU immediately and decode
            outputs_cpu = outputs.cpu()
            
            # Decode batch responses
            batch_responses = []
            for j, output in enumerate(outputs_cpu):
                # Get only the new tokens (after the input)
                new_tokens = output[len(inputs.input_ids[j]):]
                response = tokenizer.decode(new_tokens, skip_special_tokens=True)
                batch_responses.append(response.strip())
            
            # Clean up GPU tensors
            del outputs
            del inputs
            torch.cuda.empty_cache()  # Force GPU memory cleanup
        
        all_responses.extend(batch_responses)
        
        # Optional: print progress
        print(f"Processed {min(i + batch_size, len(prompts))}/{len(prompts)} prompts")
    
    return all_responses

def batch_quiz_model(dataset_path, max_new_tokens=10, batch_size=8, start_batch=0, end_batch=None, prefix="You are a calculator. Answer immediately: ", postfix=""):
    """
    Quiz the model with improved memory management.
    """
    # Load the dataset
    with open(dataset_path, 'r') as f:
        problems = [line.strip() for line in f.readlines() if line.strip()]
    
    # Calculate total number of batches
    total_batches = (len(problems) + batch_size - 1) // batch_size
    
    # Set end_batch if not specified
    if end_batch is None:
        end_batch = total_batches
    
    # Validate batch indices
    if start_batch < 0 or start_batch >= total_batches:
        raise ValueError(f"start_batch {start_batch} is out of range [0, {total_batches})")
    if end_batch < start_batch or end_batch > total_batches:
        raise ValueError(f"end_batch {end_batch} is out of range [{start_batch}, {total_batches}]")
    
    # Calculate problem indices for the specified batch range
    start_idx = start_batch * batch_size
    end_idx = min(end_batch * batch_size, len(problems))
    
    # Get subset of problems for this run
    subset_problems = problems[start_idx:end_idx]
    
    # Add prefix/postfix to problems
    prompts = [prefix + problem + postfix for problem in subset_problems]
    
    print(f"Processing batches {start_batch} to {end_batch-1} ({len(subset_problems)} problems) in batches of {batch_size}...")
    print(f"Total dataset size: {len(problems)} problems ({total_batches} total batches)")
    
    # Generate answers in batches
    answers = batch_generation(prompts, max_new_tokens=max_new_tokens, batch_size=batch_size)
    
    # Combine problems with answers - store prompts, not the full dataset
    results = list(zip(prompts, answers))
    
    # Clean up large variables
    del subset_problems
    del prompts
    del answers
    torch.cuda.empty_cache()  # Force GPU memory cleanup
    return results

def save_results(results, output_path, append=True):
    """
    Save results to file with option to append or overwrite.
    
    Args:
        results: List of tuples (problem, answer)
        output_path: Path to output file
        append: If True, append to existing file. If False, overwrite
    """
    # Create directory if it doesn't exist
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    mode = 'a' if append else 'w'
    with open(output_path, mode) as f:
        for problem, answer in results:
            f.write(f"{problem}\t{answer}\n")
    
    print(f"Results {'appended to' if append else 'saved to'} {output_path}")

def find_optimal_batch_size():
    """Find the largest batch size that fits in GPU memory"""
    batch_size = 256
    
    while True:
        try:
            # Test with your actual inference function
            test_prompts = ["5+7="] * batch_size
            _ = batch_generation(test_prompts, max_new_tokens=6, batch_size=batch_size)
            batch_size *= 2  # Double until we hit memory limit
        except torch.cuda.OutOfMemoryError:
            # Use half of the failed size
            optimal_size = batch_size // 4  # Go back to last working size
            break
    
    return optimal_size

In [5]:
def benchmark_correct_answers(dataset_path, batch_size=64, start_batch=0, end_batch=None,):
    """
    Load the same questions as batch_quiz_model and calculate correct answers.
    
    Args:
        dataset_path: Path to the dataset file
        batch_size: Number of problems to process simultaneously (same as used in batch_quiz_model)
        start_batch: Starting batch index (0-based, same as used in batch_quiz_model)
        end_batch: Ending batch index (exclusive, same as used in batch_quiz_model)
        prefix: Prefix that was added to each problem in batch_quiz_model
        postfix: Postfix that was added to each problem in batch_quiz_model
    
    Returns:
        List of tuples (full_prompt, correct_answer)
    """
    # Load the dataset
    with open(dataset_path, 'r') as f:
        problems = [line.strip() for line in f.readlines() if line.strip()]
    
    # Calculate total number of batches
    total_batches = (len(problems) + batch_size - 1) // batch_size
    
    # Set end_batch if not specified
    if end_batch is None:
        end_batch = total_batches
    
    # Validate batch indices (same validation as batch_quiz_model)
    if start_batch < 0 or start_batch >= total_batches:
        raise ValueError(f"start_batch {start_batch} is out of range [0, {total_batches})")
    if end_batch < start_batch or end_batch > total_batches:
        raise ValueError(f"end_batch {end_batch} is out of range [{start_batch}, {total_batches}]")
    
    # Calculate problem indices for the specified batch range (same as batch_quiz_model)
    start_idx = start_batch * batch_size
    end_idx = min(end_batch * batch_size, len(problems))
    
    # Get subset of problems for this run
    subset_problems = problems[start_idx:end_idx]
    
    results = []
    for problem in subset_problems:
        # Create the same full prompt as batch_quiz_model does
        full_prompt =problem 
        
        # Calculate correct answer for the original problem (without prefix/postfix)
        correct_answer = calc_correct_answer(problem)
        
        results.append((full_prompt, correct_answer))
    
    print(f"Calculated correct answers for batches {start_batch} to {end_batch-1} ({len(subset_problems)} problems)")
    
    return results

In [7]:
def extract_first_digits(model_answer, correct_answer):
    """
    Extract the first X digits from the model answer, where X is the length of the correct answer.
    Removes any repeated question pattern first and handles step-by-step numbering and negative signs.
    
    Args:
        model_answer: The model's response string
        correct_answer: The correct answer as a string
    
    Returns:
        str: First X digits from model answer, or empty string if not enough digits found
    """
    import re
    
    # Remove question pattern if it exists (operand1{operation}operand2 with no spaces)
    # This pattern matches: digits, operation (+,-,*,/), digits, optional =
    question_pattern = r'\d+[\+\-\*/]\d+\s*=?\s*'
    cleaned_answer = re.sub(question_pattern, '', model_answer)
    
    # Remove step numbering patterns (e.g., "1.", "2.", etc.)
    step_pattern = r'\b\d+\.\s*'
    cleaned_answer = re.sub(step_pattern, '', cleaned_answer)
    
    # Get the length of the correct answer (including minus sign if present)
    target_length = len(correct_answer.strip())
    
    # Extract the first number (with optional minus sign) from the cleaned model answer
    # This pattern matches: optional minus sign followed by digits
    number_pattern = r'-?\d+'
    numbers = re.findall(number_pattern, cleaned_answer)
    
    if numbers:
        first_number = numbers[0]
        # Truncate to target length if needed
        if len(first_number) >= target_length:
            return first_number[:target_length]
        else:
            return first_number
    else:
        return ""

In [8]:
def calculate_accuracy(model_results, correct_answers):
    """
    Calculate accuracy by comparing extracted numbers from model answers with correct answers.
    Handles negative numbers and step-by-step responses.
    
    Args:
        model_results: List of tuples (prompt, model_answer) from batch_quiz_model
        correct_answers: List of tuples (prompt, correct_answer) from benchmark_correct_answers
    
    Returns:
        tuple: (accuracy, correct_count, total_count, detailed_results, skipped_count)
            - accuracy: Accuracy as a float between 0 and 1
            - correct_count: Number of correct answers
            - total_count: Total number of valid questions (excluding skipped)
            - detailed_results: List of tuples (prompt, correct_answer, model_answer, extracted_number, is_correct)
            - skipped_count: Number of prompts skipped due to no valid numbers found
    """
    import re
    
    if len(model_results) != len(correct_answers):
        raise ValueError(f"Mismatch in result lengths: {len(model_results)} vs {len(correct_answers)}")
    
    correct_count = 0
    skipped_count = 0
    detailed_results = []
    
    for i, ((model_prompt, model_answer), (correct_prompt, correct_answer)) in enumerate(zip(model_results, correct_answers)):
        # Remove question pattern if it exists
        question_pattern = r'\d+[\+\-\*/]\d+\s*=?\s*'
        cleaned_answer = re.sub(question_pattern, '', model_answer)
        
        # Remove step numbering patterns (e.g., "1.", "2.", etc.)
        step_pattern = r'\b\d+\.\s*'
        cleaned_answer = re.sub(step_pattern, '', cleaned_answer)
        
        # Extract the first number (with optional minus sign) from the cleaned model answer
        number_pattern = r'-?\d+'
        numbers = re.findall(number_pattern, cleaned_answer)
        
        # Skip if no numbers found
        if not numbers:
            skipped_count += 1
            continue
        
        # Get the first extracted number
        extracted_number = numbers[0]
        
        # Check if extracted number matches the correct answer
        is_correct = extracted_number == correct_answer.strip()
        if is_correct:
            correct_count += 1
        
        detailed_results.append((
            model_prompt,
            correct_answer,
            model_answer,
            extracted_number,
            is_correct
        ))
    
    total_count = len(detailed_results)  # Only count valid results
    accuracy = correct_count / total_count if total_count > 0 else 0.0
    
    return accuracy, correct_count, total_count, detailed_results, skipped_count

def print_accuracy_report(accuracy, correct_count, total_count, detailed_results, skipped_count=0, show_errors_only=False, max_examples=10):
    """
    Print a detailed accuracy report.
    
    Args:
        accuracy: Accuracy as a float
        correct_count: Number of correct answers
        total_count: Total number of questions
        detailed_results: Detailed results from calculate_accuracy
        skipped_count: Number of prompts skipped due to insufficient digits
        show_errors_only: If True, only show incorrect answers
        max_examples: Maximum number of examples to show
    """
    print(f"Accuracy: {accuracy:.3f} ({correct_count}/{total_count})")
    print(f"Correct: {correct_count}, Incorrect: {total_count - correct_count}")
    if skipped_count > 0:
        print(f"Skipped (insufficient digits): {skipped_count}")
    print("-" * 80)
    
    examples_shown = 0
    for prompt, correct, model_answer, extracted, is_correct in detailed_results:
        if show_errors_only and is_correct:
            continue
        
        if examples_shown >= max_examples:
            break
        
        status = "✓" if is_correct else "✗"
        print(f"{status} Prompt: {prompt}")
        print(f"  Correct: {correct}")
        print(f"  Model output: '{model_answer}'")
        print(f"  Extracted: '{extracted}'")
        print()
        
        examples_shown += 1
    
    if examples_shown < len(detailed_results):
        remaining = len(detailed_results) - examples_shown
        print(f"... and {remaining} more examples")

In [9]:
import subprocess
import psutil
import shutil

# Optional: Use pynvml if you want more structured GPU info
try:
    from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo
    pynvml_available = True
except ImportError:
    pynvml_available = False

def show_gpu_usage():
    print("=== GPU ===")
    if pynvml_available:
        nvmlInit()
        handle = nvmlDeviceGetHandleByIndex(0)  # GPU 0
        meminfo = nvmlDeviceGetMemoryInfo(handle)
        print(f"Total: {meminfo.total / 1024**2:.2f} MB")
        print(f"Used:  {meminfo.used / 1024**2:.2f} MB")
        print(f"Free:  {meminfo.free / 1024**2:.2f} MB")
    else:
        print(subprocess.getoutput("nvidia-smi"))

def show_ram_usage():
    print("\n=== RAM ===")
    mem = psutil.virtual_memory()
    print(f"Total: {mem.total / 1024**3:.2f} GB")
    print(f"Used:  {mem.used / 1024**3:.2f} GB")
    print(f"Free:  {mem.available / 1024**3:.2f} GB")

def show_disk_usage(path="/"):
    print("\n=== Disk ===")
    total, used, free = shutil.disk_usage(path)
    print(f"Total: {total / 1024**3:.2f} GB")
    print(f"Used:  {used / 1024**3:.2f} GB")
    print(f"Free:  {free / 1024**3:.2f} GB")

# Benchmark model performnce

In [10]:
batch_size = 512  # Set your desired batch size
start_batch = 0  # Starting batch index
end_batch = 1    # Ending batch index (exclusive)

In [11]:
show_gpu_usage()
show_ram_usage()
show_disk_usage()

=== GPU ===
Total: 6144.00 MB
Used:  4206.98 MB
Free:  1937.02 MB

=== RAM ===
Total: 15.85 GB
Used:  10.04 GB
Free:  5.80 GB

=== Disk ===
Total: 874.14 GB
Used:  635.67 GB
Free:  238.47 GB


In [12]:
prefix = ""
postfix = ""
results = batch_quiz_model("./data/random_addition.txt", max_new_tokens=12 ,batch_size=batch_size, start_batch=start_batch, end_batch=end_batch, prefix=prefix, postfix=postfix)

correct_answers = benchmark_correct_answers("./data/random_addition.txt", batch_size=batch_size, start_batch=start_batch, end_batch=end_batch)

# Calculate accuracy with skipping
accuracy, correct_count, total_count, detailed_results, skipped_count = calculate_accuracy(results, correct_answers)

# Print report including skipped count
print_accuracy_report(accuracy, correct_count, total_count, detailed_results, skipped_count, show_errors_only=True, max_examples=10)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Processing batches 0 to 0 (512 problems) in batches of 512...
Total dataset size: 10000 problems (20 total batches)
Processed 512/512 prompts
Calculated correct answers for batches 0 to 0 (512 problems)
Accuracy: 0.234 (103/441)
Correct: 103, Incorrect: 338
Skipped (insufficient digits): 71
--------------------------------------------------------------------------------
✗ Prompt: 6734+7265=
  Correct: 13999
  Model output: '15000

15000-'
  Extracted: '15000'

✗ Prompt: 1466+5426=
  Correct: 6892
  Model output: 'Here's how to solve it:

**1'
  Extracted: '1'

✗ Prompt: 6578+9322=
  Correct: 15900
  Model output: '20000

20000-'
  Extracted: '20000'

✗ Prompt: 7949+3433=
  Correct: 11382
  Model output: 'Here's how to solve it:

**1'
  Extracted: '1'

✗ Prompt: 7420+2184=
  Correct: 9604
  Model output: '**Answer:** 

4604'
  Extracted: '4604'

✗ Prompt: 7396+9666=
  Correct: 17062
  Model output: 'Here's how to solve it:

**1'
  Extracted: '1'

✗ Prompt: 3047+3747=
  Correct: 6794
  M

In [13]:
prefix = "Answer directly: "
postfix = " "
results = batch_quiz_model("./data/random_addition.txt", max_new_tokens=12 ,batch_size=batch_size, start_batch=start_batch, end_batch=end_batch, prefix=prefix, postfix=postfix)

correct_answers = benchmark_correct_answers("./data/random_addition.txt", batch_size=batch_size, start_batch=start_batch, end_batch=end_batch)

# Calculate accuracy with skipping
accuracy, correct_count, total_count, detailed_results, skipped_count = calculate_accuracy(results, correct_answers)

# Print report including skipped count
print_accuracy_report(accuracy, correct_count, total_count, detailed_results, skipped_count, show_errors_only=True, max_examples=10)


Processing batches 0 to 0 (512 problems) in batches of 512...
Total dataset size: 10000 problems (20 total batches)
Processed 512/512 prompts
Calculated correct answers for batches 0 to 0 (512 problems)
Accuracy: 0.851 (406/477)
Correct: 406, Incorrect: 71
Skipped (insufficient digits): 35
--------------------------------------------------------------------------------
✗ Prompt: Answer directly: 6734+7265= 
  Correct: 13999
  Model output: '**Answer:** 14000'
  Extracted: '14000'

✗ Prompt: Answer directly: 6578+9322= 
  Correct: 15900
  Model output: '**Answer:** 15890'
  Extracted: '15890'

✗ Prompt: Answer directly: 2267+2528= 
  Correct: 4795
  Model output: '**Answer:** 4805'
  Extracted: '4805'

✗ Prompt: Answer directly: 4943+8555= 
  Correct: 13498
  Model output: '**Answer:** 13508'
  Extracted: '13508'

✗ Prompt: Answer directly: 1995+8629= 
  Correct: 10624
  Model output: '**Answer:** 10628'
  Extracted: '10628'

✗ Prompt: Answer directly: 1064+9006= 
  Correct: 10070
  Mod

In [19]:
# Run all
show_gpu_usage()
show_ram_usage()
show_disk_usage()


=== GPU ===
Total: 6144.00 MB
Used:  5535.83 MB
Free:  608.17 MB

=== RAM ===
Total: 15.85 GB
Used:  11.51 GB
Free:  4.33 GB

=== Disk ===
Total: 874.14 GB
Used:  635.67 GB
Free:  238.47 GB


In [20]:
prefix = ""
postfix = " Answer: "
results = batch_quiz_model("./data/random_addition.txt", max_new_tokens=12 ,batch_size=batch_size, start_batch=start_batch, end_batch=end_batch, prefix=prefix, postfix=postfix)
correct_answers = benchmark_correct_answers("./data/random_addition.txt", batch_size=batch_size, start_batch=start_batch, end_batch=end_batch)

# Calculate accuracy with skipping
accuracy, correct_count, total_count, detailed_results, skipped_count = calculate_accuracy(results, correct_answers)

# Print report including skipped count
print_accuracy_report(accuracy, correct_count, total_count, detailed_results, skipped_count, show_errors_only=True, max_examples=10)


Processing batches 0 to 0 (512 problems) in batches of 512...
Total dataset size: 10000 problems (20 total batches)
Processed 512/512 prompts
Calculated correct answers for batches 0 to 0 (512 problems)
Accuracy: 0.341 (155/455)
Correct: 155, Incorrect: 300
Skipped (insufficient digits): 57
--------------------------------------------------------------------------------
✗ Prompt: 8270+1860= Answer: 
  Correct: 10130
  Model output: '**Answer:** 31300'
  Extracted: '31300'

✗ Prompt: 6734+7265= Answer: 
  Correct: 13999
  Model output: '**15000**'
  Extracted: '15000'

✗ Prompt: 6578+9322= Answer: 
  Correct: 15900
  Model output: '**18890**'
  Extracted: '18890'

✗ Prompt: 7420+2184= Answer: 
  Correct: 9604
  Model output: '**Solution:**

```
4604'
  Extracted: '4604'

✗ Prompt: 7396+9666= Answer: 
  Correct: 17062
  Model output: '**19062**'
  Extracted: '19062'

✗ Prompt: 3047+3747= Answer: 
  Correct: 6794
  Model output: '**7794**'
  Extracted: '7794'

✗ Prompt: 2267+2528= Answer:

In [21]:
# Run all
show_gpu_usage()
show_ram_usage()
show_disk_usage()

=== GPU ===
Total: 6144.00 MB
Used:  4936.20 MB
Free:  1207.80 MB

=== RAM ===
Total: 15.85 GB
Used:  12.11 GB
Free:  3.73 GB

=== Disk ===
Total: 874.14 GB
Used:  635.67 GB
Free:  238.47 GB


In [22]:
prefix = ""
postfix = " "
results = batch_quiz_model("./data/random_addition.txt", max_new_tokens=12 ,batch_size=batch_size, start_batch=start_batch, end_batch=end_batch, prefix=prefix, postfix=postfix)
correct_answers = benchmark_correct_answers("./data/random_addition.txt", batch_size=batch_size, start_batch=start_batch, end_batch=end_batch)

# Calculate accuracy with skipping
accuracy, correct_count, total_count, detailed_results, skipped_count = calculate_accuracy(results, correct_answers)

# Print report including skipped count
print_accuracy_report(accuracy, correct_count, total_count, detailed_results, skipped_count, show_errors_only=True, max_examples=10)


Processing batches 0 to 0 (512 problems) in batches of 512...
Total dataset size: 10000 problems (20 total batches)
Processed 512/512 prompts
Calculated correct answers for batches 0 to 0 (512 problems)
Accuracy: 0.298 (93/312)
Correct: 93, Incorrect: 219
Skipped (insufficient digits): 200
--------------------------------------------------------------------------------
✗ Prompt: 6734+7265= 
  Correct: 13999
  Model output: '15000

15000-'
  Extracted: '15000'

✗ Prompt: 7420+2184= 
  Correct: 9604
  Model output: '**Answer:** 

**4604**'
  Extracted: '4604'

✗ Prompt: 4005+5658= 
  Correct: 9663
  Model output: '10663

10663 is'
  Extracted: '10663'

✗ Prompt: 2899+8734= 
  Correct: 11633
  Model output: 'Here's how to solve it:

**1'
  Extracted: '1'

✗ Prompt: 9838+6393= 
  Correct: 16231
  Model output: '**Answer:** 

**13231'
  Extracted: '13231'

✗ Prompt: 8041+7235= 
  Correct: 15276
  Model output: '14280

14280-'
  Extracted: '14280'

✗ Prompt: 6486+8099= 
  Correct: 14585
  Mo

In [24]:
prefix = ""
postfix = " final answer: "
results = batch_quiz_model("./data/random_addition.txt", max_new_tokens=12 ,batch_size=batch_size, start_batch=start_batch, end_batch=end_batch, prefix=prefix, postfix=postfix)

correct_answers = benchmark_correct_answers("./data/random_addition.txt", batch_size=batch_size, start_batch=start_batch, end_batch=end_batch)

# Calculate accuracy with skipping
accuracy, correct_count, total_count, detailed_results, skipped_count = calculate_accuracy(results, correct_answers)

# Print report including skipped count
print_accuracy_report(accuracy, correct_count, total_count, detailed_results, skipped_count, show_errors_only=True, max_examples=10)

Processing batches 0 to 0 (512 problems) in batches of 512...
Total dataset size: 10000 problems (20 total batches)
Processed 512/512 prompts
Calculated correct answers for batches 0 to 0 (512 problems)
Accuracy: 0.419 (184/439)
Correct: 184, Incorrect: 255
Skipped (insufficient digits): 73
--------------------------------------------------------------------------------
✗ Prompt: 8270+1860= final answer: 
  Correct: 10130
  Model output: '**Answer:** 31300'
  Extracted: '31300'

✗ Prompt: 6734+7265= final answer: 
  Correct: 13999
  Model output: '14999

**Explanation:**

You simply'
  Extracted: '14999'

✗ Prompt: 6578+9322= final answer: 
  Correct: 15900
  Model output: '16890

Here's how to solve'
  Extracted: '16890'

✗ Prompt: 7949+3433= final answer: 
  Correct: 11382
  Model output: '**Answer:** 7382'
  Extracted: '7382'

✗ Prompt: 7420+2184= final answer: 
  Correct: 9604
  Model output: '**Answer:** 4604'
  Extracted: '4604'

✗ Prompt: 7396+9666= final answer: 
  Correct: 170

In [25]:
prefix = "Answer directly: "
postfix = " "
results = batch_quiz_model("./data/addition.txt", max_new_tokens=8 ,batch_size=batch_size, start_batch=start_batch, end_batch=end_batch, prefix=prefix, postfix=postfix)

correct_answers = benchmark_correct_answers("./data/addition.txt", batch_size=batch_size, start_batch=start_batch, end_batch=end_batch)

# Calculate accuracy with skipping
accuracy, correct_count, total_count, detailed_results, skipped_count = calculate_accuracy(results, correct_answers)

# Print report including skipped count
print_accuracy_report(accuracy, correct_count, total_count, detailed_results, skipped_count, show_errors_only=True, max_examples=10)


Processing batches 0 to 0 (512 problems) in batches of 512...
Total dataset size: 10000 problems (20 total batches)
Processed 512/512 prompts
Calculated correct answers for batches 0 to 0 (512 problems)
Accuracy: 0.986 (492/499)
Correct: 492, Incorrect: 7
Skipped (insufficient digits): 13
--------------------------------------------------------------------------------
✗ Prompt: Answer directly: 0+29= 
  Correct: 29
  Model output: 'Answer: 30'
  Extracted: '30'

✗ Prompt: Answer directly: 0+39= 
  Correct: 39
  Model output: 'Answer: 40'
  Extracted: '40'

✗ Prompt: Answer directly: 0+49= 
  Correct: 49
  Model output: '**Answer:** 50'
  Extracted: '50'

✗ Prompt: Answer directly: 0+59= 
  Correct: 59
  Model output: 'Answer: 60'
  Extracted: '60'

✗ Prompt: Answer directly: 0+79= 
  Correct: 79
  Model output: '**Answer:** 80'
  Extracted: '80'

✗ Prompt: Answer directly: 0+89= 
  Correct: 89
  Model output: '**Answer:** 99'
  Extracted: '99'

✗ Prompt: Answer directly: 0+99= 
  Corre

In [26]:
prefix = "Answer directly: "
postfix = " "
dataset = "./data/subtraction.txt"
results = batch_quiz_model(dataset, max_new_tokens=20, batch_size=batch_size, start_batch=start_batch, end_batch=end_batch, prefix=prefix, postfix=postfix)

correct_answers = benchmark_correct_answers(dataset, batch_size=batch_size, start_batch=start_batch, end_batch=end_batch)

# Calculate accuracy with skipping
accuracy, correct_count, total_count, detailed_results, skipped_count = calculate_accuracy(results, correct_answers)

# Print report including skipped count
print_accuracy_report(accuracy, correct_count, total_count, detailed_results, skipped_count, show_errors_only=True, max_examples=10)

Processing batches 0 to 0 (512 problems) in batches of 512...
Total dataset size: 10000 problems (20 total batches)
Processed 512/512 prompts
Calculated correct answers for batches 0 to 0 (512 problems)
Accuracy: 0.462 (163/353)
Correct: 163, Incorrect: 190
Skipped (insufficient digits): 159
--------------------------------------------------------------------------------
✗ Prompt: Answer directly: 0-8= 
  Correct: -8
  Model output: '**Answer:** 8'
  Extracted: '8'

✗ Prompt: Answer directly: 0-15= 
  Correct: -15
  Model output: '**Answer:** 15'
  Extracted: '15'

✗ Prompt: Answer directly: 0-18= 
  Correct: -18
  Model output: '**Answer:** 18'
  Extracted: '18'

✗ Prompt: Answer directly: 0-20= 
  Correct: -20
  Model output: '**Answer:** 20'
  Extracted: '20'

✗ Prompt: Answer directly: 0-21= 
  Correct: -21
  Model output: '**Answer:** 0-21 = 21'
  Extracted: '21'

✗ Prompt: Answer directly: 0-23= 
  Correct: -23
  Model output: '**Answer:** 
-24'
  Extracted: '-24'

✗ Prompt: Answ

In [27]:
# Run all
show_gpu_usage()
show_ram_usage()
show_disk_usage()

=== GPU ===
Total: 6144.00 MB
Used:  4342.87 MB
Free:  1801.13 MB

=== RAM ===
Total: 15.85 GB
Used:  12.35 GB
Free:  3.49 GB

=== Disk ===
Total: 874.14 GB
Used:  635.67 GB
Free:  238.47 GB


In [28]:
prefix = "Answer directly: "
postfix = " "
dataset = "./data/random_subtraction.txt"
results = batch_quiz_model(dataset, max_new_tokens=25, batch_size=batch_size, start_batch=start_batch, end_batch=end_batch, prefix=prefix, postfix=postfix)

correct_answers = benchmark_correct_answers(dataset, batch_size=batch_size, start_batch=start_batch, end_batch=end_batch)

# Calculate accuracy with skipping
accuracy, correct_count, total_count, detailed_results, skipped_count = calculate_accuracy(results, correct_answers)

# Print report including skipped count
print_accuracy_report(accuracy, correct_count, total_count, detailed_results, skipped_count, show_errors_only=True, max_examples=10)

Processing batches 0 to 0 (512 problems) in batches of 512...
Total dataset size: 10000 problems (20 total batches)
Processed 512/512 prompts
Calculated correct answers for batches 0 to 0 (512 problems)
Accuracy: 0.766 (353/461)
Correct: 353, Incorrect: 108
Skipped (insufficient digits): 51
--------------------------------------------------------------------------------
✗ Prompt: Answer directly: 6390-6191= 
  Correct: 199
  Model output: '200

**Explanation:**

You're asking for a direct answer, and I can provide that.'
  Extracted: '200'

✗ Prompt: Answer directly: 6734-7265= 
  Correct: -531
  Model output: '**Answer:** -5311'
  Extracted: '-5311'

✗ Prompt: Answer directly: 3568-6463= 
  Correct: -2895
  Model output: '**Answer:**  -995'
  Extracted: '-995'

✗ Prompt: Answer directly: 3027-3695= 
  Correct: -668
  Model output: '**Answer:** -628'
  Extracted: '-628'

✗ Prompt: Answer directly: 7278-9392= 
  Correct: -2114
  Model output: '**Answer:** -2104'
  Extracted: '-2104'

✗ 

# Ablation

In [None]:
# Activation Modification Hook
class AblationHook:
    def __init__(self, layer_idx, lambda_scale,):
        """
        Initialize the ablation hook.
        
        Args:
            sae: Sparse Autoencoder instance
            lambda_scale: Scaling factor for the single latent reconstruction
        """
        # Initialize SAE for layer 12
        sae_id = f"layer_{layer_idx}/width_16k/canonical"
        sae, _, _ = SAE.from_pretrained(
            release="gemma-scope-2b-pt-res-canonical",
            sae_id=sae_id
        )

        model_dtype = next(model.parameters()).dtype
        sae = sae.to(device=device, dtype=model_dtype)
        sae.eval()

        self.sae = sae
        self.lambda_scale = lambda_scale
        self.sae_hook_handles = []
    

    # Single Latent Reconstruction Hook
    def single_latent_hook(self, module, input, output, latent_idx):
        """
        Forward hook to zero all latent activations except for the specified index.
        
        Args:
            module: The HookPoint module (hook_sae_acts_post)
            input: Input to the hook (not used here)
            output: feature_acts tensor of shape [..., d_sae]
            latent_idx: Index of the latent to keep non-zero
        
        Returns:
            Modified feature_acts with only latent_idx non-zero
        """
        modified_acts = torch.zeros_like(output)
        modified_acts[..., latent_idx] = output[..., latent_idx]
        # validate that the modified acts at the latent index are the same as the no ablation case. 
        # print(f"Modified acts at latent index {latent_idx}: {modified_acts[..., latent_idx]}")

        return modified_acts # TODO: DO I need to use an activation function here?
    

    # Register the activation modification hook
    def register_hook(self, latent_idx):
        # Register hook to capture single latent reconstruction
        self.sae_hook_handles.append(self.sae.hook_sae_acts_post.register_forward_hook(
            lambda m, i, o: self.single_latent_hook(m, i, o, latent_idx=latent_idx)
        ))

    def unregister_hook(self):
        if hasattr(self, 'sae_hook_handles'):
            for hook_handle in self.sae_hook_handles:
                hook_handle.remove()
        self.sae_hook_handles = []

    # def run_hooked_sae(self, hidden_states):
    def activation_modification_hook(self, module, input, output):
        hidden_states = output[0]
        
        # Compute single latent reconstruction on current hidden_states
        with torch.no_grad():
                reconstruction_w_hooks = self.sae(hidden_states)
        
        modified_hidden_states = hidden_states - self.lambda_scale * reconstruction_w_hooks
        
        return modified_hidden_states # TODO: check that this has the correct shape, value, and dtype. 

    def run_encoding_decoding(self, hidden_states):
        """
        Run encoding and Calculate the reconstruction loss for the given hidden states.
        
        Args:
            hidden_states: The hidden states to calculate the loss for
            
        Returns:
            The reconstruction loss
        """
        with torch.no_grad():
            latents = self.sae.encode(hidden_states)
            # Compute single latent reconstruction on current hidden_states
            reconstruction = self.sae(hidden_states)
            # Calculate the mean squared error loss
            loss = F.mse_loss(reconstruction, hidden_states)
        return reconstruction, latents, loss


In [None]:
# Environment settings
import os
os.environ["TORCHDYNAMO_DISABLE"] = "1"
os.environ["USE_TRITON"] = "0"
layer_idx = 12  # Layer to modify

In [None]:
prompt = "5+7= "
inputs = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=True).to(device)
print(f"Input tokens: {inputs}")

# Get original hidden states
with torch.no_grad():
    outputs = model(inputs, output_hidden_states=True)
# print(outputs.shape)
print(len(outputs.hidden_states))
# original_hidden_states = outputs.hidden_states[layer_idx, -1, :]
# print(f"Original hidden states shape: {original_hidden_states.shape}")

In [None]:
print(len(outputs.hidden_states[layer_idx]))

In [None]:
ablation_hook = AblationHook(layer_idx=12, lambda_scale = 1.0,)
reconstruction, latents, loss = ablation_hook.run_encoding_decoding(original_hidden_states)
ablation_hook.register_hook(latent_idx=11301)
ablation_reconstruction, ablation_latents, ablation_loss = ablation_hook.run_encoding_decoding(original_hidden_states)

print(reconstruction.shape)
print(f"Reconstruction loss: {loss.item()}")
print(f"Ablation reconstruction shape: {ablation_reconstruction.shape}")
print(f"Ablation reconstruction loss: {ablation_loss.item()}")
print(f"Latents shape: {latents.shape}")

In [None]:
latents[0,:, 11301] 

In [None]:
ablation_latents[0,:, 11301]

## Quick check: are latents zero for all tokens except the first?

In [None]:
# print current dir
import os
print(f"Current directory: {os.getcwd()}")

In [None]:
from scipy import sparse
latents_npz_path = r"../../latents/addition/layer_12.npz"
latents_npz = sparse.load_npz(latents_npz_path)
print(latents_npz.shape)  # List all arrays in the npz file


In [None]:
print(latents_npz[:,11301].toarray())  # Print the latents for

# Misc


In [None]:
def quick_generation(prompt, max_new_tokens=12):
    inputs = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=True).to(device)
    print(f"Input tokens: {inputs}")
    with torch.no_grad():
        outputs = model.generate(
            inputs,
            max_new_tokens=max_new_tokens,              # Limit tokens for just the number
            do_sample=False,                # Deterministic output
            repetition_penalty=1.0,         # Avoid penalizing repeated digits
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            num_beams=1,                  # Single beam for simplicity
            # # Stop at common continuation tokens
            # bad_words_ids=[[tokenizer.encode(" The")[0]], 
            #               [tokenizer.encode(" So")[0]],
            #               [tokenizer.encode("\n")[0]],
            #               [tokenizer.encode("Here")[0]],
            #               ]
        )
    
    response = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
    # Extract just the number
    return response.strip()  # Assuming the model outputs just the number
    # number_match = re.search(r'^\s*(\d+)', response.strip())
    # return number_match.group(1) if number_match else response.strip()
    
def quiz_model(dataset: str):
    """
    Quiz the model on a dataset of arithmetic problems.
    """
    # Load the dataset
    with open(dataset, 'r') as f:
        problems = f.readlines()
    
    results = []
    for problem in problems:
        problem = problem.strip()
        problem = "You are a calculator. Answer immediately: " + problem
        if problem:
            answer = quick_generation(problem, max_new_tokens=10)
            results.append((problem, answer))
            print(answer)
    return results
# quiz_model("./data/addition.txt")