In [1]:
from collections import Counter
import torch as th


def all_equal(iterable):
    # This throws for empty iterables but that's fine
    pivot, *rest = iterable
    return all(x == pivot for x in rest)


def extract_features(stats: dict):
    raw_lls = [x['log_likelihoods'] for x in stats.values()]
    class_counts = Counter(len(x) for x in raw_lls)
    (normal_count, _), = class_counts.most_common(1)

    ll_tensor = th.tensor([
        ll for ll in raw_lls if len(ll) == normal_count
    ])

    # Convert to log odds
    num_features = ll_tensor.shape[1]
    if num_features == 2:
        return ll_tensor.diff(dim=1).squeeze(1) # .sigmoid()
    else:
        return ll_tensor.flatten(1)


In [2]:
from functools import partial
from itertools import chain
from pathlib import Path
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
import random
import torch as th


ROOT = Path("/mnt/ssd-1/nora/tuned-lenses/pythia")
downstream_dir = ROOT / "12b-deduped" / "affine" / "downstream"


def auroc_bootstrap(labels, scores, num_samples=1000, seed=0):
    rng = random.Random(seed)
    n = len(labels)
    aurocs = []
    for _ in range(num_samples):
        idx = rng.choices(range(n), k=n)
        aurocs.append(roc_auc_score(labels[idx], scores[idx]))

    return aurocs


def collect_hiddens(subdir: str, task: str, dtype: th.dtype = th.half):
    load_fn = partial(th.load, map_location="cpu")
    paths = downstream_dir.joinpath(subdir).glob(f"rank_*/{task}_hiddens.pt")
    return [
        h[:, ::4].mean(dim=0).to(dtype)
        for _, h in chain.from_iterable(map(load_fn, paths))
    ]

def train_val_split(normal: list, abnormal: list, seed: int):
    normal_train, normal_test = train_test_split(normal, random_state=seed)
    test = normal_test + abnormal
    labels = [0] * len(normal_test) + [1] * len(abnormal)
    return th.stack(normal_train).cuda(0), th.stack(test).cuda(0), th.tensor(labels).cuda(0)

def eval_mahalanobis(
    train: th.Tensor, test: th.Tensor, labels: th.Tensor, layer: int = -1, relative: bool = False
):
    cov = train[:, layer].T.float().cov()
    th.linalg.diagonal(cov).add_(1e-3)

    mu = train[:, layer].float().mean(dim=0)

    # Compute Mahalanobis distance
    demeaned = test[:, layer].float() - mu
    dists_sq = th.sum(
        demeaned * th.linalg.solve(cov, demeaned, left=False), dim=-1
    )

    if relative:
        diag_dists_sq = th.sum(
            demeaned * cov.diag().reciprocal() * demeaned, dim=-1
        )
        dists_sq -= diag_dists_sq

    return auroc_bootstrap(labels.cpu(), dists_sq.cpu())

In [3]:
from tuned_lens.stats import fit_anomaly_detector
import numpy as np
import torch as th


def compute_aurocs(task: str, **kwargs):
    ROOT = "/mnt/ssd-1/nora/tuned-lenses/pythia"

    # Tuned lens
    per_doc_injected = th.load(
        f"{ROOT}/12b-deduped/affine/downstream/injection/{task}_per_doc.pt",
        map_location="cpu",
    )
    per_doc_normal = th.load(
        f"{ROOT}/12b-deduped/affine/downstream/{task}_per_doc.pt",
        map_location="cpu",
    )
    normal = extract_features(per_doc_normal)
    injected = extract_features(per_doc_injected)

    aurocs = []
    for seed in range(10):
        result = fit_anomaly_detector(
            normal, injected, method="iforest", seed=seed, plot=False,
        )
        aurocs.extend(result.bootstrapped_aurocs)

    lo, mid, hi = np.quantile(aurocs, [0.025, 0.5, 0.975])
    print(f"Tuned lens iForest:\n{mid:.2f}\\;({lo:.2f}, {hi:.2f})")

    aurocs = []
    for seed in range(10):
        result = fit_anomaly_detector(
            normal, injected, method="lof", seed=seed, plot=False,
            **kwargs
        )
        aurocs.extend(result.bootstrapped_aurocs)

    lo, mid, hi = np.quantile(aurocs, [0.025, 0.5, 0.975])
    print(f"Tuned lens LOF:\n{mid:.2f}\\;({lo:.2f}, {hi:.2f})")

    # Logit lens
    per_doc_injected = th.load(
        f"{ROOT}/12b-deduped/affine/downstream/ll-injected/{task}_per_doc.pt",
        map_location="cpu",
    )
    per_doc_normal = th.load(
        f"{ROOT}/12b-deduped/affine/downstream/ll-normal/{task}_per_doc.pt",
        map_location="cpu",
    )
    normal = extract_features(per_doc_normal)
    injected = extract_features(per_doc_injected)

    aurocs = []
    for seed in range(10):
        result = fit_anomaly_detector(
            normal, injected, method="iforest", seed=seed, plot=False,
        )
        aurocs.extend(result.bootstrapped_aurocs)

    lo, mid, hi = np.quantile(aurocs, [0.025, 0.5, 0.975])
    print(f"Logit lens iForest:\n{mid:.2f}\\;({lo:.2f}, {hi:.2f})")

    aurocs = []
    for seed in range(10):
        result = fit_anomaly_detector(
            normal, injected, method="lof", seed=seed, plot=False,
            **kwargs
        )
        aurocs.extend(result.bootstrapped_aurocs)

    lo, mid, hi = np.quantile(aurocs, [0.025, 0.5, 0.975])
    print(f"Logit lens LOF:\n{mid:.2f}\\;({lo:.2f}, {hi:.2f})")

In [4]:
import numpy as np

def estimate_baseline_mahal_auroc(task: str, layer: int = -1):
    normal = collect_hiddens("ll-normal", task)
    injected = collect_hiddens("ll-injected", task)

    aurocs = sum(
        (
            eval_mahalanobis(
                *train_val_split(normal, injected, seed=i),
                layer=layer,
                relative=True
            )
            for i in range(10)
        ),
        []
    )
    lo, mid, hi = np.quantile(aurocs, [0.025, 0.5, 0.975])
    print(f"Baseline: {mid:.2f}\\;({lo:.2f}, {hi:.2f})")

In [5]:
for i in range(10):
    print(f"Layer {i}")
    estimate_baseline_mahal_auroc("sciq", layer=i)
    print()

Layer 0
Baseline: 0.47\;(0.42, 0.51)

Layer 1
Baseline: 0.75\;(0.70, 0.80)

Layer 2
Baseline: 0.96\;(0.94, 0.97)

Layer 3
Baseline: 0.96\;(0.94, 0.97)

Layer 4
Baseline: 0.92\;(0.90, 0.94)

Layer 5
Baseline: 0.75\;(0.72, 0.78)

Layer 6
Baseline: 0.58\;(0.54, 0.62)

Layer 7
Baseline: 0.52\;(0.48, 0.56)

Layer 8
Baseline: 0.45\;(0.42, 0.50)

Layer 9
Baseline: 0.30\;(0.26, 0.34)



## ARC Challenge

In [73]:
compute_aurocs("arc_challenge", metric="minkowski")

Tuned lens iForest:
0.71\;(0.65, 0.77)
Tuned lens LOF:
0.81\;(0.77, 0.84)
Logit lens iForest:
0.73\;(0.67, 0.79)
Logit lens LOF:
0.80\;(0.77, 0.83)


In [6]:
for i in range(10):
    print(f"Layer {i}")
    estimate_baseline_mahal_auroc("arc_challenge", layer=i)
    print()

Layer 0
Baseline: 0.43\;(0.39, 0.47)

Layer 1
Baseline: 0.47\;(0.43, 0.51)

Layer 2
Baseline: 0.78\;(0.75, 0.82)

Layer 3
Baseline: 0.76\;(0.73, 0.80)

Layer 4
Baseline: 0.65\;(0.61, 0.69)

Layer 5
Baseline: 0.57\;(0.53, 0.61)

Layer 6
Baseline: 0.50\;(0.46, 0.54)

Layer 7
Baseline: 0.47\;(0.43, 0.51)

Layer 8
Baseline: 0.42\;(0.38, 0.46)

Layer 9
Baseline: 0.29\;(0.26, 0.33)



In [7]:
for i in range(10):
    print(f"Layer {i}")
    estimate_baseline_mahal_auroc("boolq", layer=i)
    print()

Layer 0
Baseline: 0.50\;(0.50, 0.50)

Layer 1
Baseline: 1.00\;(1.00, 1.00)

Layer 2
Baseline: 1.00\;(1.00, 1.00)

Layer 3
Baseline: 1.00\;(1.00, 1.00)

Layer 4
Baseline: 1.00\;(1.00, 1.00)

Layer 5
Baseline: 1.00\;(1.00, 1.00)

Layer 6
Baseline: 1.00\;(1.00, 1.00)

Layer 7
Baseline: 1.00\;(1.00, 1.00)

Layer 8
Baseline: 1.00\;(1.00, 1.00)

Layer 9
Baseline: 1.00\;(0.99, 1.00)



In [32]:
estimate_baseline_mahal_auroc("arc_challenge", layer=-5)

Baseline: 0.57\;(0.53, 0.61)


## ARC Easy

In [6]:
compute_aurocs("arc_easy", metric="cosine")

Tuned lens iForest:
0.59\;(0.54, 0.62)
Tuned lens LOF:
0.76\;(0.73, 0.79)
Logit lens iForest:
0.53\;(0.50, 0.57)
Logit lens LOF:
0.59\;(0.56, 0.63)


In [7]:
compute_aurocs("arc_easy", metric="minkowski")

Tuned lens iForest:
0.59\;(0.54, 0.62)
Tuned lens LOF:
0.73\;(0.71, 0.76)
Logit lens iForest:
0.53\;(0.50, 0.57)
Logit lens LOF:
0.59\;(0.56, 0.62)


In [8]:
estimate_baseline_mahal_auroc("arc_easy", layer=-5)

Baseline: 0.73\;(0.70, 0.75)


In [8]:
for i in range(10):
    print(f"Layer {i}")
    estimate_baseline_mahal_auroc("arc_easy", layer=i)
    print()

Layer 0
Baseline: 0.48\;(0.45, 0.52)

Layer 1
Baseline: 0.64\;(0.60, 0.67)

Layer 2
Baseline: 0.89\;(0.87, 0.91)

Layer 3
Baseline: 0.90\;(0.88, 0.92)

Layer 4
Baseline: 0.82\;(0.80, 0.84)

Layer 5
Baseline: 0.73\;(0.70, 0.75)

Layer 6
Baseline: 0.64\;(0.62, 0.67)

Layer 7
Baseline: 0.59\;(0.56, 0.62)

Layer 8
Baseline: 0.54\;(0.51, 0.57)

Layer 9
Baseline: 0.37\;(0.34, 0.40)



## BoolQ

In [9]:
compute_aurocs("boolq", metric="cosine")

Tuned lens iForest:
0.99\;(0.98, 1.00)
Tuned lens LOF:
1.00\;(0.99, 1.00)
Logit lens iForest:
0.89\;(0.87, 0.91)
Logit lens LOF:
0.50\;(0.49, 0.50)


In [34]:
compute_aurocs("boolq", metric="minkowski")

Tuned lens iForest:
0.99\;(0.98, 0.99)
Tuned lens LOF:
1.00\;(1.00, 1.00)
Logit lens iForest:
0.89\;(0.87, 0.91)
Logit lens LOF:
0.61\;(0.57, 0.66)


In [6]:
estimate_baseline_mahal_auroc("boolq", layer=-5)

array([1., 1., 1.])

## MC TACO

In [None]:
compute_aurocs("mc_taco", metric="cosine")

In [37]:
compute_aurocs("mc_taco", metric="minkowski")

Tuned lens iForest:
0.74\;(0.71, 0.77)
Tuned lens LOF:
0.68\;(0.66, 0.70)
Logit lens iForest:
0.68\;(0.66, 0.69)
Logit lens LOF:
0.55\;(0.53, 0.59)


In [7]:
estimate_baseline_mahal_auroc("mc_taco", layer=-5)

array([0.99900598, 1.        , 1.        ])

In [11]:
compute_aurocs("mnli")

Tuned lens iForest:
0.98\;(0.98, 0.99)
Tuned lens LOF:
1.00\;(1.00, 1.00)
Logit lens iForest:
0.95\;(0.94, 0.96)
Logit lens LOF:
1.00\;(1.00, 1.00)


In [12]:
estimate_baseline_mahal_auroc("mnli", layer=-5)

Baseline: 1.00\;(1.00, 1.00)


## QNLI

In [36]:
compute_aurocs("qnli")

Tuned lens iForest:
0.99\;(0.99, 1.00)
Tuned lens LOF:
1.00\;(1.00, 1.00)
Logit lens iForest:
0.93\;(0.92, 0.95)
Logit lens LOF:
0.68\;(0.63, 0.71)


In [5]:
estimate_baseline_mahal_auroc("qnli", layer=-5)

Baseline: 1.00\;(1.00, 1.00)


## QQP

In [35]:
compute_aurocs("qqp")

Tuned lens iForest:
1.00\;(0.99, 1.00)
Tuned lens LOF:
1.00\;(1.00, 1.00)
Logit lens iForest:
0.90\;(0.89, 0.90)
Logit lens LOF:
0.79\;(0.76, 0.81)


In [6]:
estimate_baseline_mahal_auroc("qqp", layer=-5)

Baseline: 1.00\;(1.00, 1.00)


## SciQ

In [46]:
estimate_baseline_mahal_auroc("sciq", layer=-5)

Baseline: 0.75\;(0.72, 0.78)


In [47]:
compute_aurocs("sciq", metric="cosine")

Tuned lens iForest:
0.62\;(0.57, 0.69)
Tuned lens LOF:
0.64\;(0.59, 0.70)
Logit lens iForest:
0.75\;(0.71, 0.79)
Logit lens LOF:
0.70\;(0.65, 0.74)


In [None]:
compute_aurocs("sciq", metric="minowski")

## SST-2

In [45]:
estimate_baseline_mahal_auroc("sst", layer=-5)

Baseline: 1.00\;(1.00, 1.00)


In [74]:
compute_aurocs("sst", metric="cosine")

Tuned lens iForest:
0.98\;(0.95, 1.00)
Tuned lens LOF:
1.00\;(0.99, 1.00)
Logit lens iForest:
0.76\;(0.70, 0.81)
Logit lens LOF:
0.50\;(0.49, 0.51)


In [48]:
compute_aurocs("sst", metric="minkowski")

Tuned lens iForest:
1.00\;(0.98, 1.00)
Tuned lens LOF:
1.00\;(1.00, 1.00)
Logit lens iForest:
0.78\;(0.72, 0.83)
Logit lens LOF:
0.61\;(0.56, 0.65)


In [40]:
compute_aurocs("logiqa", metric="cosine")

Tuned lens iForest:
0.62\;(0.56, 0.72)
Tuned lens LOF:
0.49\;(0.43, 0.55)
Logit lens iForest:
0.79\;(0.75, 0.84)
Logit lens LOF:
0.78\;(0.73, 0.83)


In [None]:
compute_aurocs("logiqa", metric="minkowski")

In [43]:
estimate_baseline_mahal_auroc("logiqa", layer=-5)

Baseline: 0.60\;(0.53, 0.65)


In [53]:
from lm_eval.tasks import get_task

task = get_task("logiqa")
val_docs = list(task().validation_docs())

Found cached dataset logiqa (/mnt/ssd-1/nora/huggingface/datasets/logiqa/logiqa/0.0.1/4bf60449574fbe40eccf8b5177e294a179a5c85aeedce3210d359100e91af224)


  0%|          | 0/3 [00:00<?, ?it/s]

In [54]:
val_docs[0]

{'passage': 'Black Americans are twice as likely to suffer from hypertension as white Americans. The same is true when comparing Westernized black Africans to white Africans. The researchers hypothesized that the reason why westernized black people suffer from hypertension is the result of the interaction of two reasons? one is the high salt content of western foods, and the other is the adaptation mechanism of black genetic genes to the salt-deficient environment .',
 'query': "Passage: Black Americans are twice as likely to suffer from hypertension as white Americans. The same is true when comparing Westernized black Africans to white Africans. The researchers hypothesized that the reason why westernized black people suffer from hypertension is the result of the interaction of two reasons? one is the high salt content of western foods, and the other is the adaptation mechanism of black genetic genes to the salt-deficient environment .\nQuestion: The following conclusions about contem

In [41]:
compute_aurocs("piqa")

Tuned lens iForest:
0.42\;(0.37, 0.47)
Tuned lens LOF:
0.46\;(0.43, 0.50)
Logit lens iForest:
0.22\;(0.19, 0.27)
Logit lens LOF:
0.58\;(0.54, 0.61)


In [9]:
estimate_baseline_mahal_auroc("piqa", layer=-5)

array([0.95451446, 0.96614265, 0.97655125])

In [42]:
compute_aurocs("mnli")

Tuned lens iForest:
0.98\;(0.98, 0.99)
Tuned lens LOF:
1.00\;(1.00, 1.00)
Logit lens iForest:
0.95\;(0.94, 0.96)
Logit lens LOF:
1.00\;(1.00, 1.00)
