In [None]:
import re
import os
import json
import math
from glob import glob
from collections import defaultdict

import numpy as np
import pandas as pd
import wandb

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib

matplotlib.rcParams.update({'font.size': 20})

In [None]:
from constants import WANDB_PROJECT

In [None]:
PLOT_DIR = './plots/'
os.makedirs(PLOT_DIR, exist_ok=True)

In [None]:
%%time
api = wandb.Api(timeout=30)
runs = api.runs(WANDB_PROJECT)

pattern = r"^(.*?)\s(.*?)\s(.*)$"
df = pd.DataFrame()

for run in runs:
    if run.state != 'finished' or '(' in run.name or 'CIFAR10 Rotation' not in run.name or \
     '_n-' in run.name:
        continue

    match = re.match(pattern, run.name)
    if match:
        opt = {k: v['value'] for k, v in json.loads(run.json_config).items()}
        opt['n_params'] = int(opt['orig_n_params']) if 'orig_n_params' in opt else opt['n_params']
        tmp_df = pd.DataFrame(run.scan_history())
        dataset = match.group(1)
        transformation = match.group(2)
        method = match.group(3)
        tmp_df['id'] = run.id
        mean_acc = tmp_df.dropna(subset=['Accuracy per Angle'])['Accuracy per Angle'].mean()
        std_acc = tmp_df.dropna(subset=['Accuracy per Angle'])['Accuracy per Angle'].std()
        df = pd.concat([df, pd.DataFrame({'accuracy': mean_acc, 'std': std_acc, 'dataset': dataset, 'transformation': transformation, 'method': method, 'n_params': opt['n_params']}, index=[0])])

In [None]:
df = df.sort_values(['method'])
df.replace('SCN D=3', 'SCN D=3 N=28', inplace=True)
df.replace('SCN D=5', 'SCN D=5 N=28', inplace=True)

# move extra methods to the end, to ensure that the colors in all plots are consistently the same
for m in ['DA 4x', 'SCN D=5 N=56', 'SCN D=5 N=114']:
    mask = df['method'] == m
    df_not_masked = df[~mask]
    df_masked = df[mask]
    df = pd.concat([df_not_masked, df_masked])

In [None]:
for ds in df['dataset'].unique():
    for tf in df['transformation'].unique():
        fig = plt.figure(figsize=(10, 5))

        for m in df['method'].unique():
            tmp_df = df.query(f'dataset == "{ds}" and transformation == "{tf}" and method == "{m}"')
            vals = tmp_df[['accuracy', 'n_params', 'std']].mean().to_numpy()

            mkr_size = 6
            if m == 'Baseline':
                mkr = '+'
            elif m == 'Inverse':
                mkr = 'x'
            elif 'DA' in m:
                mkr = 's'
                if '2' in m:
                    mkr_size = 6#8
                elif '4' in m:
                    mkr_size = 8#12
                else:
                    mkr_size = 4
            elif 'SCN' in m:
                mkr = 'o'
                if '56' in m:
                    mkr_size = 6
                elif '114' in m:
                    mkr_size = 8#12
                else:
                    mkr_size = 4
            
            markers, caps, bars = plt.errorbar(x=vals[1], y=vals[0], label=m, yerr=vals[2], fmt=mkr, capsize=12, markersize=mkr_size*2)
        plt.xscale('log')
        plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.39), fancybox=False, frameon=False, ncol=3)
        plt.xlabel('Number of Parameters')
        plt.ylabel('Accuracy')
        
        plt.savefig(f'{PLOT_DIR}/acc_parameters.pdf',  bbox_inches='tight')   
        plt.show()