In [1]:
import os
import json
import numpy as np
from collections import defaultdict
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
from datasets import load_dataset #load_dataset from Huggingface
from scipy import stats
from scipy.stats import rankdata, spearmanr, pearsonr
import statsmodels.stats.proportion as smp

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
plt.rcParams["savefig.format"] = 'pdf'
plt.rcParams['font.family'] = 'Palatino'

In [3]:
LANG_DICT = {'afrikaans':'afr_Latn' ,
'english': 'eng_Latn',
'amharic':'amh_Ethi' ,
'armenian':'hye_Armn' ,
'assamese':'asm_Beng' ,
'basque':'eus_Latn' ,
'bengali':'ben_Beng' ,
'bulgarian':'bul_Cyrl' ,
'burmese':'mya_Mymr' ,
'catalan':'cat_Latn' ,
'central kurdish':'ckb_Arab' ,
'croatian': 'hrv_Latn',
'dutch': 'nld_Latn',
'xhosa': 'xho_Latn',
'macedonian': 'mkd_Cyrl',
'czech':'ces_Latn' ,
'danish':'dan_Latn' ,
'eastern panjabi':'pan_Guru' ,
'egyptian arabic':'arz_Arab' ,
'estonian':'est_Latn' ,
'finnish':'fin_Latn' ,
'french':'fra_Latn' ,
'georgian':'kat_Geor' ,
'german':'deu_Latn' ,
'greek':'ell_Grek' ,
'gujarati':'guj_Gujr' ,
'hausa':'hau_Latn' ,
'hebrew':'heb_Hebr' ,
'hindi':'hin_Deva' ,
'hungarian':'hun_Latn' ,
'icelandic':'isl_Latn' ,
'indonesian':'ind_Latn' ,
'italian':'ita_Latn' ,
'japanese':'jpn_Jpan' ,
'javanese':'jav_Latn' ,
'kannada':'kan_Knda' ,
'kazakh':'kaz_Cyrl' ,
'khmer':'khm_Khmr' ,
'korean':'kor_Hang' ,
'kyrgyz':'kir_Cyrl' ,
'lao':'lao_Laoo' ,
'lithuanian':'lit_Latn' ,
'malayalam':'mal_Mlym' ,
'marathi':'mar_Deva' ,
'mesopotamian arabic':'acm_Arab' ,
'modern standard arabic':'arb_Arab' ,
'moroccan arabic':'ary_arab' ,
'najdi arabic':'ars_Arab' ,
'nepali':'npi_Deva' ,
'north azerbaijani':'azj_Latn' ,
'north levantine arabic':'apc_Arab' ,
'northern uzbek':'uzn_Latn' ,
'norwegian bokmal':'nob_Latn' ,
'odia':'ory_Orya' ,
'polish':'pol_Latn' ,
'portuguese':'por_Latn' ,
'romanian':'ron_Latn' ,
'russian':'rus_Cyrl' ,
'serbian':'srp_Cyrl' ,
'simplified chinese':'zho_Hans' ,
'sindhi':'snd_Arab' ,
'sinhala':'sin_Sinh' ,
'slovak':'slk_Latn' ,
'slovenian':'slv_Latn' ,
'somali':'som_Latn' ,
'southern pashto':'pbt_Arab' ,
'spanish':'spa_Latn' ,
'standard latvian':'lvs_Latn' ,
'standard malay':'zsm_Latn' ,
'sundanese':'sun_Latn' ,
'swahili':'swh_Latn' ,
'swedish':'swe_Latn' ,
'tamil':'tam_Taml' ,
'telugu':'tel_Telu' ,
'thai':'tha_Thai' ,
'tosk albanian':'als_Latn' ,
'traditional chinese':'zho_Hant' ,
'turkish':'tur_Latn' ,
'ukrainian':'ukr_Cyrl' ,
'urdu':'urd_Arab' ,
'vietnamese':'vie_Latn' ,
'western persian':'pes_Arab'}

LANGUAGE=[k for k,v in LANG_DICT.items()]
LANGUAGE_wo_ENGLISH = [k for k,v in LANG_DICT.items() if k!='english']

In [4]:
LR_LANG = ['acm_Arab',
'amh_Ethi',
'apc_Arab',
'ars_Arab',
'ary_Arab',
'arz_Arab',
'asm_Beng',
'azj_Latn',
'ckb_Arab',
'guj_Gujr',
'hau_Latn',
'hye_Armn',
'jav_Latn',
'kan_Knda',
'kat_Geor',
'khm_Khmr',
'kir_Cyrl',
'lao_Laoo',
'mal_Mlym',
'mar_Deva',
'mya_Mymr',
'nob_Latn',
'npi_Deva',
'ory_Orya',
'pan_Guru',
'pbt_Arab',
'sin_Sinh',
'snd_Arab',
'som_Latn',
'srp_Cyrl',
'sun_Latn',
'tam_Taml',
'tel_Telu',
'urd_Arab',
'uzn_Latn']

print([k for k,v in LANG_DICT.items() if v in LR_LANG])

['amharic', 'armenian', 'assamese', 'burmese', 'central kurdish', 'eastern panjabi', 'egyptian arabic', 'georgian', 'gujarati', 'hausa', 'javanese', 'kannada', 'khmer', 'kyrgyz', 'lao', 'malayalam', 'marathi', 'mesopotamian arabic', 'najdi arabic', 'nepali', 'north azerbaijani', 'north levantine arabic', 'northern uzbek', 'norwegian bokmal', 'odia', 'serbian', 'sindhi', 'sinhala', 'somali', 'southern pashto', 'sundanese', 'tamil', 'telugu', 'urdu']


In [5]:
HR_LANG = ['afr_Latn',
'als_Latn',
'arb_Arab',
'ben_Beng',
'bul_Cyrl',
'cat_Latn',
'ces_Latn',
'dan_Latn',
'deu_Latn',
'ell_Grek',
'est_Latn',
'eus_Latn',
'fin_Latn',
'fra_Latn',
'heb_Hebr',
'hin_Deva',
'hrv_Latn',
'hun_Latn',
'ind_Latn',
'isl_Latn',
'ita_Latn',
'jpn_Jpan',
'kaz_Cyrl',
'kor_Hang',
'lit_Latn',
'lvs_Latn',
'mkd_Cyrl',
'nld_Latn',
'pes_Arab',
'pol_Latn',
'por_Latn',
'ron_Latn',
'rus_Cyrl',
'slk_Latn',
'slv_Latn',
'spa_Latn',
'swe_Latn',
'swh_Latn',
'tha_Thai',
'tur_Latn',
'ukr_Cyrl',
'vie_Latn',
'xho_Latn',
'zho_Hans',
'zho_Hant',
'zsm_Latn']

In [6]:
def plot_DALI(dataset, lang, model, mode):
    
    if mode == 'DALI':
        lang_code = LANG_DICT[lang]
        DAS_path = f'../../alignment_outputs/{model}/{dataset}_dali/DALI_{lang_code}_lasttoken.json'
    if mode == 'DALIStrong':
        lang_code = LANG_DICT[lang]
        DAS_path = f'../../alignment_outputs/{model}/{dataset}_dali_strong/DALI_{lang_code}_lasttoken.json'

    if mode == 'MEXAFlores':
        lang_code = LANG_DICT[lang]
        DAS_path = f'../../alignment_outputs/{model}/flores_mexa/{lang_code}.json'
    if mode == 'MEXATask':
        if dataset=='belebele' or dataset == 'flores':
            lang_code = LANG_DICT[lang]
            DAS_path = f'../../alignment_outputs/{model}/{dataset}_mexa/{lang_code}.json'
        else:
            DAS_path = f'../../alignment_outputs/{model}/{dataset}_mexa/{lang}.json'
    with open(DAS_path) as f:
        lang_DAS = json.load(f)
    return lang_DAS

In [7]:
def load_translation_agg(dataset, model, field='flores_passage'):
    if dataset == 'flores':
        entoxxpath = f'../../translation_outputs/{model}/{dataset}_100/sentence/entoxx_{model}_{dataset}_COMET.json'
        xxtoenpath = f'../../translation_outputs/{model}/{dataset}_100/sentence/xxtoen_{model}_{dataset}_COMET.json'
        #list_of_languages = ['modern standard arabic', 'spanish', 'basque', 'hindi', 'indonesian', 'burmese', 'russian', 'telugu', 'simplified chinese', 'swahili']
        #lang_key = {'modern standard arabic': 'arabic', 'spanish': 'spanish', 'basque': 'basque', 'hindi': 'hindi', 'indonesian': 'indonesian', 'burmese': 'burmese', 'russian': 'russian', 'telugu': 'telugu', 'simplified chinese': 'chinese', 'swahili': 'swahili'}

        with open(entoxxpath) as f:
            entoxx_COMET = json.load(f)
    
        with open(xxtoenpath) as f:
            xxtoen_COMET = json.load(f)
        
        entoxx_COMET_filtered = {}
        xxtoen_COMET_filtered = {}
        for k,v in entoxx_COMET.items():
            if k in LANG_DICT.keys():
                entoxx_COMET_filtered[k] = v[0]
        for k,v in xxtoen_COMET.items():
            if k in LANG_DICT.keys():
                xxtoen_COMET_filtered[k] = v[0]

    if dataset == 'belebele':
        entoxxpath = f'../../translation_outputs/{model}/{dataset}/{field}/entoxx_{model}_{dataset}_COMET.json'
        xxtoenpath = f'../../translation_outputs/{model}/{dataset}/{field}/xxtoen_{model}_{dataset}_COMET.json'

        with open(entoxxpath) as f:
            entoxx_COMET = json.load(f)
        with open(xxtoenpath) as f:
            xxtoen_COMET = json.load(f)
        entoxx_COMET_filtered = {}
        xxtoen_COMET_filtered = {}

        for k,v in entoxx_COMET.items():
            entoxx_COMET_filtered[k] = v[0]
        for k,v in xxtoen_COMET.items():
            xxtoen_COMET_filtered[k] = v[0]           
    return entoxx_COMET_filtered, xxtoen_COMET_filtered


In [8]:
def load_translation_sample(dataset, model, field):

    if dataset=='flores':
        entoxxpath = f'../../translation_outputs/{model}/{dataset}_100/{field}/entoxx_{model}_{dataset}_COMET_sample.json'
        xxtoenpath = f'../../translation_outputs/{model}/{dataset}_100/{field}/xxtoen_{model}_{dataset}_COMET_sample.json'
    else:
        entoxxpath = f'../../translation_outputs/{model}/{dataset}/{field}/entoxx_{model}_{dataset}_COMET_sample.json'
        xxtoenpath = f'../../translation_outputs/{model}/{dataset}/{field}/xxtoen_{model}_{dataset}_COMET_sample.json'


    with open(entoxxpath) as f:
        entoxx_COMET_sample = json.load(f)
    with open(xxtoenpath) as f:
        xxtoen_COMET_sample = json.load(f)

    return entoxx_COMET_sample, xxtoen_COMET_sample

In [9]:
def analyze_sample_level_translation(entoxx_sample, xxtoen_sample, dataset):

    LANGUAGE_DICT = {'xstorycloze': ['arabic', 'chinese', 'spanish', 'basque', 'hindi', 'indonesian', 'burmese', 'russian', 'telugu', 'swahili'],
                     'xcopa': ['chinese', 'indonesian', 'italian', 'swahili', 'tamil', 'thai', 'turkish', 'vietnamese'],
                     'belebele': LANGUAGE_wo_ENGLISH,
                     'flores': LANGUAGE_wo_ENGLISH}
    
    selected_lang_list = LANGUAGE_DICT[dataset]

    entoxxdelta_results = defaultdict(dict)
    xxtoendelta_results = defaultdict(dict)

    if dataset=='flores':
        n=100
    if dataset == 'belebele':
        n=900
    if dataset == 'xstorycloze':
        n=1511
    if dataset == 'xcopa':
        n=500


    for lang in selected_lang_list:
        if dataset =='flores':
            lang_DAS = plot_DALI(dataset, lang, 'Llama3.1', 'MEXAFlores')
        else:
            lang_DAS = plot_DALI(dataset, lang, 'Llama3.1', 'MEXATask')

        lang_DAS = {int(k): v for k,v in lang_DAS.items()}
        lang_DAS_formatted = defaultdict(dict)
        for layer in range(32):
            for i,dali in enumerate(lang_DAS[layer]):
                lang_DAS_formatted[i][layer] = lang_DAS[layer][i]

        all_list = defaultdict(list)
        for item in range(n):
            for layer in range(32):
                all_list[layer].append(lang_DAS_formatted[item][layer])

        all_mean = []

        for k,v in all_list.items():
            all_mean.append(np.mean(v))

        alignment_in_max_layer = lang_DAS[np.argmax(all_mean)]

        if sum(lang_DAS[np.argmax(all_mean)]) == len(lang_DAS[np.argmax(all_mean)]):
            entoxxdelta_results[lang]['delta'] = 'NA'
            entoxxdelta_results[lang]['Utest_pval'] = 'NA'
            entoxxdelta_results[lang]['ttest_pval'] = 'NA'
            entoxxdelta_results[lang]['N_aligned'] = len(lang_DAS[np.argmax(all_mean)])
            entoxxdelta_results[lang]['N_nonaligned'] = 0


            xxtoendelta_results[lang]['delta'] = 'NA'
            xxtoendelta_results[lang]['Utest_pval'] = 'NA'
            xxtoendelta_results[lang]['ttest_pval'] = 'NA'
            xxtoendelta_results[lang]['N_aligned'] = len(lang_DAS[np.argmax(all_mean)])
            xxtoendelta_results[lang]['N_nonaligned'] = 0



        else:
            alignment_zero_idx = [i for i,mexa in enumerate(alignment_in_max_layer) if mexa==0]
            alignment_one_idx = [i for i,mexa in enumerate(alignment_in_max_layer) if mexa==1]

            aligned_entoxx_scores = [entoxx_sample[lang]['scores'][i] for i in alignment_one_idx]
            aligned_xxtoen_scores = [xxtoen_sample[lang]['scores'][i] for i in alignment_one_idx]

            misaligned_entoxx_scores = [entoxx_sample[lang]['scores'][i] for i in alignment_zero_idx]
            misaligned_xxtoen_scores = [xxtoen_sample[lang]['scores'][i] for i in alignment_zero_idx]

            entoxx_u_stat, entoxx_u_p_value = stats.mannwhitneyu(aligned_entoxx_scores, misaligned_entoxx_scores, alternative='greater')
            entoxx_t_stat, entoxx_t_p_value = stats.ttest_ind(aligned_entoxx_scores, misaligned_entoxx_scores, alternative='greater')

            entoxxdelta_results[lang]['delta'] = np.mean(aligned_entoxx_scores)-np.mean(misaligned_entoxx_scores)
            entoxxdelta_results[lang]['Utest_pval'] = entoxx_u_p_value
            entoxxdelta_results[lang]['ttest_pval'] = entoxx_t_p_value
            entoxxdelta_results[lang]['N_aligned'] = len(alignment_one_idx)
            entoxxdelta_results[lang]['N_nonaligned'] = len(alignment_zero_idx)

            xxtoen_u_stat, xxtoen_u_p_value = stats.mannwhitneyu(aligned_xxtoen_scores, misaligned_xxtoen_scores, alternative='greater')
            xxtoen_t_stat, xxtoen_t_p_value = stats.ttest_ind(aligned_xxtoen_scores, misaligned_xxtoen_scores, alternative='greater')

            xxtoendelta_results[lang]['delta'] = np.mean(aligned_xxtoen_scores)-np.mean(misaligned_xxtoen_scores)
            xxtoendelta_results[lang]['Utest_pval'] = xxtoen_u_p_value
            xxtoendelta_results[lang]['ttest_pval'] = xxtoen_t_p_value
            xxtoendelta_results[lang]['N_aligned'] = len(alignment_one_idx)
            xxtoendelta_results[lang]['N_nonaligned'] = len(alignment_zero_idx)

    return entoxxdelta_results, xxtoendelta_results

In [10]:
entoxx_sample, xxtoen_sample = load_translation_sample('belebele', 'Llama3.1', 'flores_passage')
entoxxdelta_results, xxtoendelta_results = analyze_sample_level_translation(entoxx_sample, xxtoen_sample, 'belebele')

# Convert entoxxdelta_results to a DataFrame
entoxx_delta_df = pd.DataFrame.from_dict(entoxxdelta_results, orient='index').reset_index()
xxtoen_delta_df = pd.DataFrame.from_dict(xxtoendelta_results, orient='index').reset_index()

# Rename the columns
entoxx_delta_df.rename(columns={'index': 'Language', 
                                 'delta': 'Delta', 
                                 'Utest_pval': 'Utest_pval', 
                                 'ttest_pval': 'Ttest_pval', 
                                 'N_aligned': 'N_aligned', 
                                 'N_nonaligned': 'N_nonaligned'}, inplace=True)

xxtoen_delta_df.rename(columns={'index': 'Language', 
                                 'delta': 'Delta', 
                                 'Utest_pval': 'Utest_pval', 
                                 'ttest_pval': 'Ttest_pval', 
                                 'N_aligned': 'N_aligned', 
                                 'N_nonaligned': 'N_nonaligned'}, inplace=True)



entoxx_delta_df.to_excel("../../../../Images_DALI/belebele_plots/entoxx_samplelevel_belebele_delta.xlsx", index=False)
xxtoen_delta_df.to_excel("../../../../Images_DALI/belebele_plots/xxtoen_samplelevel_belebele_delta.xlsx", index=False)

In [11]:
entoxx_sample, xxtoen_sample = load_translation_sample('flores', 'Llama3.1', 'sentence')
entoxxdelta_results, xxtoendelta_results = analyze_sample_level_translation(entoxx_sample, xxtoen_sample, 'flores')

# Convert entoxxdelta_results to a DataFrame
entoxx_delta_df = pd.DataFrame.from_dict(entoxxdelta_results, orient='index').reset_index()
xxtoen_delta_df = pd.DataFrame.from_dict(xxtoendelta_results, orient='index').reset_index()

# Rename the columns
entoxx_delta_df.rename(columns={'index': 'Language', 
                                 'delta': 'Delta', 
                                 'Utest_pval': 'Utest_pval', 
                                 'ttest_pval': 'Ttest_pval', 
                                 'N_aligned': 'N_aligned', 
                                 'N_nonaligned': 'N_nonaligned'}, inplace=True)

xxtoen_delta_df.rename(columns={'index': 'Language', 
                                 'delta': 'Delta', 
                                 'Utest_pval': 'Utest_pval', 
                                 'ttest_pval': 'Ttest_pval', 
                                 'N_aligned': 'N_aligned', 
                                 'N_nonaligned': 'N_nonaligned'}, inplace=True)



entoxx_delta_df.to_excel("../../../../Images_DALI/belebele_plots/entoxx_samplelevel_flores_delta.xlsx", index=False)
xxtoen_delta_df.to_excel("../../../../Images_DALI/belebele_plots/xxtoen_samplelevel_flores_delta.xlsx", index=False)

In [22]:
entoxx_sample_input1, xxtoen_sample_input1 = load_translation_sample('xstorycloze', 'Llama3.1', 'input_sentence_1')
entoxx_sample_input2, xxtoen_sample_input2 = load_translation_sample('xstorycloze', 'Llama3.1', 'input_sentence_2')
entoxx_sample_input3, xxtoen_sample_input3 = load_translation_sample('xstorycloze', 'Llama3.1', 'input_sentence_3')
entoxx_sample_input4, xxtoen_sample_input4 = load_translation_sample('xstorycloze', 'Llama3.1', 'input_sentence_4')

entoxx_sample_premise = defaultdict(dict)
xxtoen_sample_premise = defaultdict(dict)

for lang in entoxx_sample_input1.keys():
    entoxx_sample_premise[lang]['scores'] = []
    
    entoxx_sample_premise[lang]['scores'].extend(entoxx_sample_input1[lang]['scores'])
    entoxx_sample_premise[lang]['scores'].extend(entoxx_sample_input2[lang]['scores'])
    entoxx_sample_premise[lang]['scores'].extend(entoxx_sample_input3[lang]['scores'])
    entoxx_sample_premise[lang]['scores'].extend(entoxx_sample_input4[lang]['scores'])
    
    xxtoen_sample_premise[lang]['scores'] = []
    
    xxtoen_sample_premise[lang]['scores'].extend(xxtoen_sample_input1[lang]['scores'])
    xxtoen_sample_premise[lang]['scores'].extend(xxtoen_sample_input2[lang]['scores'])
    xxtoen_sample_premise[lang]['scores'].extend(xxtoen_sample_input3[lang]['scores'])
    xxtoen_sample_premise[lang]['scores'].extend(xxtoen_sample_input4[lang]['scores'])






entoxxdelta_results, xxtoendelta_results = analyze_sample_level_translation(entoxx_sample_premise, xxtoen_sample_premise, 'xstorycloze')

# Convert entoxxdelta_results to a DataFrame
entoxx_delta_df = pd.DataFrame.from_dict(entoxxdelta_results, orient='index').reset_index()
xxtoen_delta_df = pd.DataFrame.from_dict(xxtoendelta_results, orient='index').reset_index()

# Rename the columns
entoxx_delta_df.rename(columns={'index': 'Language', 
                                 'delta': 'Delta', 
                                 'Utest_pval': 'Utest_pval', 
                                 'ttest_pval': 'Ttest_pval', 
                                 'N_aligned': 'N_aligned', 
                                 'N_nonaligned': 'N_nonaligned'}, inplace=True)

xxtoen_delta_df.rename(columns={'index': 'Language', 
                                 'delta': 'Delta', 
                                 'Utest_pval': 'Utest_pval', 
                                 'ttest_pval': 'Ttest_pval', 
                                 'N_aligned': 'N_aligned', 
                                 'N_nonaligned': 'N_nonaligned'}, inplace=True)



entoxx_delta_df.to_excel("../../../../Images_DALI/xstorycloze_plots/entoxx_samplelevel_premise_delta.xlsx", index=False)
xxtoen_delta_df.to_excel("../../../../Images_DALI/xstorycloze_plots/xxtoen_samplelevel_premise_delta.xlsx", index=False)





In [23]:
entoxx_sample, xxtoen_sample = load_translation_sample('xcopa', 'Llama3.1', 'premise')
entoxxdelta_results, xxtoendelta_results = analyze_sample_level_translation(entoxx_sample, xxtoen_sample, 'xcopa')

# Convert entoxxdelta_results to a DataFrame
entoxx_delta_df = pd.DataFrame.from_dict(entoxxdelta_results, orient='index').reset_index()
xxtoen_delta_df = pd.DataFrame.from_dict(xxtoendelta_results, orient='index').reset_index()

# Rename the columns
entoxx_delta_df.rename(columns={'index': 'Language', 
                                 'delta': 'Delta', 
                                 'Utest_pval': 'Utest_pval', 
                                 'ttest_pval': 'Ttest_pval', 
                                 'N_aligned': 'N_aligned', 
                                 'N_nonaligned': 'N_nonaligned'}, inplace=True)

xxtoen_delta_df.rename(columns={'index': 'Language', 
                                 'delta': 'Delta', 
                                 'Utest_pval': 'Utest_pval', 
                                 'ttest_pval': 'Ttest_pval', 
                                 'N_aligned': 'N_aligned', 
                                 'N_nonaligned': 'N_nonaligned'}, inplace=True)



entoxx_delta_df.to_excel("../../../../Images_DALI/xcopa_plots/entoxx_samplelevel_premise_delta.xlsx", index=False)
xxtoen_delta_df.to_excel("../../../../Images_DALI/xcopa_plots/xxtoen_samplelevel_premise_delta.xlsx", index=False)