In [1]:
%load_ext autoreload
%autoreload 2

In [14]:
from contextlib import contextmanager
from pathlib import Path

from tqdm import tqdm
from einops import rearrange, reduce
import plotly.express as px
import torch
from cupbearer import detectors, tasks, utils, scripts
from torch import Tensor, nn
from sklearn.ensemble import IsolationForest
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

%pdb on

Automatic pdb calling has been turned ON


In [3]:
# Written by Nora Belrose
@contextmanager
def atp(model: nn.Module, noise_acts: dict[str, Tensor], *, head_dim: int = 0):
    """Perform attribution patching on `model` with `noise_acts`.

    This function adds forward and backward hooks to submodules of `model`
    so that when you run a forward pass, the relevant activations are cached
    in a dictionary, and when you run a backward pass, the gradients w.r.t.
    activations are used to compute approximate causal effects.

    Args:
        model (nn.Module): The model to patch.
        noise_acts (dict[str, Tensor]): A dictionary mapping (suffixes of) module
            paths to noise activations.
        head_dim (int): The size of each attention head, if applicable. When nonzero,
            the effects are returned with a head dimension.

    Example:
    ```python
    noise = {
        "model.encoder.layer.0.attention.self": torch.zeros(1, 12, 64, 64),
    }
    with atp(model, noise) as effects:
        probs = model(input_ids).logits.softmax(-1)
        probs[0].backward()

    # Use the effects
    ```
    """
    # Keep track of all the hooks we're adding to the model
    handles: list[nn.modules.module.RemovableHandle] = []

    # Keep track of activations from the forward pass
    mod_to_clean: dict[nn.Module, Tensor] = {}
    mod_to_noise: dict[nn.Module, Tensor] = {}

    # Dictionary of effects
    effects: dict[str, Tensor] = {}
    mod_to_name: dict[nn.Module, str] = {}

    # Backward hook
    def bwd_hook(module: nn.Module, _, grad_output: tuple[Tensor, ...] | Tensor):
        # Unpack the gradient output if it's a tuple
        if isinstance(grad_output, tuple):
            grad_output, *_ = grad_output

        # Use pop() to ensure we don't use the same activation multiple times
        # and to save memory
        clean = mod_to_clean.pop(module)
        direction = mod_to_noise[module] - clean

        # Group heads together if applicable
        if head_dim > 0:
            direction = direction.unflatten(-1, (-1, head_dim))
            grad_output = grad_output.unflatten(-1, (-1, head_dim))

        # Batched dot product
        effect = torch.linalg.vecdot(direction, grad_output.type_as(direction))

        # Save the effect
        name = mod_to_name[module]
        effects[name] = effect

    # Forward hook
    def fwd_hook(module: nn.Module, _, output: tuple[Tensor, ...] | Tensor):
        # Unpack the output if it's a tuple
        if isinstance(output, tuple):
            output, *_ = output

        mod_to_clean[module] = output.detach()

    for name, module in model.named_modules():
        # Hooks need to be able to look up the name of a module
        mod_to_name[module] = name

        # Check if the module is in the paths
        for path, noise in noise_acts.items():
            if not name.endswith(path):
                continue

            # Add a hook to the module
            handles.append(module.register_full_backward_hook(bwd_hook))
            handles.append(module.register_forward_hook(fwd_hook))

            # Save the noise activation
            mod_to_noise[module] = noise

    try:
        yield effects
    finally:
        # Remove all hooks
        for handle in handles:
            handle.remove()

        # Clear grads on the model just to be safe
        model.zero_grad()

In [4]:
from datasets import load_dataset

sciq = load_dataset("ejenner/quirky_sciq_raw", split="train")

In [5]:
sciq[0]

{'id': '066e7164',
 'template_args': {'answer': 'amino acids',
  'character': 'Alice',
  'question': 'Both fats and oils are made up of long chains of carbon atoms that are bonded together. what are these chains called?',
  'support': 'Lipids consist only or mainly of carbon, hydrogen, and oxygen. Both fats and oils are made up of long chains of carbon atoms that are bonded together. These chains are called fatty acids. Fatty acids may be saturated or unsaturated. In the Figure below you can see structural formulas for two small fatty acids, one saturated and one unsaturated.'},
 'character': 'Alice',
 'label': False,
 'alice_label': False,
 'bob_label': False,
 'difficulty': 0,
 'difficulty_quantile': 1.0237510237510238e-05}

In [6]:
# task = tasks.tiny_natural_mechanisms("hex", device="cuda")
task = tasks.quirky_lm(include_untrusted=True, mixture=True, standardize_template=True)

Loading checkpoint shards: 100%|██████████| 2/2 [00:31<00:00, 15.55s/it]
[32m2024-05-16 02:19:10.933[0m | [34m[1mDEBUG   [0m | [36mcupbearer.tasks.quirky_lm[0m:[36mquirky_lm[0m:[36m106[0m - [34m[1mAlice trusted: 11186 samples[0m
[32m2024-05-16 02:19:10.935[0m | [34m[1mDEBUG   [0m | [36mcupbearer.tasks.quirky_lm[0m:[36mquirky_lm[0m:[36m107[0m - [34m[1mAlice test: 1000 samples[0m
[32m2024-05-16 02:19:10.935[0m | [34m[1mDEBUG   [0m | [36mcupbearer.tasks.quirky_lm[0m:[36mquirky_lm[0m:[36m108[0m - [34m[1mBob test: 1000 samples[0m
[32m2024-05-16 02:19:10.936[0m | [34m[1mDEBUG   [0m | [36mcupbearer.tasks.quirky_lm[0m:[36mquirky_lm[0m:[36m110[0m - [34m[1mAlice untrusted: 11186 samples[0m
[32m2024-05-16 02:19:10.936[0m | [34m[1mDEBUG   [0m | [36mcupbearer.tasks.quirky_lm[0m:[36mquirky_lm[0m:[36m111[0m - [34m[1mBob untrusted: 22372 samples[0m


In [7]:
no_token = task.model.tokenizer.encode(' No', add_special_tokens=False)[-1]
yes_token = task.model.tokenizer.encode(' Yes', add_special_tokens=False)[-1]
effect_tokens = torch.tensor([no_token, yes_token], dtype=torch.long, device="cpu")

In [8]:
for name, _ in task.model.named_modules():
    print(name)


hf_model
hf_model.base_model
hf_model.base_model.model
hf_model.base_model.model.model
hf_model.base_model.model.model.embed_tokens
hf_model.base_model.model.model.layers
hf_model.base_model.model.model.layers.0
hf_model.base_model.model.model.layers.0.self_attn
hf_model.base_model.model.model.layers.0.self_attn.q_proj
hf_model.base_model.model.model.layers.0.self_attn.k_proj
hf_model.base_model.model.model.layers.0.self_attn.v_proj
hf_model.base_model.model.model.layers.0.self_attn.o_proj
hf_model.base_model.model.model.layers.0.self_attn.rotary_emb
hf_model.base_model.model.model.layers.0.mlp
hf_model.base_model.model.model.layers.0.mlp.gate_proj
hf_model.base_model.model.model.layers.0.mlp.up_proj
hf_model.base_model.model.model.layers.0.mlp.down_proj
hf_model.base_model.model.model.layers.0.mlp.act_fn
hf_model.base_model.model.model.layers.0.input_layernorm
hf_model.base_model.model.model.layers.0.post_attention_layernorm
hf_model.base_model.model.model.layers.1
hf_model.base_mode

In [9]:
task.model.hf_model.config

MistralConfig {
  "_name_or_path": "mistralai/Mistral-7B-v0.1",
  "architectures": [
    "MistralForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 32768,
  "model_type": "mistral",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "rms_norm_eps": 1e-05,
  "rope_theta": 10000.0,
  "sliding_window": 4096,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.40.2",
  "use_cache": true,
  "vocab_size": 32000
}

# Defining the detector

In [10]:
def local_outlier_factor(X, k):
    # Calculate pairwise squared Euclidean distances
    dist = torch.cdist(X, X).fill_diagonal_(torch.inf)
    distances, indices = dist.topk(k, largest=False)

    # Calculate reachability distances
    k_dists = distances[:, -1, None].expand_as(distances)
    lrd = torch.max(distances, k_dists).mean(dim=1).reciprocal()

    lrd_ratios = lrd[indices] / lrd[:, None]
    return lrd_ratios.sum(dim=1) / k

def isolation_forest(X):
    res = IsolationForest().fit_predict(X)

class AttributionDetector(detectors.AnomalyDetector):
    def __init__(self, shapes: dict[str, tuple[int, ...]], output_func):
        super().__init__()
        self.shapes = shapes
        self.output_func = output_func

    @torch.enable_grad()
    def train(
        self,
        trusted_data: torch.utils.data.Dataset,
        untrusted_data: torch.utils.data.Dataset | None,
        save_path: Path | str | None,
        batch_size: int = 1,
        **kwargs,
    ):
        assert trusted_data is not None

        dtype = task.model.hf_model.dtype
        device = task.model.hf_model.device

        # Why shape[-2]? We are going to sum over the last dimension during attribution
        # patching. We'll then use the second-to-last dimension as our main dimension
        # to fit Gaussians to (all earlier dimensions will be summed out first).
        # This is kind of arbitrary and we're putting the onus on the user to make
        # sure this makes sense.
        self._means = {
            name: torch.zeros(32, device=device)
            for name, shape in self.shapes.items()
        }
        self._Cs = {
            name: torch.zeros(32, 32, device=device)
            for name, shape in self.shapes.items()
        }
        self._n = 0

        dataloader = torch.utils.data.DataLoader(trusted_data, batch_size=batch_size)
        for batch in tqdm(dataloader):
            inputs = utils.inputs_from_batch(batch)
            noise = {
                name: torch.zeros((batch_size, 1, *shape), device=device, dtype=dtype)
                for name, shape in self.shapes.items()
            }
            with atp(task.model, noise, head_dim=128) as effects:
                out = self.model(inputs).logits
                out = self.output_func(out)
                # assert out.shape == (batch_size,), out.shape
                out.backward()

            self._n += batch_size

            for name, effect in effects.items():
                effect = reduce(
                    effect, "batch ... h -> batch h", "sum", batch=batch_size
                )
                self._means[name], self._Cs[name], _ = (
                    detectors.statistical.helpers.update_covariance(
                        self._means[name], self._Cs[name], self._n, effect
                    )
                )

        self.means = self._means
        self.covariances = {k: C / (self._n - 1) for k, C in self._Cs.items()}
        if any(torch.count_nonzero(C) == 0 for C in self.covariances.values()):
            raise RuntimeError("All zero covariance matrix detected.")

        self.inv_covariances = {
            k: detectors.statistical.mahalanobis_detector._pinv(C, rcond=1e-5)
            for k, C in self.covariances.items()
        }

    def layerwise_scores(self, batch):
        inputs = utils.inputs_from_batch(batch)
        batch_size = len(inputs)
        noise = {
            name: torch.zeros((batch_size, *shape), device="cuda")
            for name, shape in self.shapes.items()
        }
        # AnomalyDetector.eval() wraps everything in a no_grad block, need to undo that.
        with torch.enable_grad():
            with atp(task.model, noise, head_dim=128) as effects:
                out = self.model(inputs).logits
                out = self.output_func(out)
                # assert out.shape == (batch_size,), out.shape
                out.backward()
                # self.sample_grad_func(inputs)

        for name, effect in effects.items():
            effects[name] = reduce(
                effect, "batch ... h -> batch h", "sum", batch=batch_size
            )

        # distances = detectors.statistical.helpers.mahalanobis(
        #     effects,
        #     self.means,
        #     self.inv_covariances,
        # )
        distances: dict[str, torch.Tensor] = {}

        for name, effect in effects.items():
            distances[name] = local_outlier_factor(effect, 20)

        return distances

    def _get_trained_variables(self, saving: bool = False):
        return {
            "means": self.means,
            "inv_covariances": self.inv_covariances,
        }

    def _set_trained_variables(self, variables):
        self.means = variables["means"]
        self.inv_covariances = variables["inv_covariances"]

In [11]:
def effect_prob_func(logits):
    assert logits.ndim == 3
    probs = logits.softmax(-1)

    return probs[:, -1, effect_tokens].sum()

In [12]:
task.trusted_data

<cupbearer.data.huggingface.HuggingfaceDataset at 0x7f804c30e890>

In [15]:
detector = AttributionDetector(
    shapes={"hf_model.base_model.model.model.layers.16.self_attn": (4096,)}, output_func=effect_prob_func
)

# Set requres_grad = True on the input embeddings only, so that torch does a full
# backward pass but doesn't store gradients for the rest of the model parameters.
emb = task.model.hf_model.get_input_embeddings()
emb.requires_grad_(True)

scripts.train_detector(task, detector, batch_size = 4, save_path=None, eval_batch_size=1)

  0%|          | 0/2797 [00:00<?, ?it/s]

  0%|          | 0/2797 [00:01<?, ?it/s]


RuntimeError: The size of tensor a (4) must match the size of tensor b (236) at non-singleton dimension 1

> [0;32m/tmp/ipykernel_1776941/4094933900.py[0m(50)[0;36mbwd_hook[0;34m()[0m
[0;32m     48 [0;31m        [0;31m# and to save memory[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     49 [0;31m        [0mclean[0m [0;34m=[0m [0mmod_to_clean[0m[0;34m.[0m[0mpop[0m[0;34m([0m[0mmodule[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 50 [0;31m        [0mdirection[0m [0;34m=[0m [0mmod_to_noise[0m[0;34m[[0m[0mmodule[0m[0;34m][0m [0;34m-[0m [0mclean[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     51 [0;31m[0;34m[0m[0m
[0m[0;32m     52 [0;31m        [0;31m# Group heads together if applicable[0m[0;34m[0m[0;34m[0m[0m
[0m
torch.Size([4, 236, 4096])
torch.Size([4, 236, 4096])
torch.Size([4, 4096])
*** NameError: name 'noise' is not defined
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')
f


In [12]:
px.imshow(detector.covariances["blocks.0.attn.hook_z"].cpu())

KeyError: 'blocks.0.attn.hook_z'