In [None]:
import os

import json


def process_problem_data(base_path):

    """

    Iterates through all problem directories, extracts problem statements

    and sentences from `chunks_labeled.json`, and returns a list of dictionaries.


    Args:

        base_path (str): The path to the directory containing all the problems

                         (e.g., 'math-rollouts/.../correct_base_solution').


    Returns:

        list: A list of dictionaries, where each dictionary contains the problem

              and all sentences for a given problem directory.

    """

    all_problem_data = []


    # Check if the base path exists

    if not os.path.isdir(base_path):

        print(f"Error: The directory '{base_path}' was not found.")

        return all_problem_data

    print(f"Found problem directory: {base_path}")


    # List all entries in the base directory

    problem_dirs = [d for d in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, d))]

    print(f"Found problem directory: {problem_dirs}")


    if not problem_dirs:

        print(f"No problem directories found in '{base_path}'.")

        return all_problem_data


    # Iterate through each problem directory (e.g., problem_330, problem_1591)

    for problem_name in problem_dirs:

        problem_path = os.path.join(base_path, problem_name)

       

        # Define the file paths for the problem and chunks

        problem_file = os.path.join(problem_path, "problem.json")

        chunks_file = os.path.join(problem_path, "chunks_labeled.json")

       

        problem_text = ""

        allsentences = []

       

        # Load the problem statement

        try:

            with open(problem_file, 'r') as f:

                problem_data = json.load(f)

                problem_text = problem_data.get("problem", "")

        except (FileNotFoundError, json.JSONDecodeError) as e:

            print(f"Skipping {problem_name}: Could not load problem.json. Error: {e}")

            continue


        # Load all sentences from chunks_labeled.json

        try:

            with open(chunks_file, 'r') as f:

                chunks_data = json.load(f)

                allsentences = [chunk["chunk"] for chunk in chunks_data]

        except (FileNotFoundError, json.JSONDecodeError) as e:

            print(f"Skipping {problem_name}: Could not load chunks_labeled.json. Error: {e}")

            continue


        # Create a dictionary to store the extracted data

        problem_info = {

            "problem_id": problem_name,

            "problem_statement": problem_text,

            "sentences": allsentences

        }

        all_problem_data.append(problem_info)


    return all_problem_data

    print("No data was loaded.")




# Define the base directory for all problems

base_problem_dir = "math-rollouts/deepseek-r1-distill-llama-8b/temperature_0.6_top_p_0.95/correct_base_solution"

# Run the function to get all the data

correct_all_data = process_problem_data(base_problem_dir)


# Now, `all_data` is a list of dictionaries. You can iterate through it.

print(f"Successfully loaded data for {len(correct_all_data)} problems.")

# Define the base directory for all problems

base_problem_dir = "math-rollouts/deepseek-r1-distill-llama-8b/temperature_0.6_top_p_0.95/incorrect_base_solution"

# Run the function to get all the data

incorrect_all_data = process_problem_data(base_problem_dir)


print(f"Successfully loaded data for {len(incorrect_all_data)} problems.") 

In [None]:
all_prompt = correct_all_data + incorrect_all_data

In [None]:
import os
import json

def process_problem_labels_minimal(base_path):
    """
    Iterates through all problem directories, extracts selected fields from each chunk
    in `chunks_labeled.json`, and returns a list of dictionaries.

    Args:
        base_path (str): The path to the directory containing all the problems.

    Returns:
        list: A list of dictionaries, each with problem_id and a list of selected chunk fields.
    """
    all_problem_data = []

    if not os.path.isdir(base_path):
        print(f"Error: The directory '{base_path}' was not found.")
        return all_problem_data

    problem_dirs = [d for d in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, d))]
    if not problem_dirs:
        print(f"No problem directories found in '{base_path}'.")
        return all_problem_data

    for problem_name in problem_dirs:
        problem_path = os.path.join(base_path, problem_name)
        chunks_file = os.path.join(problem_path, "chunks_labeled.json")
        if not os.path.isfile(chunks_file):
            continue

        try:
            with open(chunks_file, "r") as f:
                chunks_data = json.load(f)
        except Exception as e:
            print(f"Skipping {problem_name}: {e}")
            continue

        # Extract only the fields you care about from each chunk
        selected_chunks = []
        for chunk in chunks_data:
            selected = {
                "function_tags": chunk.get("function_tags"),
                "chunk": chunk.get("chunk"),
                "accuracy": chunk.get("accuracy"),
                "resampling_importance_accuracy": chunk.get("resampling_importance_accuracy"),
                "resampling_importance_kl": chunk.get("resampling_importance_kl"),
                "counterfactual_importance_accuracy": chunk.get("counterfactual_importance_accuracy"),
                "counterfactual_importance_kl": chunk.get("counterfactual_importance_kl"),
                
                "summary": chunk.get("summary"),
            }
            selected_chunks.append(selected)

        all_problem_data.append({
            "problem_id": problem_name,
            "chunks": selected_chunks
        })

    return all_problem_data

# Example usage:
base_problem_dir = "math-rollouts/deepseek-r1-distill-llama-8b/temperature_0.6_top_p_0.95/correct_base_solution"
correct_all_data = process_problem_labels_minimal(base_problem_dir)
print(f"Loaded {len(correct_all_data)} problems with selected chunk fields.")
base_problem_dir = "math-rollouts/deepseek-r1-distill-llama-8b/temperature_0.6_top_p_0.95/incorrect_base_solution"
incorrect_all_data = process_problem_labels_minimal(base_problem_dir)
print(f"Loaded {len(incorrect_all_data)} problems with selected chunk fields.")

In [None]:
all_problem_labels = correct_all_data + incorrect_all_data

In [None]:
all_problem_labels

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, BitsAndBytesConfig, AutoModelForCausalLM, pipeline

import torch


model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" # Or any other suitable model

mname = model_name

tokenizer = AutoTokenizer.from_pretrained(model_name)

# Important: Add a pad token if the tokenizer doesn't have one, especially for decoder models.

if tokenizer.pad_token is None:

    tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})

In [None]:
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")

In [None]:
#want do a loop through all the chunks append it to the current text set it through multiple rollouts and measure counterfactual importance

In [None]:
import torch
import torch.nn.functional as F
import numpy as np

In [None]:
def cosine_similarity(a, b):
    """
    a: shape (n_samples, n_features)
    b: shape (1, n_features) or (n_features,)
    Returns: shape (n_samples,)
    """
    a_norm = np.linalg.norm(a, axis=1, keepdims=True)
    b_norm = np.linalg.norm(b)
    sim = np.dot(a, b.reshape(-1)) / (a_norm.flatten() * b_norm + 1e-8)
    return sim

In [None]:
all_prompt

In [None]:
all_problem_labels


In [None]:
import re
import sys
import json
import math
from collections import defaultdict

def normalize_answer(answer):
    """Normalize numerical answer for comparison."""
    if answer is None:
        return ""
    
    # Convert to string if not already
    answer = str(answer).strip()
    
    # Remove common prefixes/suffixes
    answer = re.sub(r'^(the answer is|answer:|final answer:)\s*', '', answer.lower())
    answer = re.sub(r'\s*(dollars?|cents?|\$|€|£)\s*$', '', answer)
    
    # Extract numerical value
    number_match = re.search(r'([+-]?\d*\.?\d+)', answer)
    if number_match:
        try:
            # Try to convert to float then back to remove trailing zeros
            num = float(number_match.group(1))
            if num.is_integer():
                return str(int(num))
            else:
                return str(num)
        except ValueError:
            pass
    
    # If no number found, return original cleaned answer
    return answer.strip()


def extract_final_answer(text):
    """Extract the final numerical answer from model output."""
    patterns = [
        r"(?:the answer is|answer:|final answer:)\s*([+-]?\d*\.?\d+)",
        r"([+-]?\d*\.?\d+)\s*$",  # Number at end
        r"\\boxed\{([^}]+)\}",    # LaTeX boxed answer
    ]
    
    for pattern in patterns:
        match = re.search(pattern, text.lower())
        if match:
            return match.group(1).strip()
    return None

def get_ground_truth_answer(problem_id, base_dirs):
    """Get ground truth answer from problem.json"""
    for base_dir in base_dirs:
        problem_path = f"{base_dir}/{problem_id}/problem.json"
        try:
            with open(problem_path, 'r') as f:
                problem_data = json.load(f)
                return problem_data.get("answer", None)
        except:
            continue
    return None

In [None]:
def get_probs(model, inputs):
    """Get probability distribution from model."""
    device = next(model.parameters()).device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        logits = model(**inputs).logits[0, -1, :]  # Last token logits
        probs = torch.softmax(logits, dim=-1)
    
    return probs

In [None]:
def get_embeddings(model, inputs, layer=-1):
    """Get hidden state embeddings from a specific layer."""
    device = next(model.parameters()).device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
        # Get embeddings and convert to float32 on GPU
        embeddings = outputs.hidden_states[layer][0, -1, :].float()  # Convert to float32
    
    return embeddings  # Keep on GPU

def generate_diverse_rollouts(model, tokenizer, context, num_rollouts=10, batch_size=5, temperature=0.8, top_p=0.9):
    """Generate diverse text completions in batches for better GPU utilization."""
    device = next(model.parameters()).device
    inputs = tokenizer(context, return_tensors="pt", max_length=1500, truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    rollout_texts = []
    rollout_embeddings = []
    
    # Process in batches
    for batch_start in range(0, num_rollouts, batch_size):
        batch_end = min(batch_start + batch_size, num_rollouts)
        current_batch_size = batch_end - batch_start
        
        # Expand inputs for batch processing
        batch_inputs = {
            'input_ids': inputs['input_ids'].repeat(current_batch_size, 1),
            'attention_mask': inputs['attention_mask'].repeat(current_batch_size, 1)
        }
        
        with torch.no_grad():
            outputs = model.generate(
                batch_inputs['input_ids'],
                attention_mask=batch_inputs['attention_mask'],
                max_new_tokens=50,
                do_sample=True,
                temperature=temperature,
                top_p=top_p,
                pad_token_id=tokenizer.eos_token_id,
                eos_token_id=tokenizer.eos_token_id,
                output_hidden_states=True,
                return_dict_in_generate=True
            )
        
        # Process each sequence in the batch
        for i in range(current_batch_size):
            # Decode generated text
            generated_ids = outputs.sequences[i][inputs['input_ids'].shape[1]:]
            generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
            rollout_texts.append(generated_text.strip())
            
            # Get embeddings and convert to float32, keep on GPU
            final_hidden = outputs.hidden_states[-1][-1][i, -1, :].float().cpu()  # Convert to float32
            rollout_embeddings.append(final_hidden)
    
    # Stack embeddings into a single tensor on GPU
    return rollout_texts, torch.stack(rollout_embeddings)


In [None]:
# Updated main processing loop
num_rollouts = 10  # Reduced for actual generation
cos_sim_threshold = 0.8
results = []
for prompt, label in zip(all_prompt[:1], all_problem_labels[:1]):  # Test with 1 problem first
    problem_id = prompt["problem_id"]
    problem_text = prompt["problem_statement"]
    allsentences = prompt["sentences"]
   
    # Get baseline (original full context) - CONVERT TO NUMPY
    context_original = problem_text + " " + " ".join(allsentences)
    inputs_original = tokenizer(context_original, return_tensors="pt", max_length=2048, truncation=True)
    original_embedding = get_embeddings(model, inputs_original).cpu().numpy()  # ADD .cpu().numpy()
   
    for i, chunk in enumerate(allsentences[:5]):  # Test first 5 chunks
        print(f"Processing chunk {i+1}/{min(5, len(allsentences))}: {chunk[:50]}...")
       
        # Context WITHOUT the chunk (this is what thought-anchors does)
        chunks_without = allsentences[:i] + allsentences[i+1:]
        context_without = problem_text + " " + " ".join(chunks_without)
       
        # Generate diverse rollouts from context without chunk
        rollout_texts, rollout_embeddings = generate_diverse_rollouts(
            model, tokenizer, context_without, num_rollouts=num_rollouts
        )
       
        # CONVERT ROLLOUT EMBEDDINGS TO NUMPY IF THEY'RE TENSORS
        if torch.is_tensor(rollout_embeddings):
            rollout_embeddings = rollout_embeddings.cpu().numpy()
       
        # Calculate cosine similarities between rollouts and original
        cos_sims = []
        for rollout_emb in rollout_embeddings:
            cos_sim = np.dot(original_embedding, rollout_emb) / (
                np.linalg.norm(original_embedding) * np.linalg.norm(rollout_emb) + 1e-8
            )
            cos_sims.append(cos_sim)
       
        cos_sims = np.array(cos_sims)
        similar_mask = cos_sims > cos_sim_threshold
        not_similar_mask = ~similar_mask
       
        # Calculate importance metrics
        avg_cos_sim = np.mean(cos_sims)
        num_different = int(not_similar_mask.sum())
       
        # KL divergence between "different" rollouts and original (if any different ones exist)
        kl_divergence = None
        if num_different > 0:
            # For embeddings, we can use cosine distance as a proxy for KL
            different_cos_sims = cos_sims[not_similar_mask]
            kl_divergence = float(np.mean(1 - different_cos_sims))  # 1 - cosine similarity
       
        # Count unique responses
        unique_responses = len(set(rollout_texts))
       
        results.append({
            "chunk_index": i,
            "chunk_text": chunk,
            "original_context_length": len(context_original),
            "without_chunk_context_length": len(context_without),
            # Rollout analysis
            "num_rollouts": num_rollouts,
            "unique_responses": unique_responses,
            "rollout_texts": rollout_texts[:3],  # Store first 3 for inspection
            # Similarity analysis
            "avg_cosine_similarity": float(avg_cos_sim),
            "num_similar_to_original": int(similar_mask.sum()),
            "num_different_from_original": num_different,
            "counterfactual_importance": kl_divergence,
            # All cosine similarities for detailed analysis
            "all_cosine_similarities": cos_sims.tolist()
        })
       
        print(f"  Generated {unique_responses}/{num_rollouts} unique responses")
        print(f"  Avg cosine sim: {avg_cos_sim:.3f}, Different: {num_different}/{num_rollouts}")

# Print results
for r in results:
    print(f"\nChunk {r['chunk_index']}: {r['chunk_text'][:100]}...")
    print(f"  Unique responses: {r['unique_responses']}/{r['num_rollouts']}")
    print(f"  Counterfactual importance: {r['counterfactual_importance']}")
    print(f"  Sample rollouts: {r['rollout_texts'][:2]}")
print(f"\nProcessed {len(results)} chunks with diverse generation")

In [None]:
len(results)

: 

In [None]:
import json

# Save results to a JSON file
with open("counterfactual_importance_results_control.json", "w") as f:
    json.dump(results, f, indent=2)
print("Results saved to counterfactual_importance_results_control.json")