In [1]:
from transformers import MBartForConditionalGeneration, AutoTokenizer

from huggingface_hub import notebook_login
from transformers import MBart50Tokenizer
import tqdm
import torch
import evaluate
import pandas as pd

import os

In [2]:
val = pd.read_json('/home/yush/kreol-benchmark/data/lang_data/en-cr/en-cr_dev.jsonl',lines=True)
val_inputs = list(val['input'])
val_labels = list(val['target'])

In [3]:
bleu = evaluate.load("bleu")
chrf = evaluate.load("chrf")

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device='cpu'

In [5]:
weights = os.listdir('/home/yush/kreol-benchmark/checkpoint/MBart50/bidirectional')
weights = sorted([x for x in weights if 'checkpoint' in x],key=lambda x: int(x.split('-')[1]))

In [6]:
def index_slice_dict(dicti,slice_begin,slice_end=None):
    sliced_dict = {}
    for k,v in dicti.items():
        if slice_end:
            sliced_dict[k] = v[slice_begin:slice_end]
        else:
            sliced_dict[k] = v[slice_begin:]
    return sliced_dict

def eval_batches_on_gpu(model,tokenizer,input_tokens,val,batch_num):
    batch_size = len(val) // batch_num
    output = []
    for i in range(0,len(val),batch_size):
        input_dict = index_slice_dict(input_tokens,i,i+batch_size)
        output_tokens_bn = model.generate(**input_dict)
        output_batch = tokenizer.batch_decode(output_tokens_bn, skip_special_tokens=True)
        output.extend(output_batch)
        torch.cuda.empty_cache()
    return output

In [7]:
torch.cuda.empty_cache()

In [8]:
src_lang = 'en_XX'
tgt_lang = 'cr_CR'

bidirectional = True

In [9]:
weights_dir = '/home/yush/kreol-benchmark/checkpoint/MBart50/finetune/bidirectional/dabre_finetune'
weights_hq = [x for x in os.listdir(weights_dir) if 'dabre' in x]
weights_hq = sorted(weights_hq, key=lambda x: int(x.split('_')[-1][2:]))

In [11]:
batch_num=5
for weight_ckpt in weights_hq:
    ckpt_dir = os.path.join(weights_dir,weight_ckpt)
    weights = os.listdir(ckpt_dir)
    for ckpt in weights:
        checkpoint = os.path.join(ckpt_dir,ckpt)
        tokenizer = AutoTokenizer.from_pretrained(checkpoint)
        tokenizer.src_lang = src_lang
        tokenizer.tgt_lang = tgt_lang
        model = MBartForConditionalGeneration.from_pretrained(checkpoint)
        model = model.to(device)
        input_tokens = tokenizer(val_inputs,max_length=128, truncation=True, padding="max_length",return_tensors='pt').to(device)
        output = eval_batches_on_gpu(model,tokenizer,input_tokens,val,batch_num)
        print(f'Evaluation Metrics for {src_lang} -> {tgt_lang} at {checkpoint}')
        print(chrf.compute(predictions=output,references=val_labels))
        print(bleu.compute(predictions=output,references=val_labels))
        print('----------------')
        if bidirectional:
            val_inputs_bi, val_labels_bi = val_labels,val_inputs
            tokenizer.src_lang,tokenizer.tgt_lang = tokenizer.tgt_lang,tokenizer.src_lang
            input_tokens = tokenizer(val_inputs_bi,max_length=128, truncation=True, padding="max_length",return_tensors='pt').to(device)
            output = eval_batches_on_gpu(model,tokenizer,input_tokens,val,batch_num)
            print(f'Evaluation Metrics for {tgt_lang} -> {src_lang} at {checkpoint}')
            print(chrf.compute(predictions=output,references=val_labels_bi))
            print(bleu.compute(predictions=output,references=val_labels_bi))
            print('----------------')

Evaluation Metrics for en_XX -> cr_CR at /home/yush/kreol-benchmark/checkpoint/MBart50/finetune/bidirectional/dabre_finetune/dabre_hq500/checkpoint-1384
{'score': 45.19214507281211, 'char_order': 6, 'word_order': 0, 'beta': 2}
{'bleu': 0.21449912352692466, 'precisions': [0.5818284424379232, 0.2922248803827751, 0.16310432569974553, 0.09319385952995517], 'brevity_penalty': 0.9513367927972872, 'length_ratio': 0.9524833369167921, 'translation_length': 8860, 'reference_length': 9302}
----------------
Evaluation Metrics for cr_CR -> en_XX at /home/yush/kreol-benchmark/checkpoint/MBart50/finetune/bidirectional/dabre_finetune/dabre_hq500
{'score': 47.29599650038933, 'char_order': 6, 'word_order': 0, 'beta': 2}
{'bleu': 0.23309650890088893, 'precisions': [0.5536822507240381, 0.28904886561954624, 0.17258883248730963, 0.10688050930460333], 'brevity_penalty': 1.0, 'length_ratio': 1.013735975673692, 'translation_length': 9668, 'reference_length': 9537}
----------------
Evaluation Metrics for en_XX 