In [1]:
import random

import torch
from num2words import num2words

# from transformer_lens import HookedTransformer
from sae_lens import SAE, ActivationsStore, HookedSAETransformer
from tqdm.autonotebook import tqdm

from sae_cooccurrence.utils.set_paths import get_git_root

In [2]:
torch.set_grad_enabled(False)

git_root = get_git_root()

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
# from transformer_lens import HookedTransformer

model = HookedSAETransformer.from_pretrained("gemma-2-2b", device=device)

# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience.
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="gemma-scope-2b-pt-res-canonical",  # <- Release name
    sae_id="layer_12/width_16k/canonical",  # <- SAE id (not always a hook point!)
    device=device,
)

activation_store = ActivationsStore.from_sae(
    model=model,
    sae=sae,
    streaming=True,
    # fairly conservative parameters here so can use same for larger
    # models without running out of memory.
    store_batch_size_prompts=8,
    train_batch_size_tokens=4096,
    n_batches_in_buffer=4,
    device=device,
)



Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]



Loaded pretrained model gemma-2-2b into HookedTransformer


Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [4]:
class TokenQuestionGenerator:
    def __init__(self):
        self.colors = ["black", "white"]
        self.special_cases = {"some": "some", "all": "none", "none": "all"}

    def _number_to_words(self, n: int) -> str:
        """Convert a number to words."""
        return num2words(n)

    def _questions_are_equivalent(self, q1: str, q2: str) -> bool:
        """Compare two questions to check if they are functionally equivalent."""
        # Extract the numerical values and colors from both questions
        q1_parts = q1.split()
        q2_parts = q2.split()

        # If lengths are different, questions can't be equivalent
        if len(q1_parts) != len(q2_parts):
            return False

        return q1 == q2  # Direct comparison for now

    def generate_numeric_question(self) -> tuple[str, str]:
        """Generate a question with numeric values."""
        n_total = random.randint(1, 10)
        n_color_tokens = random.randint(0, n_total)
        colors = random.sample(self.colors, 2)
        held_color, test_color = colors

        question = (
            f"Q: I have {self._number_to_words(n_total)} tokens, and "
            f"{self._number_to_words(n_color_tokens)} of them are {held_color}. "
            f"How many of my tokens are {test_color}?"
        )

        answer = f"A: {self._number_to_words(n_total - n_color_tokens)} of them are {test_color}"

        return question, answer

    def generate_special_case_question(self) -> tuple[str, str]:
        """Generate a question with special quantifiers (some, all, none)."""
        special_type = random.choice(list(self.special_cases.keys()))
        colors = random.sample(self.colors, 2)
        held_color, test_color = colors
        n_total = random.randint(2, 10)  # Using at least 2 tokens for special cases

        question = (
            f"Q: I have {self._number_to_words(n_total)} tokens, and "
            f"{special_type} of them are {held_color}. "
            f"How many of my tokens are {test_color}?"
        )

        answer = f"A: {self.special_cases[special_type]} of them are {test_color}"

        return question, answer

    def generate_training_set(
        self, num_numeric: int = 5, num_special: int = 3
    ) -> list[tuple[str, str]]:
        """Generate a set of training questions."""
        training_set = []

        # Generate numeric questions
        for _ in range(num_numeric):
            training_set.append(self.generate_numeric_question())

        # Generate special case questions
        for _ in range(num_special):
            training_set.append(self.generate_special_case_question())

        random.shuffle(training_set)
        return training_set

    def generate_test_question(
        self, force_numeric: bool | None = None, force_special: bool | None = None
    ) -> tuple[str, str]:
        """
        Generate a test question.
        Parameters:
            force_numeric: If True, generates numeric question
            force_special: If True, generates special case question
        If both are None, randomly chooses between numeric and special case.
        """
        if force_numeric and force_special:
            raise ValueError("Cannot force both numeric and special case")

        if force_numeric:
            return self.generate_numeric_question()
        elif force_special:
            return self.generate_special_case_question()
        else:
            return random.choice(
                [self.generate_numeric_question, self.generate_special_case_question]
            )()


def generate_training_and_test(
    num_training_numeric: int = 5,
    num_training_special: int = 3,
    force_test_type: str | None = None,  # Can be 'numeric', 'special', or None
    max_attempts: int = 100,  # Add maximum attempts to prevent infinite loops
) -> dict:
    """Generate a complete training set and test question."""
    generator = TokenQuestionGenerator()

    # Generate test question first
    force_numeric = True if force_test_type == "numeric" else None
    force_special = True if force_test_type == "special" else None
    test_question, test_answer = generator.generate_test_question(
        force_numeric=force_numeric, force_special=force_special
    )

    # Generate training set, ensuring no duplicates with test question
    training_set = []
    attempts = 0

    while (
        len(training_set) < (num_training_numeric + num_training_special)
        and attempts < max_attempts
    ):
        current_set = generator.generate_training_set(
            num_numeric=num_training_numeric, num_special=num_training_special
        )

        # Filter out any questions that match the test question
        training_set = [
            (q, a)
            for q, a in current_set
            if not generator._questions_are_equivalent(q, test_question)
        ]

        attempts += 1

    if attempts >= max_attempts:
        raise RuntimeError(
            "Failed to generate unique training set after maximum attempts"
        )

    return {
        "introduction": "Tokens can be either black or white. Complete the following sentences, always use numbers (one, two, three, etc.) never digits (1, 2, 3, etc.):",
        "training_questions": training_set,
        "test_question": test_question,
        "test_answer": test_answer,
    }


# Example usage:
if __name__ == "__main__":
    result = generate_training_and_test(
        num_training_numeric=3,
        num_training_special=2,
        force_test_type="numeric",  # Can be 'numeric', 'special', or None
    )

    print(result["introduction"])
    print("\nTraining Questions:")
    for q, a in result["training_questions"]:
        print(f"{q}\n{a}\n")
    print("\nTest Question:")
    print(result["test_question"])
    print("\nTest Answer:")
    print(result["test_answer"])

Tokens can be either black or white. Complete the following sentences, always use numbers (one, two, three, etc.) never digits (1, 2, 3, etc.):

Training Questions:
Q: I have three tokens, and one of them are white. How many of my tokens are black?
A: two of them are black

Q: I have six tokens, and none of them are white. How many of my tokens are black?
A: all of them are black

Q: I have two tokens, and all of them are black. How many of my tokens are white?
A: none of them are white

Q: I have nine tokens, and nine of them are white. How many of my tokens are black?
A: zero of them are black

Q: I have one tokens, and zero of them are black. How many of my tokens are white?
A: one of them are white


Test Question:
Q: I have three tokens, and three of them are white. How many of my tokens are black?

Test Answer:
A: zero of them are black


In [5]:
from functools import partial


def find_max_activation(model, sae, activation_store, feature_idx, num_batches=100):
    """
    Find the maximum activation for a given feature index. This is useful for
    calibrating the right amount of the feature to add.
    """
    max_activation = 0.0

    pbar = tqdm(range(num_batches))
    for _ in pbar:
        tokens = activation_store.get_batch_tokens()

        _, cache = model.run_with_cache(
            tokens,
            stop_at_layer=sae.cfg.hook_layer + 1,
            names_filter=[sae.cfg.hook_name],
        )
        sae_in = cache[sae.cfg.hook_name]
        feature_acts = sae.encode(sae_in).squeeze()

        feature_acts = feature_acts.flatten(0, 1)
        batch_max_activation = feature_acts[:, feature_idx].max().item()
        max_activation = max(max_activation, batch_max_activation)

        pbar.set_description(f"Max activation: {max_activation:.4f}")

    return max_activation


def steering(activations, steering_strength=1.0, steering_vector=None, max_act=1.0):
    # Note if the feature fires anyway, we'd be adding to that here.
    return activations + max_act * steering_strength * steering_vector


def generate_with_steering(
    model,
    sae,
    prompt,
    steering_feature,
    max_act,
    steering_strength=1.0,
    max_new_tokens=95,
):
    input_ids = model.to_tokens(prompt, prepend_bos=sae.cfg.prepend_bos)

    steering_vector = sae.W_dec[steering_feature].to(model.cfg.device)

    steering_hook = partial(
        steering,
        steering_vector=steering_vector,
        steering_strength=steering_strength,
        max_act=max_act,
    )

    # standard transformerlens syntax for a hook context for generation
    with model.hooks(fwd_hooks=[(sae.cfg.hook_name, steering_hook)]):
        output = model.generate(
            input_ids,
            max_new_tokens=max_new_tokens,
            temperature=0.7,
            top_p=0.9,
            stop_at_eos=False if device == "mps" else True,
            prepend_bos=sae.cfg.prepend_bos,
        )

    return model.tokenizer.decode(output[0])

In [6]:
def evaluate_model_with_steering(
    model: HookedSAETransformer,
    sae: SAE,
    feature_to_steer: int | None = None,
    steering_strength: float = 1.0,
    max_act: float = 60.0,
    n_numeric: int = 10,
    n_special: int = 10,
    max_new_tokens: int = 5,
) -> dict:
    """
    Test model accuracy on token counting tasks with optional feature steering.

    Args:
        model: The transformer model
        sae: The sparse autoencoder
        feature_to_steer: Feature index to steer, or None for no steering
        steering_strength: Strength of steering (default 1.0)
        max_act: Maximum activation for the steered feature
        n_numeric: Number of numeric test questions
        n_special: Number of special test questions
        max_new_tokens: Maximum tokens to generate for each answer

    Returns:
        Dictionary containing accuracy metrics and test results
    """
    results = {
        "numeric_correct": 0,
        "special_correct": 0,
        "numeric_tests": [],
        "special_tests": [],
    }

    for test_type in ["numeric", "special"]:
        n_tests = n_numeric if test_type == "numeric" else n_special
        for _ in range(n_tests):
            test_data = generate_training_and_test(
                num_training_numeric=40,
                num_training_special=20,
                force_test_type=test_type,
            )

            prompt = (
                test_data["introduction"]
                + "\n\n"
                + "\n\n".join(f"{q}\n{a}" for q, a in test_data["training_questions"])
                + f"\n\n{test_data['test_question']}\nA: "
            )

            if feature_to_steer is not None:
                generated = generate_with_steering(
                    model,
                    sae,
                    prompt,
                    feature_to_steer,
                    max_act,
                    steering_strength=steering_strength,
                    max_new_tokens=max_new_tokens,
                )
            else:
                # Convert prompt to tokens first
                input_ids = model.to_tokens(prompt, prepend_bos=sae.cfg.prepend_bos)
                output = model.generate(
                    input_ids,
                    max_new_tokens=max_new_tokens,
                    # temperature=0.7,
                    # top_p=0.9,
                    stop_at_eos=False if device == "mps" else True,
                    prepend_bos=sae.cfg.prepend_bos,
                )
                generated = model.tokenizer.decode(output[0])

            test_result = {
                "prompt": prompt,
                "expected": test_data["test_answer"],
                "generated": generated,
                "correct": test_data["test_answer"] in generated,
            }

            if test_type == "numeric":
                results["numeric_tests"].append(test_result)
                if test_result["correct"]:
                    results["numeric_correct"] += 1
            else:
                results["special_tests"].append(test_result)
                if test_result["correct"]:
                    results["special_correct"] += 1

    # Calculate accuracies
    results["numeric_accuracy"] = results["numeric_correct"] / n_numeric
    results["special_accuracy"] = results["special_correct"] / n_special
    results["total_accuracy"] = (
        results["numeric_correct"] + results["special_correct"]
    ) / (n_numeric + n_special)

    return results


# Example usage:
if __name__ == "__main__":
    # Test without steering
    normal_results = evaluate_model_with_steering(model, sae)
    print("\nResults without steering:")
    print(f"Numeric accuracy: {normal_results['numeric_accuracy']:.2%}")
    print(f"Special accuracy: {normal_results['special_accuracy']:.2%}")
    print(f"Total accuracy: {normal_results['total_accuracy']:.2%}")

    print("\nNumeric test results:")
    for i, test in enumerate(normal_results["numeric_tests"], 1):
        print(f"\nTest {i}:")
        question = test["prompt"].split("A: ")[0].splitlines()[-1]
        print(f"Question: {question}")
        print(f"Expected: {test['expected']}")
        print(f"Generated: {test['generated']}")
        print(f"Correct: {test['correct']}")

    print("\nSpecial test results:")
    for i, test in enumerate(normal_results["special_tests"], 1):
        print(f"\nTest {i}:")
        question = test["prompt"].split("A: ")[0].splitlines()[-1]
        print(f"Question: {question}")
        print(f"Expected: {test['expected']}")
        print(f"Generated: {test['generated']}")
        print(f"Correct: {test['correct']}")

    # Test with feature steering
    feature_to_steer = 12257  # Replace with your feature of interest
    steering_strength = 0.0
    steered_results = evaluate_model_with_steering(
        model,
        sae,
        feature_to_steer=feature_to_steer,
        steering_strength=steering_strength,
    )
    print(
        f"\nResults with feature {feature_to_steer} steered (strength {steering_strength}):"
    )
    print(f"Numeric accuracy: {steered_results['numeric_accuracy']:.2%}")
    print(f"Special accuracy: {steered_results['special_accuracy']:.2%}")
    print(f"Total accuracy: {steered_results['total_accuracy']:.2%}")

    print("\nNumeric test results (with steering):")
    for i, test in enumerate(steered_results["numeric_tests"], 1):
        print(f"\nTest {i}:")
        question = test["prompt"].split("A: ")[0].splitlines()[-1]
        print(f"Question: {question}")
        print(f"Expected: {test['expected']}")
        print(f"Generated: {test['generated']}")
        print(f"Correct: {test['correct']}")

    print("\nSpecial test results (with steering):")
    for i, test in enumerate(steered_results["special_tests"], 1):
        print(f"\nTest {i}:")
        question = test["prompt"].split("A: ")[0].splitlines()[-1]
        print(f"Question: {question}")
        print(f"Expected: {test['expected']}")
        print(f"Generated: {test['generated']}")
        print(f"Correct: {test['correct']}")

    # Test with feature steering
    feature_to_steer = 15441  # Replace with your feature of interest
    steering_strength = 0.0
    steered_results = evaluate_model_with_steering(
        model,
        sae,
        feature_to_steer=feature_to_steer,
        steering_strength=steering_strength,
    )
    print(
        f"\nResults with feature {feature_to_steer} steered (strength {steering_strength}):"
    )
    print(f"Numeric accuracy: {steered_results['numeric_accuracy']:.2%}")
    print(f"Special accuracy: {steered_results['special_accuracy']:.2%}")
    print(f"Total accuracy: {steered_results['total_accuracy']:.2%}")

    print("\nNumeric test results (with steering):")
    for i, test in enumerate(steered_results["numeric_tests"], 1):
        print(f"\nTest {i}:")
        question = test["prompt"].split("A: ")[0].splitlines()[-1]
        print(f"Question: {question}")
        print(f"Expected: {test['expected']}")
        print(f"Generated: {test['generated']}")
        print(f"Correct: {test['correct']}")

    print("\nSpecial test results (with steering):")
    for i, test in enumerate(steered_results["special_tests"], 1):
        print(f"\nTest {i}:")
        question = test["prompt"].split("A: ")[0].splitlines()[-1]
        print(f"Question: {question}")
        print(f"Expected: {test['expected']}")
        print(f"Generated: {test['generated']}")
        print(f"Correct: {test['correct']}")

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [21]:
def evaluate_model_with_intervention(
    model: HookedSAETransformer,
    sae: SAE,
    feature_ids: int | list[int] | None = None,
    intervention_type: str = "none",  # "none", "steering", or "ablation"
    steering_strength: float = 1.0,
    max_act: float = 60.0,
    n_numeric: int = 10,
    n_special: int = 10,
    max_new_tokens: int = 5,
) -> dict:
    """
    Test model accuracy on token counting tasks with optional feature steering or ablation.

    Args:
        model: The transformer model
        sae: The sparse autoencoder
        feature_ids: Feature index(es) to intervene on, or None for no intervention
        intervention_type: Type of intervention ("none", "steering", or "ablation")
        steering_strength: Strength of steering (default 1.0)
        max_act: Maximum activation for the steered feature
        n_numeric: Number of numeric test questions
        n_special: Number of special test questions
        max_new_tokens: Maximum tokens to generate for each answer

    Returns:
        Dictionary containing accuracy metrics and test results
    """
    results = {
        "numeric_correct": 0,
        "special_correct": 0,
        "numeric_tests": [],
        "special_tests": [],
    }

    # Convert single feature_id to list for consistency
    if isinstance(feature_ids, int):
        feature_ids = [feature_ids]

    def ablate_feature_hook(feature_activations, feature_ids):
        feature_activations[:, :, feature_ids] = 0
        return feature_activations

    for test_type in ["numeric", "special"]:
        n_tests = n_numeric if test_type == "numeric" else n_special
        for _ in range(n_tests):
            test_data = generate_training_and_test(
                num_training_numeric=40,
                num_training_special=20,
                force_test_type=test_type,
            )

            prompt = (
                test_data["introduction"]
                + "\n\n"
                + "\n\n".join(f"{q}\n{a}" for q, a in test_data["training_questions"])
                + f"\n\n{test_data['test_question']}\nA: "
            )

            # Convert prompt to tokens first
            input_ids = model.to_tokens(prompt, prepend_bos=sae.cfg.prepend_bos)

            if intervention_type == "none" or feature_ids is None:
                output = model.generate(
                    input_ids,
                    max_new_tokens=max_new_tokens,
                    stop_at_eos=False if device == "mps" else True,
                    prepend_bos=sae.cfg.prepend_bos,
                )

            elif intervention_type == "steering":
                steering_vector = sae.W_dec[feature_ids[0]].to(model.cfg.device)
                steering_hook = partial(
                    steering,
                    steering_vector=steering_vector,
                    steering_strength=steering_strength,
                    max_act=max_act,
                )

                with model.hooks(fwd_hooks=[(sae.cfg.hook_name, steering_hook)]):
                    output = model.generate(
                        input_ids,
                        max_new_tokens=max_new_tokens,
                        temperature=0.7,
                        top_p=0.9,
                        stop_at_eos=False if device == "mps" else True,
                        prepend_bos=sae.cfg.prepend_bos,
                    )

            elif intervention_type == "ablation":
                ablation_hook = partial(ablate_feature_hook, feature_ids=feature_ids)
                model.add_sae(sae)
                hook_point = sae.cfg.hook_name + ".hook_sae_acts_post"

                with model.hooks(fwd_hooks=[(hook_point, ablation_hook)]):
                    output = model.generate(
                        input_ids,
                        max_new_tokens=max_new_tokens,
                        stop_at_eos=False if device == "mps" else True,
                        prepend_bos=sae.cfg.prepend_bos,
                    )

                model.reset_hooks()
                model.reset_saes()

            generated = model.tokenizer.decode(output[0])
            generated_answer = generated.split("A: ")[-1].strip()
            expected_answer = test_data["test_answer"].replace("A: ", "").strip()

            test_result = {
                "prompt": prompt,
                "expected": expected_answer,
                "generated": generated_answer,
                "correct": expected_answer in generated_answer,
            }

            if test_type == "numeric":
                results["numeric_tests"].append(test_result)
                if test_result["correct"]:
                    results["numeric_correct"] += 1
            else:
                results["special_tests"].append(test_result)
                if test_result["correct"]:
                    results["special_correct"] += 1

    # Calculate accuracies
    results["numeric_accuracy"] = results["numeric_correct"] / n_numeric
    results["special_accuracy"] = results["special_correct"] / n_special
    results["total_accuracy"] = (
        results["numeric_correct"] + results["special_correct"]
    ) / (n_numeric + n_special)

    return results

In [22]:
normal_results = evaluate_model_with_intervention(
    model, sae, intervention_type="none", n_numeric=1, n_special=1
)

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

In [27]:
print(normal_results["special_tests"])

[{'prompt': 'Tokens can be either black or white. Complete the following sentences, always use numbers (one, two, three, etc.) never digits (1, 2, 3, etc.):\n\nQ: I have two tokens, and none of them are black. How many of my tokens are white?\nA: all of them are white\n\nQ: I have ten tokens, and some of them are white. How many of my tokens are black?\nA: some of them are black\n\nQ: I have eight tokens, and none of them are white. How many of my tokens are black?\nA: all of them are black\n\nQ: I have three tokens, and zero of them are white. How many of my tokens are black?\nA: three of them are black\n\nQ: I have five tokens, and some of them are black. How many of my tokens are white?\nA: some of them are white\n\nQ: I have four tokens, and zero of them are black. How many of my tokens are white?\nA: four of them are white\n\nQ: I have two tokens, and some of them are white. How many of my tokens are black?\nA: some of them are black\n\nQ: I have four tokens, and one of them are w

In [8]:
# Test without intervention
normal_results = evaluate_model_with_intervention(model, sae, intervention_type="none")

# # Test with steering
# steered_results = evaluate_model_with_intervention(
#     model,
#     sae,
#     feature_ids=12257,
#     intervention_type="steering",
#     steering_strength=1.0,
# )

# Test with ablation
ablated_results_hub = evaluate_model_with_intervention(
    model,
    sae,
    feature_ids=[12257],  # Can ablate multiple features
    intervention_type="ablation",
)

ablated_results_some = evaluate_model_with_intervention(
    model,
    sae,
    feature_ids=[15441],  # Can ablate multiple features
    intervention_type="ablation",
)

ablated_results_all = evaluate_model_with_intervention(
    model,
    sae,
    feature_ids=[12649],  # Can ablate multiple features
    intervention_type="ablation",
)

ablation_results_spokes = evaluate_model_with_intervention(
    model,
    sae,
    feature_ids=[12649, 15441],  # Can ablate multiple features
    intervention_type="ablation",
)

ablation_results_hub_spoke_some = evaluate_model_with_intervention(
    model,
    sae,
    feature_ids=[12257, 15441],  # Can ablate multiple features
    intervention_type="ablation",
)

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

Compare the effect of ablation of the hub alone both spokes alone hub and both spokes together and both spokes together I hope hub and spokes will be more effective than any of these 

In [16]:
print(normal_results["total_accuracy"])
print(ablated_results_hub["total_accuracy"])
print(ablated_results_some["total_accuracy"])
print(ablated_results_all["total_accuracy"])
print(ablation_results_spokes["total_accuracy"])
print(ablation_results_hub_spoke_some["total_accuracy"])

0.9
0.9
0.85
0.9
0.85
0.95


In [17]:
print(normal_results["numeric_accuracy"])
print(ablated_results_hub["numeric_accuracy"])
print(ablated_results_some["numeric_accuracy"])
print(ablated_results_all["numeric_accuracy"])
print(ablation_results_spokes["numeric_accuracy"])
print(ablation_results_hub_spoke_some["numeric_accuracy"])

0.8
0.9
0.7
0.8
0.7
0.9


In [18]:
print(normal_results["special_accuracy"])
print(ablated_results_hub["special_accuracy"])
print(ablated_results_some["special_accuracy"])
print(ablated_results_all["special_accuracy"])
print(ablation_results_spokes["special_accuracy"])
print(ablation_results_hub_spoke_some["special_accuracy"])

1.0
0.9
1.0
1.0
1.0
1.0


In [None]:
# Test without intervention
normal_results = evaluate_model_with_intervention(model, sae, intervention_type="none")

# # Test with steering
# steered_results = evaluate_model_with_intervention(
#     model,
#     sae,
#     feature_ids=12257,
#     intervention_type="steering",
#     steering_strength=1.0,
# )

# Test with ablation
ablated_results_hub = evaluate_model_with_intervention(
    model,
    sae,
    feature_ids=[12257],  # Can ablate multiple features
    intervention_type="ablation",
)

ablated_results_some = evaluate_model_with_intervention(
    model,
    sae,
    feature_ids=[15441],  # Can ablate multiple features
    intervention_type="ablation",
)

ablated_results_all = evaluate_model_with_intervention(
    model,
    sae,
    feature_ids=[12649],  # Can ablate multiple features
    intervention_type="ablation",
)

ablation_results_spokes = evaluate_model_with_intervention(
    model,
    sae,
    feature_ids=[12649, 15441],  # Can ablate multiple features
    intervention_type="ablation",
)

ablation_results_hub_spoke_some = evaluate_model_with_intervention(
    model,
    sae,
    feature_ids=[12257, 15441],  # Can ablate multiple features
    intervention_type="ablation",
)