In [1]:
# Useful starting lines 
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
%load_ext autoreload
%autoreload 2

In [43]:
import os 
import sys
import time
import copy
from copy import deepcopy
import pickle
import math
import functools 
from IPython.display import display, HTML
import operator
from operator import itemgetter
import pandas as pd
import seaborn as sns
from matplotlib.lines import Line2D
from glob import glob

sns.set(style="darkgrid")
sns.set_context("paper")
pd.set_option('future.no_silent_downcasting',True)

In [3]:
root_path = '/root/autodl-tmp/ttab'
sys.path.append(root_path)

In [4]:
from monitor.tools.show_results import extract_list_of_records, reorder_records, get_pickle_info, summarize_info
from monitor.tools.plot import plot_curve_wrt_time
import monitor.tools.plot_utils as plot_utils

from monitor.tools.utils import dict_parser
from monitor.tools.file_io import load_pickle

In [166]:
def get_stats(experiment:str,conditions={},root_data_path=os.path.join(root_path,  'logs', 'resnet26')):
    # Have a glimpse of experimental results.
    raw_records = get_pickle_info(root_data_path, [experiment])
    attributes = ['model_adaptation_method', 'n_train_steps', 'episodic','lr', 'model_selection_method', 'seed', 'data_names', 'status']
    records = extract_list_of_records(list_of_records=raw_records, conditions=conditions)
    aggregated_results, averaged_records_overall = summarize_info(records, attributes, reorder_on='model_adaptation_method', groupby_on='test-overall-accuracy', larger_is_better=True)
    return aggregated_results,averaged_records_overall 

In [167]:
experiments =glob(os.path.join(os.pardir,'logs',"*","*"))
conditions = {
    # "model_adaptation_method": ["tent"],
    "seed": [2022],
    # "batch_size": [64],
    # "episodic": [False],
    # "n_train_steps": [50],
    # "lr": [0.005],
    # "data_names": ["cifar10_c_deterministic-gaussian_noise-5"],
}

In [168]:
experiments

['../logs/resnet26/tent_cifar10c_online_last_iterate',
 '../logs/resnet26/tent_cifar10_1_online_last_iterate',
 '../logs/resnet26/tent_cifar100c_online_last_iterate',
 '../logs/resnet26/tent_cifar10c_episodic_oracle_model_selection',
 '../logs/resnet26/tent_cifar100c_episodic_oracle_model_selection',
 '../logs/resnet26/tent_cifar10_1_episodic_oracle_model_selection',
 '../logs/resnet50/tent_officehome_online_last_iterate',
 '../logs/resnet50/tent_pacs_online_last_iterate',
 '../logs/resnet50/tent_officehome_episodic_oracle_model_selection',
 '../logs/resnet50/tent_pacs_episodic_oracle_model_selection']

In [173]:
exp = '../logs/resnet26/tent_cifar10c_episodic_oracle_model_selection'
results,overall = get_stats(os.path.basename(exp),conditions,os.path.dirname(exp))
display(results)

we have 45/135 records.


Unnamed: 0,model_adaptation_method,n_train_steps,episodic,lr,model_selection_method,seed,data_names,status,test-overall-accuracy
0,tent,10,True,0.001,oracle_model_selection,2022,cifar10_c_deterministic-gaussian_noise-5,finished,61.49
1,tent,10,True,0.0005,oracle_model_selection,2022,cifar10_c_deterministic-gaussian_noise-5,finished,61.38
2,tent,10,True,0.0001,oracle_model_selection,2022,cifar10_c_deterministic-gaussian_noise-5,finished,60.96
3,tent,10,True,0.001,oracle_model_selection,2022,cifar10_c_deterministic-shot_noise-5,finished,63.6
4,tent,10,True,0.0005,oracle_model_selection,2022,cifar10_c_deterministic-shot_noise-5,finished,63.55
5,tent,10,True,0.0001,oracle_model_selection,2022,cifar10_c_deterministic-shot_noise-5,finished,63.24
6,tent,10,True,0.001,oracle_model_selection,2022,cifar10_c_deterministic-impulse_noise-5,finished,54.81
7,tent,10,True,0.0005,oracle_model_selection,2022,cifar10_c_deterministic-impulse_noise-5,finished,54.64
8,tent,10,True,0.0001,oracle_model_selection,2022,cifar10_c_deterministic-impulse_noise-5,finished,54.25
9,tent,10,True,0.001,oracle_model_selection,2022,cifar10_c_deterministic-defocus_blur-5,finished,83.03


In [172]:
exp = '../logs/resnet26/tent_cifar10c_online_last_iterate'
results,overall = get_stats(os.path.basename(exp),conditions,os.path.dirname(exp))
display(results)

we have 135/405 records.


Unnamed: 0,model_adaptation_method,n_train_steps,episodic,lr,model_selection_method,seed,data_names,status,test-overall-accuracy
0,tent,2,False,0.0050,last_iterate,2022,cifar10_c_deterministic-gaussian_noise-5,finished,69.10
1,tent,1,False,0.0050,last_iterate,2022,cifar10_c_deterministic-gaussian_noise-5,finished,72.12
2,tent,1,False,0.0010,last_iterate,2022,cifar10_c_deterministic-gaussian_noise-5,finished,67.51
3,tent,1,False,0.0005,last_iterate,2022,cifar10_c_deterministic-gaussian_noise-5,finished,64.78
4,tent,3,False,0.0050,last_iterate,2022,cifar10_c_deterministic-gaussian_noise-5,finished,57.30
...,...,...,...,...,...,...,...,...,...
130,tent,2,False,0.0010,last_iterate,2022,cifar10_c_deterministic-jpeg_compression-5,finished,72.98
131,tent,3,False,0.0010,last_iterate,2022,cifar10_c_deterministic-jpeg_compression-5,finished,73.16
132,tent,1,False,0.0005,last_iterate,2022,cifar10_c_deterministic-jpeg_compression-5,finished,68.35
133,tent,2,False,0.0005,last_iterate,2022,cifar10_c_deterministic-jpeg_compression-5,finished,70.22


In [146]:
result = pd.DataFrame(data=0,columns=['CIFAR10-C','CIFAR100-C','CIFAR10.1','OfficeHome','PACS'],index=pd.MultiIndex.from_product([
    ['TENT','SAR'],
    ['episodic','online']
],names=['method','protocol']))

In [164]:
for experiment in experiments:
    aggregated_results,averaged_records_overall = get_stats(os.path.basename(experiment),conditions,os.path.dirname(experiment))
    adaptation_method = 'TENT' if 'tent' in experiment else 'SAR'
    protocol = 'episodic' if 'episodic' in experiment else 'online'
    if 'cifar10c' in experiment:
        dataset = 'CIFAR10-C'
    elif 'cifar100c' in experiment:
        dataset = 'CIFAR100-C'
    elif 'cifar10_1' in experiment:
        dataset = 'CIFAR10.1'
    elif 'officehome' in experiment:
        dataset = 'OfficeHome'
    elif 'pacs' in experiment:
        dataset = 'PACS'
    else:
        raise NotImplementedError('invalid experiment!')
    result.loc[(adaptation_method,protocol),dataset] = aggregated_results.loc[:,'test-overall-accuracy'].agg('mean')
    

we have 135/405 records.
we have 9/27 records.
we have 135/405 records.
we have 45/135 records.
we have 45/135 records.
we have 3/9 records.
we have 108/324 records.
we have 108/324 records.
we have 36/108 records.
we have 15/43 records.


In [165]:
100-result

Unnamed: 0_level_0,Unnamed: 1_level_0,CIFAR10-C,CIFAR100-C,CIFAR10.1,OfficeHome,PACS
method,protocol,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
TENT,episodic,27.182667,55.210667,18.373742,38.718874,15.377517
TENT,online,24.038296,52.484444,18.516686,38.934265,25.103918
SAR,episodic,100.0,100.0,100.0,100.0,100.0
SAR,online,100.0,100.0,100.0,100.0,100.0
