In [1]:
import numpy as np
import summarize_all_results
import importlib
importlib.reload(summarize_all_results)

<module 'summarize_all_results' from '/juice/scr/ananya/cifar_experiments/transfer_learning/scripts/summarize_all_results.py'>

In [2]:

def get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list=None, aggregation_metric_list=None):
    table = np.zeros((len(models), len(options), sum([len(ol) for ol in output_metrics_list])))
    for model_idx, model in enumerate(models):
        for option_idx, option in enumerate(options):
            cur_results = []
            for dataset_idx, dataset in enumerate(datasets):
                dir_path = log_dir + '/' + name_template.format(model=model, option=option, dataset=dataset)
                output_metrics = output_metrics_list[dataset_idx]
                if val_metric_list is None:
                    val_metric = 'LAST'
                else:
                    val_metric = val_metric_list[dataset_idx]
                if aggregation_metric_list is None:
                    agg_metric = val_metric
                else:
                    agg_metric = aggregation_metric_list[dataset_idx]
                _, best_row_mean, best_row_std = summarize_all_results.get_experiment_aggregate_summary(
                    dir_path, val_metric, output_metrics, aggregation_metric=agg_metric)
                cur_results.append(best_row_mean.to_numpy()[0,:-2])
            cur_results = np.concatenate(cur_results)
            table[model_idx][option_idx] = cur_results
    return table

# Add table to make TSV and Latex tables.
def display_tsv_table(models, options, shortened_output_metrics_list, table, std_err_table=None):
    first_line = '\t\t' + '\t'.join(shortened_output_metrics_list)
    print(first_line)
    for model_idx, model in enumerate(models):
        for option_idx, option in enumerate(options):
            line = model + '\t' + option
            for idx, val in enumerate(table[model_idx][option_idx]):
                line += '\t' + '{:.1f}'.format(val)
                if std_err_table is not None:
                    line += ' ({:.1f})'.format(std_err_table[model_idx][option_idx])
            print(line)

In [3]:
log_dir = '../logs/'
name_template = 'full_ft_{option}{dataset}_{model}'
models = ['clip_vit_b16', 'timm_vit_b16_in21k', 'dino_vit_b16', 'bit_resnet_50_in21k', 'bit_resnet_101_in21k', 'convnext_vit_b', 'clip_resnet50', 'clip_vit_l14']
options = ['', 'opt_torch.optim.AdamW_', 'freeze_bottom_2_', 'freeze_bottom_2_opt_torch.optim.AdamW_']
datasets = ['living17_nonorm', 'waterbirds', 'domainnet']
output_metrics_list = [['test_acc/source_val_living', 'test_acc/target_val_living'], ['WATERBIRDS_VAL', 'WORST'], ['test_acc/sketch_val', 'test_acc/real_val']]
val_metric_list = ['LAST', 'LAST', 'LAST']
aggregation_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val']

In [34]:
table = get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list, aggregation_metric_list)

In [5]:
shorted_model_names = ['CLIP ViT-B/16', 'Sup ViT-B/16', 'DINO ViT-B/16', 'BIT ResNet-50', 'BIT ResNet-101', 'ConvNext-Base', 'CLIP ResNet-50', 'CLIP ViT-L/14']
shortened_options_names = ['SGD', 'AdamW', 'SGD (Freeze-2)', 'AdamW (Freeze-2)']
shortened_output_metrics_list = ['Living-17 ID', 'Living-17 OOD', 'Waterbirds ID', 'Waterbirds OOD', 'DomainNet ID', 'DomainNet OOD']
display_tsv_table(shorted_model_names, shortened_options_names, shortened_output_metrics_list, table)

		Living-17 ID	Living-17 OOD	Waterbirds ID	Waterbirds OOD	DomainNet ID	DomainNet OOD
CLIP ViT-B/16	SGD	71.2	56.5	76.8	0.0	9.0	6.0
CLIP ViT-B/16	AdamW	74.5	63.6	76.8	0.0	9.1	6.4


IndexError: index 2 is out of bounds for axis 0 with size 2

In [36]:
val_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val']
table = get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list)
display_tsv_table(shorted_model_names, shortened_options_names, shortened_output_metrics_list, table)

		Living-17 ID	Living-17 OOD	Waterbirds ID	Waterbirds OOD	DomainNet ID	DomainNet OOD
CLIP ViT-B/16	SGD	97.8	81.4	97.3	61.7	89.6	74.3
CLIP ViT-B/16	AdamW	98.1	81.5	97.9	69.3	95.0	90.6
CLIP ViT-B/16	SGD (Freeze-2)	98.1	80.7	97.9	73.1	95.2	89.0
CLIP ViT-B/16	AdamW (Freeze-2)	98.3	82.4	97.8	69.5	95.3	88.2
Sup ViT-B/16	SGD	98.6	89.5	99.1	77.4	91.7	86.3
Sup ViT-B/16	AdamW	98.7	88.3	99.0	81.6	91.7	84.4
Sup ViT-B/16	SGD (Freeze-2)	98.7	88.0	99.2	82.4	91.5	86.3
Sup ViT-B/16	AdamW (Freeze-2)	98.6	88.1	99.0	82.4	90.9	82.3
DINO ViT-B/16	SGD	98.4	88.2	97.0	56.1	88.2	76.0
DINO ViT-B/16	AdamW	98.5	87.4	97.9	61.2	89.4	77.4
DINO ViT-B/16	SGD (Freeze-2)	98.4	86.7	97.5	67.9	89.0	78.4
DINO ViT-B/16	AdamW (Freeze-2)	98.4	86.8	97.8	64.5	89.7	76.6
BIT ResNet-50	SGD	97.4	84.3	98.4	76.5	89.3	80.0
BIT ResNet-50	AdamW	97.2	83.1	98.5	74.8	89.2	84.0
BIT ResNet-50	SGD (Freeze-2)	97.6	84.1	98.5	75.5	89.2	82.3
BIT ResNet-50	AdamW (Freeze-2)	97.4	82.9	98.4	77.3	89.1	83.3
BIT ResNet-101	SGD	98.3	82.8	98.9	76.9	92.0	86.

In [9]:
# LAMB and LARS results for CLIP ViT-B/16
models = ['clip_vit_b16']
options = ['opt_torch_optimizer.Lamb_', 'opt_torch_optimizer.LARS_']
val_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val']
datasets = ['living17_nonorm', 'waterbirds', 'domainnet']
output_metrics_list = [['test_acc/source_val_living', 'test_acc/target_val_living'], ['WATERBIRDS_VAL', 'WORST'], ['test_acc/sketch_val', 'test_acc/real_val']]
val_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val']
aggregation_metric_list = ['test_acc/source_val_living', 'WATERBIRDS_VAL', 'test_acc/sketch_val']
table = get_table(log_dir, name_template, models, options, datasets, output_metrics_list, val_metric_list)
shorted_model_names = ['CLIP ViT-B/16']
shortened_options_names = ['Lamb', 'LARS']
shortened_output_metrics_list = ['Living-17 ID', 'Living-17 OOD', 'Waterbirds ID', 'Waterbirds OOD', 'DomainNet ID', 'DomainNet OOD']
display_tsv_table(shorted_model_names, shortened_options_names, shortened_output_metrics_list, table)

		Living-17 ID	Living-17 OOD	Waterbirds ID	Waterbirds OOD	DomainNet ID	DomainNet OOD
CLIP ViT-B/16	Lamb	98.2	79.5	97.8	64.0	95.1	90.4
CLIP ViT-B/16	LARS	97.7	83.9	97.1	48.6	93.2	83.8


In [24]:
best_row_mean.to_numpy()[:,:-2]

array([[97.026]], dtype=object)