In [None]:
%matplotlib inline

import torch
import torchvision
from torch.utils.data.dataset import Dataset
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import numpy as np
import matplotlib.pyplot as plt
import random
from collections import defaultdict

from data_utils import CustomImageDataset, split_image_data
from data_utils import get_default_data_transforms
from models import ConvNet
from fl_devices import Server, Client
from helper import ExperimentLogger, display_train_stats

from sklearn.cluster import AgglomerativeClustering, DBSCAN
from sklearn.metrics import pairwise_distances
from sklearn.metrics import f1_score
from sklearn.decomposition import PCA


torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# helper functions

# detect_adv_idx: adverary indices detected by server
# gt_adv_idx: ground-truth indices
def check_detect(detect_adv_idx, gt_adv_idx):
    intersection = [idx for idx in gt_adv_idx if idx in detect_adv_idx]
    if len(intersection) > 0:
        return True
    else:
        return False
    
# feature_matrix:
# each row is flatten dWs from a client
def generate_feature_matrix(dW_dicts):
    with torch.no_grad():
        rows = []
        
        for dW_dict in dW_dicts:
            row = torch.empty(0).to(device)
            for key, value in dW_dict.items():
                row = torch.cat((row, value.flatten()), 0)
            rows.append(row)
            
        matrix = torch.stack(rows, 0)
        if device is "cpu":
            return matrix.numpy()
        else:
            return matrix.cpu().numpy()
        
def print_labels(labels):
    string = []
    for idx, label in enumerate(labels):
        string.append(str(idx)+': '+str(label))
    print('\t'.join(string))
    
def print_outliers(labels):
    outlier_idx = np.argwhere(labels == -1).flatten()
    print(outlier_idx)
    
def print_distance(feature_matrix, metric):
    distance = pairwise_distances(feature_matrix,metric=metirc)
    return distance 

In [None]:
# hyperparameters
N_CLIENT = 25
N_ADV_RANDOM = 3
N_ADV_OPP = 0
N_ADV_SWAP = 0

In [None]:
# data = datasets.EMNIST(root="./", split="byclass",download=True)
data = datasets.MNIST(root='./',download=True)

In [None]:
## It seems that data.train_data and data.test_data are the same
## small data for fast training 
train_frac = 0.2
test_frac = 0.2 
train_num = int(train_frac * len(data))
test_num = int(test_frac * len(data))
idcs = np.random.permutation(len(data))
train_idcs, test_idcs = idcs[:train_num], idcs[train_num:train_num + test_num]
train_labels = data.train_labels.numpy()

In [None]:
clients_split = split_image_data(data.train_data[train_idcs], train_labels[train_idcs], n_clients=N_CLIENT, classes_per_client=5,balancedness=1)

In [None]:
train_trans, val_trans = get_default_data_transforms("EMNIST")
client_data = [CustomImageDataset(clients_split[i][0].to(torch.float32), clients_split[i][1],transforms=train_trans ) for i in range(len(clients_split))]

In [None]:
test_data = data.test_data[train_num:train_num+test_num]
test_labels = train_labels[train_num:train_num+test_num]
test_data = CustomImageDataset(test_data.to(torch.float32), test_labels, transforms=val_trans)

In [None]:
# Assign client modes
clients = [Client(ConvNet, lambda x : torch.optim.SGD(x, lr=0.1, momentum=0.9), client_data[i], idnum=i) 
           for i, dat in enumerate(client_data)]
client_indx = np.random.permutation(len(clients))
offset = 0
adv_random = client_indx[0:N_ADV_RANDOM]
offset += N_ADV_RANDOM
adv_opp = client_indx[offset:offset + N_ADV_OPP]
offset += N_ADV_OPP
adv_swap = client_indx[offset:offset+N_ADV_SWAP]
offset += N_ADV_SWAP
adv_idx = np.concatenate((adv_random,adv_opp,adv_swap)).tolist()
for i in adv_random:
  clients[i].client_mode = 'random'

for i in adv_opp:
  clients[i].client_mode = 'opposite'

for i in adv_swap:
  clients[i].client_mode = 'swap'

# print out each client and its mode
for idx, client in enumerate(clients):
  print('{}: {}'.format(idx, client.client_mode))

server = Server(ConvNet, test_data)

In [None]:
# hyperparemeters
TOTAL_ROUND = 20

esp = 2.0
min_samples =2
metric = 'l2'
cfl_stats = ExperimentLogger()
#counter_key = (esp, min_samples, metric)

In [None]:
for round in range(TOTAL_ROUND):
        if round == 0:
            for client in clients:
                client.synchronize_with_server(server)

        participating_clients = server.select_clients(clients, frac=1.0)

        for client in participating_clients:
            train_stats = client.compute_weight_update(epochs=1)
            client.reset()


        # generate feature matrix for clustering
        client_dW_dicts = [client.dW for client in clients]
        feature_matrix = generate_feature_matrix(client_dW_dicts)
        print("feature matrix max")
        print(feature_matrix.max())
        
        # detect adversary using clustering
        #detect_adv_idx = server.detect_adversary(feature_matrix, esp, min_samples, metric)
        #detect_result = check_detect(detect_adv_idx, adv_idx)
        
        # return labels assigned to clients
        clustering_labels = server.detect_adversary(feature_matrix, esp, min_samples, metric)

        #if detect_result:
        #    detect_counter[counter_key][round] += 1
        # aggregate weight updates; copy new weights to clients
        server.aggregate_weight_updates(clients)
        server.copy_weights(clients)

        acc_clients = [client.evaluate() for client in clients]
        cfl_stats.log({"acc_clients" : acc_clients, "rounds" : round})
    
        print("round %d"%(round))
        #print(detect_adv_idx)
        #print(acc_clients)

        print("labels assigned to clients:")
        print_labels(clustering_labels)
        print('detected outliers:')
        print_outliers(clustering_labels)

# Experiment A: F1-score vs N_communication_round


In [None]:
data = datasets.MNIST(root='./',download=True)

train_frac = 0.2
test_frac = 0.2 
train_num = int(train_frac * len(data))
test_num = int(test_frac * len(data))
idcs = np.random.permutation(len(data))
train_idcs, test_idcs = idcs[:train_num], idcs[train_num:train_num + test_num]
train_labels = data.train_labels.numpy()

test_data = data.test_data[train_num:train_num+test_num]
test_labels = train_labels[train_num:train_num+test_num]
test_data = CustomImageDataset(test_data.to(torch.float32), test_labels, transforms=val_trans)

In [None]:
def init_clients(n_clients, adv_config):
    clients_split = split_image_data(data.train_data[train_idcs], train_labels[train_idcs], n_clients=n_clients, classes_per_client=5,balancedness=1)
    
    train_trans, val_trans = get_default_data_transforms("EMNIST")
    client_data = [CustomImageDataset(clients_split[i][0].to(torch.float32), clients_split[i][1],transforms=train_trans ) for i in range(len(clients_split))]
    
    clients = [Client(ConvNet, lambda x : torch.optim.SGD(x, lr=0.1, momentum=0.9), client_data[i], idnum=i) 
           for i, dat in enumerate(client_data)]
    
    # assign client mode
    client_indx = np.random.permutation(len(clients))
    n_adv_rand, n_adv_oppo, n_adv_swap = adv_config
    
    offset = 0
    adv_rand = client_indx[0:n_adv_rand]
    offset += n_adv_rand
    adv_oppo = client_indx[offset:offset + n_adv_oppo]
    offset += n_adv_oppo
    adv_swap = client_indx[offset:offset + n_adv_swap]
    offset += n_adv_swap
    #adv_idx = np.concatenate((adv_random,adv_opp,adv_swap)).tolist()
    
    for i in adv_rand:
        clients[i].client_mode = 'random'

    for i in adv_oppo:
        clients[i].client_mode = 'opposite'

    for i in adv_swap:
        clients[i].client_mode = 'swap'
        
    return clients

def init_server():
    server = Server(ConvNet, test_data)
    
    return server

In [None]:
def generate_feature_matrix(dW_dicts):
    with torch.no_grad():
        rows = []
        
        for dW_dict in dW_dicts:
            row = torch.empty(0).to(device)
            for key, value in dW_dict.items():
                row = torch.cat((row, value.flatten()), 0)
            rows.append(row)
            
        matrix = torch.stack(rows, 0)
        if device is "cpu":
            return matrix.numpy()
        else:
            return matrix.cpu().numpy()

def compute_gt_labels(clients):
    gt_labels = []
    for client in clients:
        if client.client_mode == 'normal':
            gt_labels.append(0)
        else:
            gt_labels.append(-1)
    return gt_labels
    
def compute_f1(pred_labels, gt_labels):
    pred_labels = [-1 if label == -1 else 0 for label in pred_labels]
    score = f1_score(pred_labels, gt_labels)
    
    return score

In [None]:
TOTAL_ROUND = 30
TOTAL_TRIAL = 30

esp = 0.8
min_samples = 2
metric = 'cosine'

n_clients = 25
n_adv_rand = 3
n_adv_oppo = 0
n_adv_swap = 0
adv_config = (n_adv_rand, n_adv_oppo, n_adv_swap)

cfl_stats = ExperimentLogger()

gt_labels = compute_gt_labels(clients)
f1_sum = np.array([0] * TOTAL_ROUND)

In [None]:
# for a combination of esp, min_samples, metric
# run multiple trials, do clustering at each round
# compute and accumulate f1_score for clustering labels at each round

for trial in range(TOTAL_TRIAL):
    # initialize server and clients
    clients = init_clients(n_clients, adv_config)
    server = init_server()
    
    # compute ground-truth labels
    gt_labels = compute_gt_labels(clients)
    
    for round in range(TOTAL_ROUND):
        print("Trial: {}, Round: {}".format(trial, round))
        
        if round == 0:
            for client in clients:
                client.synchronize_with_server(server)
                
        participating_clients = server.select_clients(clients, frac=1.0)
        for client in participating_clients:
            train_stats = client.compute_weight_update(epochs=1)
            client.reset()
            
        # generate feature matrix for clustering
        client_dW_dicts = [client.dW for client in clients]
        feature_matrix = generate_feature_matrix(client_dW_dicts)
        
        # detect outlier using clustering
        clustering_labels = server.detect_adversary(feature_matrix, esp, min_samples, metric)
        
        # aggregate weight updates; copy new weights to clients
        server.aggregate_weight_updates(clients)
        server.copy_weights(clients)
        
        score = compute_f1(clustering_labels, gt_labels, pos_label=-1)
        f1_sum[round] += score