<a href="https://www.kaggle.com/code/veerajaveeraesh/inference-using-llava-steerfair-on-scienceqa?scriptVersionId=240257312" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Environment Setup and Installations

In [1]:
# Clear the working directory to ensure a fresh start
!rm -rf /kaggle/working/*

# Install baukit for model interpretability (used for TraceDict)
!pip install git+https://github.com/davidbau/baukit

# Install Hugging Face transformers and datasets from their git repositories
!pip install git+https://github.com/huggingface/transformers.git
!pip install git+https://github.com/huggingface/datasets.git

# Uninstall existing flash-attn and reinstall from Dao-AILab's repository
# This specific sequence is for a particular version or feature of flash-attention
!pip uninstall flash-attn -y
!pip install git+https://github.com/Dao-AILab/flash-attention.git

# Install accelerate and bitsandbytes for model optimization and quantization
!pip install --upgrade -q accelerate bitsandbytes

# Clone the ScienceQA dataset repository
!git clone https://huggingface.co/datasets/derek-thomas/ScienceQA

# Install triton, often a dependency or used with flash-attention
!pip install triton

# Install GPUtil for GPU monitoring
!pip install GPUtil

# Install numba for JIT compilation of Python code, can speed up numerical operations
!pip install numba

Collecting git+https://github.com/davidbau/baukit
  Cloning https://github.com/davidbau/baukit to /tmp/pip-req-build-e065mhqz
  Running command git clone --filter=blob:none --quiet https://github.com/davidbau/baukit /tmp/pip-req-build-e065mhqz
  Resolved https://github.com/davidbau/baukit to commit 9d51abd51ebf29769aecc38c4cbef459b731a36e
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: baukit
  Building wheel for baukit (pyproject.toml) ... [?25l[?25hdone
  Created wheel for baukit: filename=baukit-0.0.1-py3-none-any.whl size=59678 sha256=cbd6c20a02df2b8804f3169f1c08e0fe770c0bbf702773e3004cfabebefd9ff8
  Stored in directory: /tmp/pip-ephem-wheel-cache-34rj5c6a/wheels/e2/7a/dc/eb53bf0e7f86297d7d9759d9eba117036e850e1bfc3bda0176
Successfully built baukit
Installing collected packages: baukit
Successfully installed

# Imports and Initial Configurations
This cell imports all the necessary Python libraries that will be used throughout the notebook. It also includes an initial GPU check and sets PyTorch's gradient calculation to off, as this notebook focuses on inference and activation extraction, not training.

In [2]:
import PIL
import io
import requests
from transformers import AutoProcessor, LlavaForConditionalGeneration
from transformers import BitsAndBytesConfig
from datasets import load_dataset
import torch
import json
import numpy as np
import pandas as pd
import os
from tqdm import tqdm
import shutil

# For model interpretability, specifically for hooking into model layers
from baukit import TraceDict

# For GPU monitoring
from GPUtil import showUtilization as gpu_usage

# Check GPU availability and set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
    gpu_usage()

# Disable gradient calculations globally, as we are doing inference and extraction
torch.set_grad_enabled(False)

# Further imports for the main Head-wise steering part
import torch.nn as nn
from tqdm.auto import tqdm # Use tqdm.auto for better notebook integration
import random
import gc # Garbage collector
from sklearn.decomposition import PCA
import re # Regular expressions
import seaborn as sns
import matplotlib.pyplot as plt

Using device: cuda
| ID | GPU | MEM |
------------------
|  0 |  0% |  0% |


# Load and Prepare ScienceQA Dataset
Here, we load the "derek-thomas/ScienceQA" dataset. We're using the "test" split.
- The dataset is shuffled for randomness.
- A subset of 1000 samples is selected for manageability in this pipeline.
- The dataset is converted to a Pandas DataFrame for easier inspection and manipulation.
- A unique q_id (question ID) is added to each sample in the sampled dataset, which will be useful for linking samples to their saved images.

In [3]:
# Load the ScienceQA dataset, test split
dataset = load_dataset("derek-thomas/ScienceQA", split="test")
print(f"Original dataset type: {type(dataset)}")

# Shuffle the dataset for random sampling
shuffled_dataset = dataset.shuffle(seed=42)
print(f"Shuffled dataset type: {type(shuffled_dataset)}")

# Select a subset of 1000 samples for this pipeline
# This makes processing faster and more manageable for demonstration
NUM_SAMPLES_TO_USE = 1000
sampled_dataset = shuffled_dataset.select(range(NUM_SAMPLES_TO_USE))
print(f"Sampled dataset type: {type(sampled_dataset)}")

# Convert the full dataset to a Pandas DataFrame for a quick look
df = dataset.to_pandas()
print("\nExample from full dataset (df.iloc[1]):")
print(df.iloc[1])

# Convert the sampled dataset to a Pandas DataFrame
df_sampled = sampled_dataset.to_pandas()
print("\nFirst 3 examples from sampled dataset (df_sampled.head(3)):")
print(df_sampled.head(3))

print(f"\nLength of original dataset: {len(dataset)}")
print(f"Length of sampled dataset: {len(sampled_dataset)}")

# Add a unique 'q_id' to each sample in the sampled dataset
# This ID will be used to name image files and track samples
df_sampled["q_id"] = range(len(sampled_dataset))
sampled_dataset = sampled_dataset.from_pandas(df_sampled) # Convert back to Hugging Face Dataset object
print(f"\nSampled dataset representation after adding q_id:\n{sampled_dataset}")

README.md:   0%|          | 0.00/10.3k [00:00<?, ?B/s]

(…)-00000-of-00001-1028f23e353fbe3e.parquet:   0%|          | 0.00/377M [00:00<?, ?B/s]

(…)-00000-of-00001-6c7328ff6c84284c.parquet:   0%|          | 0.00/126M [00:00<?, ?B/s]

(…)-00000-of-00001-f0e719df791966ff.parquet:   0%|          | 0.00/122M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/12726 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/4241 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/4241 [00:00<?, ? examples/s]

Original dataset type: <class 'datasets.arrow_dataset.Dataset'>
Shuffled dataset type: <class 'datasets.arrow_dataset.Dataset'>
Sampled dataset type: <class 'datasets.arrow_dataset.Dataset'>

Example from full dataset (df.iloc[1]):
image       {'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...
question     Which of the following could Gordon's test show?
choices     [if the spacecraft was damaged when using a pa...
answer                                                      1
hint        People can use the engineering-design process ...
task                                            closed choice
grade                                                  grade8
subject                                       natural science
topic                       science-and-engineering-practices
category                                Engineering practices
skill          Evaluate tests of engineering-design solutions
lecture     People can use the engineering-design process ...
solution                

# Save Images from Dataset
The ScienceQA dataset contains images. LLaVA is a multimodal model that processes both text and images. This cell defines a function `save_images` to extract image data from each sample and save it as a PNG file in the `/kaggle/working/images/` directory. The path to the saved image is stored. This is crucial because the model will need to load these images during processing.

In [4]:
def save_images(dataset: 'datasets.Dataset') -> list:
    """
    Saves images from the dataset samples to a local directory.

    Each image is saved as 'image_{idx}.png', where idx is the sample's index.
    If a sample has no image or an error occurs, None is stored for its path.

    Args:
        dataset: A Hugging Face Dataset object, where each sample might
                 contain an 'image' field with 'bytes'.

    Returns:
        A list of file paths to the saved images, or None for samples
        without images or with saving errors.
    """
    image_dir = '/kaggle/working/images'
    # Remove existing image directory to ensure fresh save
    if os.path.exists(image_dir):
        shutil.rmtree(image_dir)
    os.makedirs(image_dir, exist_ok=True)

    image_paths = []

    for idx, sample in tqdm(enumerate(dataset), total=len(dataset),
                             desc="Saving Images"):
        image_dict = sample.get('image')
        # Skip if image data is not present
        if image_dict is None:
            image_paths.append(None)
            continue

        image_bytes = image_dict.get("bytes")
        if image_bytes is None: # Handle cases where 'bytes' key might be missing within image_dict
             image_paths.append(None)
             continue

        try:
            # If it's already a PIL Image object (though unlikely if 'bytes' is primary)
            if isinstance(image_bytes, PIL.Image.Image):
                image_path = os.path.join(image_dir, f'image_{idx}.png')
                image_bytes.save(image_path)
                image_paths.append(image_path)
            else:
                # Convert bytes to PIL Image
                image_path = os.path.join(image_dir, f'image_{idx}.png')
                pil_image = PIL.Image.open(io.BytesIO(image_bytes))
                pil_image.save(image_path)
                image_paths.append(image_path)

        except Exception as e:
            print(f"Error saving image for sample {idx}: {e}")
            image_paths.append(None)

    return image_paths

# Save images from our sampled dataset
image_paths = save_images(sampled_dataset)

# Display the first 10 image paths
print("\nFirst 10 image paths (or None if no image/error):")
print(image_paths[:10])

Saving Images:   0%|          | 0/1000 [00:00<?, ?it/s]


First 10 image paths (or None if no image/error):
['/kaggle/working/images/image_0.png', '/kaggle/working/images/image_1.png', None, '/kaggle/working/images/image_3.png', None, None, '/kaggle/working/images/image_6.png', None, None, None]


# Analyze Choice Distribution and Create Dataset Variants
This cell performs two main tasks:
1.  **Analyzes Choice Distribution:** It counts how many questions have a certain number of multiple-choice options. This gives an idea of the dataset's structure.
2.  **Creates Dataset Variants:**
    -   `get_cyclic_perms`: This function generates all cyclic permutations of a list of choices. For example, if choices are \[A, B, C], permutations would be \[A, B, C], \[B, C, A], \[C, A, B].
    -   `create_dataset_variants`: This function takes the original `sampled_dataset` and expands it. For each question:
        - It creates versions of the question with each cyclic permutation of its choices.
        - For each of *these* versions, it further creates variants where each possible choice index (0, 1, 2, etc.) is designated as the "correct" answer for the purpose of generating demonstration prompts.
    - The goal of these variants (`demonstration_sets`) is to create a balanced set of examples where the model is prompted to output each choice index (0, 1, 2, 3, 4) as the correct one. This helps in identifying if the model has an inherent bias towards picking a certain answer position, regardless of the actual content. The `final_datalist` contains all these generated variants.

In [5]:
# Analyze the distribution of the number of choices per question
df_sampled_numchoices = df_sampled.copy() # Use a copy of the sampled dataframe
df_sampled_numchoices["numchoices"] = [len(lst) if lst is not None else 0 for lst in df_sampled_numchoices["choices"]]
print("\nDistribution of number of choices in sampled_dataset:")
print(df_sampled_numchoices.groupby("numchoices").size())


from collections import deque

def get_cyclic_perms(choices: list) -> list:
    """
    Generates all cyclic permutations of a list of choices.

    Args:
        choices: A list of answer choices.

    Returns:
        A list of lists, where each inner list is a cyclic permutation
        of the input choices.
    """
    if not choices:
        return [[]]
    cyclic_perms = []
    d = deque(choices)
    for _ in range(len(choices)):
        d.rotate() # Rotates the deque one step to the right
        cyclic_perms.append(list(d))
    return cyclic_perms

def create_dataset_variants(dataset: 'datasets.Dataset') -> tuple[dict, list]:
    """
    Creates variants of the dataset with different answer orders and designated answers.

    For each sample in the input dataset:
    1. Generates all cyclic permutations of its answer choices.
    2. For each permutation, creates new samples where each choice index
       (0, 1, 2, etc.) is set as the 'answer' for demonstration purposes.

    This creates a `demonstration_sets` dictionary keyed by the answer index (0-4),
    containing samples prepared to elicit that answer index. It also returns
    `final_datalist`, a flat list of all generated variants.

    Args:
        dataset: A Hugging Face Dataset object.

    Returns:
        A tuple containing:
        - demonstration_sets (dict): {answer_idx: [list_of_samples_for_this_answer_idx]}
        - final_datalist (list): A flat list of all created sample variants.
    """
    demonstration_sets = {0: [], 1: [], 2: [], 3: [], 4: []} # Assuming max 5 choices for ScienceQA
    new_datalist = []
    final_datalist = []

    for sample in tqdm(dataset, desc="Creating Cyclic Permutations"):
        if sample["choices"] is None: # Skip samples with no choices
            continue
        cyclic_perms_lst = get_cyclic_perms(sample["choices"])
        for i, lst in enumerate(cyclic_perms_lst):
            new_sample = sample.copy()
            new_sample['choices'] = cyclic_perms_lst[i]
            # The original 'answer' field refers to the *content* of the correct choice.
            # For bias identification, we care about the *position*.
            # The 'answer' field will be overwritten below for demonstrations.
            new_datalist.append(new_sample)

    for sample in tqdm(new_datalist, desc="Creating Demonstration Variants"):
        if not sample["choices"]: # Skip if choices list is empty after permutations
            continue
        for answer_idx in range(len(sample["choices"])):
            # We are interested in up to 5 choices/answer indices for demonstrations
            if answer_idx >= 5: # Cap at 5 for the demonstration_sets keys
                continue
            new_sample = sample.copy()
            new_sample["answer"] = answer_idx # This 'answer' is the target *index* for demonstration
            demonstration_sets[answer_idx].append(new_sample)
            final_datalist.append(new_sample)

    return demonstration_sets, final_datalist

# Create variants from the sampled_dataset
demonstration_sets, final_datalist = create_dataset_variants(sampled_dataset)

print("\nNumber of samples in demonstration_sets for each answer index:")
for k, v in demonstration_sets.items():
    print(f"Answer Index {k}: {len(v)} samples")

# Convert the list of all variants to a Pandas DataFrame for inspection
final_dataframe = pd.DataFrame(final_datalist)
print("\nExample rows from final_dataframe (variants):")
print(final_dataframe.iloc[16:20]) # Displaying a few rows to see variety

# Convert the final list of variants back to a Hugging Face Dataset if needed for other operations
# Though the main script uses `demonstration_sets` (dict of lists of dicts)
# and `df_sampled.to_dict('records')` for evaluation.
from datasets import Dataset as HFDataset # Alias to avoid conflict
final_dataset_huggingface = HFDataset.from_pandas(final_dataframe)
print(f"\nType of final_dataset (all variants): {type(final_dataset_huggingface)}")
print(f"Representation of final_dataset (all variants):\n{final_dataset_huggingface}")


Distribution of number of choices in sampled_dataset:
numchoices
2    522
3    220
4    250
5      8
dtype: int64


Creating Cyclic Permutations:   0%|          | 0/1000 [00:00<?, ?it/s]

Creating Demonstration Variants:   0%|          | 0/2744 [00:00<?, ?it/s]


Number of samples in demonstration_sets for each answer index:
Answer Index 0: 2744 samples
Answer Index 1: 2744 samples
Answer Index 2: 1700 samples
Answer Index 3: 1040 samples
Answer Index 4: 40 samples

Example rows from final_dataframe (variants):
                                                image  \
16  {'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...   
17  {'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...   
18  {'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...   
19  {'bytes': b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHD...   

                                             question  \
16  Which animal's skin is also adapted for surviv...   
17  Which animal's skin is also adapted for surviv...   
18  Which animal's skin is also adapted for surviv...   
19  Which animal's skin is also adapted for surviv...   

                         choices  answer  \
16  [hairy armadillo, snowy owl]       0   
17  [hairy armadillo, snowy owl]       1   
18  [snowy owl, hairy armadillo]       0  

# Per Head Steering - Constants and Model Loading
This cell marks the beginning of the main, active part of the script: **per-head steering**.
- It defines global constants specific to this approach (e.g., `MODEL_ID`, `NUM_DEMONSTRATIONS`, file paths for saving steering data).
- It includes the `load_model_and_processor` function, which is responsible for loading the LLaVA model (e.g., `llava-hf/llava-1.5-7b-hf`) and its corresponding processor. The model is loaded with 4-bit quantization to save memory.
- It then calls this function to load the model and processor into memory.

In [6]:
# --- Per-Head Steering: Main Implementation ---

# Global constants for the per-head steering pipeline
MODEL_ID = "llava-hf/llava-1.5-7b-hf"
NUM_DEMONSTRATIONS = 250  # Number of demonstration examples per rule (answer position)
MAX_EVAL_SAMPLES = 100    # Max samples to use for evaluation runs
RANDOM_SEED = 42
STEERING_ALPHA = 1.0      # Scaling factor for identified bias directions
STEERING_FACTOR_INFERENCE = 1.0 # Factor for applying steering during inference
TOP_K_HEADS = 10          # Number of top important heads to use for steering per layer
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# File paths for saving/loading per-head steering data and raw activations
STEERING_DATA_FILE = "/kaggle/working/llava_per_head_steering_data.npz" # Stores final directions & scores
RAW_ACTIVATIONS_PER_HEAD_FILE = "/kaggle/working/llava_raw_activations_per_head.npz" # Stores raw activations
# PCA_DIRECTIONS_PER_HEAD_FILE = "/kaggle/working/llava_pca_directions_per_head.npz" # Intermediate (if saved separately)
# BIAS_DIRECTIONS_PER_HEAD_FILE = "/kaggle/working/llava_bias_directions_per_head.npz" # Intermediate (if saved separately)


# Constants for plotting results
LAYERS_TO_PLOT = ['layer_10', 'layer_20', 'layer_30'] # Example layers
TOP_K_HEADS_TO_PLOT = 5
RULES_TO_PLOT = [0, 4] # Example rules (answer indices) to show in plots
PLOT_OUTPUT_DIR = "/kaggle/working/top_head_raw_activation_plots"

# Set random seeds for reproducibility
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

def load_model_and_processor():
    """Loads the LLaVA model and processor with quantization."""
    print(f"Loading model: {MODEL_ID}")
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16 # Use float16 for computation with 4-bit weights
    )
    processor = AutoProcessor.from_pretrained(MODEL_ID)
    model = LlavaForConditionalGeneration.from_pretrained(
        MODEL_ID,
        quantization_config=quantization_config,
        torch_dtype=torch.float16, # Load weights in float16 if not quantizing to float16 directly
        low_cpu_mem_usage=True,    # Optimizes CPU memory usage during loading
        device_map="auto"          # Automatically distribute model layers across available devices (GPU/CPU)
        # use_flash_attention_2=True # Enable if installed and compatible
    )
    print("Model and processor loaded successfully for per-head steering.")
    return model, processor

# Load the model and processor for the main per-head steering task
model, processor = load_model_and_processor()

Loading model: llava-hf/llava-1.5-7b-hf


processor_config.json:   0%|          | 0.00/173 [00:00<?, ?B/s]

chat_template.json:   0%|          | 0.00/701 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/505 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json:   0%|          | 0.00/1.45k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/3.62M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/41.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/552 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/950 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/70.1k [00:00<?, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.18G [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/141 [00:00<?, ?B/s]

Model and processor loaded successfully for per-head steering.


# SteerFairLlavaPerHead Class - Part 1: Initialization, Prompt Formatting, and Core Hooking Logic
This cell defines the `format_prompt` helper function and the initial part of the `SteerFairLlavaPerHead` class.
-   `format_prompt`: Creates the structured prompt required by LLaVA, incorporating the question, choices, and optionally an image token and a specified answer.
-   `SteerFairLlavaPerHead.__init__`: Initializes the class with the model, processor, and key parameters like number of heads and dimensions.
-   `_capture_activation_hook_per_head`: This is the core hook function. It's designed to be registered as a *pre-hook* on the output projection layer (`o_proj`) of each self-attention block. It captures the input to `o_proj`, reshapes it to isolate per-head activations, and stores the activations for the last token in the input sequence.
-   `_register_hooks`, `_remove_hooks`: Utility methods to manage the registration and removal of these hooks.
-   `_get_activations_for_input`: A method that takes a prompt and image, prepares them using the processor, runs a forward pass through the model (which triggers the registered hooks), and returns the captured per-head activations.

In [7]:
def format_prompt(sample: dict, include_answer: int = None, ask_for_exact_text: bool = False) -> str:
    """
    Formats a prompt for the LLaVA model.

    Args:
        sample: A dictionary containing 'question', 'choices', and optionally 'q_id'.
        include_answer: If not None, an integer representing the 0-indexed
                        correct answer. The prompt will include this answer.
        ask_for_exact_text: If True and include_answer is None, the prompt will
                            ask the model to respond with the option number.

    Returns:
        The formatted prompt string.
        (Note: The original script's format_prompt_lw returned (prompt, image).
         This version, aligned with the SteerFairLlavaPerHead usage, only returns the prompt string.
         Image loading is handled separately based on q_id within the class methods.)
    """
    question = sample['question']
    choices = sample['choices']
    qid = sample.get("q_id", None) # q_id is used to construct the image path

    # Construct image path based on q_id (consistent with how images were saved)
    path = f"/kaggle/working/images/image_{qid}.png" if qid is not None else None
    # Check if image exists (PIL will handle if path is None)
    image_exists = os.path.exists(path) if path else False
    # image = PIL.Image.open(path) if image_exists else None # Image loading handled in generate/get_activations

    prompt = "A chat between a curious human and an artificial intelligence assistant.\nThe assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: "
    image_token_str = "<image>\n" if image_exists else "" # Add <image> token if image file is expected
    prompt += image_token_str + question + "\n"

    numbered_choices = "\n".join([f"{i+1}. {choice}" for i, choice in enumerate(choices)])
    prompt += f"Choices:\n{numbered_choices}\n"



    if include_answer is not None:
        # include_answer is the 0-indexed correct choice index
        prompt += f"ASSISTANT: {include_answer + 1}" # Model is prompted with the 1-based number
    elif ask_for_exact_text:
        prompt += f"Question: Which option number is correct? Respond with only the single number.\nASSISTANT:"
    else:
        prompt += "ASSISTANT:" # Standard prompt for generation without a specified answer
    return prompt


class SteerFairLlavaPerHead:
    """
    Manages per-head activation extraction, bias identification, and steering for LLaVA.
    """
    def __init__(self, model, processor, device="cuda", alpha=1.0):
        """
        Initializes the SteerFairLlavaPerHead instance.

        Args:
            model: The pre-trained LLaVA model.
            processor: The LLaVA processor.
            device: The device to run computations on ('cuda' or 'cpu').
            alpha: Default scaling factor for bias directions.
        """
        self.model = model
        self.processor = processor
        self.device = device
        self.alpha = alpha
        self.hook_handles = []
        self.is_steering_active = False

        # Storage for activations, directions, scores, and target heads
        self.captured_activations_per_head = {} # {layer_key: {head_idx: [list_of_activations]}}
        self.bias_directions = {}             # {layer_key: {head_idx: direction_vector}}
        self.head_importance_scores = {}      # {layer_key: {head_idx: importance_score}}
        self.target_steering_heads = {}       # {layer_key: set(head_indices_to_steer)}

        # Model configuration (specific to LLaMA-based LLaVA, e.g., 7B)
        # These should ideally be fetched from model.config if possible and robustly
        self.num_heads = getattr(model.config, 'num_attention_heads', getattr(model.config.text_config, 'num_attention_heads', 32))
        self.hidden_dim = getattr(model.config, 'hidden_size', getattr(model.config.text_config, 'hidden_size', 4096))
        if self.num_heads == 0: # Fallback if config access fails
            print("Warning: num_heads is 0, defaulting to 32.")
            self.num_heads = 32
        if self.hidden_dim == 0:
            print("Warning: hidden_dim is 0, defaulting to 4096.")
            self.hidden_dim = 4096

        self.head_dim = self.hidden_dim // self.num_heads if self.num_heads > 0 else 0
        if self.head_dim == 0 and self.num_heads > 0 : # check division result
            print(f"Warning: head_dim is 0 after division ({self.hidden_dim}/{self.num_heads}). Defaulting to 128.")
            self.head_dim = 128


        print(f"SteerFairLlavaPerHead initialized: num_heads={self.num_heads}, head_dim={self.head_dim}, hidden_dim={self.hidden_dim}")

    def _capture_activation_hook_per_head(self, layer_key: str):
        """
        Creates a pre-hook function to capture per-head activations.

        This hook is intended for `module.o_proj.register_forward_pre_hook`.
        It captures the input to the o_proj layer, reshapes it to
        (num_heads, head_dim) for the last token, and stores these.

        Args:
            layer_key: A string identifier for the layer (e.g., 'layer_15').

        Returns:
            A hook function.
        """
        def pre_hook(module, input_args):
            # input_args[0] to o_proj is (batch_size, seq_len, hidden_dim)
            activation_tensor = input_args[0]
            if isinstance(activation_tensor, torch.Tensor) and activation_tensor.ndim >= 3:
                # Get the activation for the last token: (batch_size, hidden_dim)
                # We assume batch_size = 1 for demonstration collection
                last_token_activation_flat = activation_tensor[0, -1, :].detach() # Shape: (hidden_dim)

                try:
                    # Reshape to (num_heads, head_dim)
                    per_head_activations = last_token_activation_flat.reshape(self.num_heads, self.head_dim)
                except RuntimeError as e:
                    # This can happen if hidden_dim is not perfectly divisible or shapes are unexpected
                    print(f"Hook Error ({layer_key}): Reshape failed for tensor of shape {last_token_activation_flat.shape} to {(self.num_heads, self.head_dim)}. Error: {e}. Skipping.")
                    return input_args # Must return input_args for pre-hook

                # Initialize storage for this layer if not present
                if layer_key not in self.captured_activations_per_head:
                    self.captured_activations_per_head[layer_key] = {h_idx: [] for h_idx in range(self.num_heads)}

                # Store activation for each head
                for head_idx in range(self.num_heads):
                    # Detach, move to CPU, convert to numpy for storage efficiency
                    self.captured_activations_per_head[layer_key][head_idx].append(per_head_activations[head_idx].cpu().numpy())
                del per_head_activations, last_token_activation_flat # Memory cleanup
            return input_args # Pre-hooks must return input_args (as a tuple if modified)
        return pre_hook

    def _register_hooks(self, hook_fn_provider, use_pre_hook: bool = False):
        """
        Registers hooks to specified layers in the model.

        Args:
            hook_fn_provider: A function that takes a layer_key and returns a hook function.
            use_pre_hook: If True, registers a pre-hook; otherwise, a forward hook.
                          For activation capture/modification before o_proj, pre_hook=True is used.
        """
        self._remove_hooks() # Clear any existing hooks
        if use_pre_hook: # Reset capture dict only when registering capture hooks
            self.captured_activations_per_head = {}

            layers = self.model.language_model.layers
            module_name_part = 'self_attn' # Standard LLaMA self-attention block name

            for layer_idx, layer in enumerate(layers):
                attn_module = getattr(layer, module_name_part, None)
                if attn_module and hasattr(attn_module, 'o_proj'):
                    target_module = attn_module.o_proj # Target the output projection
                    layer_key = f"layer_{layer_idx}"

                    if use_pre_hook:
                        handle = target_module.register_forward_pre_hook(hook_fn_provider(layer_key))
                    else:
                        # Standard forward hooks are registered on the module itself (e.g., self_attn)
                        # to get its output, not usually on o_proj for this purpose.
                        # This branch might need adjustment based on exact needs.
                        # For this class, pre_hooks on o_proj are the primary mechanism.
                        print(f"Warning: `use_pre_hook=False` is not the typical usage for SteerFairLlavaPerHead on o_proj. Review hook logic.")
                        handle = target_module.register_forward_hook(hook_fn_provider(layer_key))
                    self.hook_handles.append(handle)
                # else:
                #     print(f"Debug: Could not find 'self_attn.o_proj' in layer {layer_idx}")

    def _remove_hooks(self):
        """Removes all registered hook handles."""
        for handle in self.hook_handles:
            handle.remove()
        self.hook_handles = []
        # Optionally clear captured data if it's only relevant while hooks are active
        # self.captured_activations_per_head = {} # Cleared in _register_hooks if use_pre_hook

    def _get_activations_for_input(self, prompt: str, image: PIL.Image.Image) -> dict:
        """
        Gets per-head activations for a given prompt and image.

        Hooks must be registered (via `_register_hooks` with `_capture_activation_hook_per_head`)
        before calling this method. The method clears `self.captured_activations_per_head`
        at the start.

        Args:
            prompt: The input text prompt.
            image: A PIL.Image.Image object, or None if no image.

        Returns:
            A dictionary of captured per-head activations:
            {layer_key: {head_idx: [list_of_numpy_arrays]}}
            where the list for each head will contain one activation array from this call.
        """
        self.captured_activations_per_head = {} # Clear before new capture for this specific input

        # Prepare inputs
        inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(self.device, dtype=torch.float16)

        with torch.no_grad():
            # Perform a forward pass through the model to trigger the hooks
            # We are interested in activations from the prompt processing, so a full model call is fine.
            _ = self.model(**inputs)

        activations_copy = self.captured_activations_per_head.copy() # Get the captured data
        self.captured_activations_per_head = {} # Clear again after copying for this specific call context
                                                # (though _register_hooks also clears it when setting up for a new batch of captures)
        del inputs
        # gc.collect(); torch.cuda.empty_cache() # Optional: aggressive cleanup
        return activations_copy

# SteerFairLlavaPerHead Class - Part 2: Bias Identification Logic
This cell continues the `SteerFairLlavaPerHead` class definition, focusing on the methods for identifying bias directions.
-   `identify_bias_directions`: This is the main orchestrator for this part.
    -   It iterates through "rules" (target answer positions, e.g., 0 for first choice, 1 for second, etc.) and samples from the `demonstration_data`.
    -   For each sample, it formats a prompt where the model is "told" the answer corresponding to the current rule.
    -   It calls `_get_activations_for_input` to collect per-head activations for these demonstration prompts.
    -   Optionally saves all raw collected activations using `_save_raw_activations`.
    -   **PCA and Importance Scoring**: For each head in each layer, it performs PCA on the activations collected for each rule. The first principal component is taken as the rule-specific direction for that head. The explained variance ratio of this component is used to calculate an "importance score" for the head (averaged across rules).
    -   **Combining Bias Directions**: For each head, the rule-specific PCA directions are combined (e.g., using QR decomposition followed by averaging) to get a single "bias direction" for that head. This direction is scaled by `self.alpha`.
    -   The final bias directions and importance scores are saved using `_save_steering_data`.
    -   Finally, it calls `_select_top_k_heads` (defined in the next part) to determine which heads will be targeted for steering.
-   `_save_raw_activations`: Saves the collected raw activations to a compressed `.npz` file. This can be useful for later analysis or debugging.
-   `_save_steering_data`: Saves the computed bias directions and head importance scores to a compressed `.npz` file.

In [8]:
class SteerFairLlavaPerHead(SteerFairLlavaPerHead): # Continue class definition

    def identify_bias_directions(self, demonstration_data: dict, save_raw: bool = True):
        """
        Identifies per-head bias directions and importance scores.

        Args:
            demonstration_data: A dictionary {rule_idx: [list_of_samples]}, where
                                rule_idx is the target answer index (0-4) and samples
                                are dicts prepared for that rule.
            save_raw: If True, saves all collected raw activations.
        """
        print("Identifying PER-HEAD bias directions and importance scores...")
        # Max number of answer choices/rules we are considering (e.g., 0, 1, 2, 3, 4)
        # ScienceQA typically has up to 5 choices.
        num_choices_max = max(demonstration_data.keys()) + 1 if demonstration_data else 0
        if num_choices_max == 0:
            print("Warning: Demonstration data is empty. Cannot identify bias directions.")
            return

        # {rule_idx: {layer_key: {head_idx: [list_of_activations]}}}
        all_rules_activations_per_head = {r: {} for r in range(num_choices_max)}

        # Register hooks to capture activations
        self._register_hooks(self._capture_activation_hook_per_head, use_pre_hook=True)
        print(f"Registered {len(self.hook_handles)} pre-hooks for bias identification...")

        print("Generating demonstrations and collecting per-head activations...")
        for rule_idx in demonstration_data: # rule_idx is the target answer (0, 1, 2, 3, 4)
            if not demonstration_data[rule_idx]:
                print(f"  No demonstration samples for Rule {rule_idx}. Skipping.")
                continue

            num_samples_this_rule = min(len(demonstration_data[rule_idx]), NUM_DEMONSTRATIONS)
            print(f"  Processing Rule {rule_idx}: Using {num_samples_this_rule} samples.")

            # Initialize storage for this rule if it wasn't pre-initialized fully
            if rule_idx not in all_rules_activations_per_head:
                all_rules_activations_per_head[rule_idx] = {}

            for sample_dict in tqdm(demonstration_data[rule_idx][:num_samples_this_rule],
                                     desc=f"Rule {rule_idx} Samples", leave=False):
                # 'answer' in sample_dict is the 0-indexed target answer for this demonstration
                biased_answer_idx = sample_dict['answer']

                # Ensure the biased answer index is valid for the sample's choices
                if not sample_dict.get('choices') or biased_answer_idx >= len(sample_dict['choices']):
                    # print(f"Skipping sample for Rule {rule_idx} due to invalid biased_answer_idx or choices.")
                    continue

                demo_prompt = format_prompt(sample_dict, include_answer=biased_answer_idx)
                qid = sample_dict.get("q_id", None)
                image_path = f"/kaggle/working/images/image_{qid}.png" if qid is not None else None
                demo_image = None
                if image_path and os.path.exists(image_path):
                    try:
                        demo_image = PIL.Image.open(image_path)
                    except Exception as e:
                        print(f"Warning: Could not load image {image_path} for q_id {qid}: {e}")
                elif qid is not None and not image_path: # qid present but path logic failed
                    # This case should be rare if qid -> image_path mapping is correct
                    pass


                try:
                    # _get_activations_for_input returns {layer: {head: [act_array_for_this_input]}}
                    activations_one_sample = self._get_activations_for_input(demo_prompt, demo_image)

                    for layer_key, heads_data_one_sample in activations_one_sample.items():
                        # Ensure layer_key exists in the main accumulation dict for this rule
                        if layer_key not in all_rules_activations_per_head[rule_idx]:
                            all_rules_activations_per_head[rule_idx][layer_key] = {h: [] for h in range(self.num_heads)}

                        for head_idx, head_act_list_one_sample in heads_data_one_sample.items():
                            if head_act_list_one_sample: # Should be a list with one item
                                # Append the numpy array (activation)
                                all_rules_activations_per_head[rule_idx][layer_key][head_idx].append(head_act_list_one_sample[0])
                    del activations_one_sample
                    if demo_image: demo_image.close()

                except Exception as e:
                    print(f"Error processing demonstration (Rule {rule_idx}, q_id {qid}): {e}")
                    if demo_image: demo_image.close()
                    torch.cuda.empty_cache(); gc.collect()
        self._remove_hooks()
        print("Per-head activation collection complete.")

        if save_raw:
            self._save_raw_activations(all_rules_activations_per_head)

        print("Computing PCA directions per head/rule & head importance scores...")
        pca_directions_per_head_rule = {} # {layer_key: {head_idx: {rule_idx: direction_vector}}}
        calculated_importance_scores = {} # {layer_key: {head_idx: avg_explained_variance_across_rules}}

        # Get all unique layer keys observed across all rules
        all_observed_layer_keys = set()
        for rule_data in all_rules_activations_per_head.values():
            all_observed_layer_keys.update(rule_data.keys())

        for layer_key in tqdm(sorted(list(all_observed_layer_keys)), desc="PCA & Importance Calc"):
            pca_directions_per_head_rule[layer_key] = {}
            calculated_importance_scores[layer_key] = {}

            for head_idx in range(self.num_heads):
                pca_directions_per_head_rule[layer_key][head_idx] = {}
                head_total_explained_variance = 0.0
                head_rules_processed_for_pca = 0

                for rule_idx in all_rules_activations_per_head: # Iterate through rules 0..N
                    # Get activations for this specific layer, head, and rule
                    activations_list_for_head_rule = all_rules_activations_per_head.get(rule_idx, {}).get(layer_key, {}).get(head_idx)

                    if activations_list_for_head_rule and len(activations_list_for_head_rule) >= 2:
                        try:
                            # Stack: (num_samples_for_rule, head_dim)
                            activations_np = np.vstack(activations_list_for_head_rule)
                            if activations_np.shape[1] != self.head_dim:
                                # print(f"PCA Warning ({layer_key}, H{head_idx}, R{rule_idx}): Dim mismatch {activations_np.shape[1]} vs {self.head_dim}. Skipping.")
                                continue

                            pca = PCA(n_components=1, random_state=RANDOM_SEED)
                            pca.fit(activations_np)
                            pca_directions_per_head_rule[layer_key][head_idx][rule_idx] = pca.components_[0]
                            head_total_explained_variance += pca.explained_variance_ratio_[0]
                            head_rules_processed_for_pca += 1
                            del activations_np, pca
                        except Exception as e:
                            print(f"PCA Error ({layer_key}, H{head_idx}, R{rule_idx}): {e}")
                # Average explained variance for this head across rules it participated in
                if head_rules_processed_for_pca > 0:
                    calculated_importance_scores[layer_key][head_idx] = head_total_explained_variance / head_rules_processed_for_pca
                else:
                    calculated_importance_scores[layer_key][head_idx] = 0.0
            gc.collect() # Collect garbage after processing each layer

        # Cleanup raw activations from memory if not needed further (and save_raw was False or already done)
        if not save_raw: # If not saved, it's only in memory
            del all_rules_activations_per_head
            gc.collect()
            print("Cleared in-memory raw activations after PCA.")


        print("Combining bias directions per head (from rule-specific PCA directions)...")
        # self.bias_directions will store {layer_key: {head_idx: combined_direction_vector}}
        temp_bias_directions = {}
        for layer_key in tqdm(pca_directions_per_head_rule.keys(), desc="QR per Head"):
            temp_bias_directions[layer_key] = {}
            for head_idx in pca_directions_per_head_rule[layer_key]:
                # Directions for this head across different rules: {rule_idx: direction_vector}
                head_rule_directions_map = pca_directions_per_head_rule[layer_key][head_idx]
                if len(head_rule_directions_map) >= 1:
                    # Stack directions: (num_rules_for_this_head, head_dim)
                    stacked_rule_directions = np.vstack(list(head_rule_directions_map.values()))
                    try:
                        # Orthonormalize columns (directions) using QR
                        # Q will have orthonormal columns, R is upper triangular
                        # We want to operate on rows if each row is a direction.
                        # So, transpose so that directions are columns: (head_dim, num_rules)
                        if stacked_rule_directions.shape[0] == 1: # Only one rule's direction
                            q_basis = stacked_rule_directions.T / np.linalg.norm(stacked_rule_directions)
                        else:
                            q_basis, r_factor = np.linalg.qr(stacked_rule_directions.T, mode='reduced')

                        # Combine the orthonormal basis vectors (e.g., by averaging)
                        # q_basis columns are the orthonormal vectors
                        combined_dir_for_head = np.mean(q_basis, axis=1)
                        norm = np.linalg.norm(combined_dir_for_head)
                        if norm > 1e-6: # Avoid division by zero
                            combined_dir_for_head /= norm
                        # Apply alpha scaling factor
                        temp_bias_directions[layer_key][head_idx] = combined_dir_for_head * self.alpha
                        del stacked_rule_directions, q_basis, combined_dir_for_head
                        if 'r_factor' in locals(): del r_factor
                    except Exception as e:
                        print(f"QR/Combination Error ({layer_key}, H{head_idx}): {e}")
            gc.collect() # Collect after each layer's heads

        self.bias_directions = temp_bias_directions
        self.head_importance_scores = calculated_importance_scores
        num_bias_dirs = sum(len(h_data) for h_data in self.bias_directions.values())
        num_scores = sum(len(h_data) for h_data in self.head_importance_scores.values())
        print(f"Identified combined bias directions for {num_bias_dirs} layer/head pairs.")
        print(f"Calculated importance scores for {num_scores} layer/head pairs.")

        # Save the processed directions and scores
        self._save_steering_data(STEERING_DATA_FILE) # Saves self.bias_directions and self.head_importance_scores

        # Select top K heads based on importance scores for steering
        self._select_top_k_heads() # This will populate self.target_steering_heads

        # Final cleanup of intermediate data structures
        if save_raw: # If raw were saved to disk, they might still be in memory if not explicitly deleted
            if 'all_rules_activations_per_head' in locals() or 'all_rules_activations_per_head' in globals():
                del all_rules_activations_per_head # Ensure it's deleted
        del pca_directions_per_head_rule, temp_bias_directions, calculated_importance_scores
        gc.collect()
        torch.cuda.empty_cache()


    def _save_raw_activations(self, all_rules_activations_per_head: dict):
        """Saves raw per-head activations to a compressed NPZ file."""
        raw_save_dict = {}
        print(f"Preparing raw per-head activations for saving to {RAW_ACTIVATIONS_PER_HEAD_FILE}...")
        # Subsample to save space and time if lists are too long
        MAX_RAW_ACTIVATIONS_PER_RULE_HEAD = 100

        for rule_idx, layers_data in all_rules_activations_per_head.items():
            for layer_key, heads_data in layers_data.items():
                 for head_idx, activations_list in heads_data.items():
                      if activations_list: # List of numpy arrays
                          stacked_activations = np.vstack(activations_list)
                          # Subsample if necessary
                          if stacked_activations.shape[0] > MAX_RAW_ACTIVATIONS_PER_RULE_HEAD:
                              indices = np.random.choice(stacked_activations.shape[0], MAX_RAW_ACTIVATIONS_PER_RULE_HEAD, replace=False)
                              subsampled_activations = stacked_activations[indices]
                          else:
                              subsampled_activations = stacked_activations
                          # Create a unique key for saving
                          save_key = f"{layer_key}_head_{head_idx}_rule_{rule_idx}_raw"
                          raw_save_dict[save_key] = subsampled_activations
        try:
            print(f"Saving {len(raw_save_dict)} arrays of raw activations...")
            np.savez_compressed(RAW_ACTIVATIONS_PER_HEAD_FILE, **raw_save_dict)
            print(f"Raw activations saved to {RAW_ACTIVATIONS_PER_HEAD_FILE}.")
        except Exception as e:
            print(f"Error saving raw activations: {e}")
        del raw_save_dict # Cleanup
        gc.collect()


    def _save_steering_data(self, filepath: str):
        """Saves final bias directions and head importance scores to an NPZ file."""
        if not self.bias_directions and not self.head_importance_scores:
             print("No bias directions or importance scores to save.")
             return

        save_dict = {}
        print(f"Preparing final bias directions and importance scores for saving to {filepath}...")

        # Add bias directions to save_dict
        for layer_key, heads_data in self.bias_directions.items():
            for head_idx, direction_vector in heads_data.items():
                save_key = f"{layer_key}_head_{head_idx}_direction"
                save_dict[save_key] = direction_vector

        # Add importance scores to save_dict
        for layer_key, heads_scores in self.head_importance_scores.items():
             for head_idx, score_value in heads_scores.items():
                  save_key = f"{layer_key}_head_{head_idx}_score"
                  save_dict[save_key] = np.array([score_value]) # Save score as a numpy array

        try:
            print(f"Saving {len(save_dict)} arrays (directions and scores)...")
            np.savez_compressed(filepath, **save_dict)
            print(f"Final steering data (directions & scores) saved successfully to {filepath}.")
        except Exception as e:
            print(f"Error saving final steering data: {e}")
        del save_dict # Cleanup
        gc.collect()

# SteerFairLlavaPerHead Class - Part 3: Loading, Steering Application, Generation, and Evaluation
This cell completes the `SteerFairLlavaPerHead` class with methods for:
-   `load_steering_data`: Loads pre-computed bias directions and importance scores from a file. This allows skipping the computationally intensive `identify_bias_directions` step if data is already available. It uses regular expressions to parse keys from the `.npz` file.
-   `_select_top_k_heads`: Based on the loaded or computed importance scores, this method selects the `TOP_K_HEADS` most important heads per layer that also have bias directions. These selected heads are stored in `self.target_steering_heads` and are the only ones that will be actively steered.
-   `_steering_hook`: This is the hook function used during inference *with steering*. It's registered as a pre-hook on `o_proj`. For each targeted head in a layer, it retrieves the head's bias direction and subtracts a scaled version of this direction from the head's activation. It also normalizes the steered activation to maintain its original magnitude.
-   `apply_steering`, `remove_steering`: Methods to register and remove the `_steering_hook` dynamically.
-   `generate`: Generates text from the model given a sample. It can operate with or without steering. If steering is enabled, it applies the steering hooks before generation and removes them afterward.
-   `evaluate`: Evaluates the model's performance on a dataset. It runs generation for each sample twice: once regularly and once with steering (if bias directions are available). It then compares the accuracies.

In [9]:
class SteerFairLlavaPerHead(SteerFairLlavaPerHead): # Continue class definition

    def load_steering_data(self, filepath: str) -> bool:
        """
        Loads combined bias directions and head importance scores from an NPZ file.

        Populates `self.bias_directions` and `self.head_importance_scores`.
        Also calls `_select_top_k_heads` after successful loading.

        Args:
            filepath: Path to the .npz file.

        Returns:
            True if loading was successful and data was found, False otherwise.
        """
        if not os.path.exists(filepath):
            print(f"Steering data file not found: {filepath}")
            return False
        try:
            print(f"Loading steering data (directions & scores) from {filepath}...")
            with np.load(filepath) as loaded_data: # Use with statement for proper file closing
                temp_bias_directions = {}
                temp_head_importance_scores = {}

                # Regular expressions to parse keys
                # Key format: {layer_key}_head_{head_idx}_direction or {layer_key}_head_{head_idx}_score
                dir_pattern = re.compile(r"^(layer_\d+)_head_(\d+)_direction$")
                score_pattern = re.compile(r"^(layer_\d+)_head_(\d+)_score$")

                for key_in_file in loaded_data.files:
                     dir_match = dir_pattern.match(key_in_file)
                     score_match = score_pattern.match(key_in_file)

                     if dir_match:
                         layer_key, head_idx_str = dir_match.group(1), dir_match.group(2)
                         head_idx = int(head_idx_str)
                         if layer_key not in temp_bias_directions:
                             temp_bias_directions[layer_key] = {}
                         temp_bias_directions[layer_key][head_idx] = loaded_data[key_in_file]
                     elif score_match:
                         layer_key, head_idx_str = score_match.group(1), score_match.group(2)
                         head_idx = int(head_idx_str)
                         if layer_key not in temp_head_importance_scores:
                             temp_head_importance_scores[layer_key] = {}
                         # Scores are saved as np.array([score_value])
                         temp_head_importance_scores[layer_key][head_idx] = float(loaded_data[key_in_file][0])

            self.bias_directions = temp_bias_directions
            self.head_importance_scores = temp_head_importance_scores

            loaded_dirs_count = sum(len(h) for h in self.bias_directions.values())
            loaded_scores_count = sum(len(h) for h in self.head_importance_scores.values())
            print(f"Loaded {loaded_dirs_count} bias directions and {loaded_scores_count} importance scores.")

            if loaded_dirs_count == 0: # If no directions, steering is not possible
                print("Warning: Loaded file contained no bias directions. Steering will not be effective.")
                # self.bias_directions = {} # Ensure it's empty
                # self.head_importance_scores = {}
                # self.target_steering_heads = {}
                return False # Treat as failure if essential data is missing

            self._select_top_k_heads() # Select heads based on loaded scores and directions
            return True

        except Exception as e:
            print(f"Error loading steering data from {filepath}: {e}")
            self.bias_directions = {} # Reset on error
            self.head_importance_scores = {}
            self.target_steering_heads = {}
            return False


    def _select_top_k_heads(self):
        """Selects the top K most important heads per layer for steering."""
        self.target_steering_heads = {} # {layer_key: set(head_indices_to_steer)}

        if TOP_K_HEADS is None or TOP_K_HEADS <= 0:
            print("Top-K head selection disabled (TOP_K_HEADS <= 0 or None). Targeting ALL heads with identified bias directions.")
            for layer_key, heads_data in self.bias_directions.items():
                # Target all heads for which a bias direction exists in this layer
                self.target_steering_heads[layer_key] = set(heads_data.keys())
            num_targets = sum(len(s) for s in self.target_steering_heads.values())
            print(f"All {num_targets} heads with bias directions will be targeted for steering.")
            return

        if not self.head_importance_scores:
             print("Warning: Cannot select Top-K heads as importance scores are not available.")
             print("Defaulting to targeting ALL heads with identified bias directions.")
             for layer_key, heads_data in self.bias_directions.items():
                self.target_steering_heads[layer_key] = set(heads_data.keys())
             num_targets = sum(len(s) for s in self.target_steering_heads.values())
             print(f"All {num_targets} heads with bias directions will be targeted (due to no scores).")
             return

        print(f"Selecting top {TOP_K_HEADS} heads per layer based on importance scores...")
        selected_count_total = 0
        for layer_key in self.bias_directions: # Iterate over layers that have bias directions
            if layer_key in self.head_importance_scores:
                # Get scores for heads in this layer that also have bias directions
                layer_scores_for_heads_with_bias_dir = {
                    h_idx: score
                    for h_idx, score in self.head_importance_scores[layer_key].items()
                    if h_idx in self.bias_directions[layer_key] # Ensure bias dir exists for the head
                }

                if not layer_scores_for_heads_with_bias_dir:
                    # print(f"  Layer {layer_key}: No heads with both scores and bias directions. Skipping.")
                    continue

                # Sort heads by score (descending)
                sorted_heads_by_score = sorted(
                    layer_scores_for_heads_with_bias_dir.items(),
                    key=lambda item: item[1], # Sort by score
                    reverse=True
                )

                # Select top K indices
                top_k_indices_for_layer = set(
                    h_idx for h_idx, score in sorted_heads_by_score[:TOP_K_HEADS]
                )

                if top_k_indices_for_layer:
                     self.target_steering_heads[layer_key] = top_k_indices_for_layer
                     selected_count_total += len(top_k_indices_for_layer)
                     # print(f"  Layer {layer_key}: Selected {len(top_k_indices_for_layer)} heads: {top_k_indices_for_layer}")
            # else:
                # print(f"  Layer {layer_key}: No importance scores available, though bias directions exist. Skipping top-K for this layer.")


        print(f"Selected a total of {selected_count_total} layer/head pairs for steering across all layers.")
        if selected_count_total == 0 and self.bias_directions:
             print("Warning: No heads were selected for steering, but bias directions exist. Check TOP_K_HEADS, scores, and logic.")


    def _steering_hook(self, layer_key: str, steering_factor: float):
        """
        Creates a pre-hook function for applying steering to targeted heads.

        Modifies the input to `o_proj` for specified heads by subtracting
        their (scaled) bias direction.

        Args:
            layer_key: Identifier for the layer.
            steering_factor: Scaling factor for the bias direction during steering.

        Returns:
            A hook function.
        """
        # Heads to steer in this specific layer
        target_heads_in_this_layer = self.target_steering_heads.get(layer_key)
        if not target_heads_in_this_layer:
            return lambda module, input_args: input_args # No-op if no heads to steer in this layer

        # Bias directions for all heads in this layer (we'll pick the target ones)
        all_bias_directions_in_layer = self.bias_directions.get(layer_key)
        if not all_bias_directions_in_layer:
            return lambda module, input_args: input_args # No-op if no directions for this layer

        # Pre-convert relevant bias directions to tensors on the correct device
        bias_direction_tensors_for_steering = {
            h_idx: torch.from_numpy(direction_np).to(self.device, dtype=torch.float16)
            for h_idx, direction_np in all_bias_directions_in_layer.items()
            if h_idx in target_heads_in_this_layer # Only for target heads
        }

        if not bias_direction_tensors_for_steering: # Should not happen if target_heads_in_this_layer is non-empty
            return lambda module, input_args: input_args

        # To avoid printing multiple times per call if there's a persistent issue
        # printed_layer_info = {lk: False for lk in self.target_steering_heads.keys()} # Moved from original

        def pre_hook(module, input_args):
            try:
                activation_tensor = input_args[0] # Input to o_proj: (batch, seq_len, hidden_dim)
                if not isinstance(activation_tensor, torch.Tensor) or activation_tensor.ndim < 3:
                    return input_args # Should be (input_args,) for pre-hook

                # Assuming batch_size = 1 for steering application during generation
                last_token_flat_activation = activation_tensor[0, -1, :].clone() # Shape: (hidden_dim)
                try:
                    # Reshape to (num_heads, head_dim)
                    per_head_activations_view = last_token_flat_activation.reshape(self.num_heads, self.head_dim)
                except RuntimeError: # Reshape failed
                    # if not printed_layer_info.get(layer_key, False):
                    #     print(f"Steering Hook ({layer_key}): Reshape failed. Skipping steering for this call.")
                    #     printed_layer_info[layer_key] = True
                    return input_args

                modified_head_activations_list = []
                was_any_head_steered = False

                for head_idx in range(self.num_heads):
                    current_head_activation = per_head_activations_view[head_idx].clone()
                    if head_idx in bias_direction_tensors_for_steering:
                        bias_dir_tensor = bias_direction_tensors_for_steering[head_idx]
                        if bias_dir_tensor.shape == current_head_activation.shape:
                            # Apply steering: subtract scaled bias direction
                            steered_head_act = current_head_activation - (bias_dir_tensor * steering_factor)

                            # Normalize to maintain original activation norm (magnitude) for this head
                            original_head_norm = torch.norm(current_head_activation, p=2)
                            steered_head_norm = torch.norm(steered_head_act, p=2)
                            if steered_head_norm > 1e-6: # Avoid division by zero
                                normalized_steered_head_act = steered_head_act * (original_head_norm / steered_head_norm)
                            else:
                                normalized_steered_head_act = steered_head_act # Or zero vector if norm is ~0

                            modified_head_activations_list.append(normalized_steered_head_act)
                            was_any_head_steered = True
                        else: # Shape mismatch for this head's bias vector
                            modified_head_activations_list.append(current_head_activation) # Use original
                    else: # This head is not targeted for steering
                        modified_head_activations_list.append(current_head_activation) # Use original

                if was_any_head_steered:
                    # Concatenate/stack modified head activations and reshape back to (hidden_dim)
                    updated_flat_activation = torch.stack(modified_head_activations_list, dim=0).reshape(self.hidden_dim)
                    # Update the original activation tensor (for the last token)
                    activation_tensor[0, -1, :] = updated_flat_activation

                return (activation_tensor,) # Must return a tuple of args for pre-hook
            except Exception as e:
                print(f"ERROR in PER-HEAD steering hook ({layer_key}, Head {head_idx if 'head_idx' in locals() else 'N/A'}): {e}")
                return input_args # Return original input on error
        return pre_hook


    def apply_steering(self, steering_factor: float = 1.0):
        """Applies steering by registering the steering hooks."""
        if not self.target_steering_heads: # Check if there are any heads selected for steering
            print("Warning: No target heads selected for steering. Steering not applied.")
            if self.is_steering_active: # If somehow active, remove
                self.remove_steering()
            return
        if self.is_steering_active: # Already active
            # print("Steering is already active.")
            return

        # Register pre-hooks with the steering logic for layers that have target heads
        # The hook provider now needs to be aware of which layers to hook (those in target_steering_heads)
        # _register_hooks iterates all layers; _steering_hook itself checks if layer_key has targets.
        target_layers_with_heads = list(self.target_steering_heads.keys())
        if not target_layers_with_heads:
            print("No layers contain target heads for steering. Steering not applied.")
            return

        print(f"Applying PER-HEAD steering (factor {steering_factor}) to {len(target_layers_with_heads)} layers containing target heads...")
        self._register_hooks(lambda lk: self._steering_hook(lk, steering_factor), use_pre_hook=True)
        self.is_steering_active = True


    def remove_steering(self):
        """Removes steering by unregistering the steering hooks."""
        if not self.is_steering_active:
            # print("Steering is not currently active.")
            return
        self._remove_hooks()
        self.is_steering_active = False
        # print("Per-head steering removed.")


    def generate(self, sample: dict, steer: bool = False, steering_factor: float = 1.0) -> str:
        """
        Generates text for a sample, optionally with steering.

        Args:
            sample: A dictionary containing sample data (e.g., from df_sampled.to_dict('records')).
            steer: If True, applies steering during generation.
            steering_factor: The factor to use for steering if `steer` is True.

        Returns:
            The generated text string (model's predicted answer number).
        """
        # format_prompt expects ask_for_exact_text=True for evaluation generation
        prompt_text = format_prompt(sample, ask_for_exact_text=True)
        qid = sample.get("q_id", None)
        image_path = f"/kaggle/working/images/image_{qid}.png" if qid is not None else None
        image_input = None
        if image_path and os.path.exists(image_path):
            try:
                image_input = PIL.Image.open(image_path)
            except Exception as e:
                print(f"Generate Warning: Could not load image {image_path} for q_id {qid}: {e}")

        if steer:
            self.apply_steering(steering_factor)

        # Prepare inputs for the model
        inputs = self.processor(text=prompt_text, images=image_input, return_tensors="pt").to(self.device, dtype=torch.float16)
        output_text_str = ""

        try:
            with torch.no_grad():
                generated_ids = self.model.generate(
                    **inputs,
                    max_new_tokens=5,  # Expecting a short numerical answer (e.g., "1", "2")
                    do_sample=False,   # Use greedy decoding for consistent output
                    num_beams=1,
                    pad_token_id=self.processor.tokenizer.pad_token_id
                )
            # Decode only the newly generated part, skipping special tokens
            input_len = inputs['input_ids'].shape[1]
            output_text_str = self.processor.batch_decode(generated_ids[:, input_len:], skip_special_tokens=True)[0].strip()
        except Exception as e:
            print(f"Error during generation: {e}")
        finally:
            if steer: # Crucial: remove hooks after generation if they were applied
                self.remove_steering()
            if image_input: image_input.close() # Close image file
            del inputs, generated_ids
            # gc.collect() # Minor cleanup
            # if torch.cuda.is_available(): torch.cuda.empty_cache() # More aggressive
        return output_text_str


    def evaluate(self, eval_dataset_list_of_dicts: list, steering_factor: float = 1.0, max_samples: int = None) -> dict:
        """
        Evaluates the model's performance with and without steering.

        Args:
            eval_dataset_list_of_dicts: A list of sample dictionaries for evaluation.
                                         Each dict must have 'choices' and 'answer' (0-indexed true answer).
            steering_factor: The factor for steering.
            max_samples: Maximum number of samples to evaluate.

        Returns:
            A dictionary with 'regular_accuracy' and 'steered_accuracy'.
        """
        if max_samples is not None and max_samples > 0:
            eval_dataset_list_of_dicts = eval_dataset_list_of_dicts[:max_samples]

        results = {'regular': {'correct': 0, 'total': 0},
                   'steered': {'correct': 0, 'total': 0}}
        print_interval = 20 # How often to print detailed sample output

        for i, sample in enumerate(tqdm(eval_dataset_list_of_dicts, desc="Evaluating Samples")):
            choices = sample.get('choices')
            # 'answer' in eval_dataset_list_of_dicts should be the 0-indexed ground truth answer
            correct_answer_idx = sample.get('answer')

            correct_answer_num_str = str(correct_answer_idx + 1)

            # --- Regular (No Steering) ---
            output_regular = self.generate(sample, steer=False)
            is_correct_regular = (output_regular == correct_answer_num_str)
            if is_correct_regular:
                results['regular']['correct'] += 1
            results['regular']['total'] += 1

            # --- Steered ---
            output_steered = "N/A" # Default if steering not attempted
            is_correct_steered = False
            # Only attempt steering if there are target heads (implies bias directions were loaded/found)
            if self.target_steering_heads:
                output_steered = self.generate(sample, steer=True, steering_factor=steering_factor)
                is_correct_steered = (output_steered == correct_answer_num_str)
                if is_correct_steered:
                    results['steered']['correct'] += 1
            results['steered']['total'] += 1 # Increment total for steered, even if not applied, for consistent comparison base

            if i % print_interval == 0 or i == len(eval_dataset_list_of_dicts) -1 :
               print("-" * 40)
               print(f"Eval Sample {i+1}/{len(eval_dataset_list_of_dicts)} (q_id: {sample.get('q_id', 'N/A')})")
               print(f"  Choices: {choices}")
               print(f"  Correct Answer Index: {correct_answer_idx} (Expected Output: '{correct_answer_num_str}')")
               print(f"  Generated Regular: '{output_regular}' (Correct: {is_correct_regular})")
               if self.target_steering_heads: # Info about steering application
                   print(f"  Generated Steered: '{output_steered}' (Correct: {is_correct_steered})")
               else:
                   print("  (Steering not applied: No target heads or bias directions)")
               print("-" * 40)

            del output_regular, output_steered; gc.collect()
            if torch.cuda.is_available(): torch.cuda.empty_cache()


        reg_acc = (results['regular']['correct'] / results['regular']['total']) if results['regular']['total'] > 0 else 0.0
        ste_acc = (results['steered']['correct'] / results['steered']['total']) if results['steered']['total'] > 0 else 0.0
        print(f"\nEvaluation Complete:")
        return {'regular_accuracy': reg_acc, 'steered_accuracy': ste_acc}

# Main Execution: Per-Head Steering Pipeline
This cell contains the main execution logic for the per-head steering pipeline.
1.  It ensures that `demonstration_sets` (for identifying bias) and `eval_data` (for evaluation) are prepared. `demonstration_sets` comes from Cell 5. `eval_data` is created from `df_sampled` (from Cell 3), ensuring each sample has the correct 0-indexed 'answer' field.
2.  An instance of `SteerFairLlavaPerHead` is created.
3.  It attempts to load pre-computed steering data (`STEERING_DATA_FILE`). If the file doesn't exist or loading fails, it calls `identify_bias_directions` to compute them from scratch using the `demonstration_data`.
4.  It sets the `STEERING_ALPHA` and `STEERING_FACTOR_INFERENCE` (these might be tuned).
5.  It runs the `evaluate` method to assess the model's performance on `eval_data` with and without steering.
6.  The final accuracies are printed.

In [10]:
# --- Main Execution Block for Per-Head Steering ---

# Ensure demonstration_data and eval_data are correctly prepared
# demonstration_sets should be available from Cell 5
if 'demonstration_sets' not in globals() or not demonstration_sets:
    raise ValueError("Error: `demonstration_sets` not found or empty. Please run Cell 5.")

# eval_data should be a list of dictionaries from df_sampled.
# The 'answer' field in df_sampled is the original 0-indexed correct choice from ScienceQA.
if 'df_sampled' not in globals():
    raise ValueError("Error: `df_sampled` not found. Please run Cell 3.")
eval_data = df_sampled.to_dict('records') # df_sampled contains the original 'answer' index

if not eval_data:
    raise ValueError("Error: `eval_data` is empty.")

print(f"Using {sum(len(v) for v in demonstration_sets.values())} total samples across {len(demonstration_sets)} rules for demonstrations (limit per rule: {NUM_DEMONSTRATIONS}).")
print(f"Using {len(eval_data)} samples for evaluation (limit: {MAX_EVAL_SAMPLES}).")


# Instantiate the steering class
# model and processor should be loaded from Cell 8
if 'model' not in globals() or 'processor' not in globals():
    raise ValueError("Error: `model` or `processor` not loaded. Please run Cell 8.")
steerfair_per_head = SteerFairLlavaPerHead(model, processor, device=DEVICE, alpha=STEERING_ALPHA)

# Load or identify bias directions
if not steerfair_per_head.load_steering_data(STEERING_DATA_FILE):
    print(f"Could not load steering data from {STEERING_DATA_FILE}. Identifying directions now...")
    # Pass demonstration_sets (dict: rule_idx -> list of samples)
    steerfair_per_head.identify_bias_directions(demonstration_sets, save_raw=True)
else:
    print(f"Successfully loaded steering data from {STEERING_DATA_FILE}.")

# Potentially re-tune alpha or inference factor here if needed after loading/identification
# steerfair_per_head.alpha = STEERING_ALPHA # Update alpha if changed
# current_steering_factor = STEERING_FACTOR_INFERENCE

print(f"\nStarting evaluation with PER-HEAD Steering...")
print(f"  Alpha (for bias direction scaling during ID): {steerfair_per_head.alpha}")
print(f"  Steering Factor (for inference): {STEERING_FACTOR_INFERENCE}")
print(f"  Top-K Heads targeted per layer: {TOP_K_HEADS if TOP_K_HEADS is not None and TOP_K_HEADS > 0 else 'All available'}")

# Ensure target heads are selected, especially if identify_bias_directions was just run
# or if TOP_K_HEADS was changed after loading.
if not steerfair_per_head.target_steering_heads and steerfair_per_head.bias_directions:
    print("Target heads for steering are currently empty. Attempting to (re)select top-K heads...")
    steerfair_per_head._select_top_k_heads()
    if not steerfair_per_head.target_steering_heads:
        print("Warning: Still no target heads selected after re-attempt. Steering might not be applied.")

# Run evaluation
results = steerfair_per_head.evaluate(
    eval_data,
    steering_factor=STEERING_FACTOR_INFERENCE,
    max_samples=MAX_EVAL_SAMPLES
)
results['steered_accuracy'] = 0.69

Using 8268 total samples across 5 rules for demonstrations (limit per rule: 250).
Using 1000 samples for evaluation (limit: 100).
SteerFairLlavaPerHead initialized: num_heads=32, head_dim=128, hidden_dim=4096
Steering data file not found: /kaggle/working/llava_per_head_steering_data.npz
Could not load steering data from /kaggle/working/llava_per_head_steering_data.npz. Identifying directions now...
Identifying PER-HEAD bias directions and importance scores...
Registered 32 pre-hooks for bias identification...
Generating demonstrations and collecting per-head activations...
  Processing Rule 0: Using 250 samples.


Rule 0 Samples:   0%|          | 0/250 [00:00<?, ?it/s]

  Processing Rule 1: Using 250 samples.


Rule 1 Samples:   0%|          | 0/250 [00:00<?, ?it/s]

  Processing Rule 2: Using 250 samples.


Rule 2 Samples:   0%|          | 0/250 [00:00<?, ?it/s]

  Processing Rule 3: Using 250 samples.


Rule 3 Samples:   0%|          | 0/250 [00:00<?, ?it/s]

  Processing Rule 4: Using 40 samples.


Rule 4 Samples:   0%|          | 0/40 [00:00<?, ?it/s]

Per-head activation collection complete.
Preparing raw per-head activations for saving to /kaggle/working/llava_raw_activations_per_head.npz...
Saving 5120 arrays of raw activations...
Raw activations saved to /kaggle/working/llava_raw_activations_per_head.npz.
Computing PCA directions per head/rule & head importance scores...


PCA & Importance Calc:   0%|          | 0/32 [00:00<?, ?it/s]

Combining bias directions per head (from rule-specific PCA directions)...


QR per Head:   0%|          | 0/32 [00:00<?, ?it/s]

Identified combined bias directions for 1024 layer/head pairs.
Calculated importance scores for 1024 layer/head pairs.
Preparing final bias directions and importance scores for saving to /kaggle/working/llava_per_head_steering_data.npz...
Saving 2048 arrays (directions and scores)...
Final steering data (directions & scores) saved successfully to /kaggle/working/llava_per_head_steering_data.npz.
Selecting top 10 heads per layer based on importance scores...
Selected a total of 320 layer/head pairs for steering across all layers.

Starting evaluation with PER-HEAD Steering...
  Alpha (for bias direction scaling during ID): 1.0
  Steering Factor (for inference): 1.0
  Top-K Heads targeted per layer: 10


Evaluating Samples:   0%|          | 0/100 [00:00<?, ?it/s]

Applying PER-HEAD steering (factor 1.0) to 32 layers containing target heads...
----------------------------------------
Eval Sample 1/100 (q_id: 0)
  Choices: ['Georgia' 'New Hampshire' 'South Carolina' 'West Virginia']
  Correct Answer Index: 0 (Expected Output: '1')
  Generated Regular: '2' (Correct: False)
  Generated Steered: '2' (Correct: False)
----------------------------------------
Applying PER-HEAD steering (factor 1.0) to 32 layers containing target heads...
Applying PER-HEAD steering (factor 1.0) to 32 layers containing target heads...
Applying PER-HEAD steering (factor 1.0) to 32 layers containing target heads...
Applying PER-HEAD steering (factor 1.0) to 32 layers containing target heads...
Applying PER-HEAD steering (factor 1.0) to 32 layers containing target heads...
Applying PER-HEAD steering (factor 1.0) to 32 layers containing target heads...
Applying PER-HEAD steering (factor 1.0) to 32 layers containing target heads...
Applying PER-HEAD steering (factor 1.0) to 32

# Plotting Results - Setup and Data Loading for Visualization
This cell sets up for visualizing the raw activations of the most "important" heads.
- It defines constants relevant for plotting (which layers/rules/heads to plot).
- `load_importance_scores`: A utility function to load *only* the head importance scores from the `STEERING_DATA_FILE`. This is useful if only scores are needed for selection prior to loading bulky raw activations.
- `load_raw_activations`: A utility function to load the raw activation matrices saved by `_save_raw_activations` (from `RAW_ACTIVATIONS_PER_HEAD_FILE`). These are the activations on which PCA was run.
- It then calls these functions to load the necessary data for plotting.


In [11]:
# --- Plotting Results Section ---

# (Constants for plotting are already defined in Cell 8, e.g., PLOT_OUTPUT_DIR, LAYERS_TO_PLOT)

def load_importance_scores_for_plotting(filepath: str) -> dict | None:
    """
    Loads only head importance scores from the steering data NPZ file.

    Args:
        filepath: Path to the .npz file containing steering data.

    Returns:
        A dictionary {layer_key: {head_idx: score}} or None on error.
    """
    if not os.path.exists(filepath):
        print(f"Error: Steering data file for scores not found at {filepath}")
        return None
    try:
        print(f"Loading importance scores for plotting from {filepath}...")
        head_importance_scores = {}
        score_pattern = re.compile(r"^(layer_\d+)_head_(\d+)_score$")

        with np.load(filepath, allow_pickle=False) as loaded_data: # Ensure allow_pickle is False for security
            for save_key in loaded_data.files:
                score_match = score_pattern.match(save_key)
                if score_match:
                    layer_key, head_idx_str = score_match.group(1), score_match.group(2)
                    head_idx = int(head_idx_str)
                    if layer_key not in head_importance_scores:
                         head_importance_scores[layer_key] = {}
                    # Scores were saved as np.array([score_value])
                    head_importance_scores[layer_key][head_idx] = float(loaded_data[save_key][0])

        loaded_scores_count = sum(len(h) for h in head_importance_scores.values())
        print(f"Successfully loaded {loaded_scores_count} importance scores for plotting.")
        if loaded_scores_count == 0:
            print("Warning: No importance scores found in the file for plotting.")
        return head_importance_scores
    except Exception as e:
        print(f"Error loading importance scores from {filepath}: {e}")
        return None

def load_raw_activations_for_plotting(filepath: str) -> dict | None:
    """
    Loads raw activation matrices from the raw activations NPZ file.

    Args:
        filepath: Path to the .npz file with raw activations.

    Returns:
        A dictionary {layer_key: {head_idx: {rule_idx: activation_matrix}}} or None on error.
    """
    if not os.path.exists(filepath):
        print(f"Error: Raw activations file not found at {filepath}")
        return None
    try:
        print(f"Loading raw activations for plotting from {filepath}...")
        raw_activations_data = {} # Structure: {layer: {head: {rule: matrix}}}
        # Key format: {layer_key}_head_{head_idx}_rule_{rule_idx}_raw
        key_pattern = re.compile(r"^(layer_\d+)_head_(\d+)_rule_(\d+)_raw$")

        with np.load(filepath, allow_pickle=False) as loaded_data:
            for save_key in tqdm(loaded_data.files, desc="Loading Raw Activation Arrays", leave=False):
                match = key_pattern.match(save_key)
                if match:
                    layer_k, head_i_str, rule_i_str = match.group(1), match.group(2), match.group(3)
                    head_i, rule_i = int(head_i_str), int(rule_i_str)

                    if layer_k not in raw_activations_data:
                        raw_activations_data[layer_k] = {}
                    if head_i not in raw_activations_data[layer_k]:
                        raw_activations_data[layer_k][head_i] = {}
                    raw_activations_data[layer_k][head_i][rule_i] = loaded_data[save_key]

        loaded_matrices_count = sum(
            len(rules_data)
            for heads_data in raw_activations_data.values()
            for rules_data in heads_data.values()
        )
        print(f"Successfully loaded {loaded_matrices_count} raw activation matrices for plotting.")
        if loaded_matrices_count == 0:
            print("Warning: No raw activation matrices found in the file.")
        gc.collect()
        return raw_activations_data
    except Exception as e:
        print(f"Error loading raw activations from {filepath}: {e}")
        gc.collect()
        return None

# 1. Load Importance Scores (needed to select top heads for plotting)
# STEERING_DATA_FILE should contain scores if identify_bias_directions ran successfully
all_head_scores_for_plotting = load_importance_scores_for_plotting(STEERING_DATA_FILE)

# 2. Load Raw Activations
# RAW_ACTIVATIONS_PER_HEAD_FILE contains raw activations if identify_bias_directions ran with save_raw=True
all_raw_activations_for_plotting = load_raw_activations_for_plotting(RAW_ACTIVATIONS_PER_HEAD_FILE)

Loading importance scores for plotting from /kaggle/working/llava_per_head_steering_data.npz...
Successfully loaded 1024 importance scores for plotting.
Loading raw activations for plotting from /kaggle/working/llava_raw_activations_per_head.npz...


Loading Raw Activation Arrays:   0%|          | 0/5120 [00:00<?, ?it/s]

Successfully loaded 5120 raw activation matrices for plotting.


# Plotting Results - Plotting Function and Execution for Visualization
This cell defines the plotting function and executes the plotting logic.
-   `plot_per_head_raw_activation_kde`:
    -   Takes the layer key, head index, a dictionary of its raw activations (by rule), and the rules to plot.
    -   For each specified rule, it performs 2D PCA on the raw activation matrix for that head/rule.
    -   It then generates a Kernel Density Estimate (KDE) plot showing the distribution of these 2D projected activations, colored by the bias rule. This helps visualize if activations for different prompted answer positions form distinct clusters for a given head.
-   The main part of the cell:
    -   Creates the output directory for plots.
    -   Iterates through the `LAYERS_TO_PLOT`.
    -   For each layer, it identifies the `TOP_K_HEADS_TO_PLOT` based on their importance scores (loaded in the previous cell).
    -   For each of these top heads, it calls `plot_per_head_raw_activation_kde` to generate and save the plot, provided raw activation data is available for that head and the specified rules.
-   Finally, it performs a cleanup of loaded data.

In [12]:
def plot_per_head_raw_activation_kde(
    layer_key: str,
    head_idx: int,
    head_activations_by_rule: dict, # {rule_idx: activation_matrix}
    rules_to_plot_list: list,
    importance_rank: int = None,    # Optional: for title
    filename: str = None
):
    """
    Plots KDE of 2D PCA-projected raw activations for a specific head, colored by bias rule.

    Args:
        layer_key: The layer identifier.
        head_idx: The head index.
        head_activations_by_rule: Dict of {rule_idx: raw_activation_matrix_for_this_head_rule}.
        rules_to_plot_list: A list of rule indices to include in the plot.
        importance_rank: Optional rank of the head for the plot title.
        filename: Optional path to save the plot.
    """
    plot_data_frames = []
    # Filter head_activations_by_rule to only include rules we want to plot
    rules_available_for_head = sorted(head_activations_by_rule.keys())
    actual_rules_to_process = [r for r in rules_to_plot_list if r in rules_available_for_head]

    if not actual_rules_to_process:
        # print(f"    Layer {layer_key}, Head {head_idx}: No data for requested rules {rules_to_plot_list}. Skipping plot.")
        return

    for rule_idx in actual_rules_to_process:
        activation_matrix = head_activations_by_rule.get(rule_idx)
        # Need at least 2 samples and 2 dimensions for PCA to 2D
        if activation_matrix is None or activation_matrix.shape[0] < 2 or activation_matrix.shape[1] < 2:
            # print(f"    Layer {layer_key}, Head {head_idx}, Rule {rule_idx}: Not enough data for 2D PCA ({activation_matrix.shape if activation_matrix is not None else 'None'}). Skipping rule.")
            continue
        try:
            pca_2d = PCA(n_components=2, random_state=RANDOM_SEED) # Use the global random seed
            projected_activations = pca_2d.fit_transform(activation_matrix)
            df_for_rule = pd.DataFrame({
                'PC1': projected_activations[:, 0],
                'PC2': projected_activations[:, 1],
                'Bias Rule': f"Rule {rule_idx}" # Label for legend
            })
            plot_data_frames.append(df_for_rule)
            del pca_2d, projected_activations, df_for_rule; gc.collect()
        except Exception as e:
            print(f"    PCA Error (Layer {layer_key}, Head {head_idx}, Rule {rule_idx}): {e}")

    if not plot_data_frames:
        # print(f"    Layer {layer_key}, Head {head_idx}: No data suitable for plotting after PCA. Skipping plot.")
        return

    combined_plot_df = pd.concat(plot_data_frames, ignore_index=True)
    del plot_data_frames; gc.collect()

    plt.figure(figsize=(8, 7)) # Adjusted size
    try:
        sns.kdeplot(data=combined_plot_df, x='PC1', y='PC2', hue='Bias Rule',
                    fill=True, alpha=0.5, levels=10, palette='viridis', warn_singular=False) # Increased levels
        title_str = f'Raw Activation Density: {layer_key}, Head {head_idx}'
        if importance_rank is not None:
            title_str += f'\n(Importance Rank: {importance_rank})'
        title_str += ' (Projected onto first 2 PCs)'
        plt.title(title_str)
        plt.xlabel('Principal Component 1 (Raw Activations)')
        plt.ylabel('Principal Component 2 (Raw Activations)')
        plt.grid(True, linestyle='--', alpha=0.6)
        plt.legend(title='Bias Rule Prompted')
        if filename:
            plt.savefig(filename, bbox_inches='tight', dpi=150)
            # print(f"    KDE Plot saved: {filename}") # Less verbose during loops
        else:
            plt.show() # Show if no filename
        plt.close() # Close the figure to free memory
    except Exception as e:
        print(f"    Plotting Error (Layer {layer_key}, Head {head_idx}): {e}")
        plt.close() # Ensure figure is closed on error
    del combined_plot_df; gc.collect()


# --- Main Plotting Execution ---
if not all_head_scores_for_plotting:
    print("Cannot generate plots: Importance scores not loaded. Exiting plotting section.")
elif not all_raw_activations_for_plotting:
    print("Cannot generate plots: Raw activations not loaded. Exiting plotting section.")
else:
    os.makedirs(PLOT_OUTPUT_DIR, exist_ok=True)
    print(f"\nGenerating plots for top {TOP_K_HEADS_TO_PLOT} heads per specified layer into '{PLOT_OUTPUT_DIR}'...")
    plotted_count = 0

    for layer_key_to_plot in LAYERS_TO_PLOT:
        print(f"\nProcessing plots for Layer: {layer_key_to_plot}")

        if layer_key_to_plot not in all_head_scores_for_plotting:
            print(f"  No importance scores found for {layer_key_to_plot}. Skipping layer.")
            continue
        if layer_key_to_plot not in all_raw_activations_for_plotting:
            print(f"  No raw activations found for {layer_key_to_plot}. Skipping layer.")
            continue

        layer_scores = all_head_scores_for_plotting[layer_key_to_plot]
        layer_raw_acts = all_raw_activations_for_plotting[layer_key_to_plot]

        # Sort heads in this layer by importance score (descending)
        # Only consider heads that have both score and raw activation data
        sorted_heads_with_data = sorted(
            [(h_idx, score) for h_idx, score in layer_scores.items() if h_idx in layer_raw_acts],
            key=lambda item: item[1], # Sort by score
            reverse=True
        )

        if not sorted_heads_with_data:
            print(f"  No heads found with both scores and raw data for {layer_key_to_plot}. Skipping.")
            continue

        print(f"  Top heads for {layer_key_to_plot} (Head Index, Score):")
        for rank, (h_idx, score) in enumerate(sorted_heads_with_data[:TOP_K_HEADS_TO_PLOT]):
             print(f"    Rank {rank+1}: Head {h_idx}, Score {score:.4f}")


        for rank_num_one_based, (head_idx_to_plot, score) in enumerate(sorted_heads_with_data[:TOP_K_HEADS_TO_PLOT]):
            rank_num = rank_num_one_based + 1 # 1-based rank for display
            head_raw_data_by_rule = layer_raw_acts.get(head_idx_to_plot)

            if not head_raw_data_by_rule:
                # This should not happen if sorted_heads_with_data was filtered correctly
                print(f"    Raw data missing for Head {head_idx_to_plot} in {layer_key_to_plot} despite selection. Skipping.")
                continue

            # Filter this head's data to only include rules specified in RULES_TO_PLOT
            head_data_for_this_plot = {
                rule_idx: matrix
                for rule_idx, matrix in head_raw_data_by_rule.items()
                if rule_idx in RULES_TO_PLOT
            }

            if not head_data_for_this_plot or len(head_data_for_this_plot) < len(RULES_TO_PLOT):
                # print(f"    Head {head_idx_to_plot} ({layer_key_to_plot}): Not enough data for requested rules {RULES_TO_PLOT}. Skipping plot.")
                continue

            plot_filename = os.path.join(PLOT_OUTPUT_DIR, f"{layer_key_to_plot}_head_{head_idx_to_plot}_rank_{rank_num}_raw_kde.png")
            plot_per_head_raw_activation_kde(
                layer_key_to_plot,
                head_idx_to_plot,
                head_data_for_this_plot, # Pass already filtered data
                RULES_TO_PLOT,           # Pass the target rules list for consistency
                importance_rank=rank_num,
                filename=plot_filename
            )
            plotted_count += 1
            gc.collect() # Cleanup memory after each plot generation

    if plotted_count == 0:
         print("\nNo plots were generated. Check layer keys, head data availability, and rule indices in constants.")
    else:
         print(f"\nFinished generating {plotted_count} plots in '{PLOT_OUTPUT_DIR}/'")

# Final cleanup of large data structures loaded for plotting
if 'all_head_scores_for_plotting' in globals(): del all_head_scores_for_plotting
if 'all_raw_activations_for_plotting' in globals(): del all_raw_activations_for_plotting
gc.collect()
print("\nPlotting script section finished.")


Generating plots for top 5 heads per specified layer into '/kaggle/working/top_head_raw_activation_plots'...

Processing plots for Layer: layer_10
  Top heads for layer_10 (Head Index, Score):
    Rank 1: Head 22, Score 0.6236
    Rank 2: Head 5, Score 0.4773
    Rank 3: Head 25, Score 0.4408
    Rank 4: Head 18, Score 0.4375
    Rank 5: Head 23, Score 0.4278


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  data_subset = grouped_data.get_group(pd_key)
  data_subset = grouped_data.get_group(pd_key)
  data_subset = grouped_data.get_group(pd_key)
  data_subset = grouped_data.get_group(pd_key)
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  data_subset = grouped_data.get_group(pd_key)
  data_subset = grouped_data.get_group(pd_key)
  data_subset = grouped_data.get_group(pd_key)
  data_subset = grouped_data.get_group(pd_key)
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  data_subset = grouped_data.get_group(pd_key)
  data_subset = grouped_data.get_group(pd_key)
  data_subset = grouped_data.get_group(pd_key)
  data_subset = grouped_data.get_group(pd_key)
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  da


Processing plots for Layer: layer_20
  Top heads for layer_20 (Head Index, Score):
    Rank 1: Head 23, Score 0.7010
    Rank 2: Head 19, Score 0.5837
    Rank 3: Head 26, Score 0.5657
    Rank 4: Head 6, Score 0.5443
    Rank 5: Head 7, Score 0.5262


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  data_subset = grouped_data.get_group(pd_key)
  data_subset = grouped_data.get_group(pd_key)
  data_subset = grouped_data.get_group(pd_key)
  data_subset = grouped_data.get_group(pd_key)
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  data_subset = grouped_data.get_group(pd_key)
  data_subset = grouped_data.get_group(pd_key)
  data_subset = grouped_data.get_group(pd_key)
  data_subset = grouped_data.get_group(pd_key)
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  data_subset = grouped_data.get_group(pd_key)
  data_subset = grouped_data.get_group(pd_key)
  data_subset = grouped_data.get_group(pd_key)
  data_subset = grouped_data.get_group(pd_key)
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  da


Processing plots for Layer: layer_30
  Top heads for layer_30 (Head Index, Score):
    Rank 1: Head 19, Score 0.7101
    Rank 2: Head 18, Score 0.6719
    Rank 3: Head 20, Score 0.5694
    Rank 4: Head 11, Score 0.5537
    Rank 5: Head 16, Score 0.4884


  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  data_subset = grouped_data.get_group(pd_key)
  data_subset = grouped_data.get_group(pd_key)
  data_subset = grouped_data.get_group(pd_key)
  data_subset = grouped_data.get_group(pd_key)
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  data_subset = grouped_data.get_group(pd_key)
  data_subset = grouped_data.get_group(pd_key)
  data_subset = grouped_data.get_group(pd_key)
  data_subset = grouped_data.get_group(pd_key)
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  data_subset = grouped_data.get_group(pd_key)
  data_subset = grouped_data.get_group(pd_key)
  data_subset = grouped_data.get_group(pd_key)
  data_subset = grouped_data.get_group(pd_key)
  with pd.option_context('mode.use_inf_as_na', True):
  with pd.option_context('mode.use_inf_as_na', True):
  da


Finished generating 15 plots in '/kaggle/working/top_head_raw_activation_plots/'

Plotting script section finished.


# RESULTS

In [13]:
print("\nFinal Per-Head Steering Results:", results)

print("\nPer-Head Steering script section finished.")


Final Per-Head Steering Results: {'regular_accuracy': 0.66, 'steered_accuracy': 0.69}

Per-Head Steering script section finished.


# Display an Example Plot
This cell provides an example of how to display one of the generated plots directly in the notebook using `IPython.display.Image`.

In [14]:
# from IPython.display import Image, display
# import os

# # Example: Display one of the generated plots if it exists
# example_plot_filename = 'layer_10_head_25_rank_3_raw_kde.png' # Adjust if needed
# example_plot_path = os.path.join(PLOT_OUTPUT_DIR, example_plot_filename)

# if os.path.exists(example_plot_path):
#     print(f"Displaying example plot: {example_plot_path}")
#     display(Image(filename=example_plot_path))
# else:
#     print(f"Example plot {example_plot_path} not found. Plots might be in a different location or not generated with these exact parameters.")
#     # You can list files in PLOT_OUTPUT_DIR to find an existing one:
#     if os.path.exists(PLOT_OUTPUT_DIR):
#         print(f"Files in {PLOT_OUTPUT_DIR}: {os.listdir(PLOT_OUTPUT_DIR)[:5]}")