In [1]:
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import re
from io import StringIO


# set pandas display options
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)



In [2]:
# Set the path to the data
data_dir = Path('../../models/hotels')

dfs = []
for model_dir in data_dir.iterdir():
    if model_dir.is_dir():
        model = model_dir.name

        seed = re.search(r'_s(\d+)', model)
        if seed:
            seed = int(seed.group(1))
        else:
            seed = 0
        
        # strip the seed from the model name
        model = re.sub(r'_s\d+', '', model)

        print(model, seed)
        
        results_file = model_dir / 'inference' / 'eval_result.txt'
        if not results_file.exists():
            print(f'No results for {model} {seed}')
            continue
            # # try:
            # #     # Try to find the results file in a subdirectory with wildcards
            # #     results_file = model_dir / 'inference' / '*/eval_result.txt'
                
            # #     if not results_file.exists():
            #         print(f'No results for {model} {seed}')
            #         continue
    
        with open(results_file, 'r', encoding='utf8') as f:
            file_contents = [line.strip() for line in f.readlines()]
            
            print(len(file_contents))
            if len(file_contents) == 4:
                print('skipping first line')
                file_contents = file_contents[1:]

            csv_string = StringIO('\n'.join(file_contents))
            df = pd.read_csv(csv_string, sep=",", header=0)
            
            # rename Unnamed: 0 to file
            df.rename(columns={'Unnamed: 0': 'file'}, inplace=True)

            df['model'] = model
            df['seed'] = seed

            # put the model and seed columns first
            cols = df.columns.tolist()
            cols = cols[-2:] + cols[:-2]
            df = df[cols]

            dfs.append(df)

df = pd.concat(dfs)
print(df.columns)
df

baseline 0
3
baseline 42
3
baseline 985
3
filt_freq_distro 0
3
filt_freq_distro 42
3
filt_freq_distro 985
3
filt_gen_sent 0
3
filt_gen_sent 42
3
filt_gen_sent 985
3
filt_rrsts 0
3
filt_rrtfidf 0
4
skipping first line
filt_tfidf 0
3
filt_tgt_ppl 0
3
filt_combo 0
3
filt_tgt_ppl_abl_20 0
3
filt_tgt_ppl_abl_60 0
3
filt_tgt_ppl_abl_80 0
3
filt_tgt_ppl 42
4
skipping first line
filt_tgt_ppl 985
4
skipping first line
rule_based 0
No results for rule_based 0
label_tgt_ppl 0
No results for label_tgt_ppl 0
Index(['model', 'seed', 'file', 'test set size', 'BLEU', 'ROUGE-1', 'ROUGE-2', 'ROUGE-L', 'METEOR', 'SARI-BP', 'SARI', 'CHRF-tgt', 'CHRF-src', 'intraDIST-1', 'intraDIST-2', 'interDIST-1', 'interDIST-2', 'Self-BLEU', 'rep-r', 'rep-w', 'seq-rep-n', 'paraphrase_reps', 'src-tgt sts', 'domain acc', 'rating acc', 'source acc', 'hyp lens', 'Uniq'], dtype='object')


Unnamed: 0,model,seed,file,test set size,BLEU,ROUGE-1,ROUGE-2,ROUGE-L,METEOR,SARI-BP,SARI,CHRF-tgt,CHRF-src,intraDIST-1,intraDIST-2,interDIST-1,interDIST-2,Self-BLEU,rep-r,rep-w,seq-rep-n,paraphrase_reps,src-tgt sts,domain acc,rating acc,source acc,hyp lens,Uniq
0,baseline,0,/srv/scratch6/kew/bart/hospo_respo/en/500k//ba...,24736,0.1137,0.415,0.171,0.31,0.311,0.3046,0.4424,0.305,0.158,0.7695,0.9783,0.005,0.02,0.271,0.0801,0.1377,0.0648,,0.4344,,,,59.0884,
0,baseline,42,/srv/scratch6/kew/bart/hospo_respo/en/500k//ba...,24736,0.1121,0.413,0.171,0.309,0.31,0.3008,0.4417,0.302,0.157,0.7694,0.9782,0.0048,0.0191,0.2743,0.0813,0.1394,0.0649,,0.4305,,,,58.63,6994.0
0,baseline,985,/srv/scratch6/kew/bart/hospo_respo/en/500k//ba...,24736,0.115,0.414,0.17,0.308,0.312,0.3076,0.4418,0.307,0.161,0.7663,0.9774,0.0049,0.0195,0.1927,0.084,0.1381,0.0658,,0.4367,,,,60.4496,7254.0
0,filt_freq_distro,0,/srv/scratch6/kew/bart/hospo_respo/en/500k//fi...,24736,0.1156,0.392,0.145,0.277,0.319,0.3666,0.4358,0.334,0.206,0.7473,0.9734,0.0058,0.0262,0.1468,0.0982,0.1228,0.0712,,0.4792,,,,83.1926,
0,filt_freq_distro,42,/srv/scratch6/kew/bart/hospo_respo/en/500k//fi...,24736,0.1207,0.398,0.15,0.284,0.326,0.3722,0.4383,0.338,0.208,0.7385,0.9721,0.0058,0.0261,0.1689,0.1032,0.1276,0.0737,,0.4872,,,,81.9,11824.0
0,filt_freq_distro,985,/srv/scratch6/kew/bart/hospo_respo/en/500k//fi...,24736,0.1194,0.396,0.149,0.282,0.323,0.3658,0.4374,0.336,0.205,0.744,0.9723,0.0059,0.0265,0.1453,0.1024,0.1252,0.0723,,0.4832,,,,82.0029,11918.0
0,filt_gen_sent,0,/srv/scratch6/kew/bart/hospo_respo/en/500k//fi...,24736,0.1136,0.4,0.151,0.287,0.317,0.3544,0.4363,0.324,0.2,0.7367,0.9702,0.0064,0.0296,0.1025,0.109,0.1348,0.0748,,0.4814,,,,74.7225,
0,filt_gen_sent,42,/srv/scratch6/kew/bart/hospo_respo/en/500k//fi...,24736,0.1131,0.395,0.147,0.282,0.315,0.3575,0.4351,0.327,0.207,0.7316,0.9687,0.0064,0.0302,0.119,0.1143,0.1352,0.0764,,0.4833,,,,77.6362,12235.0
0,filt_gen_sent,985,/srv/scratch6/kew/bart/hospo_respo/en/500k//fi...,24736,0.1127,0.399,0.15,0.286,0.316,0.3546,0.4361,0.325,0.199,0.7355,0.97,0.0063,0.0285,0.1719,0.1092,0.1362,0.0751,,0.4768,,,,74.7043,11553.0
0,filt_rrsts,0,/srv/scratch6/kew/bart/hospo_respo/en/500k/fil...,24736,0.1277,0.421,0.171,0.307,0.33,0.3433,0.4453,0.332,0.194,0.7349,0.9692,0.0059,0.0266,0.179,0.1137,0.1435,0.0757,,0.5298,,,,70.4874,


In [3]:
# compute the mean and std of metrics across seeds

agg_df = df[df['model'].isin(['baseline', 'filt_tgt_ppl', 'filt_gen_sent', 'filt_freq_distro'])].groupby(['model']).agg(['mean', 'std'])
agg_df

Unnamed: 0_level_0,seed,seed,test set size,test set size,BLEU,BLEU,ROUGE-1,ROUGE-1,ROUGE-2,ROUGE-2,ROUGE-L,ROUGE-L,METEOR,METEOR,SARI-BP,SARI-BP,SARI,SARI,CHRF-tgt,CHRF-tgt,CHRF-src,CHRF-src,intraDIST-1,intraDIST-1,intraDIST-2,intraDIST-2,interDIST-1,interDIST-1,interDIST-2,interDIST-2,Self-BLEU,Self-BLEU,rep-r,rep-r,rep-w,rep-w,seq-rep-n,seq-rep-n,paraphrase_reps,paraphrase_reps,src-tgt sts,src-tgt sts,domain acc,domain acc,rating acc,rating acc,source acc,source acc,hyp lens,hyp lens,Uniq,Uniq
Unnamed: 0_level_1,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std,mean,std
model,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2,Unnamed: 24_level_2,Unnamed: 25_level_2,Unnamed: 26_level_2,Unnamed: 27_level_2,Unnamed: 28_level_2,Unnamed: 29_level_2,Unnamed: 30_level_2,Unnamed: 31_level_2,Unnamed: 32_level_2,Unnamed: 33_level_2,Unnamed: 34_level_2,Unnamed: 35_level_2,Unnamed: 36_level_2,Unnamed: 37_level_2,Unnamed: 38_level_2,Unnamed: 39_level_2,Unnamed: 40_level_2,Unnamed: 41_level_2,Unnamed: 42_level_2,Unnamed: 43_level_2,Unnamed: 44_level_2,Unnamed: 45_level_2,Unnamed: 46_level_2,Unnamed: 47_level_2,Unnamed: 48_level_2,Unnamed: 49_level_2,Unnamed: 50_level_2,Unnamed: 51_level_2,Unnamed: 52_level_2
baseline,342.333333,556.961698,24736,0.0,0.1136,0.001453,0.414,0.001,0.170667,0.000577,0.309,0.001,0.311,0.001,0.304333,0.003408,0.441967,0.000379,0.304667,0.002517,0.158667,0.002082,0.7684,0.001819,0.977967,0.000493,0.0049,0.0001,0.019533,0.000451,0.246,0.046189,0.0818,0.001997,0.1384,0.000889,0.065167,0.000551,,,0.433867,0.003134,,,,,,,59.389333,0.946391,7124.0,183.847763
filt_freq_distro,342.333333,556.961698,24736,0.0,0.118567,0.00265,0.395333,0.003055,0.148,0.002646,0.281,0.003606,0.322667,0.003512,0.3682,0.003487,0.437167,0.001266,0.336,0.002,0.206333,0.001528,0.743267,0.004446,0.9726,0.0007,0.005833,5.8e-05,0.026267,0.000208,0.153667,0.013214,0.101267,0.002686,0.1252,0.0024,0.0724,0.001253,,,0.4832,0.004,,,,,,,82.365167,0.718423,11871.0,66.468037
filt_gen_sent,342.333333,556.961698,24736,0.0,0.113133,0.000451,0.398,0.002646,0.149333,0.002082,0.285,0.002646,0.316,0.001,0.3555,0.001735,0.435833,0.000643,0.325333,0.001528,0.202,0.004359,0.7346,0.002666,0.969633,0.000814,0.006367,5.8e-05,0.029433,0.000862,0.131133,0.036256,0.110833,0.003004,0.1354,0.000721,0.075433,0.00085,,,0.4805,0.003342,,,,,,,75.687667,1.687504,11894.0,482.246825
filt_tgt_ppl,342.333333,556.961698,24736,0.0,0.115333,0.00125,0.403,0.001,0.151333,0.001528,0.288333,0.001528,0.319667,0.000577,0.355133,0.002363,0.436633,0.000252,0.326333,0.002082,0.21,0.003606,0.745067,0.002517,0.973033,0.000681,0.007333,0.000153,0.037133,0.00125,0.042367,0.009069,0.099767,0.002228,0.129133,0.00125,0.071933,0.000751,,,0.5075,0.003568,,,0.9685,0.002121,,,73.817133,1.374777,13274.5,440.527525


In [5]:
# get max and min std values for each metric
stds = agg_df.loc[:, (slice(None), 'std')]
stds


Unnamed: 0_level_0,seed,test set size,BLEU,ROUGE-1,ROUGE-2,ROUGE-L,METEOR,SARI-BP,SARI,CHRF-tgt,CHRF-src,intraDIST-1,intraDIST-2,interDIST-1,interDIST-2,Self-BLEU,rep-r,rep-w,seq-rep-n,paraphrase_reps,src-tgt sts,domain acc,rating acc,source acc,hyp lens,Uniq
Unnamed: 0_level_1,std,std,std,std,std,std,std,std,std,std,std,std,std,std,std,std,std,std,std,std,std,std,std,std,std,std
model,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2,Unnamed: 24_level_2,Unnamed: 25_level_2,Unnamed: 26_level_2
baseline,556.961698,0.0,0.001453,0.001,0.000577,0.001,0.001,0.003408,0.000379,0.002517,0.002082,0.001819,0.000493,0.0001,0.000451,0.046189,0.001997,0.000889,0.000551,,0.003134,,,,0.946391,183.847763
filt_freq_distro,556.961698,0.0,0.00265,0.003055,0.002646,0.003606,0.003512,0.003487,0.001266,0.002,0.001528,0.004446,0.0007,5.8e-05,0.000208,0.013214,0.002686,0.0024,0.001253,,0.004,,,,0.718423,66.468037
filt_gen_sent,556.961698,0.0,0.000451,0.002646,0.002082,0.002646,0.001,0.001735,0.000643,0.001528,0.004359,0.002666,0.000814,5.8e-05,0.000862,0.036256,0.003004,0.000721,0.00085,,0.003342,,,,1.687504,482.246825
filt_tgt_ppl,556.961698,0.0,0.00125,0.001,0.001528,0.001528,0.000577,0.002363,0.000252,0.002082,0.003606,0.002517,0.000681,0.000153,0.00125,0.009069,0.002228,0.00125,0.000751,,0.003568,,0.002121,,1.374777,440.527525
