In [4]:
import numpy as np
import torch
import pandas as pd


from transformers import MBartForConditionalGeneration, AutoTokenizer

In [5]:
src_lang = 'en_XX'
tgt_lang = 'cr_CR'

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [6]:
val = pd.read_json('/home/yush/kreol-benchmark/data/lang_data/en-cr/en-cr_dev.jsonl',lines=True)
val_inputs = list(val['input'])
val_labels = list(val['target'])

In [7]:
ckpt1 ="/home/yush/kreol-benchmark/checkpoint_tests/checkpoint-120000_best"
tokenizer_1 = AutoTokenizer.from_pretrained(ckpt1)

tokenizer_1.src_lang=src_lang
tokenizer_1.tgt_lang=tgt_lang

input_tokens_1 = tokenizer_1(val_inputs,max_length=128, truncation=True, padding="max_length",return_tensors='pt')

mdl_1 = MBartForConditionalGeneration.from_pretrained(ckpt1)
mdl_1 = mdl_1.to(device)

In [8]:
ckpt2 ="/home/yush/kreol-benchmark/checkpoint_tests/checkpoint-11_best500ft"
tokenizer_2 = AutoTokenizer.from_pretrained(ckpt2)

tokenizer_2.src_lang=src_lang
tokenizer_2.tgt_lang=tgt_lang

input_tokens_2 = tokenizer_2(val_inputs,max_length=128, truncation=True, padding="max_length",return_tensors='pt')

mdl_2 = MBartForConditionalGeneration.from_pretrained(ckpt2)
mdl_2 = mdl_2.to(device)

In [9]:
def index_slice_dict(dicti,slice_begin,slice_end=None):
    sliced_dict = {}
    for k,v in dicti.items():
        if slice_end:
            sliced_dict[k] = v[slice_begin:slice_end]
        else:
            sliced_dict[k] = v[slice_begin:]
    return sliced_dict

def generate_output(input_tokens,model,tokenizer,batch_num=5):
    batch_size = len(val_inputs) // batch_num
    output = []
    for i in range(0,len(val_inputs),batch_size):
        input_dict = index_slice_dict(input_tokens,i,i+batch_size)
        input_dict = {k:v.to(device) for k,v in input_dict.items()}
        output_tokens_bn = model.generate(**input_dict)
        output_batch = tokenizer.batch_decode(output_tokens_bn, skip_special_tokens=True)
        output.extend(output_batch)
        torch.cuda.empty_cache()
    # input_tokens = {k:v.to('cpu') for k,v in input_tokens.items()}
    return output

In [10]:
output_1 = generate_output(input_tokens_1,mdl_1,tokenizer_1)

In [11]:
# output_1 = generate_output(input_tokens_1,mdl_1,tokenizer_1)
torch.cuda.empty_cache()
output_2 = generate_output(input_tokens_2,mdl_2,tokenizer_2)

In [16]:
r = list(np.random.randint(0,500,10))
for i in r:
    print(f"Input: {val_inputs[i]}")
    print(f"SOTA: {output_1[i]}")
    print(f"FtSOTA: {output_2[i]}")
    print(f"Label: {val_labels[i]}")
    print('------------')

Input: Some day you will see that horrible hhingh in the holy place, just as the prophet Daniel said.
SOTA: Enn zour zot pou trouve ki fighting dan plas sakre, kouma profetgamot finn dir.
FtSOTA: Enn zour zot pou trouv, dan plas sakre, kouma profet finn dir.
Label: Alor ler zot trouv bann sakrilez abominab parey kouma Profet Daniel ti anonse.
------------
Input: God gives such beauty to everything that grows in the fields, even though it is here today and thrown into a fire tomorrow.
SOTA: Bondie donn sa bote la tou seki grandi dan karo, mem si li isi e zet li dan dife demin.
FtSOTA: Bondie donn sa bote la tou seki grandi dan karo mem si li isi e zet li dan dife demin.
Label: Bondie donn enn bote tou seki pous dan karo, mem si zot la azordi e pou zet dan dife dime.
------------
Input: The disciples were shocked when they saw how quickly the tree had dried up.
SOTA: Bann disip ti soke kan zot ti trouv ki vites pie la ti sek.
FtSOTA: Bann disip ti soke kan zot ti trouv ki vites pie la ti