# Eliciting latent knowledge - experiments

Alexander Cai, Gabriel Wu, Max Nadeau

Some experiments initially performed for [CS 229br Foundations of Deep Learning](https://boazbk.github.io/mltheoryseminar/) as taught in Spring 2023 at Harvard University by Boaz Barak.

This is a work-in-progress research draft to explore some properties of the "Contrast-Consistent Search" algorithm (and later algorithms)
for identifying a language model's internal representation of truth. This would be helpful in identifying model misbehaviours
or "eliciting latent knowledge" from intelligent models.

Our research questions:

- Does the "direction" discovered by CCS carry any semantic meaning outside its original setting (i.e. the residual stream on the final "positive" / "negative" token)?

See [adzcai/llama-ccs](https://github.com/adzcai/llama-ccs) for some preliminary experiments on the Meta LLaMA models.

Currently the [EleutherAI/elk](https://github.com/EleutherAI/elk) package must be installed in editable mode to download the associated prompt templates. Also note that it requires Python 3.10 which is not supported on Google Colab.

To do this, make sure you have the desired environment enabled, navigate to a convenient repository, and run the following.

In [None]:
# ! git clone https://github.com/EleutherAI/elk.git
# ! cd elk && pip install -qe .

In [None]:
# install the remaining requirements

# ! pip install -q \
#     circuitsvis \
#     plotly \
#     git+https://github.com/neelnanda-io/TransformerLens.git

## Resources

[EleutherAI/elk](https://github.com/EleutherAI/elk): Contains many further innovations on top of CCS. Very convenient tool for interacting with HF models and datasets.

[Discovering Latent Knowledge in Language Models Without Supervision](https://arxiv.org/abs/2212.03827): The original paper by Collin Burns and Haotian Ye et al that proposes "Contrast-Consistent Search" (CCS).
- [collin-burns/discovering_latent_knowledge](https://github.com/collin-burns/discovering_latent_knowledge): The corresponding repository.
  - This is claimed to be quite buggy. See [Bugs of the Initial Release of CCS](https://docs.google.com/document/d/16Q8ZJFloA-x2lR65hs80rbbjX70TteCSMhuDQGcC75Q/edit?usp=sharing) by Fabien Roger.
- [How "Discovering Latent Knowledge in Language Models Without Supervision" Fits Into a Broader Alignment Scheme](https://www.lesswrong.com/posts/L4anhrxjv8j2yRKKp/how-discovering-latent-knowledge-in-language-models-without)

[What Discovering Latent Knowledge Did and Did Not Find](https://www.lesswrong.com/posts/bWxNPMy5MhPnQTzKz/what-discovering-latent-knowledge-did-and-did-not-find-4): A writeup by Fabien Roger on takeaways from the original paper.

- [safer-ai/Exhaustive-CCS](https://github.com/safer-ai/Exhaustive-CCS): The corresponding repository. Similar to Collin Burns's but with fewer bugs.
- [Several experiments with CCS.](https://docs.google.com/document/d/1LCjjnUPN51gHl_rmCWEmmtbY-Wu1dixzOif14e-7i-U/edit)

## Getting started

In [None]:
import os
from pathlib import Path

cwd = Path(os.getcwd())
data_path = cwd / "data"
reporters_path = cwd / "reporters"

In [None]:
use_data_dir = True
"""Optionally store data in this folder instead of the default."""

if use_data_dir:
    data_path.mkdir(parents=True, exist_ok=True)
    os.environ["HF_HOME"] = data_path.as_posix()

Here we elicit latent knowledge from the [Pythia](https://github.com/EleutherAI/pythia) model family from EleutherAI.
We used the _non-deduplicated_ version of the models as of 17 April 2023. We use the 1B and 1.4B parameter models.

This model are notable in that every model in the family is trained on the same data in the same order.
A [paper](https://arxiv.org/pdf/2304.01373.pdf) with detailed information about these models is also available.

Additionally, these models are also available for use with [TransformerLens](https://github.com/neelnanda-io/TransformerLens/blob/main/transformer_lens/model_properties_table.md).

We use the [SuperGLUE (BoolQ)](https://huggingface.co/datasets/super_glue/viewer/boolq/test) dataset for a QA task and the [IMDB](https://huggingface.co/datasets/imdb) dataset for sentiment analysis.

We chose these datasets for preliminary analysis since they're simple archetypes for their respective tasks.

In [None]:
# ! elk elicit EleutherAI/pythia-1b 'super_glue boolq' --net ccs --out_dir 'reporters/ccs/pythia-1b/super_glue boolq'

# ! elk elicit EleutherAI/pythia-1b 'super_glue boolq' --net eigen --out_dir 'reporters/eigen/pythia-1b/super_glue boolq'

# ! elk elicit EleutherAI/pythia-1b imdb --net ccs --out_dir reporters/ccs/pythia-1b/imdb

# ! elk elicit EleutherAI/pythia-1b imdb --net eigen --out_dir reporters/eigen/pythia-1b/imdb

# ! elk elicit EleutherAI/pythia-1.4b 'super_glue boolq' --net ccs --out_dir reporters/ccs/pythia-1.4b/'super_glue boolq'

# ! elk elicit EleutherAI/pythia-1.4b 'super_glue boolq' --net eigen --out_dir reporters/eigen/pythia-1.4b/'super_glue boolq'

# ! elk elicit EleutherAI/pythia-1.4b imdb --net ccs --out_dir reporters/ccs/pythia-1.4b/imdb

# ! elk elicit EleutherAI/pythia-1.4b imdb --net eigen --out_dir reporters/eigen/pythia-1.4b/imdb

## Load the learned directions

In [None]:
import torch

# disable gradients since we're not doing any training here
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.cuda.device_count()

In [None]:
ccs_path = reporters_path / "ccs/pythia-1b/imdb"
list(ccs_path.iterdir())

In [None]:
reporters = [
    torch.load(reporter, map_location=device)
    for reporter in (ccs_path / "reporters").iterdir()
]

## Load the model

We use the [TransformerLens](https://github.com/neelnanda-io/TransformerLens) library to interact with model internals.

The reference documentation can be found [here](https://neelnanda-io.github.io/TransformerLens/transformer_lens.html).

The [main tutorial](https://neelnanda.io/transformer-lens-demo) was very helpful in getting started with the library.

In [None]:
from transformer_lens import HookedTransformer

model = HookedTransformer.from_pretrained("EleutherAI/pythia-1b").cuda()

In [None]:
n_layers = model.cfg.n_layers
n_layers, len(reporters) == n_layers

## Load prompts

In the original DLK paper, the authors found that out of the models they tested,
the single decoder-only model (GPT-J-6B) performed the worst (measured in terms of accuracy across datasets).

Later on, researchers at EleutherAI found that this could be resolved by prompting the model using different prompt templates.
Their method leverages that the truth of the given statement should be the same regardless of the prompt template chosen.

In [None]:
from elk.extraction.prompt_loading import load_prompts
import itertools
import pandas as pd

n_prompts = 12

prompt_dataset = load_prompts("imdb", split_type="val")
prompt_dataset = list(itertools.islice(prompt_dataset, n_prompts))

The loaded prompts (each element of the `prompt_dataset` have the following structure:

```
{
    "label": 0 if correct answer is "negative", 1 if correct answer is "positive"
    "prompts": [
        [
            {
                "answer": "negative" or "bad" or ...
                "text": formatted prompt with the "negative" answer
            },
            {
                "answer": "positive" or "good" or ...
                "text": formatted prompt with the "positive" answer
            }
        ],
        ...
    ],
    "template_names": [
        template name for prompt 0,
        ...
    ]
}
```

Note that prompts vary along two binary axes:

1. Whether the statement ends in the "positive" answer or the "negative" answer;
2. Whether the statement is factually correct or incorrect.

It's important to distinguish these two axes; The goal of CCS is to uncover the latter.

In [None]:
df = pd.DataFrame(
    [
        dict(
            negative_prompt=negative["text"],
            positive_prompt=positive["text"],
            negative_answer=negative["answer"],
            positive_answer=positive["answer"],
            incorrect_answer=negative["answer"]
            if prompts["label"]
            else positive["answer"],
            correct_answer=positive["answer"]
            if prompts["label"]
            else negative["answer"],
            template_name=template_name,
            template_id=i,
            prompt_id=j,
        )
        for j, prompts in enumerate(prompt_dataset)
        for i, ((negative, positive), template_name) in enumerate(
            zip(prompts["prompts"], prompts["template_names"])
        )
    ]
)
# set a multiindex using the prompt_id and template_id
df = df.set_index(["prompt_id", "template_id"])
df.head()

In [None]:
print(df.at[(0, 0), "negative_prompt"])

## Forward pass on the text

In [None]:
from jaxtyping import Float
from tqdm import tqdm
from transformer_lens.hook_points import HookPoint
import transformer_lens.utils as utils
from functools import partial

Here we fix a given template ID. This gives us a single fully-formatted prompt for each original text sample.

In [None]:
template_id = 1
prompts = df.loc[pd.IndexSlice[:, template_id], :].reset_index(drop=True)
prompts.head()

In [None]:
neg_prompts = prompts["negative_prompt"].tolist()
pos_prompts = prompts["positive_prompt"].tolist()
correct_answers = prompts["correct_answer"].tolist()
incorrect_answers = prompts["incorrect_answer"].tolist()

In [None]:
neg_tokens = model.to_tokens(neg_prompts)
pos_tokens = model.to_tokens(pos_prompts)

neg_str_tokens = model.to_str_tokens(neg_prompts)
pos_str_tokens = model.to_str_tokens(pos_prompts)

prompt_lengths = torch.tensor([len(tokens) for tokens in neg_str_tokens])

In [None]:
# the negative and positive prompts match up until the very last token,
# so we just record one and then get only the last token of the other one
neg_results = torch.zeros((n_prompts, n_layers, max(prompt_lengths)), device="cpu")
pos_results = torch.zeros((n_prompts, n_layers), device="cpu")
batch_range = torch.arange(n_prompts)


def projection(
    resid_pre: Float[torch.Tensor, "batch pos d_model"],  # at a given layer
    hook: HookPoint,
    layer: int,
):
    # TODO should we be normalizing here?
    # since the prompts have different lengths, technically this is more computation than we need to do
    neg_results[:, layer, :] = reporters[layer](resid_pre).cpu()
    return resid_pre


def final_projection(
    resid_pre: Float[torch.Tensor, "batch pos d_model"],  # at a given layer
    hook: HookPoint,
    layer: int,
):
    # TODO should we be normalizing here?
    x = resid_pre[batch_range, prompt_lengths - 1, :]
    pos_results[:, layer] = reporters[layer](x).cpu()
    return resid_pre

In [None]:
# should be quite fast with a few gpus
for layer in tqdm(range(n_layers)):
    act_name = utils.get_act_name("resid_pre", layer)

    patch_hook_fn = partial(projection, layer=layer)
    model.run_with_hooks(neg_tokens, fwd_hooks=[(act_name, patch_hook_fn)])

    patch_hook_fn = partial(final_projection, layer=layer)
    model.run_with_hooks(pos_tokens, fwd_hooks=[(act_name, patch_hook_fn)])

In [None]:
save_path = Path(f"results/ccs/pythia-1b/imdb/template-{template_id}")
save_path.mkdir(parents=True, exist_ok=True)
torch.save(neg_results, save_path / "neg.pt")
torch.save(pos_results, save_path / "pos.pt")

## Visualize outputs

In [None]:
import circuitsvis as cv
from circuitsvis.tokens import colored_tokens

In [None]:
neg_results = torch.load(save_path / "neg.pt")
pos_results = torch.load(save_path / "pos.pt")

projections = torch.cat([neg_results, torch.zeros((n_prompts, n_layers, 2))], axis=-1)
projections[batch_range, :, prompt_lengths] = pos_results
projections[batch_range, :, prompt_lengths + 1] = 0  # visualization

# make the signs across layers consistent with the final token
# since CCS only identifies the hyperplane up to sign
projections = pos_results.sign()[:, :, None] * projections
projections.shape

Here we choose a given prompt to visualize. "positive" answers always corresponds to _blue_ and "negative" answers always correspond to _red_
independently of the correct answer.

In [None]:
def plot_colors(prompt_id: int):
    flattened_tokens = (
        neg_str_tokens[prompt_id] + pos_str_tokens[prompt_id][-1:] + ["\n\n\n"]
    ) * n_layers
    flattened_projections = projections[
        prompt_id, :, : prompt_lengths[prompt_id] + 2
    ].flatten()

    # clip the values to between (-5.5, 4) to make the visualization more readable
    print("distance bounds:", flattened_projections.min(), flattened_projections.max())
    return colored_tokens(
        flattened_tokens, flattened_projections, min_value=-5, max_value=5
    )

In [None]:
plot_colors(prompt_id=4)

It's interesting that these almost look like attention patterns! The blue tokens are often ones that carry some sort of positive connotation
and the red ones often carry some kind of negative connotation. Let's see if there's any similarities with the actual attention patterns:

In [None]:
logits, cache = model.run_with_cache(neg_tokens[4][:12])
attention_pattern = cache['pattern', 0, 'attn']
cv.attention.attention_patterns(tokens=neg_str_tokens[4][:12], attention=attention_pattern)