In [1]:
import numpy as np
import summarize_all_results
import importlib
import math
import pickle
importlib.reload(summarize_all_results)

<module 'summarize_all_results' from 'c:\\Users\\suriyag\\Documents\\GitHub\\transfer_learning\\scripts\\summarize_all_results.py'>

In [37]:
log_dir = '../logs/'
name_template = 'full_ft_{option}{dataset}_{model}'

def get_bolds(table, std_err_table=None, eps=0.2):
    is_best = np.zeros((len(table), len(table[0])), dtype=np.int32)
    for col in range(len(table[0])):
        best_idx = np.argmax(table[:,col])
        best_acc = table[best_idx][col]
        if std_err_table is not None and col < len(std_err_table[0]):
            best_std = std_err_table[best_idx][col]
            if math.isnan(best_std):
                best_std = eps
        for row in range(len(table)):
            if std_err_table is not None and col < len(std_err_table[0]):
                cur_std = std_err_table[row][col]
                if math.isnan(cur_std):
                    cur_std = eps
                if np.abs(table[row][col] - best_acc) <= max(best_std, cur_std):
                    is_best[row][col] = 1
            else:
                if table[row][col] >= best_acc - eps:
                    is_best[row][col] = 1
    return is_best

def get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list=None,
              aggregation_metric_list=None, num_sweep=6, desired_epochs_list=None):
    # desired_epochs_list is a list of size len(datasets) that stores the number of epochs we are
    # supposed to have run each dataset for. Used to validate if the runs finished.
    mean_table = np.zeros((len(models), len(options), sum([len(ol) for ol in output_metrics_list])))
    std_table = np.zeros((len(models), len(options), sum([len(ol) for ol in output_metrics_list])))
    for model_idx, model in enumerate(models):
        for option_idx, option in enumerate(options):
            cur_means = []
            cur_stds = []
            for dataset_idx, dataset in enumerate(datasets):
                dir_path = log_dir + '/' + name_template.format(model=model, option=option, dataset=dataset)
                output_metrics = output_metrics_list[dataset_idx]
                if val_metric_list is None:
                    val_metric = 'LAST'
                else:
                    val_metric = val_metric_list[dataset_idx]
                if aggregation_metric_list is None:
                    agg_metric = val_metric
                else:
                    agg_metric = aggregation_metric_list[dataset_idx]
                res, best_row_mean, best_row_std, num_epochs_list = summarize_all_results.get_experiment_aggregate_summary(
                    dir_path, val_metric, output_metrics, aggregation_metric=agg_metric)
                num_jobs = len(res)
                if (desired_epochs_list is not None and
                    desired_epochs_list[dataset_idx] != np.min(num_epochs_list)):
                    bad_indices = np.argwhere(np.array(num_epochs_list) < desired_epochs_list[dataset_idx])
                    bad_runs = [dir_path + '/' + run_folder for run_folder in res['name'][bad_indices[:, 0]]]
                    print('some jobs did not finish: ', model, dataset, option, bad_runs)
                if num_jobs != num_sweep and num_jobs != num_sweep * 3:
                    print('not all jobs ran: ', model, option, dataset, num_jobs)
                cur_means.append(best_row_mean.to_numpy()[0,:-2])
                cur_stds.append(best_row_std.to_numpy()[0,:-2])
            cur_means = np.concatenate(cur_means)
            cur_stds = np.concatenate(cur_stds)
            mean_table[model_idx][option_idx] = cur_means
            std_table[model_idx][option_idx] = cur_stds
    return mean_table, std_table

def flatten(list_of_lists):
    x = []
    for l in list_of_lists:
        x.append(l)
    return x

def filter_columns(table, old_output_metrics_list, new_output_metrics_list):
    def get_indices(old_list, new_items):
        return np.array([old_list.index(x) for x in new_items])
    local_indices_list = [get_indices(o, n) for o, n in zip(old_output_metrics_list, new_output_metrics_list)]
#     print(local_indices_list)
    old_lens = [len(o) for o in old_output_metrics_list]
    cum_old_lens = np.concatenate([[0], np.cumsum(old_lens)])[:-1]
#     print(cum_old_lens)
    global_indices = [local_indices + cum_old_len for local_indices, cum_old_len in zip(local_indices_list, cum_old_lens)]
#     print(list(np.concatenate(global_indices)))
    return table[:, :, global_indices][:,:,:,0]


# Add table to make TSV and Latex tables. Add averages.
def display_tsv_table(models, options, shortened_output_metrics_list, table, std_err_table=None):
    first_line = '\t\t' + '\t'.join(shortened_output_metrics_list)
    print(first_line)
    for model_idx, model in enumerate(models):
        for option_idx, option in enumerate(options):
            line = model + '\t' + option
            for idx, val in enumerate(table[model_idx][option_idx]):
                line += '\t' + '{:.1f}'.format(val)
                if (std_err_table is not None and
                    idx < len(std_err_table[model_idx][option_idx])):
                    line += ' ({:.1f})'.format(std_err_table[model_idx][option_idx][idx])
            print(line)
            
def add_average_column(table):
    new_table = []
    for model_idx in range(len(table)):
        model_table = table[model_idx]
        means = np.mean(model_table, axis=1)
        new_model_table = np.concatenate([model_table, np.expand_dims(means, axis=-1)], axis=-1)
        new_table.append(new_model_table)
    return np.array(new_table)
        
def display_latex_table(models, options, shortened_output_metrics_list, table, std_err_table=None, bold_best=True, is_last_avg=True):

    # Print latex table.
    if len(models) == 1:
        num_columns = 2*len(shortened_output_metrics_list) + 1
    else:
        num_columns = 2*len(shortened_output_metrics_list) + 2
    if is_last_avg:
        print('\\begin{tabular}{' + ('c'*(num_columns-2)) + '|cc}')
    else:
        print('\\begin{tabular}{' + ('c'*num_columns) + '}')
    print('\\toprule')
    first_line = '' 
    for oml in shortened_output_metrics_list:
        first_line += ' & \multicolumn{2}{c}{' + oml + '} '
    first_line += '\\\\'
    if len(models) > 1:
        first_line = ' & ' + first_line
    print(first_line)
    for model_idx, model in enumerate(models):
        print('\\midrule')
        if std_err_table is None:
            is_bold = get_bolds(table[model_idx], eps=0.2)
        else:
            is_bold = get_bolds(table[model_idx], std_err_table[model_idx], eps=0.2)
        for option_idx, option in enumerate(options):
            if len(models) > 1:
                line = model + ' & ' + option
            else:
                line = option
            for idx, val in enumerate(table[model_idx][option_idx]):
                line += ' & '
#                 print(is_bold[option_idx][idx])
                if is_bold[option_idx][idx] and bold_best:
                    line += '\\textbf{'
                line += '{:.1f}'.format(val)
                if (std_err_table is not None and
                    idx < len(std_err_table[model_idx][option_idx]) and
                    not math.isnan(std_err_table[model_idx][option_idx][idx])):
                    line += ' ({:.1f})'.format(std_err_table[model_idx][option_idx][idx])
                if is_bold[option_idx][idx] and bold_best:
                    line += '}'
                if option_idx:# if not sgd
                    sgd_val = table[model_idx][0][idx]
                    diff = val-sgd_val
                    if diff>=0:
                        line += ' & \hspace{{-1em}}{{\hgreen{{(+{:.1f})}}}}'.format(diff)
                    else:
                        line += ' & \hspace{{-1em}}{{\hred{{({:.1f})}}}}'.format(diff)
                else:
                    line += ' & '

            line += '\\\\'
            print(line)
    print('\\bottomrule')
    print('\\end{tabular}')
    
    
def display_id_ood_tables(mean_table, std_table, id_output_metrics_list, ood_output_metrics_list, shortened_dataset_names, no_std=False):
    # Filter the ID and OOD metrics.
    print('\n\nID Accuracies')
    id_mean_table = filter_columns(mean_table, output_metrics_list, id_output_metrics_list)
    id_mean_table = add_average_column(id_mean_table)
    id_std_table = filter_columns(std_table, output_metrics_list, id_output_metrics_list)
    if no_std:
        id_std_table = None
    display_tsv_table(shorted_model_names, shortened_options_names, shortened_dataset_names, id_mean_table, id_std_table)
    print('')
    display_latex_table(shorted_model_names, shortened_options_names, shortened_dataset_names, id_mean_table, id_std_table, is_last_avg=True)

    print('\n\nOOD Accuracies')
    ood_mean_table = filter_columns(mean_table, output_metrics_list, ood_output_metrics_list)
    ood_mean_table = add_average_column(ood_mean_table)
    ood_std_table = filter_columns(std_table, output_metrics_list, ood_output_metrics_list)
    if no_std:
        ood_std_table = None
    display_tsv_table(shorted_model_names, shortened_options_names, shortened_dataset_names, ood_mean_table, ood_std_table)
    print('')
    display_latex_table(shorted_model_names, shortened_options_names, shortened_dataset_names, ood_mean_table, ood_std_table, is_last_avg=True)
    return id_mean_table, ood_mean_table
    

# Big table, Living17, Waterbirds, DomainNet, Camelyon, FMoW (Early stopped)


In [38]:
# Big table, Living17, Waterbirds, DomainNet, Camelyon, FMoW (Early stopped)
models = ['clip_vit_b16', 'clip_vit_l14', 'timm_vit_b16_in21k', 'dino_vit_b16',  'convnext_vit_b', 'bit_resnet_50_in21k', 'bit_resnet_101_in21k']
options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_']
datasets = ['living17_nonorm', 'waterbirds', 'domainnet', 'fmow_all_nonorm_weakaugs', 'camelyon17_weakaugs']
output_metrics_list = [['test_acc/source_val_living', 'test_acc/target_val_living'], ['WATERBIRDS_VAL', 'WORST'], ['test_acc/sketch_val', 'test_acc/real_val'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test', 'test_acc/africa_test'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test']]
val_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
aggregation_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
desired_epochs_list = [20, 20, 50, 5, 3]
shorted_model_names = ['CLIP ViT-B/16', 'CLIP ViT-L/14', 'Sup ViT-B/16', 'DINO ViT-B/16', 'ConvNext-Base', 'BiT ResNet-50', 'BiT ResNet-101', ]
shortened_options_names = ['SGD', 'AdamW', 'SGD (freeze-embed)']
id_output_metrics_list = [['test_acc/source_val_living'], ['WATERBIRDS_VAL'], ['test_acc/sketch_val'], ['test_acc/id_val'], ['test_acc/id_val']]
ood_output_metrics_list = [['test_acc/target_val_living'], ['WORST'], ['test_acc/real_val'], ['test_acc/africa_test'], ['test_acc/ood_test']]
shortened_dataset_names = ['Living-17', 'Waterbirds', 'DomainNet', "FMoW", "Camelyon", "Avg."]


# mean_table, std_table = get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list, desired_epochs_list=desired_epochs_list)
# pickle.dump( (mean_table, std_table), open( "big_table_main.pkl", "wb" ) )

mean_table, std_table = pickle.load( open( "big_table_main.pkl", "rb" ) )
id_mean_table, ood_mean_table = display_id_ood_tables(
    mean_table, std_table, id_output_metrics_list, ood_output_metrics_list, shortened_dataset_names, no_std=True)

# display_tsv_table(shorted_model_names, shortened_options_names, shortened_output_metrics_list, mean_table, std_table)



ID Accuracies
		Living-17	Waterbirds	DomainNet	FMoW	Camelyon	Avg.
CLIP ViT-B/16	SGD	97.8	97.2	88.8	67.0	99.4	90.0
CLIP ViT-B/16	AdamW	98.1	97.7	95.0	70.1	99.5	92.1
CLIP ViT-B/16	SGD (freeze-embed)	98.2	97.8	94.9	70.0	99.5	92.1
CLIP ViT-L/14	SGD	98.6	99.1	91.7	64.1	99.4	90.6
CLIP ViT-L/14	AdamW	98.7	99.0	91.7	66.4	99.5	91.1
CLIP ViT-L/14	SGD (freeze-embed)	98.7	99.2	91.5	65.0	99.6	90.8
Sup ViT-B/16	SGD	98.4	97.0	88.2	62.4	99.4	89.1
Sup ViT-B/16	AdamW	98.5	97.9	89.4	66.0	99.6	90.3
Sup ViT-B/16	SGD (freeze-embed)	98.4	97.5	89.0	63.5	99.5	89.6
DINO ViT-B/16	SGD	97.4	98.4	89.3	64.6	99.5	89.8
DINO ViT-B/16	AdamW	97.2	98.5	89.2	65.1	99.5	89.9
DINO ViT-B/16	SGD (freeze-embed)	97.6	98.5	89.2	64.8	99.5	89.9
ConvNext-Base	SGD	98.3	98.9	92.0	66.0	99.4	90.9
ConvNext-Base	AdamW	98.4	98.6	91.1	67.0	99.6	90.9
ConvNext-Base	SGD (freeze-embed)	98.4	98.8	91.5	65.9	99.5	90.8
BiT ResNet-50	SGD	98.7	99.0	94.8	66.3	99.4	91.6
BiT ResNet-50	AdamW	98.6	99.5	94.5	68.8	99.7	92.2
BiT ResNet-50	SGD (freeze-embed)

In [None]:
# Get average ID and OOD results for the 3 methods.
print(np.mean(id_mean_table[:,:,5], axis=0))
print(np.mean(ood_mean_table[:,:,5], axis=0))

# CLIP results including wilds datasets

In [None]:

print("CLIP Results, all datasets")
models = ['clip_vit_b16']
options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_', 'layer_wise_tune__', 'opt_torch_optimizer.Lamb_', 'opt_torch_optimizer.LARS_']
val_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
datasets = ['living17_nonorm', 'waterbirds', 'domainnet', 'fmow_all_nonorm_weakaugs', 'camelyon17_weakaugs']
output_metrics_list = [['test_acc/source_val_living', 'test_acc/target_val_living'], ['WATERBIRDS_VAL', 'WORST'], ['test_acc/sketch_val', 'test_acc/real_val'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test', 'test_acc/africa_test'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test']]
aggregation_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
desired_epochs_list = [20, 20, 50, 5, 3]
mean_table, std_table = get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list, desired_epochs_list=desired_epochs_list)
shorted_model_names = ['CLIP ViT-B/16']
shortened_options_names = ['SGD', 'AdamW', 'SGD (Freeze-embed)', 'Layer-wise', 'LAMB', 'LARS']
shortened_output_metrics_list = ['Living-17 ID', 'Living-17 OOD', 'Waterbirds ID', 'Waterbirds OOD', 'DomainNet ID', 'DomainNet OOD', "FMoW ID", "FMoW OOD Val", "FMoW OOD Test", "FMoW Africa", "Camelyon ID", "Camelyon OOD Val", "Camelyon OOD Test"]

id_output_metrics_list = [['test_acc/source_val_living'], ['WATERBIRDS_VAL'], ['test_acc/sketch_val'], ['test_acc/id_val'], ['test_acc/id_val']]
ood_output_metrics_list = [['test_acc/target_val_living'], ['WORST'], ['test_acc/real_val'], ['test_acc/africa_test'], ['test_acc/ood_test']]
shortened_dataset_names = ['Liv-17', 'Waterbirds', 'DomainNet', "FMoW", "Camelyon", "Avg."]

display_id_ood_tables(mean_table, std_table, id_output_metrics_list, ood_output_metrics_list, shortened_dataset_names)


# CLIP Results for ViT-L

In [None]:
# Big table, Living17, Waterbirds, DomainNet, Camelyon, FMoW (Early stopped)
models = ['clip_vit_l14']
options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_']
datasets = ['living17_nonorm', 'waterbirds', 'domainnet', 'fmow_all_nonorm_weakaugs', 'camelyon17_weakaugs']
output_metrics_list = [['test_acc/source_val_living', 'test_acc/target_val_living'], ['WATERBIRDS_VAL', 'WORST'], ['test_acc/sketch_val', 'test_acc/real_val'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test', 'test_acc/africa_test'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test']]
val_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
aggregation_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
desired_epochs_list = [20, 20, 50, 5, 3]
mean_table, std_table = get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list, desired_epochs_list=desired_epochs_list)
shorted_model_names = ['CLIP ViT-B/16', 'Sup ViT-B/16', 'DINO ViT-B/16', 'BIT ResNet-50', 'BIT ResNet-101', 'ConvNext-Base', 'CLIP ViT-L/14']
shortened_options_names = ['SGD', 'AdamW', 'SGD (Freeze-2)']

id_output_metrics_list = [['test_acc/source_val_living'], ['WATERBIRDS_VAL'], ['test_acc/sketch_val'], ['test_acc/id_val'], ['test_acc/id_val']]
ood_output_metrics_list = [['test_acc/target_val_living'], ['WORST'], ['test_acc/real_val'], ['test_acc/africa_test'], ['test_acc/ood_test']]
shortened_dataset_names = ['Liv-17', 'Waterbirds', 'DomainNet', "FMoW", "Camelyon", "Avg."]

id_mean_table, ood_mean_table = display_id_ood_tables(
    mean_table, std_table, id_output_metrics_list, ood_output_metrics_list, shortened_dataset_names, no_std=True)

# display_tsv_table(shorted_model_names, shortened_options_names, shortened_output_metrics_list, mean_table, std_table)

In [None]:
# Get table for dino and clip, freezing different number of layers.
print("CLIP Results, all datasets")
models = ['clip_vit_b16']
options = ['freeze_bottom_2_full_ft_epoch_50_', 'freeze_bottom_5_full_ft_epoch_50_', 'freeze_bottom_8_full_ft_epoch_50_', 'freeze_bottom_11_full_ft_epoch_50_', 'freeze_bottom_14_full_ft_epoch_50_']
val_metric_list = ['WATERBIRDS_VAL']
datasets = ['waterbirds']
output_metrics_list = [['WATERBIRDS_VAL', 'WORST']]
aggregation_metric_list = ['WATERBIRDS_VAL']
# desired_epochs_list = [20, 20, 50, 5, 3]
mean_table, std_table = get_table('../older_logs/', name_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list, desired_epochs_list=desired_epochs_list)
shorted_model_names = ['CLIP ViT-B/16']
shortened_options_names = ['SGD', 'AdamW', 'SGD (Freeze-embed)', 'Layer-wise', 'LAMB', 'LARS']
shortened_output_metrics_list = ['Living-17 ID', 'Living-17 OOD', 'Waterbirds ID', 'Waterbirds OOD', 'DomainNet ID', 'DomainNet OOD', "FMoW ID", "FMoW OOD Val", "FMoW OOD Test", "FMoW Africa", "Camelyon ID", "Camelyon OOD Val", "Camelyon OOD Test"]


# Post ICLR Supplementary results

## Big table, including AdamW + Freeze 

In [25]:
# Big table, Living17, Waterbirds, DomainNet, Camelyon, FMoW (Early stopped)
models = ['clip_vit_b16', 'clip_vit_l14', 'timm_vit_b16_in21k', 'dino_vit_b16',  'convnext_vit_b', 'bit_resnet_50_in21k', 'bit_resnet_101_in21k']
options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_', 'freeze_bottom_2_opt_torch.optim.AdamW_']
datasets = ['living17_nonorm', 'waterbirds', 'domainnet', 'fmow_all_nonorm_weakaugs', 'camelyon17_weakaugs']
output_metrics_list = [['test_acc/source_val_living', 'test_acc/target_val_living'], ['WATERBIRDS_VAL', 'WORST'], ['test_acc/sketch_val', 'test_acc/real_val'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test', 'test_acc/africa_test'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test']]
val_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
aggregation_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
desired_epochs_list = [20, 20, 50, 5, 3]

mean_table_3, std_table_3 = pickle.load( open( "big_table_main.pkl", "rb" ) )
mean_table = np.concatenate((mean_table_3, mean_table_3[:,1,None,:]), axis=1)
std_table = np.concatenate((std_table_3, std_table_3[:,1,None,:]), axis=1)

# mean_table, std_table = get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list, desired_epochs_list=desired_epochs_list)
shorted_model_names = ['CLIP ViT-B/16', 'CLIP ViT-L/14', 'Sup ViT-B/16', 'DINO ViT-B/16', 'ConvNext-Base', 'BiT ResNet-50', 'BiT ResNet-101']
shortened_options_names = ['SGD', 'AdamW', 'SGD (freeze-embed)', 'AdamW (freeze-embed)']

id_output_metrics_list = [['test_acc/source_val_living'], ['WATERBIRDS_VAL'], ['test_acc/sketch_val'], ['test_acc/id_val'], ['test_acc/id_val']]
ood_output_metrics_list = [['test_acc/target_val_living'], ['WORST'], ['test_acc/real_val'], ['test_acc/africa_test'], ['test_acc/ood_test']]
shortened_dataset_names = ['Living-17', 'Waterbirds', 'DomainNet', "FMoW", "Camelyon", "Avg."]

id_mean_table, ood_mean_table = display_id_ood_tables(
    mean_table, std_table, id_output_metrics_list, ood_output_metrics_list, shortened_dataset_names, no_std=True)

# display_tsv_table(shorted_model_names, shortened_options_names, shortened_output_metrics_list, mean_table, std_table)



ID Accuracies
		Living-17	Waterbirds	DomainNet	FMoW	Camelyon	Avg.
CLIP ViT-B/16	SGD	97.8	97.2	88.8	67.0	99.4	90.0
CLIP ViT-B/16	AdamW	98.1	97.7	95.0	70.1	99.5	92.1
CLIP ViT-B/16	SGD (freeze-embed)	98.2	97.8	94.9	70.0	99.5	92.1
CLIP ViT-B/16	AdamW (freeze-embed)	98.1	97.7	95.0	70.1	99.5	92.1
CLIP ViT-L/14	SGD	98.6	99.1	91.7	64.1	99.4	90.6
CLIP ViT-L/14	AdamW	98.7	99.0	91.7	66.4	99.5	91.1
CLIP ViT-L/14	SGD (freeze-embed)	98.7	99.2	91.5	65.0	99.6	90.8
CLIP ViT-L/14	AdamW (freeze-embed)	98.7	99.0	91.7	66.4	99.5	91.1
Sup ViT-B/16	SGD	98.4	97.0	88.2	62.4	99.4	89.1
Sup ViT-B/16	AdamW	98.5	97.9	89.4	66.0	99.6	90.3
Sup ViT-B/16	SGD (freeze-embed)	98.4	97.5	89.0	63.5	99.5	89.6
Sup ViT-B/16	AdamW (freeze-embed)	98.5	97.9	89.4	66.0	99.6	90.3
DINO ViT-B/16	SGD	97.4	98.4	89.3	64.6	99.5	89.8
DINO ViT-B/16	AdamW	97.2	98.5	89.2	65.1	99.5	89.9
DINO ViT-B/16	SGD (freeze-embed)	97.6	98.5	89.2	64.8	99.5	89.9
DINO ViT-B/16	AdamW (freeze-embed)	97.2	98.5	89.2	65.1	99.5	89.9
ConvNext-Base	SGD	98.3	98.9	92.0

In [12]:
mean_table = np.concatenate((mean_table, mean_table[:,1,None,:]), axis=1)
mean_table.shape

(7, 5, 13)

In [None]:
# Get average ID and OOD results for the 3 methods.
print(np.mean(id_mean_table[:,:,5], axis=0))
print(np.mean(ood_mean_table[:,:,5], axis=0))

## ResNet-50 all results 

In [None]:
# Big table, Living17, Waterbirds, DomainNet, Camelyon, FMoW (Early stopped)
models = ['sup_resnet50']
options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_', 'freeze_bottom_2_opt_torch.optim.AdamW_']
datasets = ['living17_nonorm', 'waterbirds', 'domainnet', 'fmow_all_nonorm_weakaugs', 'camelyon17_weakaugs']
output_metrics_list = [['test_acc/source_val_living', 'test_acc/target_val_living'], ['WATERBIRDS_VAL', 'WORST'], ['test_acc/sketch_val', 'test_acc/real_val'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test', 'test_acc/africa_test'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test']]
val_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
aggregation_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
desired_epochs_list = [20, 20, 50, 5, 3]
mean_table, std_table = get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list, desired_epochs_list=desired_epochs_list)
shorted_model_names = ['ResNet50']
shortened_options_names = ['SGD', 'AdamW', 'SGD (Freeze-2)', 'AdamW (Freeze-2)']

id_output_metrics_list = [['test_acc/source_val_living'], ['WATERBIRDS_VAL'], ['test_acc/sketch_val'], ['test_acc/id_val'], ['test_acc/id_val']]
ood_output_metrics_list = [['test_acc/target_val_living'], ['WORST'], ['test_acc/real_val'], ['test_acc/africa_test'], ['test_acc/ood_test']]
shortened_dataset_names = ['Liv-17', 'Waterbirds', 'DomainNet', "FMoW", "Camelyon", "Avg."]

id_mean_table, ood_mean_table = display_id_ood_tables(
    mean_table, std_table, id_output_metrics_list, ood_output_metrics_list, shortened_dataset_names, no_std=True)

# display_tsv_table(shorted_model_names, shortened_options_names, shortened_output_metrics_list, mean_table, std_table)

## CLIP weight decay, no momentum, freeze only embed 

In [None]:

print("CLIP Results, all datasets")
models = ['clip_vit_b16']
# options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_', 'layer_wise_tune__', 'opt_torch_optimizer.Lamb_', 'opt_torch_optimizer.LARS_', 'freeze_bottom_1_', 'optimizer.args.momentum-0.0_', 'freeze_bottom_2_optimizer.args.momentum-0.0_', 'optimizer.args.weight_decay-0.01_']
options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_', 'layer_wise_tune__', 'opt_torch_optimizer.Lamb_', 'opt_torch_optimizer.LARS_', 'freeze_bottom_1_', 'optimizer.args.momentum-0.0_', 'optimizer.args.weight_decay-0.01_']
val_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
datasets = ['living17_nonorm', 'waterbirds', 'domainnet', 'fmow_all_nonorm_weakaugs', 'camelyon17_weakaugs']
output_metrics_list = [['test_acc/source_val_living', 'test_acc/target_val_living'], ['WATERBIRDS_VAL', 'WORST'], ['test_acc/sketch_val', 'test_acc/real_val'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test', 'test_acc/africa_test'], ['test_acc/id_val', 'test_acc/ood_val', 'test_acc/ood_test']]
aggregation_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val', 'test_acc/id_val', 'test_acc/id_val']
desired_epochs_list = [20, 20, 50, 5, 3]
mean_table, std_table = get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list, desired_epochs_list=desired_epochs_list)
shorted_model_names = ['CLIP ViT-B/16']
# shortened_options_names = ['SGD', 'AdamW', 'SGD (Freeze-embed)', 'Layer-wise', 'LAMB', 'LARS', 'Freeze only embed', 'SGD (no momentum)', 'SGD (Freeze, no momentum)', 'SGD (weight decay)']
shortened_options_names = ['SGD', 'AdamW', 'SGD (Freeze-embed)', 'Layer-wise', 'LAMB', 'LARS', 'Freeze only embed', 'SGD (no momentum)', 'SGD (weight decay)']
shortened_output_metrics_list = ['Living-17 ID', 'Living-17 OOD', 'Waterbirds ID', 'Waterbirds OOD', 'DomainNet ID', 'DomainNet OOD', "FMoW ID", "FMoW OOD Val", "FMoW OOD Test", "FMoW Africa", "Camelyon ID", "Camelyon OOD Val", "Camelyon OOD Test"]

id_output_metrics_list = [['test_acc/source_val_living'], ['WATERBIRDS_VAL'], ['test_acc/sketch_val'], ['test_acc/id_val'], ['test_acc/id_val']]
ood_output_metrics_list = [['test_acc/target_val_living'], ['WORST'], ['test_acc/real_val'], ['test_acc/africa_test'], ['test_acc/ood_test']]
shortened_dataset_names = ['Liv-17', 'Waterbirds', 'DomainNet', "FMoW", "Camelyon", "Avg."]

display_id_ood_tables(mean_table, std_table, id_output_metrics_list, ood_output_metrics_list, shortened_dataset_names, no_std=True)
