From 81c0077f998672e989757ad007dfa329def00b64 Mon Sep 17 00:00:00 2001 From: d0rbu Date: Thu, 9 Oct 2025 16:01:12 -0500 Subject: [PATCH 1/2] miscellaneous small cleanups --- delphi/pipeline.py | 9 ++- delphi/scorers/classifier/intruder.py | 103 ++++++++++++-------------- delphi/scorers/classifier/sample.py | 85 +++++++++++++-------- 3 files changed, 104 insertions(+), 93 deletions(-) diff --git a/delphi/pipeline.py b/delphi/pipeline.py index 428d817e..91b739cf 100644 --- a/delphi/pipeline.py +++ b/delphi/pipeline.py @@ -161,8 +161,9 @@ async def process_item(self, item: Any, semaphore: asyncio.Semaphore) -> Any: async with semaphore: result = item for pipe in self.pipes: - if result is not None: - result = await pipe(result) - else: - pass + if result is None: + break + + result = await pipe(result) + return result diff --git a/delphi/scorers/classifier/intruder.py b/delphi/scorers/classifier/intruder.py index 28b36c05..aa1c87f8 100644 --- a/delphi/scorers/classifier/intruder.py +++ b/delphi/scorers/classifier/intruder.py @@ -1,6 +1,8 @@ import asyncio import re +from collections import defaultdict from dataclasses import dataclass +from itertools import cycle from typing import Literal from beartype.typing import Sequence @@ -136,12 +138,11 @@ def _get_quantiled_examples( """ Get the quantiled examples. """ - quantiles = {} + examples_grouped_by_quantiles = defaultdict(list) for example in examples: - if example.quantile not in quantiles: - quantiles[example.quantile] = [] - quantiles[example.quantile].append(example) - return quantiles + examples_grouped_by_quantiles[example.quantile].append(example) + + return examples_grouped_by_quantiles def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]: """ @@ -153,38 +154,39 @@ def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]: quantiled_intruder_sentences = self._get_quantiled_examples(record.test) intruder_sentences = record.not_active - for i, intruder in enumerate(intruder_sentences): - # select each quantile equally - quantile_index = i % len(quantiled_intruder_sentences.keys()) - active_examples = quantiled_intruder_sentences[quantile_index] + # select each quantile equally by repeatedly cycling through them + quantile_iterator = cycle(quantiled_intruder_sentences.items()) + for (active_quantile, all_active_examples), intruder in zip( + quantile_iterator, intruder_sentences + ): # if there are more examples than the number of examples to show, # sample which examples to show - examples_to_show = min(self.n_examples_shown - 1, len(active_examples)) - example_indices = self.rng.sample( - range(len(active_examples)), examples_to_show + num_active_examples = min( + # - 1 because we are going to insert the intruder sentence + self.n_examples_shown - 1, + len(all_active_examples), ) - active_examples = [active_examples[i] for i in example_indices] - - # convert the examples to strings + active_examples = self.rng.sample(all_active_examples, num_active_examples) - # highlights the active tokens + # highlights the active tokens with <<>> markers majority_examples = [] - active_tokens = 0 + num_active_tokens = 0 for example in active_examples: - text, _ = _prepare_text( + text, _str_tokens = _prepare_text( example, n_incorrect=0, threshold=0.3, highlighted=True ) majority_examples.append(text) - active_tokens += (example.activations > 0).sum().item() - active_tokens = int(active_tokens / len(active_examples)) + num_active_tokens += (example.activations > 0).sum().item() + + avg_active_tokens_per_example = num_active_tokens // len(active_examples) if self.type == "default": # if example is contrastive, use the active tokens # otherwise use the non-activating tokens if intruder.activations.max() > 0: n_incorrect = 0 else: - n_incorrect = active_tokens + n_incorrect = avg_active_tokens_per_example intruder_sentence, _ = _prepare_text( intruder, n_incorrect=n_incorrect, @@ -194,22 +196,15 @@ def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]: elif self.type == "internal": # randomly select a quantile to be the intruder, make sure it's not # the same as the source quantile - intruder_quantile_index = self.rng.randint( - 0, len(quantiled_intruder_sentences.keys()) - 1 - ) - while intruder_quantile_index == quantile_index: - intruder_quantile_index = self.rng.randint( - 0, len(quantiled_intruder_sentences.keys()) - 1 - ) - posible_intruder_sentences = quantiled_intruder_sentences[ - intruder_quantile_index - ] - intruder_index_selected = self.rng.randint( - 0, len(posible_intruder_sentences) - 1 - ) - intruder = posible_intruder_sentences[intruder_index_selected] + all_quantile_examples = list(quantiled_intruder_sentences.values()) + all_quantile_examples.remove(all_active_examples) + possible_intruder_sentences = self.rng.choice(all_quantile_examples) + + intruder = self.rng.choice(possible_intruder_sentences) # here the examples are activating, so we have to convert them # to non-activating examples + assert intruder.str_tokens is not None, "intruder has no str_tokens" + non_activating_intruder = NonActivatingExample( tokens=intruder.tokens, activations=intruder.activations, @@ -224,23 +219,27 @@ def _prepare_and_batch(self, record: LatentRecord) -> list[IntruderSentence]: highlighted=True, ) intruder = non_activating_intruder + else: + raise ValueError("Invalid intruder scorer type") # select a random index to insert the intruder sentence - intruder_index = self.rng.randint(0, examples_to_show) - majority_examples.insert(intruder_index, intruder_sentence) + intruder_index = self.rng.randint(0, num_active_examples) + examples = ( + majority_examples[:intruder_index] + + [intruder_sentence] + + majority_examples[intruder_index:] + ) - activations = [example.activations.tolist() for example in active_examples] - tokens = [example.str_tokens for example in active_examples] - activations.insert(intruder_index, intruder.activations.tolist()) - tokens.insert(intruder_index, intruder.str_tokens) + example_activations = [example.activations.tolist() for example in examples] + example_tokens = [example.str_tokens for example in examples] batches.append( IntruderSentence( - examples=majority_examples, + examples=examples, intruder_index=intruder_index, - chosen_quantile=quantile_index, - activations=activations, - tokens=tokens, + chosen_quantile=active_quantile, + activations=example_activations, + tokens=example_tokens, intruder_distance=intruder.distance, ) ) @@ -311,21 +310,11 @@ async def _generate(self, sample: IntruderSentence) -> IntruderResult: prompt = self._build_prompt(sample) try: response = await self.client.generate(prompt, **self.generation_kwargs) + interpretation, prediction = self._parse(response.text) except Exception as e: - logger.error(f"Error generating text: {e}") - response = None - - if response is None: + logger.error(str(e)) # default result is a error return IntruderResult() - else: - - try: - interpretation, prediction = self._parse(response.text) - except Exception as e: - logger.error(f"Parsing selections failed: {e}") - # default result is a error - return IntruderResult() # check that the only prediction is the intruder correct = prediction == sample.intruder_index diff --git a/delphi/scorers/classifier/sample.py b/delphi/scorers/classifier/sample.py index ac4f90f3..0f8c2240 100644 --- a/delphi/scorers/classifier/sample.py +++ b/delphi/scorers/classifier/sample.py @@ -1,6 +1,8 @@ import random +from collections import deque from dataclasses import dataclass -from typing import NamedTuple +from itertools import groupby +from typing import Callable, NamedTuple import torch @@ -88,73 +90,92 @@ def _prepare_text( threshold: float, highlighted: bool, ) -> tuple[str, list[str]]: + assert n_incorrect >= 0, ( + "n_incorrect must be 0 if highlighting correct example " + "or positive if creating false positives. " + f"Got {n_incorrect}" + ) + str_toks = example.str_tokens assert str_toks is not None, "str_toks were not set" - clean = "".join(str_toks) + # Just return text if there's no highlighting if not highlighted: + clean = "".join(str_toks) + return clean, str_toks - threshold = threshold * example.max_activation + abs_threshold = threshold * example.max_activation # Highlight tokens with activations above threshold - # if correct example + # if this is a correct example if n_incorrect == 0: - def threshold_check(i): - return example.activations[i] >= threshold + def is_above_activation_threshold(i: int) -> bool: + return example.activations[i] >= abs_threshold - return _highlight(str_toks, threshold_check), str_toks + return _highlight(str_toks, is_above_activation_threshold), str_toks # Highlight n_incorrect tokens with activations - # below threshold if incorrect example - below_threshold = torch.nonzero(example.activations <= threshold).squeeze() + # below threshold if this is an incorrect example + tokens_below_threshold = torch.nonzero( + example.activations <= abs_threshold + ).squeeze() # Rare case where there are no tokens below threshold - if below_threshold.dim() == 0: - logger.error("Failed to prepare example.") + if tokens_below_threshold.dim() == 0: + logger.error( + f"Tried to prepare false-positive example with {n_incorrect} tokens " + "incorrectly highlighted, but no tokens were below activation threshold." + ) return DEFAULT_MESSAGE, str_toks random.seed(22) - n_incorrect = min(n_incorrect, len(below_threshold)) + num_tokens_to_highlight = min(n_incorrect, tokens_below_threshold.shape[0]) # The activating token is always ctx_len - ctx_len//4 - # so we always highlight this one, and if n_incorrect > 1 - # we highlight n_incorrect-1 random ones + # so we always highlight this one, and if num_tokens_to_highlight > 1 + # we highlight num_tokens_to_highlight - 1 random ones token_pos = len(str_toks) - len(str_toks) // 4 - if token_pos in below_threshold: + if token_pos in tokens_below_threshold: random_indices = [token_pos] - if n_incorrect > 1: + + num_remaining_tokens_to_highlight = num_tokens_to_highlight - 1 + if num_remaining_tokens_to_highlight > 0: + remaining_tokens_below_threshold = tokens_below_threshold.tolist() + remaining_tokens_below_threshold.remove(token_pos) + random_indices.extend( random.sample(below_threshold.tolist(), n_incorrect - 1) ) else: - random_indices = random.sample(below_threshold.tolist(), n_incorrect) + random_indices = random.sample( + tokens_below_threshold.tolist(), num_tokens_to_highlight + ) random_indices = set(random_indices) - def check(i): + def is_false_positive(i): return i in random_indices - return _highlight(str_toks, check), str_toks + return _highlight(str_toks, is_false_positive), str_toks + +def _highlight(tokens: list[str], check: Callable[[int], bool]) -> str: + result: deque[str] = deque() -def _highlight(tokens, check): - result = [] + tokens_grouped_by_check_fn = groupby( + enumerate(tokens), key=lambda item: check(item[0]) + ) - i = 0 - while i < len(tokens): - if check(i): - result.append(L) + for should_highlight, token_group in tokens_grouped_by_check_fn: + highlighted_tokens = deque(token for _token_index, token in token_group) - while i < len(tokens) and check(i): - result.append(tokens[i]) - i += 1 + if should_highlight: + highlighted_tokens.appendleft(L) + highlighted_tokens.append(R) - result.append(R) - else: - result.append(tokens[i]) - i += 1 + result.extend(highlighted_tokens) return "".join(result) From 06eacd179ec9d4326170eb299fa16d6219846ff1 Mon Sep 17 00:00:00 2001 From: d0rbu Date: Thu, 9 Oct 2025 16:10:50 -0500 Subject: [PATCH 2/2] return None instead of when pipe result is none --- delphi/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/delphi/pipeline.py b/delphi/pipeline.py index 91b739cf..fcb01319 100644 --- a/delphi/pipeline.py +++ b/delphi/pipeline.py @@ -162,7 +162,7 @@ async def process_item(self, item: Any, semaphore: asyncio.Semaphore) -> Any: result = item for pipe in self.pipes: if result is None: - break + return None result = await pipe(result)