In [2]:
import altair as alt
from altair import datum
import functools
import pandas as pd
import sys
sys.path.append('/private/home/victorialin/Projects/fairseq-py')
from scripts.visualization.multilingual_few_shot_eval_utils import *


result_tables = {
    # "en-6.7B-dense-en": "/checkpoint/victorialin/few_shot/6.7B_gpt3_setting_en_tasks/results.tsv",
    "multi-7.5B-dense-en": "/checkpoint/victorialin/few_shot/dense_7.5B_lang30_new_cc100_xl_unigram_en_tasks/results.tsv",
    "multi-7.5B-dense-multi": "/checkpoint/victorialin/few_shot/dense_7.5B_lang30_new_cc100_xl_unigram_mutli_tasks_v1/results.tsv",
    # "multi-200B-moe-en": "/checkpoint/victorialin/few_shot/moe_200B_lang30_new_cc100_xl_unigram_en_tasks//results.tsv",
    # "multi-200B-moe-multi": "/checkpoint/victorialin/few_shot/moe_200B_lang30_new_cc100_xl_unigram_mutli_tasks_v1/results.tsv"
}

_all_eval_tasks = [
    'arcchallenge', 
    'arceasy', 
    'copa', 
    'hellaswag', 
    'openbookqa', 
    'piqa', 
    'winogrande', 
    'storycloze', 
    'xnli', 
    'xcopa', 
    'xwinograd', 
    'pawsx'
]

_all_en_eval_tasks = [
    'arcchallenge', 
    'arceasy', 
    'copa', 
    'hellaswag', 
    'openbookqa', 
    'piqa', 
    'winogrande', 
    'storycloze'
]

_all_multi_eval_tasks = [
    'storycloze', 
    'xnli', 
    'xcopa', 
    'xwinograd', 
    'pawsx'
]

# filtering conditions
def all_eval_tasks(df):
    return functools.reduce(lambda x, y: x | y, [df['task'] == t for t in _all_eval_tasks])

def all_en_eval_tasks(df):
    return functools.reduce(lambda x, y: x | y, [df['task'] == t for t in _all_en_eval_tasks])

def all_multi_eval_tasks(df):
    return functools.reduce(lambda x, y: x | y, [df['task'] == t for t in _all_multi_eval_tasks])

def valid_settings(df):
    return (df['template'] != 'xcopa_simple')

def eval_settings(option, df):
    if option == 'best':
        return ((df['task'] != 'openbookqa') | (df['run_params::scoring'] == 'unconditional-norm')) \
            & ((df['task'] != 'winogrande') | (df['run_params::scoring'] == 'suffix'))
    elif option == 'default':
        return ((df['task'] != 'openbookqa') | (df['run_params::scoring'] != 'unconditional-norm')) \
            & ((df['task'] != 'winogrande') | (df['run_params::scoring'] != 'suffix'))

def template_selection(df):
    return (df['task'] != 'xcopa') | (df['template'] == 'xcopa__en')

def all_checkpoints(df):
    return (((df.model == 'dense_7.5B_lang30_new_cc100_xl_unigram') & 
                ((df.step == 30000))
                | (df.step == 60000)
                | (df.step == 120000)
                | (df.step == 238000))
            | (df.model == '6.7B_gpt3_setting') & 
                ((df.step == 10000)
                | (df.step == 30000)
                | (df.step == 70000)
                | (df.step == 143050)))

def multilingual_checkpoints(df):
    return (df.model == 'dense_7.5B_lang30_new_cc100_xl_unigram')

def last_checkpoint(df):
    return ((df.model == '6.7B_gpt3_setting') & (df.step == 143050)) \
        | ((df.model == 'dense_7.5B_lang30_new_cc100_xl_unigram') & (df.step == 238000))
        # | ((df.model == 'moe_200B_lang30_new_cc100_xl_unigram') & (df.step == 118000)) \

def num_few_shot_samples(df):
    return ((df['nb_few_shot_samples'] == 0)
        | (df['nb_few_shot_samples'] == 1)
        | (df['nb_few_shot_samples'] == 4)
        | (df['nb_few_shot_samples'] == 32)
        | (df['nb_few_shot_samples'] == 128))

def final_eval_splits(df):
    return (((df['task'] == 'arcchallenge') & (df['eval_set'] == 'dev') & (df['train_set'] == 'train'))
        | ((df['task'] == 'arceasy') & (df['eval_set'] == 'dev') & (df['train_set'] == 'train'))
        | ((df['task'] == 'copa') & (df['eval_set'] == 'val') & (df['train_set'] == 'train'))
        | ((df['task'] == 'hellaswag') & (df['eval_set'] == 'val') & (df['train_set'] == 'train'))
        | ((df['task'] == 'openbookqa') & (df['eval_set'] == 'test') & (df['train_set'] == 'train'))
        | ((df['task'] == 'piqa') & (df['eval_set'] == 'valid') & (df['train_set'] == 'train'))
        | ((df['task'] == 'winogrande') & (df['eval_set'] == 'dev') & (df['train_set'] == 'train_xl'))
        | ((df['task'] == 'storycloze') & (df['eval_set'] == 'test2016') & (df['train_set'] == 'val2016'))
        | ((df['task'] == 'storycloze') & (df['eval_set'] == 'val2016') & (df['train_set'] == 'val2016'))
        | ((df['task'] == 'pawsx') & (df['eval_set'] == 'test') & (df['train_set'] == 'dev'))
        | ((df['task'] == 'xcopa') & (df['eval_set'] == 'test') & (df['train_set'] == 'val'))
        | ((df['task'] == 'xnli') & (df['eval_set'] == 'test') & (df['train_set'] == 'dev'))
        | ((df['task'] == 'xwinograd') & ((df['language'] != 'fr') & (df['language'] != 'zh')) & (df['eval_set'] == 'test') & (df['train_set'] == 'test')))

def en_final_eval_splits(df):
    return (((df['task'] == 'arcchallenge') & (df['eval_set'] == 'dev') & (df['train_set'] == 'train'))
        | ((df['task'] == 'arceasy') & (df['eval_set'] == 'dev') & (df['train_set'] == 'train'))
        | ((df['task'] == 'copa') & (df['eval_set'] == 'val') & (df['train_set'] == 'train'))
        | ((df['task'] == 'hellaswag') & (df['eval_set'] == 'val') & (df['train_set'] == 'train'))
        | ((df['task'] == 'openbookqa') & (df['eval_set'] == 'test') & (df['train_set'] == 'train'))
        | ((df['task'] == 'piqa') & (df['eval_set'] == 'valid') & (df['train_set'] == 'train'))
        | ((df['task'] == 'winogrande') & (df['eval_set'] == 'dev') & (df['train_set'] == 'train_xl'))
        | ((df['task'] == 'storycloze') & (df['eval_set'] == 'test2016') & (df['train_set'] == 'val2016'))
        | ((df['task'] == 'pawsx') & (df['eval_set'] == 'test') & (df['train_set'] == 'dev'))
        | ((df['task'] == 'xcopa') & (df['eval_set'] == 'test') & (df['train_set'] == 'val'))
        | ((df['task'] == 'xnli') & (df['eval_set'] == 'test') & (df['train_set'] == 'dev'))
        | ((df['task'] == 'xwinograd') & ((df['language'] != 'fr') & (df['language'] != 'zh')) & (df['eval_set'] == 'test') & (df['train_set'] == 'test')))

def multi_final_eval_splits(df):
    return (((df['task'] == 'arcchallenge') & (df['eval_set'] == 'dev') & (df['train_set'] == 'train'))
        | ((df['task'] == 'arceasy') & (df['eval_set'] == 'dev') & (df['train_set'] == 'train'))
        | ((df['task'] == 'copa') & (df['eval_set'] == 'val') & (df['train_set'] == 'train'))
        | ((df['task'] == 'hellaswag') & (df['eval_set'] == 'val') & (df['train_set'] == 'train'))
        | ((df['task'] == 'openbookqa') & (df['eval_set'] == 'test') & (df['train_set'] == 'train'))
        | ((df['task'] == 'piqa') & (df['eval_set'] == 'valid') & (df['train_set'] == 'train'))
        | ((df['task'] == 'winogrande') & (df['eval_set'] == 'dev') & (df['train_set'] == 'train_xl'))
        | ((df['task'] == 'storycloze') & (df['eval_set'] == 'val2016_split_20_80_eval') & (df['train_set'] == 'val2016_split_20_80_train'))
        | ((df['task'] == 'pawsx') & (df['eval_set'] == 'test') & (df['train_set'] == 'dev'))
        | ((df['task'] == 'xcopa') & (df['language'] != 'ru') & (df['eval_set'] == 'test') & (df['train_set'] == 'val'))
        | ((df['task'] == 'xnli') & (df['eval_set'] == 'test') & (df['train_set'] == 'dev'))
        | ((df['task'] == 'xwinograd') & ((df['language'] != 'fr') & (df['language'] != 'zh')) & (df['eval_set'] == 'test') & (df['train_set'] == 'test')))

def multi_dev_eval_splits(df):
    return (((df['task'] == 'arcchallenge') & (df['eval_set'] == 'dev') & (df['train_set'] == 'train'))
        | ((df['task'] == 'arceasy') & (df['eval_set'] == 'dev') & (df['train_set'] == 'train'))
        | ((df['task'] == 'copa') & (df['eval_set'] == 'val') & (df['train_set'] == 'train'))
        | ((df['task'] == 'hellaswag') & (df['eval_set'] == 'val') & (df['train_set'] == 'train'))
        | ((df['task'] == 'openbookqa') & (df['eval_set'] == 'test') & (df['train_set'] == 'train'))
        | ((df['task'] == 'piqa') & (df['eval_set'] == 'valid') & (df['train_set'] == 'train'))
        | ((df['task'] == 'winogrande') & (df['eval_set'] == 'dev') & (df['train_set'] == 'train_xl'))
        | ((df['task'] == 'storycloze') & (df['eval_set'] == 'val2016_split_20_80_eval') & (df['train_set'] == 'val2016_split_20_80_train'))
        | ((df['task'] == 'pawsx') & (df['eval_set'] == 'test') & (df['train_set'] == 'dev'))
        | ((df['task'] == 'xcopa') & (df['eval_set'] == 'val') & (df['train_set'] == 'val'))
        | ((df['task'] == 'xnli') & (df['eval_set'] == 'dev') & (df['train_set'] == 'dev'))
        | ((df['task'] == 'xwinograd') & ((df['language'] != 'fr') & (df['language'] != 'zh')) & (df['eval_set'] == 'test') & (df['train_set'] == 'test')))

def en_only(df):
    return df['language'] == 'en'

dfs = {}
for key in result_tables:
    df = pd.read_csv(result_tables[key], sep='\t', index_col=False).iloc[:, 1:]
    df = df.drop_duplicates()
    df = df[valid_settings(df)]

    df['model'] = df.model_name.apply(lambda x: x.split('__step')[0])
    df['step'] = df.model_name.apply(lambda x: int(x.split('__step')[1]))
    df['meta_task'] = df.task.apply(lambda x:x.split('__', 1)[0])
    df['model_id'] = df.model_name.apply(lambda x:x.split('__step', 1)[0])
    
    if key.startswith("en-6.7B"):
        df['num_tokens_B'] = df.step.apply(lambda x:x * 2 * 1024 * 1024 / 1e9)
        df['num_EN_tokens_B'] = df.step.apply(lambda x:x * 2 * 1024 * 1024 / 1e9)
        df['num_gpu_days'] = df.num_tokens_B.apply(lambda x:x / 300 * 6356)
    if key.startswith("multi-7.5B"):
        df['num_tokens_B'] = df.step.apply(lambda x:x * 2 * 1024 * 1024 / 1e9)
        df['num_EN_tokens_B'] = df.step.apply(lambda x:x * 2 * 1024 * 1024 / 1e9 * 0.326)
        df['num_gpu_days'] = df.num_tokens_B.apply(lambda x:x / 500 * 11007)
    if key.startswith("multi-200B"):
        df['step'] = df.step.apply(lambda x: x + 22000)
        df['num_tokens_B'] = df.step.apply(lambda x:x * 4 * 1024 * 1024 / 1e9)
        df['num_EN_tokens_B'] = df.step.apply(lambda x:x * 4 * 1024 * 1024 / 1e9 * 0.326)
        df['num_gpu_days'] = df.num_tokens_B.apply(lambda x:x / 500 * 5711)

    df = df[eval_settings('best', df)]
    dfs[key] = df

en_dfs = []
for key in dfs:
    if key.endswith("-en"):
        en_dfs.append(dfs[key])
en_result_df = pd.concat(en_dfs)
en_result_df['resource_level'] = en_result_df.language.apply(lambda x:get_resource_level(x))

multi_dfs = []
for key in dfs:
    if key.endswith("-multi"):
        multi_dfs.append(dfs[key])
multi_result_df = pd.concat(multi_dfs)
multi_result_df['resource_level'] = multi_result_df.language.apply(lambda x:get_resource_level(x))

In [3]:
# df_mapping = pd.DataFrame({'task': all_eval_tasks})
# sort_mapping = df_mapping.reset_index().set_index('task')
result_df = pd.concat([en_result_df, multi_result_df]).drop_duplicates(
    subset=[
        'task',
        'eval_set',
        'language',
        'train_set',
        'train_lang',
        'template',
        'nb_few_shot_samples',
        'calibration',
        'run_params::scoring',
        'model_name',
        "accuracy::mean",
        "accuracy::std",
        'model',
        'step',
        'meta_task',
        'model_id',
        'num_tokens_B',
        'num_EN_tokens_B',
        'num_gpu_days'
    ],
    ignore_index=True
)
filtered_result_df = result_df[last_checkpoint(result_df) & final_eval_splits(result_df) & num_few_shot_samples(result_df) & template_selection(result_df)] # & all_multi_eval_tasks(result_df) & multilingual_checkpoints(result_df)]
grouped_filtered_result_df = filtered_result_df.groupby(['task', 'eval_set', 'language', 'train_set', 'train_lang', 'template', 'calibration', 'model_name'])
grouped_filtered_result_df = grouped_filtered_result_df.apply(lambda a: a.sort_values('nb_few_shot_samples'))
grouped_filtered_result_df.to_csv('/checkpoint/victorialin/few_shot/dense_7.5B_lang30_new_cc100_xl_unigram_en_tasks/few_shot.tsv', sep='\t')

In [4]:
import numpy as np
num_trials = 5
eval_metrics = 'accuracy'

def to_altair_data(df):
    reformatted_data = []
    for _, row in df.iterrows():
        for i in range(num_trials):
            data_point = dict()
            for key in [
                        'task',
                        'language',
                        'nb_few_shot_samples',
                        'model',
                        'step',
                        'resource_level'
                    ]:
                if key == 'resource_level' and key not in row:
                    data_point[key] = 'high'
                else:
                    data_point[key] = row[key]
            data_point[eval_metrics] = row[f'{eval_metrics}_{i}']
            data_point[f'{eval_metrics}::mean'] = row[f'{eval_metrics}::mean']
            reformatted_data.append(data_point)
    plot_data = alt.Data(values=reformatted_data)
    return plot_data

resource_levels = [
    'high',
    'medium',
    'low',
    'extremely-low'
]

In [5]:
# df_mapping = pd.DataFrame({'task': all_eval_tasks})
# sort_mapping = df_mapping.reset_index().set_index('task')
result_df = pd.concat([en_result_df, multi_result_df]).drop_duplicates(
    subset=[
        'task',
        'eval_set',
        'language',
        'train_set',
        'train_lang',
        'template',
        'nb_few_shot_samples',
        'calibration',
        'run_params::scoring',
        'model_name',
        "accuracy::mean",
        "accuracy::std",
        'model',
        'step',
        'meta_task',
        'model_id',
        'num_tokens_B',
        'num_EN_tokens_B',
        'num_gpu_days'
    ],
    ignore_index=True
)
filtered_result_df = result_df[last_checkpoint(result_df) & num_few_shot_samples(result_df) & template_selection(result_df) & multilingual_checkpoints(result_df)
    & (en_final_eval_splits(result_df) | multi_final_eval_splits(result_df))
    & (all_en_eval_tasks(result_df) | all_multi_eval_tasks(result_df))] 
grouped_filtered_result_df = filtered_result_df.groupby(['task', 'eval_set', 'language', 'train_set', 'train_lang', 'template', 'calibration', 'model_name'])
grouped_filtered_result_df = grouped_filtered_result_df.apply(lambda a: a.sort_values('nb_few_shot_samples'))
grouped_filtered_result_df.to_csv('/checkpoint/victorialin/few_shot/dense_7.5B_lang30_new_cc100_xl_unigram_en_tasks/few_shot.tsv', sep='\t')

multi_filtered_result_df = filtered_result_df[multi_final_eval_splits(filtered_result_df) & all_multi_eval_tasks(filtered_result_df)]
multi_grouped_filtered_result_df = multi_filtered_result_df.groupby(['task', 'eval_set', 'language', 'train_set', 'train_lang', 'template', 'calibration', 'model_name'])
multi_grouped_filtered_result_df = multi_grouped_filtered_result_df.apply(lambda a: a.sort_values('nb_few_shot_samples'))
multi_grouped_filtered_result_df.to_csv('/checkpoint/victorialin/few_shot/dense_7.5B_lang30_new_cc100_xl_unigram_en_tasks/multi_few_shot.tsv', sep='\t')

en_filtered_result_df = filtered_result_df[en_final_eval_splits(filtered_result_df) & all_en_eval_tasks(filtered_result_df) & en_only(filtered_result_df)]
en_grouped_filtered_result_df = en_filtered_result_df.groupby(['task', 'eval_set', 'language', 'train_set', 'train_lang', 'template', 'calibration', 'model_name'])
en_grouped_filtered_result_df = en_grouped_filtered_result_df.apply(lambda a: a.sort_values('nb_few_shot_samples'))
en_grouped_filtered_result_df.to_csv('/checkpoint/victorialin/few_shot/dense_7.5B_lang30_new_cc100_xl_unigram_en_tasks/en_few_shot.tsv', sep='\t')

In [6]:
line = alt.Chart().mark_line(point=True).encode(
    x=alt.X("nb_few_shot_samples:Q", axis=alt.Axis(values=[0, 1, 4, 32, 128]), scale=alt.Scale(type='symlog')),
    y=alt.Y("accuracy:Q", aggregate='mean', scale=alt.Scale(zero=False)),
    color=alt.Color("resource_level:N", scale=alt.Scale(domain=resource_levels, range=['#5778a4', '#85b6b2', '#e49444', '#d1615d']))
)
band = alt.Chart().mark_errorband(extent='ci').encode(
    x=alt.X("nb_few_shot_samples:Q", axis=alt.Axis(values=[0, 1, 4, 32, 128]), scale=alt.Scale(type='symlog')),
    y=alt.Y("accuracy:Q", scale=alt.Scale(zero=False)),
    color=alt.Color("resource_level:N", scale=alt.Scale(domain=resource_levels, range=['#5778a4', '#85b6b2', '#e49444', '#d1615d']))
)

en_plot_data = to_altair_data(en_filtered_result_df)
alt.layer(line, band, data=en_plot_data).facet(
    'task:N', columns=4
).resolve_scale(
    x='independent',
    y='independent'
)

In [7]:
line = alt.Chart().mark_line(point=True).encode(
    x=alt.X("nb_few_shot_samples:Q", axis=alt.Axis(values=[0, 1, 4, 32, 128]), scale=alt.Scale(type='symlog')),
    y=alt.Y("accuracy::mean:Q", aggregate='mean', scale=alt.Scale(zero=False)),
    color=alt.Color("resource_level:N", scale=alt.Scale(domain=resource_levels, range=['#5778a4', '#85b6b2', '#e49444', '#d1615d']))
)
band = alt.Chart().mark_errorband(extent='ci').encode(
    x=alt.X("nb_few_shot_samples:Q", axis=alt.Axis(values=[0, 1, 4, 32, 128]), scale=alt.Scale(type='symlog')),
    y=alt.Y("accuracy::mean:Q", scale=alt.Scale(zero=False)),
    color=alt.Color("resource_level:N", scale=alt.Scale(domain=resource_levels, range=['#5778a4', '#85b6b2', '#e49444', '#d1615d']))
)

plot_data = to_altair_data(multi_filtered_result_df)
lm_tasks = alt.layer(line, band, data=plot_data).transform_filter(
    (datum.task == 'storycloze')
    | (datum.task == 'xcopa')
    | (datum.task == 'xwinograd')
).facet(
    column=alt.Column('task:N', sort=['storycloze', 'xcopa', 'xwinograd']),
).resolve_scale(
    x='independent',
    y='independent'
)

cls_tasks = alt.layer(line, band, data=plot_data).transform_filter(
    (datum.task == 'xnli')
    | (datum.task == 'pawsx')
).facet(
    column=alt.Column('task:N', sort=['xnli', 'pawsx']),
).resolve_scale(
    x='independent',
    y='independent'
)

alt.vconcat(lm_tasks, cls_tasks)

In [25]:
[r for r in reformatted_data if r['task'] == 'copa']

[{'task': 'copa',
  'language': 'en',
  'nb_few_shot_samples': 0,
  'model': 'dense_7.5B_lang30_new_cc100_xl_unigram',
  'step': 238000,
  'accuracy': nan},
 {'task': 'copa',
  'language': 'en',
  'nb_few_shot_samples': 1,
  'model': 'dense_7.5B_lang30_new_cc100_xl_unigram',
  'step': 238000,
  'accuracy': 71.0},
 {'task': 'copa',
  'language': 'en',
  'nb_few_shot_samples': 4,
  'model': 'dense_7.5B_lang30_new_cc100_xl_unigram',
  'step': 238000,
  'accuracy': 72.0},
 {'task': 'copa',
  'language': 'en',
  'nb_few_shot_samples': 128,
  'model': 'dense_7.5B_lang30_new_cc100_xl_unigram',
  'step': 238000,
  'accuracy': 70.0},
 {'task': 'copa',
  'language': 'en',
  'nb_few_shot_samples': 32,
  'model': 'dense_7.5B_lang30_new_cc100_xl_unigram',
  'step': 238000,
  'accuracy': 69.0}]

In [8]:
# baseline_data = alt.Data(values=[
#     {'task': 'arcchallenge', 'accuracy::mean': 34, 'model': 'dense_6.7B_en', 'num_tokens_(B)': 300, 'num_gpu_days': 2444.6}, 
#     {'task': 'arceasy', 'accuracy::mean': 60.4, 'model': 'dense_6.7B_en', 'num_tokens_(B)': 300, 'num_gpu_days': 2444.6},
#     {'task': 'openbookqa', 'accuracy::mean': 34.2, 'model': 'dense_6.7B_en', 'num_tokens_(B)': 300, 'num_gpu_days': 2444.6},
#     {'task': 'piqa', 'accuracy::mean': 78.7, 'model': 'dense_6.7B_en', 'num_tokens_(B)': 300, 'num_gpu_days': 2444.6},
#     {'task': 'storycloze', 'accuracy::mean': 80, 'model': 'dense_6.7B_en', 'num_tokens_(B)': 300, 'num_gpu_days': 2444.6},
#     {'task': 'winogrande', 'accuracy::mean': 62, 'model': 'dense_6.7B_en', 'num_tokens_(B)': 300, 'num_gpu_days': 2444.6},
#     {'task': 'hellaswag', 'accuracy::mean': 70.6, 'model': 'dense_6.7B_en', 'num_tokens_(B)': 300, 'num_gpu_days': 2444.6},
# ])

chart = alt.vconcat()
for task in en_result_df.task.unique():
    if (task.startswith('blimp')):
        continue
    if eval_settings == 'default' and task not in ['openbookqa', 'winogrande']:
        continue
    print(task)
    cols = []
    for x_var in ["num_gpu_days", "num_tokens_B", "num_EN_tokens_B"]:
        col = alt.Chart(en_result_df).mark_line(point=True).encode(
            x=alt.X(x_var, type="quantitative", scale=alt.Scale(type='log', range=[0, 500])),
            y=alt.Y("accuracy::mean", type="quantitative", scale=alt.Scale(zero=False)),
            color=alt.Color("model:N", scale=alt.Scale(domain=['dense_7.5B_lang30_new_cc100_xl_unigram', 'moe_200B_lang30_new_cc100_xl_unigram', '6.7B_gpt3_setting'], range=['#4E79A7', '#F28E2B', '#E15759'])),
        ).transform_filter(
            (datum.nb_few_shot_samples == 0)
            & (datum.task == task)
            & (datum.language == 'en')
        ).properties(
            title=task
        )
        # en_baseline_mark = alt.Chart(baseline_data).mark_line(point=True).encode(
        #     x='num_tokens_(B):Q', 
        #     y='accuracy::mean:Q', 
        #     color=alt.Color("model:N", scale=alt.Scale(domain=['dense_7.5B_lang30_new_cc100_xl_unigram', 'moe_200B_lang30_new_cc100_xl_unigram', 'dense_6.7B_en'], range=['#4E79A7', '#F28E2B', '#E15759']))
        # ).transform_filter(
        #     datum.task == task
        # )
        cols.append(col)
    chart &= (functools.reduce(lambda x, y: x | y, cols))
chart

arcchallenge
arceasy
hellaswag
piqa
storycloze
openbookqa
winogrande


In [9]:
chart = alt.vconcat()
for task in multi_result_df.task.unique():
    if (task.startswith('mlama')):
        continue
    row1_mean = alt.Chart(multi_result_df).mark_line(point=True).encode(
        x=alt.X("num_tokens_B:Q", scale=alt.Scale(type='log')),
        y=alt.Y("mean(accuracy::mean)", type="quantitative", scale=alt.Scale(zero=False)),
        color="model:N"
    ).transform_filter(
        (datum.nb_few_shot_samples == 0)
        & (datum.task == task)
    ).properties(
        title=task
    )
    row1_median = alt.Chart(multi_result_df).mark_line(opacity=0.3, point=True).encode(
        x=alt.X("num_tokens_B:Q", scale=alt.Scale(type='log')),
        y=alt.Y("median(accuracy::mean)", type="quantitative", scale=alt.Scale(zero=False)),
        color="model:N"
    ).transform_filter(
        (datum.nb_few_shot_samples == 0)
        & (datum.task == task)
    ).properties(
        title=task
    )
    row2_mean = alt.Chart(multi_result_df).mark_line(point=True).encode(
        x=alt.X("num_gpu_days:Q", scale=alt.Scale(type='log')),
        y=alt.Y("mean(accuracy::mean)", type="quantitative", scale=alt.Scale(zero=False)),
        color="model:N"
    ).transform_filter(
        (datum.nb_few_shot_samples == 0)
        & (datum.task == task)
    ).properties(
        title=task
    )
    row2_median = alt.Chart(multi_result_df).mark_line(opacity=0.3, point=True).encode(
        x=alt.X("num_gpu_days:Q", scale=alt.Scale(type='log')),
        y=alt.Y("median(accuracy::mean)", type="quantitative", scale=alt.Scale(zero=False)),
        color="model:N"
    ).transform_filter(
        (datum.nb_few_shot_samples == 0)
        & (datum.task == task)
    ).properties(
        title=task
    )
    chart &= ((row2_mean + row2_median) | (row1_mean+row1_median))
chart

In [10]:
en_result_df

Unnamed: 0,task,eval_set,language,train_set,train_lang,template,nb_few_shot_samples,calibration,run_params::scoring,model_name,accuracy::mean,model,step,meta_task,num_tokens_B,num_EN_tokens_B,num_gpu_days
0,arcchallenge,dev,en,train,en,arc_old,0,True,mean,dense_7.5B_lang30_new_cc100_xl_unigram__step00...,27.424749,dense_7.5B_lang30_new_cc100_xl_unigram,30000,arcchallenge,62.914560,20.510147,1385.001124
1,arceasy,dev,en,train,en,arc_old,0,False,mean,dense_7.5B_lang30_new_cc100_xl_unigram__step00...,46.315789,dense_7.5B_lang30_new_cc100_xl_unigram,30000,arceasy,62.914560,20.510147,1385.001124
2,blimp__adjunct_island,adjunct_island,en,adjunct_island,en,blimp,0,False,mean,dense_7.5B_lang30_new_cc100_xl_unigram__step00...,89.800000,dense_7.5B_lang30_new_cc100_xl_unigram,30000,blimp,62.914560,20.510147,1385.001124
3,blimp__anaphor_gender_agreement,anaphor_gender_agreement,en,anaphor_gender_agreement,en,blimp,0,False,mean,dense_7.5B_lang30_new_cc100_xl_unigram__step00...,99.700000,dense_7.5B_lang30_new_cc100_xl_unigram,30000,blimp,62.914560,20.510147,1385.001124
4,blimp__anaphor_number_agreement,anaphor_number_agreement,en,anaphor_number_agreement,en,blimp,0,False,mean,dense_7.5B_lang30_new_cc100_xl_unigram__step00...,99.500000,dense_7.5B_lang30_new_cc100_xl_unigram,30000,blimp,62.914560,20.510147,1385.001124
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1779,piqa,valid,en,train,en,piqa,16,False,mean,dense_7.5B_lang30_new_cc100_xl_unigram__step00...,73.884657,dense_7.5B_lang30_new_cc100_xl_unigram,238000,piqa,499.122176,162.713829,10987.675582
1780,winogrande,dev,en,train_xl,en,winogrande,16,False,suffix,dense_7.5B_lang30_new_cc100_xl_unigram__step00...,60.063141,dense_7.5B_lang30_new_cc100_xl_unigram,238000,winogrande,499.122176,162.713829,10987.675582
1781,winogrande,dev,en,train_xl,en,winogrande,64,False,suffix,dense_7.5B_lang30_new_cc100_xl_unigram__step00...,60.015785,dense_7.5B_lang30_new_cc100_xl_unigram,238000,winogrande,499.122176,162.713829,10987.675582
1782,arceasy,dev,en,train,en,arc_old,16,False,mean,dense_7.5B_lang30_new_cc100_xl_unigram__step00...,63.859649,dense_7.5B_lang30_new_cc100_xl_unigram,238000,arceasy,499.122176,162.713829,10987.675582


In [16]:
chart = alt.Chart(en_result_df).mark_line(point=True).encode(
    x=alt.X("nb_few_shot_samples:Q", axis=alt.Axis(tickMinStep=1)),
    y=alt.Y("accuracy::mean", type="quantitative", scale=alt.Scale(zero=False)),
    color=alt.Color("model:N", scale=alt.Scale(domain=['dense_7.5B_lang30_new_cc100_xl_unigram', 'moe_200B_lang30_new_cc100_xl_unigram', '6.7B_gpt3_setting'], range=['#4E79A7', '#F28E2B', '#E15759'])),
).transform_filter(
    ((eval_settings != 'default') | ((datum.task in ['openbookqa', 'winogrande'])))
    & (datum.language == 'en')
    & (datum.meta_task != 'blimp') 
    & (last_checkpoint(datum))
).properties(
    title=task
).facet(
    facet='task:N',
    columns=4
).resolve_scale(
    x='independent',
    y='independent'
)
chart

In [17]:
chart = alt.Chart(multi_result_df).mark_line(point=True).encode(
    x=alt.X("nb_few_shot_samples:Q", axis=alt.Axis(tickMinStep=1)),
    y=alt.Y("mean(accuracy::mean)", type="quantitative", scale=alt.Scale(zero=False)),
    color=alt.Color("resource_level:N")
    # color=alt.Color("resource_level:N", scale=alt.Scale(domain=['dense_7.5B_lang30_new_cc100_xl_unigram', 'moe_200B_lang30_new_cc100_xl_unigram', '6.7B_gpt3_setting'], range=['#4E79A7', '#F28E2B', '#E15759'])),
).transform_filter(
      (datum.meta_task != 'blimp') 
    & (((datum.model == '6.7B_gpt3_setting') & (datum.step == 143050))
    | ((datum.model == 'dense_7.5B_lang30_new_cc100_xl_unigram') & (datum.step == 238000))
    | ((datum.model == 'moe_200B_lang30_new_cc100_xl_unigram') & (datum.step == 118000)))
).properties(
    title=task
).facet(
    facet='task:N',
    columns=4
).resolve_scale(
    x='independent',
    y='independent'
)
chart