Run this notebook with your predictions to generate your BLEU, CHRF, SacreBLEU, BERTScore scores.

In [66]:
import jieba
import json
import torch
from bert_score import score
from rouge_chinese import Rouge
from sacrebleu.metrics import BLEU, CHRF, TER

In [67]:
PREDICTION_FILE = '10k-rnn-baseline-ss-bi.json' # 10k-rnn-baseline-spacy-jieba-bi.json
REFERENCE_FILE = 'iwslt2017-en-zh-test.zh'

In [68]:
predictions = []
reference = []

with open(f'./predictions/{PREDICTION_FILE}', 'r', encoding='utf-8') as f:
    raw = f.read()
    pdict = json.loads(raw)
    if "predicted" in pdict:
        predictions = [*predictions, *pdict['predicted']]

with open(f'../tokenisation/data/{REFERENCE_FILE}', 'r',encoding='utf-8') as f:
    reference = [*reference, *(f.readlines())]

assert len(predictions) == len(reference), \
    'Received a wrong number of predictions. ' + \
    'Ensure that you have generated predictions for the whole test set. \n\n' + \
    f'Predictions Length: {len(predictions)}, Expected: {len(reference)}' 

refs = [[r.rstrip()] for r in reference]
sys = [str(p).rstrip() for p in predictions]

## Average on test set for BLEU , CHRF , CHRF++ , TER , ROUGE and BERTSCORE

In [69]:
# predicted , reference
sys[0] , refs[0]

('<s>在年前，我的了一个我的一个项目，我的一个的的的的的的的的的的。',
 ['几年前，在TED大会上， Peter Skillman 介绍了一个设计挑战 叫做“棉花糖挑战”'])

## Save Individual Scores

In [55]:
refs = [[r] for r in reference]
sys = [str(p).rstrip() for p in predictions]

In [56]:
# BLEU
bleu = BLEU(smooth_method='exp', tokenize='zh', max_ngram_order=4, effective_order=True)
bleu_scores = [bleu.sentence_score(h, r).score for h, r in zip(sys, refs)]

# CHRF
chrf = CHRF(word_order=0, beta=0, eps_smoothing=False)
chrf_scores = [chrf.sentence_score(h, r).score for h, r in zip(sys, refs)]

# CHRF++
chrf_plus = CHRF(word_order=2, beta=0, eps_smoothing=False)
chrf_plus_scores = [chrf_plus.sentence_score(h, r).score for h, r in zip(sys, refs)]

# TER
ter = TER(asian_support=True, normalized=True)
ter_scores = [ter.sentence_score(h, r).score for h, r in zip(sys, refs)]

# ROUGE 
rouge = Rouge()
def get_tok(sent):
    sent = str(sent)
    return ' '.join(jieba.lcut(sent))
r_scores = rouge.get_scores(list(map(get_tok, predictions)), list(map(get_tok, reference)), avg=False)

# BERTSCORE 
P, R, F1 = score(predictions, reference, lang='zh')
precision_scores = P.tolist()
recall_scores = R.tolist()
f1_scores = F1.tolist()


scores = {
    'BLEU': bleu_scores,
    'CHRF': chrf_scores,
    'CHRF++': chrf_plus_scores,
    'TER': ter_scores,
    'ROUGE':r_scores,
    'BERTSCORE_P' : precision_scores,
    'BERTSCORE_R' : recall_scores,
    'BERTSCORE_F1' : f1_scores
}

with open(f'individual_scores\{PREDICTION_FILE[:-5]}.json', 'w') as f:
    json.dump(scores, f, indent=4)


In [57]:
refs[0]

['几年前，在TED大会上， Peter Skillman 介绍了一个设计挑战 叫做“棉花糖挑战”\n']

In [63]:
print('Average BLEU-4:')
print(sum(bleu_scores) / len(refs))
print('-' * 50)

# CHRF
print("CHRF")
print(sum(chrf_scores) / len(refs))
print(50*'-')

# CHRF++
print("CHRF++")
print(sum(chrf_plus_scores) / len(refs))
print(50*'-')

# TER
print("TER")
print(sum(ter_scores) / len(refs))
print(50*'-')

# Rouge
print("Rouge")
def get_tok(sent):
    sent = str(sent)
    return ' '.join(jieba.lcut(sent))
rouge = Rouge()
scores = rouge.get_scores(list(map(get_tok, predictions)), list(map(get_tok, reference)), avg=True)
print(json.dumps(scores, indent=2))
print(50*'-')

# BERTScore
print("BERT")
# run this if your strings got corrupted 
predictions = [str(p) for p in predictions]
reference = [str(r) for r in reference]
P, R, F1 = score(predictions, reference, lang='zh') # default model for zh is bert-base-chinese
print(f'Precision: {P.mean().item()} | Recall: {R.mean().item()} | F1: {F1.mean().item()}') # Precision, Recall and F1

Average BLEU-4:
3.449183455164379
--------------------------------------------------
CHRF
3.9452189255904924
--------------------------------------------------
CHRF++
3.104803878479188
--------------------------------------------------
TER
107.91505345039711
--------------------------------------------------
