In [49]:
import sys
# setting path
sys.path.append('../')
import os
import numpy as np

import nets as models
import functions as fns

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import time
import pickle

DEVICE = "cuda"

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def get_avg(self):
        return self.avg
    
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        # print(val, self.sum, self.count, n)
        
def compute_accuracy(output, target, debug=False):
    output = output.argmax(dim=1)
    acc = 0.0
    acc = torch.sum(target == output).item()
    acc = acc/output.size(0)*100
    if debug:
        print(target)
        print(output)
        input()
    # print(acc)
    return acc

def eval(test_loader, model, debug=False):
    batch_time = AverageMeter()
    acc = AverageMeter()

    # switch to eval mode
    model.eval()

    end = time.time()
    for i, (images, target) in enumerate(test_loader):

        if len(target.shape) > 1: target = target.reshape(len(target))

        images = images.to(DEVICE)
        target = target.to(DEVICE)
        
        output = model(images)
        batch_acc = compute_accuracy(output, target, debug=debug)
        acc.update(batch_acc, images.size(0))
        # batch_time.update(time.time() - end)
        end = time.time()

        # Update statistics
        # estimated_time_remained = batch_time.get_avg()*(len(test_loader)-i-1)
        # fns.update_progress(i, len(test_loader), 
        #     ESA='{:8.2f}'.format(estimated_time_remained)+'s',
        #     acc='{:4.2f}'.format(float(batch_acc))
        #     )
    print('Test accuracy: {:4.2f}% (time = {:8.2f}s)'.format(
            float(acc.get_avg()), batch_time.get_avg()*len(test_loader)))
    # print('===================================================================')
    return float(acc.get_avg())


In [22]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
train_dataset = datasets.CIFAR100(root="../data", train=True, download=True,
    transform=transform)
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=128, shuffle=True,
    num_workers=1, pin_memory=True)
test_dataset = datasets.CIFAR100(root="../data", train=False, download=True,
    transform=transform)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=128, shuffle=True,
    num_workers=1, pin_memory=True)

global_model = torch.load("../projects/pretrained_model/cifar100_predefined_a03_56c_fedavg/last_model.pth.tar")

global_train_acc = eval(train_loader, global_model, debug=False)
global_test_acc = eval(test_loader, global_model, debug=False)

print("Train :", global_train_acc)
print("Test :", global_test_acc)

Files already downloaded and verified
Files already downloaded and verified
Test accuracy: 79.93% (time =     6.40s)
Test accuracy: 80.54% (time =     1.31s)
Train : 79.928
Test : 80.54


In [66]:
for client_id in range(56):
    train_dataset_path = os.path.join("../data/32_Cifar100_NIID_56c_a05_fix", "train", f"{client_id}.pkl")
    train_data = pickle.load(open(train_dataset_path, "rb"))
    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=128,
        shuffle=True)

    test_dataset_path = os.path.join("../data/32_Cifar100_NIID_56c_a05_fix", "test", f"{client_id}.pkl")
    test_data = pickle.load(open(test_dataset_path, "rb"))
    test_loader = torch.utils.data.DataLoader(
        test_data,
        batch_size=128,
        shuffle=True)

    global_model = torch.load("../projects/pretrained_model/cifar100_predefined_a03_56c_fedavg/last_model.pth.tar")

    print(f"========= Client {client_id} =========")
    global_train_acc = eval(train_loader, global_model, debug=False)
    global_test_acc = eval(test_loader, global_model, debug=False)


Test accuracy: 77.68% (time =     0.00s)
Test accuracy: 77.50% (time =     0.00s)
Test accuracy: 79.98% (time =     0.00s)
Test accuracy: 78.87% (time =     0.00s)
Test accuracy: 79.33% (time =     0.00s)
Test accuracy: 80.32% (time =     0.00s)
Test accuracy: 84.20% (time =     0.00s)
Test accuracy: 84.13% (time =     0.00s)
Test accuracy: 82.54% (time =     0.00s)
Test accuracy: 80.29% (time =     0.00s)
Test accuracy: 79.42% (time =     0.00s)
Test accuracy: 79.35% (time =     0.00s)
Test accuracy: 75.41% (time =     0.00s)
Test accuracy: 75.66% (time =     0.00s)
Test accuracy: 78.25% (time =     0.00s)
Test accuracy: 81.54% (time =     0.00s)
Test accuracy: 78.05% (time =     0.00s)
Test accuracy: 78.99% (time =     0.00s)
Test accuracy: 81.54% (time =     0.00s)
Test accuracy: 78.81% (time =     0.00s)
Test accuracy: 81.99% (time =     0.00s)
Test accuracy: 79.41% (time =     0.00s)
Test accuracy: 76.87% (time =     0.00s)
Test accuracy: 79.93% (time =     0.00s)
Test accuracy: 8

In [38]:
train_dataset_path = os.path.join("../data/32_Cifar100_NIID_56c_a03", "server", f"train.pkl")
train_data = pickle.load(open(train_dataset_path, "rb"))
train_loader = torch.utils.data.DataLoader(
    train_data,
    batch_size=128,
    shuffle=True)

test_dataset_path = os.path.join("../data/32_Cifar100_NIID_56c_a03", "server", f"test.pkl")
test_data = pickle.load(open(test_dataset_path, "rb"))
test_loader = torch.utils.data.DataLoader(
    test_data,
    batch_size=128,
    shuffle=True)

global_model = torch.load("../projects/pretrained_model/cifar100_predefined_a03_56c_fedavg/last_model.pth.tar")

print(f"========= Server =========")
global_train_acc = eval(train_loader, global_model, debug=False)
global_test_acc = eval(test_loader, global_model, debug=False)

Test accuracy: 79.73% (time =     0.00s)
Test accuracy: 80.67% (time =     0.00s)


In [57]:
test_counts = np.zeros((56,100))

In [64]:
for client_id in range(56):
    train_dataset_path = os.path.join("../data/32_Cifar100_NIID_56c_a03_fix", "train", f"{client_id}.pkl")
    train_data = pickle.load(open(train_dataset_path, "rb"))

    label_counts = np.unique(train_data.labels, return_counts=True)
    for idx in range(len(label_counts[0])):
        label = label_counts[0][idx]
        count = label_counts[1][idx]
        test_counts[client_id][label] = count

In [65]:
np.sum(test_counts, axis = 0)

array([536., 587., 500., 480., 511., 609., 619., 565., 682., 769., 603.,
       539., 666., 589., 531., 552., 493., 568., 564., 545., 468., 520.,
       732., 606., 712., 491., 704., 610., 658., 549., 544., 656., 475.,
       552., 618., 781., 659., 628., 480., 661., 490., 480., 795., 665.,
       633., 621., 541., 556., 481., 683., 516., 563., 598., 597., 511.,
       611., 535., 597., 590., 646., 730., 484., 541., 727., 513., 547.,
       544., 667., 623., 487., 534., 576., 479., 703., 607., 481., 579.,
       474., 535., 743., 665., 764., 568., 611., 750., 521., 768., 786.,
       682., 786., 662., 593., 852., 834., 669., 788., 754., 755., 593.,
       884.])