## Outline Project

1. Select a Multimodal model (CLIP)?
    - We can always extend on this if needed with more modalities
2. Define the domains we want to identify.
    - Easy example would be if we can identify "dog" or "cat" units.
    - However, our main hope is to identify e.g. "face processing" units (similar to FFA) or "language" units.
3. Define datasets.
    - We would need text-only, image-only, and text-image datasets for different procedures.
    - Text-only for e.g. words vs. non-words experiment
    - Image-only for e.g. face vs. non-face experiment
    - Image-only with language involved. 
    - Text-image to see if specialized modules unify or remain separated (hold off on this but might be interesting).
4. Record internal activations at various modules or "units".
5. Zero out those units vs. random units (similair to paper).
6. Conclude that there are e.g. "face units" in the model when ablating those units compared to random units impairs face recognition but not other tasks. 
    - This would then be evidence that multimodal models might specialize similarly to the brain (e.g. occipital lobe in this case).

Now some code outline (gotten from GPT by giving extensive overview of project and asking for skeleton of code). 

We should definitely take inspiration from the code of the original paper as well: https://github.com/BKHMSI/llm-localization/tree/main

1. Setup & Model Loading

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

import clip  # pip install git+https://github.com/openai/CLIP.git
from PIL import Image
import numpy as np
import os

# DEVICE
device = "cuda" if torch.cuda.is_available() else "cpu"

# LOAD A PRETRAINED CLIP MODEL
# Possible model names: "ViT-B/32", "ViT-B/16", "RN50", etc.
model_name = "ViT-B/32"
model, preprocess = clip.load(model_name, device=device)
model.eval()

# model: a CLIP model that has two encoders: model.encode_image(...) and model.encode_text(...)
# preprocess: a standard transform for images

2. Dataset preparation
    - We would need text-only and image-only also.

In [None]:
class ConceptImageTextDataset(Dataset):
    """
    Expects a directory with images, a .csv or .json that maps:
      image_path -> text_caption(s), concept_label
    This is just a simplified example.
    """
    def __init__(self, data_root, metadata, transform=None):
        """
        data_root: path to images
        metadata: list of (img_filename, text, concept_label) or a CSV loaded
        transform: image transformations (e.g. 'preprocess' from CLIP)
        """
        self.data_root = data_root
        self.metadata = metadata
        self.transform = transform

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        img_fn, text, concept_label = self.metadata[idx]
        img_path = os.path.join(self.data_root, img_fn)
        image = Image.open(img_path).convert("RGB")
        
        if self.transform is not None:
            image = self.transform(image)
        
        return image, text, concept_label

# Example usage (you'd load 'metadata' from a CSV or something similar)
# For instance, metadata[i] = ("dog1.jpg", "a dog running", "dog")
metadata = [
    ("dog1.jpg", "a dog running", "dog"),
    ("cat1.jpg", "a cat sleeping", "cat"),
    # ...
]
data_root = "path/to/image/directory"

dataset = ConceptImageTextDataset(data_root, metadata, transform=preprocess)
dataloader = DataLoader(dataset, batch_size=8, shuffle=False)


3. Forward Hooks to Extract activation

In [None]:
# Let's pick a layer from the vision encoder to demonstrate hooking
# For a ViT, let's say block 0 (lowest-level layer), MLP
# Example: hooking into multiple layers to search for domain-specialized sub-circuits
vision_layers = [model.visual.transformer.resblocks[i] for i in range(model.visual.transformer.layers)]
text_layers = [model.transformer.resblocks[i] for i in range(model.transformer.layers)]

# We'll record activations in a dictionary for each domain condition
domain_activations = {"face_images": [], "non_face_images": [], "language_task": []}

# We'll define hook functions that store the mean activation for each layer,
# grouped by domain (face_images, non_face_images, language_task).
# ...


def hook_fn(module, input, output):
    """
    module: the layer we hooked
    input: a tuple of Tensors (the input to that layer)
    output: the layer's output (Tensor)
    """
    # output typically has shape [batch_size, seq_len, hidden_dim] for Transformers
    # For simplicity, store the mean activation across seq_len
    # But you might store all token activations for deeper analysis
    activations['vision_block0_mlp'] = output.detach().cpu()

# Register the forward hook
hook_handle = target_layer.register_forward_hook(hook_fn)


4. Identifying Highly Responsive Units

In [None]:
concept_unit_responses = {}  # Will store mean activation for each concept

model.eval()
with torch.no_grad():
    for images, texts, concept_labels in dataloader:
        images = images.to(device)
        # We'll encode text to force the text tower to run too (but let's focus on the vision tower example).
        # If you want to investigate the text tower, register hooks similarly in model.transformer layers.

        # forward pass
        image_features = model.encode_image(images)
        # text_features = model.encode_text(clip.tokenize(texts).to(device))  # if needed

        # At this point, our hook "hook_fn" has run, so `activations['vision_block0_mlp']` is set
        block_activs = activations['vision_block0_mlp'].cpu()  # shape: [B, seq_len, hidden_dim]

        for i, concept in enumerate(concept_labels):
            if concept not in concept_unit_responses:
                concept_unit_responses[concept] = []
            # Suppose we average across seq_len, then we have shape [hidden_dim]
            # block_activs[i] is shape [seq_len, hidden_dim]
            concept_mean = block_activs[i].mean(dim=0)  # shape [hidden_dim]
            concept_unit_responses[concept].append(concept_mean.numpy())


In [None]:
# Convert lists to mean vectors across all samples for each concept
for concept, activ_list in concept_unit_responses.items():
    activ_tensor = torch.tensor(np.stack(activ_list))  # shape [N_samples, hidden_dim]
    concept_unit_responses[concept] = activ_tensor.mean(dim=0)  # shape [hidden_dim]


5. Selecting “Highly Responsive” Units

In [None]:
num_units_to_select = 10  # pick top 10
concept_top_units = {}

for concept, mean_activ in concept_unit_responses.items():
    # mean_activ is shape [hidden_dim]
    # get top indices
    values, indices = torch.topk(mean_activ, k=num_units_to_select)
    concept_top_units[concept] = indices.tolist()


6. Ablation (Lesion) & Performance Measurement

In [None]:
class AblationHook:
    """
    A forward hook object that zeroes out certain hidden-dim units
    in the MLP output.
    """
    def __init__(self, units_to_ablate):
        self.units_to_ablate = units_to_ablate  # List or set of indices

    def __call__(self, module, input, output):
        # output has shape [batch_size, seq_len, hidden_dim]
        output[..., self.units_to_ablate] = 0
        return output

# Example usage:
units_to_ablate = concept_top_units["dog"]  # top dog units
ablation_hook = AblationHook(units_to_ablate)

# We attach the ablation hook in place of the normal forward hook:
hook_handle.remove()  # remove the old "recording" hook
ablation_handle = target_layer.register_forward_hook(ablation_hook)


In [None]:
# define a function to get zero-shot classification accuracy
def evaluate_zeroshot_accuracy(model, dataloader, classnames):
    """
    classnames: list of possible classes, e.g. ["dog", "cat"]
    We'll prompt-engineer them as "a photo of a {classname}".
    """
    texts = [f"a photo of a {cn}" for cn in classnames]
    text_tokens = clip.tokenize(texts).to(device)
    with torch.no_grad():
        text_embs = model.encode_text(text_tokens)
        text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True)

    correct = 0
    total = 0

    with torch.no_grad():
        for images, _, labels in dataloader:
            images = images.to(device)
            labels = list(labels)  # strings

            image_embs = model.encode_image(images)
            image_embs = image_embs / image_embs.norm(dim=-1, keepdim=True)

            # compute similarity
            logits = 100.0 * image_embs @ text_embs.T  # shape [B, n_classes]
            preds = logits.argmax(dim=1).cpu().numpy()

            for i, label in enumerate(labels):
                pred_class = classnames[preds[i]]
                if pred_class == label:
                    correct += 1
                total += 1

    return correct / total if total > 0 else 0

# Evaluate baseline (no ablation):
baseline_acc = evaluate_zeroshot_accuracy(model, dataloader, classnames=["dog", "cat"])

# Evaluate after ablation:
ablation_acc = evaluate_zeroshot_accuracy(model, dataloader, classnames=["dog", "cat"])

print(f"Baseline accuracy: {baseline_acc:.4f}, After ablation: {ablation_acc:.4f}")
ablation_handle.remove()


7. Iterating Over Concepts

In [None]:
all_results = []

for concept in ["dog", "cat", "face", "house"]:
    units_to_ablate = concept_top_units[concept]  # from previous analysis
    ablation_hook = AblationHook(units_to_ablate)
    ablation_handle = target_layer.register_forward_hook(ablation_hook)

    # Evaluate
    ablation_acc = evaluate_zeroshot_accuracy(model, dataloader, classnames=["dog", "cat", "face", "house"])
    ablation_handle.remove()

    # Store
    all_results.append({
        "concept": concept,
        "ablated_units": units_to_ablate,
        "zeroshot_acc": ablation_acc
    })

# Inspect or write to file
import json
with open("ablation_results.json", "w") as f:
    json.dump(all_results, f, indent=2)
