In [None]:
# %%
import os
import sys
sys.path.append(os.path.abspath(os.path.join('..')))
import utils

import pathlib
import argparse
from tensorboardX import SummaryWriter
import logging
from datetime import datetime
import torch 
import mymodels 
import mydataset 
from torch.utils.data import DataLoader
from utils.myfed import *
import pandas as pd
import yaml

In [None]:
utils.set_seed(0)

In [None]:
# if __file__ is not defined:
if globals().get('__file__') is None:
    __file__ = '04_model_performance.ipynb'

parent_path = pathlib.Path("../").resolve()
yamlfilepath = parent_path.joinpath('config.yaml')
args = yaml.load(yamlfilepath.open('r'), Loader=yaml.FullLoader)
args = argparse.Namespace(**args)
os.environ['CUDA_VISIBLE_DEVICES']=args.gpu

if args.dataset == 'cifar10':
    publicdata = 'cifar100'
    args.N_class = 10
elif args.dataset == 'cifar100':
    publicdata = 'imagenet'
    args.N_class = 100
elif args.dataset == 'pascal_voc2012':
    publicdata = 'mscoco'
    args.N_class = 20

assert args.dataset in ['cifar10', 'cifar100', 'pascal_voc2012']
args.datapath = os.path.expanduser(args.datapath)

In [None]:
priv_data, _, test_dataset, public_dataset, distill_loader = mydataset.data_cifar.dirichlet_datasplit(
    args, privtype=args.dataset, publictype=publicdata, N_parties=args.N_parties, online=not args.oneshot, public_percent=args.public_percent)
test_loader = DataLoader(
    dataset=test_dataset, batch_size=args.batchsize, shuffle=False, num_workers=args.num_workers, sampler=None)

In [None]:
# max_label_counts = 0
# for i in range(target.shape[0]):
#     label_counts = np.sum(target[i])
#     if label_counts > max_label_counts:
#         max_label_counts = label_counts
# print(max_label_counts)

In [None]:
def test(model, test_loader):
    model.eval()
    # testacc = utils.AverageMeter()
    m = torch.nn.Sigmoid()
    output_list = []
    target_list = []

    with torch.no_grad():
        for i, (images, target, _) in enumerate(test_loader):
            images = images.cuda()
            target = target.cuda()
            output = model(images)
            
            output_list.append(m(output).detach().cpu().numpy())
            target_list.append(target.detach().cpu().numpy())
            
            # testacc.update(acc)
    output = np.concatenate(output_list, axis=0)
    target = np.concatenate(target_list, axis=0)
    acc, = utils.accuracy(output, target)
    top_k = utils.multi_label_top_margin_k_accuracy(target, output, margin=0)
    mAP, _ = utils.compute_mean_average_precision(target, output)
    acc, top_k, mAP = round(acc, 4), round(top_k, 4), round(mAP, 4)
    print(f'Acc: {acc}, Top-k: {top_k}, mAP: {mAP}')
    return {'acc': acc, 'top_k': top_k, 'mAP': mAP}


In [None]:
df = pd.DataFrame(columns=['acc', 'top_k', 'mAP'])
for i in range(5):
    model = mymodels.define_model(modelname=args.model_name, num_classes=args.N_class, pretrained=args.pretrained)
    # load_model = torch.load('/home/suncheol/code/FedTest/FedMAD/checkpoints/pascal_voc2012/vit_tiny_patch16_224_multilabel/a1.0+sd1+e300+b64+lkl+slmha/oneshot_c1_q0.0_n0.0_h3/q0.0_n0.0_ADAM_b64_2e-05_200_1e-05_m0.9_e7_0.6575.pt')
    utils.load_dict(f'/home/suncheol/code/FedTest/FED_MHAD_sub/checkpoints/pascal_voc2012/vit_tiny_patch16_224_multilabel_clean_1.0_2e-05/model-{i}.pth', model)
    print("i : ", i)
    result = test(model, test_loader)
    df.loc[i] = result
df.to_csv(f'/home/suncheol/code/FedTest/FED_MHAD_sub/checkpoints/pascal_voc2012/vit_tiny_patch16_224_multilabel_clean_1.0_2e-05/result.csv')
    


In [None]:
args.model_name = 'vit_tiny_patch16_224'
model = mymodels.define_model(modelname=args.model_name, num_classes=args.N_class, pretrained=args.pretrained)
# model.module.setExcludedHead([0])
# load_model = torch.load('/home/suncheol/code/FedTest/FedMAD/checkpoints/pascal_voc2012/vit_tiny_patch16_224_multilabel/a1.0+sd1+e300+b64+lkl+slmha/oneshot_c1_q0.0_n0.0_h3/q0.0_n0.0_ADAM_b64_2e-05_200_1e-05_m0.9_e7_0.6575.pt')
utils.load_dict(f'/home/suncheol/code/FedTest/FED_MHAD/test/04_pascal_voc_fed_avg_multilabel/checkpoints/27197_server_best_models/model_round100_acc0.76_loss0.00.pth', model)
# model.
# print("i : ", i)
test(model, test_loader)

In [None]:
model = mymodels.define_model(modelname=args.model_name, num_classes=args.N_class, pretrained=args.pretrained)
# load_model = torch.load('/home/suncheol/code/FedTest/FedMAD/checkpoints/pascal_voc2012/vit_tiny_patch16_224_multilabel/a1.0+sd1+e300+b64+lkl+slmha/oneshot_c1_q0.0_n0.0_h3/q0.0_n0.0_ADAM_b64_2e-05_200_1e-05_m0.9_e7_0.6575.pt')
load_dict(f'/home/suncheol/code/FedTest/pytorch-model-multiclass/checkpoint/pascal_voc_vit_tiny_patch16_224_0.0001_-1_multilabel/ckpt.pth', model)
print("i : ", i)
test(model, test_loader)

In [None]:
# find .pt files in dir
dir_path = "/home/suncheol/code/FedTest/FED_MHAD_sub/checkpoints/pascal_voc2012/vit_tiny_patch16_224_multilabel_clean_1.0_2e-05/a1.0+sd1+e300+b64+lkl+slmha"
# search .pt files in under dir
file_list = os.listdir(dir_path)
pt_file_list = [file for file in file_list if file.endswith(".pth")]
print(pt_file_list)
df = pd.DataFrame(columns=['acc', 'top_k', 'mAP'])
for i in range(len(pt_file_list)):
    model = mymodels.define_model(modelname=args.model_name, num_classes=args.N_class, pretrained=args.pretrained)
    load_dict(dir_path + '/' + pt_file_list[i], model)
    print("i : ", i)
    result = test(model, test_loader)
    df.loc[i] = result
df.to_csv(dir_path + '/result.csv')
df

In [None]:
df

In [None]:
model = mymodels.define_model(modelname=args.model_name, num_classes=args.N_class, pretrained=args.pretrained)
load_dict(f'/home/suncheol/code/FedTest/FedMAD/checkpoints/pascal_voc2012/vit_tiny_patch16_224_multilabel_noisy_1.0/a1.0+sd1+e300+b64+lkl+slmha/oneshot_c1_q0.0_n0.0_h2/q0.0_n0.0_ADAM_b64_2e-05_200_1e-05_m0.9_e3_0.5793.pt', model)
test(model, test_loader)


In [None]:
load_dict(f'/home/suncheol/code/FedTest/FedMAD/checkpoints/pascal_voc2012/vit_tiny_patch16_224_multilabel_noisy_1.0/a1.0+sd1+e300+b64+lkl+slmha/oneshot_c1_q0.0_n0.0_h1/q0.0_n0.0_ADAM_b64_2e-05_200_1e-05_m0.9_e11_0.5782.pt', model)
test(model, test_loader)

In [None]:

load_dict(f'/home/suncheol/code/FedTest/FedMAD/checkpoints/pascal_voc2012/vit_tiny_patch16_224_multilabel_clean_1.0/a1.0+sd1+e300+b64+lkl+slmha/oneshot_c1_q0.0_n0.0_h1/q0.0_n0.0_ADAM_b64_2e-05_200_1e-05_m0.9_e10_0.6554.pt', model)
test(model, test_loader)


In [None]:
load_dict(f'/home/suncheol/code/FedTest/FedMAD/checkpoints/pascal_voc2012/vit_tiny_patch16_224_multilabel_clean_1.0/a1.0+sd1+e300+b64+lkl+slmha/oneshot_c1_q0.0_n0.0_h2/q0.0_n0.0_ADAM_b64_2e-05_200_1e-05_m0.9_e10_0.6544.pt', model)
test(model, test_loader)

In [None]:
object_categories = ['aeroplane', 'bicycle', 'bird', 'boat',
                     'bottle', 'bus', 'car', 'cat', 'chair',
                     'cow', 'diningtable', 'dog', 'horse',
                     'motorbike', 'person', 'pottedplant',
                     'sheep', 'sofa', 'train', 'tvmonitor']

y_score = np.array(output) 
y_test = np.array(target)

th_ls = [0.1 * i for i in range(10)]
opt_th = 0
best_acc = 0
def get_metrics(y_test, y_score, th):
    y_pred = (y_score > th).astype(int)
    acc = getAccuracy(y_test, y_pred)
    pre = getPrecision(y_test, y_pred)
    rec = getRecall(y_test, y_pred)
    f1 = getF1score(y_test, y_pred)
    return acc, pre, rec, f1

for th in th_ls:
    acc, pre, rec, f1 = get_metrics(y_test, y_score, th)
    if acc > best_acc:
        best_acc = acc
        opt_th = th
        
acc, pre, rec, f1 = get_metrics(y_test, y_score, opt_th)
print("opt threshold = {}".format(opt_th))
print("accuracy = {}".format(acc))
print("precision = {}".format(pre))
print("recall = {}".format(rec))
print("f1 score = {}".format(f1))
print("optimal threshold = {}".format(opt_th), "best f1 score = {}".format(best_acc))

plotMultiROCCurve(y_test, y_score)


In [None]:
ml_cm = multi_label_confusion_matrix(y_test, y_score > opt_th, object_categories)
print(ml_cm)
plotMultilabelconfusionmatrix(y_test, y_score > opt_th, object_categories)


In [None]:
metrics.multilabel_confusion_matrix(y_test, y_score > opt_th)

In [None]:
y_pred = (y_score > opt_th).astype(int)
y_test = y_test.astype(int)
labels = object_categories
cm = metrics.multilabel_confusion_matrix(y_test, y_pred)
''' plot n * 4 subplots '''
nClasses = len(labels)
fig, ax = plt.subplots(int(nClasses/5), 5, figsize=(10, 8))
for axes, cfs_matrix, label in zip(ax.flatten(), cm, labels):
    print(label)
    df_cm = pd.DataFrame(cfs_matrix, index = [i for i in ["True", "False"]],
                columns = [i for i in ["True", "False"]])
    sns.heatmap(df_cm, annot=True, ax = axes, fmt='g')
    axes.set_title(label)
fig.tight_layout()
plt.show()