In [9]:
# import global dependencies
import matplotlib
import matplotlib.pyplot as plt
import copy
import numpy as np
from torchvision import datasets, transforms
import torch
import random
from torch import nn
import heapq

In [2]:
def power(clients):
    clients_power = []
    for i in range(clients):
        rand = random.randint(1, 100)
        clients_power.append(rand)
    return clients_power

In [5]:
def clients_indexing(clients, clients_power):
    # p_11 --> state_1[1]
    # p_10 --> state_1[2]
    # p_01 --> state_0[1]
    # p_00 --> state_0[2]
    user_indices = []
    for i in range(clients):
        if args.clients_state[i] == 1:
            v_i_t = -(args.state_1[1]/args.num_users) - (((args.state_1[2]*clients_power[i])/100))
            user_indices.append(v_i_t)
            # print(f'client {clients[i]}, is in state {clients_state[i]}')
        elif args.clients_state[i] == 0:
            v_i_t = -(args.state_0[1]/args.num_users) - (((args.state_0[2]*clients_power[i])/100))
            user_indices.append(v_i_t)
      # print('Indices are', user_indices)
      # this prints the top k values
        top_k_users = heapq.nlargest(args.top_k, user_indices)
      # print(f'the top {top_k} users who can transmit are: {top_k_users}')
      # this prints the top k indices
        user_indices = np.argsort(user_indices)
        top_k_users = user_indices[-args.top_k:]
      # print(f'the top {top_k} users who can transmit are: {top_k_users}')
      # print(f'client {clients[i]}, is in state {clients_state[i]}')
    return top_k_users

In [6]:
def wireless_channel_transition_probability(clients):
    if args.clients_state == []:
        for i in range(clients):
            rand_transision = random.random()
            if rand_transision <= args.state_0[0]:
                args.clients_state.append(0)
            else:
                args.clients_state.append(1)
    else:
        for i in range(clients):
            rand_transision = random.random()
            if args.clients_state[i] == 0:
                if rand_transision <= args.state_0[1]:
                    args.clients_state[i] = 1
                else:
                    args.clients_state[i] = 0
            else:
                if rand_transision <= args.state_0[2]:
                    args.clients_state[i] = 0
                else:
                    args.clients_state[i] = 1

In [7]:
from torch import autograd
from torch.utils.data import DataLoader, Dataset
from sklearn import metrics


class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label


class LocalUpdate(object):
    def __init__(self, args, dataset=None, idxs=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()
        self.selected_clients = []
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)

    def train(self, net):
        net.train()
        # train and update
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=0.5)

        epoch_loss = []
        for iter in range(self.args.local_ep):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                net.zero_grad()
                log_probs = net(images)
                loss = self.loss_func(log_probs, labels)
                loss.backward()
                optimizer.step()
                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
                               100. * batch_idx / len(self.ldr_train), loss.item()))
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)

In [10]:
class CNNMnist(nn.Module):
    def __init__(self, args):
        super(CNNMnist, self).__init__()
        self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, args.num_classes)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x

In [11]:
import copy

def FedAvg(w, clients):
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        for i in range(1, len(w)):
            tens = torch.mul(w[i][k], clients[i])
            w_avg[k] += tens
        w_avg[k] = torch.div(w_avg[k], sum(clients))
    return w_avg

In [12]:
import torch.nn.functional as F
from torch.utils.data import DataLoader


def test_img(net_g, datatest, args):
    net_g.eval()
    # testing
    test_loss = 0
    correct = 0
    data_loader = DataLoader(datatest, batch_size=args.bs)
    l = len(data_loader)
    for idx, (data, target) in enumerate(data_loader):
        if args.gpu != -1:
            data, target = data.cuda(), target.cuda()
        log_probs = net_g(data)
        # sum up batch loss
        test_loss += F.cross_entropy(log_probs, target, reduction='sum').item()
        # get the index of the max log-probability
        y_pred = log_probs.data.max(1, keepdim=True)[1]
        correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()

    test_loss /= len(data_loader.dataset)
    accuracy = 100.00 * correct / len(data_loader.dataset)
    if args.verbose:
        print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
            test_loss, correct, len(data_loader.dataset), accuracy))
    return accuracy, test_loss

In [13]:
def mnist_iid(dataset, num_users):
    
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(
            all_idxs,
            random.randint(1,num_items),
            replace=False))
#         print(len(dict_users[i]))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users

In [16]:
# parse args
class args:
    gpu = -1 # <- -1 if no GPU is available
    device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
    num_channels = 1
    num_users = 100
    top_k = 40
    num_classes = 10
    frac = 0.1
    lr = 0.1
    verbose = 0
    bs = 128
    epochs = 10
    
    iid = True        # < -This Value needs to be changed
    local_ep = 20     # <- This Value needs to be changed
    local_bs = 10     # <- This Value needs to be changed
    
    
    state_0 = [0.9449, 0.0087, 0.9913]
    state_1 = [0.0551, 0.8509, 0.1491]
    
    clients_state = []

    loss_train_fedavg = []
    loss_train_ibcs = []
    # training
    fedavg_accu = []
    fedavg_loss = []
    fedavg_power = 0
    ibcs_accu = []
    ibcs_loss = []
    ibcs_power = 0
    

trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
dataset_train = datasets.MNIST('./data/mnist/', train=True, download=True, transform=trans_mnist)
dataset_test = datasets.MNIST('./data/mnist/', train=False, download=True, transform=trans_mnist)
# sample users
if args.iid:
    dict_users = mnist_iid(dataset_train, args.num_users)
else:
    dict_users = mnist_noniid(dataset_train, args.num_users)

img_size = dataset_train[0][0].shape

# build model

net_glob_fedavg = CNNMnist(args=args).to(args.device)
net_glob_ibcs = CNNMnist(args=args).to(args.device)
# print(net_glob)
net_glob_fedavg.train()
net_glob_ibcs.train()

# copy weights
w_glob_fedavg = net_glob_fedavg.state_dict()
w_glob_ibcs = net_glob_ibcs.state_dict()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/mnist/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/mnist/MNIST/raw/train-images-idx3-ubyte.gz to ./data/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/mnist/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/mnist/MNIST/raw
Processing...
Done!


In [None]:
def federated_learningFedAvg(epoch):
    temp_power = 0
    wireless_channel_transition_probability(args.num_users)
    
    w_locals, loss_locals, num_items = [], [], []
    idxs_users = np.random.choice(range(args.num_users), args.top_k, replace=False)
#     print(f'selected users {len(idxs_users)}')
    for idx in idxs_users:
        if args.clients_state[idx] == 0:
            num_items.append(len(dict_users[idx]))
            local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
            w, loss = local.train(net=copy.deepcopy(net_glob_fedavg).to(args.device))
            w_locals.append(copy.deepcopy(w))
            loss_locals.append(copy.deepcopy(loss))
        else:
            temp_power += clients_power[idx]
    # update global weights
    args.w_glob_fedavg = FedAvg(w_locals, num_items)

    # copy weight to net_glob
    net_glob_fedavg.load_state_dict(w_glob_fedavg)

    # print loss
    loss_avg = sum(loss_locals) / len(loss_locals)
    print(f'Loss avg {loss_avg:.3f}')
    
    args.loss_train_fedavg.append(loss_avg)
    print(f'Train loss {args.loss_train_fedavg}')

    # Evaluate score
    net_glob_fedavg.eval()
    acc_test, loss_test = test_img(net_glob_fedavg, dataset_test, args)
    print('Round {:3d}, Average loss {:.3f}, Accuracy {:.3f}'.format(epoch, loss_avg, acc_test))
    args.fedavg_accu.append(acc_test)
    args.fedavg_loss.append(loss_test)
    args.fedavg_power += temp_power
    

# testing
net_glob_fedavg.eval()
acc_train, loss_train = test_img(net_glob_fedavg, dataset_train, args)
acc_test, loss_test = test_img(net_glob_fedavg, dataset_test, args)
print("Training accuracy: {:.2f}".format(acc_train))
print("Testing accuracy: {:.2f}".format(acc_test))

In [None]:
def federated_learningIBCS(epoch):
    temp_power = 0
    wireless_channel_transition_probability(args.num_users)
    
    w_locals, loss_locals, num_items = [], [], []
    idxs_users = clients_indexing(args.num_users, clients_power)
#     print(f'selected users {len(idxs_users)}')
    for idx in idxs_users:
        if args.clients_state[idx] == 0:
#             print(f'This is user {idx}')
            num_items.append(len(dict_users[idx]))
            local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
            w, loss = local.train(net=copy.deepcopy(net_glob_ibcs).to(args.device))
            w_locals.append(copy.deepcopy(w))
            loss_locals.append(copy.deepcopy(loss))
        else:
            temp_power += clients_power[idx]
    # update global weights
    args.w_glob_ibcs = FedAvg(w_locals, num_items)

    # copy weight to net_glob
    net_glob_ibcs.load_state_dict(w_glob_ibcs)

    # print loss
    loss_avg = sum(loss_locals) / len(loss_locals)
    print(f'Loss avg {loss_avg:.3f}')
    
    args.loss_train_ibcs.append(loss_avg)
    print(f'Train loss {args.loss_train_ibcs}')

    # Evaluate score
    net_glob_ibcs.eval()
    acc_test, loss_test = test_img(net_glob_ibcs, dataset_test, args)
    print('Round {:3d}, Average loss {:.3f}, Accuracy {:.3f}'.format(epoch, loss_avg, acc_test))
    args.ibcs_accu.append(acc_test)
    args.ibcs_loss.append(loss_test)
    args.ibcs_power += temp_power
    

# testing
net_glob_ibcs.eval()
acc_train, loss_train = test_img(net_glob_ibcs, dataset_train, args)
acc_test, loss_test = test_img(net_glob_ibcs, dataset_test, args)
print("Training accuracy: {:.2f}".format(acc_train))
print("Testing accuracy: {:.2f}".format(acc_test))

In [18]:
for iter in range(1, args.epochs+1):
    print(f'This is Epoch# {iter}')

    w_locals, loss_locals, num_items = [], [], []
    idxs_users = np.random.choice(range(args.num_users), args.top_k, replace=False)
    for idx in idxs_users:
        num_items.append(len(dict_users[idx]))
        local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
        w, loss = local.train(net=copy.deepcopy(net_glob_fedavg).to(args.device))
        w_locals.append(copy.deepcopy(w))
        loss_locals.append(copy.deepcopy(loss))
    # update global weights
    w_glob_fedavg = FedAvg(w_locals, num_items)

    # copy weight to net_glob
    net_glob_fedavg.load_state_dict(w_glob_fedavg)

    # print loss
    loss_avg = sum(loss_locals) / len(loss_locals)
    
    args.loss_train_fedavg.append(loss_avg)
    
    # Evaluate score
    net_glob_fedavg.eval()
    acc_test, loss_test = test_img(net_glob_fedavg, dataset_test, args)
    print('Round {:3d}, Average loss {:.3f}, Accuracy {:.3f}'.format(iter, loss_avg, acc_test))

# testing
net_glob_fedavg.eval()
acc_train, loss_train = test_img(net_glob_fedavg, dataset_train, args)
acc_test, loss_test = test_img(net_glob_fedavg, dataset_test, args)
print("Training accuracy: {:.2f}".format(acc_train))
print("Testing accuracy: {:.2f}".format(acc_test))

This is Epoch# 1
Round   1, Average loss 1.825, Accuracy 11.350
This is Epoch# 2
Round   2, Average loss 1.372, Accuracy 68.280
This is Epoch# 3
Round   3, Average loss 0.955, Accuracy 52.760
This is Epoch# 4
Round   4, Average loss 0.951, Accuracy 68.910
This is Epoch# 5
Round   5, Average loss 0.997, Accuracy 29.290
This is Epoch# 6
Round   6, Average loss nan, Accuracy 9.800
This is Epoch# 7
Round   7, Average loss nan, Accuracy 9.800
This is Epoch# 8
Round   8, Average loss nan, Accuracy 9.800
This is Epoch# 9
Round   9, Average loss nan, Accuracy 9.800
This is Epoch# 10
Round  10, Average loss nan, Accuracy 9.800
Training accuracy: 9.87
Testing accuracy: 9.80


In [21]:
for iter in range(1, args.epochs+1):
    print(f'This is Epoch# {iter}')

    w_locals, loss_locals, num_items = [], [], []
    m = max(int(args.frac * args.num_users), 1)
    print(f'M is: {m}')
    idxs_users = np.random.choice(range(args.num_users), m, replace=False)
    for idx in idxs_users:
        num_items.append(len(dict_users[idx]))
        local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
        w, loss = local.train(net=copy.deepcopy(net_glob_fedavg).to(args.device))
        w_locals.append(copy.deepcopy(w))
        loss_locals.append(copy.deepcopy(loss))
    # update global weights
    w_glob_fedavg = FedAvg(w_locals, num_items)

    # copy weight to net_glob
    net_glob_fedavg.load_state_dict(w_glob_fedavg)

    # print loss
    loss_avg = sum(loss_locals) / len(loss_locals)
    
    args.loss_train_fedavg.append(loss_avg)
    
    # Evaluate score
    net_glob_fedavg.eval()
    acc_test, loss_test = test_img(net_glob_fedavg, dataset_test, args)
    print('Round {:3d}, Average loss {:.3f}, Accuracy {:.3f}'.format(iter, loss_avg, acc_test))

# testing
net_glob_fedavg.eval()
acc_train, loss_train = test_img(net_glob_fedavg, dataset_train, args)
acc_test, loss_test = test_img(net_glob_fedavg, dataset_test, args)
print("Training accuracy: {:.2f}".format(acc_train))
print("Testing accuracy: {:.2f}".format(acc_test))

This is Epoch# 1
M is: 10
Round   1, Average loss nan, Accuracy 9.800
This is Epoch# 2
M is: 10
Round   2, Average loss nan, Accuracy 9.800
This is Epoch# 3
M is: 10
Round   3, Average loss nan, Accuracy 9.800
This is Epoch# 4
M is: 10
Round   4, Average loss nan, Accuracy 9.800
This is Epoch# 5
M is: 10
Round   5, Average loss nan, Accuracy 9.800
This is Epoch# 6
M is: 10
Round   6, Average loss nan, Accuracy 9.800
This is Epoch# 7
M is: 10
Round   7, Average loss nan, Accuracy 9.800
This is Epoch# 8
M is: 10
Round   8, Average loss nan, Accuracy 9.800
This is Epoch# 9
M is: 10
Round   9, Average loss nan, Accuracy 9.800
This is Epoch# 10
M is: 10
Round  10, Average loss nan, Accuracy 9.800
Training accuracy: 9.87
Testing accuracy: 9.80


In [None]:
def plot():
    fig, ax = plt.subplots()
    ax.plot(args.fedavg_accu)
    ax.plot(args.ibcs_accu)

    ax.set_title('Accuracy')
    ax.legend(['FedAvg', 'IBCS'])
    ax.xaxis.set_label_text('Gobal Epochs')
    ax.yaxis.set_label_text('Accuracy in %')
    plt.show()
#     plt.savefig('./results/mnist/Accuracy_MNIST.png')

    fig, ax = plt.subplots()
    ax.plot(args.fedavg_loss)
    ax.plot(args.ibcs_loss)

    ax.set_title('Loss')
    ax.legend(['FedAvg', 'IBCS'])
    ax.xaxis.set_label_text('Gobal Epochs')
    ax.yaxis.set_label_text('Loss')
    # plt.show()
    plt.savefig('./results/mnist/Loss_MNIST.png')

    fig, ax = plt.subplots()
    ax.plot(args.fedavg_power)
    ax.plot(args.ibcs_power)

    ax.set_title('Power')
    ax.legend(['FedAvg', 'IBCS'])
    ax.xaxis.set_label_text('Gobal Epochs')
    ax.yaxis.set_label_text('Power')
    # plt.show()
    plt.savefig('./results/mnist/Power_MNIST.png')


plot()

In [1]:
# import global dependencies
import matplotlib
import matplotlib.pyplot as plt
import copy
import numpy as np
from torchvision import datasets, transforms
import torch
import random
from torch import nn

In [2]:
from torch import autograd
from torch.utils.data import DataLoader, Dataset
from sklearn import metrics


class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label


class LocalUpdate(object):
    def __init__(self, args, dataset=None, idxs=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()
        self.selected_clients = []
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)

    def train(self, net):
        net.train()
        # train and update
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=0.5)

        epoch_loss = []
        for iter in range(self.args.local_ep):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                net.zero_grad()
                log_probs = net(images)
                loss = self.loss_func(log_probs, labels)
                loss.backward()
                optimizer.step()
                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
                               100. * batch_idx / len(self.ldr_train), loss.item()))
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)

In [3]:
class CNNMnist(nn.Module):
    def __init__(self, args):
        super(CNNMnist, self).__init__()
        self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, args.num_classes)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x

In [4]:
import copy

def FedAvg(w, clients):
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        for i in range(1, len(w)):
            tens = torch.mul(w[i][k], clients[i])
            w_avg[k] += tens
        w_avg[k] = torch.div(w_avg[k], sum(clients))
    return w_avg

In [5]:
import torch.nn.functional as F
from torch.utils.data import DataLoader


def test_img(net_g, datatest, args):
    net_g.eval()
    # testing
    test_loss = 0
    correct = 0
    data_loader = DataLoader(datatest, batch_size=args.bs)
    l = len(data_loader)
    for idx, (data, target) in enumerate(data_loader):
        if args.gpu != -1:
            data, target = data.cuda(), target.cuda()
        log_probs = net_g(data)
        # sum up batch loss
        test_loss += F.cross_entropy(log_probs, target, reduction='sum').item()
        # get the index of the max log-probability
        y_pred = log_probs.data.max(1, keepdim=True)[1]
        correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()

    test_loss /= len(data_loader.dataset)
    accuracy = 100.00 * correct / len(data_loader.dataset)
    if args.verbose:
        print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
            test_loss, correct, len(data_loader.dataset), accuracy))
    return accuracy, test_loss


In [10]:
def mnist_iid(dataset, num_users):
    """
    Sample I.I.D. client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset) / num_users)
    all_idxs = [i for i in range(len(dataset))]  # Initialize all_idxs here
    dict_users = {}
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(
            all_idxs,
            num_items,
            replace=False))
        print(len(dict_users[i]))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users

In [11]:
# parse args
class args:
    gpu = -1 # <- -1 if no GPU is available
    device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
    num_channels = 1
    num_users = 100
    num_classes = 10
    frac = 0.1
    lr = 0.1
    verbose = 0
    bs = 128
    epochs = 100
    
    iid = True        # < -This Value needs to be changed
    local_ep = 20     # <- This Value needs to be changed
    local_bs = 10     # <- This Value needs to be changed

trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist)
dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist)
# sample users
if args.iid:
    dict_users = mnist_iid(dataset_train, args.num_users)
else:
    dict_users = mnist_noniid(dataset_train, args.num_users)

img_size = dataset_train[0][0].shape

# build model

net_glob = CNNMnist(args=args).to(args.device)
print(net_glob)
net_glob.train()

# copy weights
w_glob = net_glob.state_dict()

# training
loss_train = []
cv_loss, cv_acc = [], []
val_loss_pre, counter = 0, 0
net_best = None
best_loss = None
val_acc_list, net_list = [], []

600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
600
CNNMnist(
  (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2_drop): Dropout2d(p=0.5, inplace=False)
  (fc1): Linear(in_features=320, out_features=50, bias=True)
  (fc2): Linear(in_features=50, out_features=10, bias=True)
)


In [12]:
for iter in range(args.epochs):

    w_locals, loss_locals, num_items = [], [], []
    m = max(int(args.frac * args.num_users), 1)
    idxs_users = np.random.choice(range(args.num_users), m, replace=False)
    for idx in idxs_users:
        num_items.append(len(dict_users[idx]))
        local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
        w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
        w_locals.append(copy.deepcopy(w))
        loss_locals.append(copy.deepcopy(loss))
    # update global weights
    w_glob = FedAvg(w_locals, idxs_users)

    # copy weight to net_glob
    net_glob.load_state_dict(w_glob)

    # print loss
    loss_avg = sum(loss_locals) / len(loss_locals)
    
    loss_train.append(loss_avg)
    
    # Evaluate score
    net_glob.eval()
    acc_test, loss_test = test_img(net_glob, dataset_test, args)
    print('Round {:3d}, Average loss {:.3f}, Accuracy {:.3f}'.format(iter, loss_avg, acc_test))

# testing
net_glob.eval()
acc_train, loss_train = test_img(net_glob, dataset_train, args)
acc_test, loss_test = test_img(net_glob, dataset_test, args)
print("Training accuracy: {:.2f}".format(acc_train))
print("Testing accuracy: {:.2f}".format(acc_test))

Round   0, Average loss 1.748, Accuracy 11.400
Round   1, Average loss 1.370, Accuracy 56.390
Round   2, Average loss 1.285, Accuracy 9.740
Round   3, Average loss 1.411, Accuracy 18.060
Round   4, Average loss 1.268, Accuracy 37.070
Round   5, Average loss 1.260, Accuracy 9.740
Round   6, Average loss 2.313, Accuracy 11.350
Round   7, Average loss 2.302, Accuracy 10.100
Round   8, Average loss 2.304, Accuracy 11.350
Round   9, Average loss 2.301, Accuracy 11.350
Round  10, Average loss 2.301, Accuracy 11.350
Round  11, Average loss 2.301, Accuracy 11.350
Round  12, Average loss 2.298, Accuracy 10.280
Round  13, Average loss 2.300, Accuracy 10.280
Round  14, Average loss 2.301, Accuracy 11.350
Round  15, Average loss 2.299, Accuracy 11.350
Round  16, Average loss 2.301, Accuracy 11.350
Round  17, Average loss 2.301, Accuracy 10.280
Round  18, Average loss 2.302, Accuracy 11.350
Round  19, Average loss 2.302, Accuracy 11.350
Round  20, Average loss 2.303, Accuracy 9.800
Round  21, Avera