Skip to content

Commit

Permalink
Fix: SPIDEr and SPIDErMax outputs and add test to check if SPIDErMax …
Browse files Browse the repository at this point in the history
…gives the same output than SPIDEr when having 1 cand per audio.
  • Loading branch information
Labbeti committed Nov 27, 2023
1 parent ffa6698 commit 9e54e2b
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ All notable changes to this project will be documented in this file.

### Fixed
- METEOR localization issue. ([#9](https://github.com/Labbeti/aac-metrics/issues/9))
- SPIDErMax output when `return_all_scores=False`.

## [0.4.6] 2023-10-10
### Added
Expand Down
7 changes: 5 additions & 2 deletions src/aac_metrics/functional/mult_cands.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from torch import Tensor


SELECTIONS = ("max", "min", "mean")


def mult_cands_metric(
metric: Callable,
metric_out_name: str,
Expand All @@ -31,7 +34,6 @@ def mult_cands_metric(
:param **kwargs: The keywords arguments given to the metric call.
:returns: A tuple of globals and locals scores or a scalar tensor with the main global score.
"""
SELECTIONS = ("max", "min", "mean")
if selection not in SELECTIONS:
raise ValueError(
f"Invalid argument {selection=}. (expected one of {SELECTIONS})"
Expand Down Expand Up @@ -106,4 +108,5 @@ def mult_cands_metric(
if return_all_scores:
return outs_corpus, outs_sents
else:
return outs_corpus[metric_out_name]
out_key = f"{metric_out_name}_{selection}"
return outs_corpus[out_key]
6 changes: 3 additions & 3 deletions src/aac_metrics/functional/spider.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ def spider(
f"Number of candidates and mult_references are different (found {len(candidates)} != {len(mult_references)})."
)

return_all_scores = True
sub_return_all_scores = True

cider_d_outs: tuple[dict[str, Tensor], dict[str, Tensor]] = cider_d( # type: ignore
candidates=candidates,
mult_references=mult_references,
return_all_scores=return_all_scores,
return_all_scores=sub_return_all_scores,
n=n,
sigma=sigma,
tokenizer=tokenizer,
Expand All @@ -79,7 +79,7 @@ def spider(
spice_outs: tuple[dict[str, Tensor], dict[str, Tensor]] = spice( # type: ignore
candidates=candidates,
mult_references=mult_references,
return_all_scores=return_all_scores,
return_all_scores=sub_return_all_scores,
cache_path=cache_path,
java_path=java_path,
tmp_path=tmp_path,
Expand Down
16 changes: 8 additions & 8 deletions src/aac_metrics/functional/spider_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,14 @@ def spider_max(
:returns: A tuple of globals and locals scores or a scalar tensor with the main global score.
"""
return mult_cands_metric(
spider,
"spider",
mult_candidates,
mult_references,
return_all_scores,
return_all_cands_scores,
"max",
torch.mean,
metric=spider,
metric_out_name="spider",
mult_candidates=mult_candidates,
mult_references=mult_references,
return_all_scores=return_all_scores,
return_all_cands_scores=return_all_cands_scores,
selection="max",
reduction=torch.mean,
# CIDEr args
n=n,
sigma=sigma,
Expand Down
55 changes: 55 additions & 0 deletions tests/test_sdmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import unittest

from unittest import TestCase

import torch

from torch import Tensor

from aac_metrics.classes.spider import SPIDEr
from aac_metrics.classes.spider_max import SPIDErMax


class TestSPIDErMax(TestCase):
# Tests methods
def test_sd_vs_sdmax(self) -> None:
sd = SPIDEr(return_all_scores=False)
sdmax = SPIDErMax(return_all_scores=False)

cands, mrefs = self._get_example_0()
mcands = [[cand] for cand in cands]

sd_score = sd(cands, mrefs)
sdmax_score = sdmax(mcands, mrefs)

assert isinstance(sd_score, Tensor)
assert isinstance(sdmax_score, Tensor)
self.assertTrue(
torch.allclose(sd_score, sdmax_score), f"{sd_score=}, {sdmax_score=}"
)

def _get_example_0(self) -> tuple[list[str], list[list[str]]]:
cands = [
"a man is speaking",
"birds chirping",
"rain is falling in the background",
]
mrefs = [
[
"man speaks",
"man is speaking",
"a man speaks",
"man talks",
"someone is talking",
],
["a bird is chirping"] * 5,
["heavy rain noise"] * 5,
]
return cands, mrefs


if __name__ == "__main__":
unittest.main()

0 comments on commit 9e54e2b

Please sign in to comment.