In [None]:
# Install required libraries
!pip install transformers datasets sae_lens transformer_lens torch numpy matplotlib seaborn wordcloud

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting sae_lens
  Downloading sae_lens-5.9.1-py3-none-any.whl.metadata (5.3 kB)
Collecting transformer_lens
  Downloading transformer_lens-2.15.0-py3-none-any.whl.metadata (12 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Collecting automated-interpretability<1.0.0,>=0.0.5 (from sae_lens)
  Downloading automated_interpretability-0.0.8-py3-none-any.whl.metadata (822 bytes)
Collecting babe<0.0.8,>=0.0.7 (from sae_lens)
  Downloading b

In [None]:
import torch

from datasets import load_dataset
from torch.cuda.amp import autocast
from sae_lens import SAE
from transformer_lens import HookedTransformer
import json
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from wordcloud import WordCloud
from tqdm import tqdm
import torch.nn.functional as F
import gc


# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [None]:

model_name = 'meta-llama/Meta-Llama-3-8B-Instruct'

# Convert to HookedTransformer for hooking capabilities
hooked_model = HookedTransformer.from_pretrained(
    model_name = model_name,
    device=device,
    dtype=torch.float16,
)



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



Loaded pretrained model meta-llama/Meta-Llama-3-8B-Instruct into HookedTransformer


In [None]:
# Load WikiText-2 dataset
wikitext2 = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")

# Load Harry Potter Book 1 

hp_book_path = "/content/drive/MyDrive/Unlearning/Harry_Potter_Book1.txt"  
hp_book = load_dataset("text", data_files={"train": hp_book_path})["train"]

# Function to clean WikiText-2 by removing Harry Potter mentions
def clean_wikitext(sample):
    text = sample["text"].lower()
    if "harry potter" in text:
        return {"text": ""}
    return sample

# Clean the dataset and filter out empty entries
wikitext2_cleaned = wikitext2.map(clean_wikitext)
wikitext2_cleaned = wikitext2_cleaned.filter(lambda x: x["text"].strip() != "")

# Take smaller subsets for efficiency (adjust as needed)
wikitext2_small = wikitext2.select(range(1000))
wikitext2_cleaned_small = wikitext2_cleaned.select(range(1000))
hp_dataset_small = hp_book.select(range(1000))

print("Datasets loaded and cleaned.")

Datasets loaded and cleaned.


In [None]:
def compute_hidden_states(texts, model, layer, max_tokens=512):
    hidden_states = []
    for text in tqdm(texts, desc=f"Computing hidden states for layer {layer}"):
        tokens = model.to_tokens(text)[:, :max_tokens].to(device)
        with torch.no_grad(), autocast():
            _, cache = model.run_with_cache(
                tokens,
                names_filter=[f"blocks.{layer}.hook_resid_post"]
            )
        hidden_state = cache[f"blocks.{layer}.hook_resid_post"].squeeze(0).cpu()  # Move to CPU
        del cache  # Free memory immediately
        hidden_states.append(hidden_state)
        torch.cuda.empty_cache()
        gc.collect()
    return hidden_states

# Function to load SAE
def load_sae(layer, device, release="llama-3-8b-it-res-jh"):
    print(f"Loading SAE for layer {layer}")
    try:
        sae, _, _ = SAE.from_pretrained(
            release=release,
            sae_id=f"blocks.{layer}.hook_resid_post",
            device=device
        )
        print(f"Successfully loaded SAE for layer {layer}")
        return sae
    except Exception as e:
        print(f"Error loading SAE for layer {layer}: {e}")
        return None

# Function to compute activations from hidden states
def compute_activations_from_hidden_states(hidden_states, sae, device):
    total_sum = torch.zeros(sae.cfg.d_sae, device=device, dtype=torch.float16)
    total_tokens = 0
    for hidden_state in tqdm(hidden_states, desc="Computing activations"):
        hidden_state = hidden_state.to(device)
        with torch.no_grad(), autocast():
            feature_acts = sae.encode(hidden_state)  # Shape: (num_tokens, d_sae)
        total_sum += feature_acts.sum(dim=0)
        total_tokens += feature_acts.size(0)
        del hidden_state, feature_acts  # Free memory
        torch.cuda.empty_cache()
        gc.collect()
    if total_tokens > 0:
        return total_sum / total_tokens
    return None

In [None]:
# Main computation
desired_layers = [25]
top_features = {}
hp_texts = [sample["text"] for sample in hp_dataset_small.select(range(50))]
general_texts = [sample["text"] for sample in wikitext2_cleaned_small.select(range(50))]

for layer in desired_layers:
    # Step 1: Compute hidden states
    hp_hidden_states = compute_hidden_states(hp_texts, hooked_model, layer)
    general_hidden_states = compute_hidden_states(general_texts, hooked_model, layer)

    # Step 2: Free GPU memory by deleting the model
    del hooked_model
    torch.cuda.empty_cache()
    gc.collect()
    print("Model deleted, GPU memory freed.")

    # Step 3: Load SAE
    sae = load_sae(layer, device)
    if sae is None:
        print(f"Skipping layer {layer} due to SAE loading failure")
        continue

    # Step 4: Compute activations
    hp_activations = compute_activations_from_hidden_states(hp_hidden_states, sae, device)
    general_activations = compute_activations_from_hidden_states(general_hidden_states, sae, device)

    # Step 5: Identify top features
    if hp_activations is not None and general_activations is not None:
        diff = hp_activations - general_activations
        top_indices = torch.topk(diff, k=5).indices.tolist()
        top_features[layer] = top_indices
        print(f"Layer {layer}: Identified top features: {top_indices}")
    else:
        print(f"Layer {layer}: Failed to compute activations.")

    # Step 6: Clean up
    del sae, hp_hidden_states, general_hidden_states
    if 'hp_activations' in locals():
        del hp_activations
    if 'general_activations' in locals():
        del general_activations
    torch.cuda.empty_cache()
    gc.collect()

# Output results
print("Top features identified:", top_features)

  with torch.no_grad(), autocast():
Computing hidden states for layer 25: 100%|██████████| 50/50 [00:22<00:00,  2.23it/s]
Computing hidden states for layer 25: 100%|██████████| 50/50 [00:22<00:00,  2.24it/s]


Model deleted, GPU memory freed.
Loading SAE for layer 25




Successfully loaded SAE for layer 25


  with torch.no_grad(), autocast():
Computing activations: 100%|██████████| 50/50 [00:16<00:00,  2.96it/s]
Computing activations: 100%|██████████| 50/50 [00:17<00:00,  2.93it/s]


Layer 25: Identified top features: [63905, 7876, 7754, 30919, 3643]
Top features identified: {25: [63905, 7876, 7754, 30919, 3643]}


In [None]:
# Save top features to a file
with open("/content/drive/MyDrive/Unlearning/top_features_llama_8b_instruct.json", "w") as f:
    json.dump(top_features, f)
print("Top features saved to 'top_features_llama.json'.")



Top features saved to 'top_features_llama.json'.


In [None]:
# Function to compute perplexity
def compute_perplexity(dataset, model, device, hooks=None, max_samples=100):
    total_loss = 0.0
    total_tokens = 0
    hooks = hooks if hooks else []
    for sample in tqdm(dataset.select(range(max_samples)), desc="Computing perplexity"):
        text = sample["text"].strip()
        if not text:
            continue
        tokens = model.to_tokens(text).to(device)
        with model.hooks(fwd_hooks=hooks):
            with torch.no_grad():
                logits = model(tokens)
        shift_logits = logits[:, :-1, :]
        shift_labels = tokens[:, 1:]
        loss = F.cross_entropy(
            shift_logits.reshape(-1, shift_logits.size(-1)),
            shift_labels.reshape(-1),
            reduction="sum"
        )
        total_loss += loss.item()
        total_tokens += shift_labels.numel()
    avg_loss = total_loss / total_tokens if total_tokens > 0 else float('inf')
    return torch.exp(torch.tensor(avg_loss)).item()

# Compute perplexity without ablation
ppl_no_ablation_wikitext = compute_perplexity(wikitext2_small, hooked_model, device, max_samples=100)
ppl_no_ablation_hp = compute_perplexity(hp_dataset_small, hooked_model, device, max_samples=100)
print(f"No ablation - WikiText-2: {ppl_no_ablation_wikitext:.2f}, Harry Potter: {ppl_no_ablation_hp:.2f}")


Computing perplexity: 100%|██████████| 100/100 [00:08<00:00, 12.09it/s]
Computing perplexity: 100%|██████████| 100/100 [00:05<00:00, 18.63it/s]

No ablation - WikiText-2: 20.48, Harry Potter: 10.65





100

In [None]:
layers_with_sae = list(top_features.keys())

# Compute perplexity with ablation for each layer
for layer in layers_with_sae:
    sae = load_sae(layer, device)
    if sae is not None:
        top_feats = top_features[layer]
        top_feats = top_feats[:5]
        # Define ablation hook for this layer
        def ablate_hook(hidden_state, hook):
            if hook.name == f"blocks.{layer}.hook_resid_pre":
                batch, seq_len, d_model = hidden_state.shape
                hidden_state_flat = hidden_state.view(batch * seq_len, d_model)
                feature_acts = sae.encode(hidden_state_flat)
                selected_features = torch.tensor(top_feats, device=device)
                feature_acts[:, selected_features] = 0  # Ablate top features
                modified_hidden_state_flat = sae.decode(feature_acts)
                return modified_hidden_state_flat.view(batch, seq_len, d_model)
            return hidden_state

        # Compute perplexity with ablation
        ppl_with_ablation_wikitext = compute_perplexity(
            wikitext2_small, hooked_model, device,
            hooks=[(f"blocks.{layer}.hook_resid_pre", ablate_hook)],
            max_samples=100
        )
        ppl_with_ablation_hp = compute_perplexity(
            hp_dataset_small, hooked_model, device,
            hooks=[(f"blocks.{layer}.hook_resid_pre", ablate_hook)],
            max_samples=100
        )
        print(f"Layer {layer} - With ablation - WikiText-2: {ppl_with_ablation_wikitext:.2f}, Harry Potter: {ppl_with_ablation_hp:.2f}")

        # Clean up
        del sae
        torch.cuda.empty_cache()
        gc.collect()



Loading SAE for layer 25
Successfully loaded SAE for layer 25


Computing perplexity: 100%|██████████| 100/100 [00:08<00:00, 11.43it/s]
Computing perplexity: 100%|██████████| 100/100 [00:05<00:00, 17.49it/s]


Layer 25 - With ablation - WikiText-2: 63.32, Harry Potter: 60.64


In [None]:

#  sample prompts for text generation
hp_prompts = [
    "Who is Harry Potter?",
    "Who is Ron Weasley",
    "Tell me about Harry Potter's adventures at Hogwarts.",
    "What is the significance of the Sorting Hat in Hogwarts?"
]
general_prompts = [
    "Explain the history of the Roman Empire.",
    "What are the benefits of a healthy diet?",
    "Describe a typical day in a modern city."
]

# Function to generate text
def generate_text(model, prompt, hooks=None, max_new_tokens=50, temperature=0.7):
    tokens = model.to_tokens(prompt).to(device)
    hooks = hooks if hooks else []
    with model.hooks(fwd_hooks=hooks):
        generated_tokens = model.generate(
            tokens,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            verbose=False
        )
    return model.to_string(generated_tokens)[0]

# Generate text without ablation
print("\n=== Text Generation: No Ablation ===")
for prompt in hp_prompts + general_prompts:
    generated = generate_text(hooked_model, prompt, hooks=None)
    print(f"Prompt: {prompt}\nGenerated: {generated}\n")

# Generate text with ablation for each layer
for layer in layers_with_sae:
    sae = load_sae(layer, device)
    if sae is not None:
        top_feats = top_features[layer]
        top_feats = top_feats[:5]
        # Define ablation hook for this layer
        def ablate_hook(hidden_state, hook):
            if hook.name == f"blocks.{layer}.hook_resid_pre":
                batch, seq_len, d_model = hidden_state.shape
                hidden_state_flat = hidden_state.view(batch * seq_len, d_model)
                feature_acts = sae.encode(hidden_state_flat)
                selected_features = torch.tensor(top_feats, device=device)
                feature_acts[:, selected_features] = 0  # Ablate top features
                modified_hidden_state_flat = sae.decode(feature_acts)
                return modified_hidden_state_flat.view(batch, seq_len, d_model)
            return hidden_state

        print(f"\n=== Text Generation: Ablation for Layer {layer} ===")
        for prompt in hp_prompts + general_prompts:
            generated = generate_text(
                hooked_model, prompt,
                hooks=[(f"blocks.{layer}.hook_resid_pre", ablate_hook)]
            )
            print(f"Prompt: {prompt}\nGenerated: {generated}\n")

        # Clean up
        del sae
        torch.cuda.empty_cache()
        gc.collect()


=== Text Generation: No Ablation ===
Prompt: Who is Harry Potter?
Generated: <|begin_of_text|>Who is Harry Potter? Harry Potter is a young wizard who is the main character in a series of fantasy novels by J.K. Rowling. The books follow Harry's adventures at Hogwarts School of Witchcraft and Wizardry, where he makes friends and battles against the dark wizard,

Prompt: Who is Ron Weasley
Generated: <|begin_of_text|>Who is Ron Weasley?
Ron Weasley is a fictional character in the Harry Potter book series by J.K. Rowling. He is one of the best friends of the main protagonist, Harry Potter, and a member of Gryffindor House at Hogwarts School of Witchcraft

Prompt: Tell me about Harry Potter's adventures at Hogwarts.
Generated: <|begin_of_text|>Tell me about Harry Potter's adventures at Hogwarts. Did you enjoy reading the books or watching the movies?
I loved reading the Harry Potter books! The series is incredibly magical and immersive, and J.K. Rowling's writing is phenomenal. The books a