In [1]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]  =  "TRUE"
import numpy as np
import time

import argparse
# import logging

from tqdm import tqdm
import pickle

import matplotlib 
import matplotlib.pyplot as plt
matplotlib.use('Agg')
%matplotlib inline

from random import Random

import torch
import torch.distributed as dist
import torch.utils.data.distributed
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.multiprocessing import Process
from torch.autograd import Variable
import torchvision
from torchvision import datasets, transforms
# logging.basicConfig(level=logging.INFO)

class Partition(object):
    """ Dataset-like object, but only access a subset of it. """

    def __init__(self, data, index):
        self.data = data
        self.index = index

    def __len__(self):
        return len(self.index)

    def __getitem__(self, index):
        data_idx = self.index[index]
        return self.data[data_idx]

class DataPartitioner(object):
    """ Partitions a dataset into different chuncks. """
    def __init__(self, data, sizes=[0.7, 0.2, 0.1], seed=1234, isNonIID=False, alpha=0, dataset=None):
        self.data = data
        self.dataset = dataset
        if isNonIID:
            self.partitions, self.ratio = self.__getDirichletData__(data, sizes, seed, alpha)

        else:
            self.partitions = [] 
            self.ratio = sizes
            rng = Random() 
            rng.seed(seed) 
            data_len = len(data) 
            indexes = [x for x in range(0, data_len)] 
            rng.shuffle(indexes) 
             
     
            for frac in sizes: 
                part_len = int(frac * data_len)
                self.partitions.append(indexes[0:part_len])
                indexes = indexes[part_len:]

        

    def use(self, partition):
        return Partition(self.data, self.partitions[partition])

    def __getNonIIDdata__(self, data, sizes, seed, alpha):
        labelList = data.train_labels
        rng = Random()
        rng.seed(seed)
        a = [(label, idx) for idx, label in enumerate(labelList)]
        # Same Part
        labelIdxDict = dict()
        for label, idx in a:
            labelIdxDict.setdefault(label,[])
            labelIdxDict[label].append(idx)
        labelNum = len(labelIdxDict)
        labelNameList = [key for key in labelIdxDict]
        labelIdxPointer = [0] * labelNum
        # sizes = number of nodes
        partitions = [list() for i in range(len(sizes))]
        eachPartitionLen= int(len(labelList)/len(sizes))
        # majorLabelNumPerPartition = ceil(labelNum/len(partitions))
        majorLabelNumPerPartition = 2
        basicLabelRatio = alpha

        interval = 1
        labelPointer = 0

        #basic part
        for partPointer in range(len(partitions)):
            requiredLabelList = list()
            for _ in range(majorLabelNumPerPartition):
                requiredLabelList.append(labelPointer)
                labelPointer += interval
                if labelPointer > labelNum - 1:
                    labelPointer = interval
                    interval += 1
            for labelIdx in requiredLabelList:
                start = labelIdxPointer[labelIdx]
                idxIncrement = int(basicLabelRatio*len(labelIdxDict[labelNameList[labelIdx]]))
                partitions[partPointer].extend(labelIdxDict[labelNameList[labelIdx]][start:start+ idxIncrement])
                labelIdxPointer[labelIdx] += idxIncrement

        #random part
        remainLabels = list()
        for labelIdx in range(labelNum):
            remainLabels.extend(labelIdxDict[labelNameList[labelIdx]][labelIdxPointer[labelIdx]:])
        rng.shuffle(remainLabels)
        for partPointer in range(len(partitions)):
            idxIncrement = eachPartitionLen - len(partitions[partPointer])
            partitions[partPointer].extend(remainLabels[:idxIncrement])
            rng.shuffle(partitions[partPointer])
            remainLabels = remainLabels[idxIncrement:]

        return partitions

    def __getDirichletData__(self, data, psizes, seed, alpha):
        n_nets = len(psizes)
        K = 10
        labelList = np.array(data.targets)
        min_size = 0
        N = len(labelList)
        np.random.seed(seed)

        net_dataidx_map = {}
        while min_size < K:
            idx_batch = [[] for _ in range(n_nets)]
            # for each class in the dataset
            for k in range(K):
                idx_k = np.where(labelList == k)[0]
                np.random.shuffle(idx_k)
                proportions = np.random.dirichlet(np.repeat(alpha, n_nets))
                ## Balance
                proportions = np.array([p*(len(idx_j)<N/n_nets) for p,idx_j in zip(proportions,idx_batch)])
                proportions = proportions/proportions.sum()
                proportions = (np.cumsum(proportions)*len(idx_k)).astype(int)[:-1]
                idx_batch = [idx_j + idx.tolist() for idx_j,idx in zip(idx_batch,np.split(idx_k,proportions))]
                min_size = min([len(idx_j) for idx_j in idx_batch])

        for j in range(n_nets):
            np.random.shuffle(idx_batch[j])
            net_dataidx_map[j] = idx_batch[j]
            
        net_cls_counts = {}

        for net_i, dataidx in net_dataidx_map.items():
            unq, unq_cnt = np.unique(labelList[dataidx], return_counts=True)
            tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}
            net_cls_counts[net_i] = tmp
        print('Data statistics: %s' % str(net_cls_counts))

        local_sizes = []
        for i in range(n_nets):
            local_sizes.append(len(net_dataidx_map[i]))
        local_sizes = np.array(local_sizes)
        weights = local_sizes/np.sum(local_sizes)
        print(weights)

        return idx_batch, weights

def partition_dataset(size, args_alpha):
    print('==> load train data')
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    trainset = torchvision.datasets.CIFAR10(root='./data/',
                                            train=True, 
                                            download=True, 
                                            transform=transform_train)
    
    partition_sizes = [1.0 / size for _ in range(size)]
    partition = DataPartitioner(trainset, partition_sizes, isNonIID=True, alpha=args_alpha)
    ratio = partition.ratio
#     partition = partition.use(rank)
#     train_loader = torch.utils.data.DataLoader(partition, 
#                                             batch_size=32, 
#                                             shuffle=True, 
#                                             )


    partitions = partition.partitions
    train_sets = []
    for k in range(size):
        local_partition = Partition(trainset, partitions[k])
        train_sets.append(local_partition)

    
        
    print('==> load test data')
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    testset = torchvision.datasets.CIFAR10(root='./data/', 
                                        train=False, 
                                        download=True, 
                                        transform=transform_test)
        
    
    
    test_loader = torch.utils.data.DataLoader(testset, 
                                            batch_size=32, #64
                                            shuffle=False, 
                                            )

    # You can add more datasets here
    return train_sets, test_loader, ratio


# if __name__ == '__main__':
#     parser = argparse.ArgumentParser(description='CIFAR-10 baseline')
#     parser.add_argument('--client_num','-cN', 
#                     default=10, 
#                     type=int, 
#                     help='the number of clients')
#     parser.add_argument('--round_num','-rN', 
#                     default=10, 
#                     type=int, 
#                     help='the number of communication rounds')
#     parser.add_argument('--round_drift','-rd', 
#                     default=1, 
#                     type=float, 
#                     help='round drift') 
#     parser.add_argument('--client_drift','-cd', 
#                     default=0.1, 
#                     type=float, 
#                     help='client drift')
#     parser.add_argument('--lr', 
#                     default=0.1, 
#                     type=float, 
#                     help='client learning rate')
#     parser.add_argument('--rank', 
#                     default=0, 
#                     type=int, 
#                     help='the rank of worker')

#?????????bs 是啥

#     parser.add_argument('--bs', 
#                     default=32, 
#                     type=int, 
#                     help='batch size on each worker/client')
#     parser.add_argument('--NIID',
#                     default=True,
#                     action='store_true',
#                     help='whether the dataset is non-iid or not')
#     parser.add_argument('--datapath',
#                     default='./data/',
#                     type=str,
#                     help='directory to load data')
#     args = parser.parse_args(args=[])

In [2]:


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.fc1 = nn.Linear(2048, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


def client_update(client_model, optimizer, train_loader, epoch, num_clients, pk, batch_size):
    client_model.train()
    Grad_accumulator = []
    for e in range(epoch):
        grad_batch_idx = []
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            output = client_model(data)
            loss_fed = F.nll_loss(output, target)
            loss = loss_fed/(epoch )
            loss.backward()
           
            if Grad_accumulator == []:
                Grad_accumulator = list(i.grad for i in list(client_model.parameters()))        
            else:
                h = list(i.grad for i in list(client_model.parameters()))
                Grad_accumulator = [Grad_accumulator[i]+h[i] for i in range(len(Grad_accumulator))]
            optimizer.step()

    
    nabla_P_norm2 = sum([(torch.norm(a))**2 for a in Grad_accumulator]).item()
    grad_client = (1/(batch_size))*np.sqrt(nabla_P_norm2)
    return loss_fed.item(),grad_client


def client_update_fed(client_model, optimizer, train_loader, epoch):
    client_model.train()
    for e in range(epoch):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            output = client_model(data)
            loss = F.nll_loss(output, target)/epoch
            loss.backward()
            optimizer.step()
    return loss.item()

def server_aggregate(global_model, client_models, num_clients, pk, client_index):
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_models[i].state_dict()[k]/(num_clients*pk[client_index[i]]) for i in range(len(client_models))], 0).mean(0)
    global_model.load_state_dict(global_dict)
    for model in client_models:
        model.load_state_dict(global_model.state_dict())

def test(global_model, test_loader):
    global_model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.cuda(), target.cuda()
            output = global_model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    acc = correct / len(test_loader.dataset)

    return test_loss, acc


In [3]:
# NON-IID case: every client has images of two categories chosen from [0, 1], [2, 3], [4, 5], [6, 7], or [8, 9].

args_alpha = 0.1
# Hyperparameters
size = 100 #size = num_clients
num_clients = 100
num_selected = 10
num_rounds = 400
epochs = 6
batch_size = 32
local_ep_list = np.random.choice(range(1,epochs+1),size=num_clients)
lr = 0.01

# Creating decentralized datasets
train_loader, test_loader, DataRatios = \
        partition_dataset(size, args_alpha)

# print("==========  This is train_loader:  =============")
# print(type(train_loader[0]))
# fn = getattr(train_loader[0], '__main__')
# print(fn())
# print(train_loader)

# Instantiate models and optimizers

global_model = Net().cuda()
client_models = [Net().cuda() for _ in range(num_selected)]
for model in client_models:
    model.load_state_dict(global_model.state_dict())
# print(global_model.state_dict())

opt = [optim.SGD(model.parameters(), lr=lr) for model in client_models]
# opt = optim.SGD(model.parameters(), lr=0.01) 
                 
# Runnining FL
p_initial = np.ones(num_clients)/num_clients # initialize the probability vector
p_usersampling = p_initial
        
test_loss_accu=[]
acc_accu = []
for r in range(num_rounds):
    # select random clients
#     client_idx = np.random.permutation(num_clients)[:num_selected]
    client_idx = np.random.choice(range(num_clients), num_selected, replace=False,p = p_usersampling)
 
    # client update
    grad_list=[]
    loss_list=[]
    for i in range(num_selected):
        loss,grad_client = client_update(client_models[i], opt[i], train_loader[client_idx[i]], 
                              epoch=int(local_ep_list[client_idx[i]]), num_clients=num_clients, pk=p_usersampling[client_idx[i]], batch_size=batch_size )
        grad_list.append(grad_client)
        loss_list.append(loss)
    loss = sum(loss_list)
    # serer aggregate
    server_aggregate(global_model, client_models, num_clients=num_clients, pk=p_usersampling , client_index=client_index )
    #update sample prob
    grad_list = [a/sum(grad_list) for a in grad_list]
    normalizing_factor = sum([p_usersampling[i] for i in client_idx])
    
    for i in range(num_selected):
        p_usersampling[client_idx[i]]=(grad_list[i]/sum(grad_list)) * normalizing_factor
    
    
    test_loss, acc = test(global_model, test_loader)
    test_loss_accu.append(test_loss)
    acc_accu.append(acc)
                 
    print('%d-th round' % r)
    print('average train loss %0.3g | test loss %0.3g | test acc: %0.3f' % (loss / num_selected, test_loss, acc))
#     print('sampling probability:' )
#     print(p_usersampling)
                 
file_name = './save/objects/fedsample_Epoch{}_lr{}_round{}_loss_and_acc.pkl'. \
                format(epochs, lr,num_rounds)
with open(file_name, 'wb') as f:
            pickle.dump([test_loss_accu, acc_accu.append], f)
                 
# Plot Average Accuracy vs Communication rounds
plt.figure()
plt.title('Accuracy vs Communication rounds')
plt.plot(range(len(acc_accu)), acc_accu, color='k')
plt.ylabel('Accuracy')
plt.xlabel('Communication Rounds')
plt.savefig('./save/fedsample_diffEpoch{}_lr{}_round{}_acc.png'.
                format(epochs, lr, num_rounds))             

==> load train data
Files already downloaded and verified
Data statistics: {0: {3: 1, 4: 1017}, 1: {0: 366, 1: 212}, 2: {2: 315, 6: 117, 8: 93}, 3: {0: 1, 2: 1011}, 4: {1: 293, 2: 34, 3: 5, 5: 115, 8: 228}, 5: {1: 5, 2: 8, 3: 1139}, 6: {0: 93, 6: 50, 8: 1}, 7: {1: 107, 2: 2, 3: 1, 8: 329}, 8: {0: 206, 2: 1, 4: 332}, 9: {1: 1, 2: 7, 3: 4, 5: 3, 8: 1015}, 10: {0: 372, 2: 1, 4: 273}, 11: {0: 3, 1: 94, 2: 2, 8: 30}, 12: {0: 15, 2: 10, 3: 1, 4: 1, 5: 23, 6: 997}, 13: {0: 25, 4: 86, 8: 4, 9: 4}, 14: {0: 3, 1: 5, 4: 7, 7: 8, 8: 7}, 15: {0: 1, 2: 1, 3: 29, 5: 4, 6: 1, 7: 36, 9: 7}, 16: {0: 1, 1: 15, 2: 6, 3: 41, 4: 11, 8: 23, 9: 5}, 17: {3: 39, 4: 85, 5: 180, 6: 10, 7: 1, 8: 121}, 18: {2: 186, 4: 1, 5: 21, 7: 140}, 19: {1: 1, 3: 7, 7: 10}, 20: {0: 25, 1: 4, 3: 105, 5: 27, 6: 3, 9: 28}, 21: {0: 45, 1: 415, 2: 28, 5: 2, 6: 300}, 22: {1: 2, 2: 13, 3: 206, 5: 27, 8: 4}, 23: {0: 22, 3: 1, 4: 143, 5: 5, 6: 33, 7: 6, 9: 1}, 24: {0: 46, 2: 6, 5: 632}, 25: {1: 1, 2: 1, 3: 5, 4: 158, 5: 3, 7: 54, 8: 49,

AttributeError: 'int' object has no attribute 'cuda'

In [None]:
# NON-IID case: every client has images of two categories chosen from [0, 1], [2, 3], [4, 5], [6, 7], or [8, 9].

args_alpha = 0.1
# Hyperparameters
size = 100 #size = num_selected
num_clients = 100
num_selected = 10
num_rounds = 400
epochs = 6
batch_size = 32
local_ep_list = np.random.choice(range(1,epochs+1),size=num_clients)
lr = 0.01
# Creating decentralized datasets

# traindata = datasets.MNIST('./data', train=True, download=True,
#                        transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
#                        )
# target_labels = torch.stack([traindata.targets == i for i in range(10)])
# target_labels_split = []
# for i in range(5):
#     target_labels_split += torch.split(torch.where(target_labels[(2 * i):(2 * (i + 1))].sum(0))[0], int(60000 / num_clients))
# traindata_split = [torch.utils.data.Subset(traindata, tl) for tl in target_labels_split]
# train_loader = [torch.utils.data.DataLoader(x, batch_size=batch_size, shuffle=True) for x in traindata_split]

# test_loader = torch.utils.data.DataLoader(
#         datasets.MNIST('./data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
#         ), batch_size=batch_size, shuffle=True)
train_loader, test_loader, DataRatios = \
        partition_dataset( size, args_alpha)
# logging.debug("Worker id {} local sample ratio {} "
#               "local epoch length {}"
#               .format(rank, DataRatios[rank], len(train_loader)))

# Instantiate models and optimizers

global_model = Net().cuda()
client_models = [Net().cuda() for _ in range(num_selected)]
for model in client_models:
    model.load_state_dict(global_model.state_dict())
# print(global_model.state_dict())


opt = [optim.SGD(model.parameters(), lr=lr) for model in client_models]

# Runnining FL
test_loss_accu1 = []
acc_accu1 = []
for r in range(num_rounds):
    # select random clients
    client_idx = np.random.permutation(num_clients)[:num_selected]

    # client update
    loss = 0
    for i in range(num_selected):
        loss += client_update_fed(client_models[i], opt[i], train_loader[client_idx[i]], epoch=int(local_ep_list[client_idx[i]]))
    
    # serer aggregate
    server_aggregate(global_model, client_models)
    test_loss, acc = test(global_model, test_loader)
    test_loss_accu1.append(test_loss)
    acc_accu1.append(acc)
    
    print('%d-th round' % r)
    print('average train loss %0.3g | test loss %0.3g | test acc: %0.3f' % (loss / num_selected, test_loss, acc))
    
file_name = './save/objects/fed_diffEpoch{}_lr{}_round{}_loss_and_acc.pkl'. \
                format(epochs, lr, num_rounds)
with open(file_name, 'wb') as f:
            pickle.dump([test_loss_accu1, acc_accu1], f)
    
# Plot Average Accuracy vs Communication rounds
plt.figure()
plt.title('Accuracy vs Communication rounds')
plt.plot(range(len(acc_accu1)), acc_accu1, color='k')
plt.ylabel('Accuracy')
plt.xlabel('Communication Rounds')
plt.savefig('./save/fed_diffEpoch{}_lr{}_round{}_acc.png'.
                format(epochs, lr, num_rounds))

In [None]:
# NON-IID case: every client has images of two categories chosen from [0, 1], [2, 3], [4, 5], [6, 7], or [8, 9].

args_alpha = 0.1
# Hyperparameters
size = 100 #size = num_selected
num_clients = 100
num_selected = 10
num_rounds = 400
epochs = 3
batch_size = 32
lr = 0.01

# Creating decentralized datasets

# traindata = datasets.MNIST('./data', train=True, download=True,
#                        transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
#                        )
# target_labels = torch.stack([traindata.targets == i for i in range(10)])
# target_labels_split = []
# for i in range(5):
#     target_labels_split += torch.split(torch.where(target_labels[(2 * i):(2 * (i + 1))].sum(0))[0], int(60000 / num_clients))
# traindata_split = [torch.utils.data.Subset(traindata, tl) for tl in target_labels_split]
# train_loader = [torch.utils.data.DataLoader(x, batch_size=batch_size, shuffle=True) for x in traindata_split]

# test_loader = torch.utils.data.DataLoader(
#         datasets.MNIST('./data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
#         ), batch_size=batch_size, shuffle=True)
train_loader, test_loader, DataRatios = \
        partition_dataset( size, args_alpha)
# logging.debug("Worker id {} local sample ratio {} "
#               "local epoch length {}"
#               .format(rank, DataRatios[rank], len(train_loader)))

# Instantiate models and optimizers

global_model = Net().cuda()
client_models = [Net().cuda() for _ in range(num_selected)]
for model in client_models:
    model.load_state_dict(global_model.state_dict())

opt = [optim.SGD(model.parameters(), lr=lr) for model in client_models]

# Runnining FL
test_loss_accu2=[]
acc_accu2 = []
for r in range(num_rounds):
    # select random clients
    client_idx = np.random.permutation(num_clients)[:num_selected]

    # client update
    loss = 0
    for i in range(num_selected):
        loss += client_update_fed(client_models[i], opt[i], train_loader[client_idx[i]], epoch=epochs)
    
    # serer aggregate
    server_aggregate(global_model, client_models)
    test_loss, acc = test(global_model, test_loader)
    test_loss_accu2.append(test_loss)
    acc_accu2.append(acc)
    
    print('%d-th round' % r)
    print('average train loss %0.3g | test loss %0.3g | test acc: %0.3f' % (loss / num_selected, test_loss, acc))
    
file_name = './save/objects/fed_sameEpoch{}_lr{}_round{}_loss_and_acc.pkl'. \
                format(epochs, lr, num_rounds)
with open(file_name, 'wb') as f:
            pickle.dump([test_loss_accu2, acc_accu2], f)


# Plot Average Accuracy vs Communication rounds
plt.figure()
plt.title('Accuracy of Fed with same Epoch vs Communication rounds')
plt.plot(range(len(acc_accu2)), acc_accu2, color='k')
plt.ylabel('Accuracy')
plt.xlabel('Communication Rounds')
plt.savefig('./save/fed_sameEpoch{}_lr{}_round{}_acc.png'.
                format(epochs, lr, num_rounds))

In [None]:
plt.figure()
plt.title('Acc vs Communication rounds')
plt.plot(range(len(acc_accu)), acc_accu, color='r', label='importance')
plt.plot(range(len(acc_accu1)), acc_accu1,color='b', label='uniform')
plt.plot(range(len(acc_accu2)), acc_accu2,color='g', label='Fed-uniform-uniformE')
plt.ylabel('Acc')
plt.xlabel('Communication Rounds')
plt.savefig('./save/fed_acc_Epoch{}_lr{}_round{}_acc.png'.
                format(epochs, lr,num_rounds))

In [None]:
acc_accu_tune = [acc_accu[i] for i in range(0,len(acc_accu),20)]
acc_accu_tune1 = [acc_accu1[i] for i in range(0,len(acc_accu1),20)]
acc_accu_tune2 = [acc_accu2[i] for i in range(0,len(acc_accu2),20)]

plt.figure()
plt.title('Acc vs Communication rounds')
plt.plot(range(len(acc_accu_tune)), acc_accu_tune, color='r', label='importance')
plt.plot(range(len(acc_accu_tune1)), acc_accu_tune1,color='b', label='uniform')
plt.plot(range(len(acc_accu_tune2)), acc_accu_tune2,color='g', label='Fed-uniform-uniformE')
plt.ylabel('Acc')
plt.xlabel('Communication Rounds')
plt.savefig('./save/tune10_fed_acc_Epoch{}_lr{}_round{}_acc.png'.
                format(epochs, lr, num_rounds))