## Plan

- [x] FL base framework
- [x] implement non-iid sampling function following dirichlet dist

#### Attack simulation
- [x] untargeted model poisoning
    - Local model poisoning attacks to Byzantine-robust federated learning <br>
    https://github.com/vrt1shjwlkr/NDSS21-Model-Poisoning fang attack <br>
    krum-attack, trimmed-mean/median attack
- [x] targeted model poisoning
    - Analyzing federated learning through an adversarial lens <br>
    https://github.com/inspire-group/ModelPoisoning
- [x] data poisoning
    - DBA: Distributed backdoor attacks against federated learning <br>
    https://github.com/AI-secure/DBA
    
#### Defense baseline
- [x] FedAvg
- [x] Krum
- [x] Multi-Krum
- [x] Bulyan
- [x] Coordinate median
- [x] FLARE

#### Proposed method
- [x] extract PLRs
- [x] apply RBF hypersphere CKA

In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [2]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6


import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import copy
import time
import pickle
import numpy as np
from tqdm import tqdm
from collections import OrderedDict

import torch
from tensorboardX import SummaryWriter

from options import args_parser
from update import LocalUpdate, test_inference, mal_inference
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar, Alexnet, modelC, LeNet5
from utils import get_dataset, get_mal_dataset, exp_details, flatten, construct_ordered_dict
from aggregate import fedavg, multi_krum, krum, coomed, bulyan, tr_mean, fed_align, fed_cc, flare, fltrust
from attacks import get_malicious_updates_untargeted_mkrum, get_malicious_updates_untargeted_med, get_malicious_updates_targeted
from cka import linear_CKA, kernel_CKA
# python src/federated_main.py --model=cnn --dataset=cifar --gpu=0 --iid=1 --epochs=10

import warnings
# warnings.filterwarnings('ignore')

# Targeted Model Poisoning Attack (label flipping)

In [3]:
class Args(object):
    
    # federated parameters (default values are set)
    epochs = 20
    num_users = 10
    frac = 1 # fraction of clients
    local_ep = 3 # num of local epoch
    local_bs = 100 # batch size
    lr = 0.001
    momentum = 0.9
    aggregation = 'mkrum' # fedavg, krum, mkrum, coomed, bulyan, flare, fedcc
    
    # model arguments
    model = 'cnn'
    kernel_num = 9 # num of each kind of kernel
    kernel_sizes = '3,4,5' # comma-separated kernel size to use for convolution
    norm = 'batch_norm' # batch_norm, layer_norm, None
    num_filters = 32 # num of filters for conv nets -- 32 for mini-imagenet, 64 for omiglot
    max_pool = 'True' # whether use max pooling rather than strided convolutions
    
    # other arguments
    dataset = 'cifar100' # fmnist, cifar, mnist
    if dataset == 'cifar100':
        num_classes = 100 
        num_channels = 3 # num of channels of imgs
    else:
        num_classes = 10
        num_channels = 1
    
    gpu = 0
    optimizer = 'adam'
    iid = 0# 0 for non-iid
    alpha = 0.2 # noniid --> (0, 1] <-- iid
    unequal = 0 # whether to use unequal data splits for non-iid settings (0 for equal splits)
    stopping_rounds = 10 # rounds of early stopping
    verbose = 0
    seed = 1

    # malicious arguments
    mal_clients = [] # [0, 1] # indices of malicious user
    attack_type = 'targeted' # targeted
    num_mal = 1 # number of maliciuos data sample
    local_mal_ep = 6
    boost = 5 # alpha: 2 for fedavg, 3.5 for krum
    mal_bs = 100
#     # malicious arguments
#     mal_clients = [9] # indices of malicious user
#     attack_type = 'targeted' # targeted
#     num_mal = 1 # number of maliciuos data sample
#     local_mal_ep = 6
#     boost = 5 # alpha: 2 for fedavg, 3.5 for krum

In [4]:

if __name__ == '__main__':
    start_time = time.time()

    # define paths
    path_project = os.path.abspath('..')
    logger = SummaryWriter('../logs')

    args = Args()
    exp_details(args)

    device = 'cuda:0' if args.gpu == 0 else 'cpu'
    # device = 'cpu'
    # for n_attacker in args.n_attackers:
    torch.cuda.empty_cache()

    # load dataset and user groups
    train_dataset, test_dataset, user_groups = get_dataset(args)

    # BUILD MODEL
    if args.model == 'cnn':
        # Convolutional neural netork
        if args.dataset == 'mnist':
            global_model = CNNMnist(args=args)
        elif args.dataset == 'fmnist':
            global_model = CNNFashion_Mnist(args=args)
        elif args.dataset == 'cifar':
            global_model = CNNCifar(args=args)
        elif args.dataset == 'cifar100':
            global_model = LeNet5(args=args)

            
    # elif args.model == 'alexnet':
    #     if args.dataset == 'cifar100':
    #         global_model = Alexnet(args=args)
    
    elif args.model == 'mlp':
        # Multi-layer preceptron
        img_size = train_dataset[0][0].shape
        len_in = 1
        for x in img_size:
            len_in *= x
            global_model = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes)
    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    global_model.to(device)
    global_model.train()
    print(global_model)

    # copy weights
    global_weights = global_model.state_dict()
    if len(args.mal_clients) > 0:
        mal_X_list, mal_Y, Y_true = get_mal_dataset(test_dataset, args.num_mal, args.num_classes)
        print("malcious dataset true labels: {}, malicious labels: {}".format(Y_true, mal_Y))

    # Training
    train_loss, train_accuracy = [], []
    val_acc_list, net_list = [], []
    cv_loss, cv_acc = [], []
    print_every = 1
    val_loss_pre, counter = 0, 0
    
    confidence = []    
    for epoch in tqdm(range(args.epochs)):
        local_weights, local_losses = [], []
        
        print('=========================================')
        print(f'| Global Training Round : {epoch+1} |')
        print('=========================================')

        global_model.train()
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)

        # flattened weights + bias
        flattened_local_weights = []
        
        # create separate arrays for weights and biases and the offsets 
        only_weights = [] # for krum to consider only weights, not the biases
        only_biases = []
            
        for idx in range(args.num_users):
            mal_user = False
            
            # alternating benign and malicious training 
            if idx in args.mal_clients:# and epoch % 2 == 0:
                mal_user = True
                local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=user_groups[idx], logger=logger, \
                                      mal=mal_user, mal_X=mal_X_list, mal_Y=mal_Y, test_dataset=test_dataset)
            else:
                local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=user_groups[idx], logger=logger)
                
            w_prev = global_model.state_dict()
            w, loss = local_model.update_weights(model=copy.deepcopy(global_model), global_round=epoch)
            # boost the malicious (weight+bias) by alpha
            # if mal_user:
            #     flat_delta_m = flatten(w) - flatten(w_prev)
            #     flat_mal_w = flatten(w_prev) + args.boost * flat_delta_m
            #     w = construct_ordered_dict(global_model, torch.tensor(flat_mal_w).to(device))
              
            # construct two arrays with weights-only and bias-only
            if 'krum' in args.aggregation:
                for key in w:
                    if 'weight' in key:
                        if mal_user and epoch % 2 == 0:
                            only_weights = np.append(only_weights, args.boost * w[key].detach().cpu().numpy().reshape(-1))
                        else:
                            only_weights = np.append(only_weights, w[key].detach().cpu().numpy().reshape(-1))
                    elif 'bias' in key:
                        only_biases = np.append(only_biases, w[key].detach().cpu().numpy().reshape(-1))
                            
            new_model = copy.deepcopy(global_model)
            new_model.load_state_dict(w)
            acc, _ = local_model.inference(model=new_model)

            if mal_user == True:
                mal_acc, mal_loss = local_model.mal_inference(model=new_model)
                print('user {}, loss {}, acc {}, mal loss {}, mal acc {}'.format(idx, loss, 100*acc, mal_loss, 100*mal_acc))
            else:
                print('user {}, loss {}, acc {}'.format(idx, loss, 100*acc))

            local_weights.append(copy.deepcopy(w))
            local_losses.append(copy.deepcopy(loss))
            
            # if agg==krum, alternate training
            if args.aggregation == 'krum':
                if idx in args.mal_clients and epoch %2 ==0:
                    flattened_local_weights.append(flatten(w_prev) + args.boost * (flatten(w) - flatten(w_prev)))
                else:                     
                    flattened_local_weights.append(flatten(w))
            else:
                # if malicious user: boost the detlta weights
                if idx in args.mal_clients:
                    flattened_local_weights.append(flatten(w_prev) + args.boost * (flatten(w) - flatten(w_prev)))
                else: 
                    flattened_local_weights.append(flatten(w))
        
        only_weights = torch.tensor(np.array(only_weights)).to(device)
                
        flattened_local_weights = torch.tensor(np.array(flattened_local_weights)).to(device)
        malicious_grads = flattened_local_weights

        n_attacker = len(args.mal_clients)
        
        
        # update global weights
        if args.aggregation == 'fedavg':
            agg_weights = fedavg(malicious_grads)
        elif args.aggregation == 'krum':
            agg_weights, selected_idxs = krum(malicious_grads, n_attacker, only_weights=only_weights)
            print(f'Krum Selected idx: {selected_idxs}')
        elif args.aggregation == 'mkrum':
            agg_weights, selected_idxs = multi_krum(malicious_grads, n_attacker, only_weights=only_weights)
            print(f'multiKrum Selected idxs: {selected_idxs}')
        elif args.aggregation == 'coomed':
            agg_weights = coomed(malicious_grads)
            print(f'\ndiff {torch.norm((agg_weights - flattened_local_weights[0])) ** 2}')
        elif args.aggregation == 'bulyan':
            agg_weights, selected_idxs = bulyan(malicious_grads, n_attacker)
            print(f'Bulyan Selected idx: {selected_idxs}')
        elif args.aggregation == 'trmean':
            agg_weights = tr_mean(malicious_grads, n_attacker)
        elif args.aggregation == 'fltrust':
            glob_weights = []
            glob_weights.append(flatten(w_prev))
            agg_weights = fltrust(malicious_grads, glob_weights)
            
        elif args.aggregation == 'flare':
            second_last_layer = list(local_weights[0].keys())[-4]
            structured_local_weights = [construct_ordered_dict(global_model, flat_weights) for flat_weights in malicious_grads]
            plrs = [(each_local[second_last_layer]) for each_local in structured_local_weights]
            agg_weights, count_dict = flare(malicious_grads, plrs)
            print(f'flare count_dict: {count_dict}')

        elif args.aggregation == 'fedcc':
            second_last_layer = list(local_weights[0].keys())[-4]
            glob_plr = global_weights[second_last_layer]            
            structured_local_weights = [construct_ordered_dict(global_model, flat_weights) for flat_weights in malicious_grads]
            plrs = [(each_local[second_last_layer]) for each_local in structured_local_weights]
            agg_weights, selected_idxs = fed_cc(malicious_grads, glob_plr, plrs, 'kernel')
            print(f'fed_cc Selected idx: {selected_idxs}')
        else:
            raise ValueError('Unknown aggregation strategy: {}'.format(args.aggregation))


        # reshape the flattened global weights into the ordereddict
        global_weights = construct_ordered_dict(global_model, agg_weights)
        
        # update global weights
        global_model.load_state_dict(global_weights)

        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)

        # print global training loss after every 'i' rounds
        if (epoch+1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch+1} global rounds:')
            print(f'Training Loss : {np.mean(np.array(train_loss))}')
            
            test_acc, test_loss = test_inference(args, global_model, test_dataset)
            print('\nGlobal model Benign Test Accuracy: {:.2f}% '.format(100*test_acc))
            
            if len(args.mal_clients) > 0:
                mal_acc, mal_loss, mal_out = mal_inference(args, global_model, test_dataset, mal_X_list, mal_Y)
                print('Global model Malicious Accuracy: {:.2f}%, Malicious Loss: {:.2f}, confidence: {}\n'.format(100*mal_acc, mal_loss, mal_out))
                confidence.append(mal_out[0].item())
                
    # Test inference after completion of training
    test_acc, test_loss = test_inference(args, global_model, test_dataset)

    print(f' \n Results after {args.epochs} global rounds of training:')
    print("|---- Test Benign Accuracy: {:.2f}%".format(100*test_acc))
    
    if len(args.mal_clients) > 0:
        mal_acc, mal_loss, mal_out = mal_inference(args, global_model, test_dataset, mal_X_list, mal_Y)
        print("|---- Test Malicious Accuracy: {:.2f}%, Malicious Loss: {:.2f}, confidence:{}\n".format(100*mal_acc, mal_loss, mal_out[0]))
    
    print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))



Experimental details:
    Model     : cnn
    Optimizer : adam
    Learning  : 0.001
    Aggregation     : mkrum
    Global Rounds   : 20

    Federated parameters:
    Non-IID              : 0.2
    Fraction of users    : 1
    Local Batch size     : 100
    Local Epochs         : 3

Files already downloaded and verified
Files already downloaded and verified
LeNet5(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=44944, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=100, bias=True)
)


  0%|          | 0/20 [00:00<?, ?it/s]

| Global Training Round : 1 |


  return F.conv2d(input, weight, bias, self.stride,


user 0, loss 2.6454838429782406, acc 7.093425605536333
user 1, loss 2.8241361840566, acc 0.0
user 2, loss 2.7797121123888004, acc 1.3301088270858523
user 3, loss 2.4454901073037125, acc 11.332007952286283
user 4, loss 2.654598668564198, acc 0.0
user 5, loss 2.80490935643514, acc 0.0
user 6, loss 3.1030106518003677, acc 0.0
user 7, loss 2.287781095504761, acc 0.0
user 8, loss 2.955603975875705, acc 15.80188679245283
user 9, loss 2.9837173437460875, acc 0.1545595054095827
multiKrum Selected idxs: [6 2 8 9 0 3 1 5 4 7]
 
Avg Training Stats after 1 global rounds:
Training Loss : 2.7484443338653612


  5%|▌         | 1/20 [03:22<1:04:00, 202.13s/it]


Global model Benign Test Accuracy: 3.28% 
| Global Training Round : 2 |
user 0, loss 2.4286348388550127, acc 4.498269896193772
user 1, loss 2.5768976815541587, acc 0.0
user 2, loss 2.6502741533725414, acc 0.24183796856106407
user 3, loss 2.2448510765060177, acc 11.928429423459244
user 4, loss 2.501607908759006, acc 0.0
user 5, loss 2.6162464181582132, acc 0.0
user 6, loss 2.799733050664266, acc 0.0
user 7, loss 2.082857126281375, acc 0.0
user 8, loss 2.7525067282658, acc 16.037735849056602
user 9, loss 2.784083386262258, acc 2.627511591962906
multiKrum Selected idxs: [6 2 8 0 9 3 1 5 4 7]
 
Avg Training Stats after 2 global rounds:
Training Loss : 2.646106785366613


 10%|█         | 2/20 [06:42<1:00:14, 200.82s/it]


Global model Benign Test Accuracy: 6.70% 
| Global Training Round : 3 |
user 0, loss 2.2895939705219672, acc 9.515570934256056
user 1, loss 2.431288064320882, acc 0.0
user 2, loss 2.5692825507168746, acc 0.24183796856106407
user 3, loss 2.129957584830804, acc 13.320079522862823
user 4, loss 2.4308186390603237, acc 0.0
user 5, loss 2.494080381923252, acc 0.0
user 6, loss 2.6269280142254297, acc 0.0
user 7, loss 2.003149074599856, acc 0.0
user 8, loss 2.5689720207569646, acc 11.084905660377359
user 9, loss 2.63196118367024, acc 1.2364760432766615
multiKrum Selected idxs: [6 2 8 0 9 3 1 5 4 7]
 
Avg Training Stats after 3 global rounds:
Training Loss : 2.5699389063986287


 15%|█▌        | 3/20 [10:01<56:44, 200.26s/it]  


Global model Benign Test Accuracy: 8.54% 
| Global Training Round : 4 |
user 0, loss 2.17378135860389, acc 10.899653979238755
user 1, loss 2.263401951789856, acc 0.0
user 2, loss 2.4736968373777852, acc 0.0
user 3, loss 2.0318044065459957, acc 12.127236580516898
user 4, loss 2.3004258085590923, acc 0.0
user 5, loss 2.358251106739044, acc 0.0
user 6, loss 2.4621382143762376, acc 0.0
user 7, loss 1.872765496798924, acc 0.0
user 8, loss 2.4837183905582805, acc 21.22641509433962
user 9, loss 2.524475759420639, acc 0.0
multiKrum Selected idxs: [6 2 8 0 9 3 1 5 4 7]
 
Avg Training Stats after 4 global rounds:
Training Loss : 2.501065663068215


 20%|██        | 4/20 [13:27<54:01, 202.57s/it]


Global model Benign Test Accuracy: 10.10% 
| Global Training Round : 5 |
user 0, loss 2.0534244361498675, acc 4.1522491349480966
user 1, loss 2.172685128847758, acc 0.0
user 2, loss 2.3890289590133365, acc 2.418379685610641
user 3, loss 1.911665696438735, acc 10.536779324055665
user 4, loss 2.181071777676427, acc 0.0
user 5, loss 2.2704268018404643, acc 0.0
user 6, loss 2.3691081788804795, acc 0.0
user 7, loss 1.764766460373288, acc 0.0
user 8, loss 2.3874406908072676, acc 10.141509433962264
user 9, loss 2.4040499130884805, acc 0.7727975270479135
multiKrum Selected idxs: [6 2 8 0 9 3 1 5 4 7]
 
Avg Training Stats after 5 global rounds:
Training Loss : 2.4389258913168943


 25%|██▌       | 5/20 [16:39<49:40, 198.68s/it]


Global model Benign Test Accuracy: 11.82% 
| Global Training Round : 6 |
user 0, loss 1.9331106733768546, acc 8.304498269896193
user 1, loss 1.988613551457723, acc 0.0
user 2, loss 2.2792116118900814, acc 1.3301088270858523
user 3, loss 1.7714294931752894, acc 11.133200795228628
user 4, loss 2.031436973763991, acc 0.0
user 5, loss 2.0844194730122885, acc 0.0
user 6, loss 2.262009114689297, acc 0.0
user 7, loss 1.6548420883360364, acc 0.0
user 8, loss 2.2104572478462665, acc 18.160377358490564
user 9, loss 2.2882583836714425, acc 0.3091190108191654
multiKrum Selected idxs: [6 2 8 0 9 3 1 5 4 7]
 
Avg Training Stats after 6 global rounds:
Training Loss : 2.3741680529510663


 30%|███       | 6/20 [19:39<44:51, 192.24s/it]


Global model Benign Test Accuracy: 13.56% 
| Global Training Round : 7 |
user 0, loss 1.7926878497955645, acc 7.7854671280276815
user 1, loss 1.8201861492792766, acc 0.0
user 2, loss 2.149749533453984, acc 0.8464328899637243
user 3, loss 1.6636722892280515, acc 12.127236580516898
user 4, loss 1.9106430194174593, acc 0.0
user 5, loss 1.9889627668592667, acc 0.5420054200542005
user 6, loss 2.008071494102478, acc 0.0
user 7, loss 1.597279777980986, acc 0.0
user 8, loss 2.0233379099883284, acc 5.89622641509434
user 9, loss 2.1125359191344333, acc 0.6182380216383307
multiKrum Selected idxs: [6 2 8 0 9 3 1 5 4 7]
 
Avg Training Stats after 7 global rounds:
Training Loss : 2.3073887126614827


 35%|███▌      | 7/20 [22:41<40:57, 189.03s/it]


Global model Benign Test Accuracy: 15.06% 
| Global Training Round : 8 |
user 0, loss 1.6775741780057867, acc 9.16955017301038
user 1, loss 1.6552832277615863, acc 0.0
user 2, loss 2.030680894851684, acc 2.418379685610641
user 3, loss 1.5298382305517428, acc 10.337972166998012
user 4, loss 1.767586160075757, acc 0.0
user 5, loss 1.8221624930699667, acc 0.5420054200542005
user 6, loss 1.8122873862584432, acc 1.0723860589812333
user 7, loss 1.4148071203913009, acc 0.0
user 8, loss 1.9112083491157084, acc 7.547169811320755
user 9, loss 2.0073463702813172, acc 0.46367851622874806
multiKrum Selected idxs: [6 2 8 0 9 3 1 5 4 7]
 
Avg Training Stats after 8 global rounds:
Training Loss : 2.239324803708339


 40%|████      | 8/20 [25:43<37:22, 186.89s/it]


Global model Benign Test Accuracy: 15.89% 
| Global Training Round : 9 |
user 0, loss 1.5293947158130348, acc 8.131487889273355
user 1, loss 1.4851716454823813, acc 0.0
user 2, loss 1.910918689485806, acc 2.902055622732769
user 3, loss 1.4656928016887447, acc 8.747514910536779
user 4, loss 1.6434218938960583, acc 0.1869158878504673
user 5, loss 1.690201152033276, acc 0.27100271002710025
user 6, loss 1.6946443458398182, acc 0.2680965147453083
user 7, loss 1.3007172947838193, acc 0.0
user 8, loss 1.794833209000382, acc 16.50943396226415
user 9, loss 1.8293952032541616, acc 0.9273570324574961
multiKrum Selected idxs: [6 2 8 0 9 3 1 5 4 7]
 
Avg Training Stats after 9 global rounds:
Training Loss : 2.1721152805327177


 45%|████▌     | 9/20 [28:50<34:14, 186.73s/it]


Global model Benign Test Accuracy: 15.59% 
| Global Training Round : 10 |
user 0, loss 1.4822431785840513, acc 8.996539792387544
user 1, loss 1.4167448004086811, acc 0.0
user 2, loss 1.8382747108070412, acc 0.24183796856106407
user 3, loss 1.330954663152617, acc 11.928429423459244
user 4, loss 1.5557287298431692, acc 0.0
user 5, loss 1.5789227975739373, acc 0.0
user 6, loss 1.614554152223799, acc 0.8042895442359249
user 7, loss 1.2261681590761457, acc 0.0
user 8, loss 1.6328967272066606, acc 9.19811320754717
user 9, loss 1.7457181983269177, acc 5.7187017001545595
multiKrum Selected idxs: [6 2 8 0 9 3 1 5 4 7]
 
Avg Training Stats after 10 global rounds:
Training Loss : 2.1091258136514766


 50%|█████     | 10/20 [31:50<30:48, 184.81s/it]


Global model Benign Test Accuracy: 16.48% 
| Global Training Round : 11 |
user 0, loss 1.3943034900841138, acc 7.26643598615917
user 1, loss 1.3328138701121013, acc 0.0
user 2, loss 1.7218254866884715, acc 0.24183796856106407
user 3, loss 1.2317355488373982, acc 10.337972166998012
user 4, loss 1.4464314510655958, acc 0.0
user 5, loss 1.4652666012446085, acc 1.084010840108401
user 6, loss 1.4874672340022193, acc 1.0723860589812333
user 7, loss 1.171447058234896, acc 0.0
user 8, loss 1.5405147075653076, acc 6.367924528301887
user 9, loss 1.6663140902916591, acc 0.7727975270479135
multiKrum Selected idxs: [6 2 8 0 9 3 1 5 4 7]
 
Avg Training Stats after 11 global rounds:
Training Loss : 2.0488245536661274


 55%|█████▌    | 11/20 [34:51<27:31, 183.48s/it]


Global model Benign Test Accuracy: 16.81% 
| Global Training Round : 12 |
user 0, loss 1.2933471342350573, acc 8.650519031141869
user 1, loss 1.2292601462205253, acc 0.0
user 2, loss 1.6638749075766226, acc 1.4510278113663846
user 3, loss 1.1881630551524278, acc 8.34990059642147
user 4, loss 1.3980302935422853, acc 0.1869158878504673
user 5, loss 1.3944020436869726, acc 2.168021680216802
user 6, loss 1.4161538130707212, acc 2.4128686327077746
user 7, loss 1.0977824775945573, acc 0.0
user 8, loss 1.4870813518178227, acc 5.89622641509434
user 9, loss 1.5914152073554504, acc 1.2364760432766615
multiKrum Selected idxs: [6 2 0 8 9 1 3 5 4 7]
 
Avg Training Stats after 12 global rounds:
Training Loss : 1.9927517611127203


 60%|██████    | 12/20 [37:55<24:30, 183.81s/it]


Global model Benign Test Accuracy: 16.52% 
| Global Training Round : 13 |
user 0, loss 1.2520190111711516, acc 4.844290657439446
user 1, loss 1.148510154883067, acc 0.0
user 2, loss 1.600243238963891, acc 0.9673518742442563
user 3, loss 1.1327580540645414, acc 13.12127236580517
user 4, loss 1.3207472779954126, acc 0.0
user 5, loss 1.3349793036778765, acc 1.3550135501355014
user 6, loss 1.3345040215386283, acc 0.2680965147453083
user 7, loss 1.0459097700459616, acc 0.0
user 8, loss 1.4304548653901792, acc 7.0754716981132075
user 9, loss 1.5233307228638575, acc 1.3910355486862442
multiKrum Selected idxs: [6 2 0 8 9 1 3 5 4 7]
 
Avg Training Stats after 13 global rounds:
Training Loss : 1.9404128288778537


 65%|██████▌   | 13/20 [40:53<21:13, 181.93s/it]


Global model Benign Test Accuracy: 16.66% 
| Global Training Round : 14 |
user 0, loss 1.238583375376167, acc 8.650519031141869
user 1, loss 1.1366921472549436, acc 0.0
user 2, loss 1.5199610762928255, acc 0.12091898428053204
user 3, loss 1.0985616472193864, acc 11.530815109343937
user 4, loss 1.2476578535035598, acc 0.0
user 5, loss 1.2781411565012404, acc 0.0
user 6, loss 1.2932221906052694, acc 1.0723860589812333
user 7, loss 0.982609049195335, acc 0.0
user 8, loss 1.3186394692051644, acc 9.433962264150944
user 9, loss 1.4457776798651771, acc 1.2364760432766615
multiKrum Selected idxs: [6 2 0 8 9 1 3 5 4 7]
 
Avg Training Stats after 14 global rounds:
Training Loss : 1.8915250957081433


 70%|███████   | 14/20 [43:51<18:03, 180.64s/it]


Global model Benign Test Accuracy: 17.08% 
| Global Training Round : 15 |
user 0, loss 1.156867948406977, acc 8.304498269896193
user 1, loss 1.073787193695704, acc 0.0
user 2, loss 1.4838322949646718, acc 0.48367593712212814
user 3, loss 1.0344842485780639, acc 10.337972166998012
user 4, loss 1.2126162906949836, acc 1.3084112149532712
user 5, loss 1.2265687505404155, acc 1.3550135501355014
user 6, loss 1.1987843811511993, acc 4.021447721179625
user 7, loss 0.9953395162309918, acc 0.0
user 8, loss 1.2651406652202792, acc 17.452830188679243
user 9, loss 1.4023427267869313, acc 1.2364760432766615
multiKrum Selected idxs: [6 2 0 8 9 1 3 5 4 7]
 
Avg Training Stats after 15 global rounds:
Training Loss : 1.8457551827694019


 75%|███████▌  | 15/20 [46:53<15:06, 181.28s/it]


Global model Benign Test Accuracy: 16.37% 
| Global Training Round : 16 |
user 0, loss 1.1309033317346098, acc 7.439446366782007
user 1, loss 1.0875341109434764, acc 0.0
user 2, loss 1.4090888847166034, acc 1.2091898428053205
user 3, loss 1.0290186405181885, acc 10.139165009940358
user 4, loss 1.1732082542522935, acc 0.0
user 5, loss 1.1659434378147124, acc 0.8130081300813009
user 6, loss 1.2123156633641985, acc 0.8042895442359249
user 7, loss 0.9102335374979744, acc 0.0
user 8, loss 1.2847023761155558, acc 12.028301886792454
user 9, loss 1.3960733650586545, acc 1.0819165378670788
multiKrum Selected idxs: [6 2 0 8 1 9 3 5 4 7]
 
Avg Training Stats after 16 global rounds:
Training Loss : 1.8041393688589156


 80%|████████  | 16/20 [49:54<12:04, 181.21s/it]


Global model Benign Test Accuracy: 16.54% 
| Global Training Round : 17 |
user 0, loss 1.1040363616131723, acc 10.899653979238755
user 1, loss 0.9942176734407742, acc 0.0
user 2, loss 1.3549756897029592, acc 0.24183796856106407
user 3, loss 0.963502768578568, acc 8.151093439363818
user 4, loss 1.1452245134715884, acc 0.7476635514018692
user 5, loss 1.0928657299942441, acc 0.8130081300813009
user 6, loss 1.1932536640101008, acc 4.289544235924933
user 7, loss 0.8979203930922917, acc 0.0
user 8, loss 1.1673795110454745, acc 9.669811320754718
user 9, loss 1.3081670271662567, acc 2.472952086553323
multiKrum Selected idxs: [6 2 0 8 1 9 3 5 4 7]
 
Avg Training Stats after 17 global rounds:
Training Loss : 1.764022602056129


 85%|████████▌ | 17/20 [53:00<09:07, 182.36s/it]


Global model Benign Test Accuracy: 16.63% 
| Global Training Round : 18 |
user 0, loss 1.0111906839177964, acc 8.131487889273355
user 1, loss 0.9605444240570069, acc 0.0
user 2, loss 1.3205707811597567, acc 2.176541717049577
user 3, loss 0.9553198785316654, acc 10.139165009940358
user 4, loss 1.1070910073065943, acc 0.3738317757009346
user 5, loss 1.115899352894889, acc 0.8130081300813009
user 6, loss 1.1148349901040395, acc 1.876675603217158
user 7, loss 0.8734757947070259, acc 0.0
user 8, loss 1.1303747065511405, acc 10.141509433962264
user 9, loss 1.2515067276664267, acc 0.3091190108191654
multiKrum Selected idxs: [6 2 0 8 1 9 3 5 4 7]
 
Avg Training Stats after 18 global rounds:
Training Loss : 1.726248059424657


 90%|█████████ | 18/20 [56:03<06:05, 182.68s/it]


Global model Benign Test Accuracy: 16.67% 
| Global Training Round : 19 |
user 0, loss 1.0117030291692586, acc 9.688581314878892
user 1, loss 0.97279882868131, acc 0.0
user 2, loss 1.2569524599841577, acc 0.48367593712212814
user 3, loss 0.945445182725666, acc 7.75347912524851
user 4, loss 1.0522669985774875, acc 0.3738317757009346
user 5, loss 1.091092106865512, acc 2.710027100271003
user 6, loss 1.10467133704159, acc 2.1447721179624666
user 7, loss 0.8636173293704078, acc 0.0
user 8, loss 1.1140717148196464, acc 10.377358490566039
user 9, loss 1.2303908879940326, acc 2.1638330757341575
multiKrum Selected idxs: [6 2 0 8 1 9 3 5 4 7]
 
Avg Training Stats after 19 global rounds:
Training Loss : 1.6914087398508806


 95%|█████████▌| 19/20 [59:07<03:03, 183.09s/it]


Global model Benign Test Accuracy: 16.76% 
| Global Training Round : 20 |
user 0, loss 0.985046924851465, acc 6.5743944636678195
user 1, loss 0.9330371860663096, acc 0.0
user 2, loss 1.23987272085242, acc 0.48367593712212814
user 3, loss 0.9167926273937148, acc 9.542743538767395
user 4, loss 1.019479795943859, acc 0.1869158878504673
user 5, loss 1.0385848263899484, acc 0.27100271002710025
user 6, loss 1.040245761639542, acc 4.289544235924933
user 7, loss 0.8533777397303354, acc 0.0
user 8, loss 1.0910334622158724, acc 12.028301886792454
user 9, loss 1.1936051029807482, acc 3.7094281298299845
multiKrum Selected idxs: [6 2 0 8 1 9 3 5 4 7]
 
Avg Training Stats after 20 global rounds:
Training Loss : 1.6583936835986577


100%|██████████| 20/20 [1:01:26<00:00, 184.34s/it]


Global model Benign Test Accuracy: 16.44% 





 
 Results after 20 global rounds of training:
|---- Test Benign Accuracy: 16.44%

 Total Run Time: 3699.6150


In [5]:
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances


In [6]:
for i in range(10):
    print(cosine_similarity(glob_weights, malicious_grads[i].detach().cpu().reshape(1,-1)))

[[0.9892056]]
[[0.98841816]]
[[0.98894024]]
[[0.9891928]]
[[0.9885451]]
[[0.9889499]]
[[0.9879457]]
[[0.9889792]]
[[0.98861]]
[[0.9883492]]


In [None]:
only_weights = []
for key in global_weights:
    if 'weight' in key:
        only_weights = np.append(only_weights, global_weights[key].detach().cpu().numpy())
_, glob_svd, _, = np.linalg.svd(only_weights.reshape(1,-1))
glob_svd

In [None]:
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances

sim_vals = []
glob_plr = glob_plr.detach().cpu()

print('euclidean')
for i in range(len(plrs)):
    val = euclidean_distances(glob_plr.reshape(1, -1), plrs[i].detach().cpu().reshape(1,-1))[0][0]
    # print(val)
    if np.isnan(val):
        sim_vals.append(0)
    else:
        sim_vals.append(val)
print(sim_vals)
    

In [None]:
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances, paired_cosine_distances

sim_vals = []
glob_plr = glob_plr.detach().cpu()

print('cosine')
for i in range(len(plrs)):
    val = cosine_similarity(glob_plr.reshape(1, -1), plrs[i].detach().cpu().reshape(1,-1))[0][0]
    if np.isnan(val):
        sim_vals.append(0)
    else:
        sim_vals.append(val)
print(sim_vals)
    

In [None]:
global_weights.keys()

In [None]:
structured_local_weights = [construct_ordered_dict(global_model, flat_weights) for flat_weights in malicious_grads]

conv1_weights = [np.array(w['conv1.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
conv1_weights = np.array(conv1_weights)
conv2_weights = [np.array(w['conv2.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
conv2_weights = np.array(conv2_weights)
conv3_weights = [np.array(w['conv3.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
conv3_weights = np.array(conv3_weights)
fc1_weights = [np.array(w['fc1.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
fc1_weights = np.array(fc1_weights)
fc2_weights = [np.array(w['fc2.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
fc2_weights = np.array(fc2_weights)

num_layers = 5

_, s_layer_0, _, = np.linalg.svd(conv1_weights)
_, s_layer_1, _, = np.linalg.svd(conv2_weights)
_, s_layer_2, _, = np.linalg.svd(conv3_weights)
_, s_layer_3, _, = np.linalg.svd(fc1_weights)
_, s_layer_4, _, = np.linalg.svd(fc2_weights)


In [None]:
%matplotlib inline

import matplotlib
import matplotlib.pyplot as plt

In [None]:

for x, y in zip(np.arange(num_layers), [s_layer_0, s_layer_1, s_layer_2, s_layer_3, s_layer_4]):
    plt.scatter([x]*len(s_layer_0), y, cmap="copper")
    for i, txt in enumerate(np.arange(len(s_layer_0))):
        plt.annotate(txt, (x, y[i]))
    
plt.xticks(np.arange(num_layers))
plt.show()

In [None]:
from scipy.cluster.hierarchy import dendrogram, linkage

linkage_data = linkage(conv1_weights, method='single', metric='correlation')
dendrogram(linkage_data)['dcoord']
plt.title('Targeted-NIID: layer1')

# plt.show()

In [None]:

linkage_data = linkage(conv2_weights, method='single', metric='correlation')
dendrogram(linkage_data)['dcoord']
plt.title('Targeted-NIID: layer2')

# plt.show()

In [None]:

linkage_data = linkage(conv3_weights, method='single', metric='correlation')
dendrogram(linkage_data)['dcoord']
plt.title('Targeted-NIID: layer3')

# plt.show()

In [None]:

linkage_data = linkage(fc1_weights, method='single', metric='correlation')
dendrogram(linkage_data)['dcoord']
plt.title('Targeted-NIID: PLR')

# plt.show()

In [None]:

linkage_data = linkage(fc2_weights, method='single', metric='correlation')
dendrogram(linkage_data)['dcoord']
plt.title('Targeted-NIID: layer5')

# plt.show()

In [None]:
# structured_local_weights = [construct_ordered_dict(global_model, flat_weights) for flat_weights in malicious_grads]

# conv1_weights = [np.array(w['conv1.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
# conv1_weights = np.array(conv1_weights)
# conv2_weights = [np.array(w['conv2.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
# conv2_weights = np.array(conv2_weights)
# fc1_weights = [np.array(w['fc1.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
# fc1_weights = np.array(fc1_weights)
# fc2_weights = [np.array(w['fc2.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
# fc2_weights = np.array(fc2_weights)
# fc3_weights = [np.array(w['fc3.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
# fc3_weights = np.array(fc3_weights)

# num_layers = 5

# _, s_layer_0, _, = np.linalg.svd(conv1_weights)
# _, s_layer_1, _, = np.linalg.svd(conv2_weights)
# _, s_layer_2, _, = np.linalg.svd(fc1_weights)
# _, s_layer_3, _, = np.linalg.svd(fc2_weights)
# _, s_layer_4, _, = np.linalg.svd(fc3_weights)


In [None]:
# from matplotlib import pyplot as plt

# for x, y in zip(np.arange(num_layers-1), [s_layer_0, s_layer_1, s_layer_3, s_layer_4]):
#     plt.scatter([x]*len(s_layer_0), y, cmap="copper")
#     for i, txt in enumerate(np.arange(len(s_layer_0))):
#         plt.annotate(txt, (x, y[i]))
    
# plt.xticks(np.arange(num_layers))
# plt.show()

In [None]:
# if len(args.mal_users) > 0:
#     import csv
#     if args.iid:
#         filename = 'Confidence_2IID_'+args.dataset+'_'+args.aggregation+'.csv'
#     else:
#         filename = 'Confidence_2NIID_'+args.dataset+'_'+args.aggregation+'.csv'

#     with open(filename, mode='w', newline='') as csvfile:
#         w = csv.writer(csvfile)
#         w.writerows(map(lambda x: [x], confidence))



In [None]:
from utils import get_mal_dataset_of_class

test_mal_X_list, test_mal_Y, test_Y_true = get_mal_dataset_of_class(test_dataset, args.num_mal*100, Y_true, mal_Y)
flat_test_mal_Y = np.hstack([y for y in test_mal_Y])
mal_acc, mal_loss, mal_out = mal_inference(args, global_model, test_dataset, test_mal_X_list, flat_test_mal_Y)



In [None]:
mal_acc, mal_loss, torch.mean(mal_out)

In [None]:
args.dataset, args.boost

In [None]:
exp_details(args)


In [None]:
##### # PLOTTING (optional)
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Agg')

# Plot Loss curve
plt.figure()
plt.title('Training Loss vs Communication rounds')
plt.plot(range(len(train_loss)), train_loss, color='r')
plt.ylabel('Training loss')
plt.xlabel('Communication Rounds')
plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_loss.png'.
            format(args.dataset, args.model, args.epochs, args.frac,
                   args.iid, args.local_ep, args.local_bs))

# Plot Average Accuracy vs Communication rounds
plt.figure()
plt.title('Average Accuracy vs Communication rounds')
plt.plot(range(len(train_accuracy)), train_accuracy, color='k')
plt.ylabel('Average Accuracy')
plt.xlabel('Communication Rounds')
plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_acc.png'.
            format(args.dataset, args.model, args.epochs, args.frac,
                   args.iid, args.local_ep, args.local_bs))