In [None]:
from comet import download_model, load_from_checkpoint
import sacrebleu
import pickle
import numpy as np
import pandas as pd

In [None]:
def read_file(fname):
    output = []
    with open(fname) as f:
        for line in f:
            output.append(line.strip())
    return output

In [None]:
with open("all_submissions/data_dict.pkl", "rb") as f:
    data_dict = pickle.load(f)

In [None]:
from typing import List

# Referred from https://github.com/amazon-science/doc-mt-metrics/blob/main/Prism/add_context.py
def add_context(orig_txt: List[str], context_same: List[str], 
                context_other: List[str], sender_ids: List[str], 
                sep_token: str = "</s>", ws: int = 2) -> List[str]:
    if not (len(orig_txt) == len(context_same)== len(context_other)):
        raise Exception(f'Lengths should match: len(orig_txt)={len(orig_txt)}, len(context_same)={len(context_same)}, len(context_other)={len(context_other)}')
    i = 0
    augm_txt = []
    for i in range(len(orig_txt)):
      context_window = []
      for j in range(max(0, i - ws), i):
        if sender_ids[j] == sender_ids[i]:
          context_window.append(context_same[j])
        else:
          context_window.append(context_other[j])
      augm_txt.append(" {} ".format(sep_token).join(context_window + [orig_txt[i]]))
    return augm_txt

class DocCometMetric():
  def __init__(self, model_name="Unbabel/wmt20-comet-qe-da", batch_size=64, ref_based=True):
    checkpoint_path = download_model(model_name)
    self.model = load_from_checkpoint(checkpoint_path)
    self.batch_size = batch_size
    self.model.enable_context()
    self.ref_based = ref_based

  def get_score(self, source, outputs, references=None):
    if not self.ref_based:
      del references
      return self.model.predict([{"mt": y, "src": x} for x, y in zip(source, outputs)],
        batch_size=self.batch_size, gpus=1, progress_bar=True)['scores']
    else:
       return self.model.predict([{"mt": y, "ref":z, "src": x} for x, y, z in zip(source, outputs, references)],
        batch_size=self.batch_size, gpus=1, progress_bar=False, devices=[self.device_id])['scores']

In [None]:
ref_metric = load_from_checkpoint(download_model("Unbabel/wmt22-comet-da"))
context_metric = DocCometMetric(model_name="Unbabel/wmt20-comet-qe-da", batch_size=256, ref_based=False)

In [None]:
def get_scores(df, columns):
    score_dict = {}
    for col in columns:
        try:
            score_dict[col] = {}
            score_dict[col]["comet"] = np.mean(ref_metric.predict([{"mt": y, "ref":z, "src": x} for x, y, z in zip(df["source"].to_list(),
                                                                                df[col].to_list(),
                                                                                df["reference"].to_list())],
                                batch_size=256, gpus=1)['scores'])
            score_dict[col]["chrf"] = sacrebleu.corpus_chrf(df[col].to_list(), [df["reference"].to_list()]).score
            score_dict[col]["bleu"] = sacrebleu.corpus_bleu(df[col].to_list(), [df["reference"].to_list()]).score
        except:
            continue
    return score_dict

In [None]:
import json
import sys
from typing import Any, Callable

import numpy as np

import sys
sys.path.append("/mnt/data-poseidon/sweta/chat-translation-generation/wmt24-chat-translation/metrics/MuDA")
from muda.langs import create_tagger
from muda.metrics import compute_metrics

def read_file(fname):
    output = []
    with open(fname) as f:
        for line in f:
            output.append(line.strip())
    return output


def recursive_map(func: Callable[[Any], Any], obj: Any) -> Any:
    if isinstance(obj, dict):
        return {k: recursive_map(func, v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [recursive_map(func, v) for v in obj]
    else:
        return func(obj)


def get_muda_accuracy_score(
    srcs,
    refs,
    docids,
    tgt_lang="de",
    awesome_align_model="bert-base-multilingual-cased",
    awesome_align_cachedir=None,
    load_refs_tags_file=None,
    cohesion_threshold=3,
    dump_hyps_tags_file=None,
    dump_refs_tags_file=None,
    dump_stats_file=None,
    phenomena=["lexical_cohesion", "formality", "verb_form", "pronouns"],
    hyps=None,
) -> None:

    tagger = create_tagger(
        tgt_lang,
        align_model=awesome_align_model,
        align_cachedir=awesome_align_cachedir,
        cohesion_threshold=cohesion_threshold,
    )

    if not load_refs_tags_file:
        preproc = tagger.preprocess(srcs, refs, docids)
        tagged_refs = []
        for doc in zip(*preproc):
            tagged_doc = tagger.tag(*doc, phenomena=phenomena)
            tagged_refs.append(tagged_doc)
    else:
        tagged_refs = json.load(open(load_refs_tags_file))

    preproc = tagger.preprocess(srcs, hyps, docids)
    tagged_hyps = []
    for doc in zip(*preproc):
        tagged_doc = tagger.tag(*doc, phenomena=phenomena)
        tagged_hyps.append(tagged_doc)

    tag_prec, tag_rec, tag_f1 = compute_metrics(tagged_refs, tagged_hyps)
    stat_dicts = []
    for tag in tag_f1:
        print(
            f"{tag} -- Prec: {tag_prec[tag]:.2f} Rec: {tag_rec[tag]:.2f} F1: {tag_f1[tag]:.2f}"
        )
        stat_dicts.append(
            {
                "tag": tag,
                "precision": tag_prec[tag],
                "recall": tag_rec[tag],
                "f1": tag_f1[tag],
            }
        )
    with open(dump_stats_file, "w") as f:
        for d in stat_dicts:
            f.write(json.dumps(d, ensure_ascii=False) + "\n")

    if dump_hyps_tags_file:
        with open(dump_hyps_tags_file, "w", encoding="utf-8") as f:
            json.dump(recursive_map(lambda t: t._asdict(), tagged_refs), f, indent=2)

    if not load_refs_tags_file and dump_refs_tags_file:
        with open(dump_refs_tags_file, "w", encoding="utf-8") as f:
            json.dump(recursive_map(lambda t: t._asdict(), tagged_refs), f, indent=2)

In [None]:
score_dict_all_lps = {}
for lp in ['en_nl', 'en_pt', 'en_de', 'en_ko', 'en_fr']:
    submission_cols = list(data_dict[lp.replace("-", '_')].keys())
    test_df = pd.read_csv(f"all_submissions/{lp.replace('_', '-')}.csv")
    test_df.fillna('', inplace=True)
    score_dict_all_lps[lp] = get_scores(test_df, submission_cols)

    # MuDA
    df = test_df[test_df.source_language == "en"]
    src_lang, tgt_lang = lp.split("_")
    for col in submission_cols:
        get_muda_accuracy_score(
            df["source"].to_list(),
            df["reference"].to_list(),
            df["doc_id"].to_list(),
            hyps=df[col].to_list(),
            tgt_lang=tgt_lang,
            awesome_align_model="bert-base-multilingual-cased",
            awesome_align_cachedir=None,
            dump_hyps_tags_file=f"muda_accuracy_results/{tgt_lang}.{col}.tags.json",
            dump_refs_tags_file=f"muda_accuracy_results/{tgt_lang}.ref.tags.json",
            dump_stats_file=f"muda_accuracy_results/{tgt_lang}.{col}.stats.json",
            phenomena=["lexical_cohesion", "formality", "verb_form", "pronouns"],
            cohesion_threshold=3,
        )

    # Context Comet QE
    for col in submission_cols:
        doc_dfs = []
        for _, df_group in test_df.groupby(["doc_id"]):
            df_group['seg_id'] = list(range(len(df_group)))
            df_group[f"source_with_context"]  = add_context(
                                                    orig_txt=df_group["source"].to_list(),
                                                    context_same=df_group["source"].to_list(),
                                                    context_other=df_group[col].to_list(),
                                                    sender_ids=df_group["sender"].to_list(),
                                                    sep_token=context_metric.model.encoder.tokenizer.sep_token,)
            df_group[f"mt_with_context"]  = add_context(
                                                    orig_txt=df_group[col].to_list(),
                                                    context_same=df_group[col].to_list(),
                                                    context_other=df_group["source"].to_list(),
                                                    sender_ids=df_group["sender"].to_list(),
                                                    sep_token=context_metric.model.encoder.tokenizer.sep_token,)
            doc_dfs.append(df_group)

        dfs_all = pd.concat(doc_dfs)
        score_dict_all_lps[lp][col]["context-comet-qe"] =  np.mean(context_metric.get_score(dfs_all[f"source_with_context"], dfs_all[f"mt_with_context"]))

In [None]:
score_dict_all_lps