## Probe Generalisation MVP

### Goals:
- [x] Choose a layer to train linear probes at
  - For now, we're doing layer 5, since that's the earliest layer that got perfect accuracy in the initial probe exploration
- [x] For each category in the 19th Feb dataset, train a linear probe
- [x] Generate a heatmap plot, where the (x, y)-th entry is the accuracy of the probe trained on x data, predicted on y data
- [x] Understand GPU capacity - can we do inference with 70B?
  - Yes, but it takes 2 A100 80GB GPUs when using bfloat16

### Timeline:
- 19/02/25 and 20/02/25

In [None]:
# IPython magic commands for autoreloading modules
%load_ext autoreload
%autoreload 2


# Imports
import torch
from models_under_pressure.probes import (
    create_activations,
    train_single_layer,
    compute_accuracy,
)

from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
import pandas as pd
from pathlib import Path
from joblib import Parallel, delayed

import seaborn as sns
import matplotlib.pyplot as plt

project_root = Path("..").resolve()

In [None]:
# Loading 20 Feb dataset
import json

feb_20_data = []
with open(project_root / "temp_data/dataset_21_feb.jsonl", "r") as f:
    for line in f:
        feb_20_data.append(json.loads(line))
df = pd.DataFrame(feb_20_data)

df["top_category"] = df["category"]
df["prompt_text"] = df["prompt"]

print(f"dataset shape: {df.shape}")

In [None]:
# Loading 19th Feb dataset
df = pd.read_csv(project_root / "temp_data/dataset_19_feb.csv")

In [None]:
# Split data by top category
categories = {}
for category in df["top_category"].unique():
    category_df = df[df["top_category"] == category]
    categories[category] = {
        "X": category_df["prompt_text"].tolist(),
        "y": category_df["high_stakes"].tolist(),
    }

In [None]:
# Loading model

os.environ["TOKENIZERS_PARALLELISM"] = "false"
model_name = "meta-llama/Llama-3.3-70B-Instruct"
cache_dir = "/scratch/ucabwjn/.cache"
device = "cuda:0"

# Load the LLaMA-3-1B model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    max_memory={0: "70GB", 1: "70GB", 3: "70GB"},
    torch_dtype=torch.float16,
    cache_dir=cache_dir,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

In [None]:
# Run the model on each category's data, recording the activations

for category in categories:
    categories[category]["acts"] = create_activations(
        model=model, tokenizer=tokenizer, text=categories[category]["X"], device=device
    )

In [None]:
# Plot the categories[category]['accuracy'] values in a heatmap


model_params = {"C": 1, "random_state": 42, "fit_intercept": False}

# For each category, train a linear probe on the activations for layer 5:
# for category in categories:
#     categories[category]['probes'] = Parallel(n_jobs=16)(
#     delayed(train_single_layer)(acts, categories[category]['y'], model_params) for acts in categories[category]["acts"]
# )

for category in tqdm(categories, desc="Training probes for categories"):
    categories[category]["probes"] = Parallel(n_jobs=16)(
        delayed(train_single_layer)(acts, categories[category]["y"], model_params)
        for acts in tqdm(
            categories[category]["acts"],
            desc=f"Layer probes for {category}",
            leave=False,
        )
    )

accuracies = []
for cat1 in categories:
    for cat2 in categories:
        accuracy = [
            compute_accuracy(
                probe=probe,  # categories[cat1]["probe"],
                activations=acts,  # categories[cat2]["acts"][5],
                labels=categories[cat2]["y"],
            )
            for probe, acts in zip(categories[cat1]["probes"], categories[cat2]["acts"])
        ]
        accuracies.append(
            {
                "probe": cat1,
                "dataset": cat2,
                "accuracy": accuracy,
            }
        )
accuracies = pd.DataFrame(accuracies)

Notes:

- We need to fix y axis labels and make it clear which axis is probes and which is dataset
- We need a better colour scheme: red should be bad!
- Clearly communicate the experimental procedure
- Put all of this in a doc with some time before the meeting tomorrow (20th Feb)


In [None]:
layers = [5, 22, 50, 79]


for layer in layers:
    accuracies[f"layer_{layer}_accuracy"] = accuracies["accuracy"].apply(
        lambda x: x[layer]
    )
    # Pivot the accuracies DataFrame to create a matrix
    accuracy_matrix = accuracies.pivot(
        index="dataset", columns="probe", values=f"layer_{layer}_accuracy"
    )

    # Create the heatmap
    plt.figure(figsize=(8, 6))
    sns.heatmap(
        accuracy_matrix,
        annot=True,  # Show values in cells
        fmt=".3f",  # Format numbers to 3 decimal places
        cmap="RdBu",  # Red (0) to Blue (1)
        vmin=0,  # Force scale to start at 0
        vmax=1,  # Force scale to end at 1
    )

    plt.title(f"Probe Generalization Across Categories, layer: {layer}")
    plt.xlabel("Probe Category")
    plt.ylabel("Test Dataset")
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()

    plt.savefig(project_root / "plots/probe_generalisation_heatmap.png")
    plt.show()


### Generate Completions On-Policy

To ensure the quality of the dataset we analyse completions to the prompts by the model we 
are probing. Our goal is to ensure the dataset actually captures high stakes situations,
instead of some confounding factor e.g. likely to generate the word "wait!" becuase of the type
of prompts.

In [None]:
from models_under_pressure.utils import generate_completions

df["completions"] = generate_completions(model, tokenizer, df["prompt"].tolist())

In [None]:
# Display a sample of prompt-completion pairs
pd.set_option("display.max_colwidth", None)  # Show full text
sample_size = 5
sample_df = df[["prompt", "completions"]].sample(n=sample_size, random_state=42)

print("Sample of Prompt-Completion Pairs:\n")
for idx, row in sample_df.iterrows():
    print(f"Prompt {idx}:\n{row['prompt']}\n")
    print(f"Completion:\n{row['completions']}\n")
    print("-" * 80 + "\n")


In [None]:
# Save the dataset with completions
output_path = project_root / "temp_data/dataset_21-02-2025_completions.csv"
df.to_csv(output_path, index=False)
print(f"Saved dataset with completions to {output_path}")
