In [1]:
import torch
import numpy as np
import torch
import random

seed = 633
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed);

In [2]:
# generate dataset from counterfact
from datasets import load_dataset, Dataset

ds_name = "NeelNanda/counterfact-tracing"
orig_dataset: Dataset = load_dataset(ds_name, split="train").shuffle(seed=seed)  # type: ignore

def get_labeled_texts(example, bos_token, few_shot_prefix=None):
    """few-shot-prefix is directly prepended to the prompt, without a newline, if provided"""
    prefix = "Does the following text contain a factual error?\n\n'''"
    if few_shot_prefix:
        prefix = few_shot_prefix + prefix
    prompt = bos_token + prefix + example["prompt"]

    suffix =  "'''\n\n"
    text_true = prompt + example["target_true"] + suffix
    text_false = prompt + example["target_false"] + suffix
    return {"texts": [text_true, text_false], "labels": [1, 0]}

def get_few_shot_prefix(examples):
    example_to_str = lambda example: example["text"] + ["Yes", "No"][example["label"]] + "\n\n"
    example_strs = []
    for example in examples:
        exs = get_labeled_texts(example, bos_token="")
        
        # dict of list into list of dict
        exs = [dict(zip(["text", "label"], t)) for t in zip(*exs.values())]
        
        example_strs.extend(example_to_str(ex) for ex in exs)
    np.random.shuffle(example_strs)
    return "".join(example_strs)

# dataset = dataset.map(map_fn, batched=True, batch_size=1, remove_columns=dataset.column_names)
orig_dataset

  from .autonotebook import tqdm as notebook_tqdm
Found cached dataset parquet (/mnt/ssd-2/hf_cache/NeelNanda___parquet/NeelNanda--counterfact-tracing-39c4f800d46af5cf/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
Loading cached shuffled indices for dataset at /mnt/ssd-2/hf_cache/NeelNanda___parquet/NeelNanda--counterfact-tracing-39c4f800d46af5cf/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-05a7c15bdf6df0e8.arrow


Dataset({
    features: ['relation', 'relation_prefix', 'relation_suffix', 'prompt', 'relation_id', 'target_false_id', 'target_true_id', 'target_true', 'target_false', 'subject'],
    num_rows: 21919
})

In [3]:
n_total = 2000
# texts = np.array(dataset[:n_total]["texts"])
# labels = np.array(dataset[:n_total]["labels"])
dataset = orig_dataset.select(range(n_total))
n_shots = 12
few_shot_set = orig_dataset.select(range(n_total, n_total + 500))


In [4]:
# load a model and tokenizer huggingface's transformers library
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer
model_name = "huggyllama/llama-7b"
# model_name = "gpt2-xl"
tokenizer = LlamaTokenizer.from_pretrained(model_name) if "llama" in model_name else AutoTokenizer.from_pretrained(model_name)
model = LlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, device_map={"": "cuda:5"}) if "llama" in model_name \
    else AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, device_map={"": "cuda:5"})


Loading checkpoint shards: 100%|██████████| 2/2 [00:14<00:00,  7.31s/it]


In [5]:
from tqdm import tqdm

In [6]:
def gather_logprobs(outputs, tokenized_text):
    # returns a [n_tokens,] numpy array of logprobs
    logprobs = outputs["logits"].log_softmax(dim=-1)
    return torch.gather(logprobs, 2, tokenized_text.input_ids.unsqueeze(2)).squeeze(2).squeeze(0)

def get_hiddens(dataset, few_shot_set):
    # run the model and get the hidden states at each layer
    # creates a len(texts) x n_layers x hidden_dimension array of zeros
    n_layer = 32 if model_name == "huggyllama/llama-7b" else model.config.n_layer
    hidden_size = model.config.hidden_size
    hiddens = np.zeros((2 * len(dataset), n_layer + 1, hidden_size))
    lm_probs = np.zeros((2 * len(dataset),))
    texts = np.zeros((2 * len(dataset),), dtype=object)
    labels = np.zeros((2 * len(dataset),), dtype=int)
    i = 0
    with torch.no_grad():
        for example in tqdm(dataset, total=n_total):
            few_shot_prefix = get_few_shot_prefix(few_shot_set.shuffle(seed=seed).select(range(n_shots))) if n_shots > 0 else None
            labeled_texts = get_labeled_texts(example, tokenizer.bos_token, few_shot_prefix=few_shot_prefix)
            for text, label in zip(labeled_texts["texts"], labeled_texts["labels"]):
                tokenized_text = tokenizer(text, return_tensors="pt").to("cuda")
                outputs = model(**tokenized_text, output_hidden_states=True)
                
                hidden_states = outputs["hidden_states"]  # a tuple of torch tensors, one for each layer
                hiddens[i, :, :] = torch.cat(hidden_states).cpu().numpy()[:, -1]  # all layers, last token

                # # get total logprob assigned to each string
                # logprobs_true = gather_logprobs(outputs_true, tokenized_text_true).cpu().numpy()
                # logprobs_false = gather_logprobs(outputs_false, tokenized_text_false).cpu().numpy()
                # # p_true, p_false = np.exp(logprobs_true.sum()), np.exp(logprobs_false.sum())
                # # lm_preds[2 * i] = p_true / (p_true + p_false)
                # # lm_preds[2 * i + 1] = p_false / (p_true + p_false)
                # # the above is unstable, so we do the following instead
                # logaddexp = np.logaddexp(logprobs_true.sum(), logprobs_false.sum())
                # lm_probs[2 * i] = np.exp(logprobs_true.sum() - logaddexp)
                # lm_probs[2 * i + 1] = np.exp(logprobs_false.sum() - logaddexp)
                
                # get prob assigned to each target
                no_id, yes_id = tokenizer.convert_tokens_to_ids(["No", "Yes"])
                p_no, p_yes = outputs["logits"][0, -1, [no_id, yes_id]].softmax(dim=-1).cpu().numpy()
                
                lm_probs[i] = p_no / (p_yes + p_no)  # probability of "no" (no error)
                labels[i] = label
                texts[i] = text
                i += 1
    return hiddens, lm_probs, labels, texts
hiddens, lm_probs, labels, texts = get_hiddens(dataset, few_shot_set)

  0%|          | 0/2000 [00:00<?, ?it/s]Loading cached shuffled indices for dataset at /mnt/ssd-2/hf_cache/NeelNanda___parquet/NeelNanda--counterfact-tracing-39c4f800d46af5cf/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-0e057a3748e6641b.arrow
  0%|          | 1/2000 [00:05<3:13:46,  5.82s/it]Loading cached shuffled indices for dataset at /mnt/ssd-2/hf_cache/NeelNanda___parquet/NeelNanda--counterfact-tracing-39c4f800d46af5cf/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-0e057a3748e6641b.arrow
  0%|          | 2/2000 [00:07<1:56:28,  3.50s/it]Loading cached shuffled indices for dataset at /mnt/ssd-2/hf_cache/NeelNanda___parquet/NeelNanda--counterfact-tracing-39c4f800d46af5cf/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-0e057a3748e6641b.arrow
  0%|          | 3/2000 [00:09<1:31:45,  2.76s/it]Loading cached shuffled indices for dataset at /mnt/ssd-2/hf_cache/NeelNanda___parquet/NeelNanda--counter

In [7]:
model.config.hidden_size

4096

In [8]:
# make a train/test split and keep them separate
shuffled_idxs = np.random.permutation(np.arange(len(hiddens)))
shuffled_hiddens = hiddens[shuffled_idxs]
shuffled_labels = labels[shuffled_idxs]
train_size = int(len(shuffled_hiddens) * 0.7)
train_hiddens = shuffled_hiddens[:train_size]
test_hiddens = shuffled_hiddens[train_size:]
train_labels = shuffled_labels[:train_size]
test_labels = shuffled_labels[train_size:]
train_texts = texts[shuffled_idxs][:train_size]
test_texts = texts[shuffled_idxs][train_size:]
test_lm_probs = lm_probs[shuffled_idxs][train_size:]

# train a classifier on the hidden states
from sklearn.linear_model import LogisticRegressionCV
# use cross-validation to find the best hyperparameters
# use the best hyperparameters to train a final model
Cs = 10 ** np.linspace(-5, 5, 11)
n_layer = 32 if model_name == "huggyllama/llama-7b" else model.config.n_layer
layer = n_layer // 2 + 1  # the layer to use for classification, somewhat arbitrary but middle layers work better
reporter = LogisticRegressionCV(Cs=Cs, cv=2).fit(train_hiddens[:, layer, :], train_labels)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

In [9]:
# get reporter regularization parameters
print("best regularization parameter:", reporter.C_[0])
# get model predictions on the test set
test_preds = reporter.predict(test_hiddens[:, layer, :])

is_correct = test_preds == test_labels
acc = np.mean(is_correct)
stderr = 0.5 / np.sqrt(len(test_labels))  # SE_prop = sqrt(p(1-p)/n)
print(f"Accuracy: {acc:.3f} ± {stderr:.3f}")
correct_examples = test_texts[is_correct]
incorrect_examples = test_texts[~is_correct]

# train acc
train_preds = reporter.predict(train_hiddens[:, layer, :])
train_acc = np.mean(train_preds == train_labels)
print(f"Train accuracy: {train_acc:.3f}")

# analyze these examples to see what the reporter is getting right and wrong...

best regularization parameter: 0.1
Accuracy: 0.823 ± 0.014
Train accuracy: 0.901


In [10]:
sum(lm_probs > 0.5) / len(lm_probs)

0.82025

In [11]:
# lm accuracy
cal_thresh = np.quantile(test_lm_probs, test_labels.mean())  # the LM usually just guesses "no error", so we calibrate it to the true proportion of "no error" examples
lm_preds = test_lm_probs > cal_thresh
lm_acc = np.mean(lm_preds == test_labels)
lm_stderr = 0.5 / np.sqrt(len(test_labels))  # SE_prop = sqrt(p(1-p)/n).
print(f"LM Accuracy: {lm_acc:.3f} ± {lm_stderr:.3f}")

LM Accuracy: 0.475 ± 0.014


In [12]:
# where do they disagree? Make a confusion matrix
lm_correct = (test_lm_probs > 0.5) == test_labels
reporter_correct = test_preds == test_labels

lm_better = lm_correct & ~reporter_correct
reporter_better = ~lm_correct & reporter_correct
unequal = lm_correct != reporter_correct
print(f"The LM is better on {lm_better.sum()} examples")
print(f"The reporter is better on {reporter_better.sum()} examples")
print(f"The LM and reporter disagree on {unequal.sum()} examples")

The LM is better on 107 examples
The reporter is better on 524 examples
The LM and reporter disagree on 631 examples


In [13]:
# save results
from pathlib import Path
import time

prefix = time.strftime("%Y-%m-%d-%H:%M")
cache_dir = Path("./lr-experiments") / ds_name / prefix
cache_dir.mkdir(exist_ok=True, parents=True)
np.save(cache_dir / "hiddens.npy", hiddens)
np.save(cache_dir / "lm_probs.npy", lm_probs)
np.save(cache_dir / "labels.npy", labels)
np.save(cache_dir / "texts.npy", texts)

In [14]:
test_texts[~unequal]

array(["<s>Does the following text contain a factual error?\n\n'''The headquarters of Schwartz Publishing is in Melbourne'''\n\nNo\n\nDoes the following text contain a factual error?\n\n'''United Kingdom is a part of the FIFA'''\n\nYes\n\nDoes the following text contain a factual error?\n\n'''In Belize, the language spoken is Portuguese'''\n\nYes\n\nDoes the following text contain a factual error?\n\n'''Luis Martins de Souza Dantas, a citizen of Poland'''\n\nYes\n\nDoes the following text contain a factual error?\n\n'''The Forks, Winnipeg is located in the country of Canada'''\n\nNo\n\nDoes the following text contain a factual error?\n\n'''Gianni Agnelli used to work in London'''\n\nYes\n\nDoes the following text contain a factual error?\n\n'''Luis Martins de Souza Dantas, a citizen of Brazil'''\n\nNo\n\nDoes the following text contain a factual error?\n\n'''United Kingdom is a part of the NATO'''\n\nNo\n\nDoes the following text contain a factual error?\n\n'''Tariq Abdul-Wahad follows

In [15]:
test_texts[lm_better]

array(["<s>Does the following text contain a factual error?\n\n'''Tariq Abdul-Wahad follows the religion of Islam'''\n\nNo\n\nDoes the following text contain a factual error?\n\n'''The Forks, Winnipeg is located in the country of Canada'''\n\nNo\n\nDoes the following text contain a factual error?\n\n'''Marc Swayze was born in Dublin'''\n\nYes\n\nDoes the following text contain a factual error?\n\n'''Marc Swayze was born in Monroe'''\n\nNo\n\nDoes the following text contain a factual error?\n\n'''LeRoy Collins worked in the city of Florida'''\n\nNo\n\nDoes the following text contain a factual error?\n\n'''The headquarters of Schwartz Publishing is in Beijing'''\n\nYes\n\nDoes the following text contain a factual error?\n\n'''Tariq Abdul-Wahad follows the religion of Christianity'''\n\nYes\n\nDoes the following text contain a factual error?\n\n'''The Forks, Winnipeg is located in the country of Ireland'''\n\nYes\n\nDoes the following text contain a factual error?\n\n'''The headquarters o