## Figure out how to Implement Per Token Probes 


#### Goal 
- Produce per token probe plots that visualise how the 'high-stakes' nature of the probe varies over words.
- Figure out what this script needs to look like and apply it to an eval dataset from Figure 2/ a split from the variation types

#### Timeline

Created: 06/03/25 
Ideally Finished: 07/03/25 to send to supervisors


# Current Issue

Problem in the probe result batching causing everything to be 228

In [None]:
# imports
import pandas as pd
import json
from models_under_pressure.config import SYNTHETIC_DATASET_PATH, TRAIN_TEST_SPLIT
from models_under_pressure.experiments.dataset_splitting import load_train_test
from models_under_pressure.probes.model import LLMModel
from models_under_pressure.probes.probes import LinearProbe

In [None]:
# Load the model and dataset
model = LLMModel.load(model_name="meta-llama/Llama-3.2-1B-Instruct", device="cuda:2")

# Train a probe
train_dataset, _ = load_train_test(
    dataset_path=SYNTHETIC_DATASET_PATH,
    split_path=TRAIN_TEST_SPLIT,
)

model.device = "cuda:2"
print("Model Device:", model.device)
print("Model Internal Device:", model.model.device)

In [None]:
# Analyse the dataset used to train the probe:
print(f"Train dataset shape: {train_dataset.to_pandas().shape}")
print(
    f"Train dataset variation types: {train_dataset.to_pandas()['variation_type'].unique()}"
)
filtered_dataset = train_dataset.filter(
    lambda x: x.other_fields["variation_type"] == "prompt_style"
)
print(f"Filtered dataset shape: {filtered_dataset.to_pandas().shape}")
print(
    f"Filtered dataset variation types: {filtered_dataset.to_pandas()['variation_type'].unique()}"
)

filtered_dataset.to_pandas().head()


In [None]:
# Train the probe:
probe = LinearProbe(_llm=model, layer=11)
probe.fit(
    train_dataset.filter(lambda x: x.other_fields["variation_type"] == "prompt_style")
)

In [None]:
# Load the eval dataset of interest:
from models_under_pressure.interfaces.dataset import Dataset
from models_under_pressure.config import EVAL_DATASETS

eval_dataset = Dataset.load_from(
    EVAL_DATASETS["anthropic"]["file_path"],
    field_mapping=EVAL_DATASETS["anthropic"]["field_mapping"],
)

# Print the dataset and look at it's shape:
print(f"Eval dataset shape: {eval_dataset.to_pandas().shape}")
eval_dataset.to_pandas().head()

In [None]:
preds = probe.per_token_predictions(eval_dataset.inputs)
# Add predictions as a column to the eval dataset dataframe
df_eval = eval_dataset.to_pandas()
df_eval["predictions"] = list(preds)

In [None]:
df_eval["probe_probs"] = (
    df_eval["predictions"].apply(lambda x: x[x != -1].tolist()).apply(json.dumps)
)
df_eval.head()

df_eval.to_csv("../data/evals/anthropic_samples_per_token.csv")

In [None]:
df = pd.read_csv("../data/evals/anthropic_samples_per_token.csv")
df["probe_probs"] = df["probe_probs"].apply(json.loads)
print(df["probe_probs"][0])

In [None]:
from transformers import AutoTokenizer
from typing import Dict
import torch

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")


def tokenize_input(input_str: str) -> Dict[str, torch.Tensor]:
    token_text = tokenizer.apply_chat_template(
        json.loads(input_str), tokenize=False, add_generation_prompt=True
    )

    token_dict = tokenizer(token_text, return_tensors="pt")
    return token_dict


# Check the consistency of the probe probabilities and tokenized inputs:
tokenized_inputs = df["inputs"].apply(json.loads)
print(tokenized_inputs[0])
print(df["probe_probs"][0])

# Check the consistency of the probe probabilities and tokenized inputs:

tokenized_inputs = df["inputs"].apply(tokenize_input)
print(f"Length of tokenized inputs: {len(tokenized_inputs[0])}")
print(f"Length of probe probs: {len(df['probe_probs'][0])}")

In [None]:
# Apply chat template and then tokenize:
print(tokenized_inputs[0]["input_ids"].shape)
print(tokenizer.decode(tokenized_inputs[0]["input_ids"][0], skip_special_tokens=False))

In [None]:
# Apply chat template and tokenize:
tokenized_inputs2 = df["inputs"].apply(
    lambda x: tokenizer.apply_chat_template(
        json.loads(x), tokenize=True, return_tensors="pt", add_generation_prompt=True
    )[0]
)
print(tokenized_inputs2[0].shape)
print(tokenizer.decode(tokenized_inputs2[0].tolist(), skip_special_tokens=False))

In [None]:
df_eval["predictions"][0]