In [1]:
import pandas as pd
import os

In [2]:
model_ckpt_dict = {
    'bloom-560m': [1000, 10000, 100000, 200000, 300000, 400000, 500000, 600000],
    'bloom-1b1': [1000, 10000, 100000, 200000, 300000, 400000, 500000, 600000],
    'bloom-1b7': [10000, 50000, 100000, 150000, 200000, 250000], 
}

trg_langs = ['ar', 'es', 'eu', 'fr', 'hi', 'pt', 'ta', 'ur', 'vi']

similarity_result_path = os.path.join('../parallel-sentence-similarity/experiments/cos_similarity_csv/')
output_csv_file_path = 'csv_files/'
layer = 'inter-layer-17'

In [3]:
def read_similarity_by_langs(path, lang_list, read_code=False):
    similarity_dict = {}
    for lang in lang_list:
        result_file = os.path.join(path, f'en-{lang}.csv')
        with open(result_file, 'r') as f:
            last_line = f.readlines()[-1]
            similarity_dict[lang] = float(last_line.split(',')[1])
    if read_code:
        result_file = os.path.join(path, f'nl-code.csv')
        with open(result_file, 'r') as f:
            last_line = f.readlines()[-1]
            similarity_dict['code'] = float(last_line.split(',')[1])
    return similarity_dict

In [4]:
for model, ckpt_list in model_ckpt_dict.items():

    output_csv_file_name = os.path.join(output_csv_file_path, f'{model}_{layer}_cos-similarity.csv')
    ckpt_similarity_dict = {}
    # Dict{ckpt: Dict{ lang: similarity}}
    ckpt_similarity_dict['best'] = read_similarity_by_langs(os.path.join(similarity_result_path, model, layer), trg_langs, read_code=True)
    
    for ckpt in ckpt_list:
        ckpt_similarity_dict[ckpt] = read_similarity_by_langs(os.path.join(similarity_result_path, f'{model}-intermediate-global_step{ckpt}', layer), trg_langs, read_code=True)
    
    df = pd.DataFrame.from_dict(ckpt_similarity_dict)
    print(df)
    df.to_csv(output_csv_file_name)




          best      1000     10000    100000    200000    300000    400000  \
ar    0.798261  0.576057  0.821778  0.759966  0.797074  0.872960  0.846931   
es    0.913211  0.818782  0.902119  0.887048  0.878481  0.939869  0.926818   
eu    0.854765  0.773876  0.863074  0.851020  0.820118  0.905174  0.896829   
fr    0.906356  0.816407  0.879427  0.836612  0.816979  0.937361  0.923263   
hi    0.851320  0.761199  0.914003  0.945328  0.926196  0.914875  0.893790   
pt    0.895442  0.838145  0.892464  0.899645  0.886876  0.931779  0.921725   
ta    0.752443  0.713010  0.903304  0.920436  0.812374  0.859095  0.859116   
ur    0.767284  0.588278  0.857497  0.830784  0.864675  0.855181  0.838212   
vi    0.892080  0.787140  0.905484  0.926244  0.918944  0.932927  0.919620   
code  0.309845  0.018295  0.241529  0.555214  0.460839  0.411699  0.382130   

        500000    600000  
ar    0.821778  0.341260  
es    0.902119  0.751111  
eu    0.863074  0.767098  
fr    0.879427  0.755451  
hi    