# Colab Setup  
> Make sure you configure notebook with GPU: Click Edit->notebook settings->hardware accelerator->GPU

> Uncomment the following cell after opening in Google colab. (Do not uncomment it in local setup.)  

<a target="_blank" href="https://colab.research.google.com/github/SEED-VT/FedDebug/blob/main/fault-localization/artifact.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>



In [None]:
# !pip install pytorch-lightning
# !pip install diskcache
# !pip install dotmap
# !pip install torch torchvision torchaudio
# !git clone https://github.com/SEED-VT/FedDebug.git
# # appending the path
# import sys
# sys.path.append("FedDebug/fault-localization/")

In [None]:
import logging
import time
from dotmap import DotMap
from pytorch_lightning import seed_everything
from torch.nn.init import kaiming_uniform_ 
from utils.faulty_client_localization.FaultyClientLocalization import FaultyClientLocalization
from utils.faulty_client_localization.InferenceGuidedInputs import InferenceGuidedInputs
from utils.FLSimulation import trainFLMain

logging.basicConfig(filename='example.log', level=logging.ERROR)
logger = logging.getLogger("pytorch_lightning")
seed_everything(786)

# Description

This code is running a simulation for fault localization in federated learning. 
- It first trains a federated learning model using the provided arguments (e.g. `learning rate`, `weight decay`, `batch size`, `model architecture`, `number of epochs`, `dataset`, `number of clients`, and `faulty client IDs`).

- Then, it runs the FaultyClientLocalization function on the trained models to identify `potential faulty clients` for each input. The function uses the `InferenceGuidedInputs` class to generate inputs for the models, and the `FaultyClientLocalization` class to run the fault localization.

- Finally, it uses the `evaluateFaultLocalization` function to calculate the accuracy of the fault localization by comparing the `predicted faulty clients` to the `true faulty clients`. The `accuracy` is calculated by taking the number of `correctly localized faults` divided by the `total number of true faults`. 

- It also prints out the predicted faulty clients for each input.

In [None]:
def evaluateFaultLocalization(predicted_faulty_clients_on_each_input, true_faulty_clients):
    true_faulty_clients = set(true_faulty_clients)
    detection_acc = 0
    for pred_faulty_clients in predicted_faulty_clients_on_each_input:
        print(f"+++ Faulty Clients {pred_faulty_clients}")
        correct_localize_faults = len(
            true_faulty_clients.intersection(pred_faulty_clients))
        acc = (correct_localize_faults/len(true_faulty_clients))*100
        detection_acc += acc
    fault_localization_acc = detection_acc / \
        len(predicted_faulty_clients_on_each_input)
    return fault_localization_acc


def runFaultyClientLocalization(client2models, exp2info, num_bugs, random_generator=kaiming_uniform_, apply_transform=True, k_gen_inputs=10, na_threshold=0.003, use_gpu=True):
    print(">  Running FaultyClientLocalization ..")
    input_shape = list(exp2info['data_config']['single_input_shape'])
    generate_inputs = InferenceGuidedInputs(client2models, input_shape, randomGenerator=random_generator, apply_transform=apply_transform,
                                            dname=exp2info['data_config']['name'], min_nclients_same_pred=5, k_gen_inputs=k_gen_inputs)
    selected_inputs, input_gen_time = generate_inputs.getInputs()

    start = time.time()
    faultyclientlocalization = FaultyClientLocalization(
        client2models, selected_inputs, use_gpu=use_gpu)

    potential_benign_clients_for_each_input = faultyclientlocalization.runFaultLocalization(
        na_threshold, num_bugs=num_bugs)
    fault_localization_time = time.time()-start
    return potential_benign_clients_for_each_input, input_gen_time, fault_localization_time



In [None]:
# ====== Simulation Config ====== 
args = DotMap()
args.lr = 0.001
args.weight_decay = 0.0001
args.batch_size = 512

args.model = "resnet50" # [resnet18, resnet34, resnet50, densenet121, vgg16]
args.epochs = 5  # range 10-25
args.dataset = "cifar10" # ['cifar10', 'femnist']
args.clients = 5 # keep under 30 clients and use Resnet18, Resnet34, or Densenet to evaluate on Colab 
args.faulty_clients_ids = "0" # can be multiple clients separated by comma e.g. "0,1,2"  but keep under args.clients clients and at max less than 7 
args.noise_rate = 1  # noise rate 0 to 1 
args.sampling = "iid" # [iid, "niid"] 


In [None]:
# FL training
c2ms, exp2info = trainFLMain(args)
client2models = {k: v.model.eval() for k, v in c2ms.items()}



In [None]:
# Fault localazation to find potetial faulty clients
potential_faulty_clients, _, _ = runFaultyClientLocalization(
    client2models=client2models, exp2info=exp2info, num_bugs=len(exp2info['faulty_clients_ids']))


In [None]:
# Fault localization accuracy 
acc = evaluateFaultLocalization(
    potential_faulty_clients, exp2info['faulty_clients_ids'])
print(f"Fault Localization Accuracy: {acc}")
