In [1]:
from copy import deepcopy
import itertools
import jsonlines
from pathlib import Path

from sklearn.metrics import classification_report
from sklearn.metrics import cohen_kappa_score

from tasks.pattern_matching import WinomtPatternMatchingTask
from tasks.pattern_matching.winomt_utils.language_predictors.util import WB_GENDER_TYPES, GENDER
from tasks.contrastive_conditioning import WinomtContrastiveConditioningTask
from translation_models.fairseq_models import load_sota_evaluator
from translation_models.testing_models import DictTranslationModel

In [2]:
# Define data paths
data_path = Path(".") / "data"
winomt_ende_translations_path = data_path / "aws.de.full.txt"
winomt_ende_annotator1_path = data_path / "en-de.annotator1.jsonl"
winomt_ende_annotator2_path = data_path / "en-de.annotator2.jsonl"

In [3]:
# Load annotations
with jsonlines.open(winomt_ende_annotator1_path) as f:
  annotations1 = {line["Sample ID"]: line for line in f}
with jsonlines.open(winomt_ende_annotator2_path) as f:
  annotations2 = {line["Sample ID"]: line for line in f}

# Flatten labels
for key in annotations1:
    annotations1[key]["label"] = annotations1[key]["label"][0]
for key in annotations2:
    annotations2[key]["label"] = annotations2[key]["label"][0]

In [4]:
# Remove samples that were only partially annotated
for key in list(annotations1.keys()):
    if key not in annotations2:
        del annotations1[key]
for key in list(annotations2.keys()):
    if key not in annotations1:
        del annotations2[key]

In [5]:
# Inter-annotator agreement before data cleaning
keys = list(annotations1.keys())
labels1 = [annotations1[key]["label"] for key in keys]
labels2 = [annotations2[key]["label"] for key in keys]
kappa = cohen_kappa_score(labels1, labels2)
print(kappa)

0.9525335231992406


In [6]:
# Clean data
for annotations in [annotations1, annotations2]:
    for key in keys:
        # Treat neutral as correct
        if annotations[key]["label"] == "Both / Neutral / Ambiguous":
            annotations[key]["label"] = annotations[key]["Gold Gender"].title()
        # Treat bad as incorrect
        if annotations[key]["label"] == "Translation too bad to tell":
            annotations[key]["label"] = "Male" if annotations[key]["Gold Gender"] == "female" else "Female"

In [7]:
# Inter-annotator agreement after data cleaning
keys = list(annotations1.keys())
labels1 = [annotations1[key]["label"] for key in keys]
labels2 = [annotations2[key]["label"] for key in keys]
kappa = cohen_kappa_score(labels1, labels2)
print(kappa)

0.9526963103122044


In [8]:
# Merge annotations
annotations = list(itertools.chain(annotations1.values(), annotations2.values()))

In [9]:
# Load translations
with open(winomt_ende_translations_path) as f:
    translations = {line.split(" ||| ")[0].strip(): line.split(" ||| ")[1].strip() for line in f}

In [10]:
# Run classic (pattern-matching) WinoMT
winomt_pattern_matching = WinomtPatternMatchingTask(
    tgt_language="de",
    skip_neutral_gold=False,
    verbose=True,
)
pattern_matching_evaluated_samples = winomt_pattern_matching.evaluate(DictTranslationModel(translations)).samples

3888it [00:00, 42008.08it/s]


In [11]:
# Run contrastive conditioning
evaluator_model = load_sota_evaluator("de")
winomt_contrastive_conditioning = WinomtContrastiveConditioningTask(
    evaluator_model=evaluator_model,
    skip_neutral_gold=False,
    category_wise_weighting=True,
)
contrastive_conditioning_weighted_evaluated_samples = winomt_contrastive_conditioning.evaluate(DictTranslationModel(translations)).samples

Using cache found in /home/user/vamvas/.cache/torch/hub/pytorch_fairseq_master
Loading codes from /home/user/vamvas/.cache/torch/pytorch_fairseq/0695ef328ddefcb8cbcfabc3196182f59c0e41e0468b10cc0db2ae9c91881fcc.bb1be17de4233e13870bd7d6065bfdb03fca0a51dd0f5d0b7edf5c188eda71f1/bpecodes ...
Read 30000 codes from the codes file.
100%|████████████████████████████████████████████████████████████| 3888/3888 [02:36<00:00, 24.77it/s]
100%|████████████████████████████████████████████████████████████| 3888/3888 [02:28<00:00, 26.12it/s]


In [12]:
# Create unweighted contrastive conditioning samples
contrastive_conditioning_unweighted_evaluated_samples = deepcopy(contrastive_conditioning_weighted_evaluated_samples)
for sample in contrastive_conditioning_unweighted_evaluated_samples:
    sample.weight = 1

In [13]:
# Evaluate
for evaluated_samples in [
    pattern_matching_evaluated_samples,
    contrastive_conditioning_unweighted_evaluated_samples,
    contrastive_conditioning_weighted_evaluated_samples,
]:
    predicted_labels = []
    gold_labels = []
    weights = []
    for annotation in annotations:
        gold_labels.append(WB_GENDER_TYPES[annotation["label"].lower()].value)
        sample_index = int(annotation["Index"])
        evaluated_sample = evaluated_samples[sample_index]
        assert evaluated_sample.sentence == annotation["Source Sentence"]
        if hasattr(evaluated_sample, "predicted_gender"):
            predicted_gender = evaluated_sample.predicted_gender.value
            # Convert neutral or unknown to gold in order to treat pattern-matching WinoMT as fairly as possible
            if predicted_gender in {GENDER.neutral.value, GENDER.unknown.value}:
                predicted_gender = evaluated_sample.gold_gender.value
        else:
            if evaluated_sample.is_correct:
                predicted_gender = WB_GENDER_TYPES[evaluated_sample.gold_gender].value
            else:
                predicted_gender = int(not WB_GENDER_TYPES[evaluated_sample.gold_gender].value)
        predicted_labels.append(predicted_gender)
        weights.append(getattr(evaluated_sample, "weight", 1))
    class_labels = [gender.value for gender in GENDER][:2]
    target_names = [gender.name for gender in GENDER][:2]
    print(classification_report(
        y_true=gold_labels,
        y_pred=predicted_labels,
        labels=class_labels,
        target_names=target_names,
        sample_weight=weights,
        zero_division=True,
        digits=3,
    ))

              precision    recall  f1-score   support

        male      0.962     0.863     0.910     321.0
      female      0.607     0.861     0.712      79.0

    accuracy                          0.863     400.0
   macro avg      0.784     0.862     0.811     400.0
weighted avg      0.892     0.863     0.871     400.0

              precision    recall  f1-score   support

        male      0.942     0.969     0.955     321.0
      female      0.857     0.759     0.805      79.0

    accuracy                          0.927     400.0
   macro avg      0.900     0.864     0.880     400.0
weighted avg      0.926     0.927     0.926     400.0

              precision    recall  f1-score   support

        male      0.982     0.991     0.987  343015.0
      female      0.914     0.842     0.876   38913.0

    accuracy                          0.976  381928.0
   macro avg      0.948     0.916     0.931  381928.0
weighted avg      0.975     0.976     0.975  381928.0

