In [1]:
# %%
import os
import sys
sys.path.append(os.path.abspath(os.path.join('..')))
import utils

import pathlib
import argparse
from tensorboardX import SummaryWriter
import logging
from datetime import datetime
import torch 
import mymodels 
import mydataset 
from torch.utils.data import DataLoader
from utils.myfed import *
import yaml

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def set_seed(seed):
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(0)
set_seed(0)

In [3]:
# if __file__ is not defined:
if globals().get('__file__') is None:
    __file__ = '04_model_performance.ipynb'

parent_path = pathlib.Path("../").resolve()
yamlfilepath = parent_path.joinpath('config.yaml')
args = yaml.load(yamlfilepath.open('r'), Loader=yaml.FullLoader)
args = argparse.Namespace(**args)
os.environ['CUDA_VISIBLE_DEVICES']=args.gpu

if args.dataset == 'cifar10':
    publicdata = 'cifar100'
    args.N_class = 10
elif args.dataset == 'cifar100':
    publicdata = 'imagenet'
    args.N_class = 100
elif args.dataset == 'pascal_voc2012':
    publicdata = 'mscoco'
    args.N_class = 20

assert args.dataset in ['cifar10', 'cifar100', 'pascal_voc2012']
args.datapath = os.path.expanduser(args.datapath)

In [4]:
priv_data, _, test_dataset, public_dataset, distill_loader = mydataset.data_cifar.dirichlet_datasplit(
    args, privtype=args.dataset, publictype=publicdata, N_parties=args.N_parties, online=not args.oneshot, public_percent=args.public_percent)
test_loader = DataLoader(
    dataset=test_dataset, batch_size=args.batchsize, shuffle=False, num_workers=args.num_workers, sampler=None)

pascal_voc2012 mscoco
size of public dataset:  (3246, 224, 224, 3) images
size of test dataset:  (5823, 3, 224, 224) images
size of split_arr: (20, 5)
y label : [[1. 0. 0. ... 0. 0. 0.]
 [1. 0. 0. ... 0. 0. 0.]
 [1. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 1. 0. 1.]
 [0. 0. 0. ... 1. 0. 1.]
 [0. 0. 0. ... 1. 0. 1.]]
Party_0 data shape: (809, 3, 224, 224)
Party_1 data shape: (1011, 3, 224, 224)
Party_2 data shape: (1132, 3, 224, 224)
Party_3 data shape: (1162, 3, 224, 224)
Party_4 data shape: (1165, 3, 224, 224)
Public data shape: (3246, 224, 224, 3)
Test data shape: (5823, 3, 224, 224)


In [5]:
# max_label_counts = 0
# for i in range(target.shape[0]):
#     label_counts = np.sum(target[i])
#     if label_counts > max_label_counts:
#         max_label_counts = label_counts
# print(max_label_counts)

In [6]:
import utils

In [7]:
def test(model, test_loader):
    model.eval()
    # testacc = utils.AverageMeter()
    m = torch.nn.Sigmoid()
    output_list = []
    target_list = []

    with torch.no_grad():
        for i, (images, target, _) in enumerate(test_loader):
            images = images.cuda()
            target = target.cuda()
            output = model(images)
            
            output_list.append(m(output).detach().cpu().numpy())
            target_list.append(target.detach().cpu().numpy())
            
            # testacc.update(acc)
    output = np.concatenate(output_list, axis=0)
    target = np.concatenate(target_list, axis=0)
    acc, = utils.accuracy(output, target)
    top_k = utils.multi_label_top_margin_k_accuracy(target, output, margin=0)
    mAP, _ = utils.compute_mean_average_precision(target, output)
    print(acc, top_k, mAP)
    return acc, top_k, mAP

from utils.utils import *
for i in range(5):
    model = mymodels.define_model(modelname=args.model_name, num_classes=args.N_class, pretrained=args.pretrained)
    # load_model = torch.load('/home/suncheol/code/FedTest/FedMAD/checkpoints/pascal_voc2012/vit_tiny_patch16_224_multilabel/a1.0+sd1+e300+b64+lkl+slmha/oneshot_c1_q0.0_n0.0_h3/q0.0_n0.0_ADAM_b64_2e-05_200_1e-05_m0.9_e7_0.6575.pt')
    load_dict(f'/home/suncheol/code/FedTest/FedMAD/checkpoints/pascal_voc2012/vit_tiny_patch16_224_multilabel/model-{i}.pth', model)
    print("i : ", i)
    test(model, test_loader)
    


/home/suncheol/.cache/torch/hub/checkpoints/deit_tiny_patch16_224-a1311bcf.pth
i :  0
0.4271480909038863 0.40786536149750985 0.4673507400351453
/home/suncheol/.cache/torch/hub/checkpoints/deit_tiny_patch16_224-a1311bcf.pth
i :  1
0.515538954719791 0.525845783960158 0.6287089305266559
/home/suncheol/.cache/torch/hub/checkpoints/deit_tiny_patch16_224-a1311bcf.pth
i :  2
0.5877206970715486 0.6065601923407179 0.6709985903519831
/home/suncheol/.cache/torch/hub/checkpoints/deit_tiny_patch16_224-a1311bcf.pth
i :  3
0.5319079512278909 0.5440494590417311 0.6372888444751957
/home/suncheol/.cache/torch/hub/checkpoints/deit_tiny_patch16_224-a1311bcf.pth
i :  4
0.5557035728596803 0.5644856603125536 0.6366725028210375


In [9]:
model = mymodels.define_model(modelname=args.model_name, num_classes=args.N_class, pretrained=args.pretrained)
# load_model = torch.load('/home/suncheol/code/FedTest/FedMAD/checkpoints/pascal_voc2012/vit_tiny_patch16_224_multilabel/a1.0+sd1+e300+b64+lkl+slmha/oneshot_c1_q0.0_n0.0_h3/q0.0_n0.0_ADAM_b64_2e-05_200_1e-05_m0.9_e7_0.6575.pt')
load_dict(f'/home/suncheol/code/FedTest/FedMAD/checkpoints/pascal_voc2012/vit_tiny_patch16_224_multilabel/a1.0+sd1+e300+b64+lkl+slNone/oneshot_c1_q0.0_n0.0_h0/q0.0_n0.0_ADAM_b64_2e-05_200_1e-05_m0.9_e14_0.6473.pt', model)
print("i : ", i)
test(model, test_loader)

/home/suncheol/.cache/torch/hub/checkpoints/deit_tiny_patch16_224-a1311bcf.pth
i :  4
0.647316879697099 0.6963764382620642 0.7292937571144642


(0.647316879697099, 0.6963764382620642, 0.7292937571144642)

In [10]:
model = mymodels.define_model(modelname=args.model_name, num_classes=args.N_class, pretrained=args.pretrained)
load_dict(f'/home/suncheol/code/FedTest/FedMAD/checkpoints/pascal_voc2012/vit_tiny_patch16_224_multilabel/a1.0+sd1+e300+b64+lkl+slmha/oneshot_c1_q0.0_n0.0_h3/q0.0_n0.0_ADAM_b64_2e-05_200_1e-05_m0.9_e14_0.6578.pt', model)
test(model, test_loader)


/home/suncheol/.cache/torch/hub/checkpoints/deit_tiny_patch16_224-a1311bcf.pth
0.6577774506677162 0.6991241628026791 0.7299143802594539


(0.6577774506677162, 0.6991241628026791, 0.7299143802594539)

In [None]:
load_dict(f'/home/suncheol/code/FedTest/FedMAD/checkpoints/pascal_voc2012/vit_tiny_patch16_224_multilabel/a1.0+sd1+e300+b64+lkl+slmha/oneshot_c1_q0.0_n0.0_h3/q0.0_n0.0_ADAM_b64_2e-05_200_1e-05_m0.9_e7_0.6557.pt', model)
test(model, test_loader)
load_dict(f'/home/suncheol/code/FedTest/FedMAD/checkpoints/pascal_voc2012/vit_tiny_patch16_224_multilabel/a1.0+sd1+e300+b64+lkl+slmha/oneshot_c1_q0.0_n0.0_h3/q0.0_n0.0_ADAM_b64_2e-05_200_1e-05_m0.9_e7_0.6544.pt', model)
test(model, test_loader)
load_dict(f'/home/suncheol/code/FedTest/FedMAD/checkpoints/pascal_voc2012/vit_tiny_patch16_224_multilabel/a1.0+sd1+e300+b64+lkl+slmha/oneshot_c1_q0.0_n0.0_h3/q0.0_n0.0_ADAM_b64_2e-05_200_1e-05_m0.9_e24_0.6505.pt', model)
test(model, test_loader)

In [None]:
object_categories = ['aeroplane', 'bicycle', 'bird', 'boat',
                     'bottle', 'bus', 'car', 'cat', 'chair',
                     'cow', 'diningtable', 'dog', 'horse',
                     'motorbike', 'person', 'pottedplant',
                     'sheep', 'sofa', 'train', 'tvmonitor']

y_score = np.array(output) 
y_test = np.array(target)

th_ls = [0.1 * i for i in range(10)]
opt_th = 0
best_acc = 0
def get_metrics(y_test, y_score, th):
    y_pred = (y_score > th).astype(int)
    acc = getAccuracy(y_test, y_pred)
    pre = getPrecision(y_test, y_pred)
    rec = getRecall(y_test, y_pred)
    f1 = getF1score(y_test, y_pred)
    return acc, pre, rec, f1

for th in th_ls:
    acc, pre, rec, f1 = get_metrics(y_test, y_score, th)
    if acc > best_acc:
        best_acc = acc
        opt_th = th
        
acc, pre, rec, f1 = get_metrics(y_test, y_score, opt_th)
print("opt threshold = {}".format(opt_th))
print("accuracy = {}".format(acc))
print("precision = {}".format(pre))
print("recall = {}".format(rec))
print("f1 score = {}".format(f1))
print("optimal threshold = {}".format(opt_th), "best f1 score = {}".format(best_acc))

plotMultiROCCurve(y_test, y_score)


In [None]:
ml_cm = multi_label_confusion_matrix(y_test, y_score > opt_th, object_categories)
print(ml_cm)
plotMultilabelconfusionmatrix(y_test, y_score > opt_th, object_categories)


In [None]:
metrics.multilabel_confusion_matrix(y_test, y_score > opt_th)

In [None]:
y_pred = (y_score > opt_th).astype(int)
y_test = y_test.astype(int)
labels = object_categories
cm = metrics.multilabel_confusion_matrix(y_test, y_pred)
''' plot n * 4 subplots '''
nClasses = len(labels)
fig, ax = plt.subplots(int(nClasses/5), 5, figsize=(10, 8))
for axes, cfs_matrix, label in zip(ax.flatten(), cm, labels):
    print(label)
    df_cm = pd.DataFrame(cfs_matrix, index = [i for i in ["True", "False"]],
                columns = [i for i in ["True", "False"]])
    sns.heatmap(df_cm, annot=True, ax = axes, fmt='g')
    axes.set_title(label)
fig.tight_layout()
plt.show()