diff --git a/.gitignore b/.gitignore index fe7bd9e1..3c86fc59 100644 --- a/.gitignore +++ b/.gitignore @@ -183,3 +183,4 @@ results/ statistics/ .embedding_cache/ wandb/ +uv.lock diff --git a/delphi/scorers/classifier/intruder.py b/delphi/scorers/classifier/intruder.py index 28b36c05..6c143a3b 100644 --- a/delphi/scorers/classifier/intruder.py +++ b/delphi/scorers/classifier/intruder.py @@ -275,7 +275,7 @@ def _build_prompt( """ examples = "\n".join( - f"Example {i}: {example}" for i, example in enumerate(sample.examples) + f"Example {i}:{example}" for i, example in enumerate(sample.examples) ) return self.prompt(examples=examples) @@ -319,7 +319,6 @@ async def _generate(self, sample: IntruderSentence) -> IntruderResult: # default result is a error return IntruderResult() else: - try: interpretation, prediction = self._parse(response.text) except Exception as e: diff --git a/delphi/scorers/classifier/sample.py b/delphi/scorers/classifier/sample.py index ac4f90f3..40990e40 100644 --- a/delphi/scorers/classifier/sample.py +++ b/delphi/scorers/classifier/sample.py @@ -91,6 +91,7 @@ def _prepare_text( str_toks = example.str_tokens assert str_toks is not None, "str_toks were not set" clean = "".join(str_toks) + # Just return text if there's no highlighting if not highlighted: return clean, str_toks @@ -125,9 +126,17 @@ def threshold_check(i): token_pos = len(str_toks) - len(str_toks) // 4 if token_pos in below_threshold: random_indices = [token_pos] - if n_incorrect > 1: + + num_remaining_tokens_to_highlight = n_incorrect - 1 + if num_remaining_tokens_to_highlight > 0: + remaining_tokens_below_threshold = below_threshold.tolist() + remaining_tokens_below_threshold.remove(token_pos) + random_indices.extend( - random.sample(below_threshold.tolist(), n_incorrect - 1) + random.sample( + remaining_tokens_below_threshold, + num_remaining_tokens_to_highlight, + ) ) else: random_indices = random.sample(below_threshold.tolist(), n_incorrect)