In [232]:
# Useful starting lines 
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [233]:
import os 
import sys
import time
import copy
from copy import deepcopy
import pickle
import math
import functools 
from IPython.display import display, HTML
import operator
from operator import itemgetter
import pandas as pd
import seaborn as sns
from matplotlib.lines import Line2D
from glob import glob

sns.set(style="darkgrid")
sns.set_context("paper")
pd.set_option('future.no_silent_downcasting',True)

In [234]:
root_path = '/root/autodl-tmp/ttab'
sys.path.append(root_path)

In [235]:
from monitor.tools.show_results import extract_list_of_records, reorder_records, get_pickle_info, summarize_info
from monitor.tools.plot import plot_curve_wrt_time
import monitor.tools.plot_utils as plot_utils

from monitor.tools.utils import dict_parser
from monitor.tools.file_io import load_pickle

In [236]:
def get_stats(experiment:str,conditions={},root_data_path=os.path.join(root_path,  'logs', 'resnet26')):
    # Have a glimpse of experimental results.
    raw_records = get_pickle_info(root_data_path, [experiment])
    attributes = ['model_adaptation_method', 'n_train_steps', 'episodic','lr', 'model_selection_method', 'seed', 'data_names', 'status']
    records = extract_list_of_records(list_of_records=raw_records, conditions=conditions)
    aggregated_results, averaged_records_overall = summarize_info(records, attributes, reorder_on='model_adaptation_method', groupby_on='test-overall-accuracy', larger_is_better=True)
    return aggregated_results,averaged_records_overall 

In [237]:
experiments =glob(os.path.join(os.pardir,'logs',"*","*"))
conditions = {
    # "model_adaptation_method": ["tent"],
    "seed": [2022],
    # "batch_size": [64],
    # "episodic": [False],
    # "n_train_steps": [50],
    # "lr": [0.005],
    # "data_names": ["cifar10_c_deterministic-gaussian_noise-5"],
}

In [238]:
experiments

['../logs/resnet26/tent_cifar10c_online_oracle_model_selection',
 '../logs/resnet26/tent_cifar100c_online_oracle_model_selection',
 '../logs/resnet26/tent_cifar10_1_online_oracle_model_selection',
 '../logs/resnet26/tent_cifar10c_episodic_oracle_model_selection',
 '../logs/resnet26/tent_cifar100c_episodic_oracle_model_selection',
 '../logs/resnet26/tent_cifar10_1_episodic_oracle_model_selection',
 '../logs/resnet26/sar_cifar10c_online_oracle_model_selection',
 '../logs/resnet26/sar_cifar100c_online_oracle_model_selection',
 '../logs/resnet26/sar_cifar10_1_online_oracle_model_selection',
 '../logs/resnet26/sar_cifar10c_episodic_oracle_model_selection',
 '../logs/resnet26/sar_cifar100c_episodic_oracle_model_selection',
 '../logs/resnet26/sar_cifar10_1_episodic_oracle_model_selection',
 '../logs/resnet50/tent_officehome_online_oracle_model_selection',
 '../logs/resnet50/tent_pacs_online_oracle_model_selection',
 '../logs/resnet50/tent_officehome_episodic_oracle_model_selection',
 '../logs

In [239]:
def get_result(seed):
    result = pd.DataFrame(data=0,columns=['CIFAR10-C','CIFAR100-C','CIFAR10.1','OfficeHome','PACS'],index=pd.MultiIndex.from_product([
        ['TENT','SAR'],
        ['episodic','online']
    ],names=['method','protocol']))
    for experiment in experiments:
        aggregated_results,averaged_records_overall = get_stats(os.path.basename(experiment),{"seed":[seed]},os.path.dirname(experiment))
        adaptation_method = 'TENT' if 'tent' in experiment else 'SAR'
        protocol = 'episodic' if 'episodic' in experiment else 'online'
        if 'cifar10c' in experiment:
            dataset = 'CIFAR10-C'
        elif 'cifar100c' in experiment:
            dataset = 'CIFAR100-C'
        elif 'cifar10_1' in experiment:
            dataset = 'CIFAR10.1'
        elif 'officehome' in experiment:
            dataset = 'OfficeHome'
        elif 'pacs' in experiment:
            dataset = 'PACS'
        else:
            raise NotImplementedError('invalid experiment!')
        result.loc[(adaptation_method,protocol),dataset] = aggregated_results.loc[:,'test-overall-accuracy'].agg('mean')
    return 100-result
    
    

In [242]:
result_2022 = get_result(2022)
result_2023 = get_result(2023)
result_2024 = get_result(2024)

we have 45/135 records.
we have 45/135 records.
we have 3/9 records.
we have 45/135 records.


  result.loc[(adaptation_method,protocol),dataset] = aggregated_results.loc[:,'test-overall-accuracy'].agg('mean')
  result.loc[(adaptation_method,protocol),dataset] = aggregated_results.loc[:,'test-overall-accuracy'].agg('mean')
  result.loc[(adaptation_method,protocol),dataset] = aggregated_results.loc[:,'test-overall-accuracy'].agg('mean')


we have 45/135 records.
we have 3/9 records.
we have 45/135 records.
we have 45/135 records.
we have 3/9 records.
we have 45/135 records.
we have 45/135 records.
we have 3/9 records.
we have 36/108 records.
we have 36/108 records.
we have 36/108 records.
we have 36/108 records.


  result.loc[(adaptation_method,protocol),dataset] = aggregated_results.loc[:,'test-overall-accuracy'].agg('mean')
  result.loc[(adaptation_method,protocol),dataset] = aggregated_results.loc[:,'test-overall-accuracy'].agg('mean')


we have 36/108 records.
we have 36/108 records.
we have 36/108 records.
we have 36/108 records.
we have 45/135 records.


  result.loc[(adaptation_method,protocol),dataset] = aggregated_results.loc[:,'test-overall-accuracy'].agg('mean')
  result.loc[(adaptation_method,protocol),dataset] = aggregated_results.loc[:,'test-overall-accuracy'].agg('mean')
  result.loc[(adaptation_method,protocol),dataset] = aggregated_results.loc[:,'test-overall-accuracy'].agg('mean')


we have 45/135 records.
we have 3/9 records.
we have 45/135 records.
we have 45/135 records.
we have 3/9 records.
we have 45/135 records.
we have 45/135 records.
we have 3/9 records.
we have 45/135 records.
we have 45/135 records.
we have 3/9 records.
we have 36/108 records.
we have 36/108 records.
we have 36/108 records.
we have 36/108 records.
we have 36/108 records.
we have 36/108 records.


  result.loc[(adaptation_method,protocol),dataset] = aggregated_results.loc[:,'test-overall-accuracy'].agg('mean')
  result.loc[(adaptation_method,protocol),dataset] = aggregated_results.loc[:,'test-overall-accuracy'].agg('mean')


we have 36/108 records.
we have 36/108 records.
we have 45/135 records.
we have 45/135 records.
we have 3/9 records.


  result.loc[(adaptation_method,protocol),dataset] = aggregated_results.loc[:,'test-overall-accuracy'].agg('mean')
  result.loc[(adaptation_method,protocol),dataset] = aggregated_results.loc[:,'test-overall-accuracy'].agg('mean')
  result.loc[(adaptation_method,protocol),dataset] = aggregated_results.loc[:,'test-overall-accuracy'].agg('mean')


we have 45/135 records.
we have 45/135 records.
we have 3/9 records.
we have 45/135 records.
we have 45/135 records.
we have 3/9 records.
we have 45/135 records.
we have 45/135 records.
we have 3/9 records.
we have 36/108 records.


  result.loc[(adaptation_method,protocol),dataset] = aggregated_results.loc[:,'test-overall-accuracy'].agg('mean')
  result.loc[(adaptation_method,protocol),dataset] = aggregated_results.loc[:,'test-overall-accuracy'].agg('mean')


we have 36/108 records.
we have 36/108 records.
we have 36/108 records.
we have 36/108 records.
we have 36/108 records.
we have 36/108 records.
we have 36/108 records.
