In [None]:
import os
import re
import wandb
import numpy as np
import pandas as pd

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib

matplotlib.rcParams.update({'font.size': 22})

In [None]:
from constants import WANDB_PROJECT

In [None]:
PLOT_DIR = './plots/'
os.makedirs(PLOT_DIR, exist_ok=True)

## Load Data

In [None]:
%%time
api = wandb.Api(timeout=30)
runs = api.runs(WANDB_PROJECT)

pattern = r"^(.*?)\s(.*?)\s(.*)$"
df = pd.DataFrame()
df_cap = pd.DataFrame()

for run in runs:
    if run.state == 'running' or 'N=' in run.name or '4x' in run.name:
        continue
    
    tmp_df = pd.DataFrame(run.scan_history())
    tmp_df['id'] = run.id
    
    match = re.match(pattern, run.name)
    if match:
        dataset = match.group(1)
        transformation = match.group(2)
        method = match.group(3)
        tmp_df['dataset'] = dataset
        tmp_df['transformation'] = transformation
        tmp_df['method'] = method
        
        df = pd.concat([df, tmp_df])

df = df.drop(columns=['_runtime', '_timestamp'])
df = df.drop(columns=['_step'])

## Accuracy for different Values

In [None]:
def scatter(ax, x, y, error, name, ymax=1):
    r_upper = y + error
    r_lower = y - error
    ax.plot(x, y, label=name, linestyle='-')
    ax.fill_between(x, r_upper, r_lower, alpha=0.3)
    ax.set_xlabel('Scale Factor')
    ax.set_ylabel('Accuracy')
    ax.set_ylim([0, ymax])

def polar(ax, x, y, error, name, ymax=1):
    x = np.linspace(0, 2 * np.pi, 360)
    r_upper = y + error
    r_lower = y - error
    ax.plot(x, y, label=name, linestyle='-')
    ax.fill_between(x, r_upper, r_lower, alpha=0.3)
    ax.set_ylim([0, ymax])
    if ymax == 0.8:
        ax.set_rticks([0.2, 0.4, 0.6, 0.8])
    for label in ax.get_yticklabels():
        label.set_fontsize(16)
        label.set_color('gray')


plot_fun = {
    'Rotation': polar,
    'Scale': scatter
}

In [None]:
for ds in df['dataset'].unique():
    for tf in df['transformation'].unique():
        print(ds, tf)
        
        if tf == 'Rotation':
            target_col = 'Accuracy per Angle'
            target_x = 'Angle'
            
            fig, ax = plt.subplots(subplot_kw={'projection': 'polar'}, figsize=(10,10))
            ax.set_rlabel_position(75)
            ax.set_theta_zero_location('N')
            ax.set_theta_direction(1)
            
        elif tf == 'Scale':
            target_col = 'Accuracy per Scale Factor'
            target_x = 'Scale Factor'

            fig, ax = plt.subplots(figsize=(10,3))

        tmp_df = df.query(f'`dataset`=="{ds}" & `transformation`=="{tf}" & `{target_col}`.notna()').dropna(axis=1)
        accuracy_mean = tmp_df.groupby([target_x, 'method'])[target_col].mean().reset_index()
        accuracy_std = tmp_df.groupby([target_x, 'method'])[target_col].std().reset_index()

        for i, m in enumerate(accuracy_mean['method'].unique()):
            x = accuracy_mean.query(f'`method`=="{m}"')[target_x].to_numpy()
            y = accuracy_mean.query(f'`method`=="{m}"')[target_col].to_numpy()
            error = accuracy_std.query(f'`method`=="{m}"')[target_col].to_numpy()

            if ds == 'CIFAR10':
                ymax = 0.8
            else:
                ymax = 1
                
            plot_fun[tf](ax, x, y, error, m, ymax)

        if ds == 'CIFAR10' and tf == 'Scale':
            ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.4), fancybox=False, frameon=False, ncol=3)
        if ds == 'GTSRB' and tf == 'Scale':
            ax.get_xaxis().set_visible(False)
        plt.savefig(f'{PLOT_DIR}/{ds}_{tf}_acc.pdf',  bbox_inches='tight')
        plt.show()