In [1]:
import torch
torch.set_float32_matmul_precision("medium")
import transformers as tr
import datasets as ds

In [2]:
dataset = ds.load_dataset("RicardoRei/wmt-mqm-human-evaluation", split='train')

In [3]:
dataset = dataset.filter(lambda x: x['domain'] == 'news' and x['year'] == 2022)

In [65]:
model = "./output/dsbs-3way-small-0/checkpoint-54000/"

In [55]:
lps = set(dataset['lp'])

In [56]:
lp_map = {
    'en-ru': ('eng_Latn', 'rus_Cyrl'),
    'en-de': ('eng_Latn', 'deu_Latn'),
    'zh-en': ('zho_Hant', 'eng_Latn')
}

In [66]:
tokenizer = tr.AutoTokenizer.from_pretrained(model)
model = tr.AutoModelForSeq2SeqLM.from_pretrained(model)

In [8]:
def example_convert(example):
    src_lang, tgt_lang = lp_map[example['lp']]
    return {
        'input': f"Translation quality of: {src_lang} Source: {example['src']}; {tgt_lang} Reference: {example['ref']}; {tgt_lang} Translation: {example['mt']}"
    }

In [60]:
dataset[0]

{'lp': 'en-de',
 'src': 'Iran reports lowest number of daily COVID-19 cases in more than one year',
 'mt': '„Der Iran meldet die niedrigste Anzahl täglicher COVID-19-Fälle seit mehr als einem Jahr“',
 'ref': 'Iran meldet niedrigste Zahl täglicher COVID-19-Fälle seit über einem Jahr',
 'score': -1.2000000000000002,
 'system': 'bleurt_bestmbr',
 'annotators': 1,
 'domain': 'news',
 'year': 2022,
 'input': 'Translation quality of: eng_Latn Source: Iran reports lowest number of daily COVID-19 cases in more than one year; deu_Latn Reference: Iran meldet niedrigste Zahl täglicher COVID-19-Fälle seit über einem Jahr; deu_Latn Translation: „Der Iran meldet die niedrigste Anzahl täglicher COVID-19-Fälle seit mehr als einem Jahr“',
 'predicted_score': 0.95}

In [9]:
dataset = dataset.map(example_convert)

In [67]:
def tokenize(examples, prefix: str = 'explain bad: '):
    model_inputs = tokenizer([prefix + x for x in examples['input']], max_length=1024, truncation=True)
    return model_inputs

In [68]:
class_predict = dataset.map(tokenize, batched=True, num_proc=4, remove_columns=dataset.column_names)

In [69]:
collate_fn = tr.DataCollatorWithPadding(tokenizer, padding=True, max_length=1024, pad_to_multiple_of=8, return_tensors='pt')

In [70]:
data_loader = torch.utils.data.DataLoader(
    dataset=class_predict, shuffle=False, collate_fn=collate_fn, batch_size=16
)

In [71]:
device = torch.device("cuda")

In [72]:
from tqdm.auto import tqdm

In [73]:
model = model.to(device)
model.eval()

MT5ForConditionalGeneration(
  (shared): Embedding(250112, 512)
  (encoder): MT5Stack(
    (embed_tokens): Embedding(250112, 512)
    (block): ModuleList(
      (0): MT5Block(
        (layer): ModuleList(
          (0): MT5LayerSelfAttention(
            (SelfAttention): MT5Attention(
              (q): Linear(in_features=512, out_features=384, bias=False)
              (k): Linear(in_features=512, out_features=384, bias=False)
              (v): Linear(in_features=512, out_features=384, bias=False)
              (o): Linear(in_features=384, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 6)
            )
            (layer_norm): MT5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): MT5LayerFF(
            (DenseReluDense): MT5DenseGatedActDense(
              (wi_0): Linear(in_features=512, out_features=1024, bias=False)
              (wi_1): Linear(in_features=512, out_features=1024, bias=False)
          

In [74]:
with torch.inference_mode():
    results = []
    for batch in tqdm(data_loader):
        batch = {
            k: v.to(device) for k, v in batch.items()
        }
        outputs = model.generate(**batch, max_new_tokens=256)
        results += tokenizer.batch_decode(outputs, skip_special_tokens=True)

  0%|          | 0/1036 [00:00<?, ?it/s]



In [101]:
print(tokenizer.decode(class_predict[-10]['input_ids'], skip_special_tokens=True))



In [100]:
print(results[-10])

Found 4 translation error(s) Severity: Major, Error: "отмете воздушных и паровых", Context:...также <b>отмете воздушных и паровых</b> услуг.... Severity: Major, Error: "может привести к отключению", Context:...погода <b>может привести к отключению</b> электроэнергии,... Severity: Major, Error: "дорог и мостов, а", Context:...закрытию <b>дорог и мостов, а</b> также... Severity: Major, Error: "погода", Context:...плохая <b>погода</b> может...


In [92]:
len(results)

16575

In [91]:
Counter(results)

Counter({'Found 0 translation error(s)': 13690,
         'Found 5 translation error(s) Severity: Major, Error: "and", Context:...organizations <b>and</b> individuals... Severity: Major, Error: "and", Context:...organizations <b>and</b> individuals... Severity: Major, Error: "and", Context:...organizations <b>and</b> individuals... Severity: Major, Error: "and", Context:...organizations <b>and</b> individuals... Severity: Major, Error: "and", Context:...organizations <b>and</b> individuals... Severity: Major, Error: "and", Context:...organizations <b>and</b> individuals...': 12,
         'Found 2 translation error(s) Severity: Major, Error: "sie vom Schnee mitgerissen", Context:...wie <b>sie vom Schnee mitgerissen</b> wurden.... Severity: Major, Error: "sie vom Schnee mitgerissen", Context:...wie <b>sie vom Schnee mitgerissen</b> wurden....': 10,
         'Found 1 translation error(s) Severity: Major, Error: "название, которое было данью", Context:...само <b>название, которое было данью

In [90]:
results[6890]

'Found 0 translation error(s)'

In [47]:
freq = []
for r in results:
    if "Critical" in r:
        freq.append("Critical")
    elif "Minor" in r:
        freq.append("Minor")
    elif "Major" in r:
        freq.append("Major")

In [61]:
results[0]

'Found 2 translation error(s) Severity: Minor, Error: "„Der Iran meldet die niedrigste Anzahl", Context:...<b>„Der Iran meldet die niedrigste Anzahl</b> täglicher... Severity: Minor, Error: "mehr als einem Jahr“", Context:...Fälle seit<b>mehr als einem Jahr“</b>...'

In [46]:
dataset[-590]

{'lp': 'en-ru',
 'src': '"We all know that everything is more intense when the climate is warming.',
 'mt': '<unk>Мы все знаем, что все более интенсивно, когда климат потеплеет.',
 'ref': '«Все мы знаем, что процессы проходят интенсивнее, когда климат становится теплее.',
 'score': -2.0,
 'system': 'Lan-Bridge',
 'annotators': 1,
 'domain': 'news',
 'year': 2022,
 'input': 'Translation quality of: eng_Latn Source: "We all know that everything is more intense when the climate is warming.; rus_Cyrl Reference: «Все мы знаем, что процессы проходят интенсивнее, когда климат становится теплее.; rus_Cyrl Translation: <unk>Мы все знаем, что все более интенсивно, когда климат потеплеет.',
 'predicted_score': 0.95}

In [48]:
from collections import Counter

In [19]:
mapping = {
    "Very Poor": 0.05,  # Below 0 or 0.0 to 0.1
    "Poor": 0.15,  # 0.1 to 0.2
    "Fair": 0.25,  # 0.2 to 0.3
    "Below Average": 0.35,  # 0.3 to 0.4
    "Average": 0.45,  # 0.4 to 0.5
    "Above Average": 0.55,  # 0.5 to 0.6
    "Good": 0.65,  # 0.6 to 0.7
    "Very Good": 0.75,  # 0.7 to 0.8
    "Excellent": 0.85,  # 0.8 to 0.9
    "Outstanding": 0.95,  # 0.9 to 1.0 or above
}

In [20]:
results = [mapping[x] for x in results] 

In [21]:
if 'predicted_score' in dataset.column_names:
    dataset = dataset.remove_columns(['predicted_score'])

In [22]:
dataset = dataset.add_column("predicted_score", results)

In [23]:
from scipy.stats import kendalltau

In [25]:
for lp in lps:
    subset = dataset.filter(lambda ex: ex['lp'] == lp)
    stat = kendalltau(subset['score'], subset['predicted_score'])
    print(f"{lp} => {stat.statistic}")

en-ru => 0.30190361883703465
en-de => 0.19671692839160884
zh-en => 0.20840522422862823


In [26]:
# 1way
# en-de => 0.20536576632976486
# zh-en => 0.2569583572947779
# en-ru => 0.3095345683928942

# 2way
# en-de => 0.19671692839160884
# zh-en => 0.20840522422862823
# en-ru => 0.30190361883703465

# 3way
# en-de => 0.19590689576814513
# zh-en => 0.200590785622797
# en-ru => 0.2891122873449509