## Plan

- [x] FL base framework
- [x] implement non-iid sampling function following dirichlet dist

#### Attack simulation
- [x] untargeted model poisoning
    - Local model poisoning attacks to Byzantine-robust federated learning <br>
    https://github.com/vrt1shjwlkr/NDSS21-Model-Poisoning fang attack <br>
    krum-attack, trimmed-mean/median attack
- [x] targeted model poisoning
    - Analyzing federated learning through an adversarial lens <br>
    https://github.com/inspire-group/ModelPoisoning
- [x] data poisoning
    - DBA: Distributed backdoor attacks against federated learning <br>
    https://github.com/AI-secure/DBA
    
#### Defense baseline
- [x] FedAvg
- [x] Krum
- [x] Multi-Krum
- [x] Bulyan
- [x] Coordinate median
- [x] FLARE

#### Proposed method
- [x] extract PLRs
- [x] apply RBF hypersphere CKA

In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [2]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6


import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import copy
import time
import pickle
import numpy as np
from tqdm import tqdm
from collections import OrderedDict

import torch
from tensorboardX import SummaryWriter

from options import args_parser
from update import LocalUpdate, test_inference, mal_inference
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar, Alexnet, modelC, LeNet5
from utils import get_dataset, get_mal_dataset, exp_details, flatten, construct_ordered_dict
from aggregate import fedavg, multi_krum, krum, coomed, bulyan, tr_mean, fed_align, fed_cc, flare, fltrust
from attacks import get_malicious_updates_untargeted_mkrum, get_malicious_updates_untargeted_med, get_malicious_updates_targeted
from cka import linear_CKA, kernel_CKA
# python src/federated_main.py --model=cnn --dataset=cifar --gpu=0 --iid=1 --epochs=10

import warnings
# warnings.filterwarnings('ignore')

# Targeted Model Poisoning Attack (label flipping)

In [3]:
class Args(object):
    
    # federated parameters (default values are set)
    epochs = 50
    num_users = 10
    frac = 1 # fraction of clients
    local_ep = 3 # num of local epoch
    local_bs = 100 # batch size
    lr = 0.001
    momentum = 0.9
    aggregation = 'fedcc' # trmean, flare, fltrust, fedcc fedavg, krum, mkrum, trmean, coomed, bulyan, flare, fedcc, fltrust
    
    # model arguments
    model = 'cnn'
    kernel_num = 9 # num of each kind of kernel
    kernel_sizes = '3,4,5' # comma-separated kernel size to use for convolution
    norm = 'batch_norm' # batch_norm, layer_norm, None
    num_filters = 32 # num of filters for conv nets -- 32 for mini-imagenet, 64 for omiglot
    max_pool = 'True' # whether use max pooling rather than strided convolutions
    
    # other arguments
    dataset = 'cifar' # fmnist, cifar, mnist
    if dataset == 'cifar100':
        num_classes = 100 
        num_channels = 3 # num of channels of imgs
    else:
        num_classes = 10
        num_channels = 1
    
    gpu = 0
    optimizer = 'adam'
    iid = 0 # 0 for non-iid
    alpha = 0.2 # noniid --> (0, 1] <-- iid
    unequal = 0 # whether to use unequal data splits for non-iid settings (0 for equal splits)
    stopping_rounds = 10 # rounds of early stopping
    verbose = 0
    seed = 1

    # malicious arguments
    mal_clients = [0] # indices of malicious user
    attack_type = 'targeted' # targeted
    num_mal = 5 # number of maliciuos data sample
    mal_bs = 100
    mal_lr = 0.005
    mal_test_bs = 100
    local_mal_ep = 6
    boost = 5 # alpha: 2 for fedavg, 3.5 for krum

In [None]:

if __name__ == '__main__':
    start_time = time.time()

    # define paths
    path_project = os.path.abspath('..')
    logger = SummaryWriter('../logs')

    args = Args()
    exp_details(args)

    device = 'cuda:0' if args.gpu == 0 else 'cpu'
    # device = 'cpu'
    # for n_attacker in args.n_attackers:
    torch.cuda.empty_cache()

    # load dataset and user groups
    train_dataset, test_dataset, user_groups = get_dataset(args)

    # BUILD MODEL
    if args.model == 'cnn':
        # Convolutional neural netork
        if args.dataset == 'mnist':
            global_model = CNNMnist(args=args)
        elif args.dataset == 'fmnist':
            global_model = CNNFashion_Mnist(args=args)
        elif args.dataset == 'cifar':
            global_model = CNNCifar(args=args)
        elif args.dataset == 'cifar100':
            global_model = LeNet5(args=args)

            
    # elif args.model == 'alexnet':
    #     if args.dataset == 'cifar100':
    #         global_model = Alexnet(args=args)
    
    elif args.model == 'mlp':
        # Multi-layer preceptron
        img_size = train_dataset[0][0].shape
        len_in = 1
        for x in img_size:
            len_in *= x
            global_model = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes)
    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    global_model.to(device)
    global_model.train()
    print(global_model)

    # copy weights
    global_weights = global_model.state_dict()
    if len(args.mal_clients) > 0:
        mal_X_list, mal_Y, Y_true = get_mal_dataset(test_dataset, args.num_mal, args.num_classes)
        print("malcious dataset true labels: {}, malicious labels: {}".format(Y_true, mal_Y))

    # Training
    train_loss, train_accuracy = [], []
    val_acc_list, net_list = [], []
    cv_loss, cv_acc = [], []
    print_every = 1
    val_loss_pre, counter = 0, 0
    
    confidence = []    
    for epoch in tqdm(range(args.epochs)):
        local_weights, local_losses = [], []
        
        print('=========================================')
        print(f'| Global Training Round : {epoch+1} |')
        print('=========================================')

        global_model.train()
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)

        # flattened weights + bias
        flattened_local_weights = []
        
        # create separate arrays for weights and biases and the offsets 
        only_weights = [] # for krum to consider only weights, not the biases
        only_biases = []
            
        for idx in range(args.num_users):
            mal_user = False
            
            # alternating benign and malicious training 
            if idx in args.mal_clients:# and epoch % 2 == 0:
                mal_user = True
                local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=user_groups[idx], logger=logger, \
                                      mal=mal_user, mal_X=mal_X_list, mal_Y=mal_Y, test_dataset=test_dataset)
            else:
                local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=user_groups[idx], logger=logger)
                
            w_prev = global_model.state_dict()
            w, loss = local_model.update_weights(model=copy.deepcopy(global_model), global_round=epoch)
            # boost the malicious (weight+bias) by alpha
            # if mal_user:
            #     flat_delta_m = flatten(w) - flatten(w_prev)
            #     flat_mal_w = flatten(w_prev) + args.boost * flat_delta_m
            #     w = construct_ordered_dict(global_model, torch.tensor(flat_mal_w).to(device))
              
            # construct two arrays with weights-only and bias-only
            if 'krum' in args.aggregation:
                for key in w:
                    if 'weight' in key:
                        if mal_user and epoch % 2 == 0:
                            only_weights = np.append(only_weights, args.boost * w[key].detach().cpu().numpy().reshape(-1))
                        else:
                            only_weights = np.append(only_weights, w[key].detach().cpu().numpy().reshape(-1))
                    elif 'bias' in key:
                        only_biases = np.append(only_biases, w[key].detach().cpu().numpy().reshape(-1))
                            
            new_model = copy.deepcopy(global_model)
            new_model.load_state_dict(w)
            acc, _ = local_model.inference(model=new_model)

            if mal_user == True:
                mal_acc, mal_loss = local_model.mal_inference(model=new_model)
                print('user {}, loss {}, acc {}, mal loss {}, mal acc {}'.format(idx, loss, 100*acc, mal_loss, 100*mal_acc))
            else:
                print('user {}, loss {}, acc {}'.format(idx, loss, 100*acc))

            local_weights.append(copy.deepcopy(w))
            local_losses.append(copy.deepcopy(loss))
            
            # if agg==krum, alternate training
            if args.aggregation == 'krum':
                if idx in args.mal_clients and epoch %2 ==0:
                    flattened_local_weights.append(flatten(w_prev) + args.boost * (flatten(w) - flatten(w_prev)))
                else:                     
                    flattened_local_weights.append(flatten(w))
            else:
                # if malicious user: boost the detlta weights
                if idx in args.mal_clients:
                    flattened_local_weights.append(flatten(w_prev) + args.boost * (flatten(w) - flatten(w_prev)))
                else: 
                    flattened_local_weights.append(flatten(w))
        
        only_weights = torch.tensor(np.array(only_weights)).to(device)
                
        flattened_local_weights = torch.tensor(np.array(flattened_local_weights)).to(device)
        malicious_grads = flattened_local_weights

        n_attacker = len(args.mal_clients)
        
        
        # update global weights
        if args.aggregation == 'fedavg':
            agg_weights = fedavg(malicious_grads)
        elif args.aggregation == 'krum':
            agg_weights, selected_idxs = krum(malicious_grads, n_attacker, only_weights=only_weights)
            print(f'Krum Selected idx: {selected_idxs}')
        elif args.aggregation == 'mkrum':
            agg_weights, selected_idxs = multi_krum(malicious_grads, n_attacker, only_weights=only_weights)
            print(f'multiKrum Selected idxs: {selected_idxs}')
        elif args.aggregation == 'coomed':
            agg_weights = coomed(malicious_grads)
            print(f'\ndiff {torch.norm((agg_weights - flattened_local_weights[0])) ** 2}')
        elif args.aggregation == 'bulyan':
            agg_weights, selected_idxs = bulyan(malicious_grads, n_attacker)
            print(f'Bulyan Selected idx: {selected_idxs}')
        elif args.aggregation == 'trmean':
            agg_weights = tr_mean(malicious_grads, n_attacker)
            
        elif args.aggregation == 'fltrust':
            glob_weights = []
            glob_weights.append(flatten(global_weights))
            agg_weights = fltrust(malicious_grads, glob_weights)

        elif args.aggregation == 'flare':
            second_last_layer = list(local_weights[0].keys())[-4]
            structured_local_weights = [construct_ordered_dict(global_model, flat_weights) for flat_weights in malicious_grads]
            plrs = [(each_local[second_last_layer]) for each_local in structured_local_weights]
            agg_weights, count_dict = flare(malicious_grads, plrs)
            print(f'flare count_dict: {count_dict}')

        elif args.aggregation == 'fedcc':
            second_last_layer = list(local_weights[0].keys())[-4]
            glob_plr = global_weights[second_last_layer]
            agg_weights, selected_idxs = fed_cc(local_weights, glob_plr, 'kernel')
            print(f'fed_cc Selected idx: {selected_idxs}')
        else:
            raise ValueError('Unknown aggregation strategy: {}'.format(args.aggregation))


        # reshape the flattened global weights into the ordereddict
        global_weights = construct_ordered_dict(global_model, agg_weights)
        
        # update global weights
        global_model.load_state_dict(global_weights)

        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)

        # print global training loss after every 'i' rounds
        if (epoch+1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch+1} global rounds:')
            print(f'Training Loss : {np.mean(np.array(train_loss))}')
            
            test_acc, test_loss = test_inference(args, global_model, test_dataset)
            print('\nGlobal model Benign Test Accuracy: {:.2f}% '.format(100*test_acc))
            
            if len(args.mal_clients) > 0:
                mal_acc, mal_loss, mal_out = mal_inference(args, global_model, test_dataset, mal_X_list, mal_Y)
                print('Global model Malicious Accuracy: {:.2f}%, Malicious Loss: {:.2f}, confidence: {}\n'.format(100*mal_acc, mal_loss, mal_out))
                confidence.append(mal_out)
                
    # Test inference after completion of training
    test_acc, test_loss = test_inference(args, global_model, test_dataset)

    print(f' \n Results after {args.epochs} global rounds of training:')
    print("|---- Test Benign Accuracy: {:.2f}%".format(100*test_acc))
    
    if len(args.mal_clients) > 0:
        mal_acc, mal_loss, mal_out = mal_inference(args, global_model, test_dataset, mal_X_list, mal_Y)
        print("|---- Test Malicious Accuracy: {:.2f}%, Malicious Loss: {:.2f}, confidence:{}\n".format(100*mal_acc, mal_loss, mal_out))
    
    print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))
    
    
    import pandas as pd
    data = {
        "clean_acc": test_acc,
        "backdoor_acc": mal_acc,
        "confidence": confidence
    }
    csv_file_path = f"output/confidence_{args.alpha}_{args.dataset}_{args.aggregation}.csv"
    df = pd.DataFrame(data)
    df.to_csv(csv_file_path, index=False)


Experimental details:
    Model     : cnn
    Optimizer : adam
    Learning  : 0.001
    Aggregation     : fedcc
    Global Rounds   : 50

    Federated parameters:
    Non-IID              : 0.2
    Fraction of users    : 1
    Local Batch size     : 100
    Local Epochs         : 3

    Malicious parameters:
    Attackers            : [0]
    Attack Type          : targeted
Files already downloaded and verified
Files already downloaded and verified
CNNCifar(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)
malcious dataset true labels: 9, malicious labels: [7, 7, 7, 7, 7]


  0%|          | 0/50 [00:00<?, ?it/s]

| Global Training Round : 1 |


  return F.conv2d(input, weight, bias, self.stride,


user 0, loss 0.9027448676250599, acc 29.1970802919708, mal loss 2.9416537284851074, mal acc 0.0
user 1, loss 0.8725316736612235, acc 88.62144420131291
user 2, loss 0.9857898461070039, acc 78.57142857142857
user 3, loss 0.25065938551480754, acc 100.0
user 4, loss 0.4986122154681258, acc 21.2406015037594
user 5, loss 1.3994541984420639, acc 62.745098039215684
user 6, loss 0.121869338903232, acc 100.0
user 7, loss 0.4347316640796084, acc 89.94515539305301
user 8, loss 1.4066359549760818, acc 16.853932584269664
user 9, loss 0.9929874895984291, acc 59.896907216494846
fed_cc Selected idx: [1 3 4 5 7 8]
 
Avg Training Stats after 1 global rounds:
Training Loss : 0.7866016634375634


  2%|▏         | 1/50 [00:18<15:30, 18.99s/it]


Global model Benign Test Accuracy: 13.27% 
Global model Malicious Accuracy: 0.00%, Malicious Loss: 3.07, confidence: 0.04662033319473267

| Global Training Round : 2 |
user 0, loss 0.8605722569757037, acc 15.875912408759124, mal loss 3.4329447746276855, mal acc 0.0
user 1, loss 0.7549328693935463, acc 90.15317286652079
user 2, loss 0.9367183501593733, acc 86.18266978922716
user 3, loss 0.19787840884274396, acc 100.0
user 4, loss 0.387427045162334, acc 22.55639097744361
user 5, loss 1.2563471820977357, acc 69.49891067538127
user 6, loss 0.1032634203884543, acc 100.0
user 7, loss 0.3483794136255076, acc 83.72943327239489
user 8, loss 1.1677336767315865, acc 65.1685393258427
user 9, loss 0.9588809746962327, acc 73.4020618556701
fed_cc Selected idx: [1 3 4 5 7 8]
 
Avg Training Stats after 2 global rounds:
Training Loss : 0.7419075116224426


  4%|▍         | 2/50 [00:38<15:22, 19.23s/it]


Global model Benign Test Accuracy: 19.82% 
Global model Malicious Accuracy: 0.00%, Malicious Loss: 3.81, confidence: 0.02302250862121582

| Global Training Round : 3 |
user 0, loss 0.8053478566584763, acc 18.065693430656935, mal loss 3.9732792377471924, mal acc 0.0
user 1, loss 0.6988836416790077, acc 89.71553610503283
user 2, loss 0.8818500090912345, acc 73.18501170960188
user 3, loss 0.1500876922494708, acc 100.0
user 4, loss 0.29000735346430034, acc 21.2406015037594
user 5, loss 1.1754199336240958, acc 90.6318082788671
user 6, loss 0.07057735329390284, acc 100.0
user 7, loss 0.3084726519318241, acc 73.6745886654479
user 8, loss 1.0005620097120602, acc 59.55056179775281
user 9, loss 0.8978177833760905, acc 67.5257731958763
fed_cc Selected idx: [1 3 4 5 7 8]
 
Avg Training Stats after 3 global rounds:
Training Loss : 0.7039058839176437


  6%|▌         | 3/50 [00:55<14:27, 18.45s/it]


Global model Benign Test Accuracy: 19.87% 
Global model Malicious Accuracy: 0.00%, Malicious Loss: 4.50, confidence: 0.012015944719314576

| Global Training Round : 4 |
user 0, loss 0.7696266408871723, acc 16.970802919708028, mal loss 3.888416290283203, mal acc 0.0
user 1, loss 0.6564105169193165, acc 88.40262582056893
user 2, loss 0.8216652173350975, acc 66.39344262295081
user 3, loss 0.11728605778458027, acc 100.0
user 4, loss 0.2609218895550846, acc 22.932330827067666
user 5, loss 1.0964111964981835, acc 88.45315904139433
user 6, loss 0.0552019612827003, acc 100.0
user 7, loss 0.27671137141684693, acc 75.3199268738574
user 8, loss 1.0038829421003659, acc 61.79775280898876
user 9, loss 0.8771288321058974, acc 82.2680412371134
fed_cc Selected idx: [1 3 4 5 7 8]
 
Avg Training Stats after 4 global rounds:
Training Loss : 0.6763105785853639


  8%|▊         | 4/50 [01:13<14:01, 18.28s/it]


Global model Benign Test Accuracy: 25.11% 
Global model Malicious Accuracy: 0.00%, Malicious Loss: 4.85, confidence: 0.008684533834457397

| Global Training Round : 5 |
user 0, loss 0.7223910975235479, acc 20.072992700729927, mal loss 3.5225791931152344, mal acc 0.0
user 1, loss 0.6089117607554874, acc 88.18380743982495
user 2, loss 0.7636014388667212, acc 82.31850117096019
user 3, loss 0.10722213799435383, acc 100.0
user 4, loss 0.24222581694985543, acc 22.932330827067666
user 5, loss 1.0312771421295028, acc 91.06753812636165
user 6, loss 0.03595870502378694, acc 100.0
user 7, loss 0.2647716853428971, acc 81.53564899451554
user 8, loss 0.8967272092898687, acc 52.80898876404494
user 9, loss 0.8483069016892685, acc 80.10309278350516
fed_cc Selected idx: [1 3 4 5 7 8]
 
Avg Training Stats after 5 global rounds:
Training Loss : 0.6514763407795969


 10%|█         | 5/50 [01:34<14:13, 18.98s/it]


Global model Benign Test Accuracy: 27.70% 
Global model Malicious Accuracy: 0.00%, Malicious Loss: 5.32, confidence: 0.006213701143860817

| Global Training Round : 6 |


In [None]:
# 1 cifar trmean 1 fmnist trmean

class Args(object):
    
    # federated parameters (default values are set)
    epochs = 50
    num_users = 10
    frac = 1 # fraction of clients
    local_ep = 3 # num of local epoch
    local_bs = 100 # batch size
    lr = 0.001
    momentum = 0.9
    aggregation = 'fedcc' # fedavg, krum, mkrum, coomed, bulyan, flare, fedcc
    
    # model arguments
    model = 'cnn'
    kernel_num = 9 # num of each kind of kernel
    kernel_sizes = '3,4,5' # comma-separated kernel size to use for convolution
    norm = 'batch_norm' # batch_norm, layer_norm, None
    num_filters = 32 # num of filters for conv nets -- 32 for mini-imagenet, 64 for omiglot
    max_pool = 'True' # whether use max pooling rather than strided convolutions
    
    # other arguments
    dataset = 'cifar' # fmnist, cifar, mnist
    if dataset == 'cifar100':
        num_classes = 100 
        num_channels = 3 # num of channels of imgs
    else:
        num_classes = 10
        num_channels = 1
    
    gpu = 0
    optimizer = 'adam'
    iid = 0 # 0 for non-iid
    alpha = 0.2 # noniid --> (0, 1] <-- iid
    unequal = 0 # whether to use unequal data splits for non-iid settings (0 for equal splits)
    stopping_rounds = 10 # rounds of early stopping
    verbose = 0
    seed = 1

    # malicious arguments
    mal_clients = [0] # indices of malicious user
    attack_type = 'targeted' # targeted
    num_mal = 5 # number of maliciuos data sample
    mal_bs = 100
    mal_lr = 0.005
    mal_test_bs = 100
    local_mal_ep = 6
    boost = 5 # alpha: 2 for fedavg, 3.5 for krum
    

if __name__ == '__main__':
    start_time = time.time()

    # define paths
    path_project = os.path.abspath('..')
    logger = SummaryWriter('../logs')

    args = Args()
    exp_details(args)

    device = 'cuda:0' if args.gpu == 0 else 'cpu'
    # device = 'cpu'
    # for n_attacker in args.n_attackers:
    torch.cuda.empty_cache()

    # load dataset and user groups
    train_dataset, test_dataset, user_groups = get_dataset(args)

    # BUILD MODEL
    if args.model == 'cnn':
        # Convolutional neural netork
        if args.dataset == 'mnist':
            global_model = CNNMnist(args=args)
        elif args.dataset == 'fmnist':
            global_model = CNNFashion_Mnist(args=args)
        elif args.dataset == 'cifar':
            global_model = CNNCifar(args=args)
        elif args.dataset == 'cifar100':
            global_model = LeNet5(args=args)

            
    # elif args.model == 'alexnet':
    #     if args.dataset == 'cifar100':
    #         global_model = Alexnet(args=args)
    
    elif args.model == 'mlp':
        # Multi-layer preceptron
        img_size = train_dataset[0][0].shape
        len_in = 1
        for x in img_size:
            len_in *= x
            global_model = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes)
    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    global_model.to(device)
    global_model.train()
    print(global_model)

    # copy weights
    global_weights = global_model.state_dict()
    if len(args.mal_clients) > 0:
        mal_X_list, mal_Y, Y_true = get_mal_dataset(test_dataset, args.num_mal, args.num_classes)
        print("malcious dataset true labels: {}, malicious labels: {}".format(Y_true, mal_Y))

    # Training
    train_loss, train_accuracy = [], []
    val_acc_list, net_list = [], []
    cv_loss, cv_acc = [], []
    print_every = 1
    val_loss_pre, counter = 0, 0
    
    confidence = []    
    for epoch in tqdm(range(args.epochs)):
        local_weights, local_losses = [], []
        
        print('=========================================')
        print(f'| Global Training Round : {epoch+1} |')
        print('=========================================')

        global_model.train()
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)

        # flattened weights + bias
        flattened_local_weights = []
        
        # create separate arrays for weights and biases and the offsets 
        only_weights = [] # for krum to consider only weights, not the biases
        only_biases = []
            
        for idx in range(args.num_users):
            mal_user = False
            
            # alternating benign and malicious training 
            if idx in args.mal_clients:# and epoch % 2 == 0:
                mal_user = True
                local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=user_groups[idx], logger=logger, \
                                      mal=mal_user, mal_X=mal_X_list, mal_Y=mal_Y, test_dataset=test_dataset)
            else:
                local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=user_groups[idx], logger=logger)
                
            w_prev = global_model.state_dict()
            w, loss = local_model.update_weights(model=copy.deepcopy(global_model), global_round=epoch)
            # boost the malicious (weight+bias) by alpha
            # if mal_user:
            #     flat_delta_m = flatten(w) - flatten(w_prev)
            #     flat_mal_w = flatten(w_prev) + args.boost * flat_delta_m
            #     w = construct_ordered_dict(global_model, torch.tensor(flat_mal_w).to(device))
              
            # construct two arrays with weights-only and bias-only
            if 'krum' in args.aggregation:
                for key in w:
                    if 'weight' in key:
                        if mal_user and epoch % 2 == 0:
                            only_weights = np.append(only_weights, args.boost * w[key].detach().cpu().numpy().reshape(-1))
                        else:
                            only_weights = np.append(only_weights, w[key].detach().cpu().numpy().reshape(-1))
                    elif 'bias' in key:
                        only_biases = np.append(only_biases, w[key].detach().cpu().numpy().reshape(-1))
                            
            new_model = copy.deepcopy(global_model)
            new_model.load_state_dict(w)
            acc, _ = local_model.inference(model=new_model)

            if mal_user == True:
                mal_acc, mal_loss = local_model.mal_inference(model=new_model)
                print('user {}, loss {}, acc {}, mal loss {}, mal acc {}'.format(idx, loss, 100*acc, mal_loss, 100*mal_acc))
            else:
                print('user {}, loss {}, acc {}'.format(idx, loss, 100*acc))

            local_weights.append(copy.deepcopy(w))
            local_losses.append(copy.deepcopy(loss))
            
            # if agg==krum, alternate training
            if args.aggregation == 'krum':
                if idx in args.mal_clients and epoch %2 ==0:
                    flattened_local_weights.append(flatten(w_prev) + args.boost * (flatten(w) - flatten(w_prev)))
                else:                     
                    flattened_local_weights.append(flatten(w))
            else:
                # if malicious user: boost the detlta weights
                if idx in args.mal_clients:
                    flattened_local_weights.append(flatten(w_prev) + args.boost * (flatten(w) - flatten(w_prev)))
                else: 
                    flattened_local_weights.append(flatten(w))
        
        only_weights = torch.tensor(np.array(only_weights)).to(device)
                
        flattened_local_weights = torch.tensor(np.array(flattened_local_weights)).to(device)
        malicious_grads = flattened_local_weights

        n_attacker = len(args.mal_clients)
        
        
        # update global weights
        if args.aggregation == 'fedavg':
            agg_weights = fedavg(malicious_grads)
        elif args.aggregation == 'krum':
            agg_weights, selected_idxs = krum(malicious_grads, n_attacker, only_weights=only_weights)
            print(f'Krum Selected idx: {selected_idxs}')
        elif args.aggregation == 'mkrum':
            agg_weights, selected_idxs = multi_krum(malicious_grads, n_attacker, only_weights=only_weights)
            print(f'multiKrum Selected idxs: {selected_idxs}')
        elif args.aggregation == 'coomed':
            agg_weights = coomed(malicious_grads)
            print(f'\ndiff {torch.norm((agg_weights - flattened_local_weights[0])) ** 2}')
        elif args.aggregation == 'bulyan':
            agg_weights, selected_idxs = bulyan(malicious_grads, n_attacker)
            print(f'Bulyan Selected idx: {selected_idxs}')
        elif args.aggregation == 'trmean':
            agg_weights = tr_mean(malicious_grads, n_attacker)
            
        elif args.aggregation == 'fltrust':
            glob_weights = []
            glob_weights.append(flatten(global_weights))
            agg_weights = fltrust(malicious_grads, glob_weights)

        elif args.aggregation == 'flare':
            second_last_layer = list(local_weights[0].keys())[-4]
            structured_local_weights = [construct_ordered_dict(global_model, flat_weights) for flat_weights in malicious_grads]
            plrs = [(each_local[second_last_layer]) for each_local in structured_local_weights]
            agg_weights, count_dict = flare(malicious_grads, plrs)
            print(f'flare count_dict: {count_dict}')

        elif args.aggregation == 'fedcc':
            second_last_layer = list(local_weights[0].keys())[-4]
            glob_plr = global_weights[second_last_layer]
            structured_local_weights = [construct_ordered_dict(global_model, flat_weights) for flat_weights in malicious_grads]
            plrs = [(each_local[second_last_layer].reshape((glob_plr.shape[0], glob_plr.shape[1]))) for each_local in structured_local_weights]
            agg_weights, selected_idxs = fed_cc(local_weights, glob_plr, 'kernel')
            print(f'fed_cc Selected idx: {selected_idxs}')
        else:
            raise ValueError('Unknown aggregation strategy: {}'.format(args.aggregation))


        # reshape the flattened global weights into the ordereddict
        global_weights = construct_ordered_dict(global_model, agg_weights)
        
        # update global weights
        global_model.load_state_dict(global_weights)

        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)

        # print global training loss after every 'i' rounds
        if (epoch+1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch+1} global rounds:')
            print(f'Training Loss : {np.mean(np.array(train_loss))}')
            
            test_acc, test_loss = test_inference(args, global_model, test_dataset)
            print('\nGlobal model Benign Test Accuracy: {:.2f}% '.format(100*test_acc))
            
            if len(args.mal_clients) > 0:
                mal_acc, mal_loss, mal_out = mal_inference(args, global_model, test_dataset, mal_X_list, mal_Y)
                print('Global model Malicious Accuracy: {:.2f}%, Malicious Loss: {:.2f}, confidence: {}\n'.format(100*mal_acc, mal_loss, mal_out))
                confidence.append(mal_out)
                
    # Test inference after completion of training
    test_acc, test_loss = test_inference(args, global_model, test_dataset)

    print(f' \n Results after {args.epochs} global rounds of training:')
    print("|---- Test Benign Accuracy: {:.2f}%".format(100*test_acc))
    
    if len(args.mal_clients) > 0:
        mal_acc, mal_loss, mal_out = mal_inference(args, global_model, test_dataset, mal_X_list, mal_Y)
        print("|---- Test Malicious Accuracy: {:.2f}%, Malicious Loss: {:.2f}, confidence:{}\n".format(100*mal_acc, mal_loss, mal_out))
    
    print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))
    
    
    import pandas as pd
    data = {
        "clean_acc": test_acc,
        "backdoor_acc": mal_acc,
        "confidence": confidence
    }
    csv_file_path = f"output/confidence_{args.alpha}_{args.dataset}_{args.aggregation}.csv"
    df = pd.DataFrame(data)
    df.to_csv(csv_file_path, index=False)

In [None]:
k

In [None]:
only_weights = []
for key in global_weights:
    if 'weight' in key:
        only_weights = np.append(only_weights, global_weights[key].detach().cpu().numpy())
_, glob_svd, _, = np.linalg.svd(only_weights.reshape(1,-1))
glob_svd

In [None]:
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances

sim_vals = []
glob_plr = glob_plr.detach().cpu()

print('euclidean')
for i in range(len(plrs)):
    val = euclidean_distances(glob_plr.reshape(1, -1), plrs[i].detach().cpu().reshape(1,-1))[0][0]
    # print(val)
    if np.isnan(val):
        sim_vals.append(0)
    else:
        sim_vals.append(val)
print(sim_vals)
    

In [None]:
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances, paired_cosine_distances

sim_vals = []
glob_plr = glob_plr.detach().cpu()

print('cosine')
for i in range(len(plrs)):
    val = cosine_similarity(glob_plr.reshape(1, -1), plrs[i].detach().cpu().reshape(1,-1))[0][0]
    if np.isnan(val):
        sim_vals.append(0)
    else:
        sim_vals.append(val)
print(sim_vals)
    

In [None]:
global_weights.keys()

In [None]:
structured_local_weights = [construct_ordered_dict(global_model, flat_weights) for flat_weights in malicious_grads]

conv1_weights = [np.array(w['conv1.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
conv1_weights = np.array(conv1_weights)
conv2_weights = [np.array(w['conv2.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
conv2_weights = np.array(conv2_weights)
conv3_weights = [np.array(w['conv3.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
conv3_weights = np.array(conv3_weights)
fc1_weights = [np.array(w['fc1.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
fc1_weights = np.array(fc1_weights)
fc2_weights = [np.array(w['fc2.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
fc2_weights = np.array(fc2_weights)

num_layers = 5

_, s_layer_0, _, = np.linalg.svd(conv1_weights)
_, s_layer_1, _, = np.linalg.svd(conv2_weights)
_, s_layer_2, _, = np.linalg.svd(conv3_weights)
_, s_layer_3, _, = np.linalg.svd(fc1_weights)
_, s_layer_4, _, = np.linalg.svd(fc2_weights)


In [None]:
%matplotlib inline

import matplotlib
import matplotlib.pyplot as plt

In [None]:

for x, y in zip(np.arange(num_layers), [s_layer_0, s_layer_1, s_layer_2, s_layer_3, s_layer_4]):
    plt.scatter([x]*len(s_layer_0), y, cmap="copper")
    for i, txt in enumerate(np.arange(len(s_layer_0))):
        plt.annotate(txt, (x, y[i]))
    
plt.xticks(np.arange(num_layers))
plt.show()

In [None]:
from scipy.cluster.hierarchy import dendrogram, linkage

linkage_data = linkage(conv1_weights, method='single', metric='correlation')
dendrogram(linkage_data)['dcoord']
plt.title('Targeted-NIID: layer1')

# plt.show()

In [None]:

linkage_data = linkage(conv2_weights, method='single', metric='correlation')
dendrogram(linkage_data)['dcoord']
plt.title('Targeted-NIID: layer2')

# plt.show()

In [None]:

linkage_data = linkage(conv3_weights, method='single', metric='correlation')
dendrogram(linkage_data)['dcoord']
plt.title('Targeted-NIID: layer3')

# plt.show()

In [None]:

linkage_data = linkage(fc1_weights, method='single', metric='correlation')
dendrogram(linkage_data)['dcoord']
plt.title('Targeted-NIID: PLR')

# plt.show()

In [None]:

linkage_data = linkage(fc2_weights, method='single', metric='correlation')
dendrogram(linkage_data)['dcoord']
plt.title('Targeted-NIID: layer5')

# plt.show()

In [None]:
# structured_local_weights = [construct_ordered_dict(global_model, flat_weights) for flat_weights in malicious_grads]

# conv1_weights = [np.array(w['conv1.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
# conv1_weights = np.array(conv1_weights)
# conv2_weights = [np.array(w['conv2.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
# conv2_weights = np.array(conv2_weights)
# fc1_weights = [np.array(w['fc1.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
# fc1_weights = np.array(fc1_weights)
# fc2_weights = [np.array(w['fc2.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
# fc2_weights = np.array(fc2_weights)
# fc3_weights = [np.array(w['fc3.weight'].detach().cpu().reshape(-1)) for w in structured_local_weights]
# fc3_weights = np.array(fc3_weights)

# num_layers = 5

# _, s_layer_0, _, = np.linalg.svd(conv1_weights)
# _, s_layer_1, _, = np.linalg.svd(conv2_weights)
# _, s_layer_2, _, = np.linalg.svd(fc1_weights)
# _, s_layer_3, _, = np.linalg.svd(fc2_weights)
# _, s_layer_4, _, = np.linalg.svd(fc3_weights)


In [None]:
# from matplotlib import pyplot as plt

# for x, y in zip(np.arange(num_layers-1), [s_layer_0, s_layer_1, s_layer_3, s_layer_4]):
#     plt.scatter([x]*len(s_layer_0), y, cmap="copper")
#     for i, txt in enumerate(np.arange(len(s_layer_0))):
#         plt.annotate(txt, (x, y[i]))
    
# plt.xticks(np.arange(num_layers))
# plt.show()

In [None]:
# if len(args.mal_clients) > 0:
#     import csv
#     if args.iid:
#         filename = 'Confidence_2IID_'+args.dataset+'_'+args.aggregation+'.csv'
#     else:
#         filename = 'Confidence_2NIID_'+args.dataset+'_'+args.aggregation+'.csv'

#     with open(filename, mode='w', newline='') as csvfile:
#         w = csv.writer(csvfile)
#         w.writerows(map(lambda x: [x], confidence))



In [None]:
from utils import get_mal_dataset_of_class

test_mal_X_list, test_mal_Y, test_Y_true = get_mal_dataset_of_class(test_dataset, args.num_mal*100, Y_true, mal_Y)
flat_test_mal_Y = np.hstack([y for y in test_mal_Y])
mal_acc, mal_loss, mal_out = mal_inference(args, global_model, test_dataset, test_mal_X_list, flat_test_mal_Y)



In [None]:
mal_acc, mal_loss, torch.mean(mal_out)

In [None]:
args.dataset, args.boost

In [None]:
exp_details(args)


In [None]:
##### # PLOTTING (optional)
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Agg')

# Plot Loss curve
plt.figure()
plt.title('Training Loss vs Communication rounds')
plt.plot(range(len(train_loss)), train_loss, color='r')
plt.ylabel('Training loss')
plt.xlabel('Communication Rounds')
plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_loss.png'.
            format(args.dataset, args.model, args.epochs, args.frac,
                   args.iid, args.local_ep, args.local_bs))

# Plot Average Accuracy vs Communication rounds
plt.figure()
plt.title('Average Accuracy vs Communication rounds')
plt.plot(range(len(train_accuracy)), train_accuracy, color='k')
plt.ylabel('Average Accuracy')
plt.xlabel('Communication Rounds')
plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_acc.png'.
            format(args.dataset, args.model, args.epochs, args.frac,
                   args.iid, args.local_ep, args.local_bs))