In [1]:
import numpy as np
import summarize_all_results
import importlib
import math
import pickle
importlib.reload(summarize_all_results)

<module 'summarize_all_results' from '/juice/scr/ananya/cifar_experiments/transfer_learning/scripts/summarize_all_results.py'>

In [18]:
log_dir = '../logs/'
name_template = 'full_ft_{option}{dataset}_{model}'
lp_ft_template = 'lp_then_ft_{option}{dataset}_{model}'

def get_bolds(table, std_err_table=None, eps=0.2):
    is_best = np.zeros((len(table), len(table[0])), dtype=np.int32)
    for col in range(len(table[0])):
        best_idx = np.argmax(table[:,col])
        best_acc = table[best_idx][col]
        if std_err_table is not None and col < len(std_err_table[0]):
            best_std = std_err_table[best_idx][col]
            if math.isnan(best_std):
                best_std = eps
        for row in range(len(table)):
            if std_err_table is not None and col < len(std_err_table[0]):
                cur_std = std_err_table[row][col]
                if math.isnan(cur_std):
                    cur_std = eps
                if np.abs(table[row][col] - best_acc) <= max(best_std, cur_std):
                    is_best[row][col] = 1
            else:
                if table[row][col] >= best_acc - eps:
                    is_best[row][col] = 1
    return is_best

def get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list=None,
              aggregation_metric_list=None, num_sweep=6, desired_epochs_list=None):
    # desired_epochs_list is a list of size len(datasets) that stores the number of epochs we are
    # supposed to have run each dataset for. Used to validate if the runs finished.
    mean_table = np.zeros((len(models), len(options), sum([len(ol) for ol in output_metrics_list])))
    std_table = np.zeros((len(models), len(options), sum([len(ol) for ol in output_metrics_list])))
    lr_table = np.zeros((len(models), len(options), len(datasets)))
    for model_idx, model in enumerate(models):
        for option_idx, option in enumerate(options):
            cur_means = []
            cur_stds = []
            for dataset_idx, dataset in enumerate(datasets):
                dir_path = log_dir + '/' + name_template.format(model=model, option=option, dataset=dataset)
                output_metrics = output_metrics_list[dataset_idx]
                if val_metric_list is None:
                    val_metric = 'LAST'
                else:
                    val_metric = val_metric_list[dataset_idx]
                if aggregation_metric_list is None:
                    agg_metric = val_metric
                else:
                    agg_metric = aggregation_metric_list[dataset_idx]
                res, best_row_mean, best_row_std, num_epochs_list = summarize_all_results.get_experiment_aggregate_summary(
                    dir_path, val_metric, output_metrics, aggregation_metric=agg_metric)
                num_jobs = len(res)
                if (desired_epochs_list is not None and
                    desired_epochs_list[dataset_idx] != np.min(num_epochs_list)):
                    bad_indices = np.argwhere(np.array(num_epochs_list) < desired_epochs_list[dataset_idx])
                    bad_runs = [dir_path + '/' + run_folder for run_folder in res['name'][bad_indices[:, 0]]]
                    print('some jobs did not finish: ', model, dataset, option, bad_runs)
                if num_jobs != num_sweep and num_jobs != num_sweep * 3:
                    print('not all jobs ran: ', model, option, dataset, num_jobs)
                cur_means.append(best_row_mean.to_numpy()[0,1:-2])
                cur_stds.append(best_row_std.to_numpy()[0,1:-2])
                lr_table[model_idx][option_idx][dataset_idx] = best_row_mean.to_numpy()[0,0]
            cur_means = np.concatenate(cur_means)
            cur_stds = np.concatenate(cur_stds)
            mean_table[model_idx][option_idx] = cur_means
            std_table[model_idx][option_idx] = cur_stds
    return mean_table, std_table, lr_table

def flatten(list_of_lists):
    x = []
    for l in list_of_lists:
        x.append(l)
    return x

def filter_columns(table, old_output_metrics_list, new_output_metrics_list):
    def get_indices(old_list, new_items):
        return np.array([old_list.index(x) for x in new_items])
    local_indices_list = [get_indices(o, n) for o, n in zip(old_output_metrics_list, new_output_metrics_list)]
#     print(local_indices_list)
    old_lens = [len(o) for o in old_output_metrics_list]
    cum_old_lens = np.concatenate([[0], np.cumsum(old_lens)])[:-1]
#     print(cum_old_lens)
    global_indices = [local_indices + cum_old_len for local_indices, cum_old_len in zip(local_indices_list, cum_old_lens)]
#     print(list(np.concatenate(global_indices)))
    return table[:, :, global_indices][:,:,:,0]


# Add table to make TSV and Latex tables. Add averages.
def display_tsv_table(models, options, shortened_output_metrics_list, table, std_err_table=None):
    first_line = '\t\t' + '\t'.join(shortened_output_metrics_list)
    print(first_line)
    for model_idx, model in enumerate(models):
        for option_idx, option in enumerate(options):
            line = model + '\t' + option
            for idx, val in enumerate(table[model_idx][option_idx]):
                line += '\t' + '{:.1f}'.format(val)
                if (std_err_table is not None and
                    idx < len(std_err_table[model_idx][option_idx])):
                    line += ' ({:.1f})'.format(std_err_table[model_idx][option_idx][idx])
            print(line)
            
def add_average_column(table):
    new_table = []
    for model_idx in range(len(table)):
        model_table = table[model_idx]
        means = np.mean(model_table, axis=1)
        new_model_table = np.concatenate([model_table, np.expand_dims(means, axis=-1)], axis=-1)
        new_table.append(new_model_table)
    return np.array(new_table)
        
def display_latex_table(models, options, shortened_output_metrics_list, table, std_err_table=None, bold_best=True, is_last_avg=True, light_thresh=0.2, show_improvements=True):

    # Print latex table.
    if len(models) == 1:
        num_columns = 2*len(shortened_output_metrics_list) + 1
    else:
        num_columns = 2*len(shortened_output_metrics_list) + 2
    if is_last_avg:
        print('\\begin{tabular}{' + ('c'*(num_columns-2)) + '|cc}')
    else:
        print('\\begin{tabular}{' + ('c'*num_columns) + '}')
    print('\\toprule')
    first_line = '' 
    for oml in shortened_output_metrics_list:
        first_line += ' & \multicolumn{2}{c}{' + oml + '} '
    first_line += '\\\\'
    if len(models) > 1:
        first_line = ' & ' + first_line
    print(first_line)
    for model_idx, model in enumerate(models):
        print('\\midrule')
        if std_err_table is None:
            is_bold = get_bolds(table[model_idx], eps=0.2)
        else:
            is_bold = get_bolds(table[model_idx], std_err_table[model_idx], eps=0.2)
        for option_idx, option in enumerate(options):
            if len(models) > 1:
                line = model + ' & ' + option
            else:
                line = option
            for idx, val in enumerate(table[model_idx][option_idx]):
                line += ' & '
#                 print(is_bold[option_idx][idx])
                if is_bold[option_idx][idx] and bold_best:
                    line += '\\textbf{'
                line += '{:.1f}'.format(val)
                if (std_err_table is not None and
                    idx < len(std_err_table[model_idx][option_idx]) and
                    not math.isnan(std_err_table[model_idx][option_idx][idx])):
                    line += ' ({:.1f})'.format(std_err_table[model_idx][option_idx][idx])
                if is_bold[option_idx][idx] and bold_best:
                    line += '}'
                if show_improvements:
                    if option_idx:# if not sgd
                        sgd_val = table[model_idx][0][idx]
                        diff_fp = val - sgd_val
                        diff = round(val,1)-round(sgd_val,1)
                        if diff_fp>=light_thresh:
                            line += ' & \hspace{{-1em}}{{\hgreen{{(+{:.1f})}}}}'.format(diff)
                        elif diff_fp>=0.0:
                            line += ' & \hspace{{-1em}}{{\lgreen{{(+{:.1f})}}}}'.format(diff)
                        elif diff>=-1*light_thresh:
                            line += ' & \hspace{{-1em}}{{\lred{{(-{:.1f})}}}}'.format(abs(diff))
                        else:
                            line += ' & \hspace{{-1em}}{{\hred{{(-{:.1f})}}}}'.format(abs(diff))
                        # else:
                        #     line += ' & '
                    else:
                        line += ' & '

            line += '\\\\'
            print(line)
    print('\\bottomrule')
    print('\\end{tabular}')
    
    
def display_id_ood_tables(mean_table, std_table, id_output_metrics_list, ood_output_metrics_list, shortened_dataset_names, no_std=False, show_improvements=True):
    # Filter the ID and OOD metrics.
    print('\n\nID Accuracies')
    id_mean_table = filter_columns(mean_table, output_metrics_list, id_output_metrics_list)
    id_mean_table = add_average_column(id_mean_table)
    id_std_table = filter_columns(std_table, output_metrics_list, id_output_metrics_list)
    if no_std:
        id_std_table = None
    display_tsv_table(shorted_model_names, shortened_options_names, shortened_dataset_names, id_mean_table, id_std_table)
    print('')
    display_latex_table(shorted_model_names, shortened_options_names, shortened_dataset_names, id_mean_table, id_std_table, is_last_avg=True, show_improvements=show_improvements)

    print('\n\nOOD Accuracies')
    ood_mean_table = filter_columns(mean_table, output_metrics_list, ood_output_metrics_list)
    ood_mean_table = add_average_column(ood_mean_table)
    ood_std_table = filter_columns(std_table, output_metrics_list, ood_output_metrics_list)
    if no_std:
        ood_std_table = None
    display_tsv_table(shorted_model_names, shortened_options_names, shortened_dataset_names, ood_mean_table, ood_std_table)
    print('')
    display_latex_table(shorted_model_names, shortened_options_names, shortened_dataset_names, ood_mean_table, ood_std_table, is_last_avg=True, show_improvements=show_improvements)
    return id_mean_table, ood_mean_table
    

# Big table, Living17, Waterbirds, DomainNet, Camelyon, FMoW (Early stopped)


In [None]:
# Big table, Living17, Waterbirds, DomainNet, Camelyon, FMoW (Early stopped)
models = ['clip_vit_b16', 'clip_vit_l14', 'timm_vit_b16_in21k', 'dino_vit_b16',  'convnext_vit_b', 'bit_resnet_50_in21k', 'bit_resnet_101_in21k']
options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_']
datasets = ['living17_nonorm', 'waterbirds', 'domainnet', 'fmow_all_nonorm_weakaugs', 'camelyon17_weakaugs']
output_metrics_list = [['test_acc/source_val_living', 'test_acc/target_val_living'], ['WATERBIRDS_VAL', 'WORST'], ['test_acc/sketch_val', 'test_acc/real_val'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test', 'test_acc/africa_test'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test']]
val_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
aggregation_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
desired_epochs_list = [20, 20, 50, 5, 3]
shorted_model_names = ['CLIP ViT-B/16', 'CLIP ViT-L/14', 'Sup ViT-B/16', 'DINO ViT-B/16', 'ConvNext-Base', 'BiT ResNet-50', 'BiT ResNet-101', ]
shortened_options_names = ['SGD', 'AdamW', 'SGD (freeze-embed)']
id_output_metrics_list = [['test_acc/source_val_living'], ['WATERBIRDS_VAL'], ['test_acc/sketch_val'], ['test_acc/id_val'], ['test_acc/id_val']]
ood_output_metrics_list = [['test_acc/target_val_living'], ['WORST'], ['test_acc/real_val'], ['test_acc/africa_test'], ['test_acc/ood_test']]
shortened_dataset_names = ['Living-17', 'Waterbirds', 'DomainNet', "FMoW", "Camelyon", "Avg."]


mean_table, std_table, lr_table = get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list, desired_epochs_list=desired_epochs_list)
# pickle.dump( (mean_table, std_table), open( "big_table_main.pkl", "wb" ) )
# mean_table_old_ordit
er, std_table_old_order = pickle.load( open( "big_table_main.pkl", "rb" ) )

# Rearranging mean table
mean_table = mean_table_old_order[[0,6,1,2,5,3,4],:,:]
std_table = std_table_old_order[[0,6,1,2,5,3,4],:,:]

id_mean_table, ood_mean_table = display_id_ood_tables(
    mean_table, std_table, id_output_metrics_list, ood_output_metrics_list, shortened_dataset_names, no_std=True)

# display_tsv_table(shorted_model_names, shortened_options_names, shortened_output_metrics_list, mean_table, std_table)

In [None]:
# Get average ID and OOD results for the 3 methods.
print(np.mean(id_mean_table[:,:,5], axis=0))
print(np.mean(ood_mean_table[:,:,5], axis=0))

# ConvNeXt freeze stem results

In [8]:
# ConvNeXt, Living17, Waterbirds, DomainNet, Camelyon, FMoW (Early stopped), including freeze-stem
models = ['convnext_vit_b']
options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_', 'freeze_bottom_1_']
datasets = ['living17_nonorm', 'waterbirds', 'domainnet', 'fmow_all_nonorm_weakaugs', 'camelyon17_weakaugs']
output_metrics_list = [['test_acc/source_val_living', 'test_acc/target_val_living'], ['WATERBIRDS_VAL', 'WORST'], ['test_acc/sketch_val', 'test_acc/real_val'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test', 'test_acc/africa_test'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test']]
val_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
aggregation_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
desired_epochs_list = [20, 20, 50, 5, 3]
shorted_model_names = ['ConvNext-Base']
shortened_options_names = ['SGD', 'AdamW', 'SGD (freeze-stem-block-1)', 'SGD (freeze-stem)']
id_output_metrics_list = [['test_acc/source_val_living'], ['WATERBIRDS_VAL'], ['test_acc/sketch_val'], ['test_acc/id_val'], ['test_acc/id_val']]
ood_output_metrics_list = [['test_acc/target_val_living'], ['WORST'], ['test_acc/real_val'], ['test_acc/africa_test'], ['test_acc/ood_test']]
shortened_dataset_names = ['Living-17', 'Waterbirds', 'DomainNet', "FMoW", "Camelyon", "Avg."]


mean_table, std_table, lr_table = get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list, desired_epochs_list=desired_epochs_list)
# pickle.dump( (mean_table, std_table), open( "big_table_main.pkl", "wb" ) )
# mean_table_old_order, std_table_old_order = pickle.load( open( "big_table_main.pkl", "rb" ) )

id_mean_table, ood_mean_table = display_id_ood_tables(
    mean_table, std_table, id_output_metrics_list, ood_output_metrics_list, shortened_dataset_names, no_std=True)

# display_tsv_table(shorted_model_names, shortened_options_names, shortened_output_metrics_list, mean_table, std_table)



ID Accuracies
		Living-17	Waterbirds	DomainNet	FMoW	Camelyon	Avg.
ConvNext-Base	SGD	98.7	99.0	94.8	66.3	99.4	91.6
ConvNext-Base	AdamW	98.6	99.5	94.5	68.8	99.7	92.2
ConvNext-Base	SGD (freeze-stem-block-1)	98.6	99.4	95.1	67.4	99.5	92.0
ConvNext-Base	SGD (freeze-stem)	98.8	99.3	95.1	67.2	99.5	92.0

\begin{tabular}{ccccccccccc|cc}
\toprule
 & \multicolumn{2}{c}{Living-17}  & \multicolumn{2}{c}{Waterbirds}  & \multicolumn{2}{c}{DomainNet}  & \multicolumn{2}{c}{FMoW}  & \multicolumn{2}{c}{Camelyon}  & \multicolumn{2}{c}{Avg.} \\
\midrule
SGD & \textbf{98.7} &  & 99.0 &  & 94.8 &  & 66.3 &  & 99.4 &  & 91.6 & \\
AdamW & \textbf{98.6} & \hspace{-1em}{\lred{(-0.1)}} & \textbf{99.5} & \hspace{-1em}{\hgreen{(+0.5)}} & 94.5 & \hspace{-1em}{\hred{(-0.3)}} & \textbf{68.8} & \hspace{-1em}{\hgreen{(+2.5)}} & \textbf{99.7} & \hspace{-1em}{\hgreen{(+0.3)}} & \textbf{92.2} & \hspace{-1em}{\hgreen{(+0.6)}}\\
SGD (freeze-stem-block-1) & \textbf{98.6} & \hspace{-1em}{\lred{(-0.1)}} & \textbf{99.4} & \hspa

# CLIP results including wilds datasets

In [11]:

print("CLIP Results, all datasets")
models = ['clip_vit_b16']
options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_', 'layer_wise_tune__', 'opt_torch_optimizer.Lamb_', 'opt_torch_optimizer.LARS_']
val_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
datasets = ['living17_nonorm', 'waterbirds', 'domainnet', 'fmow_all_nonorm_weakaugs', 'camelyon17_weakaugs']
output_metrics_list = [['test_acc/source_val_living', 'test_acc/target_val_living'], ['WATERBIRDS_VAL', 'WORST'], ['test_acc/sketch_val', 'test_acc/real_val'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test', 'test_acc/africa_test'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test']]
aggregation_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
desired_epochs_list = [20, 20, 50, 5, 3]

mean_table, std_table, lr_table = get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list, desired_epochs_list=desired_epochs_list)
print(lr_table)
# pickle.dump( (mean_table, std_table), open( "clip_table_main.pkl", "wb" ) )
# mean_table, std_table = pickle.load( open( "clip_table_main.pkl", "rb" ) )

shorted_model_names = ['CLIP ViT-B/16']
shortened_options_names = ['SGD', 'AdamW', 'SGD (Freeze-embed)', 'Layer-wise', 'LAMB', 'LARS']
shortened_output_metrics_list = ['Living-17 ID', 'Living-17 OOD', 'Waterbirds ID', 'Waterbirds OOD', 'DomainNet ID', 'DomainNet OOD', "FMoW ID", "FMoW OOD Val", "FMoW OOD Test", "FMoW Africa", "Camelyon ID", "Camelyon OOD Val", "Camelyon OOD Test"]

id_output_metrics_list = [['test_acc/source_val_living'], ['WATERBIRDS_VAL'], ['test_acc/sketch_val'], ['test_acc/id_val'], ['test_acc/id_val']]
ood_output_metrics_list = [['test_acc/target_val_living'], ['WORST'], ['test_acc/real_val'], ['test_acc/africa_test'], ['test_acc/ood_test']]
shortened_dataset_names = ['Liv-17', 'Waterbirds', 'DomainNet', "FMoW", "Camelyon", "Avg."]

display_id_ood_tables(mean_table, std_table, id_output_metrics_list, ood_output_metrics_list, shortened_dataset_names)


CLIP Results, all datasets
not all jobs ran:  clip_vit_b16 opt_torch_optimizer.Lamb_ living17_nonorm 8
not all jobs ran:  clip_vit_b16 opt_torch_optimizer.Lamb_ waterbirds 8
not all jobs ran:  clip_vit_b16 opt_torch_optimizer.Lamb_ domainnet 8
not all jobs ran:  clip_vit_b16 opt_torch_optimizer.LARS_ living17_nonorm 8
not all jobs ran:  clip_vit_b16 opt_torch_optimizer.LARS_ waterbirds 8
not all jobs ran:  clip_vit_b16 opt_torch_optimizer.LARS_ domainnet 8
[[[1.e-04 3.e-05 3.e-05 3.e-04 1.e-04]
  [3.e-06 3.e-06 3.e-06 1.e-05 3.e-06]
  [1.e-04 3.e-04 1.e-04 3.e-04 3.e-04]
  [3.e-05 1.e-04 1.e-05 1.e-04 3.e-05]
  [3.e-04 1.e-04 1.e-04 1.e-04 1.e-04]
  [3.e-05 3.e-05 3.e-05 1.e-04 3.e-05]]]


ID Accuracies
		Liv-17	Waterbirds	DomainNet	FMoW	Camelyon	Avg.
CLIP ViT-B/16	SGD	97.8 (0.2)	97.2 (0.1)	88.8 (7.1)	67.0 (0.8)	99.4 (0.0)	90.0
CLIP ViT-B/16	AdamW	98.1 (0.1)	97.7 (0.0)	95.0 (0.1)	70.1 (0.2)	99.5 (0.0)	92.1
CLIP ViT-B/16	SGD (Freeze-embed)	98.2 (0.3)	97.8 (0.1)	94.9 (0.3)	70.0 (0.2)	99.

(array([[[97.82352941, 97.22234352, 88.80088926, 66.97146506,
          99.38021454, 90.03968836],
         [98.09803922, 97.7471974 , 94.9979158 , 70.11814566,
          99.54509337, 92.10127829],
         [98.15686275, 97.80620545, 94.91454773, 70.00203199,
          99.52622169, 92.08117392],
         [98.31372549, 98.29358603, 96.30401556, 69.19794479,
          99.268971  , 92.27564857],
         [98.17647059, 97.76332641, 95.1229679 , 67.91779152,
          99.52324195, 91.70075967],
         [97.70588235, 97.11959892, 93.24718633, 66.98597927,
          99.34445769, 90.88062091]]]),
 array([[[79.96078431, 62.46105919, 72.77833789, 37.34413164,
          86.7954476 , 67.86795213],
         [82.80392157, 71.85877466, 89.16894714, 40.66075331,
          95.7258526 , 76.04364986],
         [83.17647059, 73.72793354, 88.22795141, 40.18511377,
          94.29107783, 75.92170943],
         [81.90196078, 69.05503634, 93.24019396, 40.46792647,
          96.481451  , 76.22931371],
       

# CLIP Results for ViT-L

In [9]:
# Big table, Living17, Waterbirds, DomainNet, Camelyon, FMoW (Early stopped)
models = ['clip_vit_l14']
options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_']
datasets = ['living17_nonorm', 'waterbirds', 'domainnet', 'fmow_all_nonorm_weakaugs', 'camelyon17_weakaugs']
output_metrics_list = [['test_acc/source_val_living', 'test_acc/target_val_living'], ['WATERBIRDS_VAL', 'WORST'], ['test_acc/sketch_val', 'test_acc/real_val'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test', 'test_acc/africa_test'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test']]
val_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
aggregation_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
desired_epochs_list = [20, 20, 50, 5, 3]
mean_table, std_table, lr_table = get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list, desired_epochs_list=desired_epochs_list)
shorted_model_names = ['CLIP ViT-L/14']
shortened_options_names = ['SGD', 'AdamW', 'SGD (Freeze-2)']

id_output_metrics_list = [['test_acc/source_val_living'], ['WATERBIRDS_VAL'], ['test_acc/sketch_val'], ['test_acc/id_val'], ['test_acc/id_val']]
ood_output_metrics_list = [['test_acc/target_val_living'], ['WORST'], ['test_acc/real_val'], ['test_acc/africa_test'], ['test_acc/ood_test']]
shortened_dataset_names = ['Liv-17', 'Waterbirds', 'DomainNet', "FMoW", "Camelyon", "Avg."]

id_mean_table, ood_mean_table = display_id_ood_tables(
    mean_table, std_table, id_output_metrics_list, ood_output_metrics_list, shortened_dataset_names, no_std=True)

# display_tsv_table(shorted_model_names, shortened_options_names, shortened_output_metrics_list, mean_table, std_table)



ID Accuracies
		Liv-17	Waterbirds	DomainNet	FMoW	Camelyon	Avg.
CLIP ViT-L/14	SGD	98.4	97.3	84.3	69.0	99.4	89.7
CLIP ViT-L/14	AdamW	98.9	98.8	96.9	74.5	99.6	93.7
CLIP ViT-L/14	SGD (Freeze-2)	98.7	98.9	97.1	74.5	99.6	93.7

\begin{tabular}{ccccccccccc|cc}
\toprule
 & \multicolumn{2}{c}{Liv-17}  & \multicolumn{2}{c}{Waterbirds}  & \multicolumn{2}{c}{DomainNet}  & \multicolumn{2}{c}{FMoW}  & \multicolumn{2}{c}{Camelyon}  & \multicolumn{2}{c}{Avg.} \\
\midrule
SGD & 98.4 &  & 97.3 &  & 84.3 &  & 69.0 &  & \textbf{99.4} &  & 89.7 & \\
AdamW & \textbf{98.9} & \hspace{-1em}{\hgreen{(+0.5)}} & \textbf{98.8} & \hspace{-1em}{\hgreen{(+1.5)}} & 96.9 & \hspace{-1em}{\hgreen{(+12.6)}} & \textbf{74.5} & \hspace{-1em}{\hgreen{(+5.5)}} & \textbf{99.6} & \hspace{-1em}{\lgreen{(+0.2)}} & \textbf{93.7} & \hspace{-1em}{\hgreen{(+4.0)}}\\
SGD (Freeze-2) & \textbf{98.7} & \hspace{-1em}{\hgreen{(+0.3)}} & \textbf{98.9} & \hspace{-1em}{\hgreen{(+1.6)}} & \textbf{97.1} & \hspace{-1em}{\hgreen{(+12.8)}} & \text

# CLIP Results for LP-FT

In [19]:
# Big table, Living17, Waterbirds, DomainNet, Camelyon, FMoW (Early stopped)
models = ['clip_vit_b16']
options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_']
datasets = ['living17_nonorm', 'waterbirds', 'domainnet', 'fmow_all_nonorm_weakaugs', 'camelyon17_weakaugs']
output_metrics_list = [['test_acc/source_val_living', 'test_acc/target_val_living'], ['WATERBIRDS_VAL', 'WORST'], ['test_acc/sketch_val', 'test_acc/real_val'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test', 'test_acc/africa_test'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test']]
val_metric_list = ['test_acc/source_val_living', 'LAST', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
aggregation_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
desired_epochs_list = [20, 20, 50, 5, 3]
mean_table, std_table, lr_table = get_table(log_dir, lp_ft_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list, desired_epochs_list=desired_epochs_list)
shorted_model_names = ['CLIP ViT-B/16']
shortened_options_names = ['SGD', 'AdamW', 'SGD (Freeze-2)']

id_output_metrics_list = [['test_acc/source_val_living'], ['WATERBIRDS_VAL'], ['test_acc/sketch_val'], ['test_acc/id_val'], ['test_acc/id_val']]
ood_output_metrics_list = [['test_acc/target_val_living'], ['WORST'], ['test_acc/real_val'], ['test_acc/africa_test'], ['test_acc/ood_test']]
shortened_dataset_names = ['Liv-17', 'Waterbirds', 'DomainNet', "FMoW", "Camelyon", "Avg."]

print(mean_table.shape, std_table.shape)
print(lr_table)

id_mean_table, ood_mean_table = display_id_ood_tables(
    mean_table, std_table, id_output_metrics_list, ood_output_metrics_list, shortened_dataset_names, no_std=True, show_improvements=False)

# display_tsv_table(shorted_model_names, shortened_options_names, shortened_output_metrics_list, mean_table, std_table)

(1, 3, 13) (1, 3, 13)
[[[3.e-06 3.e-06 1.e-05 3.e-05 3.e-06]
  [3.e-06 1.e-05 1.e-06 1.e-05 3.e-06]
  [1.e-04 3.e-04 1.e-04 1.e-04 3.e-05]]]


ID Accuracies
		Liv-17	Waterbirds	DomainNet	FMoW	Camelyon	Avg.
CLIP ViT-B/16	SGD	98.2	97.2	95.1	66.7	99.0	91.2
CLIP ViT-B/16	AdamW	98.2	98.2	95.7	69.2	99.5	92.2
CLIP ViT-B/16	SGD (Freeze-2)	98.4	97.8	95.7	69.1	99.4	92.1

\begin{tabular}{ccccccccccc|cc}
\toprule
 & \multicolumn{2}{c}{Liv-17}  & \multicolumn{2}{c}{Waterbirds}  & \multicolumn{2}{c}{DomainNet}  & \multicolumn{2}{c}{FMoW}  & \multicolumn{2}{c}{Camelyon}  & \multicolumn{2}{c}{Avg.} \\
\midrule
SGD & \textbf{98.2} & 97.2 & 95.1 & 66.7 & 99.0 & 91.2\\
AdamW & \textbf{98.2} & \textbf{98.2} & \textbf{95.7} & \textbf{69.2} & \textbf{99.5} & \textbf{92.2}\\
SGD (Freeze-2) & \textbf{98.4} & 97.8 & \textbf{95.7} & \textbf{69.1} & \textbf{99.4} & \textbf{92.1}\\
\bottomrule
\end{tabular}


OOD Accuracies
		Liv-17	Waterbirds	DomainNet	FMoW	Camelyon	Avg.
CLIP ViT-B/16	SGD	86.7	67.3	89.2	37.9	94.

In [None]:
# Get table for dino and clip, freezing different number of layers.
print("CLIP Results, all datasets")
models = ['clip_vit_b16']
options = ['freeze_bottom_2_full_ft_epoch_50_', 'freeze_bottom_5_full_ft_epoch_50_', 'freeze_bottom_8_full_ft_epoch_50_', 'freeze_bottom_11_full_ft_epoch_50_', 'freeze_bottom_14_full_ft_epoch_50_']
val_metric_list = ['WATERBIRDS_VAL']
datasets = ['waterbirds']
output_metrics_list = [['WATERBIRDS_VAL', 'WORST']]
aggregation_metric_list = ['WATERBIRDS_VAL']
# desired_epochs_list = [20, 20, 50, 5, 3]
mean_table, std_table, lr_table = get_table('../older_logs/', name_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list, desired_epochs_list=desired_epochs_list)
shorted_model_names = ['CLIP ViT-B/16']
shortened_options_names = ['SGD', 'AdamW', 'SGD (Freeze-embed)', 'Layer-wise', 'LAMB', 'LARS']
shortened_output_metrics_list = ['Living-17 ID', 'Living-17 OOD', 'Waterbirds ID', 'Waterbirds OOD', 'DomainNet ID', 'DomainNet OOD', "FMoW ID", "FMoW OOD Val", "FMoW OOD Test", "FMoW Africa", "Camelyon ID", "Camelyon OOD Val", "Camelyon OOD Test"]


# Post ICLR Supplementary results

## Big table, including AdamW + Freeze 

In [None]:
# Big table, Living17, Waterbirds, DomainNet, Camelyon, FMoW (Early stopped)
models = ['clip_vit_b16', 'clip_vit_l14', 'timm_vit_b16_in21k', 'dino_vit_b16',  'convnext_vit_b', 'bit_resnet_50_in21k', 'bit_resnet_101_in21k']
options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_', 'freeze_bottom_2_opt_torch.optim.AdamW_']
datasets = ['living17_nonorm', 'waterbirds', 'domainnet', 'fmow_all_nonorm_weakaugs', 'camelyon17_weakaugs']
output_metrics_list = [['test_acc/source_val_living', 'test_acc/target_val_living'], ['WATERBIRDS_VAL', 'WORST'], ['test_acc/sketch_val', 'test_acc/real_val'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test', 'test_acc/africa_test'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test']]
val_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
aggregation_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
desired_epochs_list = [20, 20, 50, 5, 3]

# mean_table_3, std_table_3 = pickle.load( open( "big_table_main.pkl", "rb" ) )
# mean_table = np.concatenate((mean_table_3, mean_table_3[:,1,None,:]), axis=1)
# std_table = np.concatenate((std_table_3, std_table_3[:,1,None,:]), axis=1)

mean_table, std_table, lr_table = get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list, desired_epochs_list=desired_epochs_list)
shorted_model_names = ['CLIP ViT-B/16', 'CLIP ViT-L/14', 'Sup ViT-B/16', 'DINO ViT-B/16', 'ConvNext-Base', 'BiT ResNet-50', 'BiT ResNet-101']
shortened_options_names = ['SGD', 'AdamW', 'SGD (freeze-embed)', 'AdamW (freeze-embed)']

id_output_metrics_list = [['test_acc/source_val_living'], ['WATERBIRDS_VAL'], ['test_acc/sketch_val'], ['test_acc/id_val'], ['test_acc/id_val']]
ood_output_metrics_list = [['test_acc/target_val_living'], ['WORST'], ['test_acc/real_val'], ['test_acc/africa_test'], ['test_acc/ood_test']]
shortened_dataset_names = ['Living-17', 'Waterbirds', 'DomainNet', "FMoW", "Camelyon", "Avg."]


id_mean_table, ood_mean_table = display_id_ood_tables(
    mean_table, std_table, id_output_metrics_list, ood_output_metrics_list, shortened_dataset_names, no_std=True)

# display_tsv_table(shorted_model_names, shortened_options_names, shortened_output_metrics_list, mean_table, std_table)

In [None]:
# Get average ID and OOD results for the 3 methods.
print(np.mean(id_mean_table[:,:,5], axis=0))
print(np.mean(ood_mean_table[:,:,5], axis=0))

## ResNet-50 all results 

In [None]:
# Big table, Living17, Waterbirds, DomainNet, Camelyon, FMoW (Early stopped)
models = ['sup_resnet50']
options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_', 'freeze_bottom_2_opt_torch.optim.AdamW_']
datasets = ['living17_nonorm', 'waterbirds', 'domainnet', 'fmow_all_nonorm_weakaugs', 'camelyon17_weakaugs']
output_metrics_list = [['test_acc/source_val_living', 'test_acc/target_val_living'], ['WATERBIRDS_VAL', 'WORST'], ['test_acc/sketch_val', 'test_acc/real_val'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test', 'test_acc/africa_test'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test']]
val_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
aggregation_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
desired_epochs_list = [20, 20, 50, 5, 3]
mean_table, std_table, lr_table = get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list, desired_epochs_list=desired_epochs_list)
shorted_model_names = ['ResNet50']
shortened_options_names = ['SGD', 'AdamW', 'SGD (Freeze-2)', 'AdamW (Freeze-2)']

id_output_metrics_list = [['test_acc/source_val_living'], ['WATERBIRDS_VAL'], ['test_acc/sketch_val'], ['test_acc/id_val'], ['test_acc/id_val']]
ood_output_metrics_list = [['test_acc/target_val_living'], ['WORST'], ['test_acc/real_val'], ['test_acc/africa_test'], ['test_acc/ood_test']]
shortened_dataset_names = ['Liv-17', 'Waterbirds', 'DomainNet', "FMoW", "Camelyon", "Avg."]

id_mean_table, ood_mean_table = display_id_ood_tables(
    mean_table, std_table, id_output_metrics_list, ood_output_metrics_list, shortened_dataset_names, no_std=True)

# display_tsv_table(shorted_model_names, shortened_options_names, shortened_output_metrics_list, mean_table, std_table)

## CLIP weight decay, no momentum, freeze only embed 

In [42]:

print("CLIP Results, all datasets")
models = ['clip_vit_b16']
# options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_', 'layer_wise_tune__', 'opt_torch_optimizer.Lamb_', 'opt_torch_optimizer.LARS_', 'freeze_bottom_1_', 'optimizer.args.momentum-0.0_', 'freeze_bottom_2_optimizer.args.momentum-0.0_', 'optimizer.args.weight_decay-0.01_']
options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_', 'layer_wise_tune__', 'opt_torch_optimizer.Lamb_', 'opt_torch_optimizer.LARS_', 'freeze_bottom_1_', 'optimizer.args.momentum-0.0_', 'optimizer.args.weight_decay-0.01_', 'embed_layer_lr_multiplier-0.2_']
val_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
datasets = ['living17_nonorm', 'waterbirds', 'domainnet', 'fmow_all_nonorm_weakaugs', 'camelyon17_weakaugs']
output_metrics_list = [['test_acc/source_val_living', 'test_acc/target_val_living'], ['WATERBIRDS_VAL', 'WORST'], ['test_acc/sketch_val', 'test_acc/real_val'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test', 'test_acc/africa_test'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test']]
aggregation_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
desired_epochs_list = [20, 20, 50, 5, 3]

mean_table, std_table, lr_table = get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list, desired_epochs_list=desired_epochs_list)
# pickle.dump( (mean_table, std_table), open( "clip_table_extended.pkl", "wb" ) )
# mean_table, std_table = pickle.load( open( "clip_table_extended.pkl", "rb" ) )

shorted_model_names = ['CLIP ViT-B/16']
# shortened_options_names = ['SGD', 'AdamW', 'SGD (Freeze-embed)', 'Layer-wise', 'LAMB', 'LARS', 'Freeze only embed', 'SGD (no momentum)', 'SGD (Freeze, no momentum)', 'SGD (weight decay)']
shortened_options_names = ['SGD', 'AdamW', 'SGD (Freeze-embed)', 'Layer-wise', 'LAMB', 'LARS', 'SGD (Freeze-embed, not layer-norm)', 'SGD (no momentum)', 'SGD (weight decay)', 'SGD (lower LR, embed layer)']
shortened_output_metrics_list = ['Living-17 ID', 'Living-17 OOD', 'Waterbirds ID', 'Waterbirds OOD', 'DomainNet ID', 'DomainNet OOD', "FMoW ID", "FMoW OOD Val", "FMoW OOD Test", "FMoW Africa", "Camelyon ID", "Camelyon OOD Val", "Camelyon OOD Test"]

id_output_metrics_list = [['test_acc/source_val_living'], ['WATERBIRDS_VAL'], ['test_acc/sketch_val'], ['test_acc/id_val'], ['test_acc/id_val']]
ood_output_metrics_list = [['test_acc/target_val_living'], ['WORST'], ['test_acc/real_val'], ['test_acc/africa_test'], ['test_acc/ood_test']]
shortened_dataset_names = ['Liv-17', 'Waterbirds', 'DomainNet', "FMoW", "Camelyon", "Avg."]

display_id_ood_tables(mean_table, std_table, id_output_metrics_list, ood_output_metrics_list, shortened_dataset_names, no_std=True)


CLIP Results, all datasets
not all jobs ran:  clip_vit_b16 opt_torch_optimizer.Lamb_ living17_nonorm 8
not all jobs ran:  clip_vit_b16 opt_torch_optimizer.Lamb_ waterbirds 8
not all jobs ran:  clip_vit_b16 opt_torch_optimizer.Lamb_ domainnet 8
not all jobs ran:  clip_vit_b16 opt_torch_optimizer.LARS_ living17_nonorm 8
not all jobs ran:  clip_vit_b16 opt_torch_optimizer.LARS_ waterbirds 8
not all jobs ran:  clip_vit_b16 opt_torch_optimizer.LARS_ domainnet 8


ID Accuracies
		Liv-17	Waterbirds	DomainNet	FMoW	Camelyon	Avg.
CLIP ViT-B/16	SGD	97.8	97.2	88.8	67.0	99.4	90.0
CLIP ViT-B/16	AdamW	98.1	97.7	95.0	70.1	99.5	92.1
CLIP ViT-B/16	SGD (Freeze-embed)	98.2	97.8	94.9	70.0	99.5	92.1
CLIP ViT-B/16	Layer-wise	98.3	98.3	96.3	69.2	99.3	92.3
CLIP ViT-B/16	LAMB	98.2	97.8	95.1	67.9	99.5	91.7
CLIP ViT-B/16	LARS	97.7	97.1	93.2	67.0	99.3	90.9
CLIP ViT-B/16	SGD (Freeze-embed, not layer-norm)	98.0	98.0	95.4	70.2	99.5	92.2
CLIP ViT-B/16	SGD (no momentum)	98.0	97.1	89.5	66.4	99.3	90.1
CLIP ViT-B/16	SGD (

(array([[[97.82353333, 97.22233333, 88.8009    , 66.9715    ,
          99.3802    , 90.03969333],
         [98.09803333, 97.7472    , 94.9979    , 70.11816667,
          99.5451    , 92.10128   ],
         [98.15683333, 97.8062    , 94.91456667, 70.00203333,
          99.52623333, 92.08117333],
         [98.31373333, 98.2936    , 96.30403333, 69.19793333,
          99.26896667, 92.27565333],
         [98.1765    , 97.7633    , 95.123     , 67.9178    ,
          99.5232    , 91.70076   ],
         [97.7059    , 97.1196    , 93.2472    , 66.986     ,
          99.3445    , 90.88064   ],
         [98.        , 98.0037    , 95.4148    , 70.1646    ,
          99.5262    , 92.22186   ],
         [98.        , 97.071     , 89.5373    , 66.3764    ,
          99.3236    , 90.06166   ],
         [97.6471    , 97.2224    , 87.8699    , 66.3503    ,
          99.2938    , 89.6767    ],
         [98.        , 97.5392    , 94.7895    , 68.6841    ,
          99.5113    , 91.70482   ]]]),
 array(

# Big table with no momentum

In [4]:
# Big table, Living17, Waterbirds, DomainNet, Camelyon, FMoW (Early stopped)
models = ['clip_vit_b16', 'clip_vit_l14', 'timm_vit_b16_in21k', 'dino_vit_b16',  'convnext_vit_b', 'bit_resnet_50_in21k', 'bit_resnet_101_in21k']
options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_', 'freeze_bottom_2_optimizer.args.momentum-0.0_']
datasets = ['living17_nonorm', 'waterbirds', 'domainnet', 'fmow_all_nonorm_weakaugs', 'camelyon17_weakaugs']
output_metrics_list = [['test_acc/source_val_living', 'test_acc/target_val_living'], ['WATERBIRDS_VAL', 'WORST'], ['test_acc/sketch_val', 'test_acc/real_val'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test', 'test_acc/africa_test'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test']]
val_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
aggregation_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
desired_epochs_list = [20, 20, 50, 5, 3]
shorted_model_names = ['CLIP ViT-B/16', 'CLIP ViT-L/14', 'Sup ViT-B/16', 'DINO ViT-B/16', 'ConvNext-Base', 'BiT ResNet-50', 'BiT ResNet-101', ]
shortened_options_names = ['SGD', 'AdamW', 'SGD (freeze-embed)', 'SGD (freeze-embed, no momentum)']
id_output_metrics_list = [['test_acc/source_val_living'], ['WATERBIRDS_VAL'], ['test_acc/sketch_val'], ['test_acc/id_val'], ['test_acc/id_val']]
ood_output_metrics_list = [['test_acc/target_val_living'], ['WORST'], ['test_acc/real_val'], ['test_acc/africa_test'], ['test_acc/ood_test']]
shortened_dataset_names = ['Living-17', 'Waterbirds', 'DomainNet', "FMoW", "Camelyon", "Avg."]


mean_table, std_table, lr_table = get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list, desired_epochs_list=desired_epochs_list)

# Rearranging mean table
# mean_table = mean_table_old_order[[0,6,1,2,5,3,4],:,:]
# std_table = std_table_old_order[[0,6,1,2,5,3,4],:,:]

id_mean_table, ood_mean_table = display_id_ood_tables(
    mean_table, std_table, id_output_metrics_list, ood_output_metrics_list, shortened_dataset_names, no_std=True)

# display_tsv_table(shorted_model_names, shortened_options_names, shortened_output_metrics_list, mean_table, std_table)



ID Accuracies
		Living-17	Waterbirds	DomainNet	FMoW	Camelyon	Avg.
CLIP ViT-B/16	SGD	97.8	97.2	88.8	67.0	99.4	90.0
CLIP ViT-B/16	AdamW	98.1	97.7	95.0	70.1	99.5	92.1
CLIP ViT-B/16	SGD (freeze-embed)	98.2	97.8	94.9	70.0	99.5	92.1
CLIP ViT-B/16	SGD (freeze-embed, no momentum)	98.2	97.9	95.2	70.1	99.5	92.2
CLIP ViT-L/14	SGD	98.4	97.3	84.3	69.0	99.4	89.7
CLIP ViT-L/14	AdamW	98.9	98.8	96.9	74.5	99.6	93.7
CLIP ViT-L/14	SGD (freeze-embed)	98.7	98.9	97.1	74.5	99.6	93.7
CLIP ViT-L/14	SGD (freeze-embed, no momentum)	98.8	98.7	97.3	74.3	99.5	93.7
Sup ViT-B/16	SGD	98.6	99.1	91.7	64.1	99.4	90.6
Sup ViT-B/16	AdamW	98.7	99.0	91.7	66.4	99.5	91.1
Sup ViT-B/16	SGD (freeze-embed)	98.7	99.2	91.5	65.0	99.6	90.8
Sup ViT-B/16	SGD (freeze-embed, no momentum)	98.5	98.9	90.6	65.7	99.5	90.6
DINO ViT-B/16	SGD	98.4	97.0	88.2	62.4	99.4	89.1
DINO ViT-B/16	AdamW	98.5	97.9	89.4	66.0	99.6	90.3
DINO ViT-B/16	SGD (freeze-embed)	98.4	97.5	89.0	63.5	99.5	89.6
DINO ViT-B/16	SGD (freeze-embed, no momentum)	98.5	97.5	89.2	63.

In [14]:
np.mean(ood_mean_table, axis=(0,2))

array([71.9667419 , 76.03873524, 76.73786476, 76.86830857])

In [15]:
np.mean(id_mean_table, axis=(0,2))

array([90.25925333, 91.46286571, 91.27799905, 91.19496571])

# Try different pretrained ResNet models

In [12]:
# Big table, Living17, Waterbirds, DomainNet, Camelyon, FMoW (Early stopped)
models = ['bitresnet_gn', 'bitresnet_gn_opt_adamw', 'bitresnet_gn_patchify', 'bitresnet_gn_opt_adamw_patchify']
options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_']
datasets = ['living17_nonorm', 'waterbirds', 'domainnet', 'fmow_all_nonorm_weakaugs', 'camelyon17_weakaugs']
output_metrics_list = [['test_acc/source_val_living', 'test_acc/target_val_living'], ['WATERBIRDS_VAL', 'WORST'], ['test_acc/sketch_val', 'test_acc/real_val'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test', 'test_acc/africa_test'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test']]
val_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
aggregation_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
desired_epochs_list = [20, 20, 50, 5, 3]
shorted_model_names = ['SGD', 'AdamW', 'SGD (Patchify)', 'AdamW (Patchify)']
shortened_options_names = ['SGD', 'AdamW', 'SGD (freeze-embed)']
id_output_metrics_list = [['test_acc/source_val_living'], ['WATERBIRDS_VAL'], ['test_acc/sketch_val'], ['test_acc/id_val'], ['test_acc/id_val']]
ood_output_metrics_list = [['test_acc/target_val_living'], ['WORST'], ['test_acc/real_val'], ['test_acc/africa_test'], ['test_acc/ood_test']]
shortened_dataset_names = ['Living-17', 'Waterbirds', 'DomainNet', "FMoW", "Camelyon", "Avg."]


mean_table, std_table, lr_table = get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list, desired_epochs_list=desired_epochs_list)
# pickle.dump( (mean_table, std_table), open( "big_table_main.pkl", "wb" ) )
# mean_table_old_ordit
er, std_table_old_order = pickle.load( open( "big_table_main.pkl", "rb" ) )

# Rearranging mean table


id_mean_table, ood_mean_table = display_id_ood_tables(
    mean_table, std_table, id_output_metrics_list, ood_output_metrics_list, shortened_dataset_names, no_std=True)

# display_tsv_table(shorted_model_names, shortened_options_names, shortened_output_metrics_list, mean_table, std_table)

not all jobs ran:  bitresnet_gn  camelyon17_weakaugs 5


ID Accuracies
		Living-17	Waterbirds	DomainNet	FMoW	Camelyon	Avg.
SGD	SGD	97.6	97.6	88.3	63.1	99.5	89.2
SGD	AdamW	97.8	97.6	87.2	62.4	99.5	88.9
SGD	SGD (freeze-embed)	97.6	97.6	87.3	62.7	99.5	89.0
AdamW	SGD	97.2	97.2	87.5	58.9	99.3	88.0
AdamW	AdamW	97.4	97.3	86.5	58.1	99.5	87.8
AdamW	SGD (freeze-embed)	97.4	97.1	86.9	58.2	99.3	87.8
SGD (Patchify)	SGD	97.9	97.5	88.3	62.8	99.4	89.2
SGD (Patchify)	AdamW	97.5	97.3	87.0	62.2	99.5	88.7
SGD (Patchify)	SGD (freeze-embed)	97.6	97.5	87.6	63.1	99.4	89.0
AdamW (Patchify)	SGD	97.5	97.3	87.7	59.6	99.3	88.3
AdamW (Patchify)	AdamW	97.6	97.2	86.2	58.4	99.5	87.8
AdamW (Patchify)	SGD (freeze-embed)	97.6	97.2	87.0	59.4	99.2	88.1

\begin{tabular}{cccccccccccc|cc}
\toprule
 &  & \multicolumn{2}{c}{Living-17}  & \multicolumn{2}{c}{Waterbirds}  & \multicolumn{2}{c}{DomainNet}  & \multicolumn{2}{c}{FMoW}  & \multicolumn{2}{c}{Camelyon}  & \multicolumn{2}{c}{Avg.} \\
\midrule
SGD & SGD & \textbf{97.6} &  