# Bilinear Probing

This notebook demonstrates how to use Bilinear Probing to train bilinear classifiers on model representations from two layers.

## Setup

In [None]:
import importlib.util

DEV = True

if importlib.util.find_spec("google.colab") is not None:
    MODE = "colab-dev" if DEV else "colab"
else:
    MODE = "local"

In [None]:
if MODE == "colab":
    %pip install -q tdhook
elif MODE == "colab-dev":
    !rm -rf tdhook
    !git clone https://github.com/Xmaster6y/tdhook -b main
    %pip install -q ./tdhook

## Imports

In [None]:
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from tensordict import TensorDict
from datasets import load_dataset
import torch

from tdhook.latent.probing import Probing, BilinearProbeManager, LowRankBilinearEstimator

## Load Model and Data

In [None]:
model = AutoModelForCausalLM.from_pretrained("gpt2")
model.eval()
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

In [None]:
dataset = load_dataset("scikit-learn/imdb", split="train")

num_train = 100
num_test = 20

shuffled_dataset = dataset.shuffle(seed=42)
train_data = shuffled_dataset.select(range(num_train))
test_data = shuffled_dataset.select(range(num_train, num_train + num_test))

train_texts = [item["review"] for item in train_data]
train_labels = [1 if item["sentiment"] == "positive" else 0 for item in train_data]

test_texts = [item["review"] for item in test_data]
test_labels = [1 if item["sentiment"] == "positive" else 0 for item in test_data]

train_encoded = [tokenizer.encode(text, max_length=512, truncation=True) for text in train_texts]
test_encoded = [tokenizer.encode(text, max_length=512, truncation=True) for text in test_texts]

max_len = max(len(seq) for seq in train_encoded + test_encoded)
pad_token_id = tokenizer.pad_token_id

train_input_ids = torch.tensor([seq + [pad_token_id] * (max_len - len(seq)) for seq in train_encoded])
test_input_ids = torch.tensor([seq + [pad_token_id] * (max_len - len(seq)) for seq in test_encoded])

## Set Up Bilinear Probing

Configure BilinearProbeManager for layers 0 and 5. For causal LMs we use the last token position.

In [None]:
def preprocess_last_token(data):
    data = data.detach()
    if data.dim() > 2:
        data = data[:, -1, :]
    return data.flatten(1)


def compute_metrics(preds, labels):
    preds_np = preds.cpu().numpy() if hasattr(preds, "cpu") else np.asarray(preds)
    labels_np = labels.cpu().numpy() if hasattr(labels, "cpu") else np.asarray(labels)
    return {"accuracy": float((preds_np == labels_np).mean())}


manager = BilinearProbeManager(
    pairs=[("transformer.h.0", "transformer.h.5")],
    estimator_class=LowRankBilinearEstimator,
    estimator_kwargs={
        "d_latent1": 768,
        "d_latent2": 768,
        "num_classes": 2,
        "epochs": 100,
        "lr": 1e-3,
        "batch_size": 32,
        "verbose": False,
    },
    compute_metrics=compute_metrics,
    allow_overwrite=True,
    data_preprocess_callback=preprocess_last_token,
)

## Train and Evaluate

Run forward passes with fit (step_type="fit") on train data and predict (step_type="predict") on test data.

In [None]:
manager.before_all()
with Probing(
    manager.key_pattern,
    manager.probe_factory,
    additional_keys=["labels", "step_type"],
    relative=False,
).prepare(model, in_keys=["input_ids"], out_keys=["logits"]) as hooked_model:
    with torch.no_grad():
        train_td = TensorDict(
            {
                "input_ids": train_input_ids,
                "labels": torch.tensor(train_labels),
                "step_type": "fit",
            },
            batch_size=len(train_texts),
        )
        hooked_model(train_td)

        test_td = TensorDict(
            {
                "input_ids": test_input_ids,
                "labels": torch.tensor(test_labels),
                "step_type": "predict",
            },
            batch_size=len(test_texts),
        )
        hooked_model(test_td)
manager.after_all()

for key, value in manager.fit_metrics.items():
    print(f"Train {key}: {value}")
for key, value in manager.predict_metrics.items():
    print(f"Test {key}: {value}")