In [95]:
import os
import re
from collections import defaultdict
from itertools import product
import pickle as pkl

import torch
import pandas as pd
import numpy as np

network_list = ['resnet18', 'vgg16', 'densenet121']
dataset_list = ['cifar10', 'gtsrb', 'imagenet']
method_list  = ['badnet', 'sig', 'ref', 'warp', 'imc', 'uap', 'ulp']

logdir = "/scr/songzhu/trojai/uapattack/result"
logs = os.listdir(logdir)

In [114]:



    
res_dict = defaultdict(list)

for combo in product(method_list, dataset_list, network_list):
    
    method, dataset, network = combo
    logfiles = list(filter(re.compile(f'({method}_{dataset}_{network})+.*').match, logs))
    
    if dataset == 'cifar10':
        asr_thresh:float = 0.90
        acc_thresh:float = 0.80
    else:
        asr_thresh:float = 0.95
        acc_thresh:float = 0.90
    
    for logfile in logfiles:
        
        with open(os.path.join(logdir, logfile), 'rb') as f:
            config = pkl.load(f)
            result = pkl.load(f)
        f.close()
        
        n_epochs = config['train'][config['args']['dataset']]['N_EPOCHS']
        if n_epochs < 5:
            continue
        
        res_dict['dataset'].append(dataset)
        res_dict['network'].append(network)
        res_dict['method'].append(method)
        
        res_dict['seed'].append(config['args']['seed'])
        res_dict['use_clip'].append(config['train']['USE_CLIP'])
        res_dict['use_transform'].append(config['train']['USE_TRANSFORM'])
        res_dict['use_advtrain'].append(config['adversarial']['ADV_TRAIN'])
        res_dict['use_pretrain'].append(config['network']['PRETRAINED'])
        
        res_dict['acc'].append(max(result['test_clean_acc']))
        res_dict['asr'].append(max(result['test_troj_acc']))
        
        cond = (np.array(result['test_troj_acc']) >= asr_thresh) & (np.array(result['test_clean_acc']) >= acc_thresh)
        
        res_dict['t'].append(np.where(cond==True)[0].min() if sum(cond)>0 else n_epochs)

In [115]:
pd.DataFrame(res_dict).drop_duplicates().sort_values(by=['dataset', 'method'])

Unnamed: 0,dataset,network,method,seed,use_clip,use_transform,use_advtrain,acc,asr,t
0,cifar10,resnet18,badnet,77,False,False,False,0.835316,0.999001,18
2,cifar10,vgg16,badnet,77,False,False,False,0.888811,0.999001,6
15,cifar10,resnet18,imc,77,False,False,False,0.925707,0.999001,8
9,cifar10,resnet18,ref,77,False,False,False,0.833917,0.999001,18
10,cifar10,vgg16,ref,77,False,False,False,0.888011,0.999001,4
5,cifar10,resnet18,sig,77,False,False,False,0.862714,0.999001,12
6,cifar10,vgg16,sig,77,False,False,False,0.888611,0.999001,3
16,cifar10,resnet18,ulp,77,False,False,False,0.829217,0.999001,18
17,cifar10,vgg16,ulp,77,False,False,False,0.888911,0.999001,3
12,cifar10,resnet18,warp,77,False,False,False,0.871775,0.990431,51
