## Plan

- [x] FL base framework
- [x] implement non-iid sampling function following dirichlet dist

#### Attack simulation
- [ ] untargeted model poisoning
    - Local model poisoning attacks to Byzantine-robust federated learning <br>
    https://github.com/vrt1shjwlkr/NDSS21-Model-Poisoning fang attack <br>
    krum-attack, trimmed-mean/median attack
- [ ] targeted model poisoning
    - Analyzing federated learning through an adversarial lens <br>
    https://github.com/inspire-group/ModelPoisoning
- [ ] data poisoning
    - DBA: Distributed backdoor attacks against federated learning <br>
    https://github.com/AI-secure/DBA
    
#### Defense baseline
- [x] FedAvg
- [x] Krum
- [x] Multi-Krum
- [x] Bulyan
- [x] Coordinate median
- [ ] FoolsGold
- [ ] FLARE

#### Proposed method
- [ ] hypersphere uniformity loss as loss function
- [ ] extract PLRs
- [ ] apply RBF hypersphere CKA

In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6


import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import copy
import time
import pickle
import numpy as np
from tqdm import tqdm
from collections import OrderedDict

import torch
from tensorboardX import SummaryWriter

from options import args_parser
from update import LocalUpdate, test_inference, mal_inference
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar
from utils import get_dataset, get_mal_dataset, exp_details, flatten, construct_ordered_dict
from aggregate import fedavg, multi_krum, krum, coomed, bulyan, tr_mean, fed_align, fed_cka
from attacks import get_malicious_updates_untargeted_mkrum, get_malicious_updates_untargeted_med, get_malicious_updates_targeted
from cka import linear_CKA, kernel_CKA
# python src/federated_main.py --model=cnn --dataset=cifar --gpu=0 --iid=1 --epochs=10

import warnings
# warnings.filterwarnings('ignore')

In [3]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6


import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import copy
import time
import pickle
import numpy as np
from tqdm import tqdm
from collections import OrderedDict

import torch
from tensorboardX import SummaryWriter

from options import args_parser
from update import LocalUpdate, test_inference, mal_inference
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar
from utils import get_dataset, get_mal_dataset, exp_details, flatten
from aggregate import fedavg, krum, multi_krum, coomed, bulyan, flare
from attacks import get_malicious_updates_untargeted_mkrum, get_malicious_updates_untargeted_med

# python src/federated_main.py --model=cnn --dataset=cifar --gpu=0 --iid=1 --epochs=10



ImportError: cannot import name 'mkrum' from 'aggregate' (C:\Users\user\Documents\FL\src\aggregate.py)

# ATTACK FREE

In [None]:
class Args(object):
    
    # federated parameters (default values are set)
    epochs = 20
    num_users = 10
    frac = 1 # fraction of clients
    local_ep = 5 # num of local epoch
    local_bs = 128 # batch size
    lr = 0.001
    momentum = 0.9
    aggregation = 'coomed' # fedavg, krum, mkrum, coomed, bulyan, flare

    # model arguments
    model = 'cnn'
    kernel_num = 9 # num of each kind of kernel
    kernel_sizes = '3,4,5' # comma-separated kernel size to use for convolution
    num_channels = 1 # num of channels of imgs
    norm = 'batch_norm' # batch_norm, layer_norm, None
    num_filters = 32 # num of filters for conv nets -- 32 for mini-imagenet, 64 for omiglot
    max_pool = 'True' # whether use max pooling rather than strided convolutions
    
    # other arguments
    dataset = "fmnist"
    num_classes = 10 
    gpu = 0
    optimizer = 'adam'
    iid = 1 # 0 for non-iid
    alpha = 0.2 # noniid --> (0, 1] <-- iid
    unequal = 0 # whether to use unequal data splits for non-iid settings (0 for equal splits)
    stopping_rounds = 10 # rounds of early stopping
    verbose = 0
    seed = 1



In [None]:
from sklearn import preprocessing
from cka import linear_CKA, kernel_CKA
from sklearn.cluster import KMeans
from collections import Counter

flat_weights = malicious_grads
second_last_layer = list(local_weights[0].keys())[-4]
glob_plr = global_weights[second_last_layer]            
structured_local_weights = [construct_ordered_dict(global_model, flat_weights) for flat_weights in malicious_grads]
plrs = [(each_local[second_last_layer]) for each_local in structured_local_weights]

num_clients = len(plrs)
kernel_cka_vals = []
align_loss_vals = []
glob_plr = glob_plr.detach().cpu()

scaled_weights = []

for i in range(len(plrs)):
    kernel_cka_vals.append(kernel_CKA(glob_plr, plrs[i].detach().cpu()))

kmeans = KMeans(n_clusters=2, random_state=0).fit(np.array(kernel_cka_vals).reshape(-1, 1))
labels = kmeans.labels_
counter = Counter(labels)
print(counter)

majority = 1
minor = 0

if (counter[0] >= counter[1]):
    majority = 0
    minor = 1

selected_idx = np.where(labels == majority)
selected_parameters = []

#     for i in selected_idx:
#         selected_parameters.append(flat_weights[i].cpu().detach().numpy())
#     selected_parameters = torch.tensor(selected_parameters[0]).to('cuda:0')
#     agg_weights = torch.mean(selected_parameters, dim=0)

suspects = np.where(labels == minor)

    
for i in range(len(scale)):
    if i in suspects:
        scale[i] = 0
    else: 
        scale[i] = 1/scale[i]

print(scale)

for i in range(len(plrs)):
    selected_parameters.append(scale[i] * flat_weights[i].cpu().detach().numpy())
selected_parameters = torch.tensor(np.array(selected_parameters)).to('cuda:0')
agg_weights = torch.mean(selected_parameters, dim=0)
# agg_weights = torch.median(selected_parameters, dim=0)[0]



In [None]:

if __name__ == '__main__':
    start_time = time.time()

    # define paths
    path_project = os.path.abspath('..')
    logger = SummaryWriter('../logs')

    args = Args()
    exp_details(args)

    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

    # load dataset and user groups
    train_dataset, test_dataset, user_groups = get_dataset(args)

    # BUILD MODEL
    if args.model == 'cnn':
        # Convolutional neural netork
        if args.dataset == 'mnist':
            global_model = CNNMnist(args=args)
        elif args.dataset == 'fmnist':
            global_model = CNNFashion_Mnist(args=args)
        elif args.dataset == 'cifar':
            global_model = CNNCifar(args=args)

    elif args.model == 'mlp':
        # Multi-layer preceptron
        img_size = train_dataset[0][0].shape
        len_in = 1
        for x in img_size:
            len_in *= x
            global_model = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes)
    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    global_model.to(device)
    global_model.train()
    print(global_model)

    # copy weights
    global_weights = global_model.state_dict()

    # Training
    train_loss, train_accuracy = [], []
    val_acc_list, net_list = [], []
    cv_loss, cv_acc = [], []
    print_every = 1
    val_loss_pre, counter = 0, 0

    mal_user = False
    
    for epoch in tqdm(range(args.epochs)):
        local_weights, local_losses = [], []
        print('=========================================')
        print(f'| Global Training Round : {epoch+1} |')
        print('=========================================')

        global_model.train()
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)

        flattened_local_weights = []
        
        for idx in idxs_users:
            local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=user_groups[idx], logger=logger)
            w, loss = local_model.update_weights(model=copy.deepcopy(global_model), global_round=epoch)
            
            # get new model
            new_model = copy.deepcopy(global_model)
            new_model.load_state_dict(w)
            acc, _ = local_model.inference(model=new_model)
            print('user {}, loss {.2f}, acc {.2f}'.format(idx, loss, 100*acc))
            
            local_weights.append(copy.deepcopy(w))
            local_losses.append(copy.deepcopy(loss))
            # flatten the local weight (list of ordereddict to a tensor of lists)
            flattened_local_weights.append(flatten(w))
        flattened_local_weights = torch.tensor(np.array(flattened_local_weights)).to(device)
        ## ATTACK TAKES PLACE HERE    
        
        
        # update global weights
        if args.aggregation == 'fedavg':
            agg_weights = fedavg(flattened_local_weights)
        elif args.aggregation == 'krum':
            agg_weights, krum_candidates = mkrum(flattened_local_weights, 0, multi_k=False)
        elif args.aggregation == 'mkrum':
            agg_weights, krum_candidates = mkrum(flattened_local_weights, 0, multi_k=True)
        elif args.aggregation == 'coomed':
            agg_weights = coomed(flattened_local_weights)
        elif args.aggregation == 'bulyan':
            agg_weights, bulyan_candidate = bulyan(flattened_local_weights, 0)
        elif args.aggregation == 'flare':
            agg_weights = flare(flattened_local_weights)
        else:
            raise ValueError('Unknown aggregation strategy: {}'.format(args.aggregation))

            
        # reshape the flattened global weights into the ordereddict
        keys = global_model.state_dict().keys()
        start_idx = 0
        model_grads = []

        for i, param in enumerate(global_model.parameters()):
            param_ = agg_weights[start_idx:start_idx + len(param.data.view(-1))].reshape(param.data.shape)
            start_idx = start_idx + len(param.data.view(-1))
            param_ = param_.cuda()
            model_grads.append(param_)
            
        global_weights = OrderedDict(dict(zip(keys, model_grads)))  
        
        # update global weights
        global_model.load_state_dict(global_weights)

        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)

        # # Calculate avg training accuracy over all users at every epoch
        # list_acc, list_loss = [], []
        # global_model.eval()
        # for c in range(args.num_users):
        #     local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=user_groups[c], logger=logger)
        #     acc, loss = local_model.inference(model=global_model)
        #     list_acc.append(acc)
        #     list_loss.append(loss)
        # train_accuracy.append(sum(list_acc)/len(list_acc))

        
        # print global training loss after every 'i' rounds
        if (epoch+1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch+1} global rounds:')
            print(f'Training Loss : {np.mean(np.array(train_loss))}')
            test_acc, test_loss = test_inference(args, global_model, test_dataset)
            print('\nGlobal model Benign Test Accuracy: {:.2f}% '.format(100*val_acc))
            if mal_user:
                mal_acc, mal_loss, mal_out = mal_inference(args, global_model, test_dataset, mal_X_list, mal_Y)
                print('Global model Malicious Accuracy: {:.2f}%, Malicious Loss: {:.2f} , Outputs: {}\n'.format(100*mal_acc, mal_loss, mal_out))

    # Test inference after completion of training
    test_acc, test_loss = test_inference(args, global_model, test_dataset)

    print(f' \n Results after {args.epochs} global rounds of training:')
    # print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
    print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))

    # Saving the objects train_loss and train_accuracy:
    file_name = '../save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.\
        format(args.dataset, args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs)

    with open(file_name, 'wb') as f:
        pickle.dump([train_loss, train_accuracy], f)

    print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))


In [None]:
train_dataset

In [None]:
len(user_groups)

In [None]:

# PLOTTING (optional)
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Agg')

# Plot Loss curve
plt.figure()
plt.title('Training Loss vs Communication rounds')
plt.plot(range(len(train_loss)), train_loss, color='r')
plt.ylabel('Training loss')
plt.xlabel('Communication Rounds')
plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_loss.png'.
            format(args.dataset, args.model, args.epochs, args.frac,
                   args.iid, args.local_ep, args.local_bs))

# Plot Average Accuracy vs Communication rounds
plt.figure()
plt.title('Average Accuracy vs Communication rounds')
plt.plot(range(len(train_accuracy)), train_accuracy, color='k')
plt.ylabel('Average Accuracy')
plt.xlabel('Communication Rounds')
plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_acc.png'.
            format(args.dataset, args.model, args.epochs, args.frac,
                   args.iid, args.local_ep, args.local_bs))