In [2]:
import os
os.chdir('..')

In [3]:
from pathlib import Path
import json
import yaml
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

In [4]:
run_dir = Path("outputs/cifar10_resnet32")
metric_files = [d for d in run_dir.rglob("0.00_metrics.csv")]

In [5]:
df = pd.concat([pd.read_csv(m) for m in metric_files], ignore_index=True)
df.drop(df.columns[df.columns.str.contains('unnamed',case = False)],axis = 1, inplace = True)
df.head(5)

Unnamed: 0,test_aug,train_aug,eval_seed,train_seed,test/loss,test/acc
0,blur,gaussian_noise,40,42,1.241636,0.6892
1,none,gaussian_noise,40,42,0.434248,0.8822
2,gaussian_noise,gaussian_noise,40,42,0.441935,0.8807
3,rotation,gaussian_noise,40,42,0.683744,0.8196
4,brightness_constrast,gaussian_noise,40,42,0.454187,0.8773


In [6]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 110 entries, 0 to 109
Data columns (total 6 columns):
 #   Column      Non-Null Count  Dtype  
---  ------      --------------  -----  
 0   test_aug    110 non-null    object 
 1   train_aug   110 non-null    object 
 2   eval_seed   110 non-null    int64  
 3   train_seed  110 non-null    int64  
 4   test/loss   110 non-null    float64
 5   test/acc    110 non-null    float64
dtypes: float64(2), int64(2), object(2)
memory usage: 5.3+ KB


In [7]:
group_df = df.groupby(["test_aug", "train_aug"], as_index=False).mean()
group_df.head(5)

Unnamed: 0,test_aug,train_aug,eval_seed,train_seed,test/loss,test/acc
0,blur,"['crop', 'flip']",40.0,42.0,1.771309,0.6322
1,blur,['none'],40.0,42.0,1.435793,0.6537
2,blur,blur,40.0,42.0,0.522323,0.8575
3,blur,brightness_constrast,40.0,42.0,1.352724,0.6411
4,blur,crop,40.0,42.0,2.184507,0.6017


In [8]:
corr = group_df.pivot(index="test_aug", columns="train_aug", values="test/acc")

In [9]:
(corr*100).style.background_gradient(vmin=80, vmax=100, axis=None).set_precision(2)

  (corr*100).style.background_gradient(vmin=80, vmax=100, axis=None).set_precision(2)


train_aug,"['crop', 'flip']",['none'],blur,brightness_constrast,crop,flip,gaussian_noise,perspective,rgb_shift,rotation,shift
test_aug,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
blur,63.22,65.37,85.75,64.11,60.17,61.61,68.92,86.86,62.28,72.94,79.68
brightness_constrast,93.39,87.63,85.82,88.28,92.26,91.22,87.73,91.19,87.75,91.2,91.22
crop,93.34,87.01,85.11,87.07,92.32,91.02,87.41,91.26,86.99,91.48,91.25
flip,93.64,87.64,86.05,88.24,92.17,91.41,87.74,91.24,87.82,91.87,91.45
gaussian_noise,91.26,85.54,85.85,85.98,90.46,88.84,88.07,90.94,85.4,89.74,90.21
none,93.72,88.16,86.33,88.68,92.61,91.4,88.22,91.75,88.33,91.96,91.79
perspective,81.48,78.72,84.82,78.34,78.49,79.5,80.2,91.23,76.3,88.0,89.42
rgb_shift,93.38,87.78,85.75,88.06,92.26,91.33,87.93,91.45,88.34,91.49,91.6
rotation,87.99,81.42,81.1,81.51,86.21,85.54,81.96,87.16,80.87,91.12,86.63
shift,92.08,86.33,85.62,86.5,90.68,89.72,86.49,91.44,85.73,91.78,91.6


In [10]:
gain_corr = corr.apply(lambda x: x - corr.iloc[:,1], axis=0) * 100
gain_corr.style.background_gradient(cmap='PuOr', 
    vmin=-40, vmax=40, axis=None).set_precision(2)

  gain_corr.style.background_gradient(cmap='PuOr',


train_aug,"['crop', 'flip']",['none'],blur,brightness_constrast,crop,flip,gaussian_noise,perspective,rgb_shift,rotation,shift
test_aug,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
blur,-2.15,0.0,20.38,-1.26,-5.2,-3.76,3.55,21.49,-3.09,7.57,14.31
brightness_constrast,5.76,0.0,-1.81,0.65,4.63,3.59,0.1,3.56,0.12,3.57,3.59
crop,6.33,0.0,-1.9,0.06,5.31,4.01,0.4,4.25,-0.02,4.47,4.24
flip,6.0,0.0,-1.59,0.6,4.53,3.77,0.1,3.6,0.18,4.23,3.81
gaussian_noise,5.72,0.0,0.31,0.44,4.92,3.3,2.53,5.4,-0.14,4.2,4.67
none,5.56,0.0,-1.83,0.52,4.45,3.24,0.06,3.59,0.17,3.8,3.63
perspective,2.76,0.0,6.1,-0.38,-0.23,0.78,1.48,12.51,-2.42,9.28,10.7
rgb_shift,5.6,0.0,-2.03,0.28,4.48,3.55,0.15,3.67,0.56,3.71,3.82
rotation,6.57,0.0,-0.32,0.09,4.79,4.12,0.54,5.74,-0.55,9.7,5.21
shift,5.75,0.0,-0.71,0.17,4.35,3.39,0.16,5.11,-0.6,5.45,5.27
