# Notebook to Generate Figures in Paper

In [None]:
import os
import pdb
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl

In [None]:
LRS = {
    'xsum': '2e-7',
    'socialiqa': '2e-6',
    'mnli': '2e-6',
    'paws': '2e-6',
    'tulu': '2e-6',
}
sns.set_theme(font_scale=2.1, style='whitegrid')
sns.color_palette("colorblind")
font = {'family' : 'serif',
            # 'weight' : 'bold',
            'size'   : 19}
mpl.rcParams['figure.dpi'] = 600
mpl.rc('font', **font)
mpl.rc('xtick', labelsize=19) 
plt.rcParams["font.family"] = "Nimbus Roman"
mpl.rc('ytick', labelsize=19)

### Fig 2. Performance without Training
Group the datasets by whether they are improving over pre-training, plot the performance.

In [None]:
INSTRUCTION_EVAL_PRETTY_NAMES = {
    'boolq': 'BoolQ', 
    'openbookqa': 'OpenbookQA', 
    'arc_challenge': 'ARC Chal', 
    'arc_easy': 'ARC Easy', 
    'hellaswag': 'Hellaswag',
    'sciq': 'SciQ',
}
SFT_EVAL_PRETTY_NAMES = {
    'mnli': 'MNLI',
    'mnli_matched': 'MNLI_1',
    'mnli_matched_instruct': 'MNLI_1',
    'mnli_mismatched': 'MNLI_2',
    'rte': "RTE",
    'gpt3nli': "GPTNLI",
    'socialiqa': 'SocialIQa',
    'socialiqa_instruct': 'SocialIQa',
    'tweetqa': 'TweetQA',
    'sciq': 'SciQ',
    'xsum_instruct': 'XSum',
    'xsum': 'XSum',
    'xlsum': 'XLSum',
    'cnn': 'CNN',
    'paws': 'Paws',
    'paws_instruct': 'Paws',
    'qqp': 'QQP',
    'stsb': 'STS-B',
    'llmbar_Natural': 'LLMBar Natural',
    'llmbar_Adversarial_Manual': 'LLMBar AdvManual',
    'llmbar_Adversarial_Neighbor': 'LLMBar Neighbor'
}

In [None]:
def ckpt_vs_perf(list_of_datasets, save_name, ylim=None, ratio=-1):
    """
    Take a list of dataset names, visualize the prformance of each checkpoint in the same figure
    """
    it_perf = pd.read_csv(os.path.join(os.environ['base_dir'], 'results', 'analysis', 'official_eval_table.csv'))
    sft_perf = pd.read_csv(os.path.join(os.environ['base_dir'], 'results', 'analysis', 'all_perf_table.csv'))
    checkpoints = ['1000', '18000', '342000', '424000', '505000', '592000', '738000', 'main']
    perfs = []
    stds = []
    list_of_ds = []
    for dataset in list_of_datasets:
        # Gather the performance from the table
        for ckpt in checkpoints:
            if dataset in INSTRUCTION_EVAL_PRETTY_NAMES:
                list_of_ds.append(INSTRUCTION_EVAL_PRETTY_NAMES[dataset])
            else:
                list_of_ds.append(SFT_EVAL_PRETTY_NAMES[dataset])
            if ckpt != 'main':
                if dataset in INSTRUCTION_EVAL_PRETTY_NAMES:
                    # Instruction tuning base results
                    orig_model_id = 'checkpoint-' + ckpt
                else:
                    orig_model_id = 'olmo1b_checkpoint-' + ckpt + '_original_hf_4shots'
            else:
                if dataset in INSTRUCTION_EVAL_PRETTY_NAMES:
                    orig_model_id = 'olmo1b_original_hf'
                else:
                    orig_model_id = 'olmo1b_original_hf_4shots'
            orig_perf = sft_perf.loc[(sft_perf['model_id'] == orig_model_id) & (sft_perf['eval dataset'] == dataset)]
            if len(orig_perf) == 1 and dataset != 'sciq':
                perfs.append(orig_perf['Performance'].item())
            else:
                orig_perf = it_perf.loc[(it_perf['model_id'] == orig_model_id) & (it_perf['eval dataset'] == dataset)]
                if len(orig_perf) == 1:
                    perfs.append(orig_perf['Performance'].item())
                else:
                    perfs.append(None)
    print(len(perfs))
    print(len(list_of_ds))
    print(len(checkpoints))
    data_to_plot = pd.DataFrame({
        'Performance': perfs,
        'Dataset': list_of_ds,
        'ckpt_idx': [i for i in range(len(checkpoints))] * len(list_of_datasets)
        })
    # Create the plot
    dist_plot = sns.lineplot(data=data_to_plot, x="ckpt_idx", y="Performance", marker='o', style="Dataset", hue="Dataset", legend="auto", palette='husl', linewidth=2.5, markersize=9)
    if ylim != None:
        dist_plot.set(xticks=[i for i in range(len(checkpoints))], xlim=[-0.2, len(checkpoints)-0.8], ylim=ylim, xlabel=None)
        # , xlabel="Pretraining Steps"
    else:
        dist_plot.set(xticks=[i for i in range(len(checkpoints))], xlim=[-0.2, len(checkpoints)-0.8], ylim=[0.0, 1.0], xlabel=None)
    sns.move_legend(dist_plot, "upper left", bbox_to_anchor=(1, 1))
    dist_plot.set_xticklabels(checkpoints, rotation=30)
    plt.savefig(os.path.join(os.environ['base_dir'], "results", "analysis", f"{save_name}.pdf"), bbox_inches='tight')

In [None]:
# Plot the increasing dataset
increasing_datasets = ['hellaswag', 'arc_challenge', 'arc_easy', 'sciq', 'openbookqa']
# BoolQ is not improving
# Plot the decreasing dataset
ckpt_vs_perf(increasing_datasets, save_name='base_improving', ylim=[0.2, 1.0])

In [None]:
increasing_datasets = ['mnli_matched', 'xsum', 'socialiqa', 'paws', 'boolq']
# Plot the decreasing dataset
ckpt_vs_perf(increasing_datasets, save_name='base_notimproving', ylim=[0.0, 0.8])

#### Instruction Following Ability
Run the evaluation on LLMBar

In [None]:
datasets = ['llmbar_Natural', 'llmbar_Adversarial_Neighbor', 'llmbar_Adversarial_Manual']
# Plot the decreasing dataset
ckpt_vs_perf(datasets, save_name='llmbar_untrained')

### Fig 4. IFT Performance-Per Task

In [None]:
def it_ckpt_vs_perf_plot(eval_dataset, tight=False):
    """
    Pass tight for displaying in the main content, otherwise all figs here go into appendix
    """
    # Get the performance table
    all_perf = pd.read_csv(os.path.join(os.environ['base_dir'], 'results', 'analysis', 'official_eval_table.csv'))
    # Gather
    ft_perfs = []
    orig_perfs = []
    checkpoints = ['1000', '18000', '342000', '424000', '505000', '592000', '738000', 'main']
    for ckpt in checkpoints:
        if ckpt != 'main':
            orig_model_id = 'checkpoint-' + ckpt
            ft_model_id = 'olmo1b_hf_ckpt' + ckpt + '_tulu_5epoch_2e-6'
        else:
            orig_model_id = 'olmo1b_original_hf'
            ft_model_id = 'olmo1b_hf_main_tulu_5epoch_2e-6'
        orig_perf = all_perf.loc[(all_perf['model_id'] == orig_model_id) & (all_perf['eval dataset'] == eval_dataset)]
        ft_perf = all_perf.loc[(all_perf['model_id'] == ft_model_id) & (all_perf['eval dataset'] == eval_dataset)]
        if len(orig_perf) == 1:
            orig_perfs.append(orig_perf['Performance'].item())
        else:
            orig_perfs.append(None)
        if len(ft_perf) == 1:
            ft_perfs.append(ft_perf['Performance'].item())
        else:
            ft_perfs.append(None)
    data_to_plot = pd.DataFrame({
        'Performance': ft_perfs + orig_perfs,
        'Variant': ['Instruct' for _ in range(len(orig_perfs))] + ['BASE' for _ in range(len(orig_perfs))],
        'ckpt_idx': [i for i in range(len(checkpoints))] + [i for i in range(len(checkpoints))]
        })
    if tight:
        dist_plot = sns.lineplot(data=data_to_plot, x="ckpt_idx", y="Performance", marker='o', style="Variant", hue="Variant", legend=None, linewidth=2.5, markersize=9)
    else:
        dist_plot = sns.lineplot(data=data_to_plot, x="ckpt_idx", y="Performance", marker='o', style="Variant", hue="Variant", legend="auto", linewidth=2.5, markersize=9)
    if eval_dataset == 'sciq':
        dist_plot.set(xticks=[i for i in range(len(checkpoints))], xlim=[-0.2, len(checkpoints)-0.8], ylim=[0.3, 1.0], xlabel=None)
    else:
        if tight:
            dist_plot.set(xticks=[i for i in range(len(checkpoints))], xlim=[-0.2, len(checkpoints)-0.8], ylim=[0.2, 0.8], aspect=4, xlabel=None)
        else:
            dist_plot.set(xticks=[i for i in range(len(checkpoints))], xlim=[-0.2, len(checkpoints)-0.8], ylim=[0.2, 0.8], xlabel=None)
    # dist_plot.set(xticks=[i for i in range(len(checkpoints))], xlim=[-0.2, len(checkpoints)-0.8], xlabel="Pre-training Step")
    dist_plot.set_xticklabels(checkpoints, rotation=30)
    # pdb.set_trace()
    if tight:
        plt.savefig(os.path.join(os.environ['base_dir'], "results", "analysis", f"it_eval{eval_dataset}_tight.pdf"), bbox_inches='tight')
    else:
        plt.savefig(os.path.join(os.environ['base_dir'], "results", "analysis", f"it_eval{eval_dataset}.pdf"), bbox_inches='tight')
    plt.clf()

In [None]:
it_ckpt_vs_perf_plot(eval_dataset='hellaswag')
it_ckpt_vs_perf_plot(eval_dataset='boolq')
it_ckpt_vs_perf_plot(eval_dataset='arc_easy')
it_ckpt_vs_perf_plot(eval_dataset='arc_challenge')
it_ckpt_vs_perf_plot(eval_dataset='sciq')
it_ckpt_vs_perf_plot(eval_dataset='openbookqa')

it_ckpt_vs_perf_plot(eval_dataset='hellaswag', tight=True)

### Fig 3. SFT Performance per-task

In [None]:
def avg_change_table_by_checkpoint(dataset_pairs):
    """
    Table generation for a list of given base datasets and target dataset combination
    dataset_pairs: A list of couples where the first is eval dataset, and the second element is base dataset
    return:
        A table that has three columns: checkpoint, average raw change, average change percentage
    """
    # For each checkpoint, retrieve a list of performance (orig + ft)
    all_perf = pd.read_csv(os.path.join(os.environ['base_dir'], 'results', 'analysis', 'all_perf_table.csv'))
    all_perf.drop(columns=['std'])
    all_perf = pd.concat([all_perf, pd.read_csv(os.path.join(os.environ['base_dir'], 'results', 'analysis', 'official_eval_table.csv'))], ignore_index=True)
    res = []
    # Gather
    checkpoints = ['1000', '18000', '342000', '424000', '505000', '592000', '738000', 'main']
    for ckpt_idx, ckpt in enumerate(checkpoints):
        tot_ds = 0
        raw_difference = 0
        diff_ratio = 0
        for eval_ds, base_ds in dataset_pairs:
            # Retrieve the orig and ft performance
            if eval_ds in INSTRUCTION_EVAL_PRETTY_NAMES:
                if ckpt != 'main':
                    orig_model_id = 'checkpoint-' + ckpt
                    ft_model_id = 'olmo1b_hf_ckpt' + ckpt + '_tulu_5epoch_2e-6'
                else:
                    orig_model_id = 'olmo1b_original_hf'
                    ft_model_id = 'olmo1b_hf_main_tulu_5epoch_2e-6'
            else:
                if ckpt != 'main':
                    ft_model_id = 'olmo1b_hf_ckpt' + ckpt + '_' + base_ds + '_' + '3epoch_' + LRS[base_ds] + '_4shots'
                    orig_model_id = 'olmo1b_checkpoint-' + ckpt + '_original_hf_4shots'
                else:
                    ft_model_id = f'olmo1b_hf_main_{base_ds}_3epoch_{LRS[base_ds]}_4shots'
                    orig_model_id = f'olmo1b_original_hf_4shots'
            orig_perf = all_perf.loc[(all_perf['model_id'] == orig_model_id) & (all_perf['eval dataset'] == eval_ds)]
            ft_perf = all_perf.loc[(all_perf['model_id'] == ft_model_id) & (all_perf['eval dataset'] == eval_ds)]
            if len(orig_perf) == 1 and len(ft_perf) == 1:
                raw_difference += ft_perf['Performance'].item() - orig_perf['Performance'].item()
                diff_ratio += (ft_perf['Performance'].item() - orig_perf['Performance'].item()) / orig_perf['Performance'].item()
                tot_ds += 1
            else:
                print("This combination is problematic: ", eval_ds, base_ds)
                print("At checkpoint", ckpt)
                print(orig_perf)
                print(ft_perf)
        if ckpt_idx != 0:
            if ckpt != 'main':
                slope =  (raw_difference / tot_ds - res[-1]["Average Raw Change"]) / (int(ckpt) - int(checkpoints[ckpt_idx-1]))
            else:
                slope =  (raw_difference / tot_ds - res[-1]["Average Raw Change"]) / (750000 - int(checkpoints[ckpt_idx-1]))
        else:
            slope = 0.0
        res.append({
            "Checkpoint": ckpt,
            "Average Raw Change": raw_difference / tot_ds,
            "Avg Diff Ratio%": diff_ratio / tot_ds * 100,
            "Total DS": tot_ds,
            "Slope by 100000Step": slope * 100000
        })
    return pd.DataFrame(res)

In [None]:
tab = avg_change_table_by_checkpoint(dataset_pairs=[('mnli_matched', 'mnli'), ('paws', 'paws'), 
                                              ('xsum', 'xsum'), ('mnli_mismatched', 'mnli'), 
                                                ('xlsum', 'xsum'), ('socialiqa', 'socialiqa'), ('boolq', 'tulu')])

tab.to_csv(os.path.join(os.environ['base_dir'], 'results', 'analysis', 'avg_gain_improv_group.csv'))

tab = avg_change_table_by_checkpoint(dataset_pairs=[('sciq', 'tulu'), ('hellaswag', 'tulu'), 
                                              ('arc_challenge', 'tulu'), ('arc_easy', 'tulu'), 
                                                ('openbookqa', 'tulu'), ('sciq', 'tulu')])
tab.to_csv(os.path.join(os.environ['base_dir'], 'results', 'analysis', 'avg_lose_unimprov_group.csv'))

In [None]:
def change_table_by_checkpoint_for_plot_sft(dataset_pairs):
    """
    Table generation for a list of given base datasets and target dataset combination
    dataset_pairs: A list of couples where the first is eval dataset, and the second element is base dataset
    return:
        A table that has three columns: checkpoint, average raw change, average change percentage
    """
    # For each checkpoint, retrieve a list of performance (orig + ft)
    all_perf = pd.read_csv(os.path.join(os.environ['base_dir'], 'results', 'analysis', 'all_perf_table.csv'))
    all_perf.drop(columns=['std'])
    all_perf = pd.concat([all_perf, pd.read_csv(os.path.join(os.environ['base_dir'], 'results', 'analysis', 'official_eval_table.csv'))], ignore_index=True)
    res = {
        "Checkpoint": [],
        "Raw Change": []
    }
    # Gather
    checkpoints = ['1000', '18000', '342000', '424000', '505000', '592000', '738000', 'main']
    for ckpt_idx, ckpt in enumerate(checkpoints):
        tot_ds = 0
        raw_difference = []
        ckpts_plot = []
        for eval_ds, base_ds in dataset_pairs:
            # Retrieve the orig and ft performance
            if eval_ds in INSTRUCTION_EVAL_PRETTY_NAMES:
                if ckpt != 'main':
                    orig_model_id = 'checkpoint-' + ckpt
                    ft_model_id = 'olmo1b_hf_ckpt' + ckpt + '_tulu_5epoch_2e-6'
                else:
                    orig_model_id = 'olmo1b_original_hf'
                    ft_model_id = 'olmo1b_hf_main_tulu_5epoch_2e-6'
            else:
                if ckpt != 'main':
                    ft_model_id = 'olmo1b_hf_ckpt' + ckpt + '_' + base_ds + '_' + '3epoch_' + LRS[base_ds] + '_4shots'
                    orig_model_id = 'olmo1b_checkpoint-' + ckpt + '_original_hf_4shots'
                else:
                    ft_model_id = f'olmo1b_hf_main_{base_ds}_3epoch_{LRS[base_ds]}_4shots'
                    orig_model_id = f'olmo1b_original_hf_4shots'
            orig_perf = all_perf.loc[(all_perf['model_id'] == orig_model_id) & (all_perf['eval dataset'] == eval_ds)]
            ft_perf = all_perf.loc[(all_perf['model_id'] == ft_model_id) & (all_perf['eval dataset'] == eval_ds)]
            if len(orig_perf) == 1 and len(ft_perf) == 1:
                # Weighted Performance Change
                # raw_difference += [(ft_perf['Performance'].item() - orig_perf['Performance'].item()) / orig_perf['Performance'].item()]
                raw_difference += [(ft_perf['Performance'].item() - orig_perf['Performance'].item())]
                ckpts_plot += [ckpt]
                tot_ds += 1
            else:
                print("This combination is problematic: ", eval_ds, base_ds)
                print("At checkpoint", ckpt)
                print(orig_perf)
                print(ft_perf)
        
        res["Checkpoint"] += ckpts_plot
        res["Raw Change"] += raw_difference
    return pd.DataFrame(res)

In [None]:
# Weighted PTFT Comparison Change
# Findings 2
checkpoints = ['1000', '18000', '342000', '424000', '505000', '592000', '738000', 'main']
tab_ft = change_table_by_checkpoint_for_plot_sft(dataset_pairs=[('mnli_matched', 'mnli'), ('paws', 'paws'), 
                                              ('xsum', 'xsum'), ('mnli_mismatched', 'mnli'), 
                                                ('xlsum', 'xsum'), ('socialiqa', 'socialiqa'), ('boolq', 'tulu')])
tab_ft["Group"] = "Learned in FT"
tab_pt = change_table_by_checkpoint_for_plot_sft(dataset_pairs=[('sciq', 'tulu'), ('hellaswag', 'tulu'), 
                                              ('arc_challenge', 'tulu'), ('arc_easy', 'tulu'), 
                                                ('openbookqa', 'tulu'), ('sciq', 'tulu')])
tab_pt["Group"] = "Learned in PT"
# Concate two table
new_tab = pd.concat([tab_ft, tab_pt], ignore_index=True)
bar_plot = sns.barplot(x='Checkpoint', y='Raw Change', data=new_tab, hue='Group', errorbar=('ci', 90), palette="Set2", legend=None)
bar_plot.set(xticks=[i for i in range(len(checkpoints))], xlabel=None, ylabel="Performance Change", aspect=8)
bar_plot.set_xticklabels(checkpoints, rotation=30)
# sns.move_legend(bar_plot, "upper left", bbox_to_anchor=(1, 1))
plt.savefig(os.path.join(os.environ['base_dir'], "results", "analysis", f"ptft_comparison_bar.pdf"), bbox_inches='tight')
plt.clf()
print(sns.color_palette("Set2").as_hex())

In [None]:
def ckpt_vs_sft_perf_plot(eval_dataset, base_dataset, num_shots=4, tight=False, ylim=None):
    # Generate the figure the produce checkpoint v.s. performance plot
    all_perf = pd.read_csv(os.path.join(os.environ['base_dir'], 'results', 'analysis', 'all_perf_table.csv'))

    ft_perfs = []
    orig_perfs = []
    ft_stds = []
    orig_stds = []
    checkpoints = ['1000', '18000', '342000', '424000', '505000', '592000', '738000', 'main']
    epoch = '5' if base_dataset == 'tulu' else '3'

    # Load prediction of the corresponding eval dataset, for each checkpoint
    # Both original and fine-tuned
    for ckpt in checkpoints:
        if ckpt != 'main':
            model_id = 'olmo1b_hf_ckpt' + ckpt + '_' + base_dataset + '_' + epoch +'epoch_' + LRS[base_dataset] + f'_{str(num_shots)}shots'
            orig_model_id = 'olmo1b_checkpoint-' + ckpt + f'_original_hf_{str(num_shots)}shots'
        else:
            model_id = f'olmo1b_hf_main_{base_dataset}_{epoch}epoch_{LRS[base_dataset]}_{str(num_shots)}shots'
            orig_model_id = f'olmo1b_original_hf_{str(num_shots)}shots'
        ft_perf = all_perf.loc[(all_perf['model_id'] == model_id) & (all_perf['eval dataset'] == eval_dataset)]
        # Load the original model
        orig_perf = all_perf.loc[(all_perf['model_id'] == orig_model_id) & (all_perf['eval dataset'] == eval_dataset)]
        if len(ft_perf) == 1:
            ft_perfs.append(ft_perf['Performance'].item())
            if 'std' in ft_perf:
                ft_stds.append(ft_perf['std'].item())
        else:
            ft_perfs.append(None)
            ft_stds.append(None)
        if len(orig_perf) == 1:
            orig_perfs.append(orig_perf['Performance'].item())
            if 'std' in orig_perf:
                orig_stds.append(orig_perf['std'].item())
        else:
            orig_perfs.append(None)
            orig_stds.append(None)
        
    low1, high1, low2, high2, fill_x = [], [], [], [], []
    for i in range(len(ft_perfs)):
        if ft_perfs[i] is not None and orig_perfs[i] is not None:
                low1.append(ft_perfs[i] - ft_stds[i])
                high1.append(ft_perfs[i] + ft_stds[i])
                low2.append(orig_perfs[i] - orig_stds[i])
                high2.append(orig_perfs[i] + orig_stds[i])
                fill_x.append(i)
    data_to_plot = pd.DataFrame({
        'Performance': ft_perfs + orig_perfs,
        'Variant': ['Fine-Tuned' for _ in range(len(ft_perfs))] + ['BASE' for _ in range(len(orig_perfs))],
        'ckpt_idx': [i for i in range(len(checkpoints))] + [i for i in range(len(checkpoints))]
    })
    # Uncomment if fitting a regression line
    # dist_plot = sns.lmplot(data=data_to_plot, x="ckpt_idx", y="Performance", hue="Fine-tuned", ci=95, robust=True, legend_out=False)
    if tight or ylim is not None:
        dist_plot = sns.lineplot(data=data_to_plot, x="ckpt_idx", y="Performance", marker='o', style="Variant", hue="Variant", legend=None, linewidth=2.5, markersize=9)
    else:
        dist_plot = sns.lineplot(data=data_to_plot, x="ckpt_idx", y="Performance", marker='o', style="Variant", hue="Variant", legend="auto", linewidth=2.5, markersize=9)
    plt.fill_between(fill_x, low1, high1, alpha=0.4)
    plt.fill_between(fill_x, low2, high2, alpha=0.4)
    # dist_plot = sns.lmplot(data=data_to_plot, x="ckpt_idx", y="Performance", hue="Fine-tuned", ci=95, legend_out=False)
    # dist_plot.set(xticks=[i for i in range(len(checkpoints))], xlim=[-0.2, len(checkpoints)-0.8], ylim=[0.0, 1.0], xlabel="Pretraining Steps")
    if tight:
        dist_plot.set(xticks=[i for i in range(len(checkpoints))], xlim=[-0.2, len(checkpoints)-0.8], ylim=[0.2, 0.8], xlabel=None, aspect=4)
        # dist_plot.set_xticklabels([])
    elif ylim != None:
        dist_plot.set(xticks=[i for i in range(len(checkpoints))], xlim=[-0.2, len(checkpoints)-0.8], ylim=ylim, xlabel=None, aspect=5)
        # dist_plot.set_xticklabels([])
        dist_plot.set_xticklabels(checkpoints, rotation=30)
    else:
        dist_plot.set(xticks=[i for i in range(len(checkpoints))], xlim=[-0.2, len(checkpoints)-0.8], ylim=[0.0, 1.0], xlabel="Pretraining Steps")
        dist_plot.set_xticklabels(checkpoints, rotation=30)
    if 'instruct' in eval_dataset or 'inputoutput' in eval_dataset:
        plt.savefig(os.path.join(os.environ['base_dir'], "results", "taskformat", f"sft_eval{eval_dataset}-train{base_dataset}.pdf"), bbox_inches='tight')
    else:
        if tight:
            plt.savefig(os.path.join(os.environ['base_dir'], "results", "analysis", f"sft_eval{eval_dataset}-train{base_dataset}_tight.pdf"), bbox_inches='tight')
        elif ylim is not None:
            plt.savefig(os.path.join(os.environ['base_dir'], "results", "analysis", f"sft_eval{eval_dataset}-train{base_dataset}_main_display.pdf"), bbox_inches='tight')
        else:
            plt.savefig(os.path.join(os.environ['base_dir'], "results", "analysis", f"sft_eval{eval_dataset}-train{base_dataset}.pdf"), bbox_inches='tight')
    plt.clf()

In [None]:
ckpt_vs_sft_perf_plot(eval_dataset='mnli_matched', base_dataset='mnli', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='mnli_mismatched', base_dataset='mnli', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='socialiqa', base_dataset='socialiqa', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='xsum', base_dataset='xsum', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='xlsum', base_dataset='xsum', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='paws', base_dataset='paws', num_shots=4)

ckpt_vs_sft_perf_plot(eval_dataset='mnli_matched', base_dataset='mnli', num_shots=4, tight=True)

#### Instruction Following Ability


In [None]:
ckpt_vs_sft_perf_plot(eval_dataset='llmbar_Natural', base_dataset='tulu', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='llmbar_Adversarial_Manual', base_dataset='tulu', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='llmbar_Adversarial_Neighbor', base_dataset='tulu', num_shots=4)

### Fig 6. Cross-task generalization

In [None]:
# TODO: What if we group them by generation v.s. classification? Same format as Fig 1.

ckpt_vs_sft_perf_plot(eval_dataset='paws', base_dataset='mnli', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='socialiqa', base_dataset='mnli', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='xsum', base_dataset='mnli', num_shots=4)
####
ckpt_vs_sft_perf_plot(eval_dataset='mnli_matched', base_dataset='socialiqa', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='mnli_mismatched', base_dataset='socialiqa', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='paws', base_dataset='socialiqa', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='xsum', base_dataset='socialiqa', num_shots=4)
####
ckpt_vs_sft_perf_plot(eval_dataset='socialiqa', base_dataset='xsum', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='paws', base_dataset='xsum', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='mnli_matched', base_dataset='xsum', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='mnli_mismatched', base_dataset='xsum', num_shots=4)
###
ckpt_vs_sft_perf_plot(eval_dataset='mnli_matched', base_dataset='paws', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='mnli_mismatched', base_dataset='paws', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='socialiqa', base_dataset='paws', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='xsum', base_dataset='paws', num_shots=4)

In [None]:
def change_table_by_checkpoint_for_plot(dataset_pairs):
    """
    Table generation for a list of given base datasets and target dataset combination
    dataset_pairs: A list of couples where the first is eval dataset, and the second element is base dataset
    return:
        A table that has three columns: checkpoint, average raw change, average change percentage
    """
    # For each checkpoint, retrieve a list of performance (orig + ft)
    all_perf = pd.read_csv(os.path.join(os.environ['base_dir'], 'results', 'analysis', 'all_perf_table.csv'))
    all_perf.drop(columns=['std'])
    all_perf = pd.concat([all_perf, pd.read_csv(os.path.join(os.environ['base_dir'], 'results', 'analysis', 'official_eval_table.csv'))], ignore_index=True)
    res = {
        "Checkpoint": [],
        "Raw Change": [],
        "Change Ratio": []
    }
    # Gather
    checkpoints = ['1000', '18000', '342000', '424000', '505000', '592000', '738000', 'main']
    for ckpt_idx, ckpt in enumerate(checkpoints):
        tot_ds = 0
        raw_difference = []
        raw_difference_ratio = []
        ckpts_plot = []
        for eval_ds, base_ds in dataset_pairs:
            # Retrieve the orig and ft performance
            if eval_ds in INSTRUCTION_EVAL_PRETTY_NAMES and eval_ds != 'sciq':
                if ckpt != 'main':
                    orig_model_id = 'checkpoint-' + ckpt
                    ft_model_id = 'olmo1b_hf_ckpt' + ckpt + '_tulu_5epoch_2e-6'
                else:
                    orig_model_id = 'olmo1b_original_hf'
                    ft_model_id = 'olmo1b_hf_main_tulu_5epoch_2e-6'
            else:
                if ckpt != 'main':
                    ft_model_id = 'olmo1b_hf_ckpt' + ckpt + '_' + base_ds + '_' + '3epoch_' + LRS[base_ds] + '_4shots'
                    orig_model_id = 'olmo1b_checkpoint-' + ckpt + '_original_hf_4shots'
                else:
                    ft_model_id = f'olmo1b_hf_main_{base_ds}_3epoch_{LRS[base_ds]}_4shots'
                    orig_model_id = f'olmo1b_original_hf_4shots'
            orig_perf = all_perf.loc[(all_perf['model_id'] == orig_model_id) & (all_perf['eval dataset'] == eval_ds)]
            ft_perf = all_perf.loc[(all_perf['model_id'] == ft_model_id) & (all_perf['eval dataset'] == eval_ds)]
            if len(orig_perf) == 1 and len(ft_perf) == 1:
                raw_difference_ratio += [(ft_perf['Performance'].item() - orig_perf['Performance'].item()) / orig_perf['Performance'].item()]
                raw_difference += [(ft_perf['Performance'].item() - orig_perf['Performance'].item())]
                ckpts_plot += [ckpt]
                tot_ds += 1
            else:
                print("This combination is problematic: ", eval_ds, base_ds)
                print("At checkpoint", ckpt)
                print(orig_perf)
                print(ft_perf)
        
        res["Checkpoint"] += ckpts_plot
        res["Raw Change"] += raw_difference
        res["Change Ratio"] += raw_difference_ratio
    return pd.DataFrame(res)

In [None]:
# Deprecated
checkpoints = ['1000', '18000', '342000', '424000', '505000', '592000', '738000', 'main']

class_to_gen_tab = change_table_by_checkpoint_for_plot(dataset_pairs=[('xsum', 'mnli'), ('socialiqa', 'mnli'),
                                              ('xsum', 'paws'), ('socialiqa', 'paws')])
class_to_gen_tab['Direction'] = "Class->Gen"

gen_to_class_tab = change_table_by_checkpoint_for_plot(dataset_pairs=[
                                              ('paws', 'xsum'), ('mnli_matched', 'xsum'),
                                                ('mnli_matched', 'socialiqa'), ('paws', 'socialiqa')])
gen_to_class_tab['Direction'] = "Gen->Class"
# Concate two table
new_tab = pd.concat([class_to_gen_tab, gen_to_class_tab], ignore_index=True)
bar_plot = sns.barplot(x='Checkpoint', y='Raw Change', data=new_tab, hue='Direction', errorbar=('ci', 90))
bar_plot.set(xticks=[i for i in range(len(checkpoints))], xlabel=None, ylabel="Weighted Performance Change", ylim=[-1, 1])
bar_plot.set_xticklabels(checkpoints, rotation=30)
sns.move_legend(bar_plot, "upper left", bbox_to_anchor=(1, 1))
plt.savefig(os.path.join(os.environ['base_dir'], "results", "analysis", f"weighted_task_transfer_bar.pdf"), bbox_inches='tight')
plt.clf()

In [None]:
# Compute Across all checkpoints
print("Mean decrease percentage of class -> gen is ", class_to_gen_tab.mean(numeric_only=True))
print("Std", class_to_gen_tab.std(numeric_only=True))
print("Mean decrease percentage of gen -> class is ", gen_to_class_tab.mean(numeric_only=True))
print("Std", gen_to_class_tab.std(numeric_only=True))

### Fig 7. Cross-domain generalization

In [None]:
checkpoints = ['1000', '18000', '342000', '424000', '505000', '592000', '738000', 'main']

tab = change_table_by_checkpoint_for_plot(dataset_pairs=[('rte', 'mnli'), ('gpt3nli', 'mnli'), 
                                              ('cnn', 'xsum'), ('qqp', 'paws'), ('stsb', 'paws'),
                                                ('tweetqa', 'socialiqa'), ('sciq', 'socialiqa')])
bar_plot = sns.barplot(x='Checkpoint', y='Raw Change', data=tab, hue='Checkpoint', ci=90)
bar_plot.set(xticks=[i for i in range(len(checkpoints))], xlabel=None, ylabel="Weighted Performance Change")
bar_plot.set_xticklabels(checkpoints, rotation=30)
print(tab)
plt.savefig(os.path.join(os.environ['base_dir'], "results", "analysis", f"weighted_perf_change_ood.pdf"), bbox_inches='tight')
plt.clf()

In [None]:
# Compute the cross-domain generalization avg acorss checkpoint
nli_tab = change_table_by_checkpoint_for_plot(dataset_pairs=[('rte', 'mnli'), ('gpt3nli', 'mnli')])
nli_tab["Task"] = "NLI"

summary_tab = change_table_by_checkpoint_for_plot(dataset_pairs=[
                                              ('cnn', 'xsum')])
summary_tab["Task"] = "Sum"

q_gen_tab = change_table_by_checkpoint_for_plot(dataset_pairs=[
                                                ('tweetqa', 'socialiqa'), ('sciq', 'socialiqa')])
q_gen_tab["Task"] = "QGen"

paraphrase_tab = change_table_by_checkpoint_for_plot(dataset_pairs=[('qqp', 'paws'), ('stsb', 'paws')])
paraphrase_tab["Task"] = "Para"


# Concate two table
new_tab = pd.concat([q_gen_tab, summary_tab, nli_tab, paraphrase_tab], ignore_index=True)
bar_plot = sns.barplot(x='Task', y='Raw Change', data=new_tab, hue='Task', errorbar=('ci', 90), width=0.6, err_kws={'linewidth': 6.0})
bar_plot.set(xticks=[i for i in range(4)], xlabel=None, ylabel=None, aspect=4)
bar_plot.set_xticklabels(["Question \nGeneration", "Summary \nGeneration", "NLI", "Paraphrase \nDetection"])
# sns.move_legend(bar_plot, "upper left", bbox_to_anchor=(1, 1))
plt.savefig(os.path.join(os.environ['base_dir'], "results", "analysis", f"weighted_ood_transfer_bar.pdf"), bbox_inches='tight')
plt.clf()

In [None]:
ckpt_vs_sft_perf_plot(eval_dataset='rte', base_dataset='mnli', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='gpt3nli', base_dataset='mnli', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='tweetqa', base_dataset='socialiqa', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='sciq', base_dataset='socialiqa', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='cnn', base_dataset='xsum', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='qqp', base_dataset='paws', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='stsb', base_dataset='paws', num_shots=4)

In [None]:
# Example plot to appear in the main content
# ckpt_vs_sft_perf_plot(eval_dataset='gpt3nli', base_dataset='mnli', num_shots=4, ylim=[0.2, 1.0])
ckpt_vs_sft_perf_plot(eval_dataset='qqp', base_dataset='paws', num_shots=4, ylim=[0.4, 1.2])

### Fig 5. Performance By Task Format

In [None]:
# Code to output to sanity check
ckpt_vs_sft_perf_plot(eval_dataset='mnli_matched_instruct', base_dataset='mnli', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='mnli_mismatched_instruct', base_dataset='mnli', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='socialiqa_instruct', base_dataset='socialiqa', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='xsum_instruct', base_dataset='xsum', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='paws_instruct', base_dataset='paws', num_shots=4)

ckpt_vs_sft_perf_plot(eval_dataset='mnli_matched_inputoutput', base_dataset='mnli', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='mnli_mismatched_inputoutput', base_dataset='mnli', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='socialiqa_inputoutput', base_dataset='socialiqa', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='xsum_inputoutput', base_dataset='xsum', num_shots=4)
ckpt_vs_sft_perf_plot(eval_dataset='paws_inputoutput', base_dataset='paws', num_shots=4)

In [None]:
# Code to show performance in different task format
def task_ckpt_vs_sft_plot(eval_dataset, base_dataset, ylim=[0.0, 1.0], num_shots=4, legend=False):
    # Generate the figure the produce checkpoint v.s. performance plot
    all_perf = pd.read_csv(os.path.join(os.environ['base_dir'], 'results', 'analysis', 'all_perf_table.csv'))

    # Gather the performance in different task format
    # FT instruct, FT inputoutput, FT default

    ft_default_perfs = []
    orig_default_perfs = []
    ft_instruct_perfs = []
    orig_instruct_perfs = []
    ft_inputoutput_perfs = []
    orig_inputoutput_perfs = []
    # ft_stds = []
    # orig_stds = []
    checkpoints = ['1000', '18000', '342000', '424000', '505000', '592000', '738000', 'main']
    epoch = '3'

    # Load prediction of the corresponding eval dataset, for each checkpoint
    # Both original and fine-tuned
    for ckpt in checkpoints:
        if ckpt != 'main':
            model_id = 'olmo1b_hf_ckpt' + ckpt + '_' + base_dataset + '_' + epoch +'epoch_' + LRS[base_dataset] + f'_{str(num_shots)}shots'
            orig_model_id = 'olmo1b_checkpoint-' + ckpt + f'_original_hf_{str(num_shots)}shots'
        else:
            model_id = f'olmo1b_hf_main_{base_dataset}_{epoch}epoch_{LRS[base_dataset]}_{str(num_shots)}shots'
            orig_model_id = f'olmo1b_original_hf_{str(num_shots)}shots'
        ft_default_perf = all_perf.loc[(all_perf['model_id'] == model_id) & (all_perf['eval dataset'] == eval_dataset)]
        orig_default_perf = all_perf.loc[(all_perf['model_id'] == orig_model_id) & (all_perf['eval dataset'] == eval_dataset)]
        ft_default_perfs += [ft_default_perf['Performance'].item()] if ft_default_perf['Performance'].item() is not None else [None]
        orig_default_perfs += [orig_default_perf['Performance'].item()] if orig_default_perf['Performance'].item() is not None else [None]

        ft_instruct_perf = all_perf.loc[(all_perf['model_id'] == model_id) & (all_perf['eval dataset'] == eval_dataset + '_instruct')]
        orig_instruct_perf = all_perf.loc[(all_perf['model_id'] == orig_model_id) & (all_perf['eval dataset'] == eval_dataset + '_instruct')]
        ft_instruct_perfs += [ft_instruct_perf['Performance'].item()] if ft_instruct_perf['Performance'].item() is not None else [None]
        orig_instruct_perfs += [orig_instruct_perf['Performance'].item()] if orig_instruct_perf['Performance'].item() is not None else [None]

        ft_inputoutput_perf = all_perf.loc[(all_perf['model_id'] == model_id) & (all_perf['eval dataset'] == eval_dataset + '_inputoutput')]
        orig_inputoutput_perf = all_perf.loc[(all_perf['model_id'] == orig_model_id) & (all_perf['eval dataset'] == eval_dataset + '_inputoutput')]
        ft_inputoutput_perfs += [ft_inputoutput_perf['Performance'].item()] if ft_inputoutput_perf['Performance'].item() is not None else [None]
        orig_inputoutput_perfs += [orig_inputoutput_perf['Performance'].item()] if orig_inputoutput_perf['Performance'].item() is not None else [None]

    assert len(ft_default_perfs) == len(orig_default_perfs)
    assert len(ft_instruct_perfs) == len(orig_instruct_perfs)
    assert len(ft_inputoutput_perfs) == len(orig_inputoutput_perfs)
    data_to_plot = pd.DataFrame({
        'Performance': ft_default_perfs + orig_default_perfs
                        + ft_instruct_perfs + orig_instruct_perfs
                        + ft_inputoutput_perfs + orig_inputoutput_perfs,
        'Variant': ['Fine-Tuned' for _ in range(len(ft_default_perfs))] + ['BASE' for _ in range(len(ft_default_perfs))]
                + ['Fine-Tuned' for _ in range(len(ft_default_perfs))] + ['BASE' for _ in range(len(ft_default_perfs))]
                + ['Fine-Tuned' for _ in range(len(ft_default_perfs))] + ['BASE' for _ in range(len(ft_default_perfs))],
        'ckpt_idx': [i for i in range(len(checkpoints))] + [i for i in range(len(checkpoints))]
                + [i for i in range(len(checkpoints))] + [i for i in range(len(checkpoints))]
                + [i for i in range(len(checkpoints))] + [i for i in range(len(checkpoints))],
        'Format Type': ['Default' for _ in range(len(checkpoints))] + ['Default' for _ in range(len(checkpoints))]
                + ['Instruct' for _ in range(len(checkpoints))] + ['Instruct' for _ in range(len(checkpoints))]
                + ['IO' for _ in range(len(checkpoints))] + ['IO' for _ in range(len(checkpoints))]
    })
    if legend:
        dist_plot = sns.lineplot(data=data_to_plot, x="ckpt_idx", y="Performance", markers=['o', 'o'], style="Variant", hue="Format Type", legend="auto", linewidth=2.4, markersize=10, palette='colorblind')
        sns.move_legend(dist_plot, "upper left", bbox_to_anchor=(1, 1))
    else:
        dist_plot = sns.lineplot(data=data_to_plot, x="ckpt_idx", y="Performance", markers=['o', 'o'], style="Variant", hue="Format Type", legend=None, linewidth=2.4, markersize=10, palette='colorblind')
    # dist_plot = sns.lmplot(data=data_to_plot, x="ckpt_idx", y="Performance", hue="Fine-tuned", ci=95, legend_out=False)
    # dist_plot.set(xticks=[i for i in range(len(checkpoints))], xlim=[-0.2, len(checkpoints)-0.8], ylim=[0.0, 1.0], xlabel="Pretraining Steps")
    dist_plot.set(xticks=[i for i in range(len(checkpoints))], xlim=[-0.2, len(checkpoints)-0.8], ylim=ylim, xlabel=None)
    dist_plot.set_xticklabels(checkpoints, rotation=30)

    plt.savefig(os.path.join(os.environ['base_dir'], "results", "taskformat", f"task_format_eval{eval_dataset}-train{base_dataset}.pdf"), bbox_inches='tight')
    plt.clf()

In [None]:
task_ckpt_vs_sft_plot(eval_dataset='mnli_matched', base_dataset='mnli', ylim=[0.25, 0.85], num_shots=4)
task_ckpt_vs_sft_plot(eval_dataset='paws', base_dataset='paws', ylim=[0.4, 1.0], num_shots=4)
task_ckpt_vs_sft_plot(eval_dataset='xsum', base_dataset='xsum', ylim=[0.0, 0.2], num_shots=4)
task_ckpt_vs_sft_plot(eval_dataset='socialiqa', base_dataset='socialiqa', ylim=[0.0, 0.8], num_shots=4, legend=True)