Generate the configuration
===
0. Install all dependencies (in your venv) with `pip install -e .`
1. Set the experiment parameters in the "Experiment Settings" section.
2. Execute the cell below to overwrite the configuration files.
3. run `flwr run . local-sim-gpu`

In [3]:
import random
random.seed(42)
import yaml
from attacks.attack_names import AttackNames
from detections.detection_names import DetectionNames
from dataset_names.dataset_names import DatasetNames

# --- Experiment Settings ---

# Dataset: mnist or cifar10
DATASET = DatasetNames.mnist

# Number of clients
NUM_CLIENTS = 10

# Number of federation rounds
NUM_ROUNDS = 20

# Percentage of malicious clients
PERCENT_MALICIOUS = 0.2

# Fraction of the clients to choose for training in one round. We always use 1.0
FRACTION_FIT = 1.0

# If true, plots about the loss and accuracy will be generated on wandb
# TODO try to use this for all the metrics
USE_WANDB = False

# Attack method to use
ATTACK_METHOD = AttackNames.advanced_delta_weights_attack

# Detection method to use
DETECTION_METHOD = DetectionNames.wef_detection
# --- --- --- --- --- --- ---

DATASET = DATASET.value
# ========
ADDITIONAL_DETECTION_CONFIG = {}
if DETECTION_METHOD.value == DetectionNames.dagmm_detection.value:
    ADDITIONAL_DETECTION_CONFIG["do_data_collection"] = False   # Set to True to collect training data for the DAGMM model.
    ADDITIONAL_DETECTION_CONFIG["dagmm_output_dir"] = "./dagmm/dagmm/dagmm_train_data/" + DATASET + "/run_test_3"  # Output directory of the training data of the current run.

    ADDITIONAL_DETECTION_CONFIG["dagmm_threshold_path"] = "./dagmm/dagmm/dagmm_anomaly_threshold.yaml"
    ADDITIONAL_DETECTION_CONFIG["dagmm_ignore_up_to"] = 0   # Does not perform the detection in the first x rounds
    ADDITIONAL_DETECTION_CONFIG["dagmm_model_path"] = "./dagmm/dagmm/dagmm_model_mnist.pt"
    ADDITIONAL_DETECTION_CONFIG["dagmm_hyperparameters_path"] = "./dagmm/dagmm/dagmm_hyperparameters.yaml"
    ADDITIONAL_DETECTION_CONFIG["gmm_parameters_paths"] = {
         "cov": "./dagmm/dagmm/gmm_param_cov.pt",
         "mean": "./dagmm/dagmm/gmm_param_mean.pt",
         "mixture": "./dagmm/dagmm/gmm_param_mixture.pt",
    }
elif DETECTION_METHOD.value == DetectionNames.std_dagmm_detection.value:
    ADDITIONAL_DETECTION_CONFIG["do_data_collection"] = False   # Use DAGMM for performing data collection
    ADDITIONAL_DETECTION_CONFIG["dagmm_output_dir"] = "-----" 
    ADDITIONAL_DETECTION_CONFIG["dagmm_threshold_path"] = "./dagmm/std_dagmm/dagmm_anomaly_threshold.yaml"
    ADDITIONAL_DETECTION_CONFIG["dagmm_ignore_up_to"] = 0   # Does not perform the detection in the first x rounds
    ADDITIONAL_DETECTION_CONFIG["dagmm_model_path"] = "./dagmm/std_dagmm/dagmm_model_mnist.pt"
    ADDITIONAL_DETECTION_CONFIG["dagmm_hyperparameters_path"] = "./dagmm/std_dagmm/dagmm_hyperparameters.yaml"
    ADDITIONAL_DETECTION_CONFIG["gmm_parameters_paths"] = {
         "cov": "./dagmm/std_dagmm/gmm_param_cov.pt",
         "mean": "./dagmm/std_dagmm/gmm_param_mean.pt",
         "mixture": "./dagmm/std_dagmm/gmm_param_mixture.pt",
    }
elif DETECTION_METHOD.value == DetectionNames.delta_dagmm_detection.value:
    ADDITIONAL_DETECTION_CONFIG["do_data_collection"] = False   # Set to True to collect training data for the Delta-DAGMM model.
    ADDITIONAL_DETECTION_CONFIG["do_test_data_collection"] = False   # Set to True to collect test data (additionally collects the global model for each iteration)
    ADDITIONAL_DETECTION_CONFIG["dagmm_output_dir"] = "./dagmm/delta_dagmm/dagmm_train_data/" + DATASET + "/" + str(NUM_CLIENTS) + "/run_malicious_2"  # Output directory of the training data of the current run.

    ADDITIONAL_DETECTION_CONFIG["dagmm_threshold_path"] = "./dagmm/delta_dagmm/models/mnist/" + str(NUM_CLIENTS) + "/dagmm_anomaly_threshold.yaml"
    ADDITIONAL_DETECTION_CONFIG["dagmm_ignore_up_to"] = 1   # Does not perform the detection in the first round. This is important for Delta-DAGMM since in the first iteration, the global model is randomly initialized by the server. Thus it is good to skip it.
    ADDITIONAL_DETECTION_CONFIG["dagmm_model_path"] = "./dagmm/delta_dagmm/models/mnist/" + str(NUM_CLIENTS) + "/model.pt"
    ADDITIONAL_DETECTION_CONFIG["dagmm_hyperparameters_path"] = "./dagmm/delta_dagmm/models/mnist/" + str(NUM_CLIENTS) + "/dagmm_hyperparameters.yaml"
    ADDITIONAL_DETECTION_CONFIG["gmm_parameters_paths"] = {
         "cov": "./dagmm/delta_dagmm/models/mnist/" + str(NUM_CLIENTS) + "/gmm_param_cov.pt",
         "mean": "./dagmm/delta_dagmm/models/mnist/" + str(NUM_CLIENTS) + "/gmm_param_mean.pt",
         "mixture": "./dagmm/delta_dagmm/models/mnist/" + str(NUM_CLIENTS) + "/gmm_param_mixture.pt",
    }

elif DETECTION_METHOD.value == DetectionNames.rffl_detection.value:
    ADDITIONAL_DETECTION_CONFIG["alpha"] = 0.95
    ADDITIONAL_DETECTION_CONFIG["beta"] = 1/(3*NUM_CLIENTS)
    ADDITIONAL_DETECTION_CONFIG["gamma"] = 0.5 if DATASET == "mnist" else 0.15  # Use 0.5 for MNIST and 0.15 for CIFAR10 (as specified by the authors)

elif DETECTION_METHOD.value == DetectionNames.fdfl_detection.value:
    ADDITIONAL_DETECTION_CONFIG["tau"] = 0.2    # If the cosine similarity of two submitted data distributions for two clients of the same cluster is smaller than tau,
                                                # i.e. they are not very similar, the flag counter will be increased (if flag is very high, it will be marked as a free rider).
    ADDITIONAL_DETECTION_CONFIG["n_clusters"] = 5       # Use 5 for 100 clients and 3 for 10 clients
    method = "weak"     # Perform either 'weak' or 'strong' imitation of label distributions.
    with open("./config/fake_label_distribution.yaml", "w") as f:
        yaml.dump({"method": method}, f)

elif DETECTION_METHOD.value == DetectionNames.viceroy_detection.value:
    ADDITIONAL_DETECTION_CONFIG["omega"] = 0.525    # History decay factor (value taken from the paper)
    ADDITIONAL_DETECTION_CONFIG["eta"] = 0.2        # Reputation update factor (value taken from the paper)
    ADDITIONAL_DETECTION_CONFIG["kappa"] = 0.5      # Confidence parameter of FoolsGold.
    ADDITIONAL_DETECTION_CONFIG["skip_first_round"] = True  # Specify if the first round should be skipped. 
                                                            # This is a good practice since the initial global model is randomly initialized.
                                                            # Thus, the gradient calculation can be messy.
    ADDITIONAL_DETECTION_CONFIG["free_rider_threshold"] = 0.1   # Specify the threshold that separates benign clients from free riders.
                                                                # The values it is compared to range from 0 to 1, where 0 denotes a very suspicious client
                                                                # and 1 a very unsuspicious one.

print("Using the detection method: ", DETECTION_METHOD.value)
if "dagmm" in DETECTION_METHOD.value and ADDITIONAL_DETECTION_CONFIG and ADDITIONAL_DETECTION_CONFIG["do_data_collection"]:
    if PERCENT_MALICIOUS != 0.0:
        print("\n\n####### Warning: Performing data collection with malicious clients! Only use if intended! #######\n\n")
# ========

# ~~~~~~~~
ADDITIONAL_ATTACK_CONFIG = {}
if ATTACK_METHOD.value == "random_weights":
    ADDITIONAL_ATTACK_CONFIG["R"] = 1e-2        # 1e-1 is best against DAGMM/STD-DAGMM and 1e-2 best against Delta-DAGMM

if ATTACK_METHOD.value == "advanced_free_rider":
    ADDITIONAL_ATTACK_CONFIG["n"] = NUM_CLIENTS
# ~~~~~~~~

num_malicious = max(0, int(NUM_CLIENTS * PERCENT_MALICIOUS))
malicious_clients = sorted(random.sample(range(NUM_CLIENTS), num_malicious))

print(f"Selected {num_malicious} malicious clients out of {NUM_CLIENTS}")
print("Malicious client IDs:", malicious_clients)

import yaml

# Store the dataset to use as a .yaml file
with open("./config/dataset.yaml", "w") as f:
    yaml.dump({"dataset": DATASET}, f)

# Store the list of malicious clients as a .yaml file
with open("./config/malicious_clients.yaml", "w") as f:
    yaml.dump({"malicious_clients": malicious_clients}, f)

# Store the attack method as a .yaml file
with open("./config/attack_method.yaml", "w") as f:
     yaml.dump({"attack_method": ATTACK_METHOD.value}, f)
     if ADDITIONAL_ATTACK_CONFIG:
        yaml.dump(ADDITIONAL_ATTACK_CONFIG, f)

# Store the detection method as a .yaml file
with open("./config/detection_method.yaml", "w") as f:
    yaml.dump({"detection_method": DETECTION_METHOD.value}, f)
    if ADDITIONAL_DETECTION_CONFIG:
        yaml.dump(ADDITIONAL_DETECTION_CONFIG, f)

import toml

NUM_CORES = 24  # For 13th Gen Intel(R) Core(TM) i9-13900KF
NUM_CPUS = 1

# When running on GPU, assign an entire GPU for each client
NUM_GPUS = 1/NUM_CLIENTS
min_gpu_perc = 1/NUM_CORES

if NUM_GPUS < min_gpu_perc:
    NUM_GPUS = min_gpu_perc
    
print("num_gpus: ", NUM_GPUS)

# Load pyproject.toml
with open("pyproject.toml", "r") as f:
    data = toml.load(f)

# Modify the values
data["tool"]["flwr"]["app"]["config"]["num-server-rounds"] = NUM_ROUNDS
data["tool"]["flwr"]["app"]["config"]["fraction-fit"] = FRACTION_FIT
data["tool"]["flwr"]["app"]["config"]["use-wandb"] = USE_WANDB

data["tool"]["flwr"]["federations"]["local-sim"]["options"]["num-supernodes"] = NUM_CLIENTS
data["tool"]["flwr"]["federations"]["local-sim-gpu"]["options"]["num-supernodes"] = NUM_CLIENTS
data["tool"]["flwr"]["federations"]["local-sim-gpu"]["options"]["backend"]["client-resources"]["num-cpus"] = NUM_CPUS
data["tool"]["flwr"]["federations"]["local-sim-gpu"]["options"]["backend"]["client-resources"]["num-gpus"] = NUM_GPUS

# Save changes back to pyproject.toml
with open("pyproject.toml", "w") as f:
    toml.dump(data, f)

Using the detection method:  wef
Selected 2 malicious clients out of 10
Malicious client IDs: [0, 1]
num_gpus:  0.1
