In [1]:
import json
import numpy as np
import pandas as pd
from pathlib import Path
import seaborn as sns
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.ticker as ticker
from scipy import stats

%matplotlib inline

In [2]:
common_corruptions = ['gaussian_noise', 
                      'shot_noise', 
                      'impulse_noise', 
                      'defocus_blur', 
                      'glass_blur', 
                      'motion_blur',
                      'zoom_blur',
                      'snow',
                      'frost',
                      'fog',
                      'brightness',
                      'contrast', 
                      'elastic_transform', 
                      'pixelate', 
                      'jpeg_compression']

In [3]:
# exp_path = Path('./')
exp_path = Path('./exps/poem/ablation')

# exp_type = 'in_dist'
# exp_type = 'natural_shift'
# exp_type = 'bs1'


methods_order = ['no_adapt', 'tent', 'cotta', 'eata', 'sar', 'poem']  # add other methods in desired order


df = pd.DataFrame()
for p in exp_path.glob('**/*.csv'):
    curr_df = pd.read_csv(p)
    df = pd.concat([df, curr_df], ignore_index=True)

no_adapt_df_vit = pd.read_csv('/home/yarinbar/poem/exps/no_adapt/bs1/2024-06-26_21:41:34--vitbase_timm-level5-seed2024_069de6.csv')
no_adapt_df_res = pd.read_csv('/home/yarinbar/poem/exps/no_adapt/bs1/2024-06-26_21:42:08--resnet50_gn_timm-level5-seed2024_debba1.csv')
df = pd.concat([df, no_adapt_df_vit, no_adapt_df_res], ignore_index=True)

# Convert the 'timestamp' column to datetime
df['timestamp'] = pd.to_datetime(df['timestamp'], format='%Y-%m-%d_%H:%M:%S')
df['severity_list'] = df['severity_list'].apply(lambda x: ','.join(map(str, x)) if isinstance(x, list) else x)

# Setting the categorical type with the specified order
df['method'] = pd.Categorical(df['method'], categories=methods_order, ordered=True)
df = df.sort_values('method')



df['sort_key'] = df['method'] == 'poem'
    # Then sort by this key (False values first, then True values)
df = df.sort_values('sort_key')
# df.columns


In [4]:
level = 5

filtered_df = df[
    # (df['model'] == 'vitbase_timm') &
    # (df['model'] == 'resnet50_gn_timm') &
    (df['level'] == level)
]
filtered_df = filtered_df.drop(columns=['u_before', 'u_after', 'martingales'])
filtered_df.loc[filtered_df['method'] == 'no_adapt', 'vanilla_loss'] = True

filtered_df['corruption'] = pd.Categorical(filtered_df['corruption'], categories=common_corruptions, ordered=True)


filtered_df.columns

Index(['top1', 'top5', 'ece', 'model_delta', 'data', 'data_corruption',
       'v2_path', 'output', 'seed', 'gpu', 'debug', 'workers',
       'test_batch_size', 'if_shuffle', 'level', 'corruption', 'fisher_size',
       'fisher_alpha', 'e_margin', 'd_margin', 'method', 'model', 'exp_type',
       'cont_size', 'severity_list', 'temp', 'exp_comment', 'sar_margin_e0',
       'imbalance_ratio', 'gamma', 'eps_clip', 'lr_factor', 'vanilla_loss',
       'device', 'adapt', 'timestamp', 'exp_name', 'lr', 'print_freq',
       'sort_key'],
      dtype='object')

In [5]:
stats = filtered_df.groupby(['model', 'method', 'vanilla_loss']).agg(
        mean_top1=('top1', 'mean'),
        sem_top1=('top1', 'sem'),
        mean_top5=('top5', 'mean'),
        sem_top5=('top5', 'sem'),
        mean_ece=('ece', 'mean'),
        sem_ece=('ece', 'sem'),
        mean_model_delta=('model_delta', 'mean'),
        sem_model_delta=('model_delta', 'sem')
    )

pivot = filtered_df.pivot_table(
    index=['method', 'vanilla_loss'], 
    columns='model', 
    values='top1', 
    # aggfunc='first'
)


s = pivot.to_latex(
        escape=False,
        index=True,
        caption='Mean Top-1 Accuracy by Method and Model across All Corruptions',
        label='tab:mean_accuracy',
        position='ht',
        float_format="{:0.2f}".format
    )
print(s)
pivot


\begin{table}[ht]
\centering
\caption{Mean Top-1 Accuracy by Method and Model across All Corruptions}
\label{tab:mean_accuracy}
\begin{tabular}{llrr}
\toprule
     & model &  resnet50_gn_timm &  vitbase_timm \\
method & vanilla_loss &                   &               \\
\midrule
no_adapt & True &             31.46 &         51.65 \\
poem & False &             32.49 &         60.64 \\
     & True &             38.92 &         67.36 \\
\bottomrule
\end{tabular}
\end{table}



Unnamed: 0_level_0,model,resnet50_gn_timm,vitbase_timm
method,vanilla_loss,Unnamed: 2_level_1,Unnamed: 3_level_1
no_adapt,True,31.464534,51.646577
poem,False,32.486044,60.637049
poem,True,38.923395,67.357689


In [27]:

stats = filtered_df.groupby(['model', 'method', 'corruption', 'vanilla_loss']).agg(
        mean_top1=('top1', 'mean'),
        sem_top1=('top1', 'sem'),
        mean_top5=('top5', 'mean'),
        sem_top5=('top5', 'sem'),
        mean_ece=('ece', 'mean'),
        sem_ece=('ece', 'sem'),
        mean_model_delta=('model_delta', 'mean'),
        sem_model_delta=('model_delta', 'sem')
    )

pivot = filtered_df.pivot_table(
    index=['model', 'method', 'vanilla_loss'], 
    columns='corruption', 
    values='top1', 
    # aggfunc='first'
)


s = pivot.to_latex(
        escape=False,
        index=True,
        caption='Mean Top-1 Accuracy by Method and Model across All Corruptions',
        label='tab:mean_accuracy',
        position='ht',
        float_format="{:0.2f}".format
    )
print(s)
pivot

\begin{table}[ht]
\centering
\caption{Mean Top-1 Accuracy by Method and Model across All Corruptions}
\label{tab:mean_accuracy}
\begin{tabular}{lllrrrrrrrrrrrrrrr}
\toprule
             &      & corruption &  gaussian_noise &  shot_noise &  impulse_noise &  defocus_blur &  glass_blur &  motion_blur &  zoom_blur &  snow &  frost &   fog &  brightness &  contrast &  elastic_transform &  pixelate &  jpeg_compression \\
model & method & vanilla_loss &                 &             &                &               &             &              &            &       &        &       &             &           &                    &           &                   \\
\midrule
resnet50_gn_timm & no_adapt & True &           22.22 &       23.14 &          22.12 &         19.96 &       11.53 &        21.42 &      24.88 & 40.43 &  47.07 & 33.84 &       68.81 &     36.29 &              18.58 &     29.09 &             52.58 \\
             & poem & False &           28.77 &       31.54 &          29.92 &

Unnamed: 0_level_0,Unnamed: 1_level_0,corruption,gaussian_noise,shot_noise,impulse_noise,defocus_blur,glass_blur,motion_blur,zoom_blur,snow,frost,fog,brightness,contrast,elastic_transform,pixelate,jpeg_compression
model,method,vanilla_loss,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
resnet50_gn_timm,no_adapt,True,22.221333,23.143999,22.122667,19.962667,11.525333,21.424,24.882668,40.431999,47.069332,33.837334,68.805336,36.293335,18.576,29.093334,52.578667
resnet50_gn_timm,poem,False,28.7664,31.539733,29.915733,17.096534,10.909333,26.024,28.107733,35.548801,40.892266,5.085867,69.787199,43.081599,16.205866,48.639999,55.2448
resnet50_gn_timm,poem,True,39.844889,42.24,41.030222,18.988889,22.052444,37.881777,36.395556,21.912889,41.248,20.188,71.947556,50.429333,8.765778,55.868,57.970666
vitbase_timm,no_adapt,True,49.674667,50.312,49.967999,42.807999,34.386665,50.533333,44.741333,56.634666,52.304001,56.685333,75.834663,31.864,46.992001,65.618668,66.341331
vitbase_timm,poem,False,56.026666,57.192889,56.92311,55.575556,49.620444,58.577333,54.682667,63.085334,59.824889,65.501778,77.159111,59.440888,57.462667,70.119556,68.481778
vitbase_timm,poem,True,60.956889,62.623556,62.479111,60.154221,60.724444,65.275555,63.410222,70.063555,68.600889,73.435555,79.539111,63.728889,70.632,75.456443,73.485334


In [46]:
import pandas as pd
import numpy as np

# Read the CSV data


print("\nShape of filtered dataframe:")
print(df_filtered.shape)

print("\nUnique methods in filtered dataframe:")
print(df_filtered['method'].unique())

# Check if 'vanilla_loss' column exists
if 'vanilla_loss' in df_filtered.columns:
    print("\n'vanilla_loss' column exists")
    print("Unique values in 'vanilla_loss':")
    print(df_filtered['vanilla_loss'].unique())
else:
    print("\n'vanilla_loss' column does not exist")

df_filtered = df_filtered.sort_values(['corruption', 'method', 'vanilla_loss'])

# Create a pivot table
pivot = df_filtered.pivot_table(
    index=['method', 'vanilla_loss'], 
    columns='corruption', 
    values='top1', 
    aggfunc='first'
)

print("\nPivot table:")
print(pivot)


Shape of filtered dataframe:
(150, 43)

Unique methods in filtered dataframe:
['no_adapt', 'poem']
Categories (6, object): ['no_adapt' < 'tent' < 'cotta' < 'eata' < 'sar' < 'poem']

'vanilla_loss' column exists
Unique values in 'vanilla_loss':
[nan False True]

Pivot table:
corruption           brightness   contrast  defocus_blur  elastic_transform  \
method vanilla_loss                                                           
poem   False          82.288002  80.250664     55.522667          78.575996   
       True           79.608002  61.056000     60.234665          70.613335   

corruption                 fog      frost  gaussian_noise  glass_blur  \
method vanilla_loss                                                     
poem   False         70.389336  77.237335       56.274666   77.624001   
       True          73.549332  68.858665       80.573334   79.930664   

corruption           impulse_noise  jpeg_compression  motion_blur   pixelate  \
method vanilla_loss               