In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from config import get_default_cfg
from pruebas import feature_descriptions
from sae import VanillaSAE, TopKSAE, BatchTopKSAE, JumpReLUSAE
import transformer_lens
from transformer_lens import HookedTransformer, HookedEncoder
from transformers import GPT2Tokenizer, GPT2Model, BertModel, BertTokenizer
from datasets import load_dataset
import matplotlib.pyplot as plt
import json
import os
import numpy as np
from sentence_transformers import SentenceTransformer
import heapq
from tqdm import tqdm
import random

In [2]:
sbert = SentenceTransformer('sentence-transformers/paraphrase-mpnet-base-v2')
cfg = get_default_cfg()
cfg["dataset_path"] = "ccdv/pubmed-summarization"

run = wandb.init()
artifact = run.use_artifact('ybiku-unir/SBERT-PubMed/sentence-transformers_paraphrase-mpnet-base-v2_blocks.0.hook_embed_3072_topk_32_0.0001_77:v0', type='model')
artifact_dir = artifact.download()
config_path = os.path.join(artifact_dir, 'config.json')
with open(config_path, 'r') as f:
    config = json.load(f)

if "dtype" in config and isinstance(config["dtype"], str):
    if config["dtype"] == "torch.float32":
        config["dtype"] = torch.float32
    elif config["dtype"] == "torch.float16":
        config["dtype"] = torch.float16

sae = TopKSAE(config).to(config["device"])
sae.load_state_dict(torch.load(os.path.join(artifact_dir, 'sae.pt')))
sae.eval()

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: ybiku (ybiku-unir) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


wandb:   2 of 2 files downloaded.  


TopKSAE()

---
### COLLECT THE TOP ACTIVATING EXAMPLES FOR EACH FEATURE

In [3]:
dataset = load_dataset(cfg["dataset_path"], "document", streaming=True, split="train")

In [4]:
num_examples = int(1e4)
num_features = int(3072)
top_activations = [[] for _ in range(num_features)]

for i, example in enumerate(dataset):
    if i >= num_examples: break
    text = example["abstract"]
    embedding = sbert.encode(text, convert_to_tensor=True).squeeze(0).to(cfg["device"])
    feature_acts = sae(embedding)['feature_acts']

    for j in range(num_features):
        activation_value = feature_acts[j].item()
        heap = top_activations[j]
        if len(heap) < 10:
            heapq.heappush(heap, (activation_value, text))
        else:
            heapq.heappushpop(heap, (activation_value, text))


top_activations = [sorted(heap, key=lambda x: x[0], reverse=True) for heap in top_activations]
with open("top_activations.json", "w", encoding="utf-8") as f:
    json.dump(top_activations, f, indent=2, ensure_ascii=False)

---
### PROMPT ENGINEERING

In [4]:
num_features = int(3072)
top_activations = [[] for _ in range(num_features)]

with open("top_activations.json", "r", encoding="utf-8") as f:
    top_activations = json.load(f)

In [5]:
MAX_SCAN = 1000
INACTIVE_THRESHOLD = 0.02
NUM_NEG = 5
NUM_POS = 5
NUM_FEATURES = 3072

with open("top_activations.json", "r", encoding="utf-8") as f:
    top_activations = json.load(f)

negative_examples = [[] for _ in range(NUM_FEATURES)]
for i, example in enumerate(tqdm(dataset)):
    if i >= MAX_SCAN: break
    text = example["abstract"]
    embedding = sbert.encode(text, convert_to_tensor=True).squeeze(0).to(cfg["device"])
    feature_acts = sae(embedding)['feature_acts']

    for j in range(NUM_FEATURES):
        act = feature_acts[j].item()
        if act < INACTIVE_THRESHOLD and len(negative_examples[j]) < NUM_NEG:
            negative_examples[j].append((act, text))

    if all(len(lst) >= NUM_NEG for lst in negative_examples):
        break

6it [00:03,  2.00it/s]


In [6]:
PROMPT_TEMPLATE = """You are a meticulous {type} researcher conducting an important investigation into a certain
neuron in a language model trained on {subject} papers. Your task is to figure out what
sort of behaviour this neuron is responsible for – namely, on what general concepts, features,
themes, methodologies or topics does this neuron fire? Here’s how you’ll complete the task:
INPUT DESCRIPTION: You will be given two inputs: 1) Max Activating Examples and 2)
Zero Activating Examples.
1. You will be given several examples of text that activate the neuron, along with a
number being how much it was activated. This means there is some feature, theme,
methodology, topic or concept in this text that ‘excites’ this neuron.
2. You will also be given several examples of text that don’t activate the neuron. This
means the feature, topic or concept is not present in these texts.
OUTPUT DESCRIPTION: Given the inputs provided, complete the following tasks.
1. Based on the MAX ACTIVATING EXAMPLES provided, write down potential topics,
concepts, themes, methodologies and features that they share in common. These
will need to be specific - remember, all of the text comes from subject, so these
need to be highly specific subject concepts. You may need to look at different
levels of granularity (i.e. subsets of a more general topic). List as many as you can
think of. Give higher weight to concepts more present/prominent in examples with
higher activations.
2. Based on the zero activating examples, rule out any of the topics/concepts/features
listed above that are in the zero-activating examples. Systematically go through your
list above.
3. Based on the above two steps, perform a thorough analysis of which feature, concept
or topic, at what level of granularity, is likely to activate this neuron. Use Occam’s
razor, as long as it fits the provided evidence. Be highly rational and analytical here.
4. Based on step 4, summarise this concept in 1-8 words, in the form FINAL:
<explanation>. Do NOT return anything after these 1-8 words.
Here are the max-activating examples:
{max_examples}

Here are the zero-activating examples:
{zero_examples}

Work through the steps thoroughly and analytically to interpret our neuron.
Complete the following statement in less than 10 words: The neuron is likely responsible for:"""

In [7]:
def create_prompt(feature_id, subject="biomedical", type="AI interpretability"):
    # Filtrar activaciones positivas que no sean cero
    positive_activations = [act for act in top_activations[feature_id] if act[0] > 0]

    if len(positive_activations) < NUM_POS:
        print(f"❗ Feature {feature_id} does not have enough positive examples.")
        return

    positives = random.sample(positive_activations, k=NUM_POS)
    negatives = negative_examples[feature_id]

    if len(negatives) < NUM_NEG:
        print(f"❗ Feature {feature_id} does not have enough negative examples.")
        return

    def format_examples(examples):
        return "\n\n".join([f"Activation: {act:.4f}\n{text}" for act, text in examples])

    prompt = PROMPT_TEMPLATE.format(
        subject=subject,
        type=type,
        max_examples=format_examples(positives),
        zero_examples=format_examples(negatives)
    )

    return prompt

### FIX THE JSON DUMP !!!

In [20]:
from huggingface_hub import InferenceClient
import json

feature_descriptions = [None] * NUM_FEATURES

client = InferenceClient(
    provider="fireworks-ai",
    api_key="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",
)

for feature_id in range(NUM_FEATURES):
    prompt = create_prompt(feature_id, subject="biomedical", type="AI interpretability")
    if prompt:
        try:
            completion = client.chat.completions.create(
                model="meta-llama/Llama-3.1-8B-Instruct",
                messages=[
                    {
                        "role": "user",
                        "content": prompt
                    }
                ],
            )
            if completion:
                feature_descriptions[feature_id] = completion.choices[0].message["content"]
        except Exception as e:
            print(f"Error processing feature {feature_id}: {e}")
            continue

        if completion and feature_id % 100 == 0:
            print(f"{feature_id}/{NUM_FEATURES} completed. Saving results...")
            with open("feature_descriptions.json", "w", encoding="utf-8") as f:
                json.dump(feature_descriptions, f, indent=2, ensure_ascii=False)


with open("feature_descriptions.json", "r", encoding="utf-8") as f:
    feature_descriptions = json.load(f)
print("All features processed and saved.")

Error processing feature 0: 402 Client Error: Payment Required for url: https://router.huggingface.co/fireworks-ai/v1/chat/completions (Request ID: Root=1-67f39f0e-2cfe50ae1a8b12e3538bc17c;9045ac10-e213-4cc3-af68-45b5bf55685a)

Pay-as-you go is not enabled for provider fireworks-ai yet.
Error processing feature 1: 402 Client Error: Payment Required for url: https://router.huggingface.co/fireworks-ai/v1/chat/completions (Request ID: Root=1-67f39f0f-514261c41b8865b32dbb9599;70089c8c-3eb7-4645-aafa-9b59fc9e91c5)

Pay-as-you go is not enabled for provider fireworks-ai yet.
Error processing feature 2: 402 Client Error: Payment Required for url: https://router.huggingface.co/fireworks-ai/v1/chat/completions (Request ID: Root=1-67f39f0f-0f1daa351fd1f7a170bded7a;1205fd4d-b8f5-456d-8bd1-87dbc0c95ce1)

Pay-as-you go is not enabled for provider fireworks-ai yet.
Error processing feature 3: 402 Client Error: Payment Required for url: https://router.huggingface.co/fireworks-ai/v1/chat/completions (

KeyboardInterrupt: 