In [4]:
import os
import json
import numpy as np
from collections import defaultdict
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
from datasets import load_dataset #load_dataset from Huggingface
from scipy import stats
from scipy.stats import rankdata, spearmanr, pearsonr
import statsmodels.stats.proportion as smp

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
LANG_DICT = {'afrikaans':'afr_Latn' ,
'english': 'eng_Latn',
'amharic':'amh_Ethi' ,
'armenian':'hye_Armn' ,
'assamese':'asm_Beng' ,
'basque':'eus_Latn' ,
'bengali':'ben_Beng' ,
'bulgarian':'bul_Cyrl' ,
'burmese':'mya_Mymr' ,
'catalan':'cat_Latn' ,
'central kurdish':'ckb_Arab' ,
'croatian': 'hrv_Latn',
'dutch': 'nld_Latn',
'xhosa': 'xho_Latn',
'macedonian': 'mkd_Cyrl',
'czech':'ces_Latn' ,
'danish':'dan_Latn' ,
'eastern panjabi':'pan_Guru' ,
'egyptian arabic':'arz_Arab' ,
'estonian':'est_Latn' ,
'finnish':'fin_Latn' ,
'french':'fra_Latn' ,
'georgian':'kat_Geor' ,
'german':'deu_Latn' ,
'greek':'ell_Grek' ,
'gujarati':'guj_Gujr' ,
'hausa':'hau_Latn' ,
'hebrew':'heb_Hebr' ,
'hindi':'hin_Deva' ,
'hungarian':'hun_Latn' ,
'icelandic':'isl_Latn' ,
'indonesian':'ind_Latn' ,
'italian':'ita_Latn' ,
'japanese':'jpn_Jpan' ,
'javanese':'jav_Latn' ,
'kannada':'kan_Knda' ,
'kazakh':'kaz_Cyrl' ,
'khmer':'khm_Khmr' ,
'korean':'kor_Hang' ,
'kyrgyz':'kir_Cyrl' ,
'lao':'lao_Laoo' ,
'lithuanian':'lit_Latn' ,
'malayalam':'mal_Mlym' ,
'marathi':'mar_Deva' ,
'mesopotamian arabic':'acm_Arab' ,
'modern standard arabic':'arb_Arab' ,
'moroccan arabic':'ary_arab' ,
'najdi arabic':'ars_Arab' ,
'nepali':'npi_Deva' ,
'north azerbaijani':'azj_Latn' ,
'north levantine arabic':'apc_Arab' ,
'northern uzbek':'uzn_Latn' ,
'norwegian bokmal':'nob_Latn' ,
'odia':'ory_Orya' ,
'polish':'pol_Latn' ,
'portuguese':'por_Latn' ,
'romanian':'ron_Latn' ,
'russian':'rus_Cyrl' ,
'serbian':'srp_Cyrl' ,
'simplified chinese':'zho_Hans' ,
'sindhi':'snd_Arab' ,
'sinhala':'sin_Sinh' ,
'slovak':'slk_Latn' ,
'slovenian':'slv_Latn' ,
'somali':'som_Latn' ,
'southern pashto':'pbt_Arab' ,
'spanish':'spa_Latn' ,
'standard latvian':'lvs_Latn' ,
'standard malay':'zsm_Latn' ,
'sundanese':'sun_Latn' ,
'swahili':'swh_Latn' ,
'swedish':'swe_Latn' ,
'tamil':'tam_Taml' ,
'telugu':'tel_Telu' ,
'thai':'tha_Thai' ,
'tosk albanian':'als_Latn' ,
'traditional chinese':'zho_Hant' ,
'turkish':'tur_Latn' ,
'ukrainian':'ukr_Cyrl' ,
'urdu':'urd_Arab' ,
'vietnamese':'vie_Latn' ,
'western persian':'pes_Arab'}

LANGUAGE=[k for k,v in LANG_DICT.items()]
LANGUAGE_wo_ENGLISH = [k for k,v in LANG_DICT.items() if k!='english']

In [47]:
worst_choice_dict = defaultdict(dict)
worst_two_choices_dict = defaultdict(dict)

for lang in ['fra_Latn']:
    accuracy_data_path = f'../../accuracy_outputs/Llama3.1/belebele_5shot/{lang}/'
    jsonl_file = [f for f in os.listdir(accuracy_data_path) if f.endswith('.jsonl')][0]
    file_path = os.path.join(accuracy_data_path, jsonl_file)
    # Read the jsonl file line by line
    accuracy_results = []
    with open(file_path, 'r') as f:
        for line in f:
            accuracy_results.append(json.loads(line))
    for i in range(900):
        log_prob = [float(accuracy_results[i]['resps'][0][0][0]),float(accuracy_results[i]['resps'][1][0][0]),float(accuracy_results[i]['resps'][2][0][0]),float(accuracy_results[i]['resps'][3][0][0])]
        correct_answer_num = int(accuracy_results[i]['doc']['correct_answer_num'])
        wrong_choices = [j+1 for j in range(4) if j!=correct_answer_num-1]
        worst_choice_dict[lang][i] = min(wrong_choices, key=lambda x: log_prob[x - 1])
        worst_two_choices_dict[lang][i] = sorted(sorted(wrong_choices, key=lambda x: log_prob[x - 1])[:2])

{0: [2, 3],
 1: [3, 4],
 2: [1, 3],
 3: [3, 4],
 4: [1, 4],
 5: [1, 2],
 6: [2, 4],
 7: [2, 3],
 8: [1, 2],
 9: [1, 2],
 10: [1, 3],
 11: [1, 2],
 12: [2, 4],
 13: [1, 3],
 14: [1, 4],
 15: [1, 2],
 16: [3, 4],
 17: [1, 4],
 18: [2, 3],
 19: [2, 4],
 20: [3, 4],
 21: [1, 3],
 22: [3, 4],
 23: [3, 4],
 24: [1, 2],
 25: [2, 3],
 26: [1, 2],
 27: [1, 3],
 28: [1, 4],
 29: [1, 3],
 30: [3, 4],
 31: [2, 4],
 32: [1, 3],
 33: [2, 4],
 34: [2, 4],
 35: [1, 2],
 36: [3, 4],
 37: [3, 4],
 38: [3, 4],
 39: [2, 3],
 40: [1, 4],
 41: [1, 3],
 42: [3, 4],
 43: [1, 2],
 44: [1, 4],
 45: [2, 4],
 46: [3, 4],
 47: [2, 4],
 48: [3, 4],
 49: [2, 3],
 50: [2, 4],
 51: [2, 4],
 52: [1, 4],
 53: [1, 2],
 54: [2, 4],
 55: [3, 4],
 56: [1, 3],
 57: [1, 4],
 58: [1, 2],
 59: [1, 2],
 60: [2, 4],
 61: [2, 3],
 62: [1, 3],
 63: [2, 3],
 64: [3, 4],
 65: [2, 3],
 66: [1, 4],
 67: [1, 3],
 68: [1, 4],
 69: [1, 4],
 70: [2, 3],
 71: [1, 2],
 72: [1, 2],
 73: [2, 3],
 74: [2, 4],
 75: [1, 4],
 76: [3, 4],
 77: [1, 

In [50]:
worst_choice_dict['fra_Latn']

{0: 2,
 1: 3,
 2: 3,
 3: 3,
 4: 4,
 5: 1,
 6: 2,
 7: 3,
 8: 1,
 9: 1,
 10: 1,
 11: 1,
 12: 2,
 13: 1,
 14: 1,
 15: 2,
 16: 3,
 17: 1,
 18: 3,
 19: 2,
 20: 3,
 21: 1,
 22: 4,
 23: 3,
 24: 2,
 25: 2,
 26: 1,
 27: 1,
 28: 1,
 29: 3,
 30: 3,
 31: 4,
 32: 1,
 33: 2,
 34: 4,
 35: 1,
 36: 3,
 37: 3,
 38: 3,
 39: 3,
 40: 4,
 41: 3,
 42: 4,
 43: 1,
 44: 1,
 45: 4,
 46: 3,
 47: 2,
 48: 4,
 49: 3,
 50: 2,
 51: 4,
 52: 1,
 53: 1,
 54: 4,
 55: 3,
 56: 1,
 57: 1,
 58: 2,
 59: 1,
 60: 4,
 61: 2,
 62: 1,
 63: 3,
 64: 4,
 65: 3,
 66: 1,
 67: 3,
 68: 4,
 69: 1,
 70: 3,
 71: 1,
 72: 2,
 73: 2,
 74: 4,
 75: 1,
 76: 3,
 77: 1,
 78: 3,
 79: 3,
 80: 1,
 81: 1,
 82: 1,
 83: 1,
 84: 4,
 85: 3,
 86: 1,
 87: 3,
 88: 3,
 89: 3,
 90: 2,
 91: 3,
 92: 1,
 93: 2,
 94: 4,
 95: 1,
 96: 2,
 97: 3,
 98: 4,
 99: 4,
 100: 3,
 101: 3,
 102: 4,
 103: 2,
 104: 4,
 105: 4,
 106: 2,
 107: 4,
 108: 1,
 109: 1,
 110: 4,
 111: 2,
 112: 1,
 113: 4,
 114: 4,
 115: 4,
 116: 2,
 117: 3,
 118: 1,
 119: 1,
 120: 1,
 121: 2,
 122: 1,
 12