In [None]:
from torch import nn
import torch.nn.functional as F
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import numpy as np
from torchvision import datasets, transforms
import copy

#### Modified from below

MIT License

Copyright (c) 2019 Ashwin R Jadhav

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

#### Models

In [None]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.layer_input = nn.Linear(28*28, 512)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout()
        self.layer_hidden = nn.Linear(512, 10)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1])
        x = self.layer_input(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.layer_hidden(x)
        return self.softmax(x)

In [None]:
class CNNMnist(nn.Module):
    def __init__(self):
        super(CNNMnist, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

#### Data preparation

In [None]:
class DatasetSplit(Dataset):
    """An abstract Dataset class wrapped around Pytorch Dataset class.
    """

    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return torch.tensor(image), torch.tensor(label)

In [None]:
def mnist_iid(dataset, num_users):
    """
    Sample I.I.D. client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items,
                                             replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users


In [None]:
def get_dataset(dataset):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """

    if dataset == 'mnist' or 'fmnist':
        if dataset == 'mnist':
            data_dir = '../data/mnist/'
        else:
            data_dir = '../data/fmnist/'

        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))])
        
        if dataset == 'mnist':
            train_dataset = datasets.MNIST(data_dir, train=True, download=True,
                                           transform=apply_transform)

            test_dataset = datasets.MNIST(data_dir, train=False, download=True,
                                          transform=apply_transform)
        else:
            train_dataset = datasets.FashionMNIST(data_dir, train=True, download=True,
                                           transform=apply_transform)

            test_dataset = datasets.FashionMNIST(data_dir, train=False, download=True,
                                          transform=apply_transform)            

        user_groups = mnist_iid(train_dataset, 100)

    return train_dataset, test_dataset, user_groups

#### Local Update

In [None]:
device = 'cuda:0'

In [None]:
class LocalUpdate(object):
    def __init__(self, dataset, idxs):
        self.trainloader, self.validloader, self.testloader = self.train_val_test(
            dataset, list(idxs))
        self.device = device
        # Default criterion set to NLL loss function
        self.criterion = nn.NLLLoss().to(self.device)

    def train_val_test(self, dataset, idxs):
        """
        Returns train, validation and test dataloaders for a given dataset
        and user indexes.
        """
        # split indexes for train, validation, and test (80, 10, 10)
        idxs_train = idxs[:int(0.8*len(idxs))]
        idxs_val = idxs[int(0.8*len(idxs)):int(0.9*len(idxs))]
        idxs_test = idxs[int(0.9*len(idxs)):]

        trainloader = DataLoader(DatasetSplit(dataset, idxs_train),
                                 batch_size=10, shuffle=True)
        validloader = DataLoader(DatasetSplit(dataset, idxs_val),
                                 batch_size=int(len(idxs_val)/10), shuffle=False)
        testloader = DataLoader(DatasetSplit(dataset, idxs_test),
                                batch_size=int(len(idxs_test)/10), shuffle=False)
        return trainloader, validloader, testloader

    def update_weights(self, local_ep, model, attack):
        # Set mode to train model
        model.train()

        # Set optimizer for the local updates
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01,
                                    momentum=0.5)

        for iter in range(local_ep):

            for batch_idx, (images, labels) in enumerate(self.trainloader):
                images, labels = images.to(self.device), labels.to(self.device)
                if attack == 'label':
                    labels = torch.zeros(len(labels), dtype = torch.long).to(self.device)
                model.zero_grad()
                log_probs = model(images)
                loss = self.criterion(log_probs, labels)
                loss.backward()
                optimizer.step()

        return model.state_dict()

def test_inference(model, test_dataset):
    """ Returns the test accuracy and loss.
    """

    model.eval()
    loss, total, correct = 0.0, 0.0, 0.0

    criterion = nn.NLLLoss().to(device)
    testloader = DataLoader(test_dataset, batch_size=128,
                            shuffle=False)

    for batch_idx, (images, labels) in enumerate(testloader):
        images, labels = images.to(device), labels.to(device)

        # Inference
        outputs = model(images)
        batch_loss = criterion(outputs, labels)
        loss += batch_loss.item()

        # Prediction
        _, pred_labels = torch.max(outputs, 1)
        pred_labels = pred_labels.view(-1)
        correct += torch.sum(torch.eq(pred_labels, labels)).item()
        total += len(labels)

    accuracy = correct/total
    return accuracy, loss

#### Aggregators

In [None]:
def average_weights(w):
    """
    Returns the average of the weights.
    """
    w_avg = copy.deepcopy(w[0])
    for key in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        w_avg[key] = torch.div(w_avg[key], len(w))
    return w_avg

In [None]:
def median(w):
    w_med = copy.deepcopy(w[0])
    n = len(w)
    for k in w_med.keys():
        if 'bias' in k:
            blank = torch.zeros(n, len(w_med[k]))
            for i in range(n):
                blank[i] = w[i][k]
            w_med[k] = torch.tensor(np.median(blank, axis = -2))
        elif 'weight' in k:
            if 'conv' in k:
                k1, k2, k3, k4 = w_med[k].size()
                blank = torch.zeros(n, k1, k2, k3, k4)
                for i in range(n):
                    blank[i] = w[i][k]
                w_med[k] = torch.tensor(np.median(blank, axis = -5))
            else:
                k1, k2 = w_med[k].size()
                blank = torch.zeros(n, k1, k2)
                for i in range(n):
                    blank[i] = w[i][k]
                w_med[k] = torch.tensor(np.median(blank, axis = -3))
    return w_med

In [None]:
def approxmed5(w):
    w_med = copy.deepcopy(w[0])
    n = len(w)
    for k in w_med.keys():
        if 'bias' in k:
            k1 = len(w_med[k])
            
            blank = torch.zeros(n, k1)
            index = np.random.permutation(n)
            
            for i in range(n):
                blank[i] = w[index[i]][k]
                
            result1 = torch.zeros(n//5, k1)
            for i in range(n//5):
                result1[i] = torch.tensor(np.median(blank[5*i:5*i+5], axis = -2))
                
            result2 = torch.zeros(4, k1)
            
            for i in range(4):
                result2[i] = torch.tensor(np.median(result1[5*i:5*i+5], axis = -2))
                
            w_med[k] = torch.tensor(np.median(result2, axis = -2))
            
        elif 'weight' in k:
            if 'conv' in k:
                k1, k2, k3, k4 = w_med[k].size()

                blank = torch.zeros(n, k1, k2, k3, k4)
                index = np.random.permutation(n)

                for i in range(n):
                    blank[i] = w[index[i]][k]

                result1 = torch.zeros(n//5, k1, k2, k3, k4)
                for i in range(n//5):
                    result1[i] = torch.tensor(np.median(blank[5*i:5*i+5], axis = -5))

                result2 = torch.zeros(4, k1, k2, k3, k4)

                for i in range(4):
                    result2[i] = torch.tensor(np.median(result1[5*i:5*i+5], axis = -5))

                w_med[k] = torch.tensor(np.median(result2, axis = -5))
            
            
            else:
                k1, k2 = w_med[k].size()

                blank = torch.zeros(n, k1, k2)
                index = np.random.permutation(n)

                for i in range(n):
                    blank[i] = w[index[i]][k]

                result1 = torch.zeros(n//5, k1, k2)
                for i in range(n//5):
                    result1[i] = torch.tensor(np.median(blank[5*i:5*i+5], axis = -3))

                result2 = torch.zeros(4, k1, k2)

                for i in range(4):
                    result2[i] = torch.tensor(np.median(result1[5*i:5*i+5], axis = -3))

                w_med[k] = torch.tensor(np.median(result2, axis = -3))
            
    return w_med

#### Global update

In [None]:
def global_update(epochs, local_ep, dataset, model, aggregator, attack):
    error_list = []
    # load dataset and user groups
    train_dataset, test_dataset, user_groups = get_dataset(dataset)

    # BUILD MODEL
    if model == 'cnn':
        # Convolutional neural netork
        if dataset == 'mnist':
            global_model = CNNMnist()
        elif dataset == 'fmnist':
            global_model = CNNMnist()

    elif model == 'mlp':
        # Multi-layer preceptron
        img_size = train_dataset[0][0].shape
        len_in = 1
        for x in img_size:
            len_in *= x
            global_model = MLP()
    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    global_model.to(device)
    global_model.train()
    print(global_model)

    # copy weights
    global_weights = global_model.state_dict()

    for epoch in range(epochs):
        local_weights = []
        print(f'\n | Global Training Round : {epoch+1} |\n')

        global_model.train()

        idxs_users = np.random.choice(range(100), 100, replace=False)
        
        malicious = 0
        i = 0
        for idx in idxs_users:
            print(i)
            i += 1

            local_model = LocalUpdate(dataset=train_dataset, idxs=user_groups[idx])
            w = local_model.update_weights(local_ep, copy.deepcopy(global_model), 'clean')
            
            
            if attack == 'negative':
                if malicious < 30:
                    send = copy.deepcopy(w)
                    for k in send.keys():
                        send[k] = -2*w[k]
                    local_weights.append(copy.deepcopy(send))
                else:
                    local_weights.append(copy.deepcopy(w))
            else:
                local_weights.append(copy.deepcopy(w))
            malicious += 1

        # update global weights
        if aggregator == 'average':
            global_weights = average_weights(local_weights)
        elif aggregator == 'median':
            global_weights = median(local_weights)
        elif aggregator == 'approxmed5':
            global_weights = approxmed5(local_weights)

        # update global weights
        global_model.load_state_dict(global_weights)
        global_model.eval()
    # Test inference after completion of training
        test_acc, test_loss = test_inference(copy.deepcopy(global_model), test_dataset)

        print(f' \n Results after {epoch} global rounds of training:')
        print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))
        error_list.append(1-test_acc)
    return error_list

#### Experiment

In [None]:
error_list = global_update(10,10,'fmnist','mlp','approxmed5','clean')