In [1]:
import random
from functools import partial

import torch
from num2words import num2words

# from transformer_lens import HookedTransformer
from sae_lens import SAE, ActivationsStore, HookedSAETransformer
from tqdm.autonotebook import tqdm
from transformers.utils.logging import disable_progress_bar

from sae_cooccurrence.utils.set_paths import get_git_root

disable_progress_bar()

  warn(


In [2]:
torch.set_grad_enabled(False)

git_root = get_git_root()

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
# from transformer_lens import HookedTransformer

model = HookedSAETransformer.from_pretrained("gemma-2-2b", device=device)

# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience.
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gemma-scope-2b-pt-res-canonical",  # <- Release name
    sae_id="layer_12/width_16k/canonical",  # <- SAE id (not always a hook point!)
    device=device,
)

activation_store = ActivationsStore.from_sae(
    model=model,
    sae=sae,
    streaming=True,
    # fairly conservative parameters here so can use same for larger
    # models without running out of memory.
    store_batch_size_prompts=8,
    train_batch_size_tokens=4096,
    n_batches_in_buffer=4,
    device=device,
)



Loaded pretrained model gemma-2-2b into HookedTransformer


Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]



In [5]:
import random

from num2words import num2words


class YesNoTokenQuestionGenerator:
    def __init__(self):
        self.colors = ["black", "white"]
        self.quantifiers = ["one", "some", "all"]
        self.sections = ["basic", "numeric", "special"]
        
    def _number_to_words(self, n: int) -> str:
        """Convert a number to words."""
        return num2words(n)
        
    def _questions_are_equivalent(self, q1: str, q2: str) -> bool:
        """Compare two questions to check if they are functionally equivalent."""
        q1_parts = [p.lower() for p in q1.split() if p.lower() not in ["q:", "a:", "and", "are", "of", "them", "my", "tokens", "have", "how", "many"]]
        q2_parts = [p.lower() for p in q2.split() if p.lower() not in ["q:", "a:", "and", "are", "of", "them", "my", "tokens", "have", "how", "many"]]
        return q1_parts == q2_parts

    def _determine_answer(self, n_total: int, held_quant: str, test_quant: str, held_color: str, test_color: str) -> str:
        """Determine if the answer should be Y or N based on the question parameters."""
        # Special case for n_total = 1
        if n_total == 1:
            # With one token, "some" and "one" are equivalent to having that single token
            # "all" is also equivalent for the held color
            has_held = held_quant in ["one", "some", "all"]
            if has_held:
                # If we have the held color, we can't have the other color
                return "Y" if held_color == test_color and test_quant in ["one", "some", "all"] else "N"
            return "N"
            
        if held_color == test_color:
            # Same color questions
            if held_quant == "all":
                # If all are the color, then any quantifier for same color is true
                return "Y"
            elif test_quant == "all":
                # If not all are the color (previous condition), then "all" must be false
                return "N"
            elif held_quant == "one" and test_quant == "some":
                # If one is the color, then some is also true
                return "Y"
            elif held_quant == test_quant:
                # Same quantifier for same color
                return "Y"
            return "N"
        else:
            # Different color questions
            if held_quant == "all" or test_quant == "all":
                # If all are one color, none can be the other
                return "N"
            elif held_quant == "some":
                # If some are held color, then not all tokens are available for test color
                return "Y" if test_quant in ["one", "some"] else "N"
            elif held_quant == "one":
                # If one is held color, others can be test color unless n_total = 1
                return "Y" if test_quant in ["one", "some"] else "N"
            return "Y"

    def generate_question(self, force_type: str | None = None) -> tuple[str, str]:
        """Generate a single question with specified parameters."""
        n_total = random.randint(1, 10)
        colors = self.colors.copy()
        
        if force_type == "basic":
            held_quant = "all"
            test_quant = random.choice(["one", "some", "all"])
            # 50% chance to ask about the same color when we know all are that color
            if random.random() < 0.5:
                held_color = random.choice(colors)
                test_color = held_color
            else:
                held_color, test_color = random.sample(colors, 2)
        elif force_type == "numeric":
            held_quant = "one"
            test_quant = random.choice(["one", "some"])
            held_color, test_color = random.sample(colors, 2)
        elif force_type == "special":
            held_quant = "some"
            test_quant = random.choice(["one", "some", "all"])
            held_color, test_color = random.sample(colors, 2)
        else:
            held_quant = random.choice(self.quantifiers)
            test_quant = random.choice(self.quantifiers)
            # If we're saying "all", 50% chance to ask about the same color
            if held_quant == "all" and random.random() < 0.5:
                held_color = random.choice(colors)
                test_color = held_color
            else:
                held_color, test_color = random.sample(colors, 2)
        
        question = (
            f"Q: I have {self._number_to_words(n_total)} tokens, and "
            f"{held_quant} of them are {held_color}. "
            f"Are {test_quant} of them {test_color}?"
        )
        
        answer = f"A: {self._determine_answer(n_total, held_quant, test_quant, held_color, test_color)}"
        return question, answer

    def generate_training_set(self, section_counts: dict) -> list[tuple[str, str, str]]:
        """Generate a structured training set with sections."""
        training_set = []
        
        # Basic cases (all vs others)
        for _ in range(section_counts.get("basic", 0)):
            q, a = self.generate_question(force_type="basic")
            training_set.append(("# Basic cases - using 'all'", q, a))
            
        # Numeric cases
        for _ in range(section_counts.get("numeric", 0)):
            q, a = self.generate_question(force_type="numeric")
            training_set.append(("# Cases with 'one'", q, a))
            
        # Special cases (some)
        for _ in range(section_counts.get("special", 0)):
            q, a = self.generate_question(force_type="special")
            training_set.append(("# Cases with 'some'", q, a))
            
        return training_set

def generate_training_and_test(
    section_counts: dict = {
        "basic": 2,    # all vs others cases
        "numeric": 3,  # one vs some/all cases
        "special": 2   # some vs others cases
    },
    force_test_type: str | None = None,
    max_attempts: int = 100
) -> dict:
    """Generate a structured training set and test question."""
    generator = YesNoTokenQuestionGenerator()
    
    # Generate test question
    test_question, test_answer = generator.generate_question(force_type=force_test_type)
    
    # Generate training set with sections
    training_set = []
    attempts = 0
    
    while len(training_set) < sum(section_counts.values()) and attempts < max_attempts:
        current_set = generator.generate_training_set(section_counts)
        
        # Filter out any questions that match the test question
        filtered_set = [
            (section, q, a)
            for section, q, a in current_set
            if not generator._questions_are_equivalent(q, test_question)
        ]
        
        if len(filtered_set) == sum(section_counts.values()):
            training_set = filtered_set
            break
            
        attempts += 1
    
    if attempts >= max_attempts:
        raise RuntimeError("Failed to generate unique training set after maximum attempts")
    
    # Format the introduction
    introduction = """
Tokens can be either black or white. Answer Y for yes and N for no:

Remember:
- Only use Y or N to answer
- 'some' means at least one but not all
- 'one' means exactly one
- 'all' means every single token
"""
    
    return {
        "introduction": introduction.strip(),
        "training_questions": training_set,
        "test_question": test_question,
        "test_answer": test_answer
    }

# Example usage
if __name__ == "__main__":
    # Custom section counts
    section_counts = {
        "basic": 2,    # all vs others cases
        "numeric": 3,  # one vs some/all cases
        "special": 2   # some vs others cases
    }
    
    result = generate_training_and_test(
        section_counts=section_counts,
        force_test_type="numeric"  # Can be 'numeric', 'special', 'basic', or None
    )
    
    print(result["introduction"])
    print("\nTraining Questions:")
    for section, q, a in result["training_questions"]:
        if section:
            print(f"\n{section}")
        print(f"{q}\n{a}\n")
    print("\nTest Question:")
    print(result["test_question"])
    print("\nExpected Answer:")
    print(result["test_answer"])

Tokens can be either black or white. Answer Y for yes and N for no:

Remember:
- Only use Y or N to answer
- 'some' means at least one but not all
- 'one' means exactly one
- 'all' means every single token

Training Questions:

# Basic cases - using 'all'
Q: I have five tokens, and all of them are black. Are all of them white?
A: N


# Basic cases - using 'all'
Q: I have two tokens, and all of them are black. Are one of them black?
A: Y


# Cases with 'one'
Q: I have nine tokens, and one of them are white. Are one of them black?
A: Y


# Cases with 'one'
Q: I have one tokens, and one of them are white. Are some of them black?
A: N


# Cases with 'one'
Q: I have five tokens, and one of them are white. Are one of them black?
A: Y


# Cases with 'some'
Q: I have six tokens, and some of them are black. Are all of them white?
A: N


# Cases with 'some'
Q: I have one tokens, and some of them are black. Are some of them white?
A: N


Test Question:
Q: I have two tokens, and one of them are wh

In [11]:



def find_max_activation(model, sae, activation_store, feature_idx, num_batches=100):
    """
    Find the maximum activation for a given feature index. This is useful for
    calibrating the right amount of the feature to add.
    """
    max_activation = 0.0

    pbar = tqdm(range(num_batches))
    for _ in pbar:
        tokens = activation_store.get_batch_tokens()

        _, cache = model.run_with_cache(
            tokens,
            stop_at_layer=sae.cfg.hook_layer + 1,
            names_filter=[sae.cfg.hook_name],
        )
        sae_in = cache[sae.cfg.hook_name]
        feature_acts = sae.encode(sae_in).squeeze()

        feature_acts = feature_acts.flatten(0, 1)
        batch_max_activation = feature_acts[:, feature_idx].max().item()
        max_activation = max(max_activation, batch_max_activation)

        pbar.set_description(f"Max activation: {max_activation:.4f}")

    return max_activation


def steering(activations, steering_strength=1.0, steering_vector=None, max_act=1.0, hook=None):  # noqa: ARG001
    """
    Apply steering to activations.
    
    Args:
        activations: Input activations
        steering_strength: Strength of steering effect
        steering_vector: Vector to steer towards
        max_act: Maximum activation value
        hook: Hook object (ignored but needed for compatibility)
    
    Returns:
        Modified activations
    """
    # Note if the feature fires anyway, we'd be adding to that here.
    return activations + max_act * steering_strength * steering_vector 


def generate_with_steering(
    model,
    sae,
    prompt,
    steering_feature,
    max_act,
    steering_strength=1.0,
    max_new_tokens=95,
):
    input_ids = model.to_tokens(prompt, prepend_bos=sae.cfg.prepend_bos)

    steering_vector = sae.W_dec[steering_feature].to(model.cfg.device)

    steering_hook = partial(
        steering,
        steering_vector=steering_vector,
        steering_strength=steering_strength,
        max_act=max_act,
    )

    # standard transformerlens syntax for a hook context for generation
    with model.hooks(fwd_hooks=[(sae.cfg.hook_name, steering_hook)]):
        output = model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            temperature=0.7,
            top_p=0.9,
            stop_at_eos=False if device == "mps" else True,
            prepend_bos=sae.cfg.prepend_bos,
        )

    return model.tokenizer.decode(output[0])

In [6]:
def evaluate_model_with_intervention(
    model: HookedSAETransformer,
    sae: SAE,
    feature_ids: int | list[int] | None = None,
    intervention_type: str = "none",  # "none", "steering", or "ablation"
    steering_strength: float = 1.0,
    max_act: float = 60.0,
    n_numeric: int = 10,
    n_special: int = 10,
    max_new_tokens: int = 6,
    temperature: float = 0.0,
    top_p: float = 0.9,
    section_counts: dict | None = None,
) -> dict:
    """
    Test model accuracy on token yes/no tasks with optional feature steering or ablation.
    Lenient scoring accepts various forms of yes/no answers.
    """
    if section_counts is None:
        section_counts = {
            "basic": 2,     # all vs others cases
            "numeric": 3,   # one vs some/all cases
            "special": 2    # some vs others cases
        }

    results = {
        "numeric_correct_strict": 0,
        "numeric_correct_lenient": 0,
        "special_correct_strict": 0,
        "special_correct_lenient": 0,
        "numeric_tests": [],
        "special_tests": [],
    }

    def check_answer_lenient(generated: str, expected: str) -> bool:
        """
        Check if answer is correct under lenient scoring rules.
        Accepts various forms of yes/no answers.
        
        Args:
            generated: Generated answer text
            expected: Expected answer text
            
        Returns:
            bool: Whether the answer is correct under lenient scoring
        """
        # Convert both to lowercase for comparison
        generated = generated.lower().strip()
        expected = expected.lower().strip()
        
        # Extract the Y/N from the expected answer
        expected_yn = expected.split(":")[-1].strip() if ":" in expected else expected
        
        # Define valid positive and negative responses
        positive_responses = ["y", "yes", "true", "correct"]
        negative_responses = ["n", "no", "false", "incorrect"]
        
        # Check if the expected answer is positive or negative
        is_expected_positive = expected_yn == "y"
        
        # Check if any valid form of the answer appears in the generated text
        valid_responses = positive_responses if is_expected_positive else negative_responses
        return any(resp in generated.lower() for resp in valid_responses)

    # Convert single feature_id to list for consistency
    if isinstance(feature_ids, int):
        feature_ids = [feature_ids]

    def ablate_feature_hook(feature_activations, hook=None, feature_ids=None):  # noqa: ARG001
        feature_activations[:, :, feature_ids] = 0
        return feature_activations

    for test_type in tqdm(["numeric", "special"], desc="Testing types"):
        n_tests = n_numeric if test_type == "numeric" else n_special
        for _ in tqdm(range(n_tests), desc=f"Testing {test_type}"):
            # Use new structured prompt generator
            test_data = generate_training_and_test(
                section_counts=section_counts,
                force_test_type=test_type,
            )

            # Construct prompt with sections
            prompt_parts = [test_data["introduction"]]
            
            # Group questions by section
            sections = {}
            for section, q, a in test_data["training_questions"]:
                if section not in sections:
                    sections[section] = []
                sections[section].append((q, a))
            
            # Add each section with its header
            for section, questions in sections.items():
                prompt_parts.append(f"\n\n{section}")
                for q, a in questions:
                    prompt_parts.append(f"{q}\n{a}")
            
            # Add test question
            prompt_parts.append(f"\n\n{test_data['test_question']}\nA: ")
            
            prompt = "\n\n".join(prompt_parts)

            # Convert prompt to tokens
            input_ids = model.to_tokens(prompt, prepend_bos=sae.cfg.prepend_bos)

            if intervention_type == "none" or feature_ids is None:
                output = model.generate(
                    input_ids,
                    max_new_tokens=max_new_tokens,
                    temperature=temperature,
                    top_p=top_p,
                    stop_at_eos=False if device == "mps" else True,
                    prepend_bos=sae.cfg.prepend_bos,
                )

            elif intervention_type == "steering":
                steering_vector = sae.W_dec[feature_ids[0]].to(model.cfg.device)
                steering_hook = partial(
                    steering,
                    steering_vector=steering_vector,
                    steering_strength=steering_strength,
                    max_act=max_act,
                )

                with model.hooks(fwd_hooks=[(sae.cfg.hook_name, steering_hook)]):
                    output = model.generate(
                        input_ids,
                        max_new_tokens=max_new_tokens,
                        temperature=temperature,
                        top_p=top_p,
                        stop_at_eos=False if device == "mps" else True,
                        prepend_bos=sae.cfg.prepend_bos,
                    )

            elif intervention_type == "ablation":
                ablation_hook = partial(ablate_feature_hook, feature_ids=feature_ids)
                model.add_sae(sae)
                hook_point = sae.cfg.hook_name + ".hook_sae_acts_post"

                with model.hooks(fwd_hooks=[(hook_point, ablation_hook)]):
                    output = model.generate(
                        input_ids,
                        max_new_tokens=max_new_tokens,
                        top_p=top_p,
                        temperature=temperature,
                        stop_at_eos=False if device == "mps" else True,
                        prepend_bos=sae.cfg.prepend_bos,
                    )

                model.reset_hooks()
                model.reset_saes()

            generated = model.tokenizer.decode(output[0])  # type: ignore
            generated_answer = generated.split("A: ")[-1].strip()
            expected_answer = test_data["test_answer"].replace("A: ", "").strip()

            test_result = {
                "prompt": prompt,
                "expected": expected_answer,
                "generated": generated_answer,
                "correct_strict": expected_answer in generated_answer,
                "correct_lenient": check_answer_lenient(generated_answer, expected_answer),
                "section_counts": section_counts,
            }

            if test_type == "numeric":
                results["numeric_tests"].append(test_result)
                if test_result["correct_strict"]:
                    results["numeric_correct_strict"] += 1
                if test_result["correct_lenient"]:
                    results["numeric_correct_lenient"] += 1
            else:
                results["special_tests"].append(test_result)
                if test_result["correct_strict"]:
                    results["special_correct_strict"] += 1
                if test_result["correct_lenient"]:
                    results["special_correct_lenient"] += 1

    # Calculate accuracies
    results["numeric_accuracy_strict"] = results["numeric_correct_strict"] / n_numeric
    results["numeric_accuracy_lenient"] = results["numeric_correct_lenient"] / n_numeric
    results["special_accuracy_strict"] = results["special_correct_strict"] / n_special
    results["special_accuracy_lenient"] = results["special_correct_lenient"] / n_special
    results["total_accuracy_strict"] = (
        results["numeric_correct_strict"] + results["special_correct_strict"]
    ) / (n_numeric + n_special)
    results["total_accuracy_lenient"] = (
        results["numeric_correct_lenient"] + results["special_correct_lenient"]
    ) / (n_numeric + n_special)

    # Add section configuration to results
    results["section_counts"] = section_counts

    return results

In [7]:
normal_results = evaluate_model_with_intervention(
    model, sae, intervention_type="none", n_numeric=1, n_special=1, section_counts= {"basic": 20, "numeric": 0, "special": 20, "zero": 0}
)

Testing types:   0%|          | 0/2 [00:00<?, ?it/s]

Testing numeric:   0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Testing special:   0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

In [9]:
normal_results["special_tests"]

[{'prompt': "Tokens can be either black or white. Answer Y for yes and N for no:\n\nRemember:\n- Only use Y or N to answer\n- 'some' means at least one but not all\n- 'one' means exactly one\n- 'all' means every single token\n\n\n\n# Basic cases - using 'all'\n\nQ: I have ten tokens, and all of them are white. Are all of them white?\nA: Y\n\nQ: I have one tokens, and all of them are black. Are one of them white?\nA: N\n\nQ: I have two tokens, and all of them are black. Are one of them black?\nA: Y\n\nQ: I have ten tokens, and all of them are black. Are one of them black?\nA: Y\n\nQ: I have nine tokens, and all of them are white. Are some of them black?\nA: N\n\nQ: I have nine tokens, and all of them are white. Are some of them black?\nA: N\n\nQ: I have nine tokens, and all of them are black. Are some of them white?\nA: N\n\nQ: I have one tokens, and all of them are white. Are all of them white?\nA: Y\n\nQ: I have eight tokens, and all of them are white. Are all of them white?\nA: Y\n\n

In [24]:
ablated_results_hub = evaluate_model_with_intervention(
    model,
    sae,
    feature_ids=[2283, 8084, 13772],  # Can ablate multiple features
    intervention_type="ablation",
    n_numeric=1,
    n_special=1,
)
ablated_results_hub['special_tests']

Testing types:   0%|          | 0/2 [00:00<?, ?it/s]

Testing numeric:   0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Testing special:   0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

[{'prompt': "Tokens can be either black or white. Complete the following sentences using number words (one, two, three, etc.) never digits (1, 2, 3, etc.).\nRemember:\n- For 0, use 'zero' not '0'\n- The total number of tokens equals the sum of black and white tokens\n- Always write numbers as words (e.g., 'two' not '2')\n\nFor example, this is correct: \nQ: I have ten tokens, and five of them are black. How many of my tokens are white?\nA: five of them are white\n\nWhereas this is incorrect:\nQ: I have ten tokens, and five of them are black. How many of my tokens are white?\nA: 5 of them are white\n\n\n\n# Basic cases - all/none\n\nQ: I have four tokens, and some of them are black. How many of my tokens are white?\nA: some of them are white\n\nQ: I have six tokens, and some of them are white. How many of my tokens are black?\nA: some of them are black\n\n\n\n# Cases with specific numbers\n\nQ: I have ten tokens, and six of them are white. How many of my tokens are black?\nA: four of th

In [98]:
steering_results_hub = evaluate_model_with_intervention(
    model,
    sae,
    feature_ids=[12257, 12649],  # Can ablate multiple features
    intervention_type="steering",
    steering_strength= -1.5,
    n_numeric=1,
    n_special=1,
)

Testing types:   0%|          | 0/2 [00:00<?, ?it/s]

Testing numeric:   0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Testing special:   0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

In [99]:
steering_results_hub['special_tests']


[{'prompt': "Tokens can be either black or white. Complete the following sentences using number words (one, two, three, etc.) never digits (1, 2, 3, etc.).\nRemember:\n- For 0, use 'zero' not '0'\n- The total number of tokens equals the sum of black and white tokens\n- Always write numbers as words (e.g., 'two' not '2')\n\nFor example, this is correct: \nQ: I have ten tokens, and five of them are black. How many of my tokens are white?\nA: five of them are white\n\nWhereas this is incorrect:\nQ: I have ten tokens, and five of them are black. How many of my tokens are white?\nA: 5 of them are white\n\n\n\n# Basic cases - all/none\n\nQ: I have nine tokens, and none of them are black. How many of my tokens are white?\nA: all of them are white\n\nQ: I have ten tokens, and some of them are black. How many of my tokens are white?\nA: some of them are white\n\n\n\n# Cases with specific numbers\n\nQ: I have five tokens, and three of them are white. How many of my tokens are black?\nA: two of t

In [10]:
# Test without intervention
normal_results = evaluate_model_with_intervention(model, sae, intervention_type="none", section_counts= {"basic": 20, "numeric": 20, "special": 20, "zero": 0}, n_numeric=50, n_special=50)

# # Test with steering
# steered_results = evaluate_model_with_intervention(
#     model,
#     sae,
#     feature_ids=12257,
#     intervention_type="steering",
#     steering_strength=1.0,
# )

Testing types:   0%|          | 0/2 [00:00<?, ?it/s]

Testing numeric:   0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Testing special:   0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

In [11]:
print(normal_results["numeric_accuracy_strict"])
print(normal_results["numeric_accuracy_lenient"])
print(normal_results["special_accuracy_strict"])
print(normal_results["special_accuracy_lenient"])
# Test with ablation

0.02
0.02
0.14
0.14


In [12]:
normal_results['numeric_tests']

[{'prompt': "Tokens can be either black or white. Answer Y for yes and N for no:\n\nRemember:\n- Only use Y or N to answer\n- 'some' means at least one but not all\n- 'one' means exactly one\n- 'all' means every single token\n\n\n\n# Basic cases - using 'all'\n\nQ: I have three tokens, and all of them are black. Are one of them black?\nA: Y\n\nQ: I have ten tokens, and all of them are white. Are all of them black?\nA: N\n\nQ: I have five tokens, and all of them are white. Are all of them black?\nA: N\n\nQ: I have three tokens, and all of them are black. Are some of them white?\nA: N\n\nQ: I have one tokens, and all of them are black. Are one of them black?\nA: Y\n\nQ: I have one tokens, and all of them are black. Are all of them black?\nA: Y\n\nQ: I have eight tokens, and all of them are black. Are some of them white?\nA: N\n\nQ: I have ten tokens, and all of them are white. Are some of them white?\nA: Y\n\nQ: I have nine tokens, and all of them are black. Are all of them black?\nA: Y\

In [75]:
normal_results['special_tests']

[{'prompt': "Tokens can be either black or white. Complete the following sentences using number words (one, two, three, etc.) never digits (1, 2, 3, etc.).\nRemember:\n- For 0, use 'zero' not '0'\n- The total number of tokens equals the sum of black and white tokens\n- Always write numbers as words (e.g., 'two' not '2')\n\nFor example, this is correct: \nQ: I have ten tokens, and five of them are black. How many of my tokens are white?\nA: five of them are white\n\nWhereas this is incorrect:\nQ: I have ten tokens, and five of them are black. How many of my tokens are white?\nA: 5 of them are white\n\n\n\n# Basic cases - all/none\n\nQ: I have seven tokens, and none of them are white. How many of my tokens are black?\nA: all of them are black\n\nQ: I have five tokens, and some of them are white. How many of my tokens are black?\nA: some of them are black\n\nQ: I have five tokens, and all of them are white. How many of my tokens are black?\nA: none of them are black\n\nQ: I have four toke

In [77]:
# Test with ablation
ablated_results_hub = evaluate_model_with_intervention(
    model,
    sae,
    feature_ids=[12257],  # Can ablate multiple features
    intervention_type="ablation",
    n_numeric=50,
    n_special=50,
    section_counts= {"basic": 20, "numeric": 0, "special": 20, "zero": 0}
)

ablated_results_some = evaluate_model_with_intervention(
    model,
    sae,
    feature_ids=[15441],  # Can ablate multiple features
    intervention_type="ablation",
    n_numeric=50,
    n_special=50,
    section_counts= {"basic": 20, "numeric": 20, "special": 0, "zero": 0}
)

ablated_results_all = evaluate_model_with_intervention(
    model,
    sae,
    feature_ids=[12649],  # Can ablate multiple features
    intervention_type="ablation",
    n_numeric=50,
    n_special=50,
    section_counts= {"basic": 20, "numeric": 20, "special": 0, "zero": 0}
)

ablation_results_spokes = evaluate_model_with_intervention(
    model,
    sae,
    feature_ids=[12649, 15441],  # Can ablate multiple features
    intervention_type="ablation",
    n_numeric=50,
    n_special=50,
    section_counts= {"basic": 20, "numeric": 20, "special": 0, "zero": 0}
)

ablation_results_hub_spoke_some = evaluate_model_with_intervention(
    model,
    sae,
    feature_ids=[12257, 15441],  # Can ablate multiple features
    intervention_type="ablation",
    n_numeric=50,
    n_special=50,
    section_counts= {"basic": 20, "numeric": 20, "special": 0, "zero": 0}
)


Testing types:   0%|          | 0/2 [00:00<?, ?it/s]

Testing numeric:   0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Testing special:   0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Testing types:   0%|          | 0/2 [00:00<?, ?it/s]

Testing numeric:   0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Testing special:   0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Testing types:   0%|          | 0/2 [00:00<?, ?it/s]

Testing numeric:   0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Testing special:   0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Testing types:   0%|          | 0/2 [00:00<?, ?it/s]

Testing numeric:   0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Testing special:   0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Testing types:   0%|          | 0/2 [00:00<?, ?it/s]

Testing numeric:   0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Testing special:   0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

In [78]:

ablated_results_hub_spoke_all = evaluate_model_with_intervention(
    model,
    sae,
    feature_ids=[12257, 12649],  # Can ablate multiple features
    intervention_type="ablation",
    n_numeric=50,
    n_special=50,
    section_counts= {"basic": 20, "numeric": 20, "special": 0, "zero": 0}
)

Testing types:   0%|          | 0/2 [00:00<?, ?it/s]

Testing numeric:   0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Testing special:   0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

Compare the effect of ablation of the hub alone both spokes alone hub and both spokes together and both spokes together I hope hub and spokes will be more effective than any of these 

In [85]:
print(normal_results["numeric_accuracy_strict"])
print(normal_results["numeric_accuracy_lenient"])
print(normal_results["special_accuracy_strict"])
print(normal_results["special_accuracy_lenient"])


0.0
1.0
0.0
0.62


In [79]:
print(ablated_results_hub["numeric_accuracy_strict"])
print(ablated_results_hub["numeric_accuracy_lenient"])
print(ablated_results_hub["special_accuracy_strict"])
print(ablated_results_hub["special_accuracy_lenient"])

0.0
0.74
0.0
0.6


In [80]:
print(ablated_results_some["numeric_accuracy_strict"])
print(ablated_results_some["numeric_accuracy_lenient"])
print(ablated_results_some["special_accuracy_strict"])
print(ablated_results_some["special_accuracy_lenient"])

0.0
0.84
0.0
0.64


In [81]:
print(ablated_results_all["numeric_accuracy_strict"])
print(ablated_results_all["numeric_accuracy_lenient"])
print(ablated_results_all["special_accuracy_strict"])
print(ablated_results_all["special_accuracy_lenient"])

0.0
0.9
0.0
0.64


In [82]:
print(ablation_results_spokes["numeric_accuracy_strict"])
print(ablation_results_spokes["numeric_accuracy_lenient"])
print(ablation_results_spokes["special_accuracy_strict"])
print(ablation_results_spokes["special_accuracy_lenient"])


0.0
0.88
0.0
0.6


In [83]:
print(ablation_results_hub_spoke_some["numeric_accuracy_strict"])
print(ablation_results_hub_spoke_some["numeric_accuracy_lenient"])
print(ablation_results_hub_spoke_some["special_accuracy_strict"])
print(ablation_results_hub_spoke_some["special_accuracy_lenient"])


0.0
0.86
0.0
0.68


In [84]:
print(ablated_results_hub_spoke_all["numeric_accuracy_strict"])
print(ablated_results_hub_spoke_all["numeric_accuracy_lenient"])
print(ablated_results_hub_spoke_all["special_accuracy_strict"])
print(ablated_results_hub_spoke_all["special_accuracy_lenient"])


0.0
0.94
0.0
0.6
