In [1]:
%matplotlib inline

import torch
import torchvision
from torch.utils.data.dataset import Dataset
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import numpy as np
import matplotlib.pyplot as plt
from data_utils import CustomImageDataset, split_image_data
import random
from collections import defaultdict
from data_utils import get_default_data_transforms

from models import ConvNet
from fl_devices import Server, Client

torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

In [2]:
# helper functions

# detect_adv_idx: adverary indices detected by server
# gt_adv_idx: ground-truth indices
def check_detect(detect_adv_idx, gt_adv_idx):
    intersection = [idx for idx in gt_adv_idx if idx in detect_adv_idx]
    if len(intersection) > 0:
        return True
    else:
        return False
    
# feature_matrix:
# each row is flatten dWs from a client
def generate_feature_matrix(dW_dicts):
    rows = []
    
    for dW_dict in dW_dicts:
        row = torch.empty(0)
        for key, value in dW_dict:
            row = torch.cat((row, value.flatten()), 0)
        rows.append(row)
        
    matrix = torch.stack(rows, 0)
    return matrix.numpy()

In [3]:
# hyperparameters
N_CLIENT = 25
N_ADV = 1

data = datasets.EMNIST(root="./", split="byclass",download=True)


In [4]:
mapp = np.array(['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C',
       'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
       'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c',
       'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p',
       'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'], dtype='<U1')
import numpy as np
from torch.utils.data import Subset

In [5]:
## It seems that data.train_data and data.test_data are the same
train_frac = 0.8
test_frac = 0.2 
train_num = int(train_frac * len(data))
test_num = int(test_frac * len(data))
idcs = np.random.permutation(len(data))
train_idcs, test_idcs = idcs[:train_num], idcs[train_num:train_num + test_num]
train_labels = data.train_labels.numpy()



In [6]:
clients_split = split_image_data(data.train_data[train_idcs], train_labels[train_idcs], n_clients=N_CLIENT, classes_per_client=4,balancedness=1)



Data split:
 - Client 0: [3701    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0 5583
 2252 2339 2154 2238 1908 2158]
 - Client 1: [   0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
 5583 5583 5583 3735 1849    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0]
 - Client 2: [   0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0 3004 1981 4085 5583 5583 2097    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0]
 

In [7]:
train_trans, val_trans = get_default_data_transforms("EMNIST")
client_data = [CustomImageDataset(clients_split[i][0].to(torch.float32), clients_split[i][1],transforms=train_trans ) for i in range(len(clients_split))]


Data preprocessing: 
 - ToPILImage()
 - Resize(size=(28, 28), interpolation=PIL.Image.BILINEAR)
 - ToTensor()
 - Normalize(mean=(0.06078,), std=(0.1957,))



In [8]:
test_data = data.test_data[train_num:train_num+test_num]
test_labels = train_labels[train_num:train_num+test_num]
test_data = CustomImageDataset(test_data.to(torch.float32), test_labels, transforms=val_trans)



In [9]:
# dataset preprocess
# TODO
clients = [Client(ConvNet, lambda x : torch.optim.SGD(x, lr=0.1, momentum=0.9), client_data[i], idnum=i) 
           for i, dat in enumerate(client_data)]
server = Server(ConvNet, test_data)

adv_idx = []

In [10]:
# hyperparemeters
TOTAL_TRIAL = 30
TOTAL_ROUND = 20

esp_vals = [0.1, 0.2, 0.3, 0.4, 0.5]
min_samples_vals = [2, 3, 4, 5, 6]
metric_vals = ['l1', 'l2', 'cosine']

detect_counter = defaultdict(lambda: [0] * TOTAL_ROUND)

In [12]:
# for each combination of esp, min_samples, metric
# run multiple trials, do clustering at each round
# to find the round number before clustering that have best adversary-identificaiton rate
# no need to handle any adversary detected yet

for esp in esp_vals:
    for min_samples in min_samples_vals:
        for metric in metric_vals:
            counter_key = (esp, min_samples, metric)
            
            # run multiple trials
            for trial in range(TOTAL_TRIAL):
                # communication rounds in FL
                for round in range(TOTAL_ROUND):
                        if round == 0:
                            for client in clients:
                                client.synchronize_with_server(server)

                        participating_clients = server.select_clients(clients, frac=1.0)

                        for client in participating_clients:
                            train_stats = client.compute_weight_update(epochs=1)
                            client.reset()

                        # generate feature matrix for clustering
                        client_dW_dicts = [client.dW for client in clients]
                        feature_matrix = generate_feature_matrix(client_dW_dicts)

                        # detect adversary using clustering
                        detect_adv_idx = server.detect_adversary(feature_matrix, esp, min_samples, metric)
                        detect_result = check_detect(detect_adv_idx, adv_idx)

                        if detect_result:
                            detect_counter[counter_key][round] += 1
                        
                        # aggregate weight updates; copy new weights to clients
                        server.aggregate_weight_updates(clients)
                        server.copy_weights(clients)
                print(trial)

ValueError: too many values to unpack (expected 2)

In [13]:
trial


0

In [25]:
for x, y in clients[0].train_loader:
    break

In [26]:
x.shape

torch.Size([128, 1, 32, 32])

In [30]:
clients[0].model(x).shape

torch.Size([200, 62])