In [1]:
reciever_heads = [(29, 24),
 (17, 0),
 (24, 7),
 (25, 22),
 (23, 8),
 (18, 12),
 (23, 23),
 (21, 4),
 (19, 17),
 (18, 14),
 (30, 17),
 (19, 27),
 (28, 22),
 (1, 17),
 (27, 1),
 (24, 1),
 (26, 10),
 (26, 24),
 (1, 16),
 (24, 5)]

In [2]:
import os

import json

def process_problem_data(base_path):

    """

    Iterates through all problem directories, extracts problem statements

    and sentences from `chunks_labeled.json`, and returns a list of dictionaries.


    Args:

        base_path (str): The path to the directory containing all the problems

                         (e.g., 'math-rollouts/.../correct_base_solution').


    Returns:

        list: A list of dictionaries, where each dictionary contains the problem

              and all sentences for a given problem directory.

    """

    all_problem_data = []


    # Check if the base path exists

    if not os.path.isdir(base_path):

        print(f"Error: The directory '{base_path}' was not found.")

        return all_problem_data

    print(f"Found problem directory: {base_path}")


    # List all entries in the base directory

    problem_dirs = [d for d in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, d))]

    print(f"Found problem directory: {problem_dirs}")


    if not problem_dirs:

        print(f"No problem directories found in '{base_path}'.")

        return all_problem_data


    # Iterate through each problem directory (e.g., problem_330, problem_1591)

    for problem_name in problem_dirs:

        problem_path = os.path.join(base_path, problem_name)

       

        # Define the file paths for the problem and chunks

        problem_file = os.path.join(problem_path, "problem.json")

        chunks_file = os.path.join(problem_path, "chunks_labeled.json")

       

        problem_text = ""

        allsentences = []

       

        # Load the problem statement

        try:

            with open(problem_file, 'r') as f:

                problem_data = json.load(f)

                problem_text = problem_data.get("problem", "")
                problem_answer = problem_data.get("gt_answer", "")
                

        except (FileNotFoundError, json.JSONDecodeError) as e:

            print(f"Skipping {problem_name}: Could not load problem.json. Error: {e}")

            continue


        # Load all sentences from chunks_labeled.json

        try:

            with open(chunks_file, 'r') as f:

                chunks_data = json.load(f)

                allsentences = [chunk["chunk"] for chunk in chunks_data]

        except (FileNotFoundError, json.JSONDecodeError) as e:

            print(f"Skipping {problem_name}: Could not load chunks_labeled.json. Error: {e}")

            continue


        # Create a dictionary to store the extracted data

        problem_info = {

            "problem_id": problem_name,

            "problem_statement": problem_text,

            "sentences": allsentences,
            "answer": problem_answer

        }

        all_problem_data.append(problem_info)


    return all_problem_data

    print("No data was loaded.")




# Define the base directory for all problems

base_problem_dir = "math-rollouts/deepseek-r1-distill-llama-8b/temperature_0.6_top_p_0.95/correct_base_solution"

# Run the function to get all the data

correct_all_prompt = process_problem_data(base_problem_dir)


# Now, `all_data` is a list of dictionaries. You can iterate through it.

print(f"Successfully loaded data for {len(correct_all_prompt)} problems.")

# Define the base directory for all problems

base_problem_dir = "math-rollouts/deepseek-r1-distill-llama-8b/temperature_0.6_top_p_0.95/incorrect_base_solution"

# Run the function to get all the data

incorrect_all_prompt = process_problem_data(base_problem_dir)


print(f"Successfully loaded data for {len(incorrect_all_prompt)} problems.")

Found problem directory: math-rollouts/deepseek-r1-distill-llama-8b/temperature_0.6_top_p_0.95/correct_base_solution
Found problem directory: ['problem_6481', 'problem_4682', 'problem_3360', 'problem_4605', 'problem_2236', 'problem_1591', 'problem_4164', 'problem_2189', 'problem_2238', 'problem_3935', 'problem_6596', 'problem_3550', 'problem_2870', 'problem_4019', 'problem_2050', 'problem_6998', 'problem_3916', 'problem_2137', 'problem_3448', 'problem_330']
Successfully loaded data for 20 problems.
Found problem directory: math-rollouts/deepseek-r1-distill-llama-8b/temperature_0.6_top_p_0.95/incorrect_base_solution
Found problem directory: ['problem_4019', 'problem_2870', 'problem_3550', 'problem_3935', 'problem_6596', 'problem_2238', 'problem_2189', 'problem_4164', 'problem_1591', 'problem_2236', 'problem_4605', 'problem_3360', 'problem_6481', 'problem_4682', 'problem_3448', 'problem_2137', 'problem_330', 'problem_6998', 'problem_3916', 'problem_2050']
Successfully loaded data for 20 

In [3]:
all_prompt = correct_all_prompt[:2] + incorrect_all_prompt[:2]

In [4]:
import os
import json

def process_problem_labels_minimal(base_path):
    """
    Iterates through all problem directories, extracts selected fields from each chunk
    in `chunks_labeled.json`, and returns a list of dictionaries.

    Args:
        base_path (str): The path to the directory containing all the problems.

    Returns:
        list: A list of dictionaries, each with problem_id and a list of selected chunk fields.
    """
    all_problem_data = []

    if not os.path.isdir(base_path):
        print(f"Error: The directory '{base_path}' was not found.")
        return all_problem_data

    problem_dirs = [d for d in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, d))]
    if not problem_dirs:
        print(f"No problem directories found in '{base_path}'.")
        return all_problem_data

    for problem_name in problem_dirs:
        problem_path = os.path.join(base_path, problem_name)
        chunks_file = os.path.join(problem_path, "chunks_labeled.json")
        if not os.path.isfile(chunks_file):
            continue

        try:
            with open(chunks_file, "r") as f:
                chunks_data = json.load(f)
        except Exception as e:
            print(f"Skipping {problem_name}: {e}")
            continue

        # Extract only the fields you care about from each chunk
        selected_chunks = []
        for chunk in chunks_data:
            selected = {
                "function_tags": chunk.get("function_tags"),
                "chunk": chunk.get("chunk"),
                "accuracy": chunk.get("accuracy"),
                "resampling_importance_accuracy": chunk.get("resampling_importance_accuracy"),
                "resampling_importance_kl": chunk.get("resampling_importance_kl"),
                "counterfactual_importance_accuracy": chunk.get("counterfactual_importance_accuracy"),
                "counterfactual_importance_kl": chunk.get("counterfactual_importance_kl"),
                
                "summary": chunk.get("summary"),
            }
            selected_chunks.append(selected)

        all_problem_data.append({
            "problem_id": problem_name,
            "chunks": selected_chunks
        })

    return all_problem_data

# Example usage:
base_problem_dir = "math-rollouts/deepseek-r1-distill-llama-8b/temperature_0.6_top_p_0.95/correct_base_solution"
correct_all_labels = process_problem_labels_minimal(base_problem_dir)
print(f"Loaded {len(correct_all_labels)} problems with selected chunk fields.")
base_problem_dir = "math-rollouts/deepseek-r1-distill-llama-8b/temperature_0.6_top_p_0.95/incorrect_base_solution"
incorrect_all_labels = process_problem_labels_minimal(base_problem_dir)
print(f"Loaded {len(incorrect_all_labels)} problems with selected chunk fields.")

Loaded 20 problems with selected chunk fields.
Loaded 20 problems with selected chunk fields.


In [5]:
all_labels = correct_all_labels[:2] + incorrect_all_labels[:2]

In [6]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, BitsAndBytesConfig, AutoModelForCausalLM, pipeline

import torch


model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" # Or any other suitable model

mname = model_name

tokenizer = AutoTokenizer.from_pretrained(model_name)

# Important: Add a pad token if the tokenizer doesn't have one, especially for decoder models.

if tokenizer.pad_token is None:

    tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
# special_tokens_dict = {'additional_special_tokens': ['</think>']}
# tokenizer.add_special_tokens(special_tokens_dict)
# # Get the token ID for </think> for use in generation
# think_end_token_id = tokenizer.convert_tokens_to_ids('</think>')
# stop_token_list = [tokenizer.eos_token_id, think_end_token_id]

In [8]:
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")

Loading checkpoint shards: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2/2 [00:02<00:00,  1.04s/it]


In [9]:
from sentence_transformers import SentenceTransformer

# Initialize both embedding methods
sentence_model = SentenceTransformer('all-MiniLM-L6-v2')

In [10]:
#want do a loop through all the chunks append it to the current text set it through multiple rollouts and measure counterfactual importance

In [11]:
import torch
import torch.nn.functional as F
import numpy as np
import re

In [12]:
def contains_answer(text: str, ground_truth_answer: str):
    """
    Extract answers and check if they match the ground truth.
    
    Args:
        text: Text to extract boxed answers from
        ground_truth_answer: The correct answer to compare against
        
    Returns:
        tuple: (list of answers, True if any answer matches ground truth)
    """
    return ground_truth_answer in text

In [13]:
def get_probs(model, inputs):
    """Get probability distribution from model."""
    device = next(model.parameters()).device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        logits = model(**inputs).logits[0, -1, :]  # Last token logits
        probs = torch.softmax(logits, dim=-1)
    
    return probs

In [14]:
def get_embeddings(model, inputs, layer=-1):
    """Get hidden state embeddings from a specific layer."""
    device = next(model.parameters()).device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
        # Get embeddings and convert to float32 on GPU
        embeddings = outputs.hidden_states[layer][0, -1, :].float()  # Convert to float32
    del outputs
    return embeddings  # Keep on GPU


In [15]:
def get_sentence_embeddings(sentence_model, inputs, layer=-1):
    """Get sentence embeddings using SentenceTransformer."""
    embedding = sentence_model.encode(inputs, convert_to_tensor=True)
    return embedding.float()  # Ensure float32

In [16]:
def get_both_embeddings(llm_model, sentence_model, tokenizer, text):
    """Get both LLM and sentence transformer embeddings for comparison."""
    
    # LLM embedding (your current method)
    llm_inputs = tokenizer(text, return_tensors="pt", max_length=2048, truncation=True)
    llm_embedding = get_embeddings(llm_model, llm_inputs)
    
    # Sentence transformer embedding
    sent_embedding = get_sentence_embeddings(sentence_model, text)

    # MOVE TO CPU IMMEDIATELY to free GPU memory
    llm_embedding_cpu = llm_embedding.cpu()
    sent_embedding_cpu = sent_embedding.cpu()
    
    # Clean up GPU tensors
    del llm_embedding, sent_embedding
    return llm_embedding_cpu, sent_embedding_cpu

In [17]:
def get_both_embeddings_batch(llm_model, sentence_model, tokenizer, texts):
    """Get both LLM and sentence transformer embeddings for multiple texts in batches."""
    #texts list
    
    # Batch LLM embeddings
    llm_inputs = tokenizer(texts, return_tensors="pt", max_length=2048, truncation=True, padding=True)
    device = next(llm_model.parameters()).device
    llm_inputs = {k: v.to(device) for k, v in llm_inputs.items()}
    
    with torch.no_grad():
        outputs = llm_model(**llm_inputs, output_hidden_states=True)
        # Get last token embeddings for each sequence
        llm_embeddings = outputs.hidden_states[-1][:, -1, :].float().cpu()
    
    # Batch sentence transformer embeddings
    sent_embeddings = sentence_model.encode(texts, convert_to_tensor=True, batch_size=32).float().cpu()
    
    # Clean up GPU memory
    del outputs, llm_inputs
    
    return llm_embeddings, sent_embeddings

In [18]:
import gc

In [19]:
from typing import List, Dict

In [20]:
def extract_boxed_answers(text: str) -> List[str]:
    """
    Extract answers enclosed in \boxed{} from the text with improved handling
    of nested braces and complex LaTeX expressions.

    Args:
        text: The text to extract boxed answers from

    Returns:
        List of extracted boxed answers
    """
    # Find all occurrences of \boxed{
    boxed_starts = [m.start() for m in re.finditer(r"\\boxed\{", text)]

    if not boxed_starts:
        return [""]

    answers = []

    for start_idx in boxed_starts:
        # Start after \boxed{
        idx = start_idx + 7
        brace_count = 1  # We've already opened one brace
        answer = ""

        # Parse until we find the matching closing brace
        while idx < len(text) and brace_count > 0:
            char = text[idx]

            if char == "{":
                brace_count += 1
            elif char == "}":
                brace_count -= 1

                # Skip the closing brace of \boxed{}
                if brace_count == 0:
                    break

            if brace_count > 0:  # Only add if we're still inside the boxed content
                answer += char

            idx += 1

        if answer:
            answers.append(answer)

    return answers if answers else [""]

In [21]:
def normalize_answer(answer: str, use_sympy: bool = False) -> str:
    """
    Get the final normalized and cleaned version of an answer.
    This function combines all normalization steps used in check_answer.

    Args:
        answer: The answer string to normalize
        use_sympy: Whether to use sympy to normalize the answer

    Returns:
        The normalized answer string
    """
    # First apply basic LaTeX normalization
    normalized = normalize_latex(answer)

    # Also prepare the answer for sympy if applicable
    if use_sympy:
        try:
            sympy_ready = prepare_latex_for_sympy(answer)
            if sympy_ready != normalized and len(sympy_ready) > 0:
                return sympy_ready
        except Exception:
            pass

    return normalized

In [22]:
def get_latex_equivalent(answer0, answer1):
    """
    Check if two LaTeX expressions are mathematically equivalent using SymPy.

    Args:
        answer0: First LaTeX expression
        answer1: Second LaTeX expression

    Returns:
        True if expressions are mathematically equivalent, False otherwise
    """
    try:
        from sympy.parsing.latex import parse_latex
        import sympy

        # Clean up the LaTeX expressions for parsing
        answer0 = prepare_latex_for_sympy(answer0)
        answer1 = prepare_latex_for_sympy(answer1)

        # Parse the LaTeX expressions
        expr1 = parse_latex(answer0)
        expr2 = parse_latex(answer1)

        # Check if they are mathematically identical
        equals = expr1.equals(expr2)
        # print(f"First: {answer0}, Second: {answer1}: equals={equals}")
        return equals
    except Exception as e:
        # print(f"Error comparing expressions: {e}")
        return False


def prepare_latex_for_sympy(latex_str):
    """
    Prepare a LaTeX string for SymPy parsing by removing unsupported commands
    and simplifying the expression.
    """
    if not isinstance(latex_str, str):
        return str(latex_str)

    # Remove \boxed{} command
    latex_str = re.sub(r"\\boxed\{(.*?)\}", r"\1", latex_str)

    # Replace common LaTeX commands that SymPy doesn't support
    replacements = {
        r"\\dfrac": r"\\frac",
        r"\\tfrac": r"\\frac",
        r"\\cdot": r"*",
        r"\\times": r"*",
        r"\\div": r"/",
        r"\\left": r"",
        r"\\right": r"",
        r"\\textbf": r"",
        r"\\text": r"",
        r"\\mathrm": r"",
        r"\\!": r"",
        r",": r"",
    }

    for old, new in replacements.items():
        latex_str = re.sub(old, new, latex_str)

    return latex_str

In [23]:
def normalize_latex(latex_str: str) -> str:
    """
    Normalize LaTeX string by applying various transformations.

    Args:
        latex_str: The LaTeX string to normalize

    Returns:
        Normalized LaTeX string
    """
    normalized = latex_str.strip().lower()

    # Replace different fraction notations
    normalized = normalized.replace("dfrac", "frac")
    normalized = normalized.replace("tfrac", "frac")

    # Normalize spaces
    normalized = re.sub(r"\s+", "", normalized)

    # Normalize percentages
    normalized = normalized.replace("\\%", "")

    # Normalize funny commas
    normalized = normalized.replace("{,}", "")

    # Normalize common mathematical notations
    normalized = normalized.replace("\\times", "*")
    normalized = normalized.replace("\\cdot", "*")

    # Normalize decimal representation
    normalized = re.sub(r"(\d+)[\.,](\d+)", r"\1.\2", normalized)

    # Remove unnecessary braces in simple expressions
    normalized = re.sub(r"{([^{}]+)}", r"\1", normalized)

    # Normalize common constants
    normalized = normalized.replace("\\pi", "pi")

    # Remove LaTeX text commands
    normalized = re.sub(r"\\text\{([^{}]+)\}", r"\1", normalized)
    normalized = re.sub(r"\\mathrm\{([^{}]+)\}", r"\1", normalized)

    # Normalize date formats (e.g., "October 30" vs "October\\ 30")
    normalized = re.sub(r"([a-z]+)\\+\s*(\d+)", r"\1\2", normalized)
    normalized = normalized.replace("\\text", "")

    return normalized

In [24]:
def prepare_latex_for_sympy(latex_str):
    """
    Prepare a LaTeX string for SymPy parsing by removing unsupported commands
    and simplifying the expression.
    """
    if not isinstance(latex_str, str):
        return str(latex_str)

    # Remove \boxed{} command
    latex_str = re.sub(r"\\boxed\{(.*?)\}", r"\1", latex_str)

    # Replace common LaTeX commands that SymPy doesn't support
    replacements = {
        r"\\dfrac": r"\\frac",
        r"\\tfrac": r"\\frac",
        r"\\cdot": r"*",
        r"\\times": r"*",
        r"\\div": r"/",
        r"\\left": r"",
        r"\\right": r"",
        r"\\textbf": r"",
        r"\\text": r"",
        r"\\mathrm": r"",
        r"\\!": r"",
        r",": r"",
    }

    for old, new in replacements.items():
        latex_str = re.sub(old, new, latex_str)

    return latex_str

In [25]:
def check_answer(answer: str, gt_answer: str) -> bool:
    """
    Check if the generated answer matches the ground truth answer
    after normalizing LaTeX formatting.

    Args:
        answer: The generated answer to check
        gt_answer: The ground truth answer to compare against

    Returns:
        True if the answers match after normalization, False otherwise
    """
    # Normalize both answers
    normalized_answer = normalize_latex(answer)
    normalized_gt_answer = normalize_latex(gt_answer)

    # First check if normalized strings match
    if normalized_answer == normalized_gt_answer:
        return True

    # # If string comparison fails, try mathematical equivalence
    # try:
    #     return get_latex_equivalent(answer, gt_answer)
    # except Exception as e:
    #     # If SymPy parsing fails, fall back to string comparison result
    return False


In [26]:
def split_solution_into_chunks(solution_text: str) -> List[str]:
    """
    Split a solution into chunks for rollout generation.

    Args:
        solution_text: The full solution text

    Returns:
        List of chunks
    """
    # First, remove the prompt part if present
    if "<think>" in solution_text:
        solution_text = solution_text.split("<think>")[1].strip()

    # Remove the closing tag if present
    if "</think>" in solution_text:
        solution_text = solution_text.split("</think>")[0].strip()

    # Define patterns for chunk boundaries
    sentence_ending_tokens = [".", "?", "!"]
    paragraph_ending_patterns = ["\n\n", "\r\n\r\n"]

    # Split the text into chunks
    chunks = []
    current_chunk = ""

    # Process the text character by character
    i = 0
    while i < len(solution_text):
        current_chunk += solution_text[i]

        # Check for paragraph endings
        is_paragraph_end = False
        for pattern in paragraph_ending_patterns:
            if (
                i + len(pattern) <= len(solution_text)
                and solution_text[i : i + len(pattern)] == pattern
            ):
                is_paragraph_end = True
                break

        # Check for sentence endings followed by space or newline
        is_sentence_end = False
        if i < len(solution_text) - 1 and solution_text[i] in sentence_ending_tokens:
            next_char = solution_text[i + 1]
            if next_char == " " or next_char == "\n":
                is_sentence_end = True

        # If we found a boundary, add the chunk and reset
        if is_paragraph_end or is_sentence_end:
            if current_chunk.strip():
                chunks.append(current_chunk.strip())
                current_chunk = ""

        i += 1

    # # Add the last chunk if not empty
    # if current_chunk.strip():
    #     chunks.append(current_chunk.strip())
    #     chunk_idxs.append(len(solution_text) - 1)  # Add last index

    # Merge small chunks (less than 10 characters)
    i = 0
    while i < len(chunks):
        if len(chunks[i]) < 10:
            # If this is the last chunk, merge with previous chunk if possible
            if i == len(chunks) - 1:
                if i > 0:
                    chunks[i - 1] = chunks[i - 1] + " " + chunks[i]
                    chunks.pop(i)
            # Otherwise merge with the next chunk
            else:
                chunks[i + 1] = chunks[i] + " " + chunks[i + 1]
                chunks.pop(i)
                # Don't increment i since we need to check the new merged chunk
            # If we're at the beginning and there's only one chunk, just keep it
            if i == 0 and len(chunks) == 1:
                break
        else:
            i += 1

    # chunk_boundaries = [(chunk_idxs[i], chunk_idxs[i + 1]) for i in range(len(chunk_idxs) - 1)]
    # chunk_boundaries.append((chunk_idxs[-1], len(solution_text)))

    # if get_idxs:
    #     return chunks, chunk_boundaries
    # else:
    return chunks


In [27]:
def calculate_grouped_answer_kl_divergence(rollout_answer_correct, cos_sims, similarity_threshold=0.8):
    """
    Calculate KL divergence between answer correctness distributions for similar vs dissimilar rollouts.
    This follows the thought-anchors approach.
    
    Args:
        rollout_answer_correct: List of boolean correctness for each rollout
        cos_sims: Tensor of cosine similarities for each rollout
        similarity_threshold: Threshold for determining similar vs dissimilar
        
    Returns:
        dict: Contains KL divergences and group statistics
    """
    # Ensure cos_sims is a tensor
    if not torch.is_tensor(cos_sims):
        cos_sims = torch.tensor(cos_sims, dtype=torch.float32)
    
    
    # Separate rollouts into similar and dissimilar groups
    similar_mask = cos_sims > similarity_threshold
    dissimilar_mask = ~similar_mask
    
    similar_correctness = [rollout_answer_correct[i] for i in range(len(rollout_answer_correct)) if similar_mask[i]]
    dissimilar_correctness = [rollout_answer_correct[i] for i in range(len(rollout_answer_correct)) if dissimilar_mask[i]]
    
    if len(similar_correctness) == 0 or len(dissimilar_correctness) == 0:
        return {
            "kl_divergence": 0.0,
            "similar_group_size": len(similar_correctness),
            "dissimilar_group_size": len(dissimilar_correctness),
            "similar_accuracy": 0.0,
            "dissimilar_accuracy": 0.0
        }
    
    # Calculate accuracy rates for each group
    similar_accuracy = sum(similar_correctness) / len(similar_correctness)
    dissimilar_accuracy = sum(dissimilar_correctness) / len(dissimilar_correctness)
    
    # Create probability distributions
    # Similar distribution: [P(wrong), P(correct)]
    similar_dist = torch.tensor([1 - similar_accuracy, similar_accuracy], dtype=torch.float32)
    
    # Dissimilar distribution: [P(wrong), P(correct)]
    dissimilar_dist = torch.tensor([1 - dissimilar_accuracy, dissimilar_accuracy], dtype=torch.float32)
    
    # Add small epsilon to avoid log(0)
    eps = 1e-8
    similar_dist = similar_dist + eps
    dissimilar_dist = dissimilar_dist + eps
    
    # Normalize to ensure they sum to 1
    similar_dist = similar_dist / similar_dist.sum()
    dissimilar_dist = dissimilar_dist / dissimilar_dist.sum()
    
    # Calculate KL divergence: KL(P_dissimilar || P_similar)
    # This measures how much the dissimilar group diverges from similar group
    kl_div = torch.sum(dissimilar_dist * torch.log(dissimilar_dist / similar_dist))
    
    return {
        "kl_divergence": float(kl_div),
        "similar_group_size": len(similar_correctness),
        "dissimilar_group_size": len(dissimilar_correctness),
        "similar_accuracy": similar_accuracy,
        "dissimilar_accuracy": dissimilar_accuracy
    }

In [28]:
# def get_max_tokens_for_sentence(sentence_idx):
#     """Dynamic token allocation based on sentence position."""
#     if sentence_idx < 20:
#         return 6000  # Full CoT for very early sentences
#     elif sentence_idx < 40:
#         return 5000  # Slightly reduced
#     elif sentence_idx < :
#         return 4500  # Moderate reduction
#     else:
#         return 4000  # Conservative for very late sentences

In [29]:
def generate_diverse_rollouts(model, tokenizer, ground_truth_answer, context, num_rollouts=10, batch_size=5, temperature=0.6, top_p=0.95, sentence_idx=0):
    """Generate diverse text completions in batches for better GPU utilization."""

    # START: Clean memory before beginning (NEW)
    torch.cuda.empty_cache()
    gc.collect()

    device = next(model.parameters()).device
    inputs = tokenizer(context, return_tensors="pt", max_length=1500, truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Dynamic token allocation based on sentence position
    max_new_tokens = 6000

    # Add </think> as a stop sequence
    stop_token_ids = tokenizer.encode("</think>", add_special_tokens=False)
    rollout_texts = []
    rollout_answer_correct = [] # this uses contains_answer
    rollout_answer_correct_check = []  # New list using check_answer
    
    
    # Process in batches
    for batch_start in range(0, num_rollouts, batch_size):
        batch_end = min(batch_start + batch_size, num_rollouts)
        current_batch_size = batch_end - batch_start
        
        # Expand inputs for batch processing
        batch_inputs = {
            'input_ids': inputs['input_ids'].repeat(current_batch_size, 1),
            'attention_mask': inputs['attention_mask'].repeat(current_batch_size, 1)
        }
        
        with torch.no_grad():
            outputs = model.generate(
                batch_inputs['input_ids'],
                attention_mask=batch_inputs['attention_mask'],
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=temperature,
                top_p=top_p,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id,
                output_hidden_states=True,
                return_dict_in_generate=True,
            )

        # Process each sequence in the batch
        batch_generated_ids = outputs.sequences[:, inputs['input_ids'].shape[1]:]
        batch_generated_texts = tokenizer.batch_decode(batch_generated_ids, skip_special_tokens=True)
        # 2. BATCH STRIP ALL TEXTS
        batch_texts_stripped = [text.strip() for text in batch_generated_texts]
        rollout_texts.extend(batch_texts_stripped)
        
        # 3. BATCH CHECK ANSWER CORRECTNESS (contains_answer)
        batch_contains_answer = [contains_answer(text, ground_truth_answer) for text in batch_texts_stripped]
        rollout_answer_correct.extend(batch_contains_answer)
        
        # 4. BATCH EXTRACT AND CHECK BOXED ANSWERS
        batch_boxed_answers = [extract_boxed_answers(text) for text in batch_texts_stripped]
        batch_correct_check = [
            any(check_answer(answer, ground_truth_answer) for answer in boxed_answers)
            for boxed_answers in batch_boxed_answers
        ]
        rollout_answer_correct_check.extend(batch_correct_check)
                
        # ADD THIS: Clean up after each batch
        del outputs, batch_inputs, batch_generated_ids, batch_generated_texts
        del batch_texts_stripped, batch_contains_answer, batch_boxed_answers, batch_correct_check
        gc.collect()
        torch.cuda.empty_cache()
    
    
    # Final cleanup
    del inputs
    torch.cuda.empty_cache()
    
    # Stack embeddings into a single tensor on GPU
    return rollout_texts, rollout_answer_correct, rollout_answer_correct_check


In [30]:
# def process_embedding():
#     rollout_llm_embeddings = []
#     rollout_sentence_embeddings = []
#     rollout_sentences = []
#     for rollout_text in rollout_texts:
#         rollout_resampled = split_solution_into_chunks(rollout_text)[0]
#         llm_embedding, sent_embedding = get_both_embeddings(model, sentence_model, tokenizer, rollout_resampled)
#         rollout_llm_embeddings.append(llm_embedding)
#         rollout_sentence_embeddings.append(sent_embedding)
#         rollout_sentences.append(rollout_resampled)
#     # Stack embeddings
#     rollout_llm_embeddings = torch.stack(rollout_llm_embeddings)
#     rollout_sentence_embeddings = torch.stack(rollout_sentence_embeddings)
    
#     # Calculate cosine similarities using PyTorch batch operations
#     cos_sims_llm = torch.cosine_similarity(
#         original_llm_emb.unsqueeze(0), rollout_llm_embeddings, dim=1
#     )
#     cos_sims_sent = torch.cosine_similarity(
#         original_sent_emb.unsqueeze(0), rollout_sentence_embeddings, dim=1
#     )

#     print(f"cos_sims_llm{sentence_idx}: {cos_sims_llm}")
#     print(f"cos_sims_sent{sentence_idx}: {cos_sims_sent}")

In [31]:
def get_both_embeddings_batch(llm_model, sentence_model, tokenizer, texts):
    """Get both LLM and sentence transformer embeddings for multiple texts in batches."""
    
    # Batch LLM embeddings
    llm_inputs = tokenizer(texts, return_tensors="pt", max_length=2048, truncation=True, padding=True)
    device = next(llm_model.parameters()).device
    llm_inputs = {k: v.to(device) for k, v in llm_inputs.items()}
    
    with torch.no_grad():
        outputs = llm_model(**llm_inputs, output_hidden_states=True)
        # Get last token embeddings for each sequence
        llm_embeddings = outputs.hidden_states[-1][:, -1, :].float().cpu()
    
    # Batch sentence transformer embeddings
    sent_embeddings = sentence_model.encode(texts, convert_to_tensor=True, batch_size=32).float().cpu()
    
    # Clean up GPU memory
    del outputs, llm_inputs
    
    return llm_embeddings, sent_embeddings

In [32]:
def process_one_sentence(model, tokenizer, problem_text, allsentences, ground_truth_answer,  num_rollouts=20, sentence_idx=0, batch=5):
    #sentence_idx sentence position
    """Save raw rollout data and similarities for later importance calculation."""
    
    # START: Clean memory before beginning
    torch.cuda.empty_cache()
    gc.collect()
    
    # Context WITHOUT the current sentence
    print(f"Context removed: {allsentences[sentence_idx]}")
    prefix_text = allsentences[:sentence_idx]
    context_without = problem_text + " " + " ".join(prefix_text) + ""

    # Generate diverse rollouts from context without chunk
    rollout_texts, rollout_answer_correct, rollout_answer_correct_check = generate_diverse_rollouts(
        model,                    
        tokenizer,               
        ground_truth_answer,     
        context_without,         
        num_rollouts=num_rollouts,    
        batch_size=batch,             
        temperature=0.6,              
        top_p=0.95,
        sentence_idx=sentence_idx                     
    )
    
    # Get original sentence embeddings
    original_text = allsentences[sentence_idx]
    original_llm_emb, original_sent_emb = get_both_embeddings(model, sentence_model, tokenizer, original_text)

    # BATCH PROCESS ROLLOUTS
    rollout_sentences = [split_solution_into_chunks(rollout_text)[0] for rollout_text in rollout_texts]
    
    # Get all embeddings at once
    rollout_llm_embeddings, rollout_sentence_embeddings = get_both_embeddings_batch(
        model, sentence_model, tokenizer, rollout_sentences
    )
    
    # Calculate cosine similarities using PyTorch batch operations
    cos_sims_llm = torch.cosine_similarity(
        original_llm_emb.unsqueeze(0), rollout_llm_embeddings, dim=1
    )
    cos_sims_sent = torch.cosine_similarity(
        original_sent_emb.unsqueeze(0), rollout_sentence_embeddings, dim=1
    )

    print(f"cos_sims_llm{sentence_idx}: {cos_sims_llm}")
    print(f"cos_sims_sent{sentence_idx}: {cos_sims_sent}")
    
    # Calculate basic statistics
    unique_responses = len(set(rollout_texts))
    
    # Store raw data for later importance calculation
    result = {
        "problem_id": None,  # Will be set by caller
        "sentence_idx": sentence_idx,
        "sentence_text": original_text,
        "function_tags": [],  # Will be set by caller
        
        # Context information
        "context_without_sentence": context_without,
        "ground_truth_answer": ground_truth_answer,
        
        # Rollout data
        "num_rollouts": num_rollouts,
        "rollout_sentences": rollout_sentences,
        "rollout_answer_correct": rollout_answer_correct,
        "rollout_answer_correct_check": rollout_answer_correct_check,
        "unique_responses": unique_responses,

        
        # extracted boxed_answers
        "rollout_boxed_answers": [extract_boxed_answers(text) for text in rollout_texts],
        # Raw similarity scores (convert to lists for JSON serialization)
        "cos_sims_llm": cos_sims_llm.cpu().tolist(),
        "cos_sims_sentence": cos_sims_sent.cpu().tolist(),
        
        # Summary statistics for quick reference
        # "llm_similarity_stats": {
        #     "mean": float(torch.mean(cos_sims_llm)),
        #     "std": float(torch.std(cos_sims_llm)),
        #     "min": float(torch.min(cos_sims_llm)),
        #     "max": float(torch.max(cos_sims_llm))
        # },
        # "sentence_similarity_stats": {
        #     "mean": float(torch.mean(cos_sims_sent)),
        #     "std": float(torch.std(cos_sims_sent)),
        #     "min": float(torch.min(cos_sims_sent)),
        #     "max": float(torch.max(cos_sims_sent))
        # },
        
        # Embedding comparison
        
        # Store original embeddings for potential later use (optional)
        # "original_llm_embedding": original_llm_emb.cpu().tolist(),
        # "original_sent_embedding": original_sent_emb.cpu().tolist(),
        
        # Metadata for importance calculation
        "generation_params": {
            "temperature": 0.6,
            "top_p": 0.95,
            "batch_size": batch
        }
    }
    
    # Cleanup
    del rollout_texts, rollout_llm_embeddings, rollout_sentence_embeddings, rollout_answer_correct, rollout_answer_correct_check
    del cos_sims_llm, cos_sims_sent, rollout_sentences
    del original_llm_emb, original_sent_emb
    if 'original_text' in locals():
        del original_text
    gc.collect()
    torch.cuda.empty_cache()
    
    return result

In [33]:
# target_tags = [
#         'uncertainty_management', 
#         'plan_generation'
#     ]
# def extract_target_sentence_indices(all_problem_labels, target_tags):
#     """
#     Extract sentence indices that have the target function tags.
    
#     Returns a dictionary mapping problem_id to list of sentence indices to process.
#     """
#     target_indices = {}
    
#     for problem in all_problem_labels:
#         problem_id = problem['problem_id']
#         indices_to_process = []
        
#         for i, chunk_data in enumerate(problem['chunks']):
#             function_tags = chunk_data.get('function_tags', [])
            
#             # Check if any target tags are in this chunk's function_tags
#             if any(tag in function_tags for tag in target_tags):
#                 indices_to_process.append(i)
        
#         if indices_to_process:  # Only add if we found relevant chunks
#             target_indices[problem_id] = indices_to_process
    
#     return target_indices

# # Extract target sentence indices
# target_sentence_indices = extract_target_sentence_indices(all_problem_labels, target_tags)
# target_sentence_indices = extract_target_sentence_indices(all_problem_labels, target_tags)

# print(f"Found target sentences in {len(target_sentence_indices)} problems")

# for prompt, label in zip(all_prompt[:1], all_problem_labels[:1]):
#     problem_id = prompt["problem_id"]
#     problem_text_prompt = prompt["problem_statement"]
#     allsentences = prompt["sentences"]
#     ground_truth_answer = prompt["answer"]
#     problem_text = f"Solve this math problem step by step. You MUST put your final answer in \\boxed{{}}. Problem: {problem_text_prompt} Solution: \n<think>\n"
    
#     # Get target sentence indices for this problem
#     if problem_id not in target_sentence_indices:
#         print(f"\nSkipping problem {problem_id} - no target function tags found")
#         continue
    
#     target_indices = target_sentence_indices[problem_id]
#     print(f"\nProcessing problem {problem_id} with {len(allsentences)} total sentences")
#     print(f"Target sentence indices to process: {target_indices}")
    
#     # Process only target sentences
#     sentence_results = []
#     for sentence_idx in target_indices:
#         if sentence_idx >= len(allsentences):
#             print(f"Warning: sentence_idx {sentence_idx} >= len(allsentences) {len(allsentences)}")
#             continue
            
#         print(f"\n--- Processing sentence {sentence_idx + 1}/{len(allsentences)} (target) ---")
#         print(f"Sentence: {allsentences[sentence_idx]}")
#         print(f"Function tags: {label['chunks'][sentence_idx].get('function_tags', [])}")
        
#         sentence_result = process_one_sentence(
#             model, tokenizer, problem_text, allsentences, ground_truth_answer, 
#             num_rollouts=num_rollouts, sentence_idx=sentence_idx, batch=batch
#         )
        
#         # Add problem metadata and function tags
#         sentence_result["problem_id"] = problem_id
#         sentence_result["sentence_idx"] = sentence_idx
#         sentence_result["sentence_text"] = allsentences[sentence_idx]
#         sentence_result["function_tags"] = label['chunks'][sentence_idx].get('function_tags', [])
#         sentence_results.append(sentence_result)
        
#         # Cleanup after each sentence
#         gc.collect()
#         torch.cuda.empty_cache()
    
#     results.extend(sentence_results)
#     print(f"\nCompleted problem {problem_id} - processed {len(sentence_results)} target sentences")

# print(f"\nProcessed {len(results)} target sentences across all problems")

In [34]:
def process_single_problem(
    problem_data: Dict, 
    problem_labels: Dict,
    model, 
    tokenizer, 
    sentence_model,
    output_dir: str = "rollout_results",
    num_rollouts: int = 20,
    batch_size: int = 5,
    force: bool = False
) -> List[Dict]:
    """
    Process a single problem: loop through every sentence and generate rollouts.
    Can resume from where it left off if results already exist.
    """
    problem_id = problem_data["problem_id"]
    problem_text_prompt = problem_data["problem_statement"]
    allsentences = problem_data["sentences"]
    ground_truth_answer = problem_data["answer"]
    
    # Create output directory structure
    output_path = Path(output_dir)
    problem_dir = output_path / problem_id
    problem_dir.mkdir(exist_ok=True, parents=True)
    
    # Define the results file path
    results_file = problem_dir / "sentence_rollouts.json"
    
    # Check if we can resume from existing results
    sentence_results = []
    start_sentence_idx = 0
    
    if results_file.exists() and not force:
        try:
            with open(results_file, 'r', encoding='utf-8') as f:
                existing_results = json.load(f)
            
            # Filter out failed results to find the last successful sentence
            successful_results = [r for r in existing_results if "error" not in r]
            
            if successful_results:
                sentence_results = existing_results
                start_sentence_idx = len(successful_results)
                print(f"Resuming from sentence {start_sentence_idx + 1}/{len(allsentences)}")
                print(f"Found {len(successful_results)} existing successful results")
            else:
                print(f"Found existing file but no successful results. Starting from beginning.")
                
        except (json.JSONDecodeError, KeyError) as e:
            print(f"Error reading existing results file: {e}. Starting from beginning.")
    
    # Prepare problem text for generation
    problem_text = f"Solve this math problem step by step. You MUST put your final answer in \\boxed{{}}. Problem: {problem_text_prompt} Solution: \n<think>\n"
    
    print(f"\nProcessing problem {problem_id} with {len(allsentences)} sentences")
    
    # Process sentences starting from where we left off
    for sentence_idx in range(start_sentence_idx, len(allsentences)):
        print(f"\n--- Processing sentence {sentence_idx + 1}/{len(allsentences)} ---")
        print(f"Sentence: {allsentences[sentence_idx]}")
        
        # Get function tags if available
        function_tags = []
        if sentence_idx < len(problem_labels.get('chunks', [])):
            function_tags = problem_labels['chunks'][sentence_idx].get('function_tags', [])
            print(f"Function tags: {function_tags}")
        
        try:
            # Process this sentence
            sentence_result = process_one_sentence(
                model=model,
                tokenizer=tokenizer,
                problem_text=problem_text,
                allsentences=allsentences,
                ground_truth_answer=ground_truth_answer,
                num_rollouts=num_rollouts,
                sentence_idx=sentence_idx,
                batch=batch_size
            )
            
            # Add metadata
            sentence_result["problem_id"] = problem_id
            sentence_result["sentence_idx"] = sentence_idx
            sentence_result["sentence_text"] = allsentences[sentence_idx]
            sentence_result["function_tags"] = function_tags
            
            sentence_results.append(sentence_result)
            
            print(f"Sentence {sentence_idx + 1}: Completed successfully")
            print(f"  - Unique responses: {sentence_result['unique_responses']}")
            
        except Exception as e:
            print(f"Error processing sentence {sentence_idx}: {e}")
            # Still save partial results
            error_result = {
                "problem_id": problem_id,
                "sentence_idx": sentence_idx,
                "sentence_text": allsentences[sentence_idx],
                "function_tags": function_tags,
                "error": str(e),
                "status": "failed"
            }
            sentence_results.append(error_result)
        
        # Cleanup after each sentence
        gc.collect()
        torch.cuda.empty_cache()
        
        # Save intermediate results every 5 sentences or immediately if we're resuming
        if (sentence_idx + 1) % 5 == 0 or sentence_idx == start_sentence_idx:
            with open(results_file, 'w', encoding='utf-8') as f:
                json.dump(sentence_results, f, indent=2, default=str)
            print(f"Saved intermediate results after sentence {sentence_idx + 1}")
    
    # Save final results
    with open(results_file, 'w', encoding='utf-8') as f:
        json.dump(sentence_results, f, indent=2, default=str)
    
    # Save summary statistics
    summary = {
        "problem_id": problem_id,
        "total_sentences": len(allsentences),
        "processed_sentences": len(sentence_results),
        "successful_sentences": len([r for r in sentence_results if "error" not in r]),
        "failed_sentences": len([r for r in sentence_results if "error" in r]),
        "generation_params": {
            "num_rollouts": num_rollouts,
            "batch_size": batch_size,
            "temperature": 0.6,
            "top_p": 0.95
        },
        "resumed_from_sentence": start_sentence_idx
    }
    
    summary_file = problem_dir / "processing_summary.json"
    with open(summary_file, 'w', encoding='utf-8') as f:
        json.dump(summary, f, indent=2)
    
    # Clean up temporary files
    for temp_file in problem_dir.glob("sentence_rollouts_temp_*.json"):
        temp_file.unlink()
    
    print(f"\nCompleted problem {problem_id}")
    print(f"  - Total sentences: {len(allsentences)}")
    print(f"  - Successfully processed: {summary['successful_sentences']}")
    print(f"  - Failed: {summary['failed_sentences']}")
    print(f"  - Results saved to: {results_file}")
    
    return sentence_results

def process_multiple_problems(
    all_prompt: List[Dict],
    all_problem_labels: List[Dict],
    model,
    tokenizer,
    sentence_model,
    output_dir: str = "rollout_results",
    num_rollouts: int = 20,
    batch_size: int = 5,
    force: bool = False,
    max_problems: int = None
) -> List[Dict]:
    """
    Process multiple problems sequentially.
    Can resume from partially completed problems.
    """
    # Create output directory
    Path(output_dir).mkdir(exist_ok=True, parents=True)
    
    # Limit number of problems if specified
    problems_to_process = all_prompt[:max_problems] if max_problems else all_prompt
    labels_to_process = all_problem_labels[:max_problems] if max_problems else all_problem_labels
    
    print(f"Processing {len(problems_to_process)} problems")
    
    all_sentence_results = []
    
    for i, (problem_data, problem_labels) in enumerate(zip(problems_to_process, labels_to_process)):
        print(f"\n{'='*50}")
        print(f"Processing problem {i+1}/{len(problems_to_process)}: {problem_data['problem_id']}")
        print(f"{'='*50}")
        
        try:
            sentence_results = process_single_problem(
                problem_data=problem_data,
                problem_labels=problem_labels,
                model=model,
                tokenizer=tokenizer,
                sentence_model=sentence_model,
                output_dir=output_dir,
                num_rollouts=num_rollouts,
                batch_size=batch_size,
                force=force
            )
            all_sentence_results.extend(sentence_results)
            
        except Exception as e:
            print(f"Failed to process problem {problem_data['problem_id']}: {e}")
            continue
        
        # Final cleanup between problems
        gc.collect()
        torch.cuda.empty_cache()
    
    print(f"\nCompleted processing all problems. Results saved in: {output_dir}")
    print(f"Total sentence results: {len(all_sentence_results)}")
    
    return all_sentence_results

In [35]:
# # Test just the first sentence with truncation debugging
# def debug_first_sentence():
#     # Get data for first sentence
#     problem_data = all_prompt[0]
#     problem_labels = all_labels[0]
    
#     problem_id = problem_data["problem_id"]
#     problem_text_prompt = problem_data["problem_statement"]
#     allsentences = problem_data["sentences"]
#     ground_truth_answer = problem_data["answer"]
    
#     problem_text = f"Solve this math problem step by step. You MUST put your final answer in \\boxed{{}}. Problem: {problem_text_prompt} Solution: \n<think>\n"
    
#     print(f"Testing first sentence of problem {problem_id}")
#     print(f"First sentence: {allsentences[0]}")
#     print(f"Ground truth answer: {ground_truth_answer}")
#     print("="*50)
    
#     # Process just the first sentence (sentence_idx=0)
#     sentence_result = process_one_sentence(
#         model=model,
#         tokenizer=tokenizer,
#         problem_text=problem_text,
#         allsentences=allsentences,
#         ground_truth_answer=ground_truth_answer,
#         num_rollouts=10,  # Small number for testing
#         sentence_idx=0,  # First sentence
#         batch=5
#     )
    
#     print("\n" + "="*50)
#     print("TRUNCATION ANALYSIS:")
#     print("="*50)
    
#     # Analyze each rollout for truncation
#     for i, rollout_text in enumerate(sentence_result["rollout_texts"]):
#         tokens = tokenizer(rollout_text, return_tensors="pt")['input_ids'].shape[1]
        
#         print(f"\nRollout {i+1}:")
#         print(f"  Tokens: {tokens}")
#         print(f"  Words: {len(rollout_text.split())}")
        
#         # Check ending
#         ending = rollout_text[-100:].replace('\n', ' ').strip()
#         print(f"  Last 100 chars: ...{ending}")
        
#         # Truncation indicators
#         truncation_signs = [
#             rollout_text.strip().endswith(','),
#             rollout_text.strip().endswith('='),
#             rollout_text.strip().endswith('+'),
#             rollout_text.strip().endswith('but'),
#             rollout_text.strip().endswith('So'),
#             rollout_text.strip().endswith('The'),
#             rollout_text.strip().endswith('('),
#         ]
        
#         has_proper_ending = (
#             '\\boxed{' in rollout_text or 
#             rollout_text.strip().endswith('</think>') or
#             rollout_text.strip().endswith('.')
#         )
        
#         is_truncated = (tokens >= 3995 or any(truncation_signs) or not has_proper_ending)
        
#         print(f"  Has proper ending: {has_proper_ending}")
#         print(f"  Truncation signs: {any(truncation_signs)}")
#         print(f"  Is truncated: {is_truncated}")
#         print(f"  Contains answer: {ground_truth_answer in rollout_text}")
    
#     return sentence_result

# # # Run the debug test
# # debug_result = debug_first_sentence()
# for i, rollout_text in enumerate(debug_result["rollout_texts"]):
#     boxed_answers = extract_boxed_answers(rollout_text)
#     print(f"Rollout {i+1} boxed answers: {boxed_answers}")

In [36]:
num_rollouts = 6
batch_size = 6
outputdir = "rollout_results_no_ablation/"

In [37]:
from pathlib import Path

process_multiple_problems(
    all_prompt=all_prompt[:1],
    all_problem_labels=all_labels[:1],
    model=model,
    tokenizer=tokenizer,
    sentence_model=sentence_model,
     output_dir= outputdir,
    num_rollouts=num_rollouts,
    batch_size=batch_size
)

Processing 1 problems

Processing problem 1/1: problem_6481
Resuming from sentence 188/187
Found 187 existing successful results

Processing problem problem_6481 with 187 sentences

Completed problem problem_6481
  - Total sentences: 187
  - Successfully processed: 187
  - Failed: 0
  - Results saved to: rollout_results_no_ablation/problem_6481/sentence_rollouts.json

Completed processing all problems. Results saved in: rollout_results_no_ablation/
Total sentence results: 187


[{'problem_id': 'problem_6481',
  'sentence_idx': 0,
  'sentence_text': 'Okay, so I have this problem about a square with an area of 81 square units.',
  'function_tags': ['problem_setup'],
  'context_without_sentence': 'Solve this math problem step by step. You MUST put your final answer in \\boxed{}. Problem: Two points are drawn on each side of a square with an area of 81 square units, dividing the side into 3 congruent parts.  Quarter-circle arcs connect the points on adjacent sides to create the figure shown.  What is the length of the boundary of the bolded figure?  Express your answer as a decimal to the nearest tenth. [asy]\nsize(80);\nimport graph;\ndraw((0,0)--(3,0)--(3,3)--(0,3)--cycle, linetype("2 4"));\ndraw(Arc((0,0),1,0,90),linewidth(.8));\ndraw(Arc((0,3),1,0,-90),linewidth(.8));\ndraw(Arc((3,0),1,90,180),linewidth(.8));\ndraw(Arc((3,3),1,180,270),linewidth(.8));\ndraw((1,0)--(2,0),linewidth(.8));draw((3,1)--(3,2),linewidth(.8));\ndraw((1,3)--(2,3),linewidth(.8));draw((0

In [38]:
def process_single_problem_with_multi_head_ablation(
    problem_data: Dict, 
    problem_labels: Dict,
    model, 
    tokenizer, 
    sentence_model,
    receiver_heads: List[tuple],  # List of (layer_idx, head_idx) tuples
    output_dir: str = "rollout_results_ablation",
    num_rollouts: int = 6,
    batch_size: int = 6,
    force: bool = False
) -> List[Dict]:
    """
    Process a single problem with receiver head ablation: loop through every sentence 
    and ablate ALL receiver heads simultaneously, storing rollouts and similarities.
    """
    problem_id = problem_data["problem_id"]
    problem_text_prompt = problem_data["problem_statement"]
    allsentences = problem_data["sentences"]
    ground_truth_answer = problem_data["answer"]
    
    # Create output directory structure
    output_path = Path(output_dir)
    problem_dir = output_path / problem_id
    problem_dir.mkdir(exist_ok=True, parents=True)
    
    # Prepare problem text for generation
    problem_text = f"Solve this math problem step by step. You MUST put your final answer in \\boxed{{}}. Problem: {problem_text_prompt} Solution: \n<think>\n"
    
    print(f"\nProcessing problem {problem_id} with {len(allsentences)} sentences")
    print(f"Will ablate {len(receiver_heads)} receiver heads SIMULTANEOUSLY per sentence")
    
    all_ablation_results = []
    
    # Process each sentence
    for sentence_idx in range(len(allsentences)):
        print(f"\n--- Processing sentence {sentence_idx + 1}/{len(allsentences)} ---")
        print(f"Sentence: {allsentences[sentence_idx]}")
        
        # Get context without current sentence (same as baseline)
        prefix_text = allsentences[:sentence_idx]
        context_without = problem_text + " " + " ".join(prefix_text) + ""
        
        # Get function tags if available
        function_tags = []
        if sentence_idx < len(problem_labels.get('chunks', [])):
            function_tags = problem_labels['chunks'][sentence_idx].get('function_tags', [])
        
        # Get original sentence embeddings
        original_text = allsentences[sentence_idx]
        original_llm_emb, original_sent_emb = get_both_embeddings(model, sentence_model, tokenizer, original_text)
        
        # CREATE ABLATION HOOKS FOR ALL HEADS
        hooks = []
        print(f"  Adding hooks for {len(receiver_heads)} heads...")
        
        for layer_idx, head_num in receiver_heads:
            print(f"    Adding hook for head ({layer_idx}, {head_num})")
            
            # Define ablation hook for this specific head
            def create_ablation_hook(target_head_num):
                def ablation_hook(module, input, output):
                    attention_output = output[0]
                    batch_size_tensor, seq_len, hidden_dim = attention_output.shape
                    num_heads = module.num_attention_heads
                    head_dim = hidden_dim // num_heads
                    
                    # Zero out the specific head
                    reshaped = attention_output.view(batch_size_tensor, seq_len, num_heads, head_dim)
                    reshaped[:, :, target_head_num, :] = 0
                    modified = reshaped.view(batch_size_tensor, seq_len, hidden_dim)
                    
                    return (modified,) + output[1:]
                return ablation_hook
            
            # Register hook for this layer
            attention_layer = model.model.layers[layer_idx].self_attn
            hook = attention_layer.register_forward_hook(create_ablation_hook(head_num))
            hooks.append((hook, layer_idx, head_num))
        
        print(f"  Successfully added {len(hooks)} hooks. Running generation...")
        
        try:
            # Generate rollouts with ALL heads ablated simultaneously
            rollout_texts, rollout_answer_correct, rollout_answer_correct_check = generate_diverse_rollouts(
                model=model,
                tokenizer=tokenizer, 
                ground_truth_answer=ground_truth_answer,
                context=context_without,
                num_rollouts=num_rollouts,
                batch_size=batch_size,
                temperature=0.6,
                top_p=0.95,
                sentence_idx=sentence_idx
            )
            
            # Process rollout embeddings (same as baseline)
            rollout_sentences = [split_solution_into_chunks(rollout_text)[0] for rollout_text in rollout_texts]
            
            # Get all embeddings at once
            rollout_llm_embeddings, rollout_sentence_embeddings = get_both_embeddings_batch(
                model, sentence_model, tokenizer, rollout_sentences
            )
            
            # Calculate cosine similarities
            cos_sims_llm = torch.cosine_similarity(
                original_llm_emb.unsqueeze(0), rollout_llm_embeddings, dim=1
            )
            cos_sims_sent = torch.cosine_similarity(
                original_sent_emb.unsqueeze(0), rollout_sentence_embeddings, dim=1
            )
            
            # Calculate basic statistics
            unique_responses = len(set(rollout_texts))
            
            # Store ablation result for ALL heads ablated together
            ablation_result = {
                "problem_id": problem_id,
                "sentence_idx": sentence_idx,
                "sentence_text": original_text,
                "function_tags": function_tags,
                
                # Multi-head ablation info
                "ablated_heads": receiver_heads,  # List of all ablated heads
                "ablation_type": "multi_head_simultaneous",
                "num_ablated_heads": len(receiver_heads),
                
                # Context information
                "context_without_sentence": context_without,
                "ground_truth_answer": ground_truth_answer,
                
                # Rollout data - SAME AS YOUR BASELINE
                "num_rollouts": num_rollouts,
                # "rollout_texts": rollout_texts,  # Full rollout texts
                "rollout_sentences": rollout_sentences,
                "rollout_answer_correct": rollout_answer_correct,
                "rollout_answer_correct_check": rollout_answer_correct_check,
                "unique_responses": unique_responses,
                
                # Extracted boxed answers
                "rollout_boxed_answers": [extract_boxed_answers(text) for text in rollout_texts],
                
                # Raw similarity scores
                "cos_sims_llm": cos_sims_llm.cpu().tolist(),
                "cos_sims_sentence": cos_sims_sent.cpu().tolist(),
                
                # Embedding comparison
                "embedding_correlation": float(torch.corrcoef(torch.stack([cos_sims_llm, cos_sims_sent]))[0, 1]) if len(cos_sims_llm) > 1 else 0.0,
                
                # Metadata
                "generation_params": {
                    "temperature": 0.6,
                    "top_p": 0.95,
                    "batch_size": batch_size
                }
            }
            
            all_ablation_results.append(ablation_result)
            
            print(f"  Successfully generated {len(rollout_texts)} rollouts with all heads ablated")
            print(f"  Unique responses: {unique_responses}")
            
            # Cleanup
            del rollout_llm_embeddings, rollout_sentence_embeddings
            del cos_sims_llm, cos_sims_sent, rollout_sentences
            # rollout_texts will be garbage collected after saving
            gc.collect()
            torch.cuda.empty_cache()
            
        except Exception as e:
            print(f"    Error during multi-head ablation: {e}")
            # Store error result
            error_result = {
                "problem_id": problem_id,
                "sentence_idx": sentence_idx,
                "sentence_text": original_text,
                "ablated_heads": receiver_heads,
                "ablation_type": "multi_head_simultaneous",
                "error": str(e),
                "status": "failed"
            }
            all_ablation_results.append(error_result)
            
        finally:
            # ALWAYS remove all hooks
            print(f"  Removing {len(hooks)} hooks...")
            for hook, layer_idx, head_num in hooks:
                hook.remove()
            hooks.clear()
        
        # Cleanup after processing this sentence
        del original_llm_emb, original_sent_emb
        gc.collect()
        torch.cuda.empty_cache()
    
    # Save all results
    results_file = problem_dir / "sentence_multi_head_ablation_rollouts.json"
    with open(results_file, 'w', encoding='utf-8') as f:
        json.dump(all_ablation_results, f, indent=2, default=str)
    
    # Save summary
    summary = {
        "problem_id": problem_id,
        "total_sentences": len(allsentences),
        "ablated_heads": receiver_heads,
        "num_ablated_heads": len(receiver_heads),
        "ablation_type": "multi_head_simultaneous",
        "total_ablation_experiments": len(allsentences),  # One experiment per sentence
        "successful_experiments": len([r for r in all_ablation_results if "error" not in r]),
        "failed_experiments": len([r for r in all_ablation_results if "error" in r]),
        "generation_params": {
            "num_rollouts": num_rollouts,
            "batch_size": batch_size,
            "temperature": 0.6,
            "top_p": 0.95
        }
    }
    
    summary_file = problem_dir / "multi_head_ablation_summary.json"
    with open(summary_file, 'w', encoding='utf-8') as f:
        json.dump(summary, f, indent=2)
    
    print(f"\nCompleted problem {problem_id}")
    print(f"  - Total sentences: {len(allsentences)}")
    print(f"  - Heads ablated simultaneously: {len(receiver_heads)}")
    print(f"  - Total experiments: {len(allsentences)}")
    print(f"  - Successful: {summary['successful_experiments']}")
    print(f"  - Failed: {summary['failed_experiments']}")
    print(f"  - Results saved to: {results_file}")
    
    return all_ablation_results

In [39]:
def process_single_problem_with_multi_head_ablation(
    problem_data: Dict, 
    problem_labels: Dict,
    model, 
    tokenizer, 
    sentence_model,
    receiver_heads: List[tuple],  # List of (layer_idx, head_idx) tuples
    output_dir: str = "rollout_results_ablation",
    num_rollouts: int = 6,
    batch_size: int = 6,
    force: bool = False
) -> List[Dict]:
    """
    Process a single problem with receiver head ablation: loop through every sentence 
    and ablate ALL receiver heads simultaneously, storing rollouts and similarities.
    """

    problem_id = problem_data["problem_id"]
    problem_text_prompt = problem_data["problem_statement"]
    allsentences = problem_data["sentences"]
    ground_truth_answer = problem_data["answer"]
    
    # Create output directory structure
    output_path = Path(output_dir)
    problem_dir = output_path / problem_id
    problem_dir.mkdir(exist_ok=True, parents=True)

    # Save all results
    results_file = problem_dir / "sentence_multi_head_ablation_rollouts.json"
    
    # Prepare problem text for generation
    problem_text = f"Solve this math problem step by step. You MUST put your final answer in \\boxed{{}}. Problem: {problem_text_prompt} Solution: \n<think>\n"
    
    print(f"\nProcessing problem {problem_id} with {len(allsentences)} sentences")
    print(f"Will ablate {len(receiver_heads)} receiver heads SIMULTANEOUSLY per sentence")

    # Check if we can resume from existing results
    all_ablation_results = []
    start_sentence_idx = 0
    
    if results_file.exists() and not force:
        try:
            with open(results_file, 'r', encoding='utf-8') as f:
                existing_results = json.load(f)
            
            # Filter out failed results to find the last successful sentence
            successful_results = [r for r in existing_results if "error" not in r]
            
            if successful_results:
                all_ablation_results = existing_results
                start_sentence_idx = len(successful_results)
                print(f"Resuming from sentence {start_sentence_idx + 1}/{len(allsentences)}")
                print(f"Found {len(successful_results)} existing successful results")
            else:
                print(f"Found existing file but no successful results. Starting from beginning.")
                
        except (json.JSONDecodeError, KeyError) as e:
            print(f"Error reading existing results file: {e}. Starting from beginning.")
    
    # Process each sentence
    for sentence_idx in range(start_sentence_idx, len(allsentences)):
        print(f"\n--- Processing sentence {sentence_idx + 1}/{len(allsentences)} ---")
        print(f"Sentence: {allsentences[sentence_idx]}")
        
        # Get context without current sentence (same as baseline)
        prefix_text = allsentences[:sentence_idx]
        context_without = problem_text + " " + " ".join(prefix_text) + ""
        
        # Get function tags if available
        function_tags = []
        if sentence_idx < len(problem_labels.get('chunks', [])):
            function_tags = problem_labels['chunks'][sentence_idx].get('function_tags', [])
        
        # Get original sentence embeddings
        original_text = allsentences[sentence_idx]
        original_llm_emb, original_sent_emb = get_both_embeddings(model, sentence_model, tokenizer, original_text)
        
        # CREATE ABLATION HOOKS FOR ALL HEADS
        hooks = []
        print(f"  Adding hooks for {len(receiver_heads)} heads...")
        
        for layer_idx, head_num in receiver_heads:
            print(f"    Adding hook for head ({layer_idx}, {head_num})")
            
            # Define ablation hook for this specific head
            def create_ablation_hook(target_head_num):
                def ablation_hook(module, input, output):
                    attention_output = output[0]
                    batch_size_tensor, seq_len, hidden_dim = attention_output.shape
                    num_heads = getattr(module, 'num_heads', getattr(module, 'num_attention_heads', 32))
                    head_dim = hidden_dim // num_heads
                    
                    # Zero out the specific head
                    reshaped = attention_output.view(batch_size_tensor, seq_len, num_heads, head_dim)
                    reshaped[:, :, target_head_num, :] = 0
                    modified = reshaped.view(batch_size_tensor, seq_len, hidden_dim)
                    
                    return (modified,) + output[1:]
                return ablation_hook
            
            # Register hook for this layer
            attention_layer = model.model.layers[layer_idx].self_attn
            hook = attention_layer.register_forward_hook(create_ablation_hook(head_num))
            hooks.append((hook, layer_idx, head_num))
        
        print(f"  Successfully added {len(hooks)} hooks. Running generation...")
        
        try:
            # Generate rollouts with ALL heads ablated simultaneously
            rollout_texts, rollout_answer_correct, rollout_answer_correct_check = generate_diverse_rollouts(
                model=model,
                tokenizer=tokenizer, 
                ground_truth_answer=ground_truth_answer,
                context=context_without,
                num_rollouts=num_rollouts,
                batch_size=batch_size,
                temperature=0.6,
                top_p=0.95,
                sentence_idx=sentence_idx
            )
            
            # Process rollout embeddings (same as baseline)
            rollout_sentences = [split_solution_into_chunks(rollout_text)[0] for rollout_text in rollout_texts]
            
            # Get all embeddings at once
            rollout_llm_embeddings, rollout_sentence_embeddings = get_both_embeddings_batch(
                model, sentence_model, tokenizer, rollout_sentences
            )
            
            # Calculate cosine similarities
            cos_sims_llm = torch.cosine_similarity(
                original_llm_emb.unsqueeze(0), rollout_llm_embeddings, dim=1
            )
            cos_sims_sent = torch.cosine_similarity(
                original_sent_emb.unsqueeze(0), rollout_sentence_embeddings, dim=1
            )
            
            # Calculate basic statistics
            unique_responses = len(set(rollout_texts))
            
            # Store ablation result for ALL heads ablated together
            ablation_result = {
                "problem_id": problem_id,
                "sentence_idx": sentence_idx,
                "sentence_text": original_text,
                "function_tags": function_tags,
                
                # Multi-head ablation info
                "ablated_heads": receiver_heads,  # List of all ablated heads
                "ablation_type": "multi_head_simultaneous",
                "num_ablated_heads": len(receiver_heads),
                
                # Context information
                "context_without_sentence": context_without,
                "ground_truth_answer": ground_truth_answer,
                
                # Rollout data - SAME AS YOUR BASELINE
                "num_rollouts": num_rollouts,
                # "rollout_texts": rollout_texts,  # Full rollout texts
                "rollout_sentences": rollout_sentences,
                "rollout_answer_correct": rollout_answer_correct,
                "rollout_answer_correct_check": rollout_answer_correct_check,
                "unique_responses": unique_responses,
                
                # Extracted boxed answers
                "rollout_boxed_answers": [extract_boxed_answers(text) for text in rollout_texts],
                
                # Raw similarity scores
                "cos_sims_llm": cos_sims_llm.cpu().tolist(),
                "cos_sims_sentence": cos_sims_sent.cpu().tolist(),
                
                # Embedding comparison
                "embedding_correlation": float(torch.corrcoef(torch.stack([cos_sims_llm, cos_sims_sent]))[0, 1]) if len(cos_sims_llm) > 1 else 0.0,
                
                # Metadata
                "generation_params": {
                    "temperature": 0.6,
                    "top_p": 0.95,
                    "batch_size": batch_size
                }
            }
            
            all_ablation_results.append(ablation_result)
            
            print(f"  Successfully generated {len(rollout_texts)} rollouts with all heads ablated")
            print(f"  Unique responses: {unique_responses}")
            
            # Cleanup
            del rollout_llm_embeddings, rollout_sentence_embeddings
            del cos_sims_llm, cos_sims_sent, rollout_sentences
            # rollout_texts will be garbage collected after saving
            gc.collect()
            torch.cuda.empty_cache()
            
        except Exception as e:
            print(f"    Error during multi-head ablation: {e}")
            
            # üõ°Ô∏è Emergency hook cleanup on error
            print(f"  Emergency hook cleanup - removing {len(hooks)} hooks...")
            for hook, layer_idx, head_num in hooks:
                try:
                    hook.remove()
                except:
                    pass  # Hook might already be removed
            hooks.clear()
            torch.cuda.empty_cache()
            gc.collect()
            
            # Store error result
            error_result = {
                "problem_id": problem_id,
                "sentence_idx": sentence_idx,
                "sentence_text": original_text,
                "ablated_heads": receiver_heads,
                "ablation_type": "multi_head_simultaneous",
                "error": str(e),
                "status": "failed"
            }
            all_ablation_results.append(error_result)
            
        finally:
            # ALWAYS remove all hooks
            print(f"  Removing {len(hooks)} hooks...")
            for hook, layer_idx, head_num in hooks:
                try:
                    hook.remove()
                except:
                    pass  # Might already be removed in error handling
            hooks.clear()
        
        # üî• SAVE AFTER EVERY SENTENCE (FIXED!)
        print(f"  Saving results after sentence {sentence_idx + 1}...")
        with open(results_file, 'w', encoding='utf-8') as f:
            json.dump(all_ablation_results, f, indent=2, default=str)  # ‚úÖ Save FULL list!
        print(f"  Results saved to: {results_file}")
        
        # Cleanup after processing this sentence
        del original_llm_emb, original_sent_emb
        gc.collect()
        torch.cuda.empty_cache()
    
    # Final save (this is now redundant but harmless)
    with open(results_file, 'w', encoding='utf-8') as f:
        json.dump(all_ablation_results, f, indent=2, default=str)
    
    # Save summary
    summary = {
        "problem_id": problem_id,
        "total_sentences": len(allsentences),
        "ablated_heads": receiver_heads,
        "num_ablated_heads": len(receiver_heads),
        "ablation_type": "multi_head_simultaneous",
        "total_ablation_experiments": len(allsentences),  # One experiment per sentence
        "successful_experiments": len([r for r in all_ablation_results if "error" not in r]),
        "failed_experiments": len([r for r in all_ablation_results if "error" in r]),
        "generation_params": {
            "num_rollouts": num_rollouts,
            "batch_size": batch_size,
            "temperature": 0.6,
            "top_p": 0.95
        }
    }
    
    summary_file = problem_dir / "multi_head_ablation_summary.json"
    with open(summary_file, 'w', encoding='utf-8') as f:
        json.dump(summary, f, indent=2)
    
    print(f"\nCompleted problem {problem_id}")
    print(f"  - Total sentences: {len(allsentences)}")
    print(f"  - Heads ablated simultaneously: {len(receiver_heads)}")
    print(f"  - Total experiments: {len(allsentences)}")
    print(f"  - Successful: {summary['successful_experiments']}")
    print(f"  - Failed: {summary['failed_experiments']}")
    print(f"  - Results saved to: {results_file}")
    
    return all_ablation_results

In [40]:
results = process_single_problem_with_multi_head_ablation(
    problem_data=all_prompt[0],
    problem_labels=all_labels[0], 
    model=model,
    tokenizer=tokenizer,
    sentence_model=sentence_model,
    receiver_heads=reciever_heads,  # Your 20 heads
    output_dir="rollout_results_ablation",
    num_rollouts=6,
    batch_size=6
)


Processing problem problem_6481 with 187 sentences
Will ablate 20 receiver heads SIMULTANEOUSLY per sentence
Resuming from sentence 188/187
Found 187 existing successful results

Completed problem problem_6481
  - Total sentences: 187
  - Heads ablated simultaneously: 20
  - Total experiments: 187
  - Successful: 187
  - Failed: 0
  - Results saved to: rollout_results_ablation/problem_6481/sentence_multi_head_ablation_rollouts.json


In [None]:
results = process_single_problem_with_multi_head_ablation(
    problem_data=all_prompt[0],
    problem_labels=all_labels[0], 
    model=model,
    tokenizer=tokenizer,
    sentence_model=sentence_model,
    receiver_heads=reciever_heads,  # Your 20 heads
    output_dir="2_rollout_results_ablation",
    num_rollouts=6,
    batch_size=6
)


Processing problem problem_6481 with 187 sentences
Will ablate 20 receiver heads SIMULTANEOUSLY per sentence

--- Processing sentence 1/187 ---
Sentence: Okay, so I have this problem about a square with an area of 81 square units.


  return forward_call(*args, **kwargs)


  Adding hooks for 20 heads...
    Adding hook for head (29, 24)
    Adding hook for head (17, 0)
    Adding hook for head (24, 7)
    Adding hook for head (25, 22)
    Adding hook for head (23, 8)
    Adding hook for head (18, 12)
    Adding hook for head (23, 23)
    Adding hook for head (21, 4)
    Adding hook for head (19, 17)
    Adding hook for head (18, 14)
    Adding hook for head (30, 17)
    Adding hook for head (19, 27)
    Adding hook for head (28, 22)
    Adding hook for head (1, 17)
    Adding hook for head (27, 1)
    Adding hook for head (24, 1)
    Adding hook for head (26, 10)
    Adding hook for head (26, 24)
    Adding hook for head (1, 16)
    Adding hook for head (24, 5)
  Successfully added 20 hooks. Running generation...
  Successfully generated 6 rollouts with all heads ablated
  Unique responses: 6
  Removing 20 hooks...
  Saving results after sentence 1...
  Results saved to: 2_rollout_results_ablation/problem_6481/sentence_multi_head_ablation_rollouts.json



: 