# Linear Probing

This notebook demonstrates how to use Linear Probing to train classifiers on model representations.

## Setup

In [None]:
import importlib.util

DEV = True

if importlib.util.find_spec("google.colab") is not None:
    MODE = "colab-dev" if DEV else "colab"
else:
    MODE = "local"

In [None]:
if MODE == "colab":
    %pip install -q tdhook
elif MODE == "colab-dev":
    !rm -rf tdhook
    !git clone https://github.com/Xmaster6y/tdhook -b main
    %pip install -q ./tdhook

## Usage

Load model and prepare data

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from tensordict import TensorDict
from datasets import load_dataset
import numpy as np
import torch

from tdhook.latent.probing import Probing, ProbeManager, LinearEstimator

In [None]:
model = AutoModelForCausalLM.from_pretrained("gpt2")
model.eval()
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

In [None]:
dataset = load_dataset("scikit-learn/imdb", split="train")

num_train = 100
num_test = 10

shuffled_dataset = dataset.shuffle(seed=42)
train_data = shuffled_dataset.select(range(num_train))
test_data = shuffled_dataset.select(range(num_train, num_train + num_test))

train_texts = [item["review"] for item in train_data]
train_labels = [1 if item["sentiment"] == "positive" else 0 for item in train_data]

test_texts = [item["review"] for item in test_data]
test_labels = [1 if item["sentiment"] == "positive" else 0 for item in test_data]

train_encoded = [tokenizer.encode(text, max_length=512, truncation=True) for text in train_texts]
test_encoded = [tokenizer.encode(text, max_length=512, truncation=True) for text in test_texts]

max_len = max(len(seq) for seq in train_encoded + test_encoded)
pad_token_id = tokenizer.pad_token_id

train_input_ids = torch.tensor([seq + [pad_token_id] * (max_len - len(seq)) for seq in train_encoded])
test_input_ids = torch.tensor([seq + [pad_token_id] * (max_len - len(seq)) for seq in test_encoded])

Set up linear probing

In [None]:
def compute_metrics(preds, labels):
    preds_np = preds.cpu().numpy() if hasattr(preds, "cpu") else np.asarray(preds)
    labels_np = labels.cpu().numpy() if hasattr(labels, "cpu") else np.asarray(labels)
    return {"accuracy": float((preds_np == labels_np).mean())}


manager = ProbeManager(
    estimator_class=LinearEstimator,
    estimator_kwargs={"d_latent": 768, "num_classes": 2, "epochs": 50, "verbose": False},
    compute_metrics=compute_metrics,
    allow_overwrite=True,
)

Train probe on training data and evaluate on test data

In [None]:
with Probing(
    "transformer.h.(0|5|10).mlp$",
    manager.probe_factory,
    additional_keys=["labels", "step_type"],
).prepare(model, in_keys=["input_ids"], out_keys=["logits"]) as hooked_model:
    with torch.no_grad():
        train_td = TensorDict(
            {
                "input_ids": train_input_ids,
                "labels": torch.tensor(train_labels),
                "step_type": "fit",
            },
            batch_size=len(train_texts),
        )
        train_out = hooked_model(train_td)

        test_td = TensorDict(
            {
                "input_ids": test_input_ids,
                "labels": torch.tensor(test_labels),
                "step_type": "predict",
            },
            batch_size=len(test_texts),
        )
        test_out = hooked_model(test_td)

Display probe metrics

In [None]:
print("Training metrics:")
for key, value in manager.fit_metrics.items():
    print(f"  {key}: {value}")

print("\nTest metrics:")
for key, value in manager.predict_metrics.items():
    print(f"  {key}: {value}")