In [1]:
from copy import deepcopy
import itertools
import jsonlines
from pathlib import Path

from sklearn.metrics import classification_report
from sklearn.metrics import cohen_kappa_score

from tasks.pattern_matching import WinomtPatternMatchingTask
from tasks.pattern_matching.winomt_utils.language_predictors.util import WB_GENDER_TYPES, GENDER
from tasks.contrastive_conditioning import WinomtContrastiveConditioningTask
from translation_models.fairseq_models import load_sota_evaluator
from translation_models.testing_models import DictTranslationModel

In [2]:
# Define data paths
data_path = Path(".") / "data"
winomt_enru_translations_path = data_path / "google.ru.full.txt"
winomt_enru_annotator1_path = data_path / "en-ru.annotator1.jsonl"
winomt_enru_annotator2_path = data_path / "en-ru.annotator2.jsonl"

In [3]:
# Load annotations
with jsonlines.open(winomt_enru_annotator1_path) as f:
  annotations1 = {line["Sample ID"]: line for line in f}
with jsonlines.open(winomt_enru_annotator2_path) as f:
  annotations2 = {line["Sample ID"]: line for line in f}

# Flatten labels
for key in annotations1:
    annotations1[key]["label"] = annotations1[key]["label"][0]
for key in annotations2:
    annotations2[key]["label"] = annotations2[key]["label"][0]

In [4]:
# Remove samples that were only partially annotated
for key in list(annotations1.keys()):
    if key not in annotations2:
        del annotations1[key]
for key in list(annotations2.keys()):
    if key not in annotations1:
        del annotations2[key]

In [5]:
# Inter-annotator agreement before data cleaning
keys = list(annotations1.keys())
labels1 = [annotations1[key]["label"] for key in keys]
labels2 = [annotations2[key]["label"] for key in keys]
kappa = cohen_kappa_score(labels1, labels2)
print(kappa)

0.13114217077964607


In [6]:
# Clean data
for annotations in [annotations1, annotations2]:
    for key in keys:
        # Treat neutral as correct
        if annotations[key]["label"] == "Both / Neutral / Ambiguous":
            annotations[key]["label"] = annotations[key]["Gold Gender"].title()
        # Treat bad as wrong
        if annotations[key]["label"] == "Translation too bad to tell":
            annotations[key]["label"] = "Male" if annotations[key]["Gold Gender"] == "female" else "Female"

In [7]:
# Inter-annotator agreement after data cleaning
keys = list(annotations1.keys())
labels1 = [annotations1[key]["label"] for key in keys]
labels2 = [annotations2[key]["label"] for key in keys]
kappa = cohen_kappa_score(labels1, labels2)
print(kappa)

0.2018722773194922


In [8]:
# Merge annotations
annotations = list(itertools.chain(annotations1.values(), annotations2.values()))

In [9]:
# Load translations
with open(winomt_enru_translations_path) as f:
    translations = {line.split(" ||| ")[0].strip(): line.split(" ||| ")[1].strip() for line in f}

In [10]:
# Run classic (pattern-matching) WinoMT
winomt_pattern_matching = WinomtPatternMatchingTask(
    tgt_language="ru",
    skip_neutral_gold=False,
    verbose=True,
)
pattern_matching_evaluated_samples = winomt_pattern_matching.evaluate(DictTranslationModel(translations)).samples

3888it [00:00, 823608.79it/s]


In [11]:
# Run contrastive conditioning
evaluator_model = load_sota_evaluator("ru")
winomt_contrastive_conditioning = WinomtContrastiveConditioningTask(
    evaluator_model=evaluator_model,
    skip_neutral_gold=False,
    category_wise_weighting=True,
)
contrastive_conditioning_weighted_evaluated_samples = winomt_contrastive_conditioning.evaluate(DictTranslationModel(translations)).samples

Using cache found in /home/user/vamvas/.cache/torch/hub/pytorch_fairseq_master
Loading codes from /home/user/vamvas/.cache/torch/pytorch_fairseq/15bca559d0277eb5c17149cc7e808459c6e307e5dfbb296d0cf1cfe89bb665d7.ded47c1b3054e7b2d78c0b86297f36a170b7d2e7980d8c29003634eb58d973d9/bpecodes ...
Read 30000 codes from the codes file.
Loading codes from /home/user/vamvas/.cache/torch/hub/en24k.fastbpe.code ...
Read 24000 codes from the codes file.
Loading codes from /home/user/vamvas/.cache/torch/hub/ru24k.fastbpe.code ...
Read 24000 codes from the codes file.
100%|████████████████████████████████████████████████████████████| 3888/3888 [02:31<00:00, 25.68it/s]
100%|████████████████████████████████████████████████████████████| 3888/3888 [02:17<00:00, 28.20it/s]


In [12]:
# Create unweighted contrastive conditioning samples
contrastive_conditioning_unweighted_evaluated_samples = deepcopy(contrastive_conditioning_weighted_evaluated_samples)
for sample in contrastive_conditioning_unweighted_evaluated_samples:
    sample.weight = 1

In [13]:
# Evaluate
for evaluated_samples in [
    pattern_matching_evaluated_samples,
    contrastive_conditioning_unweighted_evaluated_samples,
    contrastive_conditioning_weighted_evaluated_samples,
]:
    predicted_labels = []
    gold_labels = []
    weights = []
    for annotation in annotations:
        gold_labels.append(WB_GENDER_TYPES[annotation["label"].lower()].value)
        sample_index = int(annotation["Index"])
        evaluated_sample = evaluated_samples[sample_index]
        assert evaluated_sample.sentence == annotation["Source Sentence"]
        if hasattr(evaluated_sample, "predicted_gender"):
            predicted_gender = evaluated_sample.predicted_gender.value
            # Convert neutral or unknown to gold in order to treat classic WinoMT as fairly as possible
            if predicted_gender in {GENDER.neutral.value, GENDER.unknown.value}:
                predicted_gender = evaluated_sample.gold_gender.value
        else:
            if evaluated_sample.is_correct:
                predicted_gender = WB_GENDER_TYPES[evaluated_sample.gold_gender].value
            else:
                predicted_gender = int(not WB_GENDER_TYPES[evaluated_sample.gold_gender].value)
        predicted_labels.append(predicted_gender)
        weights.append(getattr(evaluated_sample, "weight", 1))
    class_labels = [gender.value for gender in GENDER][:2]
    target_names = [gender.name for gender in GENDER][:2]
    print(classification_report(
        y_true=gold_labels,
        y_pred=predicted_labels,
        labels=class_labels,
        target_names=target_names,
        sample_weight=weights,
        zero_division=True,
        digits=3,
    ))

              precision    recall  f1-score   support

        male      0.826     0.844     0.835     321.0
      female      0.537     0.504     0.520     115.0

    accuracy                          0.755     436.0
   macro avg      0.682     0.674     0.678     436.0
weighted avg      0.750     0.755     0.752     436.0

              precision    recall  f1-score   support

        male      0.769     0.900     0.829     321.0
      female      0.467     0.243     0.320     115.0

    accuracy                          0.727     436.0
   macro avg      0.618     0.572     0.575     436.0
weighted avg      0.689     0.727     0.695     436.0

              precision    recall  f1-score   support

        male      0.801     0.960     0.873  297705.0
      female      0.522     0.155     0.239   84127.0

    accuracy                          0.782  381832.0
   macro avg      0.661     0.558     0.556  381832.0
weighted avg      0.739     0.782     0.733  381832.0

