In [None]:
%matplotlib inline

import torch
import torchvision
from torch.utils.data import DataLoader

import numpy as np
import matplotlib.pyplot as plt
import random
from collections import defaultdict

from models import ConvNet
from fl_devices import Server, Client

torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

In [None]:
# helper functions

# feature_matrix:
# each row is flatten dWs from a client
def generate_feature_matrix(dW_dicts):
    rows = []
    
    for dW_dict in dW_dicts:
        row = torch.empty(0)
        for key, value in dW_dict:
            row = torch.cat((row, value.flatten()), 0)
        rows.append(row)
        
    matrix = torch.stack(rows, 0)
    return matrix.numpy()

def handle_adversary(clients, adv_idx, handle):
    if handle == None:
        return clients
    elif handle == 'remove':
        clients_updated = [clients[client_idx] for client_idx in range(len(clients)) if client_idx not in adv_idx]
        return clients_updated
    elif:
        #TODO
        pass

In [None]:
# hyperparameters
N_CLIENT = 25
N_ADV = 1



In [None]:
# dataset preprocess
# TODO
clients = [Client() * N_CLIENT]
adv_idx = []

## Experiment A: Model Performance
Compare model that does handle adversary and model that does NOT handle adversary.

Same TOTAL_ROUND -> Different accuracy

In [None]:
# hyperparemeters
TOTAL_TRIAL = 30
TOTAL_ROUND = 20
DETECT_ROUND = 5

ESP = 0.5
MIN_SAMPLES = 2
METRIC = 'l2'

ADV_HANDLE = [None, 'remove', 'reg']

model_performance = defaultdict(lambda: [None] * TOTAL_TRIAL)

In [None]:
for handle in ADV_HANDLE:
    for trial in range(TOTAL_TRIAL):
        for round in range(TOTAL_ROUND):
            if round == 0:
                for client in clients:
                    client.synchronize_with_server(server)

                participating_clients = server.select_clients(clients, frac=1.0)

                for client in participating_clients:
                    train_stats = client.compute_weight_update(epochs=1)
                    client.reset()

                if round + 1 == DETECT_ROUND:
                    # generate feature matrix for clustering
                    client_dW_dicts = [client.dW for client in clients]
                    feature_matrix = generate_feature_matrix(client_dW_dicts)

                    # detect adversary using clustering
                    detect_adv_idx = server.detect_adversary(feature_matrix, esp, min_samples, metric)

                    # update clients by handling adversary detected
                    clients = handle_advesary(clients, detect_adv_idx, ADV_HANDLE)

                # aggregate weight updates; copy new weights to clients
                server.aggregate_weight_updates(clients)
                server.copy_weights(clients)
                
        # evaluate model performance after all the rounds
        model_performance[handle][trial] = # evaluation result

In [None]:
# plots
# TODO


## Experiment B: Convergence Rate
Compare model that does handle adversary and model that does NOT handle adversary.

Same accuracy -> Different round

In [None]:
# hyperparemeters
TOTAL_TRIAL = 30
MAX_ROUND = 50
DETECT_ROUND = 5

TARGET_ACC = 0.66

ESP = 0.5
MIN_SAMPLES = 2
METRIC = 'l2'

ADV_HANDLE = [None, 'remove', 'reg']

model_round = defaultdict(lambda: [None] * TOTAL_TRIAL)

In [None]:
for handle in ADV_HANDLE:
    for trial in range(TOTAL_TRIAL):
        for round in range(MAX_ROUND):
            if round == 0:
                for client in clients:
                    client.synchronize_with_server(server)

                participating_clients = server.select_clients(clients, frac=1.0)

                for client in participating_clients:
                    train_stats = client.compute_weight_update(epochs=1)
                    client.reset()

                if round + 1 == DETECT_ROUND:
                    # generate feature matrix for clustering
                    client_dW_dicts = [client.dW for client in clients]
                    feature_matrix = generate_feature_matrix(client_dW_dicts)

                    # detect adversary using clustering
                    detect_adv_idx = server.detect_adversary(feature_matrix, esp, min_samples, metric)

                    # update clients by handling adversary detected
                    clients = handle_advesary(clients, detect_adv_idx, ADV_HANDLE)

                # aggregate weight updates; copy new weights to clients
                server.aggregate_weight_updates(clients)
                server.copy_weights(clients)
                
            # evaluate model performance after each round
            model_performance = # evaluation result
            if model_performance >= TARGET_ACC:
                model_round[handle][trial] = round + 1
                break

In [None]:
# plots
# TODO
