In [1]:
from scipy.stats import spearmanr

from bertscore import BERTScore
from conventional_metrics import BLEU, METEOR
from scm import SCM
from wmd import WMD
from common import Evaluator, LANGS
import pandas as pd




In [None]:
metrics = [
    BLEU(),
    METEOR(),
    # BERTScore(tgt_lang="en"),
    WMD(tgt_lang="en"),
    SCM(tgt_lang="en", use_tfidf=False),
    SCM(tgt_lang="en", use_tfidf=True)
]
correlations = {m.label: {} for m in metrics}
correlations["human"] = {}

for lang_pair in [pair for pair in LANGS if pair.split("-")[-1] == "en"]:
    print("Evaluating lang pair %s" % lang_pair)
    evaluator = Evaluator("data_dir", lang_pair, metrics)
    report = evaluator.evaluate()

    human_judgements = report["human"]
    for metric_label, vals in report.items():
        correlations[metric_label][lang_pair] = spearmanr(vals, human_judgements).correlation
    print(correlations)

corrs_df = pd.DataFrame(correlations)

[nltk_data] Downloading package wordnet to
[nltk_data]     /Users/xstefan3/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/xstefan3/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/xstefan3/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/xstefan3/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


Evaluating lang pair cs-en


100%|██████████| 2708/2708 [04:17<00:00, 10.50it/s]
100%|██████████| 2708/2708 [04:18<00:00, 10.48it/s]
WMD: 100%|██████████| 560/560 [00:01<00:00, 490.54it/s]
SCM: 100%|██████████| 560/560 [00:00<00:00, 3624.17it/s]
SCM: 100%|██████████| 560/560 [00:00<00:00, 2881.23it/s]
  0%|          | 0/2687 [00:00<?, ?it/s]

{'BLEU': {'cs-en': 0.00948259297595841}, 'METEOR': {'cs-en': 0.40259685151100627}, 'WMD': {'cs-en': -0.35835693780579453}, 'SCM': {'cs-en': 0.24208020921663048}, 'SCM_tfidf': {'cs-en': 0.2431252749190664}, 'human': {'cs-en': 1.0}}
Evaluating lang pair de-en


100%|██████████| 2687/2687 [04:38<00:00,  9.66it/s]
  9%|▉         | 253/2687 [02:50<2:10:26,  3.22s/it] 

In [None]:
import seaborn as sns
sns.heatmap(corrs_df, annot=True)