In [1]:
import os
import argparse
import random
from copy import deepcopy
import torchvision
import torchvision.transforms as transforms
from torch import nn
import sys
import torch

torch.manual_seed(0)

sys.path.append("../../../../../")

from fedlab.core.client.trainer import SerialTrainer
from fedlab.utils.aggregator import Aggregators
from fedlab.utils.serialization import SerializationTool
from fedlab.utils.dataset.slicing import noniid_slicing, random_slicing
from fedlab.utils.functional import get_best_gpu

In [2]:
class AverageMeter(object):
    """Record train infomation"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0.0
        self.avg = 0.0
        self.sum = 0.0
        self.count = 0.0

    def update(self, val, n=1):
        self.val = val
        self.sum += val
        self.count += n
        self.avg = self.sum / self.count
        
def evaluate(model, criterion, test_loader):
    model.eval()
    gpu = next(model.parameters()).device

    loss_ = AverageMeter()
    acc_ = AverageMeter()

    with torch.no_grad():
        for inputs, labels in test_loader:

            inputs = inputs.to(gpu)
            labels = labels.to(gpu)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            _, predicted = torch.max(outputs, 1)
            loss_.update(loss.item())
            acc_.update(torch.sum(predicted.eq(labels)).item(), len(labels))

    return loss_.sum, acc_.avg

In [3]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010))
])

trainset = torchvision.datasets.CIFAR10(
    root='../../../../datasets/data/cifar10/',
    train=True,
    download=True,
    transform=transform_train)

#train_loader = torch.utils.data.DataLoader(trainset,batch_size=128)

transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010))
    ])

testset = torchvision.datasets.CIFAR10(
        root='../../../../datasets/data/cifar10/',
        train=False,
        download=True,
        transform=transform_test)

testloader = torch.utils.data.DataLoader(testset,
                                         batch_size=int(len(testset) / 10),
                                         drop_last=False,
                                         shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [14]:

# FL settings

aggregator = Aggregators.fedavg_aggregate
total_client_num = 100  # client总数
num_per_round = int(total_client_num * 0.1)


data_indices = random_slicing(trainset, num_clients=100)


# fedlab setup
local_model = torchvision.models.resnet18()

trainer = SerialTrainer(model=deepcopy(local_model),
                        dataset=trainset,
                        data_slices=data_indices,
                        aggregator=aggregator,
                        args={
                                "batch_size": 100,
                                "lr": 0.01,
                                "epochs": 5
                            })

losses = []
acces = []

# train procedure
to_select = [i for i in range(total_client_num)]
for round in range(100):
    model_parameters = SerializationTool.serialize_model(local_model)
    selection = random.sample(to_select, num_per_round)
    print(selection)
    aggregated_parameters = trainer.train(model_parameters=model_parameters,
                                          id_list=selection,
                                          aggregate=True)

    SerializationTool.deserialize_model(local_model, aggregated_parameters)

    criterion = nn.CrossEntropyLoss()
    loss, acc = evaluate(local_model, criterion, testloader)
    print("loss: {:.4f}, acc: {:.2f}".format(loss, acc))
    losses.append(loss)
    acces.append(acc)

[10, 82, 90, 25, 52, 58, 18, 56, 77, 76]
2021-08-27 17:02:30,114 - root - INFO - starting training process of client [10]
2021-08-27 17:02:31,775 - root - INFO - starting training process of client [82]
2021-08-27 17:02:33,440 - root - INFO - starting training process of client [90]
2021-08-27 17:02:35,100 - root - INFO - starting training process of client [25]
2021-08-27 17:02:36,758 - root - INFO - starting training process of client [52]
2021-08-27 17:02:38,410 - root - INFO - starting training process of client [58]
2021-08-27 17:02:40,067 - root - INFO - starting training process of client [18]
2021-08-27 17:02:41,732 - root - INFO - starting training process of client [56]
2021-08-27 17:02:43,400 - root - INFO - starting training process of client [77]
2021-08-27 17:02:45,055 - root - INFO - starting training process of client [76]
loss: 53.8500, acc: 0.10
[30, 3, 99, 42, 85, 40, 49, 94, 86, 76]
2021-08-27 17:02:53,103 - root - INFO - starting training process of client [30]
202

KeyboardInterrupt: 

In [9]:
from fedlab.utils.functional import load_dict
data_indices = load_dict("cifar10_iid.pkl")

In [11]:
data_indices.keys()

dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99])

In [12]:
len(model_parameters)

11689512