In [1]:
from copy import deepcopy
import itertools
import jsonlines
from pathlib import Path
import random

from sklearn.metrics import classification_report
from sklearn.metrics import cohen_kappa_score

from tasks.contrastive_conditioning import MucowContrastiveConditioningTask
from translation_models.fairseq_models import load_sota_evaluator
from translation_models.testing_models import DictTranslationModel

In [2]:
# Define data paths
mucow_path = Path(".").parent.parent.parent / "data" / "mucow"
data_path = Path(".") / "data"
pattern_matching_log_path = data_path / "mucow_pattern_matching.results.en-ru.ensemble.log"
mucow_enru_annotator1_path = data_path / "en-ru.annotator1.jsonl"
mucow_enru_annotator2_path = data_path / "en-ru.annotator2.jsonl"

In [3]:
# Load annotations
with jsonlines.open(mucow_enru_annotator1_path) as f:
  annotations1 = {line["Sample ID"]: line for line in f}
with jsonlines.open(mucow_enru_annotator2_path) as f:
  annotations2 = {line["Sample ID"]: line for line in f}

# Flatten labels
for key in annotations1:
    annotations1[key]["label"] = annotations1[key]["label"][0]
for key in annotations2:
    annotations2[key]["label"] = annotations2[key]["label"][0]

In [4]:
# Remove samples that were only partially annotated
for key in list(annotations1.keys()):
    if key not in annotations2:
        del annotations1[key]
for key in list(annotations2.keys()):
    if key not in annotations1:
        del annotations2[key]
print(f"Number of annotations: {len(annotations1)} + {len(annotations2)}")

Number of annotations: 90 + 90


In [5]:
# Inter-annotator agreement before data cleaning
keys = list(annotations1.keys())
labels1 = [annotations1[key]["label"] for key in keys]
labels2 = [annotations2[key]["label"] for key in keys]
kappa = cohen_kappa_score(labels1, labels2)
print(kappa)

0.5696465696465697


In [6]:
# Clean data
skipped_keys = set()
for annotations in [annotations1, annotations2]:
    for key in keys:
        # Treat neutral as correct
        if annotations[key]["label"] == "Both / Neutral / Ambiguous":
            annotations[key]["label"] = "Correct Sense"
        # Treat bad translations as wrong
        if annotations[key]["label"] == "Translation too bad to tell / Third sense":
            annotations[key]["label"] = "Wrong Sense"
        # Skip bad samples
        if annotations[key]["label"] == "Bad sample / Ill-defined senses":
            skipped_keys.add(key)
for annotations in [annotations1, annotations2]:
    for key in skipped_keys:
        if key in annotations:
            del annotations[key]
print(f"Number of annotations: {len(annotations1)} + {len(annotations2)}")

Number of annotations: 86 + 86


In [7]:
# Inter-annotator agreement after data cleaning
keys = list(annotations1.keys())
labels1 = [annotations1[key]["label"] for key in keys]
labels2 = [annotations2[key]["label"] for key in keys]
kappa = cohen_kappa_score(labels1, labels2)
print(kappa)

0.8767908309455588


In [8]:
# Merge annotations
annotations = list(itertools.chain(annotations1.values(), annotations2.values()))

In [9]:
# Load task
mucow_enru = MucowContrastiveConditioningTask(
    tgt_language="ru",
    evaluator_model=None,
)

In [10]:
# Create list of annotated (= uncovered) samples
uncovered_samples = []
all_samples_dict = {(sample.src_sentence, sample.src_word): sample for sample in mucow_enru.samples}
for annotation in annotations:
    try:
        sample = all_samples_dict[(annotation["Source Sentence"], annotation["Word"])]
    except KeyError:  # Google Sheets removes leading apostrophe
        sample = all_samples_dict.get(("'" + annotation["Source Sentence"], annotation["Word"]), None)
    if sample is None:
        continue
    sample._gold_label = annotation["label"] == "Correct Sense"
    sample.translation = annotation["Translation"]
    uncovered_samples.append(sample)
print("Full number of uncovered samples: ", len(uncovered_samples))

Full number of uncovered samples:  170


In [11]:
# Create list of unannotated (= covered) samples based on log file
num_samples = 0
covered_samples = []
all_samples_dict = {(sample.src_sentence, sample.src_word): sample for sample in mucow_enru.samples}
with jsonlines.open(pattern_matching_log_path) as f:
    for line in f:
        sample = all_samples_dict.get((line["sentence"], line["src_word"]), None)
        if sample is None:
            continue  # contrastive conditioning not applicable
        if line["corpus"] == "opensubs":  # Only evaluate on in-domain samples because they have higher quality
            continue
        num_samples += 1
        if line["is_unknown"]:  # = uncovered
            continue
        sample._gold_label = line["is_correct"]
        sample.translation = line["translation"]
        covered_samples.append(sample)
random.seed(42)
coverage = len(covered_samples) / num_samples
print("Proportion of covered samples", coverage)
print("Full number of covered samples: ", len(covered_samples))

Proportion of covered samples 0.8155136268343816
Full number of covered samples:  389


In [12]:
# Sample a proportionate amount of covered samples
_covered_samples = []
for _ in range(int(len(uncovered_samples) * (1 / (1 - coverage)))):
    _covered_samples.append(random.choice(covered_samples))
covered_samples = _covered_samples
print(f"Testing on {len(covered_samples)} covered samples and {len(uncovered_samples)} uncovered samples")

Testing on 921 covered samples and 170 uncovered samples


In [13]:
# Evaluate classic MuCoW
# Count all covered samples as agreements; judge all unknown samples as incorrect translations
num_agreements = len(covered_samples) + sum(1 for sample in uncovered_samples if not sample._gold_label)
proportion_of_agreement = num_agreements / (len(covered_samples) + len(uncovered_samples))
print("Proportion of agreement: ", proportion_of_agreement)

Proportion of agreement:  0.8973418881759854


In [14]:
# Run contrastive conditioning
evaluator_model = load_sota_evaluator("ru")

mucow_enru.samples = uncovered_samples + covered_samples
mucow_enru.categories = {sample.category for sample in mucow_enru.samples}
mucow_enru.category_wise_weighting = True
mucow_enru.evaluator_model = evaluator_model

translations = DictTranslationModel({sample.src_sentence: sample.translation for sample in mucow_enru.samples})
contrastive_conditioning_weighted_evaluated_samples = mucow_enru.evaluate(translations).samples

Using cache found in /home/user/vamvas/.cache/torch/hub/pytorch_fairseq_master
Loading codes from /home/user/vamvas/.cache/torch/pytorch_fairseq/15bca559d0277eb5c17149cc7e808459c6e307e5dfbb296d0cf1cfe89bb665d7.ded47c1b3054e7b2d78c0b86297f36a170b7d2e7980d8c29003634eb58d973d9/bpecodes ...
Read 30000 codes from the codes file.
Loading codes from /home/user/vamvas/.cache/torch/hub/en24k.fastbpe.code ...
Read 24000 codes from the codes file.
Loading codes from /home/user/vamvas/.cache/torch/hub/ru24k.fastbpe.code ...
Read 24000 codes from the codes file.
100%|████████████████████████████████████████████████████████████| 1091/1091 [03:56<00:00,  4.61it/s]


In [15]:
# Create unweighted contrastive conditioning samples
contrastive_conditioning_unweighted_evaluated_samples = deepcopy(contrastive_conditioning_weighted_evaluated_samples)
for sample in contrastive_conditioning_unweighted_evaluated_samples:
    sample.weight = 1

In [16]:
# Evaluate contrastive conditioning
for evaluated_samples in [
    contrastive_conditioning_unweighted_evaluated_samples,
    contrastive_conditioning_weighted_evaluated_samples,
]:
    predicted_labels = []
    gold_labels = []
    weights = []
    for sample in evaluated_samples:
        gold_labels.append(int(sample._gold_label))
        predicted_labels.append(int(sample.is_correct))
        weights.append(getattr(sample, "weight", 1))
    class_labels = [0, 1]
    target_names = ["Wrong", "Correct"]
    print(classification_report(
        y_true=gold_labels,
        y_pred=predicted_labels,
        labels=class_labels,
        target_names=target_names,
        sample_weight=weights,
        zero_division=True,
        digits=3,
    ))

              precision    recall  f1-score   support

       Wrong      0.394     0.429     0.411     126.0
     Correct      0.925     0.914     0.919     965.0

    accuracy                          0.858    1091.0
   macro avg      0.659     0.671     0.665    1091.0
weighted avg      0.863     0.858     0.860    1091.0

              precision    recall  f1-score   support

       Wrong      0.496     0.437     0.465   26743.0
     Correct      0.957     0.966     0.962  348295.0

    accuracy                          0.928  375038.0
   macro avg      0.727     0.702     0.713  375038.0
weighted avg      0.924     0.928     0.926  375038.0

