# GAN Attack on FashionMNIST and MNIST datasets

In [None]:
# import libraries
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import copy

In [None]:
# define the device
def try_gpu():
    """
    If GPU is available, return torch.device as cuda:0; else return torch.device
    as cpu.
    """
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')
    return device

In [None]:
# define some variables
generated_picture_number = 256
noise_size = 100
device = try_gpu()

In [None]:
# my model initialization method
def myInit(model):
    for layer in model.modules():
        if isinstance(layer, torch.nn.Conv2d):
            #torch.nn.init.kaiming_normal_(layer.weight, mode='fan_out',
            #                              nonlinearity='relu')
            if layer.bias is not None:
                torch.nn.init.constant_(layer.bias, val=0.0)
        elif isinstance(layer, torch.nn.BatchNorm2d):
            torch.nn.init.constant_(layer.weight, val=1.0)
            torch.nn.init.constant_(layer.bias, val=0.0)
        elif isinstance(layer, torch.nn.Linear):
            torch.nn.init.xavier_normal_(layer.weight)
            if layer.bias is not None:
                torch.nn.init.constant_(layer.bias, val=0.0)

In [None]:
# model architecture of GAN
class Generator(nn.Module):
    def __init__(self, noise_size):
        super(Generator, self).__init__()
        self.linear1 = nn.Sequential(
            nn.Linear(noise_size, 7 * 7 * 256, bias=False),
            nn.BatchNorm1d(7 * 7 * 256),
            nn.ReLU(True)
        )

        self.conv1 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True)
        )

        self.conv2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True)
        )

        self.conv3 = nn.Sequential(
            nn.ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.linear1(x)
        x = x.view(-1, 256, 7, 7)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)

        return x


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(3, 3))
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2))
        )
        self.linear = nn.Sequential(
            nn.Flatten(),
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Linear(256, 200),
            nn.ReLU(),
            nn.Linear(200, 11),
            nn.LogSoftmax(dim=-1)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.linear(x)
        return x

# Prepare the dataset and the dataloader

In [None]:
# load the dataset
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))])
# Can change to
#test_data = datasets.MNIST(root='../data', train=False, transform=transform, download=True)
#train_data = datasets.FashionMNIST(root='../data', train=True, download=True, transform=transform)
test_data = datasets.FashionMNIST('./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_data, batch_size=1000)


# define the dataloader for each client
def get_train_loader(index):
    train_data = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)
    #train_data.data = torch.unsqueeze(train_data.data, 1)
    indices = train_data.targets == index[0]
    if len(index) > 1:
        for i in index:
            indices = indices | (train_data.targets == i)
    train_data.data, train_data.targets = train_data.data[indices], train_data.targets[indices]
    train_loader = DataLoader(train_data, batch_size=256, shuffle=True)
    return train_loader


# 2 clients situation, assign the first 5 dataloader to the first client, the last 5 dataloader to the second client
train_loaders = [get_train_loader([j for j in range(i, i + 5)]) for i in range(0, 10, 5)]
# 10 clients situation, assign each client to each dataloader
#train_loaders = [get_train_loader([i])for i in range(10)]
dataset = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)
indices = torch.randperm(len(dataset))[:640]
#split a warmup_dataset to server
warmup_dataset = torch.utils.data.Subset(dataset, indices)
warmup_dataloader = DataLoader(warmup_dataset, batch_size=64, shuffle=True)

### use dirichlet to distribute the dataset

In [None]:
train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)


def dirichlet_split_noniid(train_labels, alpha, n_clients):
    n_classes = train_labels.max() + 1
    label_distribution = np.random.dirichlet([alpha] * n_clients, n_classes)
    class_idcs = [np.argwhere(train_labels == y).flatten()
                  for y in range(n_classes)]

    client_idcs = [[] for _ in range(n_clients)]
    for c, fracs in zip(class_idcs, label_distribution):
        for i, idcs in enumerate(np.split(c, (np.cumsum(fracs)[:-1] * len(c)).astype(int))):
            client_idcs[i] += [idcs]

    client_idcs = [np.concatenate(idcs) for idcs in client_idcs]

    return client_idcs


client_idcs = dirichlet_split_noniid(train_data.targets, 1.0, 10)
def get_train_loader(index):
    train_data = datasets.MNIST('./data', train=True, download=True, transform=transform)
    #train_data.data = torch.unsqueeze(train_data.data, 1)
    indices = client_idcs[index]
    train_data.data, train_data.targets = train_data.data[indices], train_data.targets[indices]
    train_loader = DataLoader(train_data, batch_size=256, shuffle=True)
    return  train_loader

train_loaders = [get_train_loader(i)for i in range(10)]

In [None]:
# The pseudocode can  be found in the paper
theta = 0.9
gamma = 0.003
tau = 0.0001
budget_per_para = 10


def gradients_flatten(gradients):
    flattened = torch.empty(0).to(device)
    gradients_shape = []
    for gradient in gradients:
        gradients_shape.append([gradient.shape, gradient.nelement()])
        flattened = torch.cat([flattened, gradient.data.flatten()])
    return flattened, gradients_shape


def sigma(x, c, delta_f):
    return 2 * c * delta_f / x


def differential_privacy(gradients, budget_per_para, gamma, theta, tau):
    grad_flattened, grad_shapes = gradients_flatten(gradients)
    parameter_number = grad_flattened.shape[0]
    #calculate c
    c = np.floor(theta * parameter_number)
    #split budget
    epsilon = budget_per_para * c
    epsilon_1 = 8 / 9 * epsilon
    epsilon_2 = 2 / 9 * epsilon
    sigma_1 = sigma(epsilon_1, c, 2 * gamma)
    sigma_2 = sigma(epsilon_2, c, 2 * gamma)
    tau_with_noise = torch.FloatTensor(np.random.laplace(0, sigma_1, parameter_number) + tau).to(device)
    R_w = torch.FloatTensor(np.random.laplace(0, 2 * sigma_1, parameter_number)).to(device)
    grad_flattened[grad_flattened > gamma] = gamma
    grad_flattened[grad_flattened < -gamma] = -gamma
    grad_flattened_noise = torch.absolute(grad_flattened) + R_w
    index = torch.where((grad_flattened_noise >= tau_with_noise))[0]
    if (index.shape >= c):
        index = np.random.shuffle(np.array(index.cpu()))
        index = index[0:c]
    R_w2 = torch.FloatTensor(np.random.laplace(0, sigma_2, parameter_number)).to(device)
    grad_flattened_noise2 = R_w2 + grad_flattened
    grad_flattened[index] = grad_flattened_noise2[index]
    gradients = []
    index_i = 0
    for i in range(len(grad_shapes)):
        if i == 0:
            gradients.append(grad_flattened[0:grad_shapes[0][1]].reshape(grad_shapes[0][0]))
            index_i += grad_shapes[0][1]
        else:
            gradients.append(grad_flattened[index_i:index_i + grad_shapes[i][1]].reshape(grad_shapes[i][0]))
            index_i += grad_shapes[i][1]
    return gradients

## Loss

In [None]:
# define the loss function
cross_entropy = nn.NLLLoss()


def generator_loss(fake_output, attack_label):
    ideal_result = torch.zeros(fake_output.shape[0]) + attack_label
    return cross_entropy(fake_output, ideal_result.type(torch.LongTensor).to(device))


def discriminator_loss(fake_output):
    fake_result = torch.zeros(fake_output.shape[0]) + 10
    return cross_entropy(fake_output, fake_result.type(torch.LongTensor).to(device))

## Client model

In [None]:
class Client:
    def __init__(self, id):
        self.model = Discriminator()
        self.optimizer = torch.optim.SGD(
            self.model.parameters(), lr=0.001, weight_decay=1e-7
        )
        self.malicious = False
        self.device = try_gpu()
        self.model.to(device)
        self.criterion = nn.CrossEntropyLoss()
        self.id = id
        self.generator = None
        self.attack_round = 0
        self.attack_label = 7
        self.lr = 1
        self.dataloader = train_loaders[id]

        myInit(self.model)
        self.gradients = []
        self.prev_params = []
        self.weights = None
        self.gen_lr = 0.0005
        self.dis_lr = 0.0002
        self.clip = 5.0
        self.diff = False

    '''
    def download(self, model_parameters):
        self.model.load_state_dict(model_parameters)
    '''

    # download the parameters from the server
    def download(self, model_parameters, theta_d=1):
        for params, server_params in zip(self.model.parameters(), model_parameters):
            choice_mask = torch.where((torch.rand_like(params.data) < theta_d))
            params.data[choice_mask] = server_params.data[choice_mask]

    # upload the parameters to the server
    def upload_params(self):
        return self.model.state_dict()

    # upload the gradients to the server
    def upload_gradients(self, theta_u=1):
        #theta_u refers to uploading rate, we randomly select theta_u% parameters and mask them to 0
        for gradient in self.gradients:
            gradient[torch.where((torch.rand_like(gradient) > theta_u))] = 0
        if self.diff:
            self.gradients = differential_privacy(self.gradients, budget_per_para, gamma, theta, tau)
        return self.gradients

    # train process for client
    def federated_train(self, round=1):
        self.prev_params = []
        for param in self.model.parameters():
            self.prev_params.append(copy.deepcopy(param))
        self.model.to(self.device)
        dataloader = self.dataloader
        #print(len(dataloader))
        if self.malicious:
            dataloader = self.attack()
        train_accs = []
        #print(len(dataloader))
        for _ in range(round):
          for i, (x_batch, y_batch) in enumerate(dataloader):
              x_batch, y_batch = x_batch.to(self.device), y_batch.to(self.device)
              self.optimizer.zero_grad()
              y_pred = self.model(x_batch)
              acc = (y_pred.max(dim=-1)[1] == y_batch).sum().item() / y_batch.shape[0]
              loss = self.criterion(y_pred, y_batch)
              loss.backward()
              self.optimizer.step()
              train_accs.append(acc)
          train_acc = np.mean(train_accs)
          print('ID {:.0f} Train acc: {:.3f}'.format(self.id, train_acc))
        self.gradients = []
        for param, prev_param in zip(self.model.parameters(), self.prev_params):
              self.gradients.append((prev_param - param) / self.lr)
        self.weights = len(self.dataloader.dataset)/60000

    # attack process for malicious clients(GAN training process)
    def attack(self):
        print("Attack round ", self.attack_round, "begins.")
        self.attack_round = self.attack_round + 1
        gen_losses = []
        disc_losses = []

        if self.attack_round == 100:
            self.dis_lr /= 10
            self.gen_lr /= 10

        gen_optimizer = torch.optim.SGD(
            self.generator.parameters(), lr=self.gen_lr
        )
        dis_optimizer = torch.optim.SGD(
            self.model.parameters(), lr=self.dis_lr, weight_decay=1e-7
        )
        for i in range(200):
            if i % 100 == 0:
                print(i)

            # update_G
            self.generator.to(self.device)
            self.model.to(self.device)
            self.generator.train()
            self.generator.zero_grad()
            self.model.eval()
            noise = torch.rand([generated_picture_number, noise_size]).to(self.device)
            generated_images = self.generator(noise)
            gen_optimizer.zero_grad()
            fake_output = self.model(generated_images)
            gen_loss = generator_loss(fake_output, self.attack_label)
            gen_loss.backward()
            gen_optimizer.step()
            gen_losses.append(float(gen_loss))

            # update_D
            self.model.train()
            self.generator.eval()
            dis_optimizer.zero_grad()
            self.model.zero_grad()
            fake_output = self.model(generated_images.detach())
            disc_loss = discriminator_loss(fake_output)
            disc_loss.backward()
            dis_optimizer.step()
            disc_losses.append(float(disc_loss))

        noise = torch.rand([generated_picture_number, noise_size]).to(self.device)
        generated_images = self.generator(noise)
        generated_images = torch.squeeze(generated_images.detach(), 1).to('cpu')
        plt.plot(gen_losses, label='Gen_loss')
        plt.plot(disc_losses, label='Disc_loss')
        plt.show()
        dataloader = copy.deepcopy(self.dataloader)
        dataloader.dataset.data = torch.cat((self.dataloader.dataset.data, generated_images), 0)
        dataloader.dataset.targets = torch.cat(
            (self.dataloader.dataset.targets, torch.tensor(10).repeat(generated_picture_number)), 0)
        self.generate_image()
        return dataloader

    def generate_image(self):
        if self.malicious == False:
            assert 0, "I am innocent and would not generate any images"
        noise = torch.rand([36, noise_size]).to(self.device)
        generated_Image = self.generator(noise).to("cpu").detach().numpy()
        fig = plt.figure(figsize=(6, 6))
        for i in range(36):
            plt.subplot(6, 6, i + 1)
            plt.imshow(generated_Image[i, 0, :, :] * 127.5 + 127.5, cmap='gray')
            plt.axis('off')
        plt.show()

## Server Models

In [None]:
class Server:
    def __init__(self, clients):
        self.clients = clients
        self.model = Discriminator()
        self.model.to(device)
        self.warm_up_dataloader = None
        self.optimizer = torch.optim.SGD(
            self.model.parameters(), lr=0.01, weight_decay=1e-7, momentum=0.9
        )
        self.criterion = nn.CrossEntropyLoss()
        self.warm_up_epochs = 15
        myInit(self.model)
        self.lr = 1
        self.theta_u = 1
        self.theta_d = 1
        self.local_round = 1

    # warm_up process for server
    def warm_up(self):
        self.model.to(device)
        for i in range(self.warm_up_epochs):
            test_accs = []
            for j, (x_batch, y_batch) in enumerate(self.warm_up_dataloader):
                x_batch, y_batch = x_batch.to(device), y_batch.to(device)
                self.optimizer.zero_grad()
                y_pred = self.model(x_batch)
                acc = (y_pred.max(dim=-1)[1] == y_batch).sum().item() / y_batch.shape[0]
                loss = self.criterion(y_pred, y_batch)
                loss.backward()
                self.optimizer.step()
                test_accs.append(acc)
            test_acc = np.mean(test_accs)
            print('Test acc: {:.3f}'.format(test_acc))

    def updata_from_params(self, weight):

        uploaded_parameters = [c.upload_params() for c in self.clients]
        averaged_params = uploaded_parameters[0]

        for k in averaged_params.keys():
            for i in range(0, len(uploaded_parameters)):
                local_model_params = uploaded_parameters[i]
                w = weight[i]
                if i == 0:
                    averaged_params[k] = local_model_params[k] * w
                else:
                    averaged_params[k] += local_model_params[k] * w

        self.model.load_state_dict(averaged_params)

    # update the model from gradients
    def updata_from_gradients(self, weight=None):
        if weight is None:
            weight = np.ones(len(self.clients)) / len(self.clinets)

        self.uploaded_gradients = [c.upload_gradients(self.theta_u) for c in self.clients]
        aggregated_gradients = [
            torch.zeros_like(params) for params in self.model.parameters()
        ]
        len_gradients = len(aggregated_gradients)

        for i, gradients in enumerate(self.uploaded_gradients):
            for gradient_id in range(len_gradients):
                aggregated_gradients[gradient_id] += weight[i] * gradients[
                    gradient_id
                ]

        for params, grads in zip(self.model.parameters(), aggregated_gradients):
            params.data -= self.lr * grads

    # distribute the model to the clients in the client list
    def distribtue(self):
        for client in self.clients:
            client.download(self.model.parameters(), self.theta_d)

    def train(self):
        for clinet in self.clients:
          clinet.federated_train(self.local_round)

    def dataloader(self):
        for i in range(len(self.clients)):
            clients[i].dataloader = train_loaders[i]

    # test global model's accuracy in test set
    def test(self):
        test_accs = []
        #self.model.eval()
        for X, y in test_loader:
            # Copy the data to device.
            X, y = X.to(device), y.to(device)
            y_test = self.model(X)
            acc = (y_test.max(dim=-1)[1] == y).sum().item() / y.shape[0]
            test_accs.append(acc)
        test_acc = np.mean(test_accs)
        print("######################################")
        print('Test acc: {:.3f}'.format(test_acc))
        print("######################################")

In [None]:
# init clients
clients = [Client(i) for i in range(2)]
# set client 0 as the Adversary
clients[0].malicious = True
clients[0].generator = Generator(100)
# set attack label
clients[0].attack_label = 3
# init server
server = Server(clients)
server.warm_up_dataloader = warmup_dataloader
# warm up method
server.warm_up()
# federated learning
for i in range(200):
    # distribute model to clients
    server.distribtue()
    # trainng process
    server.train()
    # server model update
    server.updata_from_gradients([client.weights for client in server.clients])

    # use to weights to update the model
    # server.updata_from_params([client.weights for client in server.clients])

    # test server model performance by testset
    server.test()