In [1]:
import altair as alt
from altair import datum
import functools
import pandas as pd
import sys
sys.path.append('/private/home/victorialin/Projects/fairseq-py')
from scripts.visualization.multilingual_few_shot_eval_utils import *


result_tables = {
    "multi-dense-564M": "/checkpoint/xianl/few_shot/dense_564M_lang30_new_cc100_xl_unigram_mutli_tasks_v1_test/results.tsv",
    "multi-dense-1.7B": "/checkpoint/xianl/few_shot/dense_1.7B_lang30_new_cc100_xl_unigram_mutli_tasks_v1_test/results.tsv",
    "multi-dense-7.5B": "/checkpoint/victorialin/few_shot/dense_7.5B_lang30_new_cc100_xl_unigram_mutli_tasks_v1/results.tsv"
}

resource_levels = list(xlmg_language_resource_brackets.keys())

def scaling_nb_few_shots(df):
    return ((df['nb_few_shot_samples'] == 0)
        | (df['nb_few_shot_samples'] == 1)
        | (df['nb_few_shot_samples'] == 4)
        | (df['nb_few_shot_samples'] == 32)
        | ((df['nb_few_shot_samples'] == 128) & (df['task'] != 'storycloze')))

dfs = {}
for key in result_tables:
    df = load_from_tsv_and_filter(result_tables[key], filter_conditions=multi_final_eval_filter_conditions)
    dfs[key] = df[scaling_nb_few_shots(df)]

multi_result_df = pd.concat(dfs.values())

In [None]:
def verify_multilingual_few_shot_learning_results(df, model_name=None):
    preferred_metrics = 'accuracy'
    num_trials = 5
    model_names = df['model_id'].unique() if model_name is None else [model_name]

    for m_name in model_names:
        print(f"Checking {m_name} predictions...")
        model_df = df[df['model_id'] == m_name]
        grouped_model_df = model_df.groupby(['task', 'nb_few_shot_samples'])
        for group_key, _ in grouped_model_df:
            group = grouped_model_df.get_group(group_key)
            task, nb_few_shot_samples = group_key
            task_langs = multi_eval_tasks_langs[task]
            collected_task_langs = group['language'].to_list()
            
            print(group_key, collected_task_langs)
            if len(collected_task_langs) > len(task_langs):
                print('Warning: duplicated results collected')
            if len(collected_task_langs) < len(task_langs):
                print('Warning: missing language in result, expecting {task_langs}')
        print()

verify_multilingual_few_shot_learning_results(multi_result_df)

In [102]:
def adjust_resource_level_for_pivoting(resource_level):
    pivoting_resource_level = {
        'high': '  hi',
        'medium': '  me',
        'low': ' lo',
        'extremely-low': 'ex-lo'
    }
    return pivoting_resource_level[resource_level]

def adjust_task_for_pivoting(task):
    pivoting_task = {
        'storycloze': 'storycloze',
        'xcopa': 'xcopa',
        'xwinograd': 'xwinograd',
        'xnli': ' xnli',
        'pawsx': 'pawsx'
    }
    return pivoting_task[task]

def adjust_task_group_for_pivoting(task_group):
    pivoting_task_group = {
        'language modeling': 'Language Modeling',
        'NLI': 'NLI'
    }
    return pivoting_task_group[task_group]

def adjust_model_size_for_pivoting(model_size):
    pivoting_model_size = {
        0.564: ' 564M',
        1.7: '1.7B',
        7.5: '7.5B'
    }
    return pivoting_model_size[model_size] 
    
attributes = [
    'task',
    'eval_set',
    'language',
    'train_set',
    'train_lang',
    'template',
    'nb_few_shot_samples',
    'calibration',
    'run_params::scoring',
    'model_name',
    'model_size',
    "accuracy::mean",
    "accuracy::std",
    'model',
    'step',
    'meta_task',
    'model_id',
]

result_df = multi_result_df.drop_duplicates(
    subset=attributes,
    ignore_index=True
)
result_df['pivoting_resource_level'] = result_df.resource_level.apply(lambda x:adjust_resource_level_for_pivoting(x))
result_df['pivoting_task'] = result_df.task.apply(lambda x: adjust_task_for_pivoting(x))
result_df['pivoting_task_group'] = result_df.task_group.apply(lambda x:adjust_task_group_for_pivoting(x))
result_df['pivoting_model_size'] = result_df.model_size.apply(lambda x:adjust_model_size_for_pivoting(x))
result_df = result_df[attributes + ['pivoting_resource_level', 'resource_level', 'pivoting_task', 'pivoting_task_group', 'pivoting_model_size', 'task_group']]

grouped_filtered_result_df = result_df.groupby(['task', 'eval_set', 'language', 'train_set', 'train_lang', 'template', 'calibration', 'model_name'])
grouped_filtered_result_df = grouped_filtered_result_df.apply(lambda a: a.sort_values('nb_few_shot_samples'))
grouped_filtered_result_df.to_csv('/checkpoint/victorialin/few_shot/dense_7.5B_lang30_new_cc100_xl_unigram_en_tasks/few_shot_scaling.tsv', sep='\t')

In [103]:
import numpy as np
# result_df_7_5b = result_df[result_df['model_size'] == 7.5]
result_table = pd.pivot_table(result_df, index=['pivoting_model_size', 'nb_few_shot_samples'], columns=['pivoting_task_group', 'pivoting_task', 'pivoting_resource_level'], values=['accuracy::mean'], aggfunc=[np.mean], fill_value='')
result_table.to_csv('/checkpoint/victorialin/few_shot/dense_7.5B_lang30_new_cc100_xl_unigram_en_tasks/paper_table.tsv', sep='\t')
result_table.to_html('/checkpoint/victorialin/few_shot/dense_7.5B_lang30_new_cc100_xl_unigram_en_tasks/paper_table.html')

result_table
print(result_table.to_latex(float_format="{:0.1f}".format))

\begin{tabular}{llllllllllllrrlllllllllllrrlll}
\toprule
     & {} & \multicolumn{15}{l}{mean} & \multicolumn{13}{l}{std} \\
     & {} & \multicolumn{15}{l}{accuracy::mean} & \multicolumn{13}{l}{accuracy::mean} \\
     & pivoting\_task\_group & \multicolumn{10}{l}{Language Modeling} & \multicolumn{5}{l}{NLI} & \multicolumn{8}{l}{Language Modeling} & \multicolumn{5}{l}{NLI} \\
     & pivoting\_task & \multicolumn{4}{l}{storycloze} & \multicolumn{4}{l}{xcopa} & \multicolumn{2}{l}{xwinograd} & \multicolumn{3}{l}{xnli} & \multicolumn{2}{l}{pawsx} & \multicolumn{3}{l}{storycloze} & \multicolumn{4}{l}{xcopa} & xwinograd & \multicolumn{3}{l}{xnli} & \multicolumn{2}{l}{pawsx} \\
     & pivoting\_resource\_level &                hi &   me &   lo & ex-lo &    hi &   me &   lo & ex-lo &        hi &   me &    hi &   me &   lo &    hi &   me &                hi &   me &  lo &    hi &   me &  lo & ex-lo &        hi &    hi &   me &  lo &    hi &   me \\
pivoting\_model\_size & nb\_few\_shot\_samples

In [99]:
for task in ['storycloze', 'xcopa', 'xwinograd', 'xnli', 'pawsx']:
    print(task)
    task_result_df_7_5b = result_df[result_df['task'] == task]
    result_table = pd.pivot_table(task_result_df_7_5b, index=['pivoting_model_size', 'nb_few_shot_samples'], columns=['pivoting_resource_level', 'language'], values=['accuracy::mean'], fill_value='')
    result_table
    print(result_table.to_latex(float_format="{:0.1f}".format))

storycloze
\begin{tabular}{llrrrrrrrrrrr}
\toprule
     & {} & \multicolumn{11}{l}{accuracy::mean} \\
     & pivoting\_resource\_level & \multicolumn{5}{l}{high} & \multicolumn{2}{l}{medium} & \multicolumn{3}{l}{low} & extremely-low \\
     & language &             ar &   en &   es &   ru &   zh &       hi &   id &   eu &   sw &   te &            my \\
pivoting\_model\_size & nb\_few\_shot\_samples &                &      &      &      &      &          &      &      &      &      &               \\
\midrule
 564M & 4  &           52.2 & 63.2 & 57.6 & 59.6 & 56.4 &     52.9 & 57.5 & 55.4 & 57.3 & 57.6 &          54.0 \\
     & 32 &           51.9 & 63.5 & 57.3 & 60.3 & 56.4 &     53.4 & 57.5 & 55.3 & 57.5 & 57.2 &          53.6 \\
1.7B & 0  &           56.1 & 67.6 & 61.4 & 64.1 & 60.2 &     56.3 & 63.3 & 56.8 & 61.0 & 58.9 &          56.4 \\
     & 1  &           55.4 & 67.6 & 61.0 & 63.8 & 60.4 &     57.5 & 63.3 & 56.9 & 59.8 & 59.3 &          56.1 \\
     & 4  &           55.4 & 67.8

In [52]:
alt.Chart(result_df).mark_line(point=True).encode(
    x=alt.X("model_size:Q", scale=alt.Scale(type='log')),
    y=alt.Y("accuracy::mean:Q", aggregate='mean', scale=alt.Scale(zero=False)),
    color=alt.Color("nb_few_shot_samples:N", scale=alt.Scale(domain=[0, 4, 32, 128], range=['#5778a4', '#85b6b2', '#e49444', '#d1615d']))
).facet(
    facet=alt.Facet("task:N"),
    columns=3
).resolve_scale(
    x='independent',
    y='independent'
)

In [29]:
import numpy as np
num_trials = 5
eval_metrics = 'accuracy'

def to_altair_data(df):
    reformatted_data = []
    for _, row in df.iterrows():
        for i in range(num_trials):
            data_point = dict()
            for key in [
                        'task',
                        'language',
                        'nb_few_shot_samples',
                        'model',
                        'step',
                        'resource_level'
                    ]:
                if key == 'resource_level' and key not in row:
                    data_point[key] = 'high'
                else:
                    data_point[key] = row[key]
            data_point[eval_metrics] = row[f'{eval_metrics}_{i}']
            data_point[f'{eval_metrics}::mean'] = row[f'{eval_metrics}::mean']
            reformatted_data.append(data_point)
    plot_data = alt.Data(values=reformatted_data)
    return plot_data

In [8]:
# df_mapping = pd.DataFrame({'task': all_eval_tasks})
# sort_mapping = df_mapping.reset_index().set_index('task')
result_df = pd.concat([en_result_df, multi_result_df]).drop_duplicates(
    subset=[
        'task',
        'eval_set',
        'language',
        'train_set',
        'train_lang',
        'template',
        'nb_few_shot_samples',
        'calibration',
        'run_params::scoring',
        'model_name',
        "accuracy::mean",
        "accuracy::std",
        'model',
        'step',
        'meta_task',
        'model_id',
        'num_tokens_B',
        'num_EN_tokens_B',
        'num_gpu_days'
    ],
    ignore_index=True
)
filtered_result_df = result_df[last_checkpoint(result_df) & num_few_shot_samples(result_df) & template_selection(result_df) & multilingual_checkpoints(result_df)
    & (en_final_eval_splits(result_df) | multi_final_eval_splits(result_df))
    & (all_en_eval_tasks(result_df) | all_multi_eval_tasks(result_df))] 
grouped_filtered_result_df = filtered_result_df.groupby(['task', 'eval_set', 'language', 'train_set', 'train_lang', 'template', 'calibration', 'model_name'])
grouped_filtered_result_df = grouped_filtered_result_df.apply(lambda a: a.sort_values('nb_few_shot_samples'))
grouped_filtered_result_df.to_csv('/checkpoint/victorialin/few_shot/dense_7.5B_lang30_new_cc100_xl_unigram_en_tasks/few_shot.tsv', sep='\t')

multi_filtered_result_df = filtered_result_df[multi_final_eval_splits(filtered_result_df) & all_multi_eval_tasks(filtered_result_df)]
multi_grouped_filtered_result_df = multi_filtered_result_df.groupby(['task', 'eval_set', 'language', 'train_set', 'train_lang', 'template', 'calibration', 'model_name'])
multi_grouped_filtered_result_df = multi_grouped_filtered_result_df.apply(lambda a: a.sort_values('nb_few_shot_samples'))
multi_grouped_filtered_result_df.to_csv('/checkpoint/victorialin/few_shot/dense_7.5B_lang30_new_cc100_xl_unigram_en_tasks/multi_few_shot.tsv', sep='\t')

en_filtered_result_df = filtered_result_df[en_final_eval_splits(filtered_result_df) & all_en_eval_tasks(filtered_result_df) & en_only(filtered_result_df)]
en_grouped_filtered_result_df = en_filtered_result_df.groupby(['task', 'eval_set', 'language', 'train_set', 'train_lang', 'template', 'calibration', 'model_name'])
en_grouped_filtered_result_df = en_grouped_filtered_result_df.apply(lambda a: a.sort_values('nb_few_shot_samples'))
en_grouped_filtered_result_df.to_csv('/checkpoint/victorialin/few_shot/dense_7.5B_lang30_new_cc100_xl_unigram_en_tasks/en_few_shot.tsv', sep='\t')

In [25]:
line = alt.Chart().mark_line(point=True).encode(
    x=alt.X("nb_few_shot_samples:Q", axis=alt.Axis(values=[0, 1, 4, 32, 128]), scale=alt.Scale(type='symlog')),
    y=alt.Y("accuracy:Q", aggregate='mean', scale=alt.Scale(zero=False)),
    color=alt.Color("resource_level:N", scale=alt.Scale(domain=resource_levels, range=['#5778a4', '#85b6b2', '#e49444', '#d1615d']))
)
band = alt.Chart().mark_errorband(extent='ci').encode(
    x=alt.X("nb_few_shot_samples:Q", axis=alt.Axis(values=[0, 1, 4, 32, 128]), scale=alt.Scale(type='symlog')),
    y=alt.Y("accuracy:Q", scale=alt.Scale(zero=False)),
    color=alt.Color("resource_level:N", scale=alt.Scale(domain=resource_levels, range=['#5778a4', '#85b6b2', '#e49444', '#d1615d']))
)

en_plot_data = to_altair_data(en_filtered_result_df)
alt.layer(line, band, data=en_plot_data).facet(
    'task:N', columns=4
).resolve_scale(
    x='independent',
    y='independent'
)

NameError: name 'resource_levels' is not defined

In [24]:
line = alt.Chart().mark_line(point=True).encode(
    x=alt.X("nb_few_shot_samples:Q", axis=alt.Axis(values=[0, 1, 4, 32, 128]), scale=alt.Scale(type='symlog')),
    y=alt.Y("accuracy::mean:Q", aggregate='mean', scale=alt.Scale(zero=False)),
    color=alt.Color("resource_level:N", scale=alt.Scale(domain=resource_levels, range=['#5778a4', '#85b6b2', '#e49444', '#d1615d']))
)
band = alt.Chart().mark_errorband(extent='ci').encode(
    x=alt.X("nb_few_shot_samples:Q", axis=alt.Axis(values=[0, 1, 4, 32, 128]), scale=alt.Scale(type='symlog')),
    y=alt.Y("accuracy::mean:Q", scale=alt.Scale(zero=False)),
    color=alt.Color("resource_level:N", scale=alt.Scale(domain=resource_levels, range=['#5778a4', '#85b6b2', '#e49444', '#d1615d']))
)

plot_data = to_altair_data(multi_filtered_result_df)
lm_tasks = alt.layer(line, band, data=plot_data).transform_filter(
    (datum.task == 'storycloze')
    | (datum.task == 'xcopa')
    | (datum.task == 'xwinograd')
).facet(
    column=alt.Column('task:N', sort=['storycloze', 'xcopa', 'xwinograd']),
).resolve_scale(
    x='independent',
    y='independent'
)

cls_tasks = alt.layer(line, band, data=plot_data).transform_filter(
    (datum.task == 'xnli')
    | (datum.task == 'pawsx')
).facet(
    column=alt.Column('task:N', sort=['xnli', 'pawsx']),
).resolve_scale(
    x='independent',
    y='independent'
)

alt.vconcat(lm_tasks, cls_tasks)

NameError: name 'resource_levels' is not defined

In [25]:
[r for r in reformatted_data if r['task'] == 'copa']

[{'task': 'copa',
  'language': 'en',
  'nb_few_shot_samples': 0,
  'model': 'dense_7.5B_lang30_new_cc100_xl_unigram',
  'step': 238000,
  'accuracy': nan},
 {'task': 'copa',
  'language': 'en',
  'nb_few_shot_samples': 1,
  'model': 'dense_7.5B_lang30_new_cc100_xl_unigram',
  'step': 238000,
  'accuracy': 71.0},
 {'task': 'copa',
  'language': 'en',
  'nb_few_shot_samples': 4,
  'model': 'dense_7.5B_lang30_new_cc100_xl_unigram',
  'step': 238000,
  'accuracy': 72.0},
 {'task': 'copa',
  'language': 'en',
  'nb_few_shot_samples': 128,
  'model': 'dense_7.5B_lang30_new_cc100_xl_unigram',
  'step': 238000,
  'accuracy': 70.0},
 {'task': 'copa',
  'language': 'en',
  'nb_few_shot_samples': 32,
  'model': 'dense_7.5B_lang30_new_cc100_xl_unigram',
  'step': 238000,
  'accuracy': 69.0}]

In [8]:
# baseline_data = alt.Data(values=[
#     {'task': 'arcchallenge', 'accuracy::mean': 34, 'model': 'dense_6.7B_en', 'num_tokens_(B)': 300, 'num_gpu_days': 2444.6}, 
#     {'task': 'arceasy', 'accuracy::mean': 60.4, 'model': 'dense_6.7B_en', 'num_tokens_(B)': 300, 'num_gpu_days': 2444.6},
#     {'task': 'openbookqa', 'accuracy::mean': 34.2, 'model': 'dense_6.7B_en', 'num_tokens_(B)': 300, 'num_gpu_days': 2444.6},
#     {'task': 'piqa', 'accuracy::mean': 78.7, 'model': 'dense_6.7B_en', 'num_tokens_(B)': 300, 'num_gpu_days': 2444.6},
#     {'task': 'storycloze', 'accuracy::mean': 80, 'model': 'dense_6.7B_en', 'num_tokens_(B)': 300, 'num_gpu_days': 2444.6},
#     {'task': 'winogrande', 'accuracy::mean': 62, 'model': 'dense_6.7B_en', 'num_tokens_(B)': 300, 'num_gpu_days': 2444.6},
#     {'task': 'hellaswag', 'accuracy::mean': 70.6, 'model': 'dense_6.7B_en', 'num_tokens_(B)': 300, 'num_gpu_days': 2444.6},
# ])

chart = alt.vconcat()
for task in en_result_df.task.unique():
    if (task.startswith('blimp')):
        continue
    if eval_settings == 'default' and task not in ['openbookqa', 'winogrande']:
        continue
    print(task)
    cols = []
    for x_var in ["num_gpu_days", "num_tokens_B", "num_EN_tokens_B"]:
        col = alt.Chart(en_result_df).mark_line(point=True).encode(
            x=alt.X(x_var, type="quantitative", scale=alt.Scale(type='log', range=[0, 500])),
            y=alt.Y("accuracy::mean", type="quantitative", scale=alt.Scale(zero=False)),
            color=alt.Color("model:N", scale=alt.Scale(domain=['dense_7.5B_lang30_new_cc100_xl_unigram', 'moe_200B_lang30_new_cc100_xl_unigram', '6.7B_gpt3_setting'], range=['#4E79A7', '#F28E2B', '#E15759'])),
        ).transform_filter(
            (datum.nb_few_shot_samples == 0)
            & (datum.task == task)
            & (datum.language == 'en')
        ).properties(
            title=task
        )
        # en_baseline_mark = alt.Chart(baseline_data).mark_line(point=True).encode(
        #     x='num_tokens_(B):Q', 
        #     y='accuracy::mean:Q', 
        #     color=alt.Color("model:N", scale=alt.Scale(domain=['dense_7.5B_lang30_new_cc100_xl_unigram', 'moe_200B_lang30_new_cc100_xl_unigram', 'dense_6.7B_en'], range=['#4E79A7', '#F28E2B', '#E15759']))
        # ).transform_filter(
        #     datum.task == task
        # )
        cols.append(col)
    chart &= (functools.reduce(lambda x, y: x | y, cols))
chart

arcchallenge
arceasy
hellaswag
piqa
storycloze
openbookqa
winogrande


In [9]:
chart = alt.vconcat()
for task in multi_result_df.task.unique():
    if (task.startswith('mlama')):
        continue
    row1_mean = alt.Chart(multi_result_df).mark_line(point=True).encode(
        x=alt.X("num_tokens_B:Q", scale=alt.Scale(type='log')),
        y=alt.Y("mean(accuracy::mean)", type="quantitative", scale=alt.Scale(zero=False)),
        color="model:N"
    ).transform_filter(
        (datum.nb_few_shot_samples == 0)
        & (datum.task == task)
    ).properties(
        title=task
    )
    row1_median = alt.Chart(multi_result_df).mark_line(opacity=0.3, point=True).encode(
        x=alt.X("num_tokens_B:Q", scale=alt.Scale(type='log')),
        y=alt.Y("median(accuracy::mean)", type="quantitative", scale=alt.Scale(zero=False)),
        color="model:N"
    ).transform_filter(
        (datum.nb_few_shot_samples == 0)
        & (datum.task == task)
    ).properties(
        title=task
    )
    row2_mean = alt.Chart(multi_result_df).mark_line(point=True).encode(
        x=alt.X("num_gpu_days:Q", scale=alt.Scale(type='log')),
        y=alt.Y("mean(accuracy::mean)", type="quantitative", scale=alt.Scale(zero=False)),
        color="model:N"
    ).transform_filter(
        (datum.nb_few_shot_samples == 0)
        & (datum.task == task)
    ).properties(
        title=task
    )
    row2_median = alt.Chart(multi_result_df).mark_line(opacity=0.3, point=True).encode(
        x=alt.X("num_gpu_days:Q", scale=alt.Scale(type='log')),
        y=alt.Y("median(accuracy::mean)", type="quantitative", scale=alt.Scale(zero=False)),
        color="model:N"
    ).transform_filter(
        (datum.nb_few_shot_samples == 0)
        & (datum.task == task)
    ).properties(
        title=task
    )
    chart &= ((row2_mean + row2_median) | (row1_mean+row1_median))
chart

In [10]:
en_result_df

Unnamed: 0,task,eval_set,language,train_set,train_lang,template,nb_few_shot_samples,calibration,run_params::scoring,model_name,accuracy::mean,model,step,meta_task,num_tokens_B,num_EN_tokens_B,num_gpu_days
0,arcchallenge,dev,en,train,en,arc_old,0,True,mean,dense_7.5B_lang30_new_cc100_xl_unigram__step00...,27.424749,dense_7.5B_lang30_new_cc100_xl_unigram,30000,arcchallenge,62.914560,20.510147,1385.001124
1,arceasy,dev,en,train,en,arc_old,0,False,mean,dense_7.5B_lang30_new_cc100_xl_unigram__step00...,46.315789,dense_7.5B_lang30_new_cc100_xl_unigram,30000,arceasy,62.914560,20.510147,1385.001124
2,blimp__adjunct_island,adjunct_island,en,adjunct_island,en,blimp,0,False,mean,dense_7.5B_lang30_new_cc100_xl_unigram__step00...,89.800000,dense_7.5B_lang30_new_cc100_xl_unigram,30000,blimp,62.914560,20.510147,1385.001124
3,blimp__anaphor_gender_agreement,anaphor_gender_agreement,en,anaphor_gender_agreement,en,blimp,0,False,mean,dense_7.5B_lang30_new_cc100_xl_unigram__step00...,99.700000,dense_7.5B_lang30_new_cc100_xl_unigram,30000,blimp,62.914560,20.510147,1385.001124
4,blimp__anaphor_number_agreement,anaphor_number_agreement,en,anaphor_number_agreement,en,blimp,0,False,mean,dense_7.5B_lang30_new_cc100_xl_unigram__step00...,99.500000,dense_7.5B_lang30_new_cc100_xl_unigram,30000,blimp,62.914560,20.510147,1385.001124
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1779,piqa,valid,en,train,en,piqa,16,False,mean,dense_7.5B_lang30_new_cc100_xl_unigram__step00...,73.884657,dense_7.5B_lang30_new_cc100_xl_unigram,238000,piqa,499.122176,162.713829,10987.675582
1780,winogrande,dev,en,train_xl,en,winogrande,16,False,suffix,dense_7.5B_lang30_new_cc100_xl_unigram__step00...,60.063141,dense_7.5B_lang30_new_cc100_xl_unigram,238000,winogrande,499.122176,162.713829,10987.675582
1781,winogrande,dev,en,train_xl,en,winogrande,64,False,suffix,dense_7.5B_lang30_new_cc100_xl_unigram__step00...,60.015785,dense_7.5B_lang30_new_cc100_xl_unigram,238000,winogrande,499.122176,162.713829,10987.675582
1782,arceasy,dev,en,train,en,arc_old,16,False,mean,dense_7.5B_lang30_new_cc100_xl_unigram__step00...,63.859649,dense_7.5B_lang30_new_cc100_xl_unigram,238000,arceasy,499.122176,162.713829,10987.675582


In [16]:
chart = alt.Chart(en_result_df).mark_line(point=True).encode(
    x=alt.X("nb_few_shot_samples:Q", axis=alt.Axis(tickMinStep=1)),
    y=alt.Y("accuracy::mean", type="quantitative", scale=alt.Scale(zero=False)),
    color=alt.Color("model:N", scale=alt.Scale(domain=['dense_7.5B_lang30_new_cc100_xl_unigram', 'moe_200B_lang30_new_cc100_xl_unigram', '6.7B_gpt3_setting'], range=['#4E79A7', '#F28E2B', '#E15759'])),
).transform_filter(
    ((eval_settings != 'default') | ((datum.task in ['openbookqa', 'winogrande'])))
    & (datum.language == 'en')
    & (datum.meta_task != 'blimp') 
    & (last_checkpoint(datum))
).properties(
    title=task
).facet(
    facet='task:N',
    columns=4
).resolve_scale(
    x='independent',
    y='independent'
)
chart

In [17]:
chart = alt.Chart(multi_result_df).mark_line(point=True).encode(
    x=alt.X("nb_few_shot_samples:Q", axis=alt.Axis(tickMinStep=1)),
    y=alt.Y("mean(accuracy::mean)", type="quantitative", scale=alt.Scale(zero=False)),
    color=alt.Color("resource_level:N")
    # color=alt.Color("resource_level:N", scale=alt.Scale(domain=['dense_7.5B_lang30_new_cc100_xl_unigram', 'moe_200B_lang30_new_cc100_xl_unigram', '6.7B_gpt3_setting'], range=['#4E79A7', '#F28E2B', '#E15759'])),
).transform_filter(
      (datum.meta_task != 'blimp') 
    & (((datum.model == '6.7B_gpt3_setting') & (datum.step == 143050))
    | ((datum.model == 'dense_7.5B_lang30_new_cc100_xl_unigram') & (datum.step == 238000))
    | ((datum.model == 'moe_200B_lang30_new_cc100_xl_unigram') & (datum.step == 118000)))
).properties(
    title=task
).facet(
    facet='task:N',
    columns=4
).resolve_scale(
    x='independent',
    y='independent'
)
chart