In [1]:
%load_ext autoreload
%autoreload 2

import asyncio, copy, os, pickle, socket, sys, time
from pathlib import Path
from tqdm import tqdm

import torch
from torch import optim
from torch.utils.tensorboard import SummaryWriter

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))
from libs import fl, nn, agg, data, poison, log

In [2]:
# Save Logs To File (info | debug | warning | error | critical) [optional]
log.init("info")
#log.init("info", "federated.log")
#log.init("debug", "flkafka.log")

In [3]:
class FedArgs():
    def __init__(self):
        self.num_clients = 50
        self.epochs = 25
        self.local_rounds = 1
        self.client_batch_size = 32
        self.test_batch_size = 128
        self.learning_rate = 1e-4
        self.weight_decay = 1e-5
        self.cuda = False
        self.seed = 1
        self.loop = asyncio.get_event_loop()
        self.tb = SummaryWriter('../../out/runs/federated/FLTrust(Attack)', comment="Mnist Centralized Federated training")

fedargs = FedArgs()

In [4]:
use_cuda = fedargs.cuda and torch.cuda.is_available()
torch.manual_seed(fedargs.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

In [5]:
host = socket.gethostname()
clients = [host + "(" + str(client + 1) + ")" for client in range(fedargs.num_clients)]

In [6]:
#Initialize Global and Client models
global_model = nn.ModelMNIST()
client_models = {client: copy.deepcopy(global_model) for client in clients}

# Function for training
def train_model(_model, train_loader, fedargs, device):
    model, loss = fl.client_update(_model,
                                train_loader,
                                fedargs.learning_rate,
                                fedargs.weight_decay,
                                fedargs.local_rounds,
                                device)
    model_update = agg.sub_model(_model, model)
    return model_update, model, loss

In [7]:
# Load MNIST Data to clients
train_data, test_data = data.load_dataset("mnist")

In [8]:
# For securing if the next cell execution is skipped
FLTrust = None
FLTrust_attack = None

<h1>FLTrust: Skip section below for any other averaging than FLTrust.</h1>

In [9]:
FLTrust = True
root_ratio = 0.01
train_data, root_data = torch.utils.data.random_split(train_data, [int(len(train_data) * (1-root_ratio)), 
                                                              int(len(train_data) * root_ratio)])
root_loader = torch.utils.data.DataLoader(root_data, batch_size=fedargs.client_batch_size, shuffle=True, **kwargs)

<h2>Resume</h2>

In [10]:
clients_data = data.split_data(train_data, clients)

<h1>Poison: Skip section(s) below to run normal, modify if required.</h1>

In [11]:
mal_clients = [c for c in range(20)]

<h2>Label Flipping attack, Skip if not required</h2>

In [28]:
for client in mal_clients:
    clients_data[clients[client]] = poison.label_flip(clients_data[clients[client]], 4, 9, poison_percent = -1)
    
#clients_data[clients[0]] = poison.label_flip(clients_data[clients[0]], 6, 2, poison_percent = 1)
#clients_data[clients[0]] = poison.label_flip(clients_data[clients[0]], 3, 8, poison_percent = 1)
#clients_data[clients[0]] = poison.label_flip(clients_data[clients[0]], 1, 5, poison_percent = 1)

<h2>FLTrust: Sine Attack, Skip if not required</h2>

In [12]:
FLTrust_attack = True

<h2>Resume</h2>

In [13]:
client_train_loaders, _ = data.load_client_data(clients_data, fedargs.client_batch_size, None, **kwargs)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=fedargs.test_batch_size, shuffle=True, **kwargs)

clients_info = {
        client: {"train_loader": client_train_loaders[client]}
        for client in clients
    }

In [14]:
def background(f):
    def wrapped(*args, **kwargs):
        return asyncio.get_event_loop().run_in_executor(None, f, *args, **kwargs)

    return wrapped

@background
def process(client, epoch, model, train_loader, fedargs, device):
    # Train
    model_update, model, loss = train_model(model, train_loader, fedargs, device)

    #Plot and Log
    for local_epoch, loss in enumerate(list(loss.values())):
        fedargs.tb.add_scalars("Training Loss/" + client, {str(epoch): loss}, str(local_epoch + 1))

    log.jsondebug(loss, "Epoch {} of {} : Federated Training loss, Client {}".format(epoch, fedargs.epochs, client))
    log.modeldebug(model_update, "Epoch {} of {} : Client {} Update".format(epoch, fedargs.epochs, client))
    
    return model_update

In [None]:
import time
start_time = time.time()
    
# Federated Training
for _epoch in tqdm(range(fedargs.epochs)):

    epoch = _epoch + 1
    log.info("Federated Training Epoch {} of {}".format(epoch, fedargs.epochs))

    # Gloabal Model Update
    if epoch > 1:
        # For Tmean, not impacts others as of now
        avgargs = {"beta": 10}
        
        # For FLTrust, if FLTrust section is skipped, this piece of code will be ignored automatically
        if FLTrust:
            global_model_update, _, _ = train_model(global_model, root_loader, fedargs, device)
            avgargs["base_update"] = global_model_update
        
        # Average
        global_model = fl.federated_avg(client_model_updates, global_model, agg.Rule.FLTrust, **avgargs)
        log.modeldebug(global_model, "Epoch {} of {} : Server Update".format(epoch, fedargs.epochs))

        # Test
        global_test_output = fl.eval(global_model, test_loader, device)
        fedargs.tb.add_scalar("Gloabl Accuracy/", global_test_output["accuracy"], epoch)
        log.jsoninfo(global_test_output, "Global Test Outut after Epoch {} of {}".format(epoch, fedargs.epochs))
    
        # Update client models
        client_models = {client: copy.deepcopy(global_model) for client in clients}

    # Clients

    tasks = [process(client, epoch, client_models[client],
                     clients_info[client]['train_loader'],
                     fedargs, device) for client in clients]
    try:
        updates = fedargs.loop.run_until_complete(asyncio.gather(*tasks))
    except KeyboardInterrupt as e:
        print("Caught keyboard interrupt. Canceling tasks...")
        tasks.cancel()
        fedargs.loop.run_forever()
        tasks.exception()
    
    client_model_updates = {client: update for client, update in zip(clients, updates)}
    
    # For FLTrust, Malicious Clients
    if epoch > 1 and FLTrust_attack:
        for client in mal_clients:
            # using existing global_model_update, however client can also calculate it by preserving the previous one.
            client_model_updates[clients[client]] = poison.model_poison_cosine(global_model_update, client_model_updates[clients[client]])
    
print(time.time() - start_time)

  0%|          | 0/25 [00:00<?, ?it/s]2021-08-25 17:33:36,687 - <ipython-input-15-b885986b40c0>::<module>(l:8) : Federated Training Epoch 1 of 25 [MainProcess : MainThread (INFO)]
  4%|▍         | 1/25 [00:41<16:24, 41.03s/it]2021-08-25 17:34:17,470 - <ipython-input-15-b885986b40c0>::<module>(l:8) : Federated Training Epoch 2 of 25 [MainProcess : MainThread (INFO)]
  _param_list = nd.array(param_list).squeeze()
2021-08-25 17:34:25,310 - /home/harsh_1921cs01/hub/F3IA/fl/libs/agg.py::FLTrust(l:90) : FLTrust Score [0.537, 0.544, 0.535, 0.551, 0.493, 0.511, 0.546, 0.51, 0.524, 0.515, 0.535, 0.532, 0.517, 0.512, 0.532, 0.514, 0.526, 0.529, 0.533, 0.528, 0.523, 0.574, 0.517, 0.535, 0.531, 0.502, 0.521, 0.53, 0.54, 0.53, 0.552, 0.558, 0.534, 0.512, 0.513, 0.498, 0.518, 0.522, 0.514, 0.557, 0.534, 0.466, 0.514, 0.531, 0.552, 0.521, 0.515, 0.501, 0.511, 0.55] [MainProcess : MainThread (INFO)]
2021-08-25 17:34:31,274 - <ipython-input-15-b885986b40c0>::<module>(l:27) : Global Test Outut after Epo

2021-08-25 17:41:35,834 - <ipython-input-15-b885986b40c0>::<module>(l:27) : Global Test Outut after Epoch 10 of 25 {
    "accuracy": 91.09,
    "correct": 9109,
    "test_loss": 0.3330233331441879
} [MainProcess : MainThread (INFO)]
 40%|████      | 10/25 [08:42<13:23, 53.58s/it]2021-08-25 17:42:18,799 - <ipython-input-15-b885986b40c0>::<module>(l:8) : Federated Training Epoch 11 of 25 [MainProcess : MainThread (INFO)]
2021-08-25 17:42:24,998 - /home/harsh_1921cs01/hub/F3IA/fl/libs/agg.py::FLTrust(l:90) : FLTrust Score [0.105, 0.108, 0.103, 0.107, 0.104, 0.106, 0.102, 0.102, 0.104, 0.108, 0.104, 0.104, 0.104, 0.103, 0.105, 0.105, 0.103, 0.104, 0.102, 0.104, 0.135, 0.14, 0.149, 0.158, 0.12, 0.129, 0.122, 0.124, 0.132, 0.129, 0.138, 0.132, 0.122, 0.14, 0.121, 0.135, 0.143, 0.126, 0.111, 0.102, 0.113, 0.141, 0.135, 0.135, 0.136, 0.135, 0.142, 0.133, 0.142, 0.124] [MainProcess : MainThread (INFO)]
2021-08-25 17:42:31,088 - <ipython-input-15-b885986b40c0>::<module>(l:27) : Global Test Outut

In [None]:
fedargs.loop.close()