Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,4 @@ results/
statistics/
.embedding_cache/
wandb/
uv.lock
3 changes: 1 addition & 2 deletions delphi/scorers/classifier/intruder.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def _build_prompt(
"""

examples = "\n".join(
f"Example {i}: {example}" for i, example in enumerate(sample.examples)
f"Example {i}:{example}" for i, example in enumerate(sample.examples)
)

return self.prompt(examples=examples)
Expand Down Expand Up @@ -319,7 +319,6 @@ async def _generate(self, sample: IntruderSentence) -> IntruderResult:
# default result is a error
return IntruderResult()
else:

try:
interpretation, prediction = self._parse(response.text)
except Exception as e:
Expand Down
13 changes: 11 additions & 2 deletions delphi/scorers/classifier/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def _prepare_text(
str_toks = example.str_tokens
assert str_toks is not None, "str_toks were not set"
clean = "".join(str_toks)

# Just return text if there's no highlighting
if not highlighted:
return clean, str_toks
Expand Down Expand Up @@ -125,9 +126,17 @@ def threshold_check(i):
token_pos = len(str_toks) - len(str_toks) // 4
if token_pos in below_threshold:
random_indices = [token_pos]
if n_incorrect > 1:

num_remaining_tokens_to_highlight = n_incorrect - 1
if num_remaining_tokens_to_highlight > 0:
remaining_tokens_below_threshold = below_threshold.tolist()
remaining_tokens_below_threshold.remove(token_pos)

random_indices.extend(
random.sample(below_threshold.tolist(), n_incorrect - 1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the old code, this could result in token_pos being selected again since it's still in below_threshold. then, after being turned into a set, random_indices would have one less element than expected, resulting in one fewer token being incorrectly highlighted than was specified by n_incorrect

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

random.sample(
remaining_tokens_below_threshold,
num_remaining_tokens_to_highlight,
)
)
else:
random_indices = random.sample(below_threshold.tolist(), n_incorrect)
Expand Down
Loading