In [1]:
import os
os.chdir('..')

In [2]:
from pathlib import Path
import json
import yaml
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

In [29]:
run_dir = Path("outputs/cifar10_resnet32")
metric_files = [d for d in run_dir.rglob("0.40_metrics.csv")]

In [30]:
df = pd.concat([pd.read_csv(m) for m in metric_files], ignore_index=True)
df.drop(df.columns[df.columns.str.contains('unnamed',case = False)],axis = 1, inplace = True)
df['train_aug'] = df['train_aug'].apply(lambda x: x if x != "['none']" else 'none')
df = df[df.train_aug != "['crop', 'flip']"]
df.head(5)

Unnamed: 0,test_aug,train_aug,eval_seed,train_seed,test/loss,test/acc
0,blur,gaussian_noise,41,43,1.060982,0.75845
1,none,gaussian_noise,41,43,0.000718,1.0
2,gaussian_noise,gaussian_noise,41,43,0.000915,1.0
3,rotation,gaussian_noise,41,43,0.340328,0.90595
4,brightness_constrast,gaussian_noise,41,43,0.003433,0.99935


In [31]:
df.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 800 entries, 0 to 819
Data columns (total 6 columns):
 #   Column      Non-Null Count  Dtype  
---  ------      --------------  -----  
 0   test_aug    800 non-null    object 
 1   train_aug   800 non-null    object 
 2   eval_seed   800 non-null    int64  
 3   train_seed  800 non-null    int64  
 4   test/loss   800 non-null    float64
 5   test/acc    800 non-null    float64
dtypes: float64(2), int64(2), object(2)
memory usage: 43.8+ KB


In [32]:
df.train_aug.unique(), df.test_aug.unique()

(array(['gaussian_noise', 'shift', 'perspective', 'blur', 'rgb_shift',
        'rotation', 'crop', 'flip', 'brightness_constrast', 'none'],
       dtype=object),
 array(['blur', 'none', 'gaussian_noise', 'rotation',
        'brightness_constrast', 'rgb_shift', 'shift', 'perspective',
        'flip', 'crop'], dtype=object))

In [33]:
group_df = df.groupby(["test_aug", "train_aug"], as_index=False).mean()
group_df.head(5)

Unnamed: 0,test_aug,train_aug,eval_seed,train_seed,test/loss,test/acc
0,blur,blur,41.25,43.5,0.001187,1.0
1,blur,brightness_constrast,41.25,43.5,1.172791,0.710494
2,blur,crop,41.25,43.5,1.751756,0.669306
3,blur,flip,41.25,43.5,1.494725,0.67605
4,blur,gaussian_noise,41.25,43.5,0.990839,0.765931


In [34]:
corr = group_df.pivot(index="test_aug", columns="train_aug", values="test/acc")
corr.sort_values(by='none', inplace=True)
corr = corr[corr.index]

In [35]:
(corr*100).style.background_gradient(vmin=50, vmax=100, axis=None).set_precision(2)

  (corr*100).style.background_gradient(vmin=50, vmax=100, axis=None).set_precision(2)


train_aug,blur,perspective,rotation,crop,flip,shift,gaussian_noise,brightness_constrast,rgb_shift,none
test_aug,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
blur,100.0,95.63,80.26,66.93,67.61,87.04,76.59,71.05,68.12,70.86
perspective,96.14,99.83,95.71,85.99,85.35,97.91,88.37,86.11,83.21,85.07
rotation,92.0,95.08,99.84,93.51,92.49,94.36,90.74,90.48,89.73,90.43
crop,94.93,99.73,99.33,100.0,96.95,99.87,95.1,94.98,94.4,94.56
flip,94.07,96.5,96.92,96.89,100.0,96.45,94.75,94.78,94.51,94.57
shift,98.37,99.95,99.77,98.95,97.22,100.0,96.89,96.5,95.72,96.01
gaussian_noise,99.98,99.67,98.65,98.41,98.34,99.4,100.0,98.92,98.49,98.63
brightness_constrast,99.92,99.88,99.84,99.88,99.86,99.89,99.92,100.0,99.91,99.9
rgb_shift,99.99,99.98,99.96,99.97,99.97,99.98,99.98,99.98,100.0,99.98
none,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0,100.0


In [36]:
gain_corr = corr.apply(lambda x: x - corr.none, axis=0) * 100
gain_corr.style.background_gradient(cmap='PuOr', 
    vmin=-40, vmax=40, axis=None).set_precision(2)

  gain_corr.style.background_gradient(cmap='PuOr',


train_aug,blur,perspective,rotation,crop,flip,shift,gaussian_noise,brightness_constrast,rgb_shift,none
test_aug,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
blur,29.14,24.77,9.4,-3.93,-3.26,16.18,5.73,0.19,-2.74,0.0
perspective,11.07,14.75,10.64,0.92,0.27,12.83,3.29,1.03,-1.86,0.0
rotation,1.57,4.65,9.41,3.08,2.06,3.93,0.31,0.05,-0.7,0.0
crop,0.37,5.16,4.77,5.44,2.39,5.31,0.54,0.42,-0.17,0.0
flip,-0.5,1.92,2.35,2.32,5.43,1.88,0.17,0.21,-0.06,0.0
shift,2.37,3.94,3.76,2.94,1.22,3.99,0.88,0.49,-0.29,0.0
gaussian_noise,1.35,1.04,0.03,-0.21,-0.29,0.77,1.37,0.3,-0.14,0.0
brightness_constrast,0.01,-0.03,-0.07,-0.03,-0.04,-0.01,0.01,0.09,0.01,0.0
rgb_shift,0.01,0.0,-0.02,-0.01,-0.01,0.0,0.0,0.01,0.02,0.0
none,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
