In [1]:
import numpy as np
import os
import orjson
import collections
from tqdm.notebook import tqdm

from transformers import AutoTokenizer

In [2]:
def get_token_idx_from_line(line_idx, indices):
    return np.where(indices == line_idx)

def get_token_input_ids_from_token_ids(token_indices, input_ids):
    return input_ids[token_indices]

def get_token_input_ids_from_line(line_idx, indices, input_ids):
    return input_ids[np.where(indices == line_idx)]

def get_all_line_idx(indices):
    return np.unique(indices)

def get_first_available_line(indices):
    return indices[0]


In [3]:
def load_data(load_dir):
    data = {}
    
    for found_file in os.listdir(load_dir):

        full_file_path = os.path.join(load_dir, found_file)

        if '.npy' in found_file:

            v_name = found_file[:-4]
            v = np.load(full_file_path)

            data[v_name] = v

        elif '.jsonl' in found_file:
            v_name = found_file[:-6]
            with open(full_file_path, 'r') as f:
                for line in f:
                    v = orjson.loads(line.strip())
                    break #  Assume only 1 line of information, must edit if this changes
            data[v_name] = v
            
    return data
            

In [4]:
def get_simple_lang_distribution(loaded_data):
    all_line_indices = np.unique(loaded_data['target_indices'])

    lang_distribution = collections.Counter()

    for line_idx in all_line_indices:

        target_token_ids = get_token_idx_from_line(line_idx, loaded_data['target_indices'])[0] # The tokens in line 0 are mapped to the indices 0-9
        # print(target_token_ids)
        # use these indices to look up their input ids:

        target_input_ids = get_token_input_ids_from_token_ids(target_token_ids, loaded_data['target_input_ids'])
        # print(target_input_ids)

        # Sanity check: convert these ids to the tokens:
        # print(tokenizer.convert_ids_to_tokens(target_input_ids))

        # For each token, lets get their nearest neighbors

        for i in range(target_token_ids.shape[0]):
            # print('#####'*50)

            target_token_id = target_token_ids[i]
            target_token_input_id = target_input_ids[i]

            # print(target_token_id)
            # print(target_token_input_id)
            # print(tokenizer.convert_ids_to_tokens([target_token_input_id]))

            neighbors = loaded_data['neighbors_indices'][target_token_id]
            # print(neighbors)

            neighbor_input_ids = loaded_data['source_input_ids'][neighbors]
            # print(neighbor_input_ids)
            # print(tokenizer.convert_ids_to_tokens(neighbor_input_ids))

            # print(loaded_data['source_langs_input_id'][neighbors])
            lang_distribution.update(loaded_data['source_langs'][neighbors])


    return lang_distribution


In [5]:
# root_load_dir = '/scratch/alpine/abeb4417/multilingual_analysis/nn_outputs/xlm-roberta-base/nn_layer_8/'
root_load_dir = '/scratch/alpine/abeb4417/multilingual_analysis/nn_outputs_ud/ud_nn_calculations_rerun/xlm-roberta-base/layer_8/'

lang_distributions_8 = {}

tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')

# for target_lang in tqdm(os.listdir(root_load_dir)):

for target_lang in ['fr_partut-ud-dev']:

    lang_load_dir = os.path.join(root_load_dir, target_lang)
    loaded_data = load_data(lang_load_dir)
    lang_distributions_8[target_lang] = get_simple_lang_distribution(loaded_data)
    

    

In [7]:
lang_distributions_0['eng'].most_common(10)

NameError: name 'lang_distributions_0' is not defined

In [8]:
print(lang_distributions_8['fr_partut-ud-dev'].most_common())

[('fro_profiterole-ud-train', 2644), ('ca_ancora-ud-train', 940), ('it_old-ud-train', 877), ('es_ancora-ud-train', 793), ('es_gsd-ud-train', 742), ('en_partut-ud-train', 682), ('it_isdt-ud-train', 660), ('it_parlamint-ud-train', 649), ('gl_ctg-ud-train', 634), ('it_partut-ud-train', 620), ('qaf_arabizi-ud-train', 602), ('it_markit-ud-train', 539), ('en_eslspok-ud-train', 533), ('pt_petrogold-ud-train', 508), ('pt_bosque-ud-train', 473), ('it_vit-ud-train', 471), ('en_gum-ud-train', 439), ('gl_treegal-ud-train', 424), ('en_ewt-ud-train', 407), ('pt_porttinari-ud-train', 404), ('it_twittiro-ud-train', 378), ('en_lines-ud-train', 364), ('pt_cintil-ud-train', 349), ('it_postwita-ud-train', 340), ('pt_gsd-ud-train', 326), ('la_udante-ud-train', 308), ('la_perseus-ud-train', 277), ('la_proiel-ud-train', 223), ('ro_simonero-ud-train', 222), ('nl_alpino-ud-train', 194), ('la_ittb-ud-train', 184), ('nl_lassysmall-ud-train', 182), ('fi_tdt-ud-train', 176), ('ro_rrt-ud-train', 171), ('ro_nonstand

In [81]:
langs = list(lang_distributions_0.keys())
from iso639 import Lang

langs.sort()

for lang in langs:
    print(f'Target Language: {Lang(lang).name}')
    print(f'{"Layer 0":^40s}{"Layer 8":^40s}')
    
    for z,e in zip(lang_distributions_0[lang].most_common(), lang_distributions_8[lang].most_common()):
        z_lang = Lang(z[0]).name 
        e_lang = Lang(e[0]).name
        z_string = f'{z_lang:30s}{z[1]:>10d}'
        e_string = f'{e_lang:30s}{e[1]:>10d}'
        print(f'{z_string:>20s} | {e_string:>20s}')
        
    print('\n\n')
        
        

Target Language: Akan
                Layer 0                                 Layer 8                 
Bambara                            71098 | Bambara                            83304
Egyptian Arabic                    20266 | Kinyarwanda                        34486
Kinyarwanda                        17042 | Wolof                              24890
Buginese                           16828 | Yoruba                             22270
Yoruba                             13057 | Southern Sotho                     21096
Southern Sotho                     12828 | Buginese                           13360
Wolof                              12751 | Waray (Philippines)                 8374
Zulu                                7927 | Minangkabau                         6054
Minangkabau                         7920 | Zulu                                4591
Waray (Philippines)                 7690 | Tajik                               3611
Norwegian Nynorsk                   7641 | Tatar         