Skip to content

Commit

Permalink
Merge pull request #40 from mjpost/sacrebleu
Browse files Browse the repository at this point in the history
Added sacrebleu datasets (closes #39)
  • Loading branch information
ricardorei committed Nov 13, 2021
2 parents 8723688 + 98751b7 commit a4c2cf1
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 12 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,16 @@ pip install unbabel-comet==1.0.0rc9
To develop locally install [Poetry](https://python-poetry.org/docs/#installation) and run the following commands:
```bash
git clone https://github.com/Unbabel/COMET
cd COMET
poetry install
```

Alternately, for development, you can run the CLI tools directly, e.g.,

```bash
PYTHONPATH=. ./comet/cli/score.py
```

## Scoring MT outputs:

### Via Bash:
Expand Down
40 changes: 34 additions & 6 deletions comet/cli/compare.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python3

# Copyright (C) 2020 Unbabel
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -18,13 +20,16 @@
optional arguments:
-h, --help Show this help message and exit.
-s SOURCES, --sources SOURCES
(required, type: Path_fr)
(required unless using -d, type: Path_fr)
-x SYSTEM_X, --system_x SYSTEM_X
(required, type: Path_fr)
-y SYSTEM_Y, --system_y SYSTEM_Y
(required, type: Path_fr)
-r REFERENCES, --references REFERENCES
(type: Path_fr, default: null)
(type: Path_fr, default: None)
-d SACREBLEU_TESTSET, --sacrebleu_dataset SACREBLEU_TESTSET
(optional, use in place of -s and -r, type: str
format TESTSET:LANGPAIR, e.g., wmt20:en-de)
--batch_size BATCH_SIZE
(type: int, default: 8)
--gpus GPUS (type: int, default: 1)
Expand All @@ -37,7 +42,7 @@
COMET model to be used. (type: Union[str, Path_fr], default: wmt20-comet-da)
--seed_everything SEED_EVERYTHING
Prediction seed. (type: int, default: 12)
"""

import json
Expand All @@ -54,15 +59,16 @@
_REFLESS_MODELS = ["comet-qe"] # All reference-free metrics are named with 'comet-qe'
# Due to small numerical differences in scores we consider that any system comparison
# with a difference bellow EPS to be considered a tie.
EPS = 0.0005
EPS = 0.0005


def compare_command() -> None:
parser = ArgumentParser(description="Command for comparing two MT systems.")
parser.add_argument("-s", "--sources", type=Path_fr, required=True)
parser.add_argument("-s", "--sources", type=Path_fr)
parser.add_argument("-x", "--system_x", type=Path_fr, required=True)
parser.add_argument("-y", "--system_y", type=Path_fr, required=True)
parser.add_argument("-r", "--references", type=Path_fr)
parser.add_argument("-d", "--sacrebleu_dataset", type=str)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--gpus", type=int, default=1)
parser.add_argument(
Expand Down Expand Up @@ -100,10 +106,28 @@ def compare_command() -> None:
cfg = parser.parse_args()
seed_everything(cfg.seed_everything)

if cfg.sources is None and cfg.sacrebleu_dataset is None:
parser.error(f"You must specify a source (-s) or a sacrebleu dataset (-d)")

if (cfg.sacrebleu_dataset is not None):
if cfg.references is not None or cfg.sources is not None:
parser.error(f"Cannot use sacrebleu datasets (-d) with manually-specified datasets (-s and -r)")

try:
testset, langpair = cfg.sacrebleu_dataset.rsplit(":", maxsplit=1)
cfg.sources = Path_fr(get_source_file(testset, langpair))
cfg.references = Path_fr(get_reference_files(testset, langpair)[0])
except ValueError:
parser.error("SacreBLEU testset format must be TESTSET:LANGPAIR, e.g., wmt20:de-en")
except Exception as e:
import sys
print("SacreBLEU error:", e, file=sys.stderr)
sys.exit(1)

if (cfg.references is None) and (
not any([i in cfg.model for i in _REFLESS_MODELS])
):
parser.error("{} requires -r/--references.".format(cfg.model))
parser.error("{} requires -r/--references or -d/--sacrebleu_dataset.".format(cfg.model))

model_path = (
download_model(cfg.model) if cfg.model in available_metrics else cfg.model
Expand Down Expand Up @@ -191,3 +215,7 @@ def compare_command() -> None:
with open(cfg.to_json, "w") as outfile:
json.dump(data, outfile, ensure_ascii=False, indent=4)
print("Predictions saved in: {}.".format(cfg.to_json))


if __name__ == "__main__":
compare_command()
40 changes: 34 additions & 6 deletions comet/cli/score.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python3

# Copyright (C) 2020 Unbabel
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -18,11 +20,14 @@
optional arguments:
-h, --help Show this help message and exit.
-s SOURCES, --sources SOURCES
(required, type: Path_fr)
(required unless using -d, type: Path_fr)
-t TRANSLATIONS, --translations TRANSLATIONS
(required, type: Path_fr)
-r REFERENCES, --references REFERENCES
(required, type: Path_fr)
(type: Path_fr, default: None)
-d SACREBLEU_TESTSET, --sacrebleu_dataset SACREBLEU_TESTSET
(optional, use in place of -s and -r, type: str
format TESTSET:LANGPAIR, e.g., wmt20:en-de)
--to_json TO_JSON (type: Union[bool, str], default: False)
--model MODEL (type: Union[str, Path_fr], default: wmt21-large-estimator)
--batch_size BATCH_SIZE
Expand All @@ -38,16 +43,17 @@
from jsonargparse import ArgumentParser
from jsonargparse.typing import Path_fr
from pytorch_lightning import seed_everything

from sacrebleu.utils import get_source_file, get_reference_files

_REFLESS_MODELS = ["comet-qe"]


def score_command() -> None:
parser = ArgumentParser(description="Command for scoring MT systems.")
parser.add_argument("-s", "--sources", type=Path_fr, required=True)
parser.add_argument("-t", "--translations", type=Path_fr, required=True)
parser.add_argument("-s", "--sources", type=Path_fr)
parser.add_argument("-t", "--translations", type=Path_fr)
parser.add_argument("-r", "--references", type=Path_fr)
parser.add_argument("-d", "--sacrebleu_dataset", type=str)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--gpus", type=int, default=1)
parser.add_argument(
Expand Down Expand Up @@ -79,10 +85,28 @@ def score_command() -> None:
cfg = parser.parse_args()
seed_everything(cfg.seed_everything)

if cfg.sources is None and cfg.sacrebleu_dataset is None:
parser.error(f"You must specify a source (-s) or a sacrebleu dataset (-d)")

if (cfg.sacrebleu_dataset is not None):
if cfg.references is not None or cfg.sources is not None:
parser.error(f"Cannot use sacrebleu datasets (-d) with manually-specified datasets (-s and -r)")

try:
testset, langpair = cfg.sacrebleu_dataset.rsplit(":", maxsplit=1)
cfg.sources = Path_fr(get_source_file(testset, langpair))
cfg.references = Path_fr(get_reference_files(testset, langpair)[0])
except ValueError:
parser.error("SacreBLEU testset format must be TESTSET:LANGPAIR, e.g., wmt20:de-en")
except Exception as e:
import sys
print("SacreBLEU error:", e, file=sys.stderr)
sys.exit(1)

if (cfg.references is None) and (
not any([i in cfg.model for i in _REFLESS_MODELS])
):
parser.error("{} requires -r/--references.".format(cfg.model))
parser.error("{} requires -r/--references or -d/--sacrebleu_dataset.".format(cfg.model))

model_path = (
download_model(cfg.model) if cfg.model in available_metrics else cfg.model
Expand Down Expand Up @@ -130,3 +154,7 @@ def score_command() -> None:
with open(cfg.to_json, "w") as outfile:
json.dump(data, outfile, ensure_ascii=False, indent=4)
print("Predictions saved in: {}.".format(cfg.to_json))


if __name__ == "__main__":
score_command()
6 changes: 6 additions & 0 deletions comet/cli/train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#!/usr/bin/env python3

# Copyright (C) 2020 Unbabel
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -103,3 +105,7 @@ def train_command() -> None:
raise Exception("Model configurations missing!")

trainer.fit(model)


if __name__ == "__main__":
train_command()
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ pandas==1.1.5
transformers==4.8.2
pytorch-lightning==1.3.5
jsonargparse==3.13.1
sacrebleu>=2.0.0
torchmetrics==0.6.0

0 comments on commit a4c2cf1

Please sign in to comment.