## Plan

- [x] FL base framework
- [x] implement non-iid sampling function following dirichlet dist

#### Attack simulation
- [x] untargeted model poisoning
    - Local model poisoning attacks to Byzantine-robust federated learning <br>
    https://github.com/vrt1shjwlkr/NDSS21-Model-Poisoning fang attack <br>
    krum-attack, trimmed-mean/median attack
- [ ] targeted model poisoning
    - Analyzing federated learning through an adversarial lens <br>
    https://github.com/inspire-group/ModelPoisoning
- [ ] data poisoning
    - DBA: Distributed backdoor attacks against federated learning <br>
    https://github.com/AI-secure/DBA
    
#### Defense baseline
- [x] FedAvg
- [x] Krum
- [x] Multi-Krum
- [x] Bulyan
- [x] Coordinate median
- [ ] FoolsGold
- [ ] FLARE

#### Proposed method
- [ ] extract PLRs
- [ ] compare the hypersphere uniformity loss 
- [ ] apply RBF hypersphere CKA

In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [2]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6


import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import copy
import time
import pickle
import numpy as np
from tqdm import tqdm
from collections import OrderedDict

import torch
from tensorboardX import SummaryWriter

from options import args_parser
from update import LocalUpdate, test_inference, mal_inference
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar
from utils import get_dataset, get_mal_dataset, exp_details, flatten, construct_ordered_dict
from aggregate import fedavg, multi_krum, krum, coomed, bulyan, flare, fed_align, fed_cc
from attacks import get_malicious_updates_untargeted_mkrum, get_malicious_updates_untargeted_med
from cka import kernel_CKA
# python src/federated_main.py --model=cnn --dataset=cifar --gpu=0 --iid=1 --epochs=10



# Untargeted

In [3]:
class Args(object):
    
    # federated parameters (default values are set)
    epochs = 50
    num_users = 10
    frac = 1 # fraction of clients
    local_ep = 5 # num of local epoch
    local_bs = 100 # batch size
    lr = 0.001
    momentum = 0.9
    aggregation = 'fedcc' # fedavg, krum, mkrum, coomed, bulyan, flare, fedcc
    
    # model arguments
    model = 'cnn'
    kernel_num = 9 # num of each kind of kernel
    kernel_sizes = '3,4,5' # comma-separated kernel size to use for convolution
    num_channels = 1 # num of channels of imgs
    norm = 'batch_norm' # batch_norm, layer_norm, None
    num_filters = 32 # num of filters for conv nets -- 32 for mini-imagenet, 64 for omiglot
    max_pool = 'True' # whether use max pooling rather than strided convolutions
    
    # other arguments
    dataset = 'cifar' # fmnist, cifar, mnist
    num_classes = 10 
    gpu = 0
    optimizer = 'adam'
    iid = 0 # 0 for non-iid
    alpha = 0.2 # noniid --> (0, 1] <-- iid
    unequal = 0 # whether to use unequal data splits for non-iid settings (0 for equal splits)
    stopping_rounds = 10 # rounds of early stopping
    verbose = 0
    seed = 1

    # malicious arguments
    mal_users = [0, 1] # indices of malicious user
    attack_type = 'targeted' # targeted
    num_mal = 1 # number of maliciuos data sample
    local_mal_ep = 5
    boost = 5 # alpha: 2 for fedavg, 3.5 for krum

In [4]:

if __name__ == '__main__':
    start_time = time.time()

    # define paths
    path_project = os.path.abspath('..')
    logger = SummaryWriter('../logs')

    args = Args()
    exp_details(args)

    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

    # for n_attacker in args.n_attackers:
    torch.cuda.empty_cache()

    # load dataset and user groups
    train_dataset, test_dataset, user_groups = get_dataset(args)

    # BUILD MODEL
    if args.model == 'cnn':
        # Convolutional neural netork
        if args.dataset == 'mnist':
            global_model = CNNMnist(args=args)
        elif args.dataset == 'fmnist':
            global_model = CNNFashion_Mnist(args=args)
        elif args.dataset == 'cifar':
            global_model = CNNCifar(args=args)

    elif args.model == 'mlp':
        # Multi-layer preceptron
        img_size = train_dataset[0][0].shape
        len_in = 1
        for x in img_size:
            len_in *= x
            global_model = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes)
    else:
        exit('Error: unrecognized model')

    # Set the model to train and send it to device.
    global_model.to(device)
    global_model.train()
    print(global_model)

    # copy weights
    global_weights = global_model.state_dict()

    mal_X_list, mal_Y, Y_true = get_mal_dataset(test_dataset, args.num_mal, args.num_classes)
    print("malcious dataset true labels: {}, malicious labels: {}".format(Y_true, mal_Y))

    # Training
    train_loss, train_accuracy = [], []
    val_acc_list, net_list = [], []
    cv_loss, cv_acc = [], []
    print_every = 1
    val_loss_pre, counter = 0, 0

    for epoch in tqdm(range(args.epochs)):
        acts = []

        local_weights, local_losses = [], []
        print('=========================================')
        print(f'| Global Training Round : {epoch+1} |')
        print('=========================================')

        global_model.train()
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)

        # flattened weights + bias
        flattened_local_weights = []
        
        # create separate arrays for weights and biases and the offsets 
        only_weights = [] # for krum to consider only weights, not the biases
        only_biases = []
        cka_vals = [] 
        for idx in range(args.num_users):
            mal_user = False
            
            # alternating benign and malicious training 
            if idx in args.mal_users:# and epoch % 2 == 0:
                mal_user = True

            local_model = LocalUpdate(args=args, dataset=train_dataset, idxs=user_groups[idx], logger=logger, \
                                      mal=mal_user, mal_X=mal_X_list, mal_Y=mal_Y, test_dataset=test_dataset)
            
            w_prev = global_model.state_dict()
            w, loss, act = local_model.update_weights(model=copy.deepcopy(global_model), global_round=epoch)
            acts.append(act)
            
            # construct two arrays with weights-only and bias-only
            if 'krum' in args.aggregation:
                for key in w:
                    if 'weight' in key:
                        if mal_user and epoch % 2 == 0:
                            only_weights = np.append(only_weights, args.boost * w[key].detach().cpu().numpy().reshape(-1))
                        else:
                            only_weights = np.append(only_weights, w[key].detach().cpu().numpy().reshape(-1))
                    elif 'bias' in key:
                        only_biases = np.append(only_biases, w[key].detach().cpu().numpy().reshape(-1))
                
            # offsets.append(w[key].detach().cpu().numpy().reshape(-1).shape[0])
            
            new_model = copy.deepcopy(global_model)
            new_model.load_state_dict(w)
            acc, _ = local_model.inference(model=new_model)

            if mal_user == True:
                mal_acc, mal_loss = local_model.mal_inference(model=new_model)
                print('user {}, loss {}, acc {}, mal loss {}, mal acc {}'.format(idx, loss, 100*acc, mal_loss, 100*mal_acc))
            else:
                print('user {}, loss {}, acc {}'.format(idx, loss, 100*acc))

            local_weights.append(copy.deepcopy(w))
            local_losses.append(copy.deepcopy(loss))

            flattened_local_weights.append(flatten(w))
        
        only_weights = torch.tensor(np.array(only_weights)).to(device)
                
        flattened_local_weights = torch.tensor(np.array(flattened_local_weights)).to(device)
        malicious_grads = flattened_local_weights

        n_attacker = len(args.mal_users)
        # print(only_weights.shape)
        
        
        # update global weights
        if args.aggregation == 'fedavg':
            agg_weights = fedavg(malicious_grads)
        elif args.aggregation == 'krum':
            agg_weights, selected_idxs = krum(malicious_grads, n_attacker, only_weights=only_weights)
            print(f'Krum Selected idx: {selected_idxs}')
        elif args.aggregation == 'mkrum':
            agg_weights, selected_idxs = multi_krum(malicious_grads, n_attacker, only_weights=only_weights)
            print(f'multiKrum Selected idxs: {selected_idxs}')
        elif args.aggregation == 'coomed':
            agg_weights = coomed(malicious_grads)
            print(f'\ndiff {torch.norm((agg_weights - flattened_local_weights[0])) ** 2}')
        elif args.aggregation == 'bulyan':
            agg_weights, selected_idxs = bulyan(malicious_grads, n_attacker)
            print(f'Bulyan Selected idx: {selected_idxs}')
        elif args.aggregation == 'trmean':
            agg_weights = tr_mean(malicious_grads, n_attacker)
        elif args.aggregation == 'flare':
            second_last_layer = list(local_weights[0].keys())[-4]
            structured_local_weights = [construct_ordered_dict(global_model, flat_weights) for flat_weights in malicious_grads]
            plrs = [(each_local[second_last_layer]) for each_local in structured_local_weights]
            agg_weights, count_dict = flare(malicious_grads, plrs)
            print(f'flare count_dict: {count_dict}')
        elif args.aggregation == 'fedalign':
            second_last_layer = list(local_weights[0].keys())[-4]
            structured_local_weights = [construct_ordered_dict(global_model, flat_weights) for flat_weights in malicious_grads]
            plrs = [(each_local[second_last_layer]) for each_local in structured_local_weights]
            agg_weights, selected_idxs = fed_align(malicious_grads, plrs)
            print(f'fed_cka_align Selected idx: {selected_idxs}')
        elif args.aggregation == 'fedcc':
            second_last_layer = list(local_weights[0].keys())[-4]
            glob_plr = global_weights[second_last_layer]            
            structured_local_weights = [construct_ordered_dict(global_model, flat_weights) for flat_weights in malicious_grads]
            plrs = [(each_local[second_last_layer]) for each_local in structured_local_weights]
            agg_weights, selected_idxs = fed_cc(malicious_grads, glob_plr, plrs)
            print(f'fed_cka Selected idx: {selected_idxs}')
            
            
        else:
            raise ValueError('Unknown aggregation strategy: {}'.format(args.aggregation))


        # reshape the flattened global weights into the ordereddict
        global_weights = construct_ordered_dict(global_model, agg_weights)
        
        # update global weights
        global_model.load_state_dict(global_weights)

        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)

        # print global training loss after every 'i' rounds
        if (epoch+1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch+1} global rounds:')
            print(f'Training Loss : {np.mean(np.array(train_loss))}')
            
            test_acc, test_loss = test_inference(args, global_model, test_dataset)
            print('\nGlobal model Benign Test Accuracy: {:.2f}% '.format(100*test_acc))

            mal_acc, mal_loss, mal_out = mal_inference(args, global_model, test_dataset, mal_X_list, mal_Y)
            print('Global model Malicious Accuracy: {:.2f}%, Malicious Loss: {:.2f}, confidence: {}\n'.format(100*mal_acc, mal_loss, mal_out))

#                 print('%s: %s n_attacker %d fed_model val loss %.4f val acc %.4f' \
#                       %(args.aggregation, args.attack_type, n_attacker, test_loss, 100*test_acc))
#                 print('\nGlobal model Benign Test Accuracy: {:.2f}% '.format(100*test_acc))
                
    # Test inference after completion of training
    test_acc, test_loss = test_inference(args, global_model, test_dataset)
    mal_acc, mal_loss, mal_out = mal_inference(args, global_model, test_dataset, mal_X_list, mal_Y)

    print(f' \n Results after {args.epochs} global rounds of training:')
    print("|---- Test Benign Accuracy: {:.2f}%".format(100*test_acc))
    print("|---- Test Malicious Accuracy: {:.2f}%, Malicious Loss: {:.2f}, confidence:{}\n".format(100*mal_acc, mal_loss, mal_out[0]))
    

    # Saving the objects train_loss and train_accuracy:
    file_name = '../save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.\
        format(args.dataset, args.model, args.epochs, args.frac, args.iid, args.local_ep, args.local_bs)

    with open(file_name, 'wb') as f:
        pickle.dump([train_loss, train_accuracy], f)

    print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))



Experimental details:
    Model     : cnn
    Optimizer : adam
    Learning  : 0.001
    Aggregation     : fedcc
    Global Rounds   : 50

    Federated parameters:
    Non-IID              : 0.2
    Fraction of users    : 1
    Local Batch size     : 100
    Local Epochs         : 5

    Malicious parameters:
    Attackers            : [0, 1]
    Attack Type          : targeted
    Mal Boost            : 5
Files already downloaded and verified
Files already downloaded and verified
CNNCifar(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)
malcious dataset true labels: 2, malicious labels: [3]


  0%|                                                                                           | 0/50 [00:00<?, ?it/s]

| Global Training Round : 1 |
user 0, loss 8.785504150390626, acc 90.72948328267478, mal loss 5.671874046325684, mal acc 0.0
user 1, loss 6.359981822967529, acc 95.81749049429658, mal loss 3.3141536712646484, mal acc 0.0
user 2, loss 1.2553675655395753, acc 27.260812581913502
user 3, loss 1.2431131457027635, acc 0.0
user 4, loss 1.0618263911317896, acc 1.5337423312883436
user 5, loss 0.7110698693785175, acc 66.38655462184873
user 6, loss 1.0476284375190734, acc 58.95765472312704
user 7, loss 1.1427614144980907, acc 38.44221105527638
user 8, loss 0.7727232098579407, acc 73.54497354497354
user 9, loss 0.5917896860954809, acc 9.936908517350158
[0.7915438667770563, 0.8353731051778266, 0.7573748762238459, 0.7923410218552733, 0.8875145880862669, 0.8904223051605455, 0.8781116352940534, 0.8338231385444032, 0.8046339129550272, 0.8179751574505504]
Counter({1: 7, 0: 3})
len(selected_parameters) = 7
fed_cka Selected idx: [0 1 2 3 7 8 9]
 
Avg Training Stats after 1 global rounds:
Training Loss : 2

  0%|                                                                                           | 0/50 [00:43<?, ?it/s]


KeyboardInterrupt: 

In [5]:
avg_acts = []
for i in range(len(acts)):
    print((acts[i]).shape)
    avg_acts.append(np.mean(acts[i].detach().cpu().numpy(), axis = (1,2)))

torch.Size([1, 64, 4, 4])
torch.Size([1, 64, 4, 4])
torch.Size([3, 64, 4, 4])
torch.Size([1, 64, 4, 4])
torch.Size([6, 64, 4, 4])
torch.Size([59, 64, 4, 4])
torch.Size([58, 64, 4, 4])
torch.Size([86, 64, 4, 4])
torch.Size([40, 64, 4, 4])
torch.Size([69, 64, 4, 4])


In [6]:
for act in avg_acts:
    print(act.shape)

(1, 4)
(1, 4)
(3, 4)
(1, 4)
(6, 4)
(59, 4)
(58, 4)
(86, 4)
(40, 4)
(69, 4)


In [15]:
act1 = acts[0].detach().cpu().numpy()
act2 = acts[9]

In [17]:
idx1

array([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12.,
       13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
       26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38.,
       39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51.,
       52., 53., 54., 55., 56., 57., 58., 59., 60., 61., 62., 63.])

In [18]:
idx2

array([0., 1., 2., 3.])

In [31]:

confidence = {'fedavg':[],
              'krum'  :[],
              'coomed':[],
              'bulyan':[],
              'flare' :[],
              'fedcc' :[]}
for i in range(10):
    confidence['krum'].append(i)

In [41]:
import kmeans1d

x = [0.9850572609232882, 0.9847250213129664, 0.9848672771600722, 0.9997756678639741, 0.9995519971898091, 0.9997412190655622, 0.9999016268180031, 0.9996079270875822, 0.9996735287208167, 0.9996309990394843]
k = 2

clusters, centroids = kmeans1d.cluster(x, k)


In [61]:
scale = clusters

In [None]:
scal

In [43]:
from collections import Counter

Counter(clusters)

Counter({0: 3, 1: 7})

In [57]:
suspect = 1
clusters.index(suspect)
suspects  = [idx for (idx, item) in enumerate(clusters) if item == suspect]


In [58]:
suspects

[3, 4, 5, 6, 7, 8, 9]

In [33]:
confidence = [0,1,2,3,4]

In [36]:
import csv

if args.iid:
    filename = 'Confidence_IID.csv'
else:
    filename = 'Confidence_NIID.csv'

with open(filename, mode='w') as csv_file:
    csv_data = [args.aggregation, confidence]
    writer = csv.writer(csv_file, delimiter=',')
    writer.writerow(csv_data)


In [16]:
from scipy import interpolate

num_d, h, w, _ = acts1.shape
num_c = act1.shape[-1]
act1_interp = np.zeros((num_d, h, w, num_c))

for d in range(num_d):
    for c in range(num_c):
        # form interpolation function
        idx1 = np.linspace(0, act1.shape[1], act1.shape[1], endpoint=False)
        idx2 = np.linspace(0, act1.shape[2], act1.shape[2], endpoint=False)  
        arr = act1[d,:,:,c]
        f_interp = interpolate.interp2d(idx1, idx2, arr)
        
        # create a larger arr
        large_idx1 = np.linspace(0, act1.shape[1], act1.shape[1], endpoint=False)
        large_idx2 = np.linspace(0, act1.shape[2], act1.shape[2], endpoint=False)  
        act1interp[d,:,:,c] = f_interp(large_idx1, large_idx2)

ValueError: When on a regular grid with x.size = m and y.size = n, if z.ndim == 2, then z must have shape (n, m)

In [7]:

# kernel CKA on pairwise local client's output of the activation 
for i in range(args.num_users):
    for j in range(i + 1, args.num_users):
        print(kernel_CKA(avg_acts[i], avg_acts[j]))
        

nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


ValueError: operands could not be broadcast together with shapes (3,3) (6,6) 

In [None]:




CNNCifar(args=args)

In [None]:
from utils import get_mal_datset_of_class

test_mal_X_list, test_mal_Y, test_Y_true = get_mal_dataset_of_class(test_dataset, args.num_mal*100, Y_true, mal_Y)
flat_test_mal_Y = np.hstack([y for y in test_mal_Y])
mal_acc, mal_loss, mal_out = mal_inference(args, global_model, test_dataset, test_mal_X_list, flat_test_mal_Y)


In [None]:
mal_acc, mal_loss, torch.mean(mal_out)

In [None]:

# PLOTTING (optional)
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Agg')

# Plot Loss curve
plt.figure()
plt.title('Training Loss vs Communication rounds')
plt.plot(range(len(train_loss)), train_loss, color='r')
plt.ylabel('Training loss')
plt.xlabel('Communication Rounds')
plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_loss.png'.
            format(args.dataset, args.model, args.epochs, args.frac,
                   args.iid, args.local_ep, args.local_bs))

# Plot Average Accuracy vs Communication rounds
plt.figure()
plt.title('Average Accuracy vs Communication rounds')
plt.plot(range(len(train_accuracy)), train_accuracy, color='k')
plt.ylabel('Average Accuracy')
plt.xlabel('Communication Rounds')
plt.savefig('../save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_acc.png'.
            format(args.dataset, args.model, args.epochs, args.frac,
                   args.iid, args.local_ep, args.local_bs))

In [None]:
last_layer = list(local_weights[0].keys())[-2]
second_last_layer = list(local_weights[0].keys())[-4]
thrid_last_layer = list(local_weights[0].keys())[-6]

In [None]:

# reshape the flattened global weights into the ordereddict
keys = global_model.state_dict().keys()
structured_local_weights = []

for grad in enumerate(malicious_grads):
    start_idx = 0
    model_grads = []

    for i, param in enumerate(global_model.parameters()):
        param_ = grad[1][start_idx:start_idx + len(param.data.view(-1))].reshape(param.data.shape)
        start_idx = start_idx + len(param.data.view(-1))
        param_ = param_.cuda()
        model_grads.append(param_)
    structured_local_weights.append(OrderedDict(dict(zip(keys, model_grads))))
# structured_local_weights = OrderedDict(dict(zip(keys, model_grads)))  

In [None]:
for key in structured_local_weights[0].keys():
    if 'weight' in key:
        print(key)
    

In [None]:
last_layer = list(local_weights[0].keys())[-2]
second_last_layer = list(local_weights[0].keys())[-4]
thrid_last_layer = list(local_weights[0].keys())[-6]

# list of the second last layer's weights
plrs = [(each_local[second_last_layer]) for each_local in structured_local_weights]

In [None]:
# bsz : batch size (number of positive pairs)
# d   : latent dim
# x   : Tensor, shape=[bsz, d]
#       latents for one side of positive pairs
# y   : Tensor, shape=[bsz, d]
#       latents for the other side of positive pairs

def align_loss(x, y, alpha=2):
    return (x - y).norm(p=2, dim=1).pow(alpha).mean()

def uniform_loss(x, t=2):
    return torch.pdist(x, p=2).pow(2).mul(-t).exp().mean().log()

In [None]:

for i in range(len(plrs)):
    print(f'uniform loss plrs[{i}]: {uniform_loss(plrs[i])}')
    for j in range(0, len(plrs)):
        print(f'align_loss btwn plrs[{i}] and [{j}]: {align_loss(plrs[i], plrs[j])}')

In [None]:
from cka import linear_CKA, kernel_CKA

In [None]:
X = np.random.randn(100, 64)
Y = np.random.randn(100, 64)

print('Linear CKA, between X and Y: {}'.format(linear_CKA(X, Y)))
print('Linear CKA, between X and X: {}'.format(linear_CKA(X, X)))

print('RBF Kernel CKA, between X and Y: {}'.format(kernel_CKA(X, Y)))
print('RBF Kernel CKA, between X and X: {}'.format(kernel_CKA(X, X)))

In [None]:
plrs = [plr.cpu().detach() for plr in plrs]

In [None]:
len(malicious_grads)

In [None]:
detached_mal_grad = malicious_grads.cpu().detach()

In [None]:
detached_mal_grad[i]

In [None]:
np.array(detached_mal_grad[i].flatten())

In [None]:
detached_mal_grad[0]

In [None]:
plrs[0]

In [None]:
linear_CKA(np.array(detached_mal_grad[1]), np.array(detached_mal_grad[2]))

In [None]:
for i in range(len(plrs)):
    for j in range(i + 1, len(plrs)):
        print(f'kernel_cka(all[{i}], all[{j}]: {kernel_CKA(plrs[i], plrs[j])}')

In [None]:
linear_CKA(plrs[9].cpu().detach(), plrs[8].cpu().detach())