In [65]:
%matplotlib inline

import torch
import torchvision
from torch.utils.data.dataset import Dataset
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import numpy as np
import matplotlib.pyplot as plt
import random
from collections import defaultdict

from data_utils import CustomImageDataset, split_image_data
from data_utils import get_default_data_transforms
from models import ConvNet
from fl_devices import Server, Client
from helper import ExperimentLogger, display_train_stats

from sklearn.cluster import AgglomerativeClustering, DBSCAN
from sklearn.metrics import pairwise_distances
from sklearn.decomposition import PCA


torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
# helper functions

# detect_adv_idx: adverary indices detected by server
# gt_adv_idx: ground-truth indices
def check_detect(detect_adv_idx, gt_adv_idx):
    intersection = [idx for idx in gt_adv_idx if idx in detect_adv_idx]
    if len(intersection) > 0:
        return True
    else:
        return False
    
# feature_matrix:
# each row is flatten dWs from a client
def generate_feature_matrix(dW_dicts):
    with torch.no_grad():
        rows = []
        
        for dW_dict in dW_dicts:
            row = torch.empty(0).to(device)
            for key, value in dW_dict.items():
                row = torch.cat((row, value.flatten()), 0)
            rows.append(row)
            
        matrix = torch.stack(rows, 0)
        if device is "cpu":
            return matrix.numpy()
        else:
            return matrix.cpu().numpy()
        
def print_labels(labels):
  string = []
  for idx, label in enumerate(labels):
    string.append(str(idx)+': '+str(label))
  print('\t'.join(string))

def print_outliers(labels):
  outlier_idx = np.argwhere(labels == -1).flatten()
  print(outlier_idx)
    
def print_distance(feature_matrix, metric):
  distance = pairwise_distances(feature_matrix,metric=metirc)
  return distance 

In [56]:
# hyperparameters
N_CLIENT = 25
N_ADV_RANDOM = 3
N_ADV_OPP = 0
N_ADV_SWAP = 0

In [4]:
# data = datasets.EMNIST(root="./", split="byclass",download=True)
data = datasets.MNIST(root='./',download=True)





In [57]:
## It seems that data.train_data and data.test_data are the same
## small data for fast training 
train_frac = 0.2
test_frac = 0.2 
train_num = int(train_frac * len(data))
test_num = int(test_frac * len(data))
idcs = np.random.permutation(len(data))
train_idcs, test_idcs = idcs[:train_num], idcs[train_num:train_num + test_num]
train_labels = data.train_labels.numpy()



In [58]:
clients_split = split_image_data(data.train_data[train_idcs], train_labels[train_idcs], n_clients=N_CLIENT, classes_per_client=5,balancedness=1)

Data split:
 - Client 0: [ 0  0  0 96 96 96 96 96  0  0]
 - Client 1: [ 0 96 96 96 96 96  0  0  0  0]
 - Client 2: [96  0  0  0  0  0 96 96 96 96]
 - Client 3: [96 96 96  0  0  0  0  0 96 96]
 - Client 4: [96 96 96 96  0  0  0  0  0 96]
 - Client 5: [ 0 96 96 96 96 96  0  0  0  0]
 - Client 6: [96  0  0  0  0  0 96 96 96 96]
 - Client 7: [ 0 96 96 96 96 96  0  0  0  0]
 - Client 8: [ 0  0  0  0 96 96 96 96 96  0]
 - Client 9: [ 0  0 96 96 96 96 96  0  0  0]
 - Client 10: [96 96 96 96 96  0  0  0  0  0]
 - Client 11: [ 0  0  0  0 96 96 96 96 96  0]
 - Client 12: [ 0  0 96 96 96 96 96  0  0  0]
 - Client 13: [96  0  0  0  0  0 96 96 96 96]
 - Client 14: [96 96 96 96  0  0  0  0  0 96]
 - Client 15: [96 96 96  0  0  0  0  0 96 96]
 - Client 16: [96 96  0  0  0  0  0 96 96 96]
 - Client 17: [96 96 96 96  0  0  0  0  0 96]
 - Client 18: [96 96 96 96  0  0  0  0  0 96]
 - Client 19: [ 0  0  0 96 96 96 96 96  0  0]
 - Client 20: [96 96 35 93 96 64  0  0  0  0]
 - Client 21: [ 0  0  0  0 96 96



In [59]:
train_trans, val_trans = get_default_data_transforms("EMNIST")
client_data = [CustomImageDataset(clients_split[i][0].to(torch.float32), clients_split[i][1],transforms=train_trans ) for i in range(len(clients_split))]


Data preprocessing: 
 - ToPILImage()
 - Resize(size=(28, 28), interpolation=PIL.Image.BILINEAR)
 - ToTensor()
 - Normalize(mean=(0.06078,), std=(0.1957,))



In [60]:
test_data = data.test_data[train_num:train_num+test_num]
test_labels = train_labels[train_num:train_num+test_num]
test_data = CustomImageDataset(test_data.to(torch.float32), test_labels, transforms=val_trans)



In [67]:
# Assign client modes
clients = [Client(ConvNet, lambda x : torch.optim.SGD(x, lr=0.1, momentum=0.9), client_data[i], idnum=i) 
           for i, dat in enumerate(client_data)]
client_indx = np.random.permutation(len(clients))
offset = 0
adv_random = client_indx[0:N_ADV_RANDOM]
offset += N_ADV_RANDOM
adv_opp = client_indx[offset:offset + N_ADV_OPP]
offset += N_ADV_OPP
adv_swap = client_indx[offset:offset+N_ADV_SWAP]
offset += N_ADV_SWAP
adv_idx = np.concatenate((adv_random,adv_opp,adv_swap)).tolist()
for i in adv_random:
  clients[i].client_mode = 'random'

for i in adv_opp:
  clients[i].client_mode = 'opposite'

for i in adv_swap:
  clients[i].client_mode = 'mode_3'

# print out each client and its mode
for idx, client in enumerate(clients):
  print('{}: {}'.format(idx, client.client_mode))

server = Server(ConvNet, test_data)

0: normal
1: normal
2: random
3: normal
4: normal
5: random
6: normal
7: normal
8: normal
9: normal
10: normal
11: normal
12: normal
13: normal
14: normal
15: normal
16: normal
17: normal
18: normal
19: random
20: normal
21: normal
22: normal
23: normal
24: normal


In [10]:
# hyperparemeters
TOTAL_ROUND = 20

esp = 2.0
min_samples =2
metric = 'l2'
cfl_stats = ExperimentLogger()
#counter_key = (esp, min_samples, metric)

In [68]:
for round in range(TOTAL_ROUND):
        if round == 0:
            for client in clients:
                client.synchronize_with_server(server)

        participating_clients = server.select_clients(clients, frac=1.0)

        for client in participating_clients:
            train_stats = client.compute_weight_update(epochs=1)
            client.reset()


        # generate feature matrix for clustering
        client_dW_dicts = [client.dW for client in clients]
        feature_matrix = generate_feature_matrix(client_dW_dicts)
        print("feature matrix max")
        print(feature_matrix.max())
        
        # detect adversary using clustering
        #detect_adv_idx = server.detect_adversary(feature_matrix, esp, min_samples, metric)
        #detect_result = check_detect(detect_adv_idx, adv_idx)
        
        # return labels assigned to clients
        clustering_labels = server.detect_adversary(feature_matrix, esp, min_samples, metric)

        #if detect_result:
        #    detect_counter[counter_key][round] += 1
        # aggregate weight updates; copy new weights to clients
        server.aggregate_weight_updates(clients)
        server.copy_weights(clients)

        acc_clients = [client.evaluate() for client in clients]
        cfl_stats.log({"acc_clients" : acc_clients, "rounds" : round})
    
        print("round %d"%(round))
        #print(detect_adv_idx)
        #print(acc_clients)

        print("labels assigned to clients:")
        print_labels(clustering_labels)
        print('detected outliers:')
        print_outliers(clustering_labels)

feature matrix max
0.27681434
round 0
labels assigned to clients:
0: 0	1: 0	2: -1	3: 0	4: 0	5: -1	6: 0	7: 0	8: 0	9: 0	10: 0	11: 0	12: 0	13: 0	14: 0	15: 0	16: 0	17: 0	18: 0	19: -1	20: 0	21: 0	22: 0	23: 0	24: 0
detected outliers:
[ 2  5 19]
feature matrix max
0.8563676
round 1
labels assigned to clients:
0: -1	1: -1	2: -1	3: -1	4: 0	5: -1	6: -1	7: -1	8: 1	9: 2	10: -1	11: -1	12: 2	13: -1	14: 3	15: -1	16: -1	17: 3	18: 0	19: -1	20: -1	21: 1	22: -1	23: -1	24: -1
detected outliers:
[ 0  1  2  3  5  6  7 10 11 13 15 16 19 20 22 23 24]
feature matrix max
0.6340621
round 2
labels assigned to clients:
0: -1	1: -1	2: -1	3: -1	4: 0	5: -1	6: 1	7: -1	8: 2	9: 3	10: -1	11: -1	12: 3	13: 1	14: 4	15: -1	16: -1	17: 4	18: 0	19: -1	20: -1	21: 2	22: -1	23: -1	24: -1
detected outliers:
[ 0  1  2  3  5  7 10 11 15 16 19 20 22 23 24]
feature matrix max
0.65579414
round 3
labels assigned to clients:
0: -1	1: -1	2: -1	3: 0	4: 0	5: -1	6: 1	7: 0	8: -1	9: -1	10: 0	11: -1	12: -1	13: 1	14: 0	15: -1	16: -1	17: 0	18: 0	1

# Experiment A: F1-score vs N_communication_round


In [None]:
#TOTAL_ROUND = 20
TOTAL_TRIAL = 30

esp = 2.0
min_samples =2
metric = 'l2'
cfl_stats = ExperimentLogger()

detect_counter = defaultdict(lambda: [0] * TOTAL_ROUND)

In [None]:
# for each combination of esp, min_samples, metric
# run multiple trials, do clustering at each round
# to find the round number before clustering that have best adversary-identificaiton rate
# no need to handle any adversary detected yet

for esp in esp_vals:
    for min_samples in min_samples_vals:
        for metric in metric_vals:
            counter_key = (esp, min_samples, metric)
            
            # run multiple trials
            for trial in range(TOTAL_TRIAL):
                # communication rounds in FL
                for round in range(TOTAL_ROUND):
                        if round == 0:
                            for client in clients:
                                client.synchronize_with_server(server)

                        participating_clients = server.select_clients(clients, frac=1.0)

                        for client in participating_clients:
                            train_stats = client.compute_weight_update(epochs=1)
                            client.reset()

                        # generate feature matrix for clustering
                        client_dW_dicts = [client.dW for client in clients]
                        feature_matrix = generate_feature_matrix(client_dW_dicts)

                        # detect adversary using clustering
                        detect_adv_idx = server.detect_adversary(feature_matrix, esp, min_samples, metric)
                        detect_result = check_detect(detect_adv_idx, adv_idx)

                        if detect_result:
                            detect_counter[counter_key][round] += 1
                        
                        # aggregate weight updates; copy new weights to clients
                        server.aggregate_weight_updates(clients)
                        server.copy_weights(clients)
                        print(round)
                print(trial)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
0
0
1
2
3
4
5


KeyboardInterrupt: ignored

In [None]:
detect_counter

defaultdict(<function __main__.<lambda>>, {})

In [None]:
detect_result

False

In [None]:
for key, value in client_dW_dicts[0].items():
    break

In [None]:
value

tensor([[[[-0.2654, -0.0482, -0.1496, -0.3325, -0.3268],
          [-0.1360, -0.0854, -0.1224, -0.2835, -0.2216],
          [-0.2093, -0.5409, -0.4456, -0.1587, -0.1258],
          [ 0.2039,  0.1777, -0.3295, -0.3602, -0.5164],
          [ 0.0288,  0.3881,  0.2621,  0.0013, -0.6501]]],


        [[[-0.1701, -0.1969,  0.0160, -0.2379, -0.5600],
          [-0.2053, -0.1223, -0.0209, -0.3452, -0.5991],
          [ 0.0276, -0.0528, -0.3311, -0.4907, -0.4634],
          [-0.2149, -0.4171, -0.1711, -0.3538, -0.1690],
          [-0.4859, -0.1660, -0.1544, -0.3510,  0.1855]]],


        [[[-0.4662, -0.3984, -0.0946, -0.0421, -0.5252],
          [-0.6843, -0.3122,  0.0254, -0.5276, -0.4501],
          [-0.6525, -0.2118, -0.5211, -0.5835, -0.0685],
          [-0.7327, -0.7185, -0.6613, -0.4881, -0.4985],
          [-0.5199, -0.2257, -0.1146, -0.2158, -0.6134]]],


        [[[-0.5786,  0.2038,  0.1686, -0.2504, -0.4746],
          [ 0.0479,  0.2259, -0.2854, -0.3259, -0.0872],
          [-0.2300,