## Plan

- [x] FL base framework
- [x] implement non-iid sampling function following dirichlet dist

#### Attack simulation
- [x] untargeted model poisoning
    - Local model poisoning attacks to Byzantine-robust federated learning <br>
    https://github.com/vrt1shjwlkr/NDSS21-Model-Poisoning fang attack <br>
    krum-attack, trimmed-mean/median attack
- [x] targeted model poisoning
    - Analyzing federated learning through an adversarial lens <br>
    https://github.com/inspire-group/ModelPoisoning
- [x] data poisoning
    - DBA: Distributed backdoor attacks against federated learning <br>
    https://github.com/AI-secure/DBA
    
#### Defense baseline
- [x] FedAvg
- [x] Krum
- [x] Multi-Krum
- [x] Bulyan
- [x] Coordinate median
- [x] FLARE

#### Proposed method
- [x] extract PLRs
- [x] apply RBF hypersphere CKA

In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [2]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6


import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import copy
import time
import pickle
import numpy as np
from tqdm import tqdm
from collections import OrderedDict

import torch
from tensorboardX import SummaryWriter

from options import args_parser
from update import LocalUpdate, test_inference, mal_inference
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar, Alexnet, modelC, LeNet5
from utils import get_dataset, get_mal_dataset, exp_details, flatten, construct_ordered_dict
from aggregate import fedavg, multi_krum, krum, coomed, bulyan, tr_mean, fed_align, fed_cc, flare, fltrust
from attacks import get_malicious_updates_untargeted_mkrum, get_malicious_updates_untargeted_med, get_malicious_updates_targeted
from cka import linear_CKA, kernel_CKA
# python src/federated_main.py --model=cnn --dataset=cifar --gpu=0 --iid=1 --epochs=10

import warnings
# warnings.filterwarnings('ignore')

# Targeted Model Poisoning Attack (label flipping)

In [3]:
class Args(object):
    
    # federated parameters (default values are set)
    epochs = 20
    num_users = 10
    frac = 1 # fraction of clients
    local_ep = 3 # num of local epoch
    local_bs = 64 # batch size
    lr = 0.01
    momentum = 0.9
    aggregation = 'fedavg' # fedavg, krum, mkrum, coomed, bulyan, flare, fedcc
    
    # model arguments
    model = 'cnn'
    kernel_num = 9 # num of each kind of kernel
    kernel_sizes = '3,4,5' # comma-separated kernel size to use for convolution
    norm = 'batch_norm' # batch_norm, layer_norm, None
    num_filters = 32 # num of filters for conv nets -- 32 for mini-imagenet, 64 for omiglot
    max_pool = 'True' # whether use max pooling rather than strided convolutions
    
    # other arguments
    dataset = 'fmnist' # fmnist, cifar, mnist
    if dataset == 'cifar100':
        num_classes = 100 
        num_channels = 3 # num of channels of imgs
    else:
        num_classes = 10
        num_channels = 1
    
    gpu = 0
    optimizer = 'sgd'
    iid = 1 # 0 for non-iid
    alpha = 0.2 # noniid --> (0, 1] <-- iid
    unequal = 0 # whether to use unequal data splits for non-iid settings (0 for equal splits)
    stopping_rounds = 10 # rounds of early stopping
    verbose = 0
    seed = 1

    # malicious arguments
    mal_clients = [0,1,2,3] # indices of malicious user
    attack_type = 'targeted' # targeted
    num_mal = 5 # number of maliciuos data sample
    mal_lr = 0.005
    local_mal_ep = 6
    boost = 2 # alpha: 2 for fedavg, 3.5 for krum
    mal_bs = 64

In [4]:

if __name__ == '__main__':
    start_time = time.time()

    # define paths
    path_project = os.path.abspath('..')
    logger = SummaryWriter('../logs')

    args = Args()
    exp_details(args)

    device = 'cuda:0' if args.gpu == 0 else 'cpu'
    # device = 'cpu'
    # for n_attacker in args.n_attackers:
    torch.cuda.empty_cache()

    # load dataset and user groups
    train_dataset, test_dataset, user_groups = get_dataset(args)

    # BUILD MODEL
    if args.model == 'cnn':
        # Convolutional neural netork
        if args.dataset == 'mnist':
            global_model = CNNMnist(args=args)
        elif args.dataset == 'fmnist':
            global_model = CNNFashion_Mnist(args=args)
        elif args.dataset == 'cifar':
            global_model = CNNCifar(args=args)
        elif args.dataset == 'cifar100':
            global_model = LeNet5(args=args)

            
    # elif args.model == 'alexnet':
    #     if args.dataset == 'cifar100':
    #         global_model = Alexnet(args=args)
    
    elif args.model == 'mlp':
        # Multi-layer preceptron
        img_size = train_dataset[0][0].shape
        len_in = 1
        for x in img_size:
            len_in *= x
            global_model = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes)
    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    global_model.to(device)
    global_model.train()
    print(global_model)

    # copy weights
    global_weights = global_model.state_dict()
    if len(args.mal_clients) > 0:
        mal_X_list, mal_Y, Y_true = get_mal_dataset(test_dataset, args.num_mal, args.num_classes)
        print("malcious dataset true labels: {}, malicious labels: {}".format(Y_true, mal_Y))

    # Training
    train_loss, train_accuracy = [], []
    val_acc_list, net_list = [], []
    cv_loss, cv_acc = [], []
    print_every = 1
    val_loss_pre, counter = 0, 0
    
    confidence = []    
    for epoch in tqdm(range(args.epochs)):
        local_weights, local_losses = [], []
        
        print('=========================================')
        print(f'| Global Training Round : {epoch+1} |')
        print('=========================================')

        global_model.train()
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)

        # flattened weights + bias
        flattened_local_weights = []
        
        # create separate arrays for weights and biases and the offsets 
        only_weights = [] # for krum to consider only weights, not the biases
        only_biases = []
            
        for idx in range(args.num_users):
            mal_user = False
            
            # alternating benign and malicious training 
            if idx in args.mal_clients:# and epoch % 2 == 0:
                mal_user = True
                local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=user_groups[idx], logger=logger, \
                                      mal=mal_user, mal_X=mal_X_list, mal_Y=mal_Y, test_dataset=test_dataset)
            else:
                local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=user_groups[idx], logger=logger)
                
            w_prev = global_model.state_dict()
            w, loss = local_model.update_weights(model=copy.deepcopy(global_model), global_round=epoch)
            # boost the malicious (weight+bias) by alpha
            # if mal_user:
            #     flat_delta_m = flatten(w) - flatten(w_prev)
            #     flat_mal_w = flatten(w_prev) + args.boost * flat_delta_m
            #     w = construct_ordered_dict(global_model, torch.tensor(flat_mal_w).to(device))
              
            # construct two arrays with weights-only and bias-only
            if 'krum' in args.aggregation:
                for key in w:
                    if 'weight' in key:
                        if mal_user and epoch % 2 == 0:
                            only_weights = np.append(only_weights, args.boost * w[key].detach().cpu().numpy().reshape(-1))
                        else:
                            only_weights = np.append(only_weights, w[key].detach().cpu().numpy().reshape(-1))
                    elif 'bias' in key:
                        only_biases = np.append(only_biases, w[key].detach().cpu().numpy().reshape(-1))
                            
            new_model = copy.deepcopy(global_model)
            new_model.load_state_dict(w)
            acc, _ = local_model.inference(model=new_model)

            if mal_user == True:
                mal_acc, mal_loss = local_model.mal_inference(model=new_model)
                print('user {}, loss {}, acc {}, mal loss {}, mal acc {}'.format(idx, loss, 100*acc, mal_loss, 100*mal_acc))
            else:
                print('user {}, loss {}, acc {}'.format(idx, loss, 100*acc))

            local_weights.append(copy.deepcopy(w))
            local_losses.append(copy.deepcopy(loss))
            
            # if agg==krum, alternate training
            if args.aggregation == 'krum':
                if idx in args.mal_clients and epoch %2 ==0:
                    flattened_local_weights.append(flatten(w_prev) + args.boost * (flatten(w) - flatten(w_prev)))
                else:                     
                    flattened_local_weights.append(flatten(w))
            else:
                # if malicious user: boost the detlta weights
                if idx in args.mal_clients:
                    flattened_local_weights.append(flatten(w_prev) + args.boost * (flatten(w) - flatten(w_prev)))
                else: 
                    flattened_local_weights.append(flatten(w))
        
        only_weights = torch.tensor(np.array(only_weights)).to(device)
                
        flattened_local_weights = torch.tensor(np.array(flattened_local_weights)).to(device)
        malicious_grads = flattened_local_weights

        n_attacker = len(args.mal_clients)
        
        
        # update global weights
        if args.aggregation == 'fedavg':
            agg_weights = fedavg(malicious_grads)
        elif args.aggregation == 'krum':
            agg_weights, selected_idxs = krum(malicious_grads, n_attacker, only_weights=only_weights)
            print(f'Krum Selected idx: {selected_idxs}')
        elif args.aggregation == 'mkrum':
            agg_weights, selected_idxs = multi_krum(malicious_grads, n_attacker, only_weights=only_weights)
            print(f'multiKrum Selected idxs: {selected_idxs}')
        elif args.aggregation == 'coomed':
            agg_weights = coomed(malicious_grads)
            print(f'\ndiff {torch.norm((agg_weights - flattened_local_weights[0])) ** 2}')
        elif args.aggregation == 'bulyan':
            agg_weights, selected_idxs = bulyan(malicious_grads, n_attacker)
            print(f'Bulyan Selected idx: {selected_idxs}')
        elif args.aggregation == 'trmean':
            agg_weights = tr_mean(malicious_grads, n_attacker)
            
        elif args.aggregation == 'fltrust':
            glob_weights = []
            glob_weights.append(flatten(global_weights))
            agg_weights = fltrust(malicious_grads, glob_weights)

        elif args.aggregation == 'flare':
            second_last_layer = list(local_weights[0].keys())[-4]
            structured_local_weights = [construct_ordered_dict(global_model, flat_weights) for flat_weights in malicious_grads]
            plrs = [(each_local[second_last_layer]) for each_local in structured_local_weights]
            agg_weights, count_dict = flare(malicious_grads, plrs)
            print(f'flare count_dict: {count_dict}')

        elif args.aggregation == 'fedcc':
            second_last_layer = list(local_weights[0].keys())[-4]
            glob_plr = global_weights[second_last_layer]            
            structured_local_weights = [construct_ordered_dict(global_model, flat_weights) for flat_weights in malicious_grads]
            plrs = [(each_local[second_last_layer]) for each_local in structured_local_weights]
            agg_weights, selected_idxs = fed_cc(malicious_grads, glob_plr, plrs, 'kernel')
            print(f'fed_cc Selected idx: {selected_idxs}')
        else:
            raise ValueError('Unknown aggregation strategy: {}'.format(args.aggregation))


        # reshape the flattened global weights into the ordereddict
        global_weights = construct_ordered_dict(global_model, agg_weights)
        
        # update global weights
        global_model.load_state_dict(global_weights)

        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)

        # print global training loss after every 'i' rounds
        if (epoch+1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch+1} global rounds:')
            print(f'Training Loss : {np.mean(np.array(train_loss))}')
            
            test_acc, test_loss = test_inference(args, global_model, test_dataset)
            print('\nGlobal model Benign Test Accuracy: {:.2f}% '.format(100*test_acc))
            
            if len(args.mal_clients) > 0:
                mal_acc, mal_loss, mal_out = mal_inference(args, global_model, test_dataset, mal_X_list, mal_Y)
                print('Global model Malicious Accuracy: {:.2f}%, Malicious Loss: {:.2f}, confidence: {}\n'.format(100*mal_acc, mal_loss, mal_out))
                confidence.append(mal_out[0].item())
                
    # Test inference after completion of training
    test_acc, test_loss = test_inference(args, global_model, test_dataset)

    print(f' \n Results after {args.epochs} global rounds of training:')
    print("|---- Test Benign Accuracy: {:.2f}%".format(100*test_acc))
    
    if len(args.mal_clients) > 0:
        mal_acc, mal_loss, mal_out = mal_inference(args, global_model, test_dataset, mal_X_list, mal_Y)
        print("|---- Test Malicious Accuracy: {:.2f}%, Malicious Loss: {:.2f}, confidence:{}\n".format(100*mal_acc, mal_loss, mal_out[0]))
    
    print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))



Experimental details:
    Model     : cnn
    Optimizer : sgd
    Learning  : 0.01
    Aggregation     : fedavg
    Global Rounds   : 20

    Federated parameters:
    IID
    Fraction of users    : 1
    Local Batch size     : 64
    Local Epochs         : 3

    Malicious parameters:
    Attackers            : [0, 1, 2, 3]
    Attack Type          : targeted
CNNFashion_Mnist(
  (conv1): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=25600, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)
malcious dataset true labels: 7, malicious labels: [9, 9, 9, 9, 9]


  0%|          | 0/20 [00:00<?, ?it/s]

| Global Training Round : 1 |


  return F.conv2d(input, weight, bias, self.stride,


user 0, loss 0.6270526064592495, acc 81.66666666666667, mal loss 0.6820063591003418, mal acc 80.0
user 1, loss 0.643422057852149, acc 79.66666666666666, mal loss 0.8096345067024231, mal acc 80.0
user 2, loss 0.6183002554370385, acc 80.66666666666666, mal loss 0.40721702575683594, mal acc 100.0
user 3, loss 0.6302642865727345, acc 83.66666666666667, mal loss 0.7034151554107666, mal acc 100.0
user 4, loss 0.6774219682481554, acc 81.83333333333334
user 5, loss 0.6925896481672922, acc 80.66666666666666
user 6, loss 0.7062370313538445, acc 81.66666666666667
user 7, loss 0.6780280296007791, acc 82.16666666666667
user 8, loss 0.6837983073128595, acc 81.33333333333333
user 9, loss 0.7248590660095214, acc 80.83333333333333
 
Avg Training Stats after 1 global rounds:
Training Loss : 0.6681973257013624


  5%|▌         | 1/20 [00:21<06:39, 21.02s/it]


Global model Benign Test Accuracy: 80.04% 
Global model Malicious Accuracy: 0.00%, Malicious Loss: 5.81, confidence: tensor([0.0006, 0.0029, 0.0053, 0.0012, 0.0211])

| Global Training Round : 2 |
user 0, loss 0.3528284011898856, acc 84.66666666666667, mal loss 0.04025585949420929, mal acc 100.0
user 1, loss 0.36888545199313705, acc 82.66666666666667, mal loss 0.07552474737167358, mal acc 100.0
user 2, loss 0.34766923374774167, acc 82.83333333333334, mal loss 0.029640858992934227, mal acc 100.0
user 3, loss 0.3453940854881678, acc 83.66666666666667, mal loss 0.012029449455440044, mal acc 100.0
user 4, loss 0.3684672062926822, acc 83.0
user 5, loss 0.3797462555434969, acc 86.0
user 6, loss 0.39673195540905, acc 86.33333333333333
user 7, loss 0.37338574330012003, acc 85.33333333333334
user 8, loss 0.36581916146808197, acc 86.16666666666667
user 9, loss 0.38454149358802375, acc 86.16666666666667
 
Avg Training Stats after 2 global rounds:
Training Loss : 0.5182721122517006


 10%|█         | 2/20 [00:42<06:19, 21.09s/it]


Global model Benign Test Accuracy: 86.79% 
Global model Malicious Accuracy: 0.00%, Malicious Loss: 6.20, confidence: tensor([0.0010, 0.0003, 0.0149, 0.0040, 0.0021])

| Global Training Round : 3 |
user 0, loss 0.2336464220037063, acc 86.66666666666667, mal loss 0.007226706948131323, mal acc 100.0
user 1, loss 0.24878283420152833, acc 85.83333333333333, mal loss 0.010681947693228722, mal acc 100.0
user 2, loss 0.23426464688555712, acc 85.0, mal loss 0.010963888838887215, mal acc 100.0
user 3, loss 0.2284793574362993, acc 87.16666666666667, mal loss 0.019302647560834885, mal acc 100.0
user 4, loss 0.29043713953759936, acc 85.5
user 5, loss 0.3007071203655667, acc 86.83333333333333
user 6, loss 0.30447367860211266, acc 87.16666666666667
user 7, loss 0.29675637592871984, acc 87.66666666666667
user 8, loss 0.2846468566192521, acc 88.0
user 9, loss 0.2958892338805728, acc 86.5
 
Avg Training Stats after 3 global rounds:
Training Loss : 0.43611753034983086


 15%|█▌        | 3/20 [01:03<06:00, 21.20s/it]


Global model Benign Test Accuracy: 88.54% 
Global model Malicious Accuracy: 0.00%, Malicious Loss: 9.90, confidence: tensor([1.9078e-05, 1.0316e-04, 6.5724e-06, 6.9887e-06, 3.5611e-03])

| Global Training Round : 4 |
user 0, loss 0.17286841846234577, acc 86.16666666666667, mal loss 0.016335006803274155, mal acc 100.0
user 1, loss 0.1830921686507696, acc 85.0, mal loss 0.019790716469287872, mal acc 100.0
user 2, loss 0.17140903260166707, acc 87.0, mal loss 0.004894697107374668, mal acc 100.0
user 3, loss 0.1732120960215597, acc 86.16666666666667, mal loss 0.01784670539200306, mal acc 100.0
user 4, loss 0.25548429552051755, acc 87.0
user 5, loss 0.2645042683349715, acc 86.33333333333333
user 6, loss 0.2622986395822631, acc 86.83333333333333
user 7, loss 0.2548699487580193, acc 87.0
user 8, loss 0.2493265222509702, acc 85.66666666666667
user 9, loss 0.252203498284022, acc 86.0
 
Avg Training Stats after 4 global rounds:
Training Loss : 0.3830698699740508


 20%|██        | 4/20 [01:24<05:39, 21.24s/it]


Global model Benign Test Accuracy: 88.67% 
Global model Malicious Accuracy: 0.00%, Malicious Loss: 9.38, confidence: tensor([4.8167e-05, 4.0346e-05, 4.8928e-04, 6.3168e-04, 7.0892e-06])

| Global Training Round : 5 |
user 0, loss 0.13862524789414907, acc 86.5, mal loss 0.0121835982427001, mal acc 100.0
user 1, loss 0.14767035868492698, acc 87.66666666666667, mal loss 0.019035454839468002, mal acc 100.0
user 2, loss 0.14755546876735853, acc 87.0, mal loss 0.0031513452995568514, mal acc 100.0
user 3, loss 0.1371762348918996, acc 89.66666666666666, mal loss 0.08473721891641617, mal acc 100.0
user 4, loss 0.22455196330944696, acc 87.16666666666667
user 5, loss 0.2382065710425377, acc 87.33333333333333
user 6, loss 0.24074905069337948, acc 87.83333333333333
user 7, loss 0.2286805563999547, acc 88.0
user 8, loss 0.22033159022529922, acc 89.0
user 9, loss 0.2201579425897863, acc 87.66666666666667
 
Avg Training Stats after 5 global rounds:
Training Loss : 0.3453299956692154


 25%|██▌       | 5/20 [01:45<05:17, 21.16s/it]


Global model Benign Test Accuracy: 89.29% 
Global model Malicious Accuracy: 0.00%, Malicious Loss: 11.48, confidence: tensor([1.1229e-07, 2.6833e-05, 7.0932e-05, 1.1708e-04, 4.7615e-06])

| Global Training Round : 6 |
user 0, loss 0.1113970436854288, acc 87.66666666666667, mal loss 0.008030777797102928, mal acc 100.0
user 1, loss 0.12313907336353798, acc 86.5, mal loss 0.020308425650000572, mal acc 100.0
user 2, loss 0.12638689661112598, acc 87.5, mal loss 0.005248176399618387, mal acc 100.0
user 3, loss 0.11172253417725299, acc 88.5, mal loss 0.00928876455873251, mal acc 100.0
user 4, loss 0.20540408555004333, acc 87.0
user 5, loss 0.21736460412542025, acc 86.66666666666667
user 6, loss 0.22122865796089175, acc 87.16666666666667
user 7, loss 0.2102776779068841, acc 89.33333333333333
user 8, loss 0.2013616005248494, acc 88.33333333333333
user 9, loss 0.20516560895575417, acc 86.83333333333333
 
Avg Training Stats after 6 global rounds:
Training Loss : 0.31666579277203266


 30%|███       | 6/20 [02:07<04:56, 21.21s/it]


Global model Benign Test Accuracy: 89.12% 
Global model Malicious Accuracy: 0.00%, Malicious Loss: 11.52, confidence: tensor([5.6460e-08, 1.1615e-05, 2.0000e-05, 3.4942e-04, 2.0898e-05])

| Global Training Round : 7 |
user 0, loss 0.0960557154122446, acc 87.83333333333333, mal loss 0.0029334735590964556, mal acc 100.0
user 1, loss 0.10790389824225649, acc 85.66666666666667, mal loss 0.015023112297058105, mal acc 100.0
user 2, loss 0.10874873578237991, acc 88.0, mal loss 0.0012791731860488653, mal acc 100.0
user 3, loss 0.09363412261907861, acc 88.33333333333333, mal loss 0.0019469615072011948, mal acc 100.0
user 4, loss 0.1951954478936063, acc 87.16666666666667
user 5, loss 0.210994601017899, acc 88.66666666666667
user 6, loss 0.20380866703059938, acc 87.83333333333333
user 7, loss 0.20218426292141278, acc 87.16666666666667
user 8, loss 0.1904540997909175, acc 89.5
user 9, loss 0.19196047397951285, acc 88.33333333333333
 
Avg Training Stats after 7 global rounds:
Training Loss : 0.294

 35%|███▌      | 7/20 [02:28<04:36, 21.24s/it]


Global model Benign Test Accuracy: 89.56% 
Global model Malicious Accuracy: 0.00%, Malicious Loss: 13.71, confidence: tensor([5.4203e-06, 2.5538e-05, 1.1070e-10, 7.6773e-06, 1.4669e-05])

| Global Training Round : 8 |
user 0, loss 0.0813888918073325, acc 87.5, mal loss 0.008009561337530613, mal acc 100.0
user 1, loss 0.09049366403926622, acc 86.33333333333333, mal loss 0.001315420726314187, mal acc 100.0
user 2, loss 0.09509760309051592, acc 88.66666666666667, mal loss 0.008629529736936092, mal acc 100.0
user 3, loss 0.08112498822932442, acc 90.83333333333333, mal loss 0.004227896220982075, mal acc 100.0
user 4, loss 0.1782347849259774, acc 87.5
user 5, loss 0.19399462746249305, acc 87.83333333333333
user 6, loss 0.19206286816133392, acc 88.16666666666667
user 7, loss 0.188140260775884, acc 87.16666666666667
user 8, loss 0.1729061340706216, acc 88.16666666666667
user 9, loss 0.17874258736769358, acc 88.33333333333333
 
Avg Training Stats after 8 global rounds:
Training Loss : 0.275663

 40%|████      | 8/20 [02:49<04:13, 21.16s/it]


Global model Benign Test Accuracy: 89.85% 
Global model Malicious Accuracy: 0.00%, Malicious Loss: 12.33, confidence: tensor([1.9819e-06, 4.5321e-08, 3.9872e-05, 4.9619e-06, 9.3787e-05])

| Global Training Round : 9 |
user 0, loss 0.0736085375681062, acc 88.66666666666667, mal loss 0.05901901796460152, mal acc 100.0
user 1, loss 0.07741780068190246, acc 87.0, mal loss 0.003601374104619026, mal acc 100.0
user 2, loss 0.08679260643762782, acc 89.0, mal loss 0.002858149353414774, mal acc 100.0
user 3, loss 0.07142403479634883, acc 90.33333333333333, mal loss 0.07243362814188004, mal acc 100.0
user 4, loss 0.17228095601830218, acc 86.33333333333333
user 5, loss 0.18812519384755025, acc 88.0
user 6, loss 0.17615216991139782, acc 88.33333333333333
user 7, loss 0.17609657959805594, acc 86.5
user 8, loss 0.1703572317875094, acc 88.83333333333333
user 9, loss 0.1626924377017551, acc 86.16666666666667
 
Avg Training Stats after 9 global rounds:
Training Loss : 0.26008912832545406


 45%|████▌     | 9/20 [03:10<03:53, 21.22s/it]


Global model Benign Test Accuracy: 89.55% 
Global model Malicious Accuracy: 0.00%, Malicious Loss: 13.76, confidence: tensor([4.2110e-07, 1.0481e-06, 1.1518e-06, 4.8545e-06, 5.3378e-07])

| Global Training Round : 10 |
user 0, loss 0.06314282366146105, acc 89.16666666666667, mal loss 0.004230614751577377, mal acc 100.0
user 1, loss 0.06961070107737262, acc 88.16666666666667, mal loss 0.013107463717460632, mal acc 100.0
user 2, loss 0.08615305930531274, acc 87.33333333333333, mal loss 0.0009373066131956875, mal acc 100.0
user 3, loss 0.06454052010922808, acc 90.33333333333333, mal loss 0.04472721368074417, mal acc 100.0
user 4, loss 0.166881033844418, acc 88.33333333333333
user 5, loss 0.17577302848299345, acc 87.66666666666667
user 6, loss 0.16751062117516993, acc 87.5
user 7, loss 0.1718650705450111, acc 86.83333333333333
user 8, loss 0.1625262234856685, acc 89.33333333333333
user 9, loss 0.1512935843070348, acc 89.33333333333333
 
Avg Training Stats after 10 global rounds:
Training 

 50%|█████     | 10/20 [03:32<03:32, 21.24s/it]


Global model Benign Test Accuracy: 89.57% 
Global model Malicious Accuracy: 0.00%, Malicious Loss: 13.43, confidence: tensor([1.0259e-05, 1.2595e-08, 1.9142e-05, 1.0297e-08, 2.7139e-04])

| Global Training Round : 11 |
user 0, loss 0.060638743745297874, acc 87.66666666666667, mal loss 0.0027693291194736958, mal acc 100.0
user 1, loss 0.061884440730321795, acc 86.83333333333333, mal loss 0.03929262235760689, mal acc 100.0
user 2, loss 0.07335236200474594, acc 88.16666666666667, mal loss 0.018928099423646927, mal acc 100.0
user 3, loss 0.05924797926497392, acc 91.33333333333333, mal loss 0.0036435741931200027, mal acc 100.0
user 4, loss 0.15637496040099197, acc 86.83333333333333
user 5, loss 0.17366860521336394, acc 87.66666666666667
user 6, loss 0.1630550971875588, acc 87.5
user 7, loss 0.1626984598984321, acc 87.66666666666667
user 8, loss 0.14980550573103957, acc 89.33333333333333
user 9, loss 0.14904716447823577, acc 88.66666666666667
 
Avg Training Stats after 11 global rounds:
Tra

 55%|█████▌    | 11/20 [03:53<03:11, 21.24s/it]


Global model Benign Test Accuracy: 89.81% 
Global model Malicious Accuracy: 0.00%, Malicious Loss: 15.81, confidence: tensor([1.0479e-06, 3.7925e-09, 3.2148e-09, 4.5022e-07, 8.2831e-06])

| Global Training Round : 12 |
user 0, loss 0.05769658367314964, acc 88.66666666666667, mal loss 0.003780986415222287, mal acc 100.0
user 1, loss 0.05693424586410913, acc 87.5, mal loss 0.00364127429202199, mal acc 100.0
user 2, loss 0.06838306148175002, acc 89.16666666666667, mal loss 0.0009597347234375775, mal acc 100.0
user 3, loss 0.05074487176809048, acc 90.0, mal loss 0.002219616435468197, mal acc 100.0
user 4, loss 0.1465825576417976, acc 86.66666666666667
user 5, loss 0.16375397875905037, acc 88.66666666666667
user 6, loss 0.15370913083354631, acc 87.83333333333333
user 7, loss 0.16720129425326982, acc 86.83333333333333
user 8, loss 0.15385301839146348, acc 89.33333333333333
user 9, loss 0.14492020805676778, acc 87.33333333333333
 
Avg Training Stats after 12 global rounds:
Training Loss : 0.

 60%|██████    | 12/20 [04:14<02:50, 21.28s/it]


Global model Benign Test Accuracy: 89.86% 
Global model Malicious Accuracy: 0.00%, Malicious Loss: 14.38, confidence: tensor([3.8921e-09, 2.4352e-06, 1.5901e-05, 5.2755e-06, 7.6197e-08])

| Global Training Round : 13 |
user 0, loss 0.0449440864785605, acc 87.5, mal loss 0.0011041872203350067, mal acc 100.0
user 1, loss 0.04883765973525506, acc 87.83333333333333, mal loss 0.30883002281188965, mal acc 80.0
user 2, loss 0.07163300884738481, acc 88.16666666666667, mal loss 0.0007435858133248985, mal acc 100.0
user 3, loss 0.04588167192649276, acc 91.0, mal loss 0.0025059471372514963, mal acc 100.0
user 4, loss 0.14781662635091278, acc 88.66666666666667
user 5, loss 0.15512535375025535, acc 87.66666666666667
user 6, loss 0.1458148203790188, acc 88.5
user 7, loss 0.14938438205255403, acc 86.16666666666667
user 8, loss 0.14098697510858377, acc 90.83333333333333
user 9, loss 0.1412034220372637, acc 87.5
 
Avg Training Stats after 13 global rounds:
Training Loss : 0.21655768070252904


 65%|██████▌   | 13/20 [04:35<02:29, 21.30s/it]


Global model Benign Test Accuracy: 89.84% 
Global model Malicious Accuracy: 0.00%, Malicious Loss: 19.37, confidence: tensor([1.5100e-11, 1.4838e-10, 2.8963e-07, 7.6834e-10, 1.7280e-06])

| Global Training Round : 14 |
user 0, loss 0.04701484583845166, acc 88.33333333333333, mal loss 0.012619413435459137, mal acc 100.0
user 1, loss 0.04670890588456308, acc 89.5, mal loss 0.00040432121022604406, mal acc 100.0
user 2, loss 0.05862128419552423, acc 90.0, mal loss 0.00043785269372165203, mal acc 100.0
user 3, loss 0.04672602278981031, acc 90.33333333333333, mal loss 0.002936630044132471, mal acc 100.0
user 4, loss 0.1375232536594073, acc 87.33333333333333
user 5, loss 0.15139281283236214, acc 87.0
user 6, loss 0.14393235757119124, acc 88.66666666666667
user 7, loss 0.14619628476599852, acc 88.5
user 8, loss 0.13620020161486335, acc 89.83333333333333
user 9, loss 0.13317749961796735, acc 89.5
 
Avg Training Stats after 14 global rounds:
Training Loss : 0.20857137114356367


 70%|███████   | 14/20 [04:56<02:07, 21.19s/it]


Global model Benign Test Accuracy: 90.00% 
Global model Malicious Accuracy: 0.00%, Malicious Loss: 14.74, confidence: tensor([6.2068e-07, 8.0961e-07, 7.5635e-10, 1.8122e-05, 1.4520e-06])

| Global Training Round : 15 |
user 0, loss 0.03968346370432604, acc 90.5, mal loss 0.0004780317540280521, mal acc 100.0
user 1, loss 0.04820254937620025, acc 86.83333333333333, mal loss 0.0023332596756517887, mal acc 100.0
user 2, loss 0.06690346675653322, acc 88.83333333333333, mal loss 0.000443519267719239, mal acc 100.0
user 3, loss 0.042222221658703984, acc 90.66666666666666, mal loss 0.025304075330495834, mal acc 100.0
user 4, loss 0.13173270367085935, acc 89.5
user 5, loss 0.15295788573722044, acc 88.16666666666667
user 6, loss 0.1341446291903655, acc 88.33333333333333
user 7, loss 0.135704322407643, acc 85.83333333333333
user 8, loss 0.12560026287204687, acc 89.33333333333333
user 9, loss 0.1292909413948655, acc 87.33333333333333
 
Avg Training Stats after 15 global rounds:
Training Loss : 0.

 75%|███████▌  | 15/20 [05:18<01:45, 21.19s/it]


Global model Benign Test Accuracy: 89.98% 
Global model Malicious Accuracy: 0.00%, Malicious Loss: 16.19, confidence: tensor([2.7767e-07, 4.6153e-08, 3.4401e-08, 2.2357e-07, 6.9344e-08])

| Global Training Round : 16 |
user 0, loss 0.039987586568857196, acc 89.66666666666666, mal loss 0.0011090198531746864, mal acc 100.0
user 1, loss 0.0408547054580515, acc 87.83333333333333, mal loss 0.04274861887097359, mal acc 100.0
user 2, loss 0.05214722256080182, acc 88.83333333333333, mal loss 0.00202577724121511, mal acc 100.0
user 3, loss 0.03969696605820087, acc 88.66666666666667, mal loss 0.0028927114326506853, mal acc 100.0
user 4, loss 0.13048127989388175, acc 87.83333333333333
user 5, loss 0.14078002295560307, acc 88.33333333333333
user 6, loss 0.134173122478856, acc 86.16666666666667
user 7, loss 0.13392195676763852, acc 86.5
user 8, loss 0.12546609434816572, acc 90.0
user 9, loss 0.11629933343993293, acc 88.5
 
Avg Training Stats after 16 global rounds:
Training Loss : 0.19475151685873

 80%|████████  | 16/20 [05:37<01:23, 20.79s/it]


Global model Benign Test Accuracy: 89.87% 
Global model Malicious Accuracy: 0.00%, Malicious Loss: 16.86, confidence: tensor([6.8155e-11, 4.4548e-06, 6.3372e-06, 1.7191e-08, 7.3383e-09])

| Global Training Round : 17 |
user 0, loss 0.03731223973731527, acc 88.16666666666667, mal loss 0.0034196064807474613, mal acc 100.0
user 1, loss 0.042546114608596426, acc 89.33333333333333, mal loss 0.009967758320271969, mal acc 100.0
user 2, loss 0.04982655242150812, acc 89.66666666666666, mal loss 0.0007492214208468795, mal acc 100.0
user 3, loss 0.039728551104832276, acc 89.66666666666666, mal loss 0.000991156091913581, mal acc 100.0
user 4, loss 0.12427066815810071, acc 88.5
user 5, loss 0.12844700425863265, acc 88.83333333333333
user 6, loss 0.12839322417974472, acc 89.16666666666667
user 7, loss 0.13083466296394666, acc 86.0
user 8, loss 0.12330219696379369, acc 88.83333333333333
user 9, loss 0.11921894476231602, acc 88.5
 
Avg Training Stats after 17 global rounds:
Training Loss : 0.18873013

 85%|████████▌ | 17/20 [05:57<01:01, 20.49s/it]


Global model Benign Test Accuracy: 90.00% 
Global model Malicious Accuracy: 0.00%, Malicious Loss: 16.40, confidence: tensor([1.1795e-09, 1.1979e-09, 1.2134e-05, 7.3870e-06, 1.8942e-08])

| Global Training Round : 18 |
user 0, loss 0.03519260104803834, acc 88.66666666666667, mal loss 0.009904688224196434, mal acc 100.0
user 1, loss 0.04112520149506612, acc 89.16666666666667, mal loss 0.0019205143908038735, mal acc 100.0
user 2, loss 0.04822526793935123, acc 88.5, mal loss 0.0004185800498817116, mal acc 100.0
user 3, loss 0.03827737212097354, acc 91.16666666666666, mal loss 0.019101470708847046, mal acc 100.0
user 4, loss 0.12230771168859471, acc 89.0
user 5, loss 0.12165502935441004, acc 88.66666666666667
user 6, loss 0.12597340424855552, acc 86.66666666666667
user 7, loss 0.1263796037932237, acc 89.5
user 8, loss 0.12262354315982922, acc 90.5
user 9, loss 0.10784866960512268, acc 86.66666666666667
 
Avg Training Stats after 18 global rounds:
Training Loss : 0.1831873958944979


 90%|█████████ | 18/20 [06:17<00:40, 20.31s/it]


Global model Benign Test Accuracy: 90.38% 
Global model Malicious Accuracy: 0.00%, Malicious Loss: 15.95, confidence: tensor([6.9127e-07, 1.7992e-09, 2.2124e-06, 4.6243e-09, 1.8568e-06])

| Global Training Round : 19 |
user 0, loss 0.033095167646247876, acc 88.83333333333333, mal loss 0.0015710864681750536, mal acc 100.0
user 1, loss 0.035754300756308484, acc 88.0, mal loss 0.0032966521102935076, mal acc 100.0
user 2, loss 0.05048936638080417, acc 89.5, mal loss 5.8313878980698064e-05, mal acc 100.0
user 3, loss 0.031512865103171576, acc 90.5, mal loss 0.0007316942792385817, mal acc 100.0
user 4, loss 0.12089519228786229, acc 87.0
user 5, loss 0.13000481763647662, acc 88.5
user 6, loss 0.11762077649434406, acc 88.33333333333333
user 7, loss 0.12591658976756864, acc 87.83333333333333
user 8, loss 0.11733076020661327, acc 89.5
user 9, loss 0.11139797281059954, acc 86.5
 
Avg Training Stats after 19 global rounds:
Training Loss : 0.1781460477373664


 95%|█████████▌| 19/20 [06:38<00:20, 20.47s/it]


Global model Benign Test Accuracy: 90.03% 
Global model Malicious Accuracy: 0.00%, Malicious Loss: 17.72, confidence: tensor([4.2582e-08, 6.6317e-07, 4.5926e-07, 5.7780e-10, 4.4101e-10])

| Global Training Round : 20 |
user 0, loss 0.028985211298149488, acc 89.66666666666666, mal loss 0.03367520496249199, mal acc 100.0
user 1, loss 0.034318327132240756, acc 86.83333333333333, mal loss 0.009656304493546486, mal acc 100.0
user 2, loss 0.050701007364790586, acc 89.33333333333333, mal loss 0.0025065012741833925, mal acc 100.0
user 3, loss 0.034151665624864234, acc 90.33333333333333, mal loss 0.002521452261134982, mal acc 100.0
user 4, loss 0.11751363807254367, acc 89.83333333333333
user 5, loss 0.127772315127982, acc 87.66666666666667
user 6, loss 0.11769658866441912, acc 87.66666666666667
user 7, loss 0.11882446966237492, acc 87.5
user 8, loss 0.11558855116056899, acc 89.66666666666666
user 9, loss 0.11352843753993512, acc 87.33333333333333
 
Avg Training Stats after 20 global rounds:
Tr

100%|██████████| 20/20 [06:59<00:00, 20.97s/it]


Global model Benign Test Accuracy: 90.10% 
Global model Malicious Accuracy: 0.00%, Malicious Loss: 19.00, confidence: tensor([1.0404e-07, 6.1119e-10, 8.9434e-09, 4.9159e-11, 1.9546e-07])






 
 Results after 20 global rounds of training:
|---- Test Benign Accuracy: 90.02%
|---- Test Malicious Accuracy: 0.00%, Malicious Loss: 19.26, confidence:4.699540667729707e-08


 Total Run Time: 420.4799


In [None]:
only_weights = []
for key in global_weights:
    if 'weight' in key:
        only_weights = np.append(only_weights, global_weights[key].detach().cpu().numpy())
_, glob_svd, _, = np.linalg.svd(only_weights.reshape(1,-1))
glob_svd

In [None]:
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances

sim_vals = []
glob_plr = glob_plr.detach().cpu()

print('euclidean')
for i in range(len(plrs)):
    val = euclidean_distances(glob_plr.reshape(1, -1), plrs[i].detach().cpu().reshape(1,-1))[0][0]
    # print(val)
    if np.isnan(val):
        sim_vals.append(0)
    else:
        sim_vals.append(val)
print(sim_vals)
    

In [None]:
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances, paired_cosine_distances

sim_vals = []
glob_plr = glob_plr.detach().cpu()

print('cosine')
for i in range(len(plrs)):
    val = cosine_similarity(glob_plr.reshape(1, -1), plrs[i].detach().cpu().reshape(1,-1))[0][0]
    if np.isnan(val):
        sim_vals.append(0)
    else:
        sim_vals.append(val)
print(sim_vals)
    

In [None]:
global_weights.keys()

In [None]:
structured_local_weights = [construct_ordered_dict(global_model, flat_weights) for flat_weights in malicious_grads]

conv1_weights = [np.array(w['conv1.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
conv1_weights = np.array(conv1_weights)
conv2_weights = [np.array(w['conv2.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
conv2_weights = np.array(conv2_weights)
conv3_weights = [np.array(w['conv3.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
conv3_weights = np.array(conv3_weights)
fc1_weights = [np.array(w['fc1.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
fc1_weights = np.array(fc1_weights)
fc2_weights = [np.array(w['fc2.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
fc2_weights = np.array(fc2_weights)

num_layers = 5

_, s_layer_0, _, = np.linalg.svd(conv1_weights)
_, s_layer_1, _, = np.linalg.svd(conv2_weights)
_, s_layer_2, _, = np.linalg.svd(conv3_weights)
_, s_layer_3, _, = np.linalg.svd(fc1_weights)
_, s_layer_4, _, = np.linalg.svd(fc2_weights)


In [None]:
%matplotlib inline

import matplotlib
import matplotlib.pyplot as plt

In [None]:

for x, y in zip(np.arange(num_layers), [s_layer_0, s_layer_1, s_layer_2, s_layer_3, s_layer_4]):
    plt.scatter([x]*len(s_layer_0), y, cmap="copper")
    for i, txt in enumerate(np.arange(len(s_layer_0))):
        plt.annotate(txt, (x, y[i]))
    
plt.xticks(np.arange(num_layers))
plt.show()

In [None]:
from scipy.cluster.hierarchy import dendrogram, linkage

linkage_data = linkage(conv1_weights, method='single', metric='correlation')
dendrogram(linkage_data)['dcoord']
plt.title('Targeted-NIID: layer1')

# plt.show()

In [None]:

linkage_data = linkage(conv2_weights, method='single', metric='correlation')
dendrogram(linkage_data)['dcoord']
plt.title('Targeted-NIID: layer2')

# plt.show()

In [None]:

linkage_data = linkage(conv3_weights, method='single', metric='correlation')
dendrogram(linkage_data)['dcoord']
plt.title('Targeted-NIID: layer3')

# plt.show()

In [None]:

linkage_data = linkage(fc1_weights, method='single', metric='correlation')
dendrogram(linkage_data)['dcoord']
plt.title('Targeted-NIID: PLR')

# plt.show()

In [None]:

linkage_data = linkage(fc2_weights, method='single', metric='correlation')
dendrogram(linkage_data)['dcoord']
plt.title('Targeted-NIID: layer5')

# plt.show()

In [None]:
# structured_local_weights = [construct_ordered_dict(global_model, flat_weights) for flat_weights in malicious_grads]

# conv1_weights = [np.array(w['conv1.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
# conv1_weights = np.array(conv1_weights)
# conv2_weights = [np.array(w['conv2.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
# conv2_weights = np.array(conv2_weights)
# fc1_weights = [np.array(w['fc1.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
# fc1_weights = np.array(fc1_weights)
# fc2_weights = [np.array(w['fc2.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
# fc2_weights = np.array(fc2_weights)
# fc3_weights = [np.array(w['fc3.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
# fc3_weights = np.array(fc3_weights)

# num_layers = 5

# _, s_layer_0, _, = np.linalg.svd(conv1_weights)
# _, s_layer_1, _, = np.linalg.svd(conv2_weights)
# _, s_layer_2, _, = np.linalg.svd(fc1_weights)
# _, s_layer_3, _, = np.linalg.svd(fc2_weights)
# _, s_layer_4, _, = np.linalg.svd(fc3_weights)


In [None]:
# from matplotlib import pyplot as plt

# for x, y in zip(np.arange(num_layers-1), [s_layer_0, s_layer_1, s_layer_3, s_layer_4]):
#     plt.scatter([x]*len(s_layer_0), y, cmap="copper")
#     for i, txt in enumerate(np.arange(len(s_layer_0))):
#         plt.annotate(txt, (x, y[i]))
    
# plt.xticks(np.arange(num_layers))
# plt.show()

In [None]:
# if len(args.mal_clients) > 0:
#     import csv
#     if args.iid:
#         filename = 'Confidence_2IID_'+args.dataset+'_'+args.aggregation+'.csv'
#     else:
#         filename = 'Confidence_2NIID_'+args.dataset+'_'+args.aggregation+'.csv'

#     with open(filename, mode='w', newline='') as csvfile:
#         w = csv.writer(csvfile)
#         w.writerows(map(lambda x: [x], confidence))



In [None]:
from utils import get_mal_dataset_of_class

test_mal_X_list, test_mal_Y, test_Y_true = get_mal_dataset_of_class(test_dataset, args.num_mal*100, Y_true, mal_Y)
flat_test_mal_Y = np.hstack([y for y in test_mal_Y])
mal_acc, mal_loss, mal_out = mal_inference(args, global_model, test_dataset, test_mal_X_list, flat_test_mal_Y)



In [None]:
mal_acc, mal_loss, torch.mean(mal_out)

In [None]:
args.dataset, args.boost

In [None]:
exp_details(args)


In [None]:
##### # PLOTTING (optional)
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Agg')

# Plot Loss curve
plt.figure()
plt.title('Training Loss vs Communication rounds')
plt.plot(range(len(train_loss)), train_loss, color='r')
plt.ylabel('Training loss')
plt.xlabel('Communication Rounds')
plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_loss.png'.
            format(args.dataset, args.model, args.epochs, args.frac,
                   args.iid, args.local_ep, args.local_bs))

# Plot Average Accuracy vs Communication rounds
plt.figure()
plt.title('Average Accuracy vs Communication rounds')
plt.plot(range(len(train_accuracy)), train_accuracy, color='k')
plt.ylabel('Average Accuracy')
plt.xlabel('Communication Rounds')
plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_acc.png'.
            format(args.dataset, args.model, args.epochs, args.frac,
                   args.iid, args.local_ep, args.local_bs))