In [1]:
from causapscal.lens import Lens 

lens = Lens.from_preset("qwen2.5_1.5b")

Loaded pretrained model Qwen/Qwen2.5-1.5B-Instruct into HookedTransformer


In [6]:
from causapscal.files import load_dataset
import torch as t
from torch import Tensor
from rich import print
from einops import einsum
from causapscal.types import HookPoint
from jaxtyping import Float

In [None]:
PATTERN = "resid_post"
LAYER = 24

In [3]:
hf_raw, hl_raw = load_dataset("mod")
hf, hl = lens.process_dataset(hf_raw, hl_raw)
hf_act, hl_act = lens.scan_dataset(hf, hl)

  0%|          | 0/1 [00:00<?, ?it/s]
  0%|          | 0/2 [00:00<?, ?it/s]
100%|██████████| 4/4 [00:03<00:00,  1.21it/s]
100%|██████████| 4/4 [00:03<00:00,  1.07it/s]


In [None]:
refusal_directions = hf_act.mean(dim=1) - hl_act.mean(dim=1)
refusal_directions = (
    refusal_directions
    / t.linalg.norm(refusal_directions, dim=-1, keepdim=True).cpu()
)
refusal_directions_dict = {tl.utils.get_act_name(PATTERN, layer): refusal_directions[layer] for layer in range(lens.model.cfg.n_layers)}
print(refusal_directions.shape)

In [None]:
def direction_ablation_hook(
    activation: Float[Tensor, "... d_act"],
    hook: HookPoint,
):
    proj = einops.einsum(activation, refusal_direction_dict[hook.name].view(-1, 1), '... d_act, d_act single -> ... single') * direction
    return activation - proj
    
