In [None]:
import sys
import os
import copy
import logging
import torch
import pandas as pd

import matplotlib.pyplot as plt
from collections import Counter
import numpy as np
import torch.nn as nn

sys.path.append("../")

cwd = os.getcwd()
project_root = os.path.abspath(os.path.join(cwd, "../.."))
sys.path.append(project_root)

# configuration
from munch import Munch
from fedlab.models.mlp import MLP
from fedlab.models.build_model import build_model
from fedlab.utils.dataset.functional import partition_report
from fedlab.utils import Logger, SerializationTool, Aggregators, LogitAdjust, LA_KD, DaAggregator
from fedlab.contrib.algorithm.basic_server import SyncServerHandler


In [None]:
args = Munch
args.total_client = 10
args.alpha = 0.5
args.seed = 0
args.preprocess = True
args.dataname = "cifar10"
args.model = "Resnet18"
args.pretrained = 1
args.num_users = args.total_client
args.device = "cuda" if torch.cuda.is_available() else "cpu"

if args.dataname == "cifar10":
    args.n_classes = 10

In [2]:
import sys
import os
import copy
import logging
import torch
import pandas as pd

import matplotlib.pyplot as plt
from collections import Counter
import numpy as np
import torch.nn as nn

from sklearn.metrics import balanced_accuracy_score, accuracy_score, confusion_matrix

sys.path.append("../")

cwd = os.getcwd()
project_root = os.path.abspath(os.path.join(cwd, "../.."))
sys.path.append(project_root)

# configuration
from munch import Munch
from fedlab.models.mlp import MLP
from fedlab.models.build_model import build_model
from fedlab.utils.dataset.functional import partition_report
from fedlab.utils import Logger, SerializationTool, Aggregators, LogitAdjust, LA_KD, DaAggregator
from fedlab.contrib.algorithm.basic_server import SyncServerHandler

args = Munch

args.total_client = 10
args.alpha = 0.5
args.seed = 0
args.preprocess = True
args.dataname = "cifar10"
args.model = "Resnet18"
args.pretrained = 1
args.num_users = args.total_client
#args.device = "cuda" if torch.cuda.is_available() else "cpu"
args.device = "cuda"


if args.dataname == "cifar10":
    args.n_classes = 10


logging.basicConfig(level=logging.INFO,
                        format='[%(asctime)s.%(msecs)03d] %(message)s', 
                        datefmt='%H:%M:%S',
                        stream=sys.stdout)

logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))


# We provide a example usage of patitioned CIFAR10 dataset
# Download raw CIFAR10 dataset and partition them according to given configuration

from torchvision import transforms
from fedlab.contrib.dataset.partitioned_cifar10 import PartitionedCIFAR10

############################################
#           Set up the dataset             #
############################################


fed_cifar10 = PartitionedCIFAR10(root="../datasets/cifar10/",
                                  path="../datasets/cifar10/fedcifar10/",
                                  dataname=args.dataname,
                                  num_clients=args.total_client,
                                  num_classes=args.n_classes,
                                  balance=True,
                                  partition="dirichlet",
                                  seed=args.seed,
                                  dir_alpha=args.alpha,
                                  preprocess=args.preprocess,
                                  download=True,
                                  verbose=True,
                                  transform=transforms.ToTensor())

# Get the dataset for the 0-th client
dataset_train = fed_cifar10.get_dataset(0, type="train")
dataset_test = fed_cifar10.get_dataset(0, type="test")

# Get the dataloaders
dataloader_train = fed_cifar10.get_dataloader(0, batch_size=128, type="train")
dataloader_test = fed_cifar10.get_dataloader(0, batch_size=128, type="test")

logging.info(
    f"train: {Counter(fed_cifar10.targets_train)}, total: {len(fed_cifar10.targets_train)}")
logging.info(
    f"test: {Counter(fed_cifar10.targets_test)}, total: {len(fed_cifar10.targets_test)}")


############################################
#                  Dataset                 #
############################################

# generate partition report
csv_file = "./partition-reports/cifar10_hetero_dir_0.3_10clients.csv"
partition_report(fed_cifar10.targets_train, fed_cifar10.data_indices_train, 
                 class_num=args.n_classes, 
                 verbose=False, file=csv_file)


hetero_dir_part_df = pd.read_csv(csv_file,header=0)
hetero_dir_part_df = hetero_dir_part_df.set_index('cid')
col_names = [f"class-{i}" for i in range(args.n_classes)]
for col in col_names:
    hetero_dir_part_df[col] = (hetero_dir_part_df[col] * hetero_dir_part_df['TotalAmount']).astype(int)

#select first 10 clients for bar plot
hetero_dir_part_df[col_names].iloc[:10].plot.barh(stacked=True)  
plt.tight_layout()
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.xlabel('sample num')
plt.savefig(f"./imgs/cifar10_dir_10clients_for_fedavg.png", dpi=400, bbox_inches = 'tight')
plt.show()


Files already downloaded and verified
Files already downloaded and verified


KeyboardInterrupt: 

In [2]:

# client
from fedlab.contrib.algorithm.basic_client import SGDSerialClientTrainer, SGDClientTrainer

# local train configuration
args.epochs = 5
args.batch_size = 128
args.lr = 0.1

trainer = SGDSerialClientTrainer(model, args.total_client, cuda=args.cuda) # serial trainer
# trainer = SGDClientTrainer(model, cuda=True) # single trainer

trainer.setup_dataset(fed_cifar10)
trainer.setup_optim(args.epochs, args.batch_size, args.lr)

# server
from fedlab.contrib.algorithm.basic_server import SyncServerHandler

# global configuration
args.com_round = 10
args.sample_ratio = 0.1

handler = SyncServerHandler(model=model, global_round=args.com_round, sample_ratio=args.sample_ratio, cuda=args.cuda)

from fedlab.utils.functional import evaluate
from fedlab.core.standalone import StandalonePipeline

from torch import nn
from torch.utils.data import DataLoader
import torchvision
import matplotlib.pyplot as plt
import numpy as np

from fedlab.utils.functional import evaluate
from fedlab.core.standalone import StandalonePipeline

from torch import nn
from torch.utils.data import DataLoader
import torchvision

class EvalPipeline(StandalonePipeline):
    def __init__(self, handler, trainer, test_loader):
        super().__init__(handler, trainer)
        self.test_loader = test_loader 
        self.loss = []
        self.acc = []
        
    def main(self):
        t=0
        while self.handler.if_stop is False:
            # server side
            sampled_clients = self.handler.sample_clients()
            broadcast = self.handler.downlink_package
            
            # client side
            self.trainer.local_process(broadcast, sampled_clients)
            uploads = self.trainer.uplink_package

            # server side
            for pack in uploads:
                self.handler.load(pack)

            loss, acc = evaluate(self.handler.model, nn.CrossEntropyLoss(), self.test_loader)
            print("Round {}, Loss {:.4f}, Test Accuracy {:.4f}".format(t, loss, acc))
            t+=1
            self.loss.append(loss)
            self.acc.append(acc)
    
    def show(self):
        plt.figure(figsize=(8,4.5))
        ax = plt.subplot(1,2,1)
        ax.plot(np.arange(len(self.loss)), self.loss)
        ax.set_xlabel("Communication Round")
        ax.set_ylabel("Loss")
        
        ax2 = plt.subplot(1,2,2)
        ax2.plot(np.arange(len(self.acc)), self.acc)
        ax2.set_xlabel("Communication Round")
        ax2.set_ylabel("Accuarcy")
        
        
test_data = torchvision.datasets.CIFAR10(root="../datasets/cifar10/",
                                       train=False,
                                       transform=transforms.ToTensor())
test_loader = DataLoader(test_data, batch_size=1024)

standalone_eval = EvalPipeline(handler=handler, trainer=trainer, test_loader=test_loader)
standalone_eval.main()

standalone_eval.show()

Training on client 3: 100%|██████████| 1/1 [03:50<00:00, 230.24s/it]


Round 0, Loss 5.4143, Test Accuracy 0.1752


Training on client 1:   0%|          | 0/1 [00:48<?, ?it/s]


KeyboardInterrupt: 