Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions delphi/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,9 @@ async def process_item(self, item: Any, semaphore: asyncio.Semaphore) -> Any:
async with semaphore:
result = item
for pipe in self.pipes:
if result is not None:
result = await pipe(result)
else:
pass
if result is None:
return None

result = await pipe(result)

return result
102 changes: 46 additions & 56 deletions delphi/scorers/classifier/intruder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import re
from collections import defaultdict
from dataclasses import dataclass
from itertools import cycle
from typing import Literal

from beartype.typing import Sequence
Expand Down Expand Up @@ -136,12 +138,11 @@ def _get_quantiled_examples(
"""
Get the quantiled examples.
"""
quantiles = {}
examples_grouped_by_quantiles = defaultdict(list)
for example in examples:
if example.quantile not in quantiles:
quantiles[example.quantile] = []
quantiles[example.quantile].append(example)
return quantiles
examples_grouped_by_quantiles[example.quantile].append(example)

return examples_grouped_by_quantiles

def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]:
"""
Expand All @@ -153,38 +154,39 @@ def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]:
quantiled_intruder_sentences = self._get_quantiled_examples(record.test)

intruder_sentences = record.not_active
for i, intruder in enumerate(intruder_sentences):
# select each quantile equally
quantile_index = i % len(quantiled_intruder_sentences.keys())

active_examples = quantiled_intruder_sentences[quantile_index]
# select each quantile equally by repeatedly cycling through them
quantile_iterator = cycle(quantiled_intruder_sentences.items())
for (active_quantile, all_active_examples), intruder in zip(
quantile_iterator, intruder_sentences
):
# if there are more examples than the number of examples to show,
# sample which examples to show
examples_to_show = min(self.n_examples_shown - 1, len(active_examples))
example_indices = self.rng.sample(
range(len(active_examples)), examples_to_show
num_active_examples = min(
# - 1 because we are going to insert the intruder sentence
self.n_examples_shown - 1,
len(all_active_examples),
)
active_examples = [active_examples[i] for i in example_indices]

# convert the examples to strings
active_examples = self.rng.sample(all_active_examples, num_active_examples)

# highlights the active tokens
# highlights the active tokens with <<>> markers
majority_examples = []
active_tokens = 0
num_active_tokens = 0
for example in active_examples:
text, _ = _prepare_text(
text, _str_tokens = _prepare_text(
example, n_incorrect=0, threshold=0.3, highlighted=True
)
majority_examples.append(text)
active_tokens += (example.activations > 0).sum().item()
active_tokens = int(active_tokens / len(active_examples))
num_active_tokens += (example.activations > 0).sum().item()

avg_active_tokens_per_example = num_active_tokens // len(active_examples)
if self.type == "default":
# if example is contrastive, use the active tokens
# otherwise use the non-activating tokens
if intruder.activations.max() > 0:
n_incorrect = 0
else:
n_incorrect = active_tokens
n_incorrect = avg_active_tokens_per_example
intruder_sentence, _ = _prepare_text(
intruder,
n_incorrect=n_incorrect,
Expand All @@ -194,22 +196,15 @@ def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]:
elif self.type == "internal":
# randomly select a quantile to be the intruder, make sure it's not
# the same as the source quantile
intruder_quantile_index = self.rng.randint(
0, len(quantiled_intruder_sentences.keys()) - 1
)
while intruder_quantile_index == quantile_index:
intruder_quantile_index = self.rng.randint(
0, len(quantiled_intruder_sentences.keys()) - 1
)
posible_intruder_sentences = quantiled_intruder_sentences[
intruder_quantile_index
]
intruder_index_selected = self.rng.randint(
0, len(posible_intruder_sentences) - 1
)
intruder = posible_intruder_sentences[intruder_index_selected]
all_quantile_examples = list(quantiled_intruder_sentences.values())
all_quantile_examples.remove(all_active_examples)
possible_intruder_sentences = self.rng.choice(all_quantile_examples)

intruder = self.rng.choice(possible_intruder_sentences)
# here the examples are activating, so we have to convert them
# to non-activating examples
assert intruder.str_tokens is not None, "intruder has no str_tokens"

non_activating_intruder = NonActivatingExample(
tokens=intruder.tokens,
activations=intruder.activations,
Expand All @@ -224,23 +219,27 @@ def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]:
highlighted=True,
)
intruder = non_activating_intruder
else:
raise ValueError("Invalid intruder scorer type")

# select a random index to insert the intruder sentence
intruder_index = self.rng.randint(0, examples_to_show)
majority_examples.insert(intruder_index, intruder_sentence)
intruder_index = self.rng.randint(0, num_active_examples)
examples = (
majority_examples[:intruder_index]
+ [intruder_sentence]
+ majority_examples[intruder_index:]
)

activations = [example.activations.tolist() for example in active_examples]
tokens = [example.str_tokens for example in active_examples]
activations.insert(intruder_index, intruder.activations.tolist())
tokens.insert(intruder_index, intruder.str_tokens)
example_activations = [example.activations.tolist() for example in examples]
example_tokens = [example.str_tokens for example in examples]

batches.append(
IntruderSentence(
examples=majority_examples,
examples=examples,
intruder_index=intruder_index,
chosen_quantile=quantile_index,
activations=activations,
tokens=tokens,
chosen_quantile=active_quantile,
activations=example_activations,
tokens=example_tokens,
intruder_distance=intruder.distance,
)
)
Expand Down Expand Up @@ -311,20 +310,11 @@ async def _generate(self, sample: IntruderSentence) -> IntruderResult:
prompt = self._build_prompt(sample)
try:
response = await self.client.generate(prompt, **self.generation_kwargs)
interpretation, prediction = self._parse(response.text)
except Exception as e:
logger.error(f"Error generating text: {e}")
response = None

if response is None:
logger.error(str(e))
# default result is a error
return IntruderResult()
else:
try:
interpretation, prediction = self._parse(response.text)
except Exception as e:
logger.error(f"Parsing selections failed: {e}")
# default result is a error
return IntruderResult()

# check that the only prediction is the intruder
correct = prediction == sample.intruder_index
Expand Down
81 changes: 48 additions & 33 deletions delphi/scorers/classifier/sample.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import random
from collections import deque
from dataclasses import dataclass
from typing import NamedTuple
from itertools import groupby
from typing import Callable, NamedTuple

import torch

Expand Down Expand Up @@ -88,48 +90,60 @@ def _prepare_text(
threshold: float,
highlighted: bool,
) -> tuple[str, list[str]]:
assert n_incorrect >= 0, (
"n_incorrect must be 0 if highlighting correct example "
"or positive if creating false positives. "
f"Got {n_incorrect}"
)

str_toks = example.str_tokens
assert str_toks is not None, "str_toks were not set"
clean = "".join(str_toks)

# Just return text if there's no highlighting
if not highlighted:
clean = "".join(str_toks)

return clean, str_toks

threshold = threshold * example.max_activation
abs_threshold = threshold * example.max_activation

# Highlight tokens with activations above threshold
# if correct example
# if this is a correct example
if n_incorrect == 0:

def threshold_check(i):
return example.activations[i] >= threshold
def is_above_activation_threshold(i: int) -> bool:
return example.activations[i] >= abs_threshold

return _highlight(str_toks, threshold_check), str_toks
return _highlight(str_toks, is_above_activation_threshold), str_toks

# Highlight n_incorrect tokens with activations
# below threshold if incorrect example
below_threshold = torch.nonzero(example.activations <= threshold).squeeze()
# below threshold if this is an incorrect example
tokens_below_threshold = torch.nonzero(
example.activations <= abs_threshold
).squeeze()

# Rare case where there are no tokens below threshold
if below_threshold.dim() == 0:
logger.error("Failed to prepare example.")
if tokens_below_threshold.dim() == 0:
logger.error(
f"Tried to prepare false-positive example with {n_incorrect} tokens "
"incorrectly highlighted, but no tokens were below activation threshold."
)
return DEFAULT_MESSAGE, str_toks

random.seed(22)

n_incorrect = min(n_incorrect, len(below_threshold))
num_tokens_to_highlight = min(n_incorrect, tokens_below_threshold.shape[0])

# The activating token is always ctx_len - ctx_len//4
# so we always highlight this one, and if n_incorrect > 1
# we highlight n_incorrect-1 random ones
# so we always highlight this one, and if num_tokens_to_highlight > 1
# we highlight num_tokens_to_highlight - 1 random ones
token_pos = len(str_toks) - len(str_toks) // 4
if token_pos in below_threshold:
if token_pos in tokens_below_threshold:
random_indices = [token_pos]

num_remaining_tokens_to_highlight = n_incorrect - 1
num_remaining_tokens_to_highlight = num_tokens_to_highlight - 1
if num_remaining_tokens_to_highlight > 0:
remaining_tokens_below_threshold = below_threshold.tolist()
remaining_tokens_below_threshold = tokens_below_threshold.tolist()
remaining_tokens_below_threshold.remove(token_pos)

random_indices.extend(
Expand All @@ -139,31 +153,32 @@ def threshold_check(i):
)
)
else:
random_indices = random.sample(below_threshold.tolist(), n_incorrect)
random_indices = random.sample(
tokens_below_threshold.tolist(), num_tokens_to_highlight
)

random_indices = set(random_indices)

def check(i):
def is_false_positive(i):
return i in random_indices

return _highlight(str_toks, check), str_toks
return _highlight(str_toks, is_false_positive), str_toks


def _highlight(tokens: list[str], check: Callable[[int], bool]) -> str:
result: deque[str] = deque()

def _highlight(tokens, check):
result = []
tokens_grouped_by_check_fn = groupby(
enumerate(tokens), key=lambda item: check(item[0])
)

i = 0
while i < len(tokens):
if check(i):
result.append(L)
for should_highlight, token_group in tokens_grouped_by_check_fn:
highlighted_tokens = deque(token for _token_index, token in token_group)

while i < len(tokens) and check(i):
result.append(tokens[i])
i += 1
if should_highlight:
highlighted_tokens.appendleft(L)
highlighted_tokens.append(R)

result.append(R)
else:
result.append(tokens[i])
i += 1
result.extend(highlighted_tokens)

return "".join(result)
Loading