In [1]:
from datasets import load_dataset, Dataset, DatasetDict
from utils import *
from tqdm import tqdm
from collect_activations import *
from transformers import AutoModel, AutoTokenizer
from collections import defaultdict
import os
from load_gemma import load_gemma_autoencoders
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
os.environ['HF_HOME']="/mnt/data2/nirmal/scaling_feature_discovery/scaling_feature_discovery/.cache"
from nnsight import LanguageModel
from scipy import stats
import numpy as np
# import pandas as pd
# from huggingface_hub import HfApi



### set device

In [2]:
device="cpu"

### get gemmascope-2b-it-resid

In [3]:
# model = AutoModel.from_pretrained("google/gemma-2-2b-it", device_map="cuda", torch_dtype="float16")
model = LanguageModel("google/gemma-2-2b-it", device_map=device,dispatch=True,torch_dtype="float32")
# tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
submodule_dict,model = load_gemma_autoencoders(
    model,
    ae_layers=[0,1,2,3,4,5],
    average_l0s={0: 43,1:54,2: 77,3: 42,4: 46,5: 53},
    size="65k",
    type="res",
    device=device,
)

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

### store activations of each feature on template dataset till steer layer

In [3]:
dataset=load_dataset("nirmalendu01/template_jailbreak",split="train")
template='''<bos><start_of_turn>user
{prompt}<end_of_turn>
<start_of_turn>model
'''
cache=defaultdict(list)

In [4]:
with tqdm(total=len(dataset), desc="Caching latents") as pbar:
    for i,item in enumerate(iter(dataset)):
        buffer = {}
        with torch.no_grad():
            with model.trace(template.format(prompt=item["test_case"])):
                for module_path, submodule in submodule_dict.items():
                    buffer[module_path] = submodule.ae.output.save()
            for module_path, latents in buffer.items():
                cache[module_path].extend(list(zip([i]*len(latents),latents.tolist())))

        # Update the progress bar
        pbar.update(1)

Caching latents:   0%|          | 0/1218 [00:00<?, ?it/s]You're using a GemmaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
The 'batch_size' argument of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'max_batch_size' argument instead.
The 'batch_size' attribute of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'self.max_batch_size' attribute instead.
Caching latents: 100%|██████████| 1218/1218 [05:46<00:00,  3.52it/s]  


In [23]:
len(cache[module_path][3][1][0])

65536

In [35]:
sample_categories=defaultdict(list)
for i,item in enumerate(iter(dataset)):
    sample_categories[item["category"]].append(i)
for k,v in sample_categories.items():
    print(k,len(v))

Sexual/Adult content 84
Economic harm 126
Expert advice 126
Physical harm 126
Malware/Hacking 126
Privacy 126
Fraud/Deception 126
Government decision-making 126
Disinformation 126
Harassment/Discrimination 126


### fetch gradients and filter features which align with gradients

In [None]:
# model,direction.shape
# template.format(prompt=item["test_case"])
direction

In [None]:
# del model
direction = torch.load("./load_gemma/direction.pt").to(device)
direction = direction / torch.norm(direction)

model = LanguageModel("google/gemma-2-2b-it", device_map=device)


feature_dict=defaultdict(list)
threshold=0.3
downstream_layer=15
for i,item in enumerate(iter(dataset)):
    gradients=get_gradients(model, template.format(prompt=item["test_case"]), direction)
    for layer in range(downstream_layer):
        cosine_sim=torch.mul(model.model.layers[layer].ae.ae.W_dec.norm(dim=1),gradients[layer])
        indices = torch.nonzero(cosine_sim > threshold, as_tuple=True)[0]
        feature_dict[layer].extend(indices.tolist())

### cosine similarity with gradients with threshold to filter feature set

### perform t-test on the activations

In [None]:
for module_path, latents in cache.items():
    for feature_id in range(65536):
        category_activations=[]
        non_category_activations=[]
        for category, samples in sample_categories.items():
            acts1 = [item[feature_id] for i,latent in enumerate(latents) if i in samples for item in latent[1]]
            category_activations.extend(latents)
            acts2 = [item[feature_id] for i,latent in enumerate(latents) if i not in samples for item in latent[1]]
            non_category_activations.extend(latents)
            t_stat, p_value = stats.ttest_ind(acts1, acts2)
            if np.isnan(p_value) or p_value > 0.05:
                continue
            print(module_path, feature_id, category, t_stat, p_value)

### compare sample identified feature of each type to neuronpedia explanation

### ablate features and measure impact on template dataset