In [None]:
import copy
import numpy as np
from torchvision import datasets, transforms
import torch
from torch import nn

from utils.sampling import mnist_noniid, mnist_iid, noniid_dirichlet
import argparse
from datetime import datetime
from models.Update import LocalUpdatePoison, LocalUpdate, LocalUpdateBack, LocalUpdateScaBack
from models.Fed import FedAvg
from models.test import test_img_poison
from models.Nets import LogisticRegression, SimpleCNN, ImprovedSimpleCNN

from attacks import sign_flipping_attack, additive_noise
from aggregations import aggregation
import matplotlib.pyplot as plt
import sys
import os
import random

In [None]:
def setup_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    
seed = 42
setup_seed(seed)

In [None]:
args_to_remove = [arg for arg in sys.argv if arg.startswith('--')]
for arg in args_to_remove:
    sys.argv.remove(arg)

In [None]:
parser = argparse.ArgumentParser()
# federated arguments
parser.add_argument('--epochs', type=int, default=150, help="rounds of training")

parser.add_argument('--attack_ratio', type=float, default=0.3, help= "ratio of attacker in sampled users")
parser.add_argument('--poison_ratio', type=float, default=0.2, help= "ratio of attacker in sampled users")
parser.add_argument('--num_users', type=int, default=100, help="number of users: K")
parser.add_argument('--sample_users', type=int, default=100, help="number of users in federated learning C")
parser.add_argument('--attack_mode', type=str, default="poison", choices=["poison", ""], help="implementation of untargeted attack")
parser.add_argument('--aggregation', type=str, default="MKrum", choices=["FedAvg", "atten", "Krum", "GeoMed"], help="name of aggregation method")
parser.add_argument('--vae_model', type=str, default="./VAE_data/netg_fashionmnist3005.pth", help="directory of vae_model")

parser.add_argument('--local_ep', type=int, default=5, help="the number of local epochs: E")
parser.add_argument('--local_bs', type=int, default=128, help="local batch size: B")
parser.add_argument('--bs', type=int, default=128, help="test batch size")
parser.add_argument('--lr', type=float, default=0.01, help="learning rate")
parser.add_argument('--momentum', type=float, default=0.5, help="SGD momentum (default: 0.5)")

# other arguments
parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset")
parser.add_argument('--gpu', type=int, default=0, help="GPU ID, -1 for CPU")
parser.add_argument('--verbose', action='store_true', help='verbose print')
parser.add_argument('--seed', type=int, default=42, help='random seed (default: 1)')

parser.add_argument('--isize', type=int, default=32, help='input image size.')
parser.add_argument('--channels', type=int, default=1, help='channels of totual data')
parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
parser.add_argument('--nc', type=int, default=1, help='input image channels')
parser.add_argument('--ndf', type=int, default=64)
parser.add_argument('--ngpu', type=int, default=4, help='number of GPUs to use')
parser.add_argument('--extralayers', type=int, default=0, help='Number of extra layers on gen and disc')
parser.add_argument('--ngf', type=int, default=64)
parser.add_argument('--w_bce', type=float, default=1, help='alpha to weight bce loss.')
parser.add_argument('--w_rec', type=float, default=50, help='alpha to weight reconstruction loss')
parser.add_argument('--w_enc', type=float, default=1, help='alpha to weight encoder loss')

args = parser.parse_args()
args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
print(args)

In [None]:
# load dataset and split users
# trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist)
# dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist)

In [None]:
trans_fashion_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.2860,), (0.3530,))])
dataset_train = datasets.FashionMNIST('../data/fashionmnist/', train=True, download=True, transform=trans_fashion_mnist)
dataset_test = datasets.FashionMNIST('../data/fashionmnist/', train=False, download=True, transform=trans_fashion_mnist)

In [None]:
# sample users
dict_users = mnist_iid(dataset_train, args.sample_users)

In [None]:
img_size = dataset_train[0][0].shape

In [None]:
# build model
input_size = 784
num_classes = 10
net_glob = LogisticRegression(input_size, num_classes).to(args.device)
# net_glob = ImprovedSimpleCNN().to(args.device)
net_glob.train()

In [None]:
# copy weights
w_glob = net_glob.state_dict()

In [None]:
# training
loss_train_epoch = []
acc_test_list = []
poison_acc_list = []
round_times = []  
DAR_list = []  
DPR_list = []  
RR_list = []  

In [None]:
for iteration in range(args.epochs):
    start_time = datetime.now()
    w_locals, loss_locals = [], []
    # m = max(int(args.frac * args.num_users), 1)
    idxs_users = np.random.choice(range(args.num_users), args.sample_users, replace=False)
    print("Randomly selected {}/{} users for federated learning. {}".format(args.sample_users, args.num_users, datetime.now().strftime("%H:%M:%S")))
    # attack
    attacker_num  = int( args.attack_ratio * args.sample_users)
    attacker_idxs = np.random.choice(range(args.sample_users), attacker_num, replace=False)
    print( "{}/{} are attackers with {} attack".format(attacker_num, args.sample_users, args.attack_mode) )
    
    result = np.setdiff1d(idxs_users, attacker_idxs)

    for idx in result:
        local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
        w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
        w_locals.append(copy.deepcopy(w))
        loss_locals.append(copy.deepcopy(loss))

    # attacker_idxs = []
    for idx in attacker_idxs:
        # local = LocalUpdatePoison(args=args, dataset=dataset_train, idxs=dict_users[idx], user_idx=idx)
        local = LocalUpdateBack(args=args, dataset=dataset_train, idxs=dict_users[idx], user_idx=idx)
        w, loss, attack_flag = local.train(net=copy.deepcopy(net_glob).to(args.device))
        w_locals.append(copy.deepcopy(w))
        loss_locals.append(copy.deepcopy(loss))
        # if attack_flag:
        #     attacker_idxs.append(np.where(idxs_users == idx)[0][0]) # indicate the sequence of attacker

    print("{} poison attackers in federated learning.".format(len(attacker_idxs)))
    
    print(attacker_idxs)

    # update global weights
    user_sizes = np.array([ len(dict_users[idx]) for idx in idxs_users ])
    user_weights = user_sizes / float(sum(user_sizes))
    if args.aggregation == "FedAvg":
        w_glob = FedAvg(w_locals, user_weights)
    else:
        w_glob, DAR, DPR, RR = aggregation(w_locals, user_weights, args, attacker_idxs, w_glob, dataset_test)
    
    # DAR_list.append(DAR)
    # DPR_list.append(DPR)
    # RR_list.append(RR)

    # copy weight to net_glob
    net_glob.load_state_dict(w_glob)

    # print loss
    loss_avg = np.sum(loss_locals * user_weights)

    print('=== Round {:3d}, Average loss {:.6f} ==='.format(iteration+1, loss_avg))
    print("{} users; time {}".format(len(idxs_users), datetime.now().strftime("%H:%M:%S")) )
    acc_test, loss_test, acc_per_label, poison_acc = test_img_poison(copy.deepcopy(net_glob), dataset_test, args)
    print( "Testing accuracy: {:.6f} loss: {:.6}".format(acc_test, loss_test))
    print( "Testing Label Acc: {}".format(acc_per_label) )
    print( "Poison Acc: {}".format(poison_acc) )
    print( "======")

    end_time = datetime.now()
    round_duration = end_time - start_time
    round_times.append(round_duration.total_seconds())
    print("Test end {}".format(datetime.now().strftime("%H:%M:%S")))

    loss_train_epoch.append(loss_avg)
    acc_test_list.append(acc_test)
    poison_acc_list.append(poison_acc)

print("=== End ===")

In [None]:
plt.plot(acc_test_list, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Test Accuracy Over Epochs')
plt.legend()
plt.show()

In [None]:
plt.plot(poison_acc_list, label='Poison Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Poison')
plt.title('Poison Over Epochs')
plt.legend()
plt.show()

In [None]:
plt.plot(loss_train_epoch, label='Train Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Train Loss Over Epochs')
plt.legend()
plt.show()

In [None]:
plt.plot(DAR_list, label='DAR')
plt.xlabel('Epoch')
plt.ylabel('DAR')
plt.title('DAR Over Epochs')
plt.legend()
plt.show()

In [None]:
plt.plot(DPR_list, label='DPR')
plt.xlabel('Epoch')
plt.ylabel('DPR')
plt.title('DPR Over Epochs')
plt.legend()
plt.show()

In [None]:
plt.plot(RR_list, label='RR')
plt.xlabel('Epoch')
plt.ylabel('RR')
plt.title('RR Over Epochs')
plt.legend()
plt.show()

In [None]:
plt.plot(round_times, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Test Accuracy Over Epochs')
plt.legend()
plt.show()

In [None]:
directory = './acc_test_results'
if not os.path.exists(directory):
    os.makedirs(directory)

file_name = f"{directory}/fashionmnist_aggregation_{args.aggregation}_attackmode_{args.attack_mode}_attackratio_{args.attack_ratio}.npy"
np.save(file_name, acc_test_list)

In [None]:
directory = './poison_acc_results'
if not os.path.exists(directory):
    os.makedirs(directory)

file_name = f"{directory}/fashionmnist_aggregation_{args.aggregation}_attackmode_{args.attack_mode}_attackratio_{args.attack_ratio}.npy"
np.save(file_name, poison_acc_list)

In [None]:
directory = './loss_train_results'
if not os.path.exists(directory):
    os.makedirs(directory)
    
file_name = f"{directory}/fashionmnist_aggregation_{args.aggregation}_attackmode_{args.attack_mode}_attackratio_{args.attack_ratio}.npy"
np.save(file_name, loss_train_epoch)

In [None]:
directory = './DAR_results'
if not os.path.exists(directory):
    os.makedirs(directory)

file_name = f"{directory}/fashionmnist_aggregation_{args.aggregation}_attackmode_{args.attack_mode}_attackratio_{args.attack_ratio}.npy"
np.save(file_name, DAR_list)

In [None]:
directory = './DPR_results'
if not os.path.exists(directory):
    os.makedirs(directory)

file_name = f"{directory}/fashionmnist_aggregation_{args.aggregation}_attackmode_{args.attack_mode}_attackratio_{args.attack_ratio}.npy"
np.save(file_name, DPR_list)

In [None]:
directory = './RR_results'
if not os.path.exists(directory):
    os.makedirs(directory)

file_name = f"{directory}/fashionmnist_aggregation_{args.aggregation}_attackmode_{args.attack_mode}_attackratio_{args.attack_ratio}.npy"
np.save(file_name, RR_list)

In [None]:
directory = './round_times_results'
if not os.path.exists(directory):
    os.makedirs(directory)

file_name = f"{directory}/fashionmnist_aggregation_{args.aggregation}_attackmode_{args.attack_mode}_attackratio_{args.attack_ratio}.npy"
np.save(file_name, round_times)