In [1]:
import os
import matplotlib.pyplot as plt
import numpy as np
import pickle

base_dir = '/ebs/.cache/ubuntu'
eval_dir = os.path.join(base_dir, 'eval_files')
os.makedirs(eval_dir, exist_ok=True)
save_win_dict = False
load_win_dict = False

In [14]:
def all_dirs_with_prefix(prefix):
    return [os.path.join(base_dir, d) for d in os.listdir(base_dir) if d.startswith(prefix)]

all_sft_dirs = [os.path.join(cdir, p) for cdir in all_dirs_with_prefix('final_ds') for p in os.listdir(cdir) if p.startswith('epoch')]
base_eval_dir = ['/ebs/.cache/ubuntu/llama7b_samples/alpaca_eval_nsamples1_maxlen512']
all_dpo_dirs = [os.path.join(cdir, p) for cdir in all_dirs_with_prefix('DS_') for p in os.listdir(cdir) if p.startswith('epoch')]
eval_list = ['gpt4', 'claude_1']

In [18]:
winrate = {}
for cdir in base_eval_dir + all_sft_dirs + all_dpo_dirs:
    if 'DS_' in cdir:
        if 'c2' in cdir or 'epoch-9' in cdir or 'epoch-12' in cdir:
            # print(f'Skipping {cdir}')
            continue
    try:
        name = open(os.path.join(cdir, 'name.txt')).read().splitlines()[0]
    except:
        print(f'Nothing found for {cdir}')
        continue
    print(cdir, name)
    winrate[name] = {}
    for eval in eval_list:
        try:
            leaderboard = open(os.path.join(cdir, eval, 'leaderboard.csv')).read().splitlines()
        except:
            print(f'Score not computed for {eval}')
            continue
        for row in leaderboard:
            if row.split(',')[0] == name:
                print(f'{eval} score for {name}: {row.split(",")[1]}')
                winrate[name][eval] = float(row.split(",")[1])

/ebs/.cache/ubuntu/llama7b_samples/alpaca_eval_nsamples1_maxlen512 llama7b_base
gpt4 score for llama7b_base: 0.375
claude_1 score for llama7b_base: 6.758448060075094
/ebs/.cache/ubuntu/final_ds_llama7b_1turn512_sft_bs8lr1e-6_df0.1_2023-10-06_17-11-26_207956/epoch-3 llama7b_sft0.1_e3
gpt4 score for llama7b_sft0.1_e3: 38.75
claude_1 score for llama7b_sft0.1_e3: 52.125
/ebs/.cache/ubuntu/final_ds_llama7b_1turn512_sft_bs8lr1e-6_df0.1_2023-10-06_17-11-26_207956/epoch-6 llama7b_sft0.1_e6
gpt4 score for llama7b_sft0.1_e6: 42.5625
claude_1 score for llama7b_sft0.1_e6: 53.8125
/ebs/.cache/ubuntu/final_ds_llama7b_1turn512_sft_bs8lr1e-6_df0.1_2023-10-06_17-11-26_207956/epoch-9 llama7b_sft0.1_e9
gpt4 score for llama7b_sft0.1_e9: 40.75
claude_1 score for llama7b_sft0.1_e9: 55.00000000000001
/ebs/.cache/ubuntu/final_ds_llama7b_1turn512_sft_bs8lr1e-6_df1.0_2023-10-06_17-11-29_798014/epoch-3 llama7b_sft1.0_e3
gpt4 score for llama7b_sft1.0_e3: 43.8125
claude_1 score for llama7b_sft1.0_e3: 55.1875
/ebs/

In [None]:
# save the winrate dict
if save_win_dict:
    with open(os.path.join(eval_dir,'sft_winrate.pkl'), 'wb') as f:
        pickle.dump(winrate, f)

if load_win_dict:
    winrate_addn = pickle.load(open(os.path.join(eval_dir, 'sft_winrate.pkl'), 'rb'))
    winrate = {**winrate, **winrate_addn}

winrate_per_model = {}
for key in winrate.keys():
    if key == 'llama7b_base':
        continue
    model_name = key[len('llama7b_'):-3]
    if model_name not in winrate_per_model.keys():
        winrate_per_model[model_name] = {}
    for eval in eval_list:
        if eval not in winrate_per_model[model_name].keys():
            winrate_per_model[model_name][eval] = []
        winrate_per_model[model_name][eval].append(winrate[key][eval])

# plot winrate for every model over training epochs
plt.figure()
for model_name in winrate_per_model.keys():
    for eval in ['gpt4']:
        plt.plot(winrate_per_model[model_name][eval], label=model_name)
    plt.legend()
    plt.title(eval)
    # plt.savefig(os.path.join(eval_dir, f'{model_name}_winrate.png'))

plt.show()
