In [167]:
import os
import re
from collections import defaultdict
from itertools import product
import pickle as pkl

import torch
import pandas as pd
import numpy as np

network_list = ['resnet18', 'vgg16', 'densenet121']
dataset_list = ['cifar10', 'gtsrb', 'imagenet']
method_list  = ['badnet', 'sig', 'ref', 'warp', 'imc', 'uap', 'ulp']

logdir = "/scr/songzhu/trojai/uapattack/result"
logs = os.listdir(logdir)

In [168]:
res_dict = defaultdict(list)

for combo in product(method_list, dataset_list, network_list):
    
    method, dataset, network = combo
    logfiles = list(filter(re.compile(f'({method}_{dataset}_{network})+.*').match, logs))
    
    if dataset == 'cifar10':
        asr_thresh:float = 0.90
        acc_thresh:float = 0.80
    elif dataset == 'gtsrb':
        asr_thresh:float = 0.95
        acc_thresh:float = 0.90
    else:
        asr_thresh:float = 0.9
        acc_thresh:float = 0.8
    
    for logfile in logfiles:
        
        with open(os.path.join(logdir, logfile), 'rb') as f:
            config = pkl.load(f)
            result = pkl.load(f)
        f.close()
        
        n_epochs = config['train'][config['args']['dataset']]['N_EPOCHS']
        if n_epochs < 5:
            continue
        
        res_dict['dataset'].append(dataset)
        res_dict['network'].append(network)
        res_dict['method'].append(method)
        
        res_dict['seed'].append(config['args']['seed'])
        res_dict['use_clip'].append(config['train']['USE_CLIP'])
        res_dict['use_transform'].append(config['train']['USE_TRANSFORM'])
        res_dict['use_advtrain'].append(config['adversarial']['ADV_TRAIN'])
        res_dict['use_pretrain'].append(config['network']['PRETRAINED'])
        
        res_dict['acc'].append(max(result['test_clean_acc']))
        res_dict['asr'].append(max(result['test_troj_acc']))
        
        cond = (np.array(result['test_troj_acc']) >= asr_thresh) & (np.array(result['test_clean_acc']) >= acc_thresh)
        
        res_dict['t'].append(np.where(cond==True)[0].min() if sum(cond)>0 else n_epochs)

In [169]:
res_dict = pd.DataFrame(res_dict).drop_duplicates().groupby(by=['dataset', 'network', 'method', 'seed']).first().reset_index()

### No Clip, No Transform, No AdvTrain, No PreTrain

In [170]:
agg_dict = res_dict.loc[((~res_dict['use_clip']) & (~res_dict['use_transform']) & (~res_dict['use_advtrain']) & (~res_dict['use_pretrain']))]
agg_dict = agg_dict[['dataset', 'network', 'method', 'acc', 'asr', 't']].groupby(by=['dataset', 'network', 'method']).agg(func=['mean', 'std'])
agg_dict['ACC'] = agg_dict['acc']['mean'].round(3).astype('str') + r' $\pm$ ' + agg_dict['acc']['std'].round(3).astype('str')
agg_dict['ASR'] = agg_dict['asr']['mean'].round(3).astype('str') + r' $\pm$ ' + agg_dict['asr']['std'].round(3).astype('str')
agg_dict['T'] =agg_dict['t']['mean'].round(3).astype('str') + r' $\pm$ ' + agg_dict['t']['std'].round(3).astype('str')
agg_dict[['ACC', 'ASR', 'T']].columns = ['ACC', 'ASR', 'T']
agg_dict[['ACC', 'ASR', 'T']]

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,ACC,ASR,T
dataset,network,method,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
cifar10,resnet18,badnet,0.843 $\pm$ 0.016,0.999 $\pm$ 0.0,19.0 $\pm$ 1.732
cifar10,resnet18,imc,0.895 $\pm$ 0.043,0.999 $\pm$ 0.0,9.5 $\pm$ 2.121
cifar10,resnet18,ref,0.843 $\pm$ 0.012,0.999 $\pm$ 0.0,17.0 $\pm$ 1.414
cifar10,resnet18,sig,0.866 $\pm$ 0.011,0.999 $\pm$ 0.0,13.0 $\pm$ 3.606
cifar10,resnet18,ulp,0.841 $\pm$ 0.016,0.999 $\pm$ 0.0,17.0 $\pm$ 1.414
cifar10,resnet18,warp,0.796 $\pm$ nan,0.999 $\pm$ nan,200.0 $\pm$ nan
cifar10,vgg16,badnet,0.889 $\pm$ 0.0,0.999 $\pm$ 0.0,7.333 $\pm$ 3.215
cifar10,vgg16,imc,0.886 $\pm$ 0.003,0.999 $\pm$ 0.0,2.0 $\pm$ 0.0
cifar10,vgg16,ref,0.887 $\pm$ 0.002,0.999 $\pm$ 0.0,4.5 $\pm$ 0.707
cifar10,vgg16,sig,0.888 $\pm$ 0.002,0.999 $\pm$ 0.0,4.667 $\pm$ 2.887
