In [None]:
%load_ext autoreload
%autoreload 2

import asyncio, copy, os, pickle, socket, sys, time
from functools import partial
from multiprocessing import Pool, Process
from pathlib import Path
from tqdm import tqdm

import torch
from torch import optim
from torch.utils.tensorboard import SummaryWriter

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))
from libs import agg, data, fl, log, nn, poison, sim

In [49]:
# Save Logs To File (info | debug | warning | error | critical) [optional]
log.init("info")
#log.init("info", "federated.log")
#log.init("debug", "flkafka.log")

In [50]:
class FedArgs():
    def __init__(self):
        self.num_clients = 50
        self.epochs = 25
        self.local_rounds = 1
        self.client_batch_size = 32
        self.test_batch_size = 128
        self.learning_rate = 1e-4
        self.weight_decay = 1e-5
        self.cuda = False
        self.seed = 1
        self.loop = asyncio.get_event_loop()
        self.tb = SummaryWriter('../../out/runs/federated/FLTrust/m-sine-5-dot-2 (proxy)', comment="Centralized Federated training")

fedargs = FedArgs()

In [51]:
use_cuda = fedargs.cuda and torch.cuda.is_available()
torch.manual_seed(fedargs.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

In [52]:
host = socket.gethostname()
clients = [host + "(" + str(client + 1) + ")" for client in range(fedargs.num_clients)]

In [53]:
# Initialize Global and Client models
global_model = nn.ModelMNIST()
client_models = {client: copy.deepcopy(global_model) for client in clients}

# Function for training
def train_model(_model, train_loader, fedargs, device):
    model, loss = fl.client_update(_model,
                                train_loader,
                                fedargs.learning_rate,
                                fedargs.weight_decay,
                                fedargs.local_rounds,
                                device)
    model_update = agg.sub_model(_model, model)
    return model_update, model, loss

In [54]:
# Load MNIST Data to clients
train_data, test_data = data.load_dataset("mnist")

In [55]:
# For securing if the next cell execution is skipped
FLTrust = None
cosine_attack = None
proxy_server = None
sybil_attack = None

<h1>FLTrust: Skip section below for any other averaging than FLTrust.</h1>

In [56]:
FLTrust = True
root_ratio = 0.003
train_data, root_data = torch.utils.data.random_split(train_data, [int(len(train_data) * (1-root_ratio)), 
                                                              int(len(train_data) * root_ratio)])
root_loader = torch.utils.data.DataLoader(root_data, batch_size=fedargs.client_batch_size, shuffle=True, **kwargs)

<h2>Resume</h2>

In [57]:
clients_data = data.split_data(train_data, clients)

<h1>Poison: Skip section(s) below to run normal, modify if required.</h1>

In [58]:
mal_clients = [c for c in range(24)]

<h2>Label Flipping attack, Skip if not required</h2>

In [12]:
for client in mal_clients:
    clients_data[clients[client]] = poison.label_flip(clients_data[clients[client]], 4, 9, poison_percent = -1)
    
#clients_data[clients[0]] = poison.label_flip(clients_data[clients[0]], 6, 2, poison_percent = 1)
#clients_data[clients[0]] = poison.label_flip(clients_data[clients[0]], 3, 8, poison_percent = 1)
#clients_data[clients[0]] = poison.label_flip(clients_data[clients[0]], 1, 5, poison_percent = 1)

<h2>Cosine Attack, Skip if not required</h2>

In [59]:
cosine_attack = True
cosargs = {"poison_percent": 1, "scale_dot": 5, "scale_norm": 2}

<h3>If using proxy server (for partial knowledge), Skip if not required</h3>

In [60]:
proxy_server = True
proxy_ratio = 0.5
proxy_data, root_data = torch.utils.data.random_split(root_data, [int(len(root_data) * (1-proxy_ratio)), 
                                                              int(len(root_data) * proxy_ratio)])
root_loader = torch.utils.data.DataLoader(root_data, batch_size=fedargs.client_batch_size, shuffle=True, **kwargs)
proxy_loader = torch.utils.data.DataLoader(proxy_data, batch_size=fedargs.client_batch_size, shuffle=True, **kwargs)

<h2>Sybil Attack, Skip if not required</h2>

In [38]:
sybil_attack = True

<h2>Resume</h2>

In [61]:
client_train_loaders, _ = data.load_client_data(clients_data, fedargs.client_batch_size, None, **kwargs)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=fedargs.test_batch_size, shuffle=True, **kwargs)

clients_info = {
        client: {"train_loader": client_train_loaders[client]}
        for client in clients
    }

In [62]:
def background(f):
    def wrapped(*args, **kwargs):
        return asyncio.get_event_loop().run_in_executor(None, f, *args, **kwargs)

    return wrapped

@background
def process(client, epoch, model, train_loader, fedargs, device):
    # Train
    model_update, model, loss = train_model(model, train_loader, fedargs, device)

    # Plot and Log
    #for local_epoch, loss in enumerate(list(loss.values())):
    #    fedargs.tb.add_scalars("Training Loss/" + client, {str(epoch): loss}, str(local_epoch + 1))

    log.jsondebug(loss, "Epoch {} of {} : Federated Training loss, Client {}".format(epoch, fedargs.epochs, client))
    log.modeldebug(model_update, "Epoch {} of {} : Client {} Update".format(epoch, fedargs.epochs, client))
    
    return model_update

In [63]:
import time
start_time = time.time()
    
# Federated Training
for _epoch in tqdm(range(fedargs.epochs)):

    epoch = _epoch + 1
    log.info("Federated Training Epoch {} of {}".format(epoch, fedargs.epochs))

    # Gloabal Model Update
    if epoch > 1:
        # For Tmean andFLTrust, not impacts others as of now
        avgargs = {"beta": 10, 
                   "base_update": global_model_update if "global_model_update" in locals() else None,
                   "base_norm": True}
        
        # Average
        global_model = fl.federated_avg(client_model_updates, global_model, agg.Rule.FLTrust, **avgargs)
        log.modeldebug(global_model, "Epoch {} of {} : Server Update".format(epoch, fedargs.epochs))

        # Test
        global_test_output = fl.eval(global_model, test_loader, device)
        fedargs.tb.add_scalar("Gloabl Accuracy/", global_test_output["accuracy"], epoch)
        log.jsoninfo(global_test_output, "Global Test Outut after Epoch {} of {}".format(epoch, fedargs.epochs))
    
        # Update client models
        client_models = {client: copy.deepcopy(global_model) for client in clients}

    # Clients
    tasks = [process(client, epoch, client_models[client],
                     clients_info[client]['train_loader'],
                     fedargs, device) for client in clients]
    try:
        updates = fedargs.loop.run_until_complete(asyncio.gather(*tasks))
    except KeyboardInterrupt as e:
        print("Caught keyboard interrupt. Canceling tasks...")
        tasks.cancel()
        fedargs.loop.run_forever()
        tasks.exception()
    
    client_model_updates = {client: update for client, update in zip(clients, updates)}
    
    if FLTrust:
        global_model_update, _, _ = train_model(global_model, root_loader, fedargs, device)
    
        # For Attacks related to FLTrust
        base_model_update = global_model_update
        if proxy_server:
            base_model_update, _, _ = train_model(global_model, proxy_loader, fedargs, device)
    
        # For cosine attack, Malicious Clients
        if cosine_attack:
            b_arr, b_list = sim.get_net_arr(base_model_update)

            with Pool(len(mal_clients)) as p:
                func = partial(poison.model_poison_cosine_coord, b_arr, cosargs)
                p_models = p.map(func, [sim.get_net_arr(client_model_updates[clients[client]])[0]
                                        for client in mal_clients])
                p.close()
                p.join()


            for client, (p_arr, _) in zip(mal_clients, p_models):
                client_model_updates[clients[client]] = sim.get_arr_net(client_model_updates[clients[client]],
                                                                        p_arr, b_list)
                
            #plot params changed for only one client
            fedargs.tb.add_scalar("Params Changed for Cosine Attack/", p_models[0][1], epoch)

        # For sybil attack, Malicious Clients
        if sybil_attack:
            for client in mal_clients:
                client_model_updates[clients[client]] = base_model_update

print(time.time() - start_time)

  0%|          | 0/25 [00:00<?, ?it/s]2021-09-04 20:03:28,907 - <ipython-input-63-87d4be2961d2>::<module>(l:8) : Federated Training Epoch 1 of 25 [MainProcess : MainThread (INFO)]
  4%|▍         | 1/25 [02:06<50:41, 126.73s/it]2021-09-04 20:05:35,643 - <ipython-input-63-87d4be2961d2>::<module>(l:8) : Federated Training Epoch 2 of 25 [MainProcess : MainThread (INFO)]
2021-09-04 20:05:40,754 - /home/harsh_1921cs01/hub/F3IA/fl/libs/agg.py::FLTrust(l:96) : Cosine Score [0.03622142, 0.045958314, 0.0030038257, 0.01256252, 0, 0.020588353, 0.040144306, 0.050083354, 0.0031156712, 0.038214166, 0.05149382, 0.020116817, 0.05099802, 0.030586226, 0.009136742, 0.026744125, 0.06554752, 0, 0.06256186, 0.007383803, 0.05471072, 0.048508417, 0.047328055, 0.097991504, 0.3315619, 0.3042499, 0.341642, 0.34182674, 0.31551716, 0.34989503, 0.34740975, 0.34663147, 0.35005882, 0.3367747, 0.3461386, 0.33748376, 0.35182384, 0.3634161, 0.35411125, 0.33874843, 0.33923692, 0.34782055, 0.33317357, 0.3268118, 0.33283356

2021-09-04 20:15:20,927 - <ipython-input-63-87d4be2961d2>::<module>(l:24) : Global Test Outut after Epoch 6 of 25 {
    "accuracy": 71.39,
    "correct": 7139,
    "test_loss": 1.469395474052429
} [MainProcess : MainThread (INFO)]
 24%|██▍       | 6/25 [14:02<45:00, 142.14s/it]2021-09-04 20:17:31,884 - <ipython-input-63-87d4be2961d2>::<module>(l:8) : Federated Training Epoch 7 of 25 [MainProcess : MainThread (INFO)]
2021-09-04 20:17:37,102 - /home/harsh_1921cs01/hub/F3IA/fl/libs/agg.py::FLTrust(l:96) : Cosine Score [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.42417318, 0.4291126, 0.41658133, 0.41201836, 0.41015926, 0.4281409, 0.420628, 0.4424944, 0.40988168, 0.43667495, 0.40579188, 0.41449806, 0.41395423, 0.43618035, 0.41933963, 0.4047081, 0.41690457, 0.41453066, 0.4269281, 0.4157073, 0.41075078, 0.42548186, 0.44067016, 0.43419838, 0.43452433, 0.4208721] [MainProcess : MainThread (INFO)]
2021-09-04 20:17:37,120 - /home/harsh_1921cs01/hub/F3IA/fl/libs/agg.p

 44%|████▍     | 11/25 [25:49<33:00, 141.48s/it]2021-09-04 20:29:18,657 - <ipython-input-63-87d4be2961d2>::<module>(l:8) : Federated Training Epoch 12 of 25 [MainProcess : MainThread (INFO)]
2021-09-04 20:29:23,555 - /home/harsh_1921cs01/hub/F3IA/fl/libs/agg.py::FLTrust(l:96) : Cosine Score [0.020922747, 0.06915193, 0.031065647, 0.02268113, 0.004891223, 0.003844383, 0.05595397, 0.008053807, 0.02826578, 0.012962077, 0.034017425, 0.014498921, 0, 0.0056127585, 0, 0.031657226, 0.0015229457, 0.018042676, 0.05011312, 0.012004727, 0.0021885047, 0.01176094, 0.018208979, 0.020755226, 0.25437355, 0.2624252, 0.21061966, 0.2565062, 0.25514817, 0.2749916, 0.29594022, 0.2439411, 0.24452126, 0.26846108, 0.25710508, 0.24976908, 0.24322847, 0.26356044, 0.2794778, 0.26073015, 0.26928222, 0.2504049, 0.24953032, 0.22297391, 0.2639877, 0.2646445, 0.24210088, 0.25197482, 0.2473898, 0.23187737] [MainProcess : MainThread (INFO)]
2021-09-04 20:29:23,571 - /home/harsh_1921cs01/hub/F3IA/fl/libs/agg.py::FLTrust(l

2021-09-04 20:38:27,373 - /home/harsh_1921cs01/hub/F3IA/fl/libs/agg.py::FLTrust(l:97) : FLTrust Score [0.0108572235, 0.011711815, 0.011692872, 0.013065703, 0.012798745, 0.012828883, 0.0100308815, 0.011266829, 0.013320793, 0.012363233, 0.012492956, 0.011449567, 0.012308451, 0.010993706, 0.011630864, 0.012453015, 0.011063773, 0.012005847, 0.013137201, 0.011128669, 0.011675067, 0.011895601, 0.01174217, 0.012935364, 0.03536335, 0.036581293, 0.044565413, 0.032373723, 0.033736017, 0.030897576, 0.03817015, 0.0491859, 0.03549666, 0.021521287, 0.031592526, 0.04288471, 0.02235699, 0.031077465, 0.039105132, 0.031195475, 0.048504673, 0.040670447, 0.036055133, 0.029468456, 0.039701357, 0.03837491, 0.04367845, 0.039915673, 0.037429, 0.04951747] [MainProcess : MainThread (INFO)]
2021-09-04 20:38:33,485 - <ipython-input-63-87d4be2961d2>::<module>(l:24) : Global Test Outut after Epoch 16 of 25 {
    "accuracy": 75.42999999999999,
    "correct": 7543,
    "test_loss": 0.7288111893653869
} [MainProcess :

2021-09-04 20:47:34,895 - <ipython-input-63-87d4be2961d2>::<module>(l:24) : Global Test Outut after Epoch 20 of 25 {
    "accuracy": 76.41,
    "correct": 7641,
    "test_loss": 0.6677708033561707
} [MainProcess : MainThread (INFO)]
 80%|████████  | 20/25 [46:10<11:20, 136.10s/it]2021-09-04 20:49:39,787 - <ipython-input-63-87d4be2961d2>::<module>(l:8) : Federated Training Epoch 21 of 25 [MainProcess : MainThread (INFO)]
2021-09-04 20:49:44,756 - /home/harsh_1921cs01/hub/F3IA/fl/libs/agg.py::FLTrust(l:96) : Cosine Score [0.044758935, 0.05067844, 0.04076597, 0.03596977, 0.039479617, 0.025737047, 0.04558699, 0.034688227, 0.03580427, 0.034515016, 0.042737328, 0.041283052, 0.037233554, 0.028254537, 0.035949368, 0.037599966, 0.043697644, 0.039243132, 0.027586991, 0.03316537, 0.03372079, 0.03508124, 0.025353353, 0.018566819, 0.052235696, 0.03862919, 0.06518037, 0.07026521, 0.07946188, 0.04208743, 0.0547107, 0.07791623, 0.048046082, 0.060887128, 0.056351624, 0.06565998, 0.038230743, 0.06689398

3445.2819497585297



