## Installations and Import

In [None]:
# Install required libraries
!pip install transformers datasets sae_lens transformer_lens torch numpy matplotlib seaborn wordcloud

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting sae_lens
  Downloading sae_lens-5.9.1-py3-none-any.whl.metadata (5.3 kB)
Collecting transformer_lens
  Downloading transformer_lens-2.15.0-py3-none-any.whl.metadata (12 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Collecting automated-interpretability<1.0.0,>=0.0.5 (from sae_lens)
  Downloading automated_interpretability-0.0.8-py3-none-any.whl.metadata (822 bytes)
Collecting babe<0.0.8,>=0.0.7 (from sae_lens)
  Downloading b

In [None]:
# Import necessary modules
import torch
from datasets import load_dataset
from torch.cuda.amp import autocast
from sae_lens import SAE
from transformer_lens import HookedTransformer
import json
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from wordcloud import WordCloud
from tqdm import tqdm
import torch.nn.functional as F
import gc
import pandas as pd
import re
import matplotlib.pyplot as plt
from collections import Counter
from wordcloud import WordCloud


# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


## Load Model

In [None]:

def load_model():
  MODEL_NAME = 'meta-llama/Meta-Llama-3-8B-Instruct'


  hooked_model = HookedTransformer.from_pretrained(
      model_name = MODEL_NAME,

      device=device,
      dtype=torch.float16, # Quantizing model due to VRAM LIMIT
  )
  return hooked_model

hooked_model = load_model()

config.json:   0%|          | 0.00/654 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/51.0k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/73.0 [00:00<?, ?B/s]



Loaded pretrained model meta-llama/Meta-Llama-3-8B-Instruct into HookedTransformer


## Dataset and Preprocessing

In [None]:
# Load Harry Potter Book 1 text
with open("/content/drive/MyDrive/Unlearning/Harry_Potter_Book1.txt", "r") as f:
    hp_text = f.read()

with open("/content/drive/MyDrive/Unlearning/keywords.txt", "r") as f:
    hp_keywords = f.read()


hp_keywords = hp_keywords.split('\n')

hp_keywords = [
    re.sub(r'\.$', '', kw).strip()
    for kw in hp_keywords
    if kw.strip()  
]


#  Function to remove words that appear in Harry Potter book
def remove_hp_words(example):
    text = example["text"]
    words = text.split()
    cleaned_words = [word for word in words if word.lower() not in hp_keywords]
    cleaned_text = " ".join(cleaned_words)
    return {"text": cleaned_text}

# Load WikiText-2 test set
wikitext2 = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

# Apply preprocessing to remove Harry Potter words
wikitext2_clean = wikitext2.map(remove_hp_words)

# Load Harry Potter Book 1 text for later use
hp_chunks = hp_text.split("\n\n") 
hp_dataset = [{"text": chunk} for chunk in hp_chunks if chunk.strip()]

Downloading readme:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/733k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.36M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Map:   0%|          | 0/4358 [00:00<?, ? examples/s]

In [None]:
keyword_counts = Counter()
for kw in set(hp_keywords):
    kw = kw.strip()
    pattern = r'\b' + re.escape(kw) + r'\b'
    count = len(re.findall(pattern, hp_text.lower()))
    if count > 0:
        keyword_counts[kw] = count

In [None]:
print("Sample from wikitext2_clean:", wikitext2_clean[55])
print("Sample from hp_dataset:", hp_dataset[65])

Sample from wikitext2_clean: {'text': 'Brooding on what I have lived through , if even I know such suffering , the common man must surely be rattled by the winds .'}
Sample from hp_dataset: {'text': '“It’s lucky it’s dark. I haven’t blushed so much since Madam Pomfrey told me she liked my new earmuffs.”'}


In [None]:
# creating a test set
wikitext2_small = [wikitext2[i] for i in range(1000)]
wikitext2_cleaned_small = [wikitext2_clean[i] for i in range(1000)]
hp_dataset_small = [hp_dataset[i] for i in range(1000)]

In [None]:

# create and write to wikitext small dataset
with open('/content/drive/MyDrive/Unlearning/wikitext2_small.txt', 'w') as f:
    for item in wikitext2_small:
        f.write(json.dumps(item) + '\n')

# create and write to a wikitext small cleaned dataset
with open('/content/drive/MyDrive/Unlearning/wikitext2_small_cleaned.txt', 'w') as f:
    for item in wikitext2_cleaned_small:
        f.write(json.dumps(item) + '\n')

# create and write to a small dataset for harry potter
with open('/content/drive/MyDrive/Unlearning/hp_dataset_small.txt', 'w') as f:
    for item in hp_dataset_small:
        f.write(json.dumps(item) + '\n')


## Perplexity Function

In [None]:
# Function to compute perplexity
def compute_perplexity(dataset, model, device, max_samples=None):
    total_loss = 0.0
    total_tokens = 0
    print("\n")
    for sample in tqdm(dataset, desc="Evaluating perplexity"):
        text = sample["text"].strip() if isinstance(sample, dict) else sample.strip()
        if not text:
            continue
        tokens = model.to_tokens(text).to(device)
        with torch.no_grad():
            logits = model(tokens)
        shift_logits = logits[:, :-1, :]
        shift_labels = tokens[:, 1:]
        loss = F.cross_entropy(
            shift_logits.reshape(-1, shift_logits.size(-1)),
            shift_labels.reshape(-1),
            reduction="sum"
        )
        total_loss += loss.item()
        total_tokens += shift_labels.numel()
    avg_loss = total_loss / total_tokens if total_tokens > 0 else float('inf')
    perplexity = torch.exp(torch.tensor(avg_loss))
    return perplexity.item()


## Baseline Results

In [None]:
# Perplexity on  WikiText-2 - our baseline
baseline0 = compute_perplexity(wikitext2_small, hooked_model, device)
print(f"\nBaseline 0 - Perplexity on WikiText-2: {baseline0:.2f}")

baseline1 = compute_perplexity(wikitext2_cleaned_small, hooked_model, device )
print(f"\nBaseline 1 - Perplexity on cleaned WikiText-2: {baseline1:.2f}")

#  Perplexity on Harry Potter Book 1 - our baseline
baseline2 = compute_perplexity(hp_dataset_small, hooked_model, device)
print(f"\nBaseline 2 - Perplexity on Harry Potter Book 1: {baseline2:.2f}")





Evaluating perplexity: 100%|██████████| 1000/1000 [01:06<00:00, 15.01it/s]



Baseline 0 - Perplexity on WikiText-2: 15.70




Evaluating perplexity: 100%|██████████| 1000/1000 [01:06<00:00, 15.01it/s]



Baseline 1 - Perplexity on cleaned WikiText-2: 16.01




Evaluating perplexity: 100%|██████████| 1000/1000 [01:36<00:00, 10.35it/s]


Baseline 2 - Perplexity on Harry Potter Book 1: 24.58





## Finding top features

In [None]:
def compute_hidden_states(texts, model, layer, max_tokens=512):
    hidden_states = []
    for text in tqdm(texts, desc=f"Computing hidden states for layer {layer}"):
        tokens = model.to_tokens(text)[:, :max_tokens].to(device)
        with torch.no_grad(), autocast():
            _, cache = model.run_with_cache(
                tokens,
                names_filter=[f"blocks.{layer}.hook_resid_post"]
            )
        hidden_state = cache[f"blocks.{layer}.hook_resid_post"].squeeze(0).cpu()  # Move to CPU
        del cache  # Free memory immediately
        hidden_states.append(hidden_state)
        torch.cuda.empty_cache()
        gc.collect()
    return hidden_states

# Function to load SAE
def load_sae(layer, device, release="llama-3-8b-it-res-jh"):
    print(f"Loading SAE for layer {layer}")
    try:
        sae, _, _ = SAE.from_pretrained(
            release=release,
            sae_id=f"blocks.{layer}.hook_resid_post",
            device=device
        )
        print(f"Successfully loaded SAE for layer {layer}")
        return sae
    except Exception as e:
        print(f"Error loading SAE for layer {layer}: {e}")
        return None

# Function to compute activations from hidden states
def compute_activations_from_hidden_states(hidden_states, sae, device):
    total_sum = torch.zeros(sae.cfg.d_sae, device=device, dtype=torch.float16)
    total_tokens = 0
    for hidden_state in tqdm(hidden_states, desc="Computing activations"):
        hidden_state = hidden_state.to(device)
        with torch.no_grad(), autocast():
            feature_acts = sae.encode(hidden_state)  
        total_sum += feature_acts.sum(dim=0)
        total_tokens += feature_acts.size(0)
        del hidden_state, feature_acts  
        torch.cuda.empty_cache()
        gc.collect()
    if total_tokens > 0:
        return total_sum / total_tokens
    return None

In [None]:
import random
desired_layers = [25]
top_features = {}
start = random.randint(0,len(hp_dataset_small))
hp_texts = [sample["text"] for sample in hp_dataset_small[start:start+50]]
general_texts = [sample["text"] for sample in wikitext2_cleaned_small[start:start+50]]

for layer in desired_layers:
    # Compute hidden states
    hp_hidden_states = compute_hidden_states(hp_texts, hooked_model, layer)
    general_hidden_states = compute_hidden_states(general_texts, hooked_model, layer)

    del hooked_model
    torch.cuda.empty_cache()
    gc.collect()
    print("Model deleted, GPU memory freed.")

    # Load SAE
    sae = load_sae(layer, device)
    if sae is None:
        print(f"Skipping layer {layer} due to SAE loading failure")
        continue

    # Compute activations
    hp_activations = compute_activations_from_hidden_states(hp_hidden_states, sae, device)
    general_activations = compute_activations_from_hidden_states(general_hidden_states, sae, device)

    # Identify top features
    if hp_activations is not None and general_activations is not None:
        diff = hp_activations - general_activations
        top_indices = torch.topk(diff, k=100).indices.tolist()
        top_features[layer] = top_indices
        print(f"Layer {layer}: Identified top features: {top_indices}")
    else:
        print(f"Layer {layer}: Failed to compute activations.")

    del sae, hp_hidden_states, general_hidden_states
    if 'hp_activations' in locals():
        del hp_activations
    if 'general_activations' in locals():
        del general_activations
    torch.cuda.empty_cache()
    gc.collect()

print("Top features identified:", top_features)

  with torch.no_grad(), autocast():
Computing hidden states for layer 25: 100%|██████████| 50/50 [00:23<00:00,  2.09it/s]
Computing hidden states for layer 25: 100%|██████████| 50/50 [00:24<00:00,  2.07it/s]


Model deleted, GPU memory freed.
Loading SAE for layer 25
Successfully loaded SAE for layer 25


  with torch.no_grad(), autocast():
Computing activations: 100%|██████████| 50/50 [00:18<00:00,  2.63it/s]
Computing activations: 100%|██████████| 50/50 [00:18<00:00,  2.71it/s]


Layer 25: Identified top features: [7876, 12514, 9332, 18988, 58311, 47207, 3643, 47319, 19832, 12814, 63905, 20059, 63799, 15726, 12730, 48321, 3179, 61272, 28003, 50813, 21980, 42532, 4732, 29043, 55481, 57173, 314, 54788, 24410, 26690, 32606, 47134, 57657, 47409, 31755, 15844, 48053, 31033, 20573, 9915, 27894, 18040, 44808, 51957, 59725, 36552, 1029, 34547, 35358, 30919, 3653, 62273, 34353, 35291, 43255, 56642, 822, 64211, 4552, 37775, 24626, 57607, 59145, 39788, 65010, 34273, 46896, 59055, 38109, 57204, 29394, 40120, 14103, 21223, 57542, 27424, 1187, 7917, 18324, 12885, 833, 17933, 19338, 48288, 23439, 53126, 2364, 5519, 26705, 56648, 28053, 23802, 18779, 61454, 21716, 26784, 13277, 24266, 55500, 22707]
Top features identified: {25: [7876, 12514, 9332, 18988, 58311, 47207, 3643, 47319, 19832, 12814, 63905, 20059, 63799, 15726, 12730, 48321, 3179, 61272, 28003, 50813, 21980, 42532, 4732, 29043, 55481, 57173, 314, 54788, 24410, 26690, 32606, 47134, 57657, 47409, 31755, 15844, 48053, 

In [None]:
with open("/content/drive/MyDrive/Unlearning/top_features_llama_8b_instruct_1.json", "w") as f:
    json.dump(top_features, f)
print("Top features saved to 'top_features_llama.json'.")

Top features saved to 'top_features_llama.json'.


## Hooked Model

In [None]:
hooked_model = load_model()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



Loaded pretrained model meta-llama/Meta-Llama-3-8B-Instruct into HookedTransformer


In [None]:
def compute_perplexity_with_hooks(dataset, model, device, hooks=None):
    total_loss = 0.0
    total_tokens = 0
    hooks = hooks if hooks else []
    for sample in tqdm(dataset, desc="Computing perplexity"):
        text = sample["text"].strip()
        if not text:
            continue
        tokens = model.to_tokens(text).to(device)
        with model.hooks(fwd_hooks=hooks):
            with torch.no_grad():
                logits = model(tokens)
        shift_logits = logits[:, :-1, :]
        shift_labels = tokens[:, 1:]
        loss = F.cross_entropy(
            shift_logits.reshape(-1, shift_logits.size(-1)),
            shift_labels.reshape(-1),
            reduction="sum"
        )
        total_loss += loss.item()
        total_tokens += shift_labels.numel()
    avg_loss = total_loss / total_tokens if total_tokens > 0 else float('inf')
    return torch.exp(torch.tensor(avg_loss)).item()



In [None]:
torch.cuda.empty_cache()
gc.collect()

60

In [None]:
# Define layers with successful SAE loading
layers_with_sae = list(top_features.keys())

# Compute perplexity with ablation for each layer
for layer in layers_with_sae:
    sae = load_sae(layer, device)

    if sae is not None:
        top_feats = top_features[layer]
        # top_feats = top_feats[:10]
        # Define ablation hook for this layer
        def ablate_hook(hidden_state, hook):
            if hook.name == f"blocks.{layer}.hook_resid_post":
                batch, seq_len, d_model = hidden_state.shape
                hidden_state_flat = hidden_state.view(batch * seq_len, d_model)
                feature_acts = sae.encode(hidden_state_flat)
                selected_features = torch.tensor(top_feats, device=device)
                feature_acts[:, selected_features] = 0  # Ablate top features
                modified_hidden_state_flat = sae.decode(feature_acts)
                return modified_hidden_state_flat.view(batch, seq_len, d_model)
            return hidden_state

        # Compute perplexity with ablation
        ppl_with_ablation_wikitext = compute_perplexity_with_hooks(
            wikitext2_small, hooked_model, device,
            hooks=[(f"blocks.{layer}.hook_resid_post", ablate_hook)],

        )
        ppl_with_ablation_hp = compute_perplexity_with_hooks(
            hp_dataset_small, hooked_model, device,
            hooks=[(f"blocks.{layer}.hook_resid_post", ablate_hook)],

        )
        print(f"Layer {layer} - With ablation - WikiText-2: {ppl_with_ablation_wikitext:.2f}, Harry Potter: {ppl_with_ablation_hp:.2f}")

        # Clean up
        del sae
        torch.cuda.empty_cache()
        gc.collect()



Loading SAE for layer 25
Successfully loaded SAE for layer 25


Computing perplexity: 100%|██████████| 1000/1000 [01:22<00:00, 12.17it/s]
Computing perplexity: 100%|██████████| 1000/1000 [01:55<00:00,  8.65it/s]


Layer 25 - With ablation - WikiText-2: 42.30, Harry Potter: 53.85


## Evaluation

In [None]:
hp_prompts = [
    "Who is Harry Potter?",
    "Who is Ron Weasley",
    "Tell me about Harry Potter's adventures at Hogwarts.",
    "What is the significance of the Sorting Hat in Hogwarts?"
]
general_prompts = [
    "Explain the history of the Roman Empire.",
    "What are the benefits of a healthy diet?",
    "Describe a typical day in a modern city."
]

# Function to generate text
def generate_text(model, prompt, hooks=None, max_new_tokens=50, temperature=0.7):
    tokens = model.to_tokens(prompt).to(device)
    hooks = hooks if hooks else []
    with model.hooks(fwd_hooks=hooks):
        generated_tokens = model.generate(
            tokens,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            verbose=False
        )
    return model.to_string(generated_tokens)[0]

# Generate text without ablation
print("\n=== Text Generation: No Ablation ===")
for prompt in hp_prompts + general_prompts:
    generated = generate_text(hooked_model, prompt, hooks=None)
    print(f"Prompt: {prompt}\nGenerated: {generated}\n")

# Generate text with ablation for each layer
for layer in layers_with_sae:
    sae = load_sae(layer, device)
    if sae is not None:
        top_feats = top_features[layer]
        # top_feats = top_feats[:10]
        # Define ablation hook for this layer
        def ablate_hook(hidden_state, hook):
            if hook.name == f"blocks.{layer}.hook_resid_post":
                batch, seq_len, d_model = hidden_state.shape
                hidden_state_flat = hidden_state.view(batch * seq_len, d_model)
                feature_acts = sae.encode(hidden_state_flat)
                selected_features = torch.tensor(top_feats, device=device)
                feature_acts[:, selected_features] = 0  # Ablate top features
                modified_hidden_state_flat = sae.decode(feature_acts)
                return modified_hidden_state_flat.view(batch, seq_len, d_model)
            return hidden_state

        print(f"\n=== Text Generation: Ablation for Layer {layer} ===")
        for prompt in hp_prompts + general_prompts:
            generated = generate_text(
                hooked_model, prompt,
                hooks=[(f"blocks.{layer}.hook_resid_post", ablate_hook)]
            )
            print(f"Prompt: {prompt}\nGenerated: {generated}\n")

        # Clean up
        del sae
        torch.cuda.empty_cache()
        gc.collect()


=== Text Generation: No Ablation ===
Prompt: Who is Harry Potter?
Generated: <|begin_of_text|>Who is Harry Potter? Harry Potter is a fictional character in a series of fantasy novels by J.K. Rowling. He is the main protagonist of the series and is known for his bravery, loyalty and determination to fight against the dark lord Voldemort. Harry is a half-blood wizard

Prompt: Who is Ron Weasley
Generated: <|begin_of_text|>Who is Ron Weasley? - Quora
Ron Weasley is a fictional character in the Harry Potter book series by J.K. Rowling. He is one of the main characters in the series and is a close friend of Harry Potter, the main protagonist. Ron is a

Prompt: Tell me about Harry Potter's adventures at Hogwarts.
Generated: <|begin_of_text|>Tell me about Harry Potter's adventures at Hogwarts. What are some of the most memorable moments in the series?
Harry Potter's adventures at Hogwarts School of Witchcraft and Wizardry are the heart of the beloved book series by J.K. Rowling. The series f