In [1]:
%matplotlib inline

import torch
import torchvision
from torch.utils.data.dataset import Dataset
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from helper import ExperimentLogger, display_train_stats

import numpy as np
import matplotlib.pyplot as plt
from data_utils import CustomImageDataset, split_image_data
import random
from collections import defaultdict
from data_utils import get_default_data_transforms

from models import ConvNet
from fl_devices import Server, Client

torch.manual_seed(0)
random.seed(0)
np.random.seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"


In [2]:
# helper functions

# detect_adv_idx: adverary indices detected by server
# gt_adv_idx: ground-truth indices
def check_detect(detect_adv_idx, gt_adv_idx):
    intersection = [idx for idx in gt_adv_idx if idx in detect_adv_idx]
    if len(intersection) > 0:
        return True
    else:
        return False
    
# feature_matrix:
# each row is flatten dWs from a client
def generate_feature_matrix(dW_dicts):
    with torch.no_grad():
        rows = []
        
        for dW_dict in dW_dicts:
            row = torch.empty(0).to(device)
            for key, value in dW_dict.items():
                row = torch.cat((row, value.flatten()), 0)
            rows.append(row)
            
        matrix = torch.stack(rows, 0)
        if device is "cpu":
            return matrix.numpy()
        else:
            return matrix.cpu().numpy()

In [14]:
# hyperparameters
N_CLIENT = 25
N_ADV_RANDOM = 0
N_ADV_OPP = 0
N_ADV_SWAP = 0

# data = datasets.EMNIST(root="./", split="byclass",download=True)
data = datasets.MNIST(root='./',download=True)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [15]:
mapp = np.array(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C',
       'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
       'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c',
       'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p',
       'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'], dtype='<U1')


In [16]:
## It seems that data.train_data and data.test_data are the same
## small data for fast training 
train_frac = 0.2
test_frac = 0.2 
train_num = int(train_frac * len(data))
test_num = int(test_frac * len(data))
idcs = np.random.permutation(len(data))
train_idcs, test_idcs = idcs[:train_num], idcs[train_num:train_num + test_num]
train_labels = data.train_labels.numpy()



In [18]:
clients_split = split_image_data(data.train_data[train_idcs], train_labels[train_idcs], n_clients=N_CLIENT, classes_per_client=5,balancedness=1)

Data split:
 - Client 0: [ 0  0  0 96 96 96 96 96  0  0]
 - Client 1: [96 96  0  0  0  0  0 96 96 96]
 - Client 2: [ 0  0  0  0  0 96 96 96 96 96]
 - Client 3: [ 0 96 96 96 96 96  0  0  0  0]
 - Client 4: [96 96 96  0  0  0  0  0 96 96]
 - Client 5: [ 0  0  0  0 96 96 96 96 96  0]
 - Client 6: [ 0 96 96 96 96 96  0  0  0  0]
 - Client 7: [ 0 96 96 96 96 96  0  0  0  0]
 - Client 8: [ 0  0 96 96 96 96 96  0  0  0]
 - Client 9: [96 96 96 96 96  0  0  0  0  0]
 - Client 10: [96  0  0  0  0  0 96 96 96 96]
 - Client 11: [ 0 96 96 96 96 96  0  0  0  0]
 - Client 12: [ 0  0  0 96 96 96 96 96  0  0]
 - Client 13: [ 0  0 96 96 96 96 96  0  0  0]
 - Client 14: [ 0 96 96 96 96 96  0  0  0  0]
 - Client 15: [78  0  0  0  0 18 96 96 96 96]
 - Client 16: [ 0  0  0 96 96  0 96 96 96  0]
 - Client 17: [96 96  0  0  0  0  0 96 96 96]
 - Client 18: [ 0  0 96 96  8  0 96 96 88  0]
 - Client 19: [96 96  0  0  0  0  0 96 96 96]
 - Client 20: [96 96  0  0  0  0  0 96 96 96]
 - Client 21: [ 0 96 96 66  0  0



In [19]:
train_trans, val_trans = get_default_data_transforms("EMNIST")
client_data = [CustomImageDataset(clients_split[i][0].to(torch.float32), clients_split[i][1],transforms=train_trans ) for i in range(len(clients_split))]


Data preprocessing: 
 - ToPILImage()
 - Resize(size=(28, 28), interpolation=PIL.Image.BILINEAR)
 - ToTensor()
 - Normalize(mean=(0.06078,), std=(0.1957,))



In [20]:
test_data = data.test_data[train_num:train_num+test_num]
test_labels = train_labels[train_num:train_num+test_num]
test_data = CustomImageDataset(test_data.to(torch.float32), test_labels, transforms=val_trans)



In [21]:
# dataset preprocess
# TODO
clients = [Client(ConvNet, lambda x : torch.optim.SGD(x, lr=0.1, momentum=0.9), client_data[i], idnum=i) 
           for i, dat in enumerate(client_data)]
client_indx = np.random.permutation(len(clients))
offset = 0
adv_random = client_indx[0:N_ADV_RANDOM]
offset += N_ADV_RANDOM
adv_opp = client_indx[offset:offset + N_ADV_OPP]
offset += N_ADV_OPP
adv_swap = client_indx[offset:offset+N_ADV_SWAP]
offset += N_ADV_SWAP
adv_idx = np.concatenate((adv_random,adv_opp,adv_swap)).tolist()
for i in adv_random:
  clients[i].client_mode = 'random'

for i in adv_opp:
  clients[i].client_mode = 'opposite'

for i in adv_swap:
  clients[i].client_mode = 'swap'

# print out each client and its mode
for idx, client in enumerate(clients):
  print('{}: {}'.format(idx, client.client_mode))

server = Server(ConvNet, test_data)

In [22]:
# hyperparemeters
TOTAL_TRIAL = 30
TOTAL_ROUND = 20

esp_vals = [0.1, 0.2, 0.3, 0.4, 0.5]
min_samples_vals = [2, 3, 4, 5, 6]
metric_vals = ['l1', 'l2', 'cosine']

detect_counter = defaultdict(lambda: [0] * TOTAL_ROUND)

In [None]:
def print_labels(labels):
  string = []
  for idx, label in enumerate(labels):
    string.append(str(idx)+': '+str(label))
  print('\t'.join(string))

def print_outliers(labels):
  outlier_idx = np.argwhere(labels == -1).flatten()
  print(outlier_idx)

In [24]:
esp = 0.8
min_samples =2
metric = 'l2'
cfl_stats = ExperimentLogger()
counter_key = (esp, min_samples, metric)

for round in range(TOTAL_ROUND):
        if round == 0:
            for client in clients:
                client.synchronize_with_server(server)

        participating_clients = server.select_clients(clients, frac=1.0)

        for client in participating_clients:
            train_stats = client.compute_weight_update(epochs=1)
            client.reset()


        # generate feature matrix for clustering
        client_dW_dicts = [client.dW for client in clients]
        feature_matrix = generate_feature_matrix(client_dW_dicts)
        print("feature matrix max")
        print(feature_matrix.max())
        
        # detect adversary using clustering
        #detect_adv_idx = server.detect_adversary(feature_matrix, esp, min_samples, metric)
        #detect_result = check_detect(detect_adv_idx, adv_idx)
        
        # return labels assigned to clients
        clustering_labels = server.detect_adversary(feature_matrix, esp, min_samples, metric)

        #if detect_result:
        #    detect_counter[counter_key][round] += 1
        # aggregate weight updates; copy new weights to clients
        server.aggregate_weight_updates(clients)
        server.copy_weights(clients)

        acc_clients = [client.evaluate() for client in clients]
        cfl_stats.log({"acc_clients" : acc_clients, "rounds" : round})
    
        print("round %d"%(round))
        #print(detect_adv_idx)
        #print(acc_clients)

        print("labels assigned to clients:")
        print_labels(clustering_labels)
        print('detected outliers:')
        print_outliers(clustering_labels)

feature matrix max
0.37492388
round 0
detection results
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24]
[0.8541666666666666, 0.9270833333333334, 0.84375, 0.8645833333333334, 0.8541666666666666, 0.8333333333333334, 0.9166666666666666, 0.84375, 0.8020833333333334, 0.90625, 0.8854166666666666, 0.9270833333333334, 0.8958333333333334, 0.78125, 0.9166666666666666, 0.875, 0.8020833333333334, 0.8333333333333334, 0.8333333333333334, 0.8854166666666666, 0.90625, 0.8541666666666666, 0.875, 0.9270833333333334, 0.875]
feature matrix max
0.3646883
round 1
detection results
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24]
[0.8541666666666666, 0.9375, 0.8541666666666666, 0.8541666666666666, 0.8645833333333334, 0.8333333333333334, 0.9166666666666666, 0.875, 0.8125, 0.8958333333333334, 0.875, 0.9270833333333334, 0.875, 0.8125, 0.9270833333333334, 0.8854166666666666, 0.8020833333333334, 0.8333333333333334, 0.84375, 0.8854166666666666, 0.90625, 0.

KeyboardInterrupt: ignored

In [None]:
detect_counter[counter_key]

[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

In [None]:
acc_clients = [client.evaluate() for client in clients]


In [None]:
feature_matrix.max()

0.0

In [None]:
detect_adv_idx

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24])

In [None]:
# for each combination of esp, min_samples, metric
# run multiple trials, do clustering at each round
# to find the round number before clustering that have best adversary-identificaiton rate
# no need to handle any adversary detected yet

for esp in esp_vals:
    for min_samples in min_samples_vals:
        for metric in metric_vals:
            counter_key = (esp, min_samples, metric)
            
            # run multiple trials
            for trial in range(TOTAL_TRIAL):
                # communication rounds in FL
                for round in range(TOTAL_ROUND):
                        if round == 0:
                            for client in clients:
                                client.synchronize_with_server(server)

                        participating_clients = server.select_clients(clients, frac=1.0)

                        for client in participating_clients:
                            train_stats = client.compute_weight_update(epochs=1)
                            client.reset()

                        # generate feature matrix for clustering
                        client_dW_dicts = [client.dW for client in clients]
                        feature_matrix = generate_feature_matrix(client_dW_dicts)

                        # detect adversary using clustering
                        detect_adv_idx = server.detect_adversary(feature_matrix, esp, min_samples, metric)
                        detect_result = check_detect(detect_adv_idx, adv_idx)

                        if detect_result:
                            detect_counter[counter_key][round] += 1
                        
                        # aggregate weight updates; copy new weights to clients
                        server.aggregate_weight_updates(clients)
                        server.copy_weights(clients)
                        print(round)
                print(trial)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
0
0
1
2
3
4
5


KeyboardInterrupt: ignored

In [None]:
detect_counter

defaultdict(<function __main__.<lambda>>, {})

In [None]:
detect_result

False

In [None]:
for key, value in client_dW_dicts[0].items():
    break

In [None]:
value

tensor([[[[-0.2654, -0.0482, -0.1496, -0.3325, -0.3268],
          [-0.1360, -0.0854, -0.1224, -0.2835, -0.2216],
          [-0.2093, -0.5409, -0.4456, -0.1587, -0.1258],
          [ 0.2039,  0.1777, -0.3295, -0.3602, -0.5164],
          [ 0.0288,  0.3881,  0.2621,  0.0013, -0.6501]]],


        [[[-0.1701, -0.1969,  0.0160, -0.2379, -0.5600],
          [-0.2053, -0.1223, -0.0209, -0.3452, -0.5991],
          [ 0.0276, -0.0528, -0.3311, -0.4907, -0.4634],
          [-0.2149, -0.4171, -0.1711, -0.3538, -0.1690],
          [-0.4859, -0.1660, -0.1544, -0.3510,  0.1855]]],


        [[[-0.4662, -0.3984, -0.0946, -0.0421, -0.5252],
          [-0.6843, -0.3122,  0.0254, -0.5276, -0.4501],
          [-0.6525, -0.2118, -0.5211, -0.5835, -0.0685],
          [-0.7327, -0.7185, -0.6613, -0.4881, -0.4985],
          [-0.5199, -0.2257, -0.1146, -0.2158, -0.6134]]],


        [[[-0.5786,  0.2038,  0.1686, -0.2504, -0.4746],
          [ 0.0479,  0.2259, -0.2854, -0.3259, -0.0872],
          [-0.2300,