In [47]:
# import global dependencies
import matplotlib
import matplotlib.pyplot as plt
import copy
import numpy as np
from torchvision import datasets, transforms
import torch
import random
from torch import nn
import heapq

In [2]:
def power(clients):
    clients_power = []
    for i in range(clients):
        rand = random.randint(1, 100)
        clients_power.append(rand)
    return clients_power

In [49]:
def clients_indexing(clients, clients_power):
    # p_11 --> state_1[1]
    # p_10 --> state_1[2]
    # p_01 --> state_0[1]
    # p_00 --> state_0[2]
    user_indices = []
    for i in range(clients):
        if args.clients_state[i] == 1:
            v_i_t = -(args.state_1[1]/args.num_users) - (((args.state_1[2]*clients_power[i])/100))
            user_indices.append(v_i_t)
            # print(f'client {clients[i]}, is in state {clients_state[i]}')
        elif args.clients_state[i] == 0:
            v_i_t = -(args.state_0[1]/args.num_users) - (((args.state_0[2]*clients_power[i])/100))
            user_indices.append(v_i_t)
    # print('Indices are', user_indices)
    # this prints the top k values
    top_k_users = heapq.nlargest(args.top_k, user_indices)
    # print(f'the top {top_k} users who can transmit are: {top_k_users}')
    # this prints the top k indices
    user_indices = np.argsort(user_indices)
    top_k_users = user_indices[-args.top_k:]
    # print(f'the top {top_k} users who can transmit are: {top_k_users}')
    # print(f'client {clients[i]}, is in state {clients_state[i]}')
    return top_k_users

In [50]:
def wireless_channel_transition_probability(clients):
    if args.clients_state == []:
        # print('This is time 0')
        for i in range(clients):
            # print(f'clien stae {i}')
            rand_transision = random.random()
            if rand_transision <= args.state_0[0]:
                # print(f'random here is {rand_transision}')
                args.clients_state.append(0)
            else:
                # print(f'random here is {rand_transision}')
                args.clients_state.append(1)
    else:
        # print('This is Not time 0')
        for i in range(clients):
            rand_transision = random.random()
            # print(f'random here is {rand_transision}')
            if args.clients_state[i] == 0:
                if rand_transision <= args.state_0[1]:
                    args.clients_state[i] = 1
                else:
                    args.clients_state[i] = 0
            else:
                if rand_transision <= args.state_0[2]:
                    args.clients_state[i] = 0
                else:
                    args.clients_state[i] = 1

In [51]:
from torch import autograd
from torch.utils.data import DataLoader, Dataset
from sklearn import metrics


class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label


class LocalUpdate(object):
    def __init__(self, args, dataset=None, idxs=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()
        self.selected_clients = []
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)

    def train(self, net):
        net.train()
        # train and update
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=0.5)

        epoch_loss = []
        for iter in range(self.args.local_ep):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                net.zero_grad()
                log_probs = net(images)
                loss = self.loss_func(log_probs, labels)
                loss.backward()
                optimizer.step()
                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
                               100. * batch_idx / len(self.ldr_train), loss.item()))
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)

In [52]:
class CNNMnist(nn.Module):
    def __init__(self, args):
        super(CNNMnist, self).__init__()
        self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, args.num_classes)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x

In [53]:
import copy

def FedAvg(w, clients):
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        for i in range(1, len(w)):
            tens = torch.mul(w[i][k], clients[i])
            w_avg[k] += tens
        w_avg[k] = torch.div(w_avg[k], sum(clients))
    return w_avg

In [54]:
import torch.nn.functional as F
from torch.utils.data import DataLoader


def test_img(net_g, datatest, args):
    net_g.eval()
    # testing
    test_loss = 0
    correct = 0
    data_loader = DataLoader(datatest, batch_size=args.bs)
    l = len(data_loader)
    for idx, (data, target) in enumerate(data_loader):
        if args.gpu != -1:
            data, target = data.cuda(), target.cuda()
        log_probs = net_g(data)
        # sum up batch loss
        test_loss += F.cross_entropy(log_probs, target, reduction='sum').item()
        # get the index of the max log-probability
        y_pred = log_probs.data.max(1, keepdim=True)[1]
        correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()

    test_loss /= len(data_loader.dataset)
    accuracy = 100.00 * correct / len(data_loader.dataset)
    if args.verbose:
        print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
            test_loss, correct, len(data_loader.dataset), accuracy))
    return accuracy, test_loss

In [55]:
def mnist_iid(dataset, num_users):
    """
    Sample I.I.D. client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(
            all_idxs,
            random.randint(1,num_items),
            replace=False))
#         print(len(dict_users[i]))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users

In [56]:
# parse args
class args:
    gpu = -1 # <- -1 if no GPU is available
    device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
    num_channels = 1
    num_users = 100
    top_k = 40
    num_classes = 10
    frac = 0.1
    lr = 0.1
    verbose = 0
    bs = 128
    epochs = 10
    
    iid = True        # < -This Value needs to be changed
    local_ep = 20     # <- This Value needs to be changed
    local_bs = 10     # <- This Value needs to be changed
    
    
    state_0 = [0.9449, 0.0087, 0.9913]
    state_1 = [0.0551, 0.8509, 0.1491]
    
    clients_state = []

    loss_train = []
    # training
    fedavg_accu = []
    fedavg_loss = []
    fedavg_power = 0
    ibcs_accu = []
    ibcs_loss = []
    ibcs_power = 0
    
    # build model

    net_glob_fedavg = CNNMnist(args=args).to(args.device)
    net_glob_ibcs = CNNMnist(args=args).to(args.device)
    # print(net_glob)
    net_glob_fedavg.train()
    net_glob_ibcs.train()

    # copy weights
    w_glob_fedavg = net_glob_fedavg.state_dict()
    w_glob_ibcs = net_glob_ibcs.state_dict()

trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist)
dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist)
# sample users
if args.iid:
    dict_users = mnist_iid(dataset_train, args.num_users)
else:
    dict_users = mnist_noniid(dataset_train, args.num_users)

img_size = dataset_train[0][0].shape

In [43]:
def federated_learningFedAvg(epoch):
    temp_power = 0
    wireless_channel_transition_probability(args.num_users)
    
    w_locals, loss_locals, num_items = [], [], []
    idxs_users = np.random.choice(range(args.num_users), args.top_k, replace=False)
    print(f'selected users {len(idxs_users)}')
    for idx in idxs_users:
        if args.clients_state[idx] == 0:
            num_items.append(len(dict_users[idx]))
            local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
            w, loss = local.train(net=copy.deepcopy(net_glob_fedavg).to(args.device))
            w_locals.append(copy.deepcopy(w))
            loss_locals.append(copy.deepcopy(loss))
        else:
            temp_power += clients_power[idx]
    # update global weights
    w_glob_fedavg = FedAvg(w_locals, num_items)

    # copy weight to net_glob
    net_glob_fedavg.load_state_dict(w_glob_fedavg)

    # print loss
    loss_avg = sum(loss_locals) / len(loss_locals)
    print(f'Loss avg {loss_avg}')
    
    args.loss_train.append(loss_avg)
    print(f'Train loss {loss_train}')

    # Evaluate score
    net_glob_fedavg.eval()
    acc_test, loss_test = test_img(net_glob_fedavg, dataset_test, args)
    print('Round {:3d}, Average loss {:.3f}, Accuracy {:.3f}'.format(epoch, loss_avg, acc_test))
    args.fedavg_accu.append(acc_test)
    args.fedavg_loss.append(loss_test)
    args.fedavg_power += temp_power
    

# testing
net_glob_fedavg.eval()
acc_train, loss_train = test_img(net_glob_fedavg, dataset_train, args)
acc_test, loss_test = test_img(net_glob_fedavg, dataset_test, args)
print("Training accuracy: {:.2f}".format(acc_train))
print("Testing accuracy: {:.2f}".format(acc_test))

Training accuracy: 7.51
Testing accuracy: 7.66


In [None]:
def federated_learningIBCS(epoch):
    temp_power = 0
    wireless_channel_transition_probability(args.num_users)
    
    w_locals, loss_locals, num_items = [], [], []
    idxs_users = clients_indexing(args.num_users, clients_power)
    print(f'selected users {len(idxs_users)}')
    for idx in idxs_users:
        if args.clients_state[idx] == 0:
#             print(f'This is user {idx}')
            num_items.append(len(dict_users[idx]))
            local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
            w, loss = local.train(net=copy.deepcopy(net_glob_ibcs).to(args.device))
            w_locals.append(copy.deepcopy(w))
            loss_locals.append(copy.deepcopy(loss))
        else:
            temp_power += clients_power[idx]
    # update global weights
    w_glob_ibcs = FedAvg(w_locals, num_items)

    # copy weight to net_glob
    net_glob_ibcs.load_state_dict(w_glob_ibcs)

    # print loss
    loss_avg = sum(loss_locals) / len(loss_locals)
    print(f'Loss avg {loss_avg}')
    
    args.loss_train.append(loss_avg)
    print(f'Train loss {loss_train}')

    # Evaluate score
    net_glob_ibcs.eval()
    acc_test, loss_test = test_img(net_glob_ibcs, dataset_test, args)
    print('Round {:3d}, Average loss {:.3f}, Accuracy {:.3f}'.format(epoch, loss_avg, acc_test))
    args.ibcs_accu.append(acc_test)
    args.ibcs_loss.append(loss_test)
    args.ibcs_power += temp_power
    

# testing
net_glob_ibcs.eval()
acc_train, loss_train = test_img(net_glob_ibcs, dataset_train, args)
acc_test, loss_test = test_img(net_glob_ibcs, dataset_test, args)
print("Training accuracy: {:.2f}".format(acc_train))
print("Testing accuracy: {:.2f}".format(acc_test))

In [62]:
# assign power to all users
clients_power = power(args.num_users)
# print(clients_power)
for i in range(1, args.epochs+1):
    args.loss_train.clear()
    args.clients_state.clear()
    print(f'This is epoch# {i}')
#     federated_learningFedAvg(i)

    args.loss_train.clear()
    args.clients_state.clear()
    print(f'Loss Train {args.loss_train}')
    print(f'Client Stat {args.clients_state}')
    federated_learningIBCS(i)
    print('\n')
    print(f'Loss Train {args.loss_train}')
    print(f'Client Stat {args.clients_state}')
    

This is epoch# 1
Loss Train []
Client Stat []
selected users 40
This is user 66
This is user 20
This is user 7
This is user 34
This is user 1
This is user 92
This is user 24
This is user 91


KeyboardInterrupt: 

In [None]:
def plot():
    fig, ax = plt.subplots()
    ax.plot(args.fedavg_accu)
    ax.plot(args.ibcs_accu)

    ax.set_title('Accuracy')
    ax.legend(['FedAvg', 'IBCS'])
    ax.xaxis.set_label_text('Gobal Epochs')
    ax.yaxis.set_label_text('Accuracy in %')
    plt.show()
#     plt.savefig('./results/mnist/Accuracy_MNIST.png')

    fig, ax = plt.subplots()
    ax.plot(args.fedavg_loss)
    ax.plot(args.ibcs_loss)

    ax.set_title('Loss')
    ax.legend(['FedAvg', 'IBCS'])
    ax.xaxis.set_label_text('Gobal Epochs')
    ax.yaxis.set_label_text('Loss')
    # plt.show()
    plt.savefig('./results/mnist/Loss_MNIST.png')

    fig, ax = plt.subplots()
    ax.plot(args.fedavg_power)
    ax.plot(args.ibcs_power)

    ax.set_title('Power')
    ax.legend(['FedAvg', 'IBCS'])
    ax.xaxis.set_label_text('Gobal Epochs')
    ax.yaxis.set_label_text('Power')
    # plt.show()
    plt.savefig('./results/mnist/Power_MNIST.png')


plot()