In [73]:
import flwr
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import DirichletPartitioner
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score
from matplotlib import pyplot as plt
import csv
import os
from sklearn.preprocessing import LabelEncoder
from datasets import Dataset
import pandas as pd
import random
import numpy as np
import copy
from Levenshtein import distance as levenshtein_distance

print(f"Flower {flwr.__version__} is installed")


Flower 1.15.2 is installed


In [74]:
def load_federated_data(num_clients, alpha, test_size, val_size):
    """
    Loads the CSV file, drops the 'attack_cat' column if present, and converts specified 
    non-numeric columns ("proto", "service", "state") to numeric using label encoding.
    The dataset is then split globally into training, validation, and test sets.
    The global training set is partitioned non-iid among clients using Flower's DirichletPartitioner.
    
    Args:
        num_clients (int): Number of federated clients.
        alpha (float): Dirichlet concentration parameter.
        test_size (float): Fraction of the overall data to reserve as test data.
        val_size (float): Fraction of the overall data to reserve as validation data.
        
    Returns:
        tuple: A tuple containing:
            - dict: A dictionary mapping client IDs to their local training DataFrame.
            - pd.DataFrame: The global test set.
            - pd.DataFrame: The global validation set.
            - dict: A dictionary mapping each of the categorical columns ("proto", "service", "state")
                    to its label encoder mapping (list of classes).
    """
    # Hard-coded file path and random seed
    file_path = "./Datasets/merged_UNSW_NB15.csv"
    random_state = 42

    # Load CSV into a pandas DataFrame and drop the 'attack_cat' column if it exists
    df = pd.read_csv(file_path)
    print("Original data shape:", df.shape)
    if "attack_cat" in df.columns:
        df.drop(columns=["attack_cat"], inplace=True)
        print("Dropped 'attack_cat'. New shape:", df.shape)

    
    # Label encoding for specified categorical columns
    label_encoders = {}
    for col in ["proto", "service", "state"]:
        if col in df.columns:
            le = LabelEncoder()
            df[col] = le.fit_transform(df[col])
            label_encoders[col] = {"classes": le.classes_.tolist()}
            print(f"Column '{col}' encoded with classes: {le.classes_.tolist()}")
    
    # Global split: first separate out the combined test+validation set.
    total_test_val = test_size + val_size
    print(total_test_val)
    global_train, global_test_val = train_test_split(
        df,
        test_size=total_test_val,
        stratify=df["label"],
        random_state=random_state
    )
    
    # Split the global test+validation set into separate test and validation sets.
    if val_size == 0:
        print("No Val size set doing centralised training")
    elif test_size == 0:
        print("No Test size please set a test size")
        return
    else:
        test_ratio = test_size / total_test_val
        global_test, global_val = train_test_split(
            global_test_val,
            test_size=(1 - test_ratio),
            stratify=global_test_val["label"],
            random_state=random_state
        )
    
    print("Global train shape:", global_train.shape)
    print("Global validation shape:", global_val.shape) if val_size != 0 else print("No validation set")
    print("Global test shape:", global_test.shape) if val_size != 0 else print("Global test shape:", global_test_val.shape)
    
    # Partition the global training data among clients using Flower's DirichletPartitioner.
    if num_clients == 1:
        # If there is only one client, return the entire training set as the client's training set.
        return {0: global_train}, global_test_val, global_test_val, label_encoders #Val is not needed for centralised
    hf_train_dataset = Dataset.from_pandas(global_train,preserve_index=False)
    partitioner = DirichletPartitioner(
        num_partitions=num_clients,
        partition_by="label",
        alpha=alpha,
        min_partition_size=10,
        self_balancing=True,
        shuffle=True,
        seed=random_state,
    )
    partitioner.dataset = hf_train_dataset
    
    client_train_dataset = {}
    
    
    # For each client, load the iid training partition.
    for client in range(num_clients):
        client_train = partitioner.load_partition(client).to_pandas()
        client_train_dataset[client] = client_train
        print(f"Client {client}: Train {client_train.shape}")
    
    return client_train_dataset, global_test, global_val, label_encoders


In [75]:
class MetricsTracker:
    def __init__(self, output_dir="metrics"):
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        self.metrics = []

    def log_round_metrics(self, round_num, global_metrics, client_metrics, merge_method, num_clients, num_trees, max_depth, criterion):
        round_metrics = {
            "round_num": round_num,
            "merge_method": merge_method,
            "num_clients": num_clients,
            "num_trees": num_trees,
            "max_depth": max_depth,
            "criterion": criterion,
        }
        round_metrics.update(global_metrics)
        round_metrics.update(client_metrics)
        self.metrics.append(round_metrics)

    def save_metrics_to_csv(self, filename="metrics.csv"):
        keys = self.metrics[0].keys()
        with open(os.path.join(self.output_dir, filename), 'w', newline='') as output_file:
            dict_writer = csv.DictWriter(output_file, fieldnames=keys)
            dict_writer.writeheader()
            dict_writer.writerows(self.metrics)

    def plot_metrics(self):
        rounds = [m["round_num"] for m in self.metrics]
        accuracies = [m["global_accuracy"] for m in self.metrics]
        recalls = [m["global_recall"] for m in self.metrics]
        roc_aucs = [m["global_roc_auc"] for m in self.metrics]
        precisions = [m["global_precision"] for m in self.metrics]
        f1_scores = [m["global_f1_score"] for m in self.metrics]

        plt.figure(figsize=(10, 6))
        plt.plot(rounds, accuracies, label="Accuracy")
        plt.plot(rounds, recalls, label="Recall")
        plt.plot(rounds, roc_aucs, label="ROC AUC")
        plt.plot(rounds, precisions, label="Precision")
        plt.plot(rounds, f1_scores, label="F1 Score")
        plt.xlabel("Round")
        plt.ylabel("Score")
        plt.title("Global Model Metrics Over Rounds")
        plt.legend()
        plt.savefig(os.path.join(self.output_dir, "metrics_plot.png"))
        plt.show()

In [76]:
class ModelMerger:
    def __init__(self, global_val):
        """
        Initialize with the global validation dataset.
        This dataset is used for methods that need to evaluate model performance.
        """
        self.global_val = global_val

    def merge_models_randomly(self, client_models, num_input):
        # Collect all trees from all client models.
        all_trees = [tree for model in client_models for tree in model.estimators_]
        print("Number of trees in global model:", len(all_trees))
        selected_trees = random.sample(all_trees, min(len(all_trees), num_input))
        
        merged_model = RandomForestClassifier(n_estimators=0, warm_start=True)
        source_model = client_models[0]
        
        merged_model.n_features_in_ = source_model.n_features_in_
        merged_model.classes_ = source_model.classes_
        merged_model.n_classes_ = len(source_model.classes_)
        merged_model.n_outputs_ = 1
        if hasattr(source_model, 'feature_names_in_'):
            merged_model.feature_names_in_ = source_model.feature_names_in_
            
        merged_model.estimators_ = selected_trees
        merged_model.n_estimators = len(selected_trees)
        return merged_model

    def merge_models_by_impurity(self, client_models, num_input):
        all_trees = []
        # Collect all trees along with their root node impurity (lower is better)
        for model in client_models:
            for tree in model.estimators_:
                root_impurity = tree.tree_.impurity[0]
                all_trees.append((tree, root_impurity))
                
        sorted_trees = sorted(all_trees, key=lambda x: x[1])
        selected_trees = [tree for tree, _ in sorted_trees[:num_input]]
        print("Number of trees selected for global model:", len(selected_trees))
        
        merged_model = RandomForestClassifier(n_estimators=0, warm_start=True)
        source_model = client_models[0]
        merged_model.n_features_in_ = source_model.n_features_in_
        merged_model.classes_ = source_model.classes_
        merged_model.n_classes_ = len(source_model.classes_)
        merged_model.n_outputs_ = 1
        if hasattr(source_model, "feature_names_in_"):
            merged_model.feature_names_in_ = source_model.feature_names_in_
            
        merged_model.estimators_ = selected_trees
        merged_model.n_estimators = len(selected_trees)
        return merged_model

    def merge_models_weight_global(self, client_models, num_input):
        X_val = self.global_val.drop(columns=["label"]).values
        y_val = self.global_val["label"].values
        all_trees = [(tree, accuracy_score(tree.predict(X_val), y_val))
                     for model in client_models for tree in model.estimators_]
        best_trees = sorted(all_trees, key=lambda x: x[1], reverse=True)[:num_input]
        merged_model = RandomForestClassifier(n_estimators=len(best_trees), warm_start=True)
        merged_model.estimators_ = [tree for tree, _ in best_trees]
        return merged_model

    def merge_models_diversity(self, client_models, num_input):
        """
        Select trees to maximize feature diversity in the global model.
        """
        all_trees = [tree for model in client_models for tree in model.estimators_]
        # First, select a few of the best trees based on impurity.
        tree_impurities = [(tree, tree.tree_.impurity[0]) for tree in all_trees]
        sorted_trees = sorted(tree_impurities, key=lambda x: x[1])
        selected_trees = [tree for tree, _ in sorted_trees[:num_input // 4]]
        
        remaining_trees = [tree for tree, _ in sorted_trees[num_input // 4:]]
        feature_usage = {}
        # Initialize feature usage from already selected trees.
        for tree in selected_trees:
            for feature in self._get_important_features(tree):
                feature_usage[feature] = feature_usage.get(feature, 0) + 1
        
        # Greedy selection to maximize feature diversity.
        while len(selected_trees) < num_input and remaining_trees:
            best_tree = None
            best_score = None
            best_tree_idx = None
            for idx, tree in enumerate(remaining_trees):
                tree_features = self._get_important_features(tree)
                if not tree_features:
                    continue
                score = sum(feature_usage.get(feature, 0) for feature in tree_features) / len(tree_features)
                if best_tree is None or score < best_score:
                    best_tree = tree
                    best_score = score
                    best_tree_idx = idx
            if best_tree is None:
                break
            selected_trees.append(best_tree)
            remaining_trees.pop(best_tree_idx)
            for feature in self._get_important_features(best_tree):
                feature_usage[feature] = feature_usage.get(feature, 0) + 1
        
        merged_model = RandomForestClassifier(n_estimators=0, warm_start=True)
        source_model = client_models[0]
        merged_model.n_features_in_ = source_model.n_features_in_
        merged_model.classes_ = source_model.classes_
        merged_model.n_classes_ = len(source_model.classes_)
        merged_model.n_outputs_ = 1
        if hasattr(source_model, 'feature_names_in_'):
            merged_model.feature_names_in_ = source_model.feature_names_in_
        merged_model.estimators_ = selected_trees
        merged_model.n_estimators = len(selected_trees)
        return merged_model

    def merge_models_weighted_voting(self, client_models, num_input):
        """
        Create an ensemble where each client contributes proportionally to their local performance.
        """
        client_scores = []
        # Use global_val as the evaluation set.
        X_val = self.global_val.drop(columns=["label"]).values
        y_val = self.global_val["label"].values
        for i, model in enumerate(client_models):
            score = accuracy_score(model.predict(X_val), y_val)
            client_scores.append((i, score))
        
        sorted_clients = sorted(client_scores, key=lambda x: x[1], reverse=True)
        total_score = sum(score for _, score in sorted_clients)
        trees_per_client = {}
        remaining_trees = num_input
        for idx, (client_idx, score) in enumerate(sorted_clients):
            if idx == len(sorted_clients) - 1:
                trees_per_client[client_idx] = remaining_trees
            else:
                client_trees = max(1, int((score / total_score) * num_input))
                trees_per_client[client_idx] = min(client_trees, remaining_trees)
                remaining_trees -= trees_per_client[client_idx]
        
        selected_trees = []
        for i, model in enumerate(client_models):
            if i not in trees_per_client or trees_per_client[i] == 0:
                continue
            trees_to_select = trees_per_client[i]
            if trees_to_select >= len(model.estimators_):
                selected_trees.extend(model.estimators_)
            else:
                tree_impurities = [(tree, tree.tree_.impurity[0]) for tree in model.estimators_]
                sorted_trees = sorted(tree_impurities, key=lambda x: x[1])
                selected_trees.extend([tree for tree, _ in sorted_trees[:trees_to_select]])
        
        merged_model = RandomForestClassifier(n_estimators=0, warm_start=True)
        source_model = client_models[0]
        merged_model.n_features_in_ = source_model.n_features_in_
        merged_model.classes_ = source_model.classes_
        merged_model.n_classes_ = len(source_model.classes_)
        merged_model.n_outputs_ = 1
        if hasattr(source_model, 'feature_names_in_'):
            merged_model.feature_names_in_ = source_model.feature_names_in_
        merged_model.estimators_ = selected_trees
        merged_model.n_estimators = len(selected_trees)
        return merged_model
    
    def merge_models_prune_similar(self, client_models):
        print("Number of trees in global model before prune:", sum(len(model.estimators_) for model in client_models))
        unique_trees = []
        pruned_client_models = []
        for model in client_models:
            pruned_estimators = []
            for tree in model.estimators_:
                if not any(self.are_trees_similar(tree, existing_tree) for existing_tree in unique_trees):
                    unique_trees.append(tree)
                    pruned_estimators.append(tree)
            pruned_model = RandomForestClassifier(n_estimators=len(pruned_estimators), warm_start=True)
            pruned_model.estimators_ = pruned_estimators
            pruned_client_models.append(pruned_model)
        print("Number of trees in global model after prune:", len(unique_trees))
        return pruned_client_models

    def are_trees_similar(self, tree1, tree2, similarity_threshold=0.8):
        """Compare two trees based on their structure and parameters."""
        tree1_struct = tree1.tree_
        tree2_struct = tree2.tree_
        
        # If trees have very different node counts, they're not similar
        if abs(tree1_struct.node_count - tree2_struct.node_count) > 5:
            return False
        
        # Check only the first few nodes for efficiency.
        max_check_node = min(7, min(tree1_struct.node_count, tree2_struct.node_count))
        similar_nodes = 0
        for i in range(max_check_node):
            # If both nodes are leaf nodes, count as similar
            if tree1_struct.children_left[i] == -1 and tree2_struct.children_left[i] == -1:
                similar_nodes += 1
                continue
            
            # If one is leaf and the other is not, skip comparison.
            if (tree1_struct.children_left[i] == -1) != (tree2_struct.children_left[i] == -1):
                continue
                
            # Check if split features are the same and thresholds are close.
            if tree1_struct.feature[i] == tree2_struct.feature[i]:
                if abs(tree1_struct.threshold[i] - tree2_struct.threshold[i]) < 0.1:
                    similar_nodes += 1
        
        similarity = similar_nodes / max_check_node
        return similarity > similarity_threshold
    
    def _get_important_features(self, tree, importance_threshold=0.01):
        """Helper function to extract important features from a tree."""
        importances = tree.feature_importances_
        return [i for i, imp in enumerate(importances) if imp > importance_threshold]

In [None]:
class ModelEvaluator:
    """Utility class for evaluating models with patched predictions for class imbalance."""
    
    @staticmethod
    def patch_single_tree_proba(p, model_classes, global_classes):
        full_p = np.zeros((p.shape[0], len(global_classes)))
        for i, cls in enumerate(global_classes):
            if cls in model_classes:
                full_p[:, i] = p[:, model_classes == cls].flatten()
        return full_p
    
    @staticmethod
    def patched_predict_proba(model, X, global_classes):
        tree_probas = []
        for tree in model.estimators_:
            p = tree.predict_proba(X)
            tree_model_classes = tree.classes_
            patched_p = ModelEvaluator.patch_single_tree_proba(p, tree_model_classes, global_classes)
            tree_probas.append(patched_p)
        return np.mean(tree_probas, axis=0)
    
    @staticmethod
    def evaluate_model(model, X_test, y_test, model_name="Model", global_classes=None):
        if global_classes is None:
            global_classes = np.sort(np.unique(y_test))
            
        try:
            proba = ModelEvaluator.patched_predict_proba(model, X_test, global_classes)
            y_pred = global_classes[np.argmax(proba, axis=1)]
            
            accuracy = accuracy_score(y_test, y_pred)
            recall = recall_score(y_test, y_pred, average='macro',zero_division=0)
            precision = precision_score(y_test, y_pred, average='macro', zero_division=0)
            f1 = f1_score(y_test, y_pred, average='macro',zero_division=0)
            
            # Ensure y_test is a 1-dimensional array
            if y_test.ndim > 1 and y_test.shape[1] > 1:
                y_test = np.argmax(y_test, axis=1)
            
            try:
                if len(global_classes) == 2:
                    # For binary classification, use the probabilities for the positive class.
                    roc_auc = roc_auc_score(y_test, proba[:, 1])
                else:
                    # For multi-class classification.
                    roc_auc = roc_auc_score(y_test, proba, multi_class='ovr')
            except ValueError as e:
                print(f"Error calculating ROC AUC for {model_name}: {e}")
                roc_auc = None
            
            metrics = {
                "accuracy": accuracy,
                "recall": recall,
                "roc_auc": float(roc_auc),
                "precision": precision,
                "f1_score": f1
            }
            
            print(f"{model_name} accuracy: {metrics.get('accuracy')}")
            return metrics
            
        except Exception as e:
            print(f"Error evaluating {model_name}: {e}")
            return {
                "accuracy": None,
                "recall": None,
                "roc_auc": None,
                "precision": None,
                "f1_score": None
            }
    

In [None]:
class Server:
    def __init__(self, num_clients, clients_per_round, merge_method, num_global_trees, global_test_data, global_val_data, prune_similar):
        self.num_clients = num_clients
        self.clients_per_round = clients_per_round
        self.merge_method = merge_method
        self.num_global_trees = num_global_trees
        self.clients = []
        self.global_model = RandomForestClassifier(n_estimators=0, warm_start=True)
        self.global_test = global_test_data
        self.global_val = global_val_data
        self.model_merger = ModelMerger(self.global_val)
        self.prune_similar = prune_similar
        

    def register_client(self, client):
        self.clients.append(client)

    def select_clients(self):
      # Randomly choose clients each round
      random.seed(42) # For reproducibility
      available_clients = random.sample(self.clients, self.clients_per_round)
      print(f"Selected Clients: {[client.client_id for client in available_clients]}")
      return available_clients

    def merge_models(self, client_models):
        # Map merge method names to the corresponding ModelMerger method.
        merge_strategies = {
            "random": self.model_merger.merge_models_randomly,
            "weight_global": self.model_merger.merge_models_weight_global,
            "weight_voting": self.model_merger.merge_models_weighted_voting,
            "impurity": self.model_merger.merge_models_by_impurity,
            "diversity": self.model_merger.merge_models_diversity,
            
        }
        if self.merge_method not in merge_strategies:
            raise ValueError(f"Invalid merge method: {self.merge_method}")
        # Some methods (like prune_similar) do not require the num_global_trees parameter.
        if self.prune_similar:
            pruned_client_models = self.model_merger.merge_models_prune_similar(client_models)
            return merge_strategies[self.merge_method](pruned_client_models, self.num_global_trees)
        return merge_strategies[self.merge_method](client_models, self.num_global_trees)

    def distribute_global_model(self):
        for client in self.clients:
            client.update_model(self.global_model)
        print("Global model distributed to all clients.")
    
    
    def evaluate_global_model(self):
        X_test = self.global_test.drop(columns=["label"]).values
        y_test = self.global_test["label"].values
        return ModelEvaluator.evaluate_model(self.global_model, X_test, y_test, "Global model")

    def train_federated(self, configs):
        for round_num in range(1, configs["num_rounds"] + 1):
            print(f"\nStarting Round {round_num}")
            selected_clients = self.select_clients()
            round_models = []
            client_metrics = {}

            for client in selected_clients:
                client.train(
                    configs["num_trees"],
                    configs["max_depth"],
                    configs["num_max_features"],
                    configs["n_jobs"]
                )
                round_models.append(client.model)

            print("\nTraining finished")
            if len(round_models) > 1:
                self.global_model = self.merge_models(round_models)
            else:
                self.global_model = round_models[0]
            print("\nMerged models")
            
            global_metrics = self.evaluate_global_model()
            if global_metrics is None:
                global_metrics = {
                    "accuracy": None,
                    "recall": None,
                    "roc_auc": None,
                    "precision": None,
                    "f1_score": None
                }
            metrics_tracker.log_round_metrics(
                round_num,
                {"global_" + k: v for k, v in global_metrics.items()},
                client_metrics,
                configs["merge_method"],
                configs["clients_per_round"],
                configs["num_trees"],
                configs["max_depth"],
                configs["criterion"]
            )

            print("\nEvaluating Round Results:")
            for client in selected_clients:
                client.evaluate()
            print("Round",round_num,"complete!")
            
            
            self.distribute_global_model()

        print("\nFederated Training Complete!")
        metrics_tracker.save_metrics_to_csv()
        metrics_tracker.plot_metrics()


In [79]:
class Client:
    def __init__(self, client_id, train_data, test_data, model_params):
        self.client_id = client_id
        self.model_params = model_params
        self.model = RandomForestClassifier(**model_params)
        # Directly assign preprocessed training and test data
        self.train_data = train_data
        self.test_data = test_data

    def train(self, num_trees, max_depth, num_max_features, n_jobs):
        print(f"Trees received from server at Client: {self.client_id} No Estimators: {self.model.n_estimators}")

        X_train = self.train_data.drop(columns=["label"]).values
        y_train = self.train_data["label"].values

        self.model.n_estimators = num_trees
        self.model.max_depth = max_depth
        self.model.max_features = num_max_features
        self.model.n_jobs = n_jobs
        self.model.fit(X_train, y_train)

        print(f"Client: {self.client_id} trained {self.model.n_estimators} trees")

        self.prune_redundant_trees() if configs["NUM_CLIENTS"] != 1 or configs["client_pruning"] != True else print("No pruning on client")

    def prune_redundant_trees(self):
        """A simple approach to prune obviously redundant trees."""
        if self.model.n_estimators <= 5:
            return

        tree_groups = {}
        for i, tree in enumerate(self.model.estimators_):
            importances = tree.feature_importances_
            top_features = np.argsort(importances)[-3:]
            signature = tuple(sorted(top_features))
            tree_groups.setdefault(signature, []).append(i)

        pruned_trees = []
        for group in tree_groups.values():
            if len(group) <= 2:
                for idx in group:
                    pruned_trees.append(self.model.estimators_[idx])
            else:
                for idx in group[:2]:
                    pruned_trees.append(self.model.estimators_[idx])

        tree_count_before_prune = self.model.n_estimators
        self.model.estimators_ = pruned_trees
        self.model.n_estimators = len(pruned_trees)
        print(f"Pruned {tree_count_before_prune - self.model.n_estimators} trees from client {self.client_id}. Total Trees: {self.model.n_estimators}")

    def evaluate(self, global_classes=None):
        X_test = self.test_data.drop(columns=["label"]).values
        y_test = self.test_data["label"].values
        model_name = f"Client {self.client_id}"
        return ModelEvaluator.evaluate_model(self.model, X_test, y_test, model_name, global_classes)


    def update_model(self, global_model):
        self.model = copy.deepcopy(global_model)

In [80]:
from timeit import default_timer as timer
global configs
configs = {
    #Federated Features
    "NUM_CLIENTS": 100,# If you set the client as 1 it will be centralized training, you should also set the size of val to 0 aswell then
    "clients_per_round": 10,
    "global_num_trees": 100,
    "num_rounds": 20,
    "merge_method": "weight_global",# random, weight_global, weight_voting, impurity, diversity, prune_similar
    "prune_similar": False,
    "client_pruning": True,
    #Training parameters
    "num_trees": 200,
    "max_depth": 4,
    "num_max_features": 20,
    "n_jobs": -1,
    #Data features partition
    "test_size": 0.2, 
    "val_size": 0.1, # Should set this to 0 if you are doing centralised training and set test size to 0.3
    "dirichlet_alpha": 0.8,
    #Model parameters
    "criterion": "gini", # or "entropy"
}



# Load federated data
start = timer()
client_train_dataset, global_test, global_val, label_encoders = load_federated_data(
    num_clients=configs["NUM_CLIENTS"],
    alpha=configs["dirichlet_alpha"],
    test_size=configs["test_size"],
    val_size=configs["val_size"]
)
end = timer()
print("Time Elapsed running load data:", end - start)

metrics_tracker = MetricsTracker()

# Initialize server
server = Server(
    num_clients=configs["NUM_CLIENTS"],
    clients_per_round=configs["clients_per_round"],
    merge_method=configs["merge_method"],
    num_global_trees=configs["global_num_trees"],
    global_test_data=global_test,
    global_val_data=global_val,
    prune_similar=configs["prune_similar"]
)


# Register clients
for client_id, client_train in client_train_dataset.items():
    client = Client(
        client_id,
        client_train,       # local training data
        global_test,        # global test set
        {"n_estimators": 0,
         "warm_start": True,
         "bootstrap": True,
         "criterion": configs["criterion"],}
    )
    server.register_client(client)

# Run federated training
server.train_federated(configs)


Original data shape: (257672, 45)
Dropped 'attack_cat'. New shape: (257672, 44)
Column 'proto' encoded with classes: ['3pc', 'a/n', 'aes-sp3-d', 'any', 'argus', 'aris', 'arp', 'ax.25', 'bbn-rcc', 'bna', 'br-sat-mon', 'cbt', 'cftp', 'chaos', 'compaq-peer', 'cphb', 'cpnx', 'crtp', 'crudp', 'dcn', 'ddp', 'ddx', 'dgp', 'egp', 'eigrp', 'emcon', 'encap', 'etherip', 'fc', 'fire', 'ggp', 'gmtp', 'gre', 'hmp', 'i-nlsp', 'iatp', 'ib', 'icmp', 'idpr', 'idpr-cmtp', 'idrp', 'ifmp', 'igmp', 'igp', 'il', 'ip', 'ipcomp', 'ipcv', 'ipip', 'iplt', 'ipnip', 'ippc', 'ipv6', 'ipv6-frag', 'ipv6-no', 'ipv6-opts', 'ipv6-route', 'ipx-n-ip', 'irtp', 'isis', 'iso-ip', 'iso-tp4', 'kryptolan', 'l2tp', 'larp', 'leaf-1', 'leaf-2', 'merit-inp', 'mfe-nsp', 'mhrp', 'micp', 'mobile', 'mtp', 'mux', 'narp', 'netblt', 'nsfnet-igp', 'nvp', 'ospf', 'pgm', 'pim', 'pipe', 'pnni', 'pri-enc', 'prm', 'ptp', 'pup', 'pvp', 'qnx', 'rdp', 'rsvp', 'rtp', 'rvd', 'sat-expak', 'sat-mon', 'sccopmce', 'scps', 'sctp', 'sdrp', 'secure-vmtp', 

KeyboardInterrupt: 

In [None]:
#Client specific models so merge for client model and global model at client side.
#use the same data set as victor
#pruning on client side? maybe if using large data set.
#global mode and client model might not be the same.
#define client models they will not use the server side model
#lit review about fed learning 3 -4 pages random forrest dec trees (2 pgs)
# anomly detection related work 3- 4 pages
#intro(2) and problem statement (1) goals (1)
#describe how you design the system (10) merdging aswell
#results analysis tables charts (5pgs)
#conclusion future work (2pgs)
#thursday 11.30 every week. with victor


#I have things that I am not doing. Like privacy or communication overhead. I dont touch on those topics in my implementation.
#Should I mention them in my paper especially in the background and say that I am not doing them?
#Should this paper be more about an implementation of federated learning rather than the actual federated learning itself? Where I look at experimentation and results?
#Challanges of IoT and FL is data is not labelled in real world. What should I write about this? 

# Meeting Agenda

## Overview

- **Explain what I have so far**

## Data Preprocessing

- No normalisation, tree models don’t benefit from this.
- Drop `attack_cat` as this will be binary exclusively.
- Proceed to label string data.
- Split the data into train/test/validation sets.
- Train data goes through Dirichlet partition.

## Server Logic

### Clients Setup
- Define number of clients.
- Define number that train each round.
- Server selects clients randomly.

### Training
- Server trains each client.
- Models are just replaced with the global model at client side; there is no averaging or any additional logic.

### Merging Methods
- I have separate logic for merging with **5 different methods**:
  - Random
  - Weight_global
  - Weight_voting
  - Impurity
  - Diversity
- Also have the option to prune similar trees.

## Data Collection and Evaluation

### Data Collection Class
- Concerned with accuracies and plotting them mainly.

### Evaluator Class
- Needed because some trees don’t have all the classes so I need to do patching.
- This allows sharing code between client and global models.

## Tracking Metrics

### Per Round
- Round number
- Accuracy
- Recall
- ROC AUC
- Precision
- F1 scores
- The clients that participate that round (not sure if this is necessary)

### Global Settings (Not Round Specific)
- Number of trees that server gives to clients
- Number of trees the clients train up to
- Merge method being used
- Number of clients participating
- Criterion 
- (whether using pruning or not on server)

### Additional Tracking
- Client distributions for each experiment to show non-IIDness.

## Experiment Considerations

### Parameters to Vary
- Tree parameters (start with fewer/more trees)
- Merge method being used
- Criterion 
- (pruning on the server or not)
- Max depth and Max features (unsure if these are useful in federated learning)

### Request for Input
- Would like suggestions on what exactly to focus on during experiments.
- Should client selection and partitioning be deterministic for experiments. I feel like it should be. 
