In [49]:
%matplotlib inline

import torch
import torchvision
from torch.utils.data import DataLoader

import numpy as np
import matplotlib.pyplot as plt
import random
from collections import defaultdict

from models import ConvNet
from fl_devices import Server, Client

torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

In [None]:
# helper functions

# detect_adv_idx: adverary indices detected by server
# gt_adv_idx: ground-truth indices
def check_detect(detect_adv_idx, gt_adv_idx):
    intersection = [idx for idx in gt_adv_idx if idx in detect_adv_idx]
    if len(intersection) > 0:
        return True
    else:
        return False
    
# feature_matrix:
# each row is flatten dWs from a client
def generate_feature_matrix(dW_dicts):
    rows = []
    
    for dW_dict in dW_dicts:
        row = torch.empty(0)
        for key, value in dW_dict:
            row = torch.cat((row, value.flatten()), 0)
        rows.append(row)
        
    matrix = torch.stack(rows, 0)
    return matrix.numpy()

In [None]:
# hyperparameters
N_CLIENT = 25
N_ADV = 1



In [None]:
# dataset preprocess
# TODO
clients = [Client() * N_CLIENT]
adv_idx = []

In [55]:
# hyperparemeters
TOTAL_RUNS = 30
TOTAL_ROUNDS = 20

esp_vals = [0.1, 0.2, 0.3, 0.4, 0.5]
min_samples_vals = [2, 3, 4, 5, 6]
metric_vals = ['l1', 'l2', 'cosine']

detect_counter = defaultdict(lambda: [0] * TOTAL_ROUNDS)

In [11]:
for esp in esp_vals:
    for min_samples in min_samples_vals:
        for metric in metric_vals:
            counter_key = (esp, min_samples, metric)
            
            for round in range(TOTAL_ROUNDS):
                    if round == 0:
                        for client in clients:
                            client.synchronize_with_server(server)

                    participating_clients = server.select_clients(clients, frac=1.0)

                    for client in participating_clients:
                        train_stats = client.compute_weight_update(epochs=1)
                        client.reset()
                        
                    client_dW_dicts = [client.dW for client in clients]
                    feature_matrix = generate_feature_matrix(client_dW_dicts)
                    
                    detect_adv_idx = server.detect_adversary(feature_matrix, esp, min_samples, metric)
                    detect_result = check_detect(detect_adv_idx, adv_idx)
                    
                    if detect_result:
                        detect_counter[counter_key][round] += 1

In [36]:
# Plots
# TODO
x = torch.empty(0)
b = torch.tensor([[2.2,3], [2,4]])