Skip to content

Commit

Permalink
Fix: BERTScoreMRefs metric with 1 cand and 1 mref input.
Browse files Browse the repository at this point in the history
  • Loading branch information
Labbeti committed Dec 14, 2023
1 parent 83e71c1 commit 19f984b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ All notable changes to this project will be documented in this file.
### Changed
- Minor refactoring with `_get_device` function.

### Fixed
- Fix `BERTScoreMRefs` metric with 1 candidate and 1 reference.

## [0.5.0] 2023-12-08
### Added
- New `Vocab` metric to compute vocabulary size and vocabulary ratio.
Expand Down
9 changes: 7 additions & 2 deletions src/aac_metrics/functional/bert_score_mrefs.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,15 @@ def bert_score_mrefs(
device = _get_device(device)
flat_mrefs, sizes = flat_list(mult_references)
duplicated_cands = duplicate_list(candidates, sizes)
assert len(duplicated_cands) == len(flat_mrefs)

tfmers_verbosity = tfmers_logging.get_verbosity()
if verbose <= 1:
tfmers_logging.set_verbosity_error()

sents_scores = bert_score(
duplicated_cands,
flat_mrefs,
preds=duplicated_cands,
target=flat_mrefs,
model_name_or_path=None,
model=model, # type: ignore
user_tokenizer=tokenizer,
Expand All @@ -107,6 +108,10 @@ def bert_score_mrefs(
# Restore previous verbosity level
tfmers_logging.set_verbosity(tfmers_verbosity)

# note: torchmetrics returns a float if input contains 1 cand and 1 ref, even in list
if len(duplicated_cands) == 1 and all(isinstance(v, float) for v in sents_scores.values()):
sents_scores = {k: [v] for k, v in sents_scores.items()}

# sents_scores keys: "precision", "recall", "f1"
sents_scores = {k: unflat_list(v, sizes) for k, v in sents_scores.items()} # type: ignore

Expand Down

0 comments on commit 19f984b

Please sign in to comment.