In [1]:
%load_ext autoreload
%autoreload 2

import asyncio, copy, os, pickle, socket, sys, time
from functools import partial
from multiprocessing import Pool, Process
from pathlib import Path
from tqdm import tqdm

import torch
from torch import optim
from torch.utils.tensorboard import SummaryWriter

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))
from libs import agg, data, fl, log, nn, poison, resnet, sim

In [2]:
# Save Logs To File (info | debug | warning | error | critical) [optional]
log.init("info")
#log.init("info", "federated.log")
#log.init("debug", "flkafka.log")

In [3]:
class FedArgs():
    def __init__(self):
        self.num_clients = 50
        self.epochs = 50
        self.local_rounds = 1
        self.client_batch_size = 32
        self.test_batch_size = 128
        self.learning_rate = 1e-4
        self.weight_decay = 1e-5
        self.cuda = False
        self.seed = 1
        self.loop = asyncio.get_event_loop()
        self.agg_rule = agg.Rule.FedAvg
        self.dataset = "mnist"
        self.labels = [label for label in range(10)]
        self.model = nn.ModelMNIST() #resnet.ResNet18()
        self.tb = SummaryWriter('../../out/runs/federated/FedAvg/mn-lfa-next-e-50', comment="Centralized Federated training")

fedargs = FedArgs()

<h2>Set right parameters for poisoning here before proceeding, else make all False!</h2>

In [4]:
# FLTrust
FLTrust = {"is": True if fedargs.agg_rule in [agg.Rule.FLTrust, agg.Rule.FLTC] else False,
           "ratio": 0.003,
           "data": None,
           "loader": None,
           "proxy": {"is": False,
                     "ratio": 0.5,
                     "data": None,
                     "loader": None}}

# No of malicious clients
mal_clients = [c for c in range(20)]

# Label Flip
label_flip_attack = {"is": True,
                     "func": poison.label_flip_next,
                     "labels": {},
                     "percent": -1}
label_flip_attack["labels"] = {4: 6} if label_flip_attack["is"] and label_flip_attack["func"] is poison.label_flip else None
label_flip_attack["labels"] = {label: fedargs.labels[(index + 1) % len(fedargs.labels)] for index, label in enumerate(fedargs.labels)} if label_flip_attack["is"] and label_flip_attack["func"] is poison.label_flip_next else label_flip_attack["labels"]

# Backdoor
backdoor_attack = {"is": False,
                   "trojan_func": poison.insert_trojan_pattern,
                   "target_label": 6,
                   "ratio": 0.006,
                   "data": None,
                   "loader": None}

# Layer replacement attack
layer_replacement_attack = {"is": False,
                            "layers": ["conv1.weight"]}

# Cosine attack
cosine_attack = {"is": False,
                 "args": {"poison_percent": 1,
                          "scale_dot": 5,
                          "scale_dot_factor": 1,
                          "scale_norm": 500,
                          "scale_norm_factor": 2,
                          "scale_epoch": 5}}

# Sybil attack, for sending same update as base
sybil_attack = {"is": False}

In [5]:
# Device settings
use_cuda = fedargs.cuda and torch.cuda.is_available()
torch.manual_seed(fedargs.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}

In [6]:
# Prepare clients
host = socket.gethostname()
clients = [host + "(" + str(client + 1) + ")" for client in range(fedargs.num_clients)]

In [7]:
# Initialize Global and Client models
global_model = copy.deepcopy(fedargs.model)
# Load Data to clients
train_data, test_data = data.load_dataset(fedargs.dataset)

<h2>FLTrust</h2>

In [8]:
if FLTrust["is"]:
    train_data, FLTrust["data"] = data.random_split(train_data, FLTrust["ratio"])
    FLTrust["loader"] = torch.utils.data.DataLoader(FLTrust["data"], batch_size=fedargs.client_batch_size, shuffle=True, **kwargs)
    
    if FLTrust["proxy"]["is"]:
        FLTrust["data"], FLTrust["proxy"]["data"] = data.random_split(FLTrust["data"], FLTrust["proxy"]["ratio"])
        FLTrust["loader"] = torch.utils.data.DataLoader(FLTrust["data"], batch_size=fedargs.client_batch_size, shuffle=True, **kwargs)
        FLTrust["proxy"]["loader"] = torch.utils.data.DataLoader(FLTrust["proxy"]["data"], batch_size=fedargs.client_batch_size, shuffle=True, **kwargs)

<h2>Prepare a backdoored loader for test</h2>

In [9]:
if backdoor_attack["is"]:
    train_data, backdoor_attack["data"] = data.random_split(train_data, backdoor_attack["ratio"])
    backdoor_attack["data"] = poison.insert_trojan(backdoor_attack["data"],
                                                   backdoor_attack["target_label"],
                                                   backdoor_attack["trojan_func"], 1)
    backdoor_attack["loader"] = torch.utils.data.DataLoader(backdoor_attack["data"], batch_size=fedargs.client_batch_size, shuffle=True, **kwargs)

<h2>Load client's data</h2>

In [10]:
clients_data = data.split_data(train_data, clients)

<h2>Label Flip Attack</h2>

In [11]:
if label_flip_attack["is"]:
    for client in mal_clients:
        clients_data[clients[client]] = label_flip_attack["func"](clients_data[clients[client]],
                                                                  label_flip_attack["labels"],
                                                                  label_flip_attack["percent"])

<h2>Backdoor Attack</h2>

In [12]:
if backdoor_attack["is"]:
    for client in mal_clients:
        clients_data[clients[client]] = poison.insert_trojan(clients_data[clients[client]],
                                                             backdoor_attack["target_label"],
                                                             backdoor_attack["trojan_func"], 0.5)

In [13]:
client_train_loaders, _ = data.load_client_data(clients_data, fedargs.client_batch_size, None, **kwargs)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=fedargs.test_batch_size, shuffle=True, **kwargs)

client_details = {
        client: {"train_loader": client_train_loaders[client],
                 "model":  copy.deepcopy(global_model),
                 "model_update": None}
        for client in clients
    }

In [14]:
def background(f):
    def wrapped(*args, **kwargs):
        return asyncio.get_event_loop().run_in_executor(None, f, *args, **kwargs)

    return wrapped

@background
def process(client, epoch, model, train_loader, fedargs, device):
    # Train
    model_update, model, loss = fl.train_model(model, train_loader, 
                                            fedargs.learning_rate,
                                            fedargs.weight_decay,
                                            fedargs.local_rounds, device)

    log.jsondebug(loss, "Epoch {} of {} : Federated Training loss, Client {}".format(epoch, fedargs.epochs, client))
    log.modeldebug(model_update, "Epoch {} of {} : Client {} Update".format(epoch, fedargs.epochs, client))
    
    return model_update

In [None]:
import time
start_time = time.time()
    
# Federated Training
for _epoch in tqdm(range(fedargs.epochs)):

    epoch = _epoch + 1
    log.info("Federated Training Epoch {} of {}".format(epoch, fedargs.epochs))

    # Gloabal Model Update
    if epoch > 1:
        # For Tmean and FLTrust, not impacts others as of now
        avgargs = {"beta": 10, 
                   "base_model_update": global_model_update if FLTrust["is"] else None,
                   "base_norm": False}
        
        # Average
        client_model_updates = {client: details["model_update"] for client, details in client_details.items()}
        global_model = fl.federated_avg(client_model_updates, global_model, fedargs.agg_rule, **avgargs)
        log.modeldebug(global_model, "Epoch {} of {} : Server Update".format(epoch, fedargs.epochs))

        # Test
        global_test_output = fl.evaluate(global_model, test_loader, device, label_flip_attack["labels"])
        fedargs.tb.add_scalar("Gloabl Accuracy/", global_test_output["accuracy"], epoch)
        if "attack" in global_test_output:
            if "attack_success_rate" in global_test_output["attack"]:
                fedargs.tb.add_scalar("Attack Success Rate/", global_test_output["attack"]["attack_success_rate"], epoch)
            if "misclassification_rate" in global_test_output["attack"]:
                fedargs.tb.add_scalar("Misclassification Rate/", global_test_output["attack"]["misclassification_rate"], epoch)
        if backdoor_attack["is"]:
            backdoor_test_output = fl.backdoor_test(global_model, backdoor_attack["loader"], device,
                                                                backdoor_attack["target_label"])
            fedargs.tb.add_scalar("Backdoor Success Rate/", backdoor_test_output["accuracy"], epoch)
            log.jsoninfo(backdoor_test_output, "Backdoor Test Outut after Epoch {} of {}".format(epoch, fedargs.epochs))
        log.jsoninfo(global_test_output, "Global Test Outut after Epoch {} of {}".format(epoch, fedargs.epochs))
    
        # Update client models
        for client in clients:
            client_details[client]['model'] = copy.deepcopy(global_model)

    # Clients
    tasks = [process(client, epoch, client_details[client]['model'],
                     client_details[client]['train_loader'],
                     fedargs, device) for client in clients]
    try:
        updates = fedargs.loop.run_until_complete(asyncio.gather(*tasks))
    except KeyboardInterrupt as e:
        print("Caught keyboard interrupt. Canceling tasks...")
        tasks.cancel()
        fedargs.loop.run_forever()
        tasks.exception()

    for client, update in zip(clients, updates):
        client_details[client]['model_update'] = update
    
    if FLTrust["is"]:
        global_model_update, _, _ = fl.train_model(global_model, FLTrust["loader"],
                                                fedargs.learning_rate,
                                                fedargs.weight_decay,
                                                fedargs.local_rounds, device)

        # For Attacks related to FLTrust
        base_model_update = global_model_update
        if FLTrust["proxy"]["is"]:
            base_model_update, _, _ = fl.train_model(global_model, FLTrust["proxy"]["loader"],
                                                  fedargs.learning_rate,
                                                  fedargs.weight_decay,
                                                  fedargs.local_rounds, device)
            
        if layer_replacement_attack["is"]:
            for client in mal_clients:
                client_details[clients[client]]['model_update'] = poison.layer_replacement_attack(base_model_update, 
                                                                                                  client_details[clients[client]]['model_update'],
                                                                                                  layer_replacement_attack["layers"])

        # For cosine attack, Malicious Clients
        if cosine_attack["is"]:
            b_arr, b_list = sim.get_net_arr(base_model_update)
            
            if epoch % cosine_attack["args"]["scale_epoch"] == 0:
                cosine_attack["args"]["scale_dot"] = cosine_attack["args"]["scale_dot_factor"] + cosine_attack["args"]["scale_dot"]
                cosine_attack["args"]["scale_norm"] = cosine_attack["args"]["scale_norm_factor"] * cosine_attack["args"]["scale_norm"]

            with Pool(len(mal_clients)) as p:
                func = partial(poison.model_poison_cosine_coord, b_arr, cosine_attack["args"])
                p_models = p.map(func, [sim.get_net_arr(client_details[clients[client]]['model_update'])[0]
                                        for client in mal_clients])
                p.close()
                p.join()


            for client, (p_arr, _) in zip(mal_clients, p_models):
                client_details[clients[client]]['model_update'] = sim.get_arr_net(client_details[clients[client]]['model_update'],
                                                                        p_arr, b_list)
                
            #plot params changed for only one client
            fedargs.tb.add_scalar("Params Changed for Cosine Attack/", p_models[0][1], epoch)

        # For sybil attack, Malicious Clients
        if sybil_attack["is"]:
            for client in mal_clients:
                client_details[clients[client]]['model_update'] = base_model_update

print(time.time() - start_time)

  0%|          | 0/50 [00:00<?, ?it/s]2021-09-17 14:45:26,619 - <ipython-input-15-b206d8af25dc>::<module>(l:8) : Federated Training Epoch 1 of 50 [MainProcess : MainThread (INFO)]
  2%|▏         | 1/50 [00:27<22:48, 27.92s/it]2021-09-17 14:45:54,345 - <ipython-input-15-b206d8af25dc>::<module>(l:8) : Federated Training Epoch 2 of 50 [MainProcess : MainThread (INFO)]
2021-09-17 14:46:00,859 - <ipython-input-15-b206d8af25dc>::<module>(l:35) : Global Test Outut after Epoch 2 of 50 {
    "accuracy": 81.26,
    "attack": {
        "attack_success_count": 141,
        "attack_success_rate": 1.41,
        "instances": 10000,
        "misclassification_rate": 18.740000000000002,
        "misclassifications": 1874
    },
    "correct": 8126,
    "test_loss": 1.4858936618804932
} [MainProcess : MainThread (INFO)]
  4%|▍         | 2/50 [01:01<25:00, 31.25s/it]2021-09-17 14:46:27,932 - <ipython-input-15-b206d8af25dc>::<module>(l:8) : Federated Training Epoch 3 of 50 [MainProcess : MainThread (INFO)

 28%|██▊       | 14/50 [07:43<20:12, 33.68s/it]2021-09-17 14:53:10,223 - <ipython-input-15-b206d8af25dc>::<module>(l:8) : Federated Training Epoch 15 of 50 [MainProcess : MainThread (INFO)]
2021-09-17 14:53:16,593 - <ipython-input-15-b206d8af25dc>::<module>(l:35) : Global Test Outut after Epoch 15 of 50 {
    "accuracy": 94.53,
    "attack": {
        "attack_success_count": 55,
        "attack_success_rate": 0.5499999999999999,
        "instances": 10000,
        "misclassification_rate": 5.47,
        "misclassifications": 547
    },
    "correct": 9453,
    "test_loss": 0.1920014710187912
} [MainProcess : MainThread (INFO)]
 30%|███       | 15/50 [08:18<19:50, 34.02s/it]2021-09-17 14:53:45,042 - <ipython-input-15-b206d8af25dc>::<module>(l:8) : Federated Training Epoch 16 of 50 [MainProcess : MainThread (INFO)]
2021-09-17 14:53:50,865 - <ipython-input-15-b206d8af25dc>::<module>(l:35) : Global Test Outut after Epoch 16 of 50 {
    "accuracy": 94.75,
    "attack": {
        "attack_suc

In [None]:
nn.ModelMNIST()