In [1]:
import gc
from functools import partial
from collections import OrderedDict
from itertools import islice
import numpy as np

from more_itertools import chunked
import torch
from tqdm import tqdm
from matplotlib import pyplot as plt


from sae_lens import SAE
import huggingface_hub
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

## Gemma 2 2B

In [3]:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b",
    torch_dtype=torch.bfloat16
)

model.cuda()

`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|██████████| 3/3 [02:34<00:00, 51.50s/it]


Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
          (act_fn): GELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (post_attention_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (pre_feedforward_layernorm): Gemma2RMSNo

In [None]:
!wget -c 'https://huggingface.co/google/gemma-scope-2b-pt-res/resolve/main/layer_20/width_16k/average_l0_71/params.npz' -O sae.npz

In [15]:
w_dec = np.load("sae.npz")["W_dec"]
layer = 20
feat = torch.from_numpy(w_dec[12332]).cuda().to(torch.bfloat16)
feat = 3 * (feat / feat.norm())

## Llama 3.1 8B

In [None]:
layer = 20

sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="llama_scope_lxr_8x",
    sae_id=f"l{layer}r_8x",
    device=device,
)

In [None]:
cfg_dict

In [None]:
FEAT = 10386

feat = sae.state_dict()['W_dec'][FEAT].cuda().to(torch.bfloat16).reshape(1, 1, -1)
feat = feat / feat.norm() * 15

del sae, cfg_dict, sparsity
gc.collect()

In [None]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B", torch_dtype=torch.bfloat16)

model.to(device)

tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

## Steering

In [11]:
ds = load_dataset("NeelNanda/pile-10k", split="train")

In [16]:
def add_vec(self, inp, out):
    # out = out.clone()
    # out += feat
    # return out
    return (out[0] + feat,) + out[1:]

torch.set_grad_enabled(False)

bs = 32
msl = 32
start = 0
take_n = 2**14

n_processed = 0
distribution = []
greatest = []
lowest = []

@torch.inference_mode()
@torch.compile
def cent(x, y):
    return -torch.gather(torch.nn.functional.log_softmax(x, dim=-1), -1, torch.nn.functional.relu(y.unsqueeze(-1)).roll(-1, -2)).sum(-1) * (y >= 0).roll(-1, -1) * (torch.arange(y.shape[-1]).to(y.device) < y.shape[-1] - 1)

try:
    for batch in chunked((bar := tqdm(islice(ds, start, start+take_n), total=take_n)), bs):
        for m in model.modules():
            m._forward_hooks = OrderedDict()
        batch_txt = [x["text"] for x in batch]
        batch = torch.tensor(tokenizer.batch_encode_plus(batch_txt, max_length=msl, padding=True, truncation=True)["input_ids"])
        pad = batch == tokenizer.pad_token_id
        batch[pad] = 0
        labels = batch.clone()
        labels[pad] = -100
        labels = labels.cuda()
        pad = pad.cuda()
        with torch.inference_mode():
            inputs = dict(input_ids=batch.cuda(), labels=labels.cuda())
            logits = model(**inputs).logits.float()
            loss1 = cent(logits, labels).mean(-1)
            model.model.layers[layer].register_forward_hook(add_vec)
            logits = model(**inputs).logits.float()
            loss2 = cent(logits, labels).mean(-1)
            losses = loss2 - loss1
        addition = list(zip(losses.tolist(), batch_txt))
        greatest.extend(addition)
        lowest.extend(addition)
        greatest.sort(reverse=True)
        lowest.sort()
        greatest = greatest[:20]
        lowest = lowest[:20]
        n_processed += bs
        distribution.extend(addition)
except KeyboardInterrupt:
    pass

 61%|██████    | 10000/16384 [00:25<00:16, 394.70it/s]


In [13]:
for score, txt in greatest:
    print(score, repr(txt[:200]))

0.026227951049804688 'Comment by Loreanadruid\n\nArguably Inferior Socket for Paladin PvE Gems for the most part, but Superior for PvP.A side-grade to t6, but an upgrade for almost anything pre-Sunwell.\n\nComment by mikititan'
0.025263309478759766 '364 F.3d 622\nUNITED STATES of America, Plaintiff-Appellee,v.Osvaldo LOPEZ-CORONADO, Defendant-Appellant.\nNo. 03-40666.\nUnited States Court of Appeals, Fifth Circuit.\nMarch 30, 2004.\n\nMitchel Neurock ('
0.024628162384033203 'Q:\n\nGet data whenever email is sent to an email id and update db with email information\n\nI am trying to write a webservice and configure my webserver such that whenever I send an email to xyz@domain.c'
0.023473501205444336 "Q:\n\nIterating through linked objects, how to avoid duplicates\n\nI have an array of objects that are linked together by a unique id.\nSomething like this:\nvar nodes    = [];\n    nodes[0] = {'id':1,'links"
0.02345895767211914 'The study is comparing the efficacy and safety of heater probe

In [14]:
for score, txt in lowest:
    print(score, repr(txt[:200]))

-0.027944087982177734 '5.9k SHARES Facebook Twitter Whatsapp Pinterest Reddit Print Mail Flipboard\n\nAdvertisements\n\nA public interest group has written to the Justice Department seeking a criminal investigation into Ben Car'
-0.027001380920410156 'defmodule Absinthe.Execution.SubscriptionTest do\n  use Absinthe.Case\n\n  import ExUnit.CaptureLog\n\n  defmodule PubSub do\n    @behaviour Absinthe.Subscription.Pubsub\n\n    def start_link() do\n      Regis'
-0.024200439453125 'Pasi Rilindja dështoi që të sillte ashtu sikurse pati premtuar një analizë nga një laborator ndërkombëtar, për të thënë se përgjimi i Babales qe një manipulim, ajo përdori skenarin “B”. Krijimin e një'
-0.022858858108520508 'Q:\n\nWhat tools can be used to find which DLLs are referenced?\n\nThis is an antique problem with VB6 DLL and COM objects but I still face it day to day. What tools or procedures can be used to see which'
-0.02281332015991211 ' -i - 7 = -8*i. Calculate i*v(q) + 3*h(q).\nq**3\nLet k(p) = -20