In [1]:
import numpy as np
import summarize_all_results
import importlib
import math
import pickle
importlib.reload(summarize_all_results)

<module 'summarize_all_results' from '/juice/scr/ananya/cifar_experiments/transfer_learning/scripts/summarize_all_results.py'>

In [2]:
log_dir = '../logs/'
name_template = 'full_ft_{option}{dataset}_{model}'

def get_bolds(table, std_err_table=None, eps=0.2):
    is_best = np.zeros((len(table), len(table[0])), dtype=np.int32)
    for col in range(len(table[0])):
        best_idx = np.argmax(table[:,col])
        best_acc = table[best_idx][col]
        if std_err_table is not None and col < len(std_err_table[0]):
            best_std = std_err_table[best_idx][col]
            if math.isnan(best_std):
                best_std = eps
        for row in range(len(table)):
            if std_err_table is not None and col < len(std_err_table[0]):
                cur_std = std_err_table[row][col]
                if math.isnan(cur_std):
                    cur_std = eps
                if np.abs(table[row][col] - best_acc) <= max(best_std, cur_std):
                    is_best[row][col] = 1
            else:
                if table[row][col] >= best_acc - eps:
                    is_best[row][col] = 1
    return is_best

def get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list=None,
              aggregation_metric_list=None, num_sweep=6, desired_epochs_list=None):
    # desired_epochs_list is a list of size len(datasets) that stores the number of epochs we are
    # supposed to have run each dataset for. Used to validate if the runs finished.
    mean_table = np.zeros((len(models), len(options), sum([len(ol) for ol in output_metrics_list])))
    std_table = np.zeros((len(models), len(options), sum([len(ol) for ol in output_metrics_list])))
    for model_idx, model in enumerate(models):
        for option_idx, option in enumerate(options):
            cur_means = []
            cur_stds = []
            for dataset_idx, dataset in enumerate(datasets):
                dir_path = log_dir + '/' + name_template.format(model=model, option=option, dataset=dataset)
                output_metrics = output_metrics_list[dataset_idx]
                if val_metric_list is None:
                    val_metric = 'LAST'
                else:
                    val_metric = val_metric_list[dataset_idx]
                if aggregation_metric_list is None:
                    agg_metric = val_metric
                else:
                    agg_metric = aggregation_metric_list[dataset_idx]
                res, best_row_mean, best_row_std, num_epochs_list = summarize_all_results.get_experiment_aggregate_summary(
                    dir_path, val_metric, output_metrics, aggregation_metric=agg_metric)
                num_jobs = len(res)
                if (desired_epochs_list is not None and
                    desired_epochs_list[dataset_idx] != np.min(num_epochs_list)):
                    bad_indices = np.argwhere(np.array(num_epochs_list) < desired_epochs_list[dataset_idx])
                    bad_runs = [dir_path + '/' + run_folder for run_folder in res['name'][bad_indices[:, 0]]]
                    print('some jobs did not finish: ', model, dataset, option, bad_runs)
                if num_jobs != num_sweep and num_jobs != num_sweep * 3:
                    print('not all jobs ran: ', model, option, dataset, num_jobs)
                cur_means.append(best_row_mean.to_numpy()[0,:-2])
                cur_stds.append(best_row_std.to_numpy()[0,:-2])
            cur_means = np.concatenate(cur_means)
            cur_stds = np.concatenate(cur_stds)
            mean_table[model_idx][option_idx] = cur_means
            std_table[model_idx][option_idx] = cur_stds
    return mean_table, std_table

def flatten(list_of_lists):
    x = []
    for l in list_of_lists:
        x.append(l)
    return x

def filter_columns(table, old_output_metrics_list, new_output_metrics_list):
    def get_indices(old_list, new_items):
        return np.array([old_list.index(x) for x in new_items])
    local_indices_list = [get_indices(o, n) for o, n in zip(old_output_metrics_list, new_output_metrics_list)]
#     print(local_indices_list)
    old_lens = [len(o) for o in old_output_metrics_list]
    cum_old_lens = np.concatenate([[0], np.cumsum(old_lens)])[:-1]
#     print(cum_old_lens)
    global_indices = [local_indices + cum_old_len for local_indices, cum_old_len in zip(local_indices_list, cum_old_lens)]
#     print(list(np.concatenate(global_indices)))
    return table[:, :, global_indices][:,:,:,0]


# Add table to make TSV and Latex tables. Add averages.
def display_tsv_table(models, options, shortened_output_metrics_list, table, std_err_table=None):
    first_line = '\t\t' + '\t'.join(shortened_output_metrics_list)
    print(first_line)
    for model_idx, model in enumerate(models):
        for option_idx, option in enumerate(options):
            line = model + '\t' + option
            for idx, val in enumerate(table[model_idx][option_idx]):
                line += '\t' + '{:.1f}'.format(val)
                if (std_err_table is not None and
                    idx < len(std_err_table[model_idx][option_idx])):
                    line += ' ({:.1f})'.format(std_err_table[model_idx][option_idx][idx])
            print(line)
            
def add_average_column(table):
    new_table = []
    for model_idx in range(len(table)):
        model_table = table[model_idx]
        means = np.mean(model_table, axis=1)
        new_model_table = np.concatenate([model_table, np.expand_dims(means, axis=-1)], axis=-1)
        new_table.append(new_model_table)
    return np.array(new_table)
        
def display_latex_table(models, options, shortened_output_metrics_list, table, std_err_table=None, bold_best=True, is_last_avg=True, light_thresh=0.2):

    # Print latex table.
    if len(models) == 1:
        num_columns = 2*len(shortened_output_metrics_list) + 1
    else:
        num_columns = 2*len(shortened_output_metrics_list) + 2
    if is_last_avg:
        print('\\begin{tabular}{' + ('c'*(num_columns-2)) + '|cc}')
    else:
        print('\\begin{tabular}{' + ('c'*num_columns) + '}')
    print('\\toprule')
    first_line = '' 
    for oml in shortened_output_metrics_list:
        first_line += ' & \multicolumn{2}{c}{' + oml + '} '
    first_line += '\\\\'
    if len(models) > 1:
        first_line = ' & ' + first_line
    print(first_line)
    for model_idx, model in enumerate(models):
        print('\\midrule')
        if std_err_table is None:
            is_bold = get_bolds(table[model_idx], eps=0.2)
        else:
            is_bold = get_bolds(table[model_idx], std_err_table[model_idx], eps=0.2)
        for option_idx, option in enumerate(options):
            if len(models) > 1:
                line = model + ' & ' + option
            else:
                line = option
            for idx, val in enumerate(table[model_idx][option_idx]):
                line += ' & '
#                 print(is_bold[option_idx][idx])
                if is_bold[option_idx][idx] and bold_best:
                    line += '\\textbf{'
                line += '{:.1f}'.format(val)
                if (std_err_table is not None and
                    idx < len(std_err_table[model_idx][option_idx]) and
                    not math.isnan(std_err_table[model_idx][option_idx][idx])):
                    line += ' ({:.1f})'.format(std_err_table[model_idx][option_idx][idx])
                if is_bold[option_idx][idx] and bold_best:
                    line += '}'
                if option_idx:# if not sgd
                    sgd_val = table[model_idx][0][idx]
                    diff_fp = val - sgd_val
                    diff = round(val,1)-round(sgd_val,1)
                    if diff_fp>=light_thresh:
                        line += ' & \hspace{{-1em}}{{\hgreen{{(+{:.1f})}}}}'.format(diff)
                    elif diff_fp>=0.0:
                        line += ' & \hspace{{-1em}}{{\lgreen{{(+{:.1f})}}}}'.format(diff)
                    elif diff>=-1*light_thresh:
                        line += ' & \hspace{{-1em}}{{\lred{{(-{:.1f})}}}}'.format(abs(diff))
                    else:
                        line += ' & \hspace{{-1em}}{{\hred{{(-{:.1f})}}}}'.format(abs(diff))
                    # else:
                    #     line += ' & '
                else:
                    line += ' & '

            line += '\\\\'
            print(line)
    print('\\bottomrule')
    print('\\end{tabular}')
    
    
def display_id_ood_tables(mean_table, std_table, id_output_metrics_list, ood_output_metrics_list, shortened_dataset_names, no_std=False):
    # Filter the ID and OOD metrics.
    print('\n\nID Accuracies')
    id_mean_table = filter_columns(mean_table, output_metrics_list, id_output_metrics_list)
    id_mean_table = add_average_column(id_mean_table)
    id_std_table = filter_columns(std_table, output_metrics_list, id_output_metrics_list)
    if no_std:
        id_std_table = None
    display_tsv_table(shorted_model_names, shortened_options_names, shortened_dataset_names, id_mean_table, id_std_table)
    print('')
    display_latex_table(shorted_model_names, shortened_options_names, shortened_dataset_names, id_mean_table, id_std_table, is_last_avg=True)

    print('\n\nOOD Accuracies')
    ood_mean_table = filter_columns(mean_table, output_metrics_list, ood_output_metrics_list)
    ood_mean_table = add_average_column(ood_mean_table)
    ood_std_table = filter_columns(std_table, output_metrics_list, ood_output_metrics_list)
    if no_std:
        ood_std_table = None
    display_tsv_table(shorted_model_names, shortened_options_names, shortened_dataset_names, ood_mean_table, ood_std_table)
    print('')
    display_latex_table(shorted_model_names, shortened_options_names, shortened_dataset_names, ood_mean_table, ood_std_table, is_last_avg=True)
    return id_mean_table, ood_mean_table
    

# Big table, Living17, Waterbirds, DomainNet, Camelyon, FMoW (Early stopped)


In [None]:
# Big table, Living17, Waterbirds, DomainNet, Camelyon, FMoW (Early stopped)
models = ['clip_vit_b16', 'clip_vit_l14', 'timm_vit_b16_in21k', 'dino_vit_b16',  'convnext_vit_b', 'bit_resnet_50_in21k', 'bit_resnet_101_in21k']
options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_']
datasets = ['living17_nonorm', 'waterbirds', 'domainnet', 'fmow_all_nonorm_weakaugs', 'camelyon17_weakaugs']
output_metrics_list = [['test_acc/source_val_living', 'test_acc/target_val_living'], ['WATERBIRDS_VAL', 'WORST'], ['test_acc/sketch_val', 'test_acc/real_val'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test', 'test_acc/africa_test'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test']]
val_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
aggregation_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
desired_epochs_list = [20, 20, 50, 5, 3]
shorted_model_names = ['CLIP ViT-B/16', 'CLIP ViT-L/14', 'Sup ViT-B/16', 'DINO ViT-B/16', 'ConvNext-Base', 'BiT ResNet-50', 'BiT ResNet-101', ]
shortened_options_names = ['SGD', 'AdamW', 'SGD (freeze-embed)']
id_output_metrics_list = [['test_acc/source_val_living'], ['WATERBIRDS_VAL'], ['test_acc/sketch_val'], ['test_acc/id_val'], ['test_acc/id_val']]
ood_output_metrics_list = [['test_acc/target_val_living'], ['WORST'], ['test_acc/real_val'], ['test_acc/africa_test'], ['test_acc/ood_test']]
shortened_dataset_names = ['Living-17', 'Waterbirds', 'DomainNet', "FMoW", "Camelyon", "Avg."]


mean_table, std_table = get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list, desired_epochs_list=desired_epochs_list)
# pickle.dump( (mean_table, std_table), open( "big_table_main.pkl", "wb" ) )
# mean_table_old_order, std_table_old_order = pickle.load( open( "big_table_main.pkl", "rb" ) )

# Rearranging mean table
mean_table = mean_table_old_order[[0,6,1,2,5,3,4],:,:]
std_table = std_table_old_order[[0,6,1,2,5,3,4],:,:]

id_mean_table, ood_mean_table = display_id_ood_tables(
    mean_table, std_table, id_output_metrics_list, ood_output_metrics_list, shortened_dataset_names, no_std=True)

# display_tsv_table(shorted_model_names, shortened_options_names, shortened_output_metrics_list, mean_table, std_table)

In [None]:
# Get average ID and OOD results for the 3 methods.
print(np.mean(id_mean_table[:,:,5], axis=0))
print(np.mean(ood_mean_table[:,:,5], axis=0))

# ConvNeXt results including just freeze stem

In [3]:
# # ConvNeXt, Living17, Waterbirds, DomainNet, Camelyon, FMoW (Early stopped), including freeze-stem
# models = ['convnext_vit_b']
# options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_', 'freeze_bottom_1_']
# datasets = ['living17_nonorm', 'waterbirds', 'domainnet', 'fmow_all_nonorm_weakaugs', 'camelyon17_weakaugs']
# output_metrics_list = [['test_acc/source_val_living', 'test_acc/target_val_living'], ['WATERBIRDS_VAL', 'WORST'], ['test_acc/sketch_val', 'test_acc/real_val'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test', 'test_acc/africa_test'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test']]
# val_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
# aggregation_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
# desired_epochs_list = [20, 20, 50, 5, 3]
# shorted_model_names = ['CLIP ViT-B/16', 'CLIP ViT-L/14', 'Sup ViT-B/16', 'DINO ViT-B/16', 'ConvNext-Base', 'BiT ResNet-50', 'BiT ResNet-101', ]
# shortened_options_names = ['SGD', 'AdamW', 'SGD (freeze-stem-block-1)', 'SGD (freeze-stem)']
# id_output_metrics_list = [['test_acc/source_val_living'], ['WATERBIRDS_VAL'], ['test_acc/sketch_val'], ['test_acc/id_val'], ['test_acc/id_val']]
# ood_output_metrics_list = [['test_acc/target_val_living'], ['WORST'], ['test_acc/real_val'], ['test_acc/africa_test'], ['test_acc/ood_test']]
# shortened_dataset_names = ['Living-17', 'Waterbirds', 'DomainNet', "FMoW", "Camelyon", "Avg."]


# mean_table, std_table = get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list, desired_epochs_list=desired_epochs_list)
# # pickle.dump( (mean_table, std_table), open( "big_table_main.pkl", "wb" ) )
# # mean_table_old_order, std_table_old_order = pickle.load( open( "big_table_main.pkl", "rb" ) )

# id_mean_table, ood_mean_table = display_id_ood_tables(
#     mean_table, std_table, id_output_metrics_list, ood_output_metrics_list, shortened_dataset_names, no_std=True)

# # display_tsv_table(shorted_model_names, shortened_options_names, shortened_output_metrics_list, mean_table, std_table)

# The following table is without DomainNet.
models = ['convnext_vit_b']
options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_', 'freeze_bottom_1_']
datasets = ['living17_nonorm', 'waterbirds', 'fmow_all_nonorm_weakaugs', 'camelyon17_weakaugs']
output_metrics_list = [['test_acc/source_val_living', 'test_acc/target_val_living'], ['WATERBIRDS_VAL', 'WORST'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test', 'test_acc/africa_test'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test']]
val_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/id_val', 'test_acc/id_val']
aggregation_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/id_val', 'test_acc/id_val']
desired_epochs_list = [20, 20, 5, 3]
shorted_model_names = ['ConvNext-Base']
shortened_options_names = ['SGD', 'AdamW', 'SGD (freeze-stem-block-1)', 'SGD (freeze-stem)']
id_output_metrics_list = [['test_acc/source_val_living'], ['WATERBIRDS_VAL'], ['test_acc/id_val'], ['test_acc/id_val']]
ood_output_metrics_list = [['test_acc/target_val_living'], ['WORST'], ['test_acc/africa_test'], ['test_acc/ood_test']]
shortened_dataset_names = ['Living-17', 'Waterbirds', "FMoW", "Camelyon", "Avg."]


mean_table, std_table = get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list, desired_epochs_list=desired_epochs_list)
# pickle.dump( (mean_table, std_table), open( "big_table_main.pkl", "wb" ) )
# mean_table_old_order, std_table_old_order = pickle.load( open( "big_table_main.pkl", "rb" ) )

print(mean_table.shape)
id_mean_table, ood_mean_table = display_id_ood_tables(
    mean_table, std_table, id_output_metrics_list, ood_output_metrics_list, shortened_dataset_names, no_std=True)

# display_tsv_table(shorted_model_names, shortened_options_names, shortened_output_metrics_list, mean_table, std_table)

(1, 4, 11)


ID Accuracies
		Living-17	Waterbirds	FMoW	Camelyon	Avg.
ConvNext-Base	SGD	98.7	99.0	66.3	99.4	90.9
ConvNext-Base	AdamW	98.6	99.5	68.8	99.7	91.7
ConvNext-Base	SGD (freeze-stem-block-1)	98.6	99.4	67.4	99.5	91.2
ConvNext-Base	SGD (freeze-stem)	98.8	99.3	67.2	99.5	91.2

\begin{tabular}{ccccccccc|cc}
\toprule
 & \multicolumn{2}{c}{Living-17}  & \multicolumn{2}{c}{Waterbirds}  & \multicolumn{2}{c}{FMoW}  & \multicolumn{2}{c}{Camelyon}  & \multicolumn{2}{c}{Avg.} \\
\midrule
SGD & \textbf{98.7} &  & 99.0 &  & 66.3 &  & 99.4 &  & 90.9 & \\
AdamW & \textbf{98.6} & \hspace{-1em}{\lred{(-0.1)}} & \textbf{99.5} & \hspace{-1em}{\hgreen{(+0.5)}} & \textbf{68.8} & \hspace{-1em}{\hgreen{(+2.5)}} & \textbf{99.7} & \hspace{-1em}{\hgreen{(+0.3)}} & \textbf{91.7} & \hspace{-1em}{\hgreen{(+0.8)}}\\
SGD (freeze-stem-block-1) & \textbf{98.6} & \hspace{-1em}{\lred{(-0.1)}} & \textbf{99.4} & \hspace{-1em}{\hgreen{(+0.4)}} & 67.4 & \hspace{-1em}{\hgreen{(+1.1)}} & \textbf{99.5} & \hspace{-1em}{\lgr

# CLIP results including wilds datasets

In [35]:

print("CLIP Results, all datasets")
models = ['clip_vit_b16']
options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_', 'layer_wise_tune__', 'opt_torch_optimizer.Lamb_', 'opt_torch_optimizer.LARS_']
val_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
datasets = ['living17_nonorm', 'waterbirds', 'domainnet', 'fmow_all_nonorm_weakaugs', 'camelyon17_weakaugs']
output_metrics_list = [['test_acc/source_val_living', 'test_acc/target_val_living'], ['WATERBIRDS_VAL', 'WORST'], ['test_acc/sketch_val', 'test_acc/real_val'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test', 'test_acc/africa_test'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test']]
aggregation_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
desired_epochs_list = [20, 20, 50, 5, 3]

mean_table, std_table = get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list, desired_epochs_list=desired_epochs_list)
# pickle.dump( (mean_table, std_table), open( "clip_table_main.pkl", "wb" ) )
# mean_table, std_table = pickle.load( open( "clip_table_main.pkl", "rb" ) )

shorted_model_names = ['CLIP ViT-B/16']
shortened_options_names = ['SGD', 'AdamW', 'SGD (Freeze-embed)', 'Layer-wise', 'LAMB', 'LARS']
shortened_output_metrics_list = ['Living-17 ID', 'Living-17 OOD', 'Waterbirds ID', 'Waterbirds OOD', 'DomainNet ID', 'DomainNet OOD', "FMoW ID", "FMoW OOD Val", "FMoW OOD Test", "FMoW Africa", "Camelyon ID", "Camelyon OOD Val", "Camelyon OOD Test"]

id_output_metrics_list = [['test_acc/source_val_living'], ['WATERBIRDS_VAL'], ['test_acc/sketch_val'], ['test_acc/id_val'], ['test_acc/id_val']]
ood_output_metrics_list = [['test_acc/target_val_living'], ['WORST'], ['test_acc/real_val'], ['test_acc/africa_test'], ['test_acc/ood_test']]
shortened_dataset_names = ['Liv-17', 'Waterbirds', 'DomainNet', "FMoW", "Camelyon", "Avg."]

display_id_ood_tables(mean_table, std_table, id_output_metrics_list, ood_output_metrics_list, shortened_dataset_names)


CLIP Results, all datasets
living17_nonorm                                     name  test_acc/source_val_living  \
0   optimizer.args.lr-0.0001_seed-0_run0                     97.8235   
1   optimizer.args.lr-0.0001_seed-1_run1                     97.6471   
2   optimizer.args.lr-0.0001_seed-2_run2                     98.0000   
3   optimizer.args.lr-0.0003_seed-0_run0                     97.7059   
4   optimizer.args.lr-0.0003_seed-1_run1                     97.1765   
5   optimizer.args.lr-0.0003_seed-2_run2                     97.2353   
6    optimizer.args.lr-0.001_seed-0_run0                     96.8824   
7    optimizer.args.lr-0.001_seed-1_run1                     95.1765   
8    optimizer.args.lr-0.001_seed-2_run2                     95.5882   
9    optimizer.args.lr-0.003_seed-0_run0                     94.0588   
10   optimizer.args.lr-0.003_seed-1_run1                     73.0588   
11   optimizer.args.lr-0.003_seed-2_run2                     74.8235   
12    optimizer.args.

fmow_all_nonorm_weakaugs                                     name  test_acc/id_val  test_acc/ood_val  \
0   optimizer.args.lr-0.0001_seed-0_run0          66.9337           65.3678   
1   optimizer.args.lr-0.0001_seed-1_run1          65.0527           63.6103   
2   optimizer.args.lr-0.0001_seed-2_run2          66.1413           64.9561   
3   optimizer.args.lr-0.0003_seed-0_run0          67.8830           65.5084   
4   optimizer.args.lr-0.0003_seed-1_run1          66.2632           64.6950   
5   optimizer.args.lr-0.0003_seed-2_run2          66.7683           64.8104   
6    optimizer.args.lr-0.001_seed-0_run0          63.3110           60.7432   
7    optimizer.args.lr-0.001_seed-1_run1          63.6506           61.1097   
8    optimizer.args.lr-0.001_seed-2_run2          62.5098           60.7381   
9    optimizer.args.lr-0.003_seed-0_run0          53.3049           50.9114   
10   optimizer.args.lr-0.003_seed-1_run1          53.9929           51.6646   
11   optimizer.args.lr-0.00

waterbirds                                     name  WATERBIRDS_VAL    WORST  \
0   optimizer.args.lr-0.0001_seed-0_run0         89.5521   2.8037   
1   optimizer.args.lr-0.0001_seed-1_run1         89.9072   6.2305   
2   optimizer.args.lr-0.0001_seed-2_run2         88.8488   5.4517   
3    optimizer.args.lr-1e-05_seed-0_run0         97.9436  69.3146   
4    optimizer.args.lr-1e-05_seed-1_run1         97.4628  66.6667   
5    optimizer.args.lr-1e-05_seed-2_run2         97.6901  68.5358   
6    optimizer.args.lr-1e-06_seed-0_run0         97.3885  67.2897   
7    optimizer.args.lr-1e-06_seed-1_run1         97.3596  70.7165   
8    optimizer.args.lr-1e-06_seed-2_run2         97.6066  64.1745   
9    optimizer.args.lr-3e-05_seed-0_run0         97.0159  53.5826   
10   optimizer.args.lr-3e-05_seed-1_run1         96.7293  43.6137   
11   optimizer.args.lr-3e-05_seed-2_run2         96.9287  48.5981   
12   optimizer.args.lr-3e-06_seed-0_run0         97.7807  69.3146   
13   optimizer.args.lr-

camelyon17_weakaugs                                     name  test_acc/id_val  test_acc/ood_val  \
0   optimizer.args.lr-0.0001_seed-0_run0          98.4237           78.4065   
1   optimizer.args.lr-0.0001_seed-1_run1          98.3045           79.9851   
2   optimizer.args.lr-0.0001_seed-2_run2          97.5387           76.1517   
3    optimizer.args.lr-1e-05_seed-0_run0          99.4458           91.2102   
4    optimizer.args.lr-1e-05_seed-1_run1          99.4160           91.2847   
5    optimizer.args.lr-1e-05_seed-2_run2          99.4190           89.3594   
6    optimizer.args.lr-1e-06_seed-0_run0          99.4249           94.7886   
7    optimizer.args.lr-1e-06_seed-1_run1          99.5143           93.6110   
8    optimizer.args.lr-1e-06_seed-2_run2          99.5083           94.2757   
9    optimizer.args.lr-3e-05_seed-0_run0          99.1359           87.5659   
10   optimizer.args.lr-3e-05_seed-1_run1          98.9988           84.8413   
11   optimizer.args.lr-3e-05_see

domainnet                                                  name  test_acc/sketch_val  \
0   freeze_bottom_k-2_optimizer.args.lr-0.0001_see...              95.1647   
1   freeze_bottom_k-2_optimizer.args.lr-0.0001_see...              94.6228   
2   freeze_bottom_k-2_optimizer.args.lr-0.0001_see...              94.9562   
3   freeze_bottom_k-2_optimizer.args.lr-0.0003_see...              94.4977   
4   freeze_bottom_k-2_optimizer.args.lr-0.0003_see...              94.5394   
5   freeze_bottom_k-2_optimizer.args.lr-0.0003_see...              94.4977   
6   freeze_bottom_k-2_optimizer.args.lr-0.001_seed...              90.7461   
7   freeze_bottom_k-2_optimizer.args.lr-0.001_seed...              90.3710   
8   freeze_bottom_k-2_optimizer.args.lr-0.001_seed...              92.3718   
9   freeze_bottom_k-2_optimizer.args.lr-0.003_seed...              84.1601   
10  freeze_bottom_k-2_optimizer.args.lr-0.003_seed...              85.5356   
11  freeze_bottom_k-2_optimizer.args.lr-0.003_seed... 

camelyon17_weakaugs                                                  name  test_acc/id_val  \
0   freeze_bottom_k-2_optimizer.args.lr-0.0001_see...          99.5143   
1   freeze_bottom_k-2_optimizer.args.lr-0.0001_see...          99.5083   
2   freeze_bottom_k-2_optimizer.args.lr-0.0001_see...          99.5203   
3   freeze_bottom_k-2_optimizer.args.lr-0.0003_see...          99.5322   
4   freeze_bottom_k-2_optimizer.args.lr-0.0003_see...          99.5173   
5   freeze_bottom_k-2_optimizer.args.lr-0.0003_see...          99.5292   
6   freeze_bottom_k-2_optimizer.args.lr-0.001_seed...          99.4458   
7   freeze_bottom_k-2_optimizer.args.lr-0.001_seed...          99.3802   
8   freeze_bottom_k-2_optimizer.args.lr-0.001_seed...          99.4219   
9   freeze_bottom_k-2_optimizer.args.lr-0.003_seed...          99.2223   
10  freeze_bottom_k-2_optimizer.args.lr-0.003_seed...          97.6698   
11  freeze_bottom_k-2_optimizer.args.lr-0.003_seed...          98.6532   
12  freeze_bottom_

waterbirds                                                  name  WATERBIRDS_VAL  \
0   layer-wise-tune-True_optimizer.args.lr-0.0001_...         98.2763   
1   layer-wise-tune-True_optimizer.args.lr-0.0001_...         98.2581   
2   layer-wise-tune-True_optimizer.args.lr-0.0001_...         98.3464   
3   layer-wise-tune-True_optimizer.args.lr-1e-05_s...         98.0542   
4   layer-wise-tune-True_optimizer.args.lr-1e-05_s...         97.8983   
5   layer-wise-tune-True_optimizer.args.lr-1e-05_s...         97.9775   
6   layer-wise-tune-True_optimizer.args.lr-1e-06_s...         97.3154   
7   layer-wise-tune-True_optimizer.args.lr-1e-06_s...         97.3131   
8   layer-wise-tune-True_optimizer.args.lr-1e-06_s...         97.2459   
9   layer-wise-tune-True_optimizer.args.lr-3e-05_s...         98.2068   
10  layer-wise-tune-True_optimizer.args.lr-3e-05_s...         98.1980   
11  layer-wise-tune-True_optimizer.args.lr-3e-05_s...         98.0882   
12  layer-wise-tune-True_optimizer.args.

camelyon17_weakaugs                                                  name  test_acc/id_val  \
0   layer-wise-tune-True_optimizer.args.lr-0.0001_...          99.1597   
1   layer-wise-tune-True_optimizer.args.lr-0.0001_...          99.0256   
2   layer-wise-tune-True_optimizer.args.lr-0.0001_...          98.7217   
3   layer-wise-tune-True_optimizer.args.lr-1e-05_s...          99.2610   
4   layer-wise-tune-True_optimizer.args.lr-1e-05_s...          99.1329   
5   layer-wise-tune-True_optimizer.args.lr-1e-05_s...          99.3325   
6   layer-wise-tune-True_optimizer.args.lr-1e-06_s...          98.6442   
7   layer-wise-tune-True_optimizer.args.lr-1e-06_s...          98.6591   
8   layer-wise-tune-True_optimizer.args.lr-1e-06_s...          98.7455   
9   layer-wise-tune-True_optimizer.args.lr-3e-05_s...          99.3355   
10  layer-wise-tune-True_optimizer.args.lr-3e-05_s...          99.1895   
11  layer-wise-tune-True_optimizer.args.lr-3e-05_s...          99.2819   
12  layer-wise-tun

fmow_all_nonorm_weakaugs                                    name  test_acc/id_val  test_acc/ood_val  \
0  optimizer.args.lr-0.0001_seed-0_run0          67.9178           66.1863   
1   optimizer.args.lr-1e-05_seed-0_run0          47.3918           47.9538   
2   optimizer.args.lr-1e-06_seed-0_run0          15.7624           17.8509   
3   optimizer.args.lr-3e-05_seed-0_run0          62.4924           61.0244   
4   optimizer.args.lr-3e-06_seed-0_run0          26.6394           28.8024   
5   optimizer.args.lr-3e-07_seed-0_run0          11.1382           11.9407   

   test_acc/ood_test  test_acc/africa_test  \
0            59.2365               38.7582   
1            43.8484               32.2792   
2            12.8686                7.2503   
3            55.2108               37.6012   
4            23.6928               15.9661   
5             9.2093                1.9283   

                                           wandb_url                     group  
0  https://wandb.ai/p-la

camelyon17_weakaugs                                    name  test_acc/id_val  test_acc/ood_val  \
0  optimizer.args.lr-0.0001_seed-0_run0          99.2193           85.9844   
1   optimizer.args.lr-1e-05_seed-0_run0          99.1806           94.0694   
2   optimizer.args.lr-1e-06_seed-0_run0          98.4118           89.9353   
3   optimizer.args.lr-3e-05_seed-0_run0          99.3445           94.1124   
4   optimizer.args.lr-3e-06_seed-0_run0          98.9273           91.0411   
5   optimizer.args.lr-3e-07_seed-0_run0          96.9249           91.2073   

   test_acc/ood_test                                          wandb_url  \
0            75.9976  https://wandb.ai/p-lambda/finetuning/runs/2hlj...   
1            94.2683  https://wandb.ai/p-lambda/finetuning/runs/uaha...   
2            89.2750  https://wandb.ai/p-lambda/finetuning/runs/1lmq...   
3            93.2843  https://wandb.ai/p-lambda/finetuning/runs/2km8...   
4            91.7335  https://wandb.ai/p-lambda/finetuning

(array([[[97.82353333, 97.22233333, 88.8009    , 66.9715    ,
          99.3802    , 90.03969333],
         [98.09803333, 97.7472    , 94.9979    , 70.11816667,
          99.5451    , 92.10128   ],
         [98.15683333, 97.8062    , 94.91456667, 70.00203333,
          99.52623333, 92.08117333],
         [98.31373333, 98.2936    , 96.30403333, 69.19793333,
          99.26896667, 92.27565333],
         [98.1765    , 97.7633    , 95.123     , 67.9178    ,
          99.5232    , 91.70076   ],
         [97.7059    , 97.1196    , 93.2472    , 66.986     ,
          99.3445    , 90.88064   ]]]),
 array([[[79.9608    , 62.46103333, 72.77833333, 37.34413333,
          86.79546667, 67.86795333],
         [82.80393333, 71.85876667, 89.16893333, 40.66073333,
          95.72586667, 76.04364667],
         [83.17646667, 73.72796667, 88.22793333, 40.18513333,
          94.29106667, 75.92171333],
         [81.90196667, 69.05503333, 93.24016667, 40.46796667,
          96.48143333, 76.22931333],
       

# CLIP Results for ViT-L

In [None]:
# Big table, Living17, Waterbirds, DomainNet, Camelyon, FMoW (Early stopped)
models = ['clip_vit_l14']
options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_']
datasets = ['living17_nonorm', 'waterbirds', 'domainnet', 'fmow_all_nonorm_weakaugs', 'camelyon17_weakaugs']
output_metrics_list = [['test_acc/source_val_living', 'test_acc/target_val_living'], ['WATERBIRDS_VAL', 'WORST'], ['test_acc/sketch_val', 'test_acc/real_val'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test', 'test_acc/africa_test'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test']]
val_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
aggregation_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
desired_epochs_list = [20, 20, 50, 5, 3]
mean_table, std_table = get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list, desired_epochs_list=desired_epochs_list)
shorted_model_names = ['CLIP ViT-B/16', 'Sup ViT-B/16', 'DINO ViT-B/16', 'BIT ResNet-50', 'BIT ResNet-101', 'ConvNext-Base', 'CLIP ViT-L/14']
shortened_options_names = ['SGD', 'AdamW', 'SGD (Freeze-2)']

id_output_metrics_list = [['test_acc/source_val_living'], ['WATERBIRDS_VAL'], ['test_acc/sketch_val'], ['test_acc/id_val'], ['test_acc/id_val']]
ood_output_metrics_list = [['test_acc/target_val_living'], ['WORST'], ['test_acc/real_val'], ['test_acc/africa_test'], ['test_acc/ood_test']]
shortened_dataset_names = ['Liv-17', 'Waterbirds', 'DomainNet', "FMoW", "Camelyon", "Avg."]

id_mean_table, ood_mean_table = display_id_ood_tables(
    mean_table, std_table, id_output_metrics_list, ood_output_metrics_list, shortened_dataset_names, no_std=True)

# display_tsv_table(shorted_model_names, shortened_options_names, shortened_output_metrics_list, mean_table, std_table)

In [None]:
# Get table for dino and clip, freezing different number of layers.
print("CLIP Results, all datasets")
models = ['clip_vit_b16']
options = ['freeze_bottom_2_full_ft_epoch_50_', 'freeze_bottom_5_full_ft_epoch_50_', 'freeze_bottom_8_full_ft_epoch_50_', 'freeze_bottom_11_full_ft_epoch_50_', 'freeze_bottom_14_full_ft_epoch_50_']
val_metric_list = ['WATERBIRDS_VAL']
datasets = ['waterbirds']
output_metrics_list = [['WATERBIRDS_VAL', 'WORST']]
aggregation_metric_list = ['WATERBIRDS_VAL']
# desired_epochs_list = [20, 20, 50, 5, 3]
mean_table, std_table = get_table('../older_logs/', name_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list, desired_epochs_list=desired_epochs_list)
shorted_model_names = ['CLIP ViT-B/16']
shortened_options_names = ['SGD', 'AdamW', 'SGD (Freeze-embed)', 'Layer-wise', 'LAMB', 'LARS']
shortened_output_metrics_list = ['Living-17 ID', 'Living-17 OOD', 'Waterbirds ID', 'Waterbirds OOD', 'DomainNet ID', 'DomainNet OOD', "FMoW ID", "FMoW OOD Val", "FMoW OOD Test", "FMoW Africa", "Camelyon ID", "Camelyon OOD Val", "Camelyon OOD Test"]


# Post ICLR Supplementary results

## Big table, including AdamW + Freeze 

In [None]:
# Big table, Living17, Waterbirds, DomainNet, Camelyon, FMoW (Early stopped)
models = ['clip_vit_b16', 'clip_vit_l14', 'timm_vit_b16_in21k', 'dino_vit_b16',  'convnext_vit_b', 'bit_resnet_50_in21k', 'bit_resnet_101_in21k']
options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_', 'freeze_bottom_2_opt_torch.optim.AdamW_']
datasets = ['living17_nonorm', 'waterbirds', 'domainnet', 'fmow_all_nonorm_weakaugs', 'camelyon17_weakaugs']
output_metrics_list = [['test_acc/source_val_living', 'test_acc/target_val_living'], ['WATERBIRDS_VAL', 'WORST'], ['test_acc/sketch_val', 'test_acc/real_val'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test', 'test_acc/africa_test'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test']]
val_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
aggregation_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
desired_epochs_list = [20, 20, 50, 5, 3]

# mean_table_3, std_table_3 = pickle.load( open( "big_table_main.pkl", "rb" ) )
# mean_table = np.concatenate((mean_table_3, mean_table_3[:,1,None,:]), axis=1)
# std_table = np.concatenate((std_table_3, std_table_3[:,1,None,:]), axis=1)

mean_table, std_table = get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list, desired_epochs_list=desired_epochs_list)
shorted_model_names = ['CLIP ViT-B/16', 'CLIP ViT-L/14', 'Sup ViT-B/16', 'DINO ViT-B/16', 'ConvNext-Base', 'BiT ResNet-50', 'BiT ResNet-101']
shortened_options_names = ['SGD', 'AdamW', 'SGD (freeze-embed)', 'AdamW (freeze-embed)']

id_output_metrics_list = [['test_acc/source_val_living'], ['WATERBIRDS_VAL'], ['test_acc/sketch_val'], ['test_acc/id_val'], ['test_acc/id_val']]
ood_output_metrics_list = [['test_acc/target_val_living'], ['WORST'], ['test_acc/real_val'], ['test_acc/africa_test'], ['test_acc/ood_test']]
shortened_dataset_names = ['Living-17', 'Waterbirds', 'DomainNet', "FMoW", "Camelyon", "Avg."]


id_mean_table, ood_mean_table = display_id_ood_tables(
    mean_table, std_table, id_output_metrics_list, ood_output_metrics_list, shortened_dataset_names, no_std=True)

# display_tsv_table(shorted_model_names, shortened_options_names, shortened_output_metrics_list, mean_table, std_table)

In [None]:
# Get average ID and OOD results for the 3 methods.
print(np.mean(id_mean_table[:,:,5], axis=0))
print(np.mean(ood_mean_table[:,:,5], axis=0))

## ResNet-50 all results 

In [None]:
# Big table, Living17, Waterbirds, DomainNet, Camelyon, FMoW (Early stopped)
models = ['sup_resnet50']
options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_', 'freeze_bottom_2_opt_torch.optim.AdamW_']
datasets = ['living17_nonorm', 'waterbirds', 'domainnet', 'fmow_all_nonorm_weakaugs', 'camelyon17_weakaugs']
output_metrics_list = [['test_acc/source_val_living', 'test_acc/target_val_living'], ['WATERBIRDS_VAL', 'WORST'], ['test_acc/sketch_val', 'test_acc/real_val'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test', 'test_acc/africa_test'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test']]
val_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
aggregation_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
desired_epochs_list = [20, 20, 50, 5, 3]
mean_table, std_table = get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list, desired_epochs_list=desired_epochs_list)
shorted_model_names = ['ResNet50']
shortened_options_names = ['SGD', 'AdamW', 'SGD (Freeze-2)', 'AdamW (Freeze-2)']

id_output_metrics_list = [['test_acc/source_val_living'], ['WATERBIRDS_VAL'], ['test_acc/sketch_val'], ['test_acc/id_val'], ['test_acc/id_val']]
ood_output_metrics_list = [['test_acc/target_val_living'], ['WORST'], ['test_acc/real_val'], ['test_acc/africa_test'], ['test_acc/ood_test']]
shortened_dataset_names = ['Liv-17', 'Waterbirds', 'DomainNet', "FMoW", "Camelyon", "Avg."]

id_mean_table, ood_mean_table = display_id_ood_tables(
    mean_table, std_table, id_output_metrics_list, ood_output_metrics_list, shortened_dataset_names, no_std=True)

# display_tsv_table(shorted_model_names, shortened_options_names, shortened_output_metrics_list, mean_table, std_table)

## CLIP weight decay, no momentum, freeze only embed 

In [42]:

print("CLIP Results, all datasets")
models = ['clip_vit_b16']
# options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_', 'layer_wise_tune__', 'opt_torch_optimizer.Lamb_', 'opt_torch_optimizer.LARS_', 'freeze_bottom_1_', 'optimizer.args.momentum-0.0_', 'freeze_bottom_2_optimizer.args.momentum-0.0_', 'optimizer.args.weight_decay-0.01_']
options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_', 'layer_wise_tune__', 'opt_torch_optimizer.Lamb_', 'opt_torch_optimizer.LARS_', 'freeze_bottom_1_', 'optimizer.args.momentum-0.0_', 'optimizer.args.weight_decay-0.01_', 'embed_layer_lr_multiplier-0.2_']
val_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
datasets = ['living17_nonorm', 'waterbirds', 'domainnet', 'fmow_all_nonorm_weakaugs', 'camelyon17_weakaugs']
output_metrics_list = [['test_acc/source_val_living', 'test_acc/target_val_living'], ['WATERBIRDS_VAL', 'WORST'], ['test_acc/sketch_val', 'test_acc/real_val'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test', 'test_acc/africa_test'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test']]
aggregation_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
desired_epochs_list = [20, 20, 50, 5, 3]

mean_table, std_table = get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list, desired_epochs_list=desired_epochs_list)
# pickle.dump( (mean_table, std_table), open( "clip_table_extended.pkl", "wb" ) )
# mean_table, std_table = pickle.load( open( "clip_table_extended.pkl", "rb" ) )

shorted_model_names = ['CLIP ViT-B/16']
# shortened_options_names = ['SGD', 'AdamW', 'SGD (Freeze-embed)', 'Layer-wise', 'LAMB', 'LARS', 'Freeze only embed', 'SGD (no momentum)', 'SGD (Freeze, no momentum)', 'SGD (weight decay)']
shortened_options_names = ['SGD', 'AdamW', 'SGD (Freeze-embed)', 'Layer-wise', 'LAMB', 'LARS', 'SGD (Freeze-embed, not layer-norm)', 'SGD (no momentum)', 'SGD (weight decay)', 'SGD (lower LR, embed layer)']
shortened_output_metrics_list = ['Living-17 ID', 'Living-17 OOD', 'Waterbirds ID', 'Waterbirds OOD', 'DomainNet ID', 'DomainNet OOD', "FMoW ID", "FMoW OOD Val", "FMoW OOD Test", "FMoW Africa", "Camelyon ID", "Camelyon OOD Val", "Camelyon OOD Test"]

id_output_metrics_list = [['test_acc/source_val_living'], ['WATERBIRDS_VAL'], ['test_acc/sketch_val'], ['test_acc/id_val'], ['test_acc/id_val']]
ood_output_metrics_list = [['test_acc/target_val_living'], ['WORST'], ['test_acc/real_val'], ['test_acc/africa_test'], ['test_acc/ood_test']]
shortened_dataset_names = ['Liv-17', 'Waterbirds', 'DomainNet', "FMoW", "Camelyon", "Avg."]

display_id_ood_tables(mean_table, std_table, id_output_metrics_list, ood_output_metrics_list, shortened_dataset_names, no_std=True)


CLIP Results, all datasets
not all jobs ran:  clip_vit_b16 opt_torch_optimizer.Lamb_ living17_nonorm 8
not all jobs ran:  clip_vit_b16 opt_torch_optimizer.Lamb_ waterbirds 8
not all jobs ran:  clip_vit_b16 opt_torch_optimizer.Lamb_ domainnet 8
not all jobs ran:  clip_vit_b16 opt_torch_optimizer.LARS_ living17_nonorm 8
not all jobs ran:  clip_vit_b16 opt_torch_optimizer.LARS_ waterbirds 8
not all jobs ran:  clip_vit_b16 opt_torch_optimizer.LARS_ domainnet 8


ID Accuracies
		Liv-17	Waterbirds	DomainNet	FMoW	Camelyon	Avg.
CLIP ViT-B/16	SGD	97.8	97.2	88.8	67.0	99.4	90.0
CLIP ViT-B/16	AdamW	98.1	97.7	95.0	70.1	99.5	92.1
CLIP ViT-B/16	SGD (Freeze-embed)	98.2	97.8	94.9	70.0	99.5	92.1
CLIP ViT-B/16	Layer-wise	98.3	98.3	96.3	69.2	99.3	92.3
CLIP ViT-B/16	LAMB	98.2	97.8	95.1	67.9	99.5	91.7
CLIP ViT-B/16	LARS	97.7	97.1	93.2	67.0	99.3	90.9
CLIP ViT-B/16	SGD (Freeze-embed, not layer-norm)	98.0	98.0	95.4	70.2	99.5	92.2
CLIP ViT-B/16	SGD (no momentum)	98.0	97.1	89.5	66.4	99.3	90.1
CLIP ViT-B/16	SGD (

(array([[[97.82353333, 97.22233333, 88.8009    , 66.9715    ,
          99.3802    , 90.03969333],
         [98.09803333, 97.7472    , 94.9979    , 70.11816667,
          99.5451    , 92.10128   ],
         [98.15683333, 97.8062    , 94.91456667, 70.00203333,
          99.52623333, 92.08117333],
         [98.31373333, 98.2936    , 96.30403333, 69.19793333,
          99.26896667, 92.27565333],
         [98.1765    , 97.7633    , 95.123     , 67.9178    ,
          99.5232    , 91.70076   ],
         [97.7059    , 97.1196    , 93.2472    , 66.986     ,
          99.3445    , 90.88064   ],
         [98.        , 98.0037    , 95.4148    , 70.1646    ,
          99.5262    , 92.22186   ],
         [98.        , 97.071     , 89.5373    , 66.3764    ,
          99.3236    , 90.06166   ],
         [97.6471    , 97.2224    , 87.8699    , 66.3503    ,
          99.2938    , 89.6767    ],
         [98.        , 97.5392    , 94.7895    , 68.6841    ,
          99.5113    , 91.70482   ]]]),
 array(

# Big table with no momentum

In [45]:
# Big table, Living17, Waterbirds, DomainNet, Camelyon, FMoW (Early stopped)
models = ['clip_vit_b16', 'clip_vit_l14', 'timm_vit_b16_in21k', 'dino_vit_b16',  'convnext_vit_b', 'bit_resnet_50_in21k', 'bit_resnet_101_in21k']
options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_', 'freeze_bottom_2_optimizer.args.momentum-0.0_']
datasets = ['living17_nonorm', 'waterbirds', 'domainnet', 'fmow_all_nonorm_weakaugs', 'camelyon17_weakaugs']
output_metrics_list = [['test_acc/source_val_living', 'test_acc/target_val_living'], ['WATERBIRDS_VAL', 'WORST'], ['test_acc/sketch_val', 'test_acc/real_val'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test', 'test_acc/africa_test'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test']]
val_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
aggregation_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
desired_epochs_list = [20, 20, 50, 5, 3]
shorted_model_names = ['CLIP ViT-B/16', 'CLIP ViT-L/14', 'Sup ViT-B/16', 'DINO ViT-B/16', 'ConvNext-Base', 'BiT ResNet-50', 'BiT ResNet-101', ]
shortened_options_names = ['SGD', 'AdamW', 'SGD (freeze-embed)', 'SGD (freeze-embed, no momentum)']
id_output_metrics_list = [['test_acc/source_val_living'], ['WATERBIRDS_VAL'], ['test_acc/sketch_val'], ['test_acc/id_val'], ['test_acc/id_val']]
ood_output_metrics_list = [['test_acc/target_val_living'], ['WORST'], ['test_acc/real_val'], ['test_acc/africa_test'], ['test_acc/ood_test']]
shortened_dataset_names = ['Living-17', 'Waterbirds', 'DomainNet', "FMoW", "Camelyon", "Avg."]


mean_table, std_table = get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list, desired_epochs_list=desired_epochs_list)

# Rearranging mean table
mean_table = mean_table_old_order[[0,6,1,2,5,3,4],:,:]
std_table = std_table_old_order[[0,6,1,2,5,3,4],:,:]

id_mean_table, ood_mean_table = display_id_ood_tables(
    mean_table, std_table, id_output_metrics_list, ood_output_metrics_list, shortened_dataset_names, no_std=True)

# display_tsv_table(shorted_model_names, shortened_options_names, shortened_output_metrics_list, mean_table, std_table)

not all jobs ran:  clip_vit_l14 freeze_bottom_2_optimizer.args.momentum-0.0_ domainnet 1
not all jobs ran:  timm_vit_b16_in21k freeze_bottom_2_optimizer.args.momentum-0.0_ domainnet 1
not all jobs ran:  timm_vit_b16_in21k freeze_bottom_2_optimizer.args.momentum-0.0_ fmow_all_nonorm_weakaugs 4
not all jobs ran:  dino_vit_b16 freeze_bottom_2_optimizer.args.momentum-0.0_ domainnet 5
empty dir ../logs//full_ft_freeze_bottom_2_optimizer.args.momentum-0.0_fmow_all_nonorm_weakaugs_dino_vit_b16


TypeError: object of type 'NoneType' has no len()