# 1. Import Necessary Modules

In [2]:
from collections import OrderedDict
from typing import List, Tuple, Dict, Optional, Union

import os,os.path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, TensorDataset, random_split

from collections import OrderedDict

import flwr as fl
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Metrics, Context
from flwr.server import ServerApp, ServerConfig, ServerAppComponents, ClientManager
from flwr.server.strategy import Strategy, FedAvg
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset
from flwr.common import ndarrays_to_parameters, NDArrays, Scalar, Context
from flwr.common import FitRes, Parameters, parameters_to_ndarrays
from flwr.server.client_proxy import ClientProxy
from flwr.common.logger import set_logger_propagation

from enum import Enum
from pathlib import Path

# 2. Configurable Parameters

In [3]:
# Device type
device = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE = torch.device(device)

# Number of clients and features
NUM_CLIENTS = 5
NUM_FEATURES = 7

# Sample size range
MIN_SAMPLES = 1000
MAX_SAMPLES = 2000

# Number of rouge clients
NUM_ROUGE_CLIENTS = 0

# Distribution configuration
DISTRIBUTIONS = {
    'Square Meters': {'type': 'normal', 'mu_range': (100, 300), 'sigma_range': (10, 30)},
    'Year Built': {'type': 'normal', 'mu_range': (1950, 2020), 'sigma_range': (1, 5)},
    'Neighborhood Quality': {'type': 'categorical', 'categories': ['Low', 'Medium', 'High'], 'prob_range': [(0.5, 0.3, 0.2), (0.3, 0.4, 0.3)]},
    'Distance to Amenities': {'type': 'uniform', 'low_range': (100, 500), 'high_range': (500, 1000)},
    'Number of Rooms': {'type': 'poisson', 'lambda_range': (2, 5)},
    #'Lot Size': {'type': 'normal', 'mu_range': (500, 2000), 'sigma_range': (100, 300)},
    'House Style': {'type': 'categorical', 'categories': ['Single Family', 'Condo', 'Townhouse'], 'prob_range': [(0.6, 0.3, 0.1), (0.4, 0.4, 0.2)]},
    'Local Economic Index': {'type': 'normal', 'mu_range': (50, 150), 'sigma_range': (10, 30)}
}

DISTRIBUTIONS2 = {
    'Square Meters': {'type': 'normal', 'mu_range': (0, 0), 'sigma_range': (1, 1)},
    'Year Built': {'type': 'normal', 'mu_range': (0, 0), 'sigma_range': (1, 1)},
    'Neighborhood Quality': {'type': 'normal', 'mu_range': (0, 0), 'sigma_range': (1, 1)},
    'Distance to Amenities': {'type': 'normal', 'mu_range': (0, 0), 'sigma_range': (1, 1)},
    'Number of Rooms': {'type': 'normal', 'mu_range': (0, 0), 'sigma_range': (1, 1)},
    'House Style': {'type': 'normal', 'mu_range': (0, 0), 'sigma_range': (1, 1)},
    'Local Economic Index': {'type': 'normal', 'mu_range': (0, 0), 'sigma_range': (1, 1)}
}

# 3. Utility / Helper Functions & Classes

In [20]:
def generate_feature(feature_name, dist_params, num_samples):
    """Generates data for a specific feature based on its distribution."""
    dist_type = dist_params['type']
    if dist_type == 'normal':
        mu = np.random.uniform(*dist_params['mu_range'])
        sigma = np.random.uniform(*dist_params['sigma_range'])
        if feature_name == "Year Built":
            return np.round(np.random.normal(mu, sigma, num_samples)).astype(int), {'mu': mu, 'sigma': sigma}
        return np.random.normal(mu, sigma, num_samples), {'mu': mu, 'sigma': sigma}
    elif dist_type == 'uniform':
        low = np.random.uniform(*dist_params['low_range'])
        high = np.random.uniform(*dist_params['high_range'])
        return np.random.uniform(low, high, num_samples), {'low': low, 'high': high}
    elif dist_type == 'categorical':
        prob = dist_params['prob_range'][np.random.randint(0, len(dist_params['prob_range']))]
        return np.random.choice(dist_params['categories'], num_samples, p=prob), {'probabilities': prob}
    elif dist_type == 'poisson':
        lam = np.random.uniform(*dist_params['lambda_range'])
        return np.random.poisson(lam, num_samples), {'lambda': lam}
    else:
        raise ValueError(f"Unknown distribution type: {dist_type}")
    
def encode_categorical_features(client_data):
    """Encodes categorical features as numerical values."""
    encoded_data = {}
    encoding_maps = {}
    for feature_name, values in client_data.items():
        if isinstance(values[0], str):  # Check if the feature is categorical
            unique_values = list(set(values))
            encoding_map = {val: idx for idx, val in enumerate(unique_values)}
            encoded_data[feature_name] = np.array([encoding_map[val] for val in values])
            encoding_maps[feature_name] = encoding_map
        else:
            encoded_data[feature_name] = values
    return encoded_data, encoding_maps

def decode_categorical_features(client_data, encoding_maps):
    """Decodes numerical categorical features back to their string values."""
    decoded_data = {}
    for feature_name, values in client_data.items():
        if feature_name in encoding_maps:
            decoding_map = {v: k for k, v in encoding_maps[feature_name].items()}
            decoded_data[feature_name] = [decoding_map[val] for val in values]
        else:
            decoded_data[feature_name] = values
    return decoded_data

def generate_target_variable(features, weights, noise_std=0):
    """Generates the target variable (house price) as a weighted sum of features with noise."""
    linear_combination = sum(w * features[i] for i, w in enumerate(weights))
    noise = np.random.normal(0, noise_std, len(features[0]))
    return linear_combination + noise

def add_noise_to_features(client_data, noise_level=0.5):
    """Adds noise to numerical features in the client data."""
    noisy_data = {}
    for feature_name, feature_values in client_data.items():
        if feature_name != 'House Price' and np.issubdtype(type(feature_values[0]), np.number):  # Only numerical features
            noise = np.random.normal(0, noise_level * np.std(feature_values), len(feature_values))
            noisy_data[feature_name] = feature_values + noise
        else:
            noisy_data[feature_name] = feature_values
    return noisy_data

def save_to_csv(data, filename):
    """Saves a DataFrame to a CSV file."""
    data.to_csv(filename, index=False)

def generate_house_pricing_dataset(
    num_clients=5,
    num_features=7,
    min_samples=1000,
    max_samples=10000,
    num_rouge_clients=0,
    distributions={},
    is_sum=False,
    path="./",
    seed=42):
    """
    Generates house pricing datasets for multiple clients, of which some are rouge (delivering noisy/faulty data).
    Returns a list of TensorDatasets, one for each client, where each dataset contains features and targets.

    Args:
        num_clients (int): Number of clients.
        num_features (int): Number of features.
        min_samples (int): Minimum samples per client dataset.
        max_samples (int): Maximum samples per client dataset.
        num_rouge_clients (int): Number of rouge clients delivering noisy data.
        distributions (dict): Distributions for feature generation.
        is_sum (bool): Wheter the target variable is a linear combination of the features with all weights = 1
        path (str): Path to save datasets and metadata.
        seed (int): Seed for reproducibility.

    Returns:
        List[TensorDataset]: A list where each element is a TensorDataset corresponding to one client's data.
    """
    # Output directory
    output_dir = Path(path)
    output_dir.mkdir(exist_ok=True)

    # Store TensorDatasets for all clients
    client_datasets = []

    # Metadata of one dataset creation instance, for trackability
    metadata = []

    # Select NUM_ROUGLE_CLIENTS many clients to act rouge
    np.random.seed(seed)
    rouge_clients = set(np.random.choice(range(1, num_clients + 1), num_rouge_clients, replace=False))

    # Generate weights for linear combination of the target variable, seed for reproducibility
    if is_sum:
        weights = np.ones(num_features)
    else:
        np.random.seed(seed)
        weights = np.random.uniform(0.1, 1, num_features)
    np.random.seed(None)

    #Generate features for each client
    for client_id in range(1, num_clients + 1):
        # Random number of samples in range [min_samples, max_samples]
        num_samples = np.random.randint(min_samples, max_samples + 1)

        # Keep track of the data and metadata
        client_data = {}
        client_metadata = {'Client_ID': client_id, 'Is_Rouge': client_id in rouge_clients}
        
        # Generate features
        for feature_name, dist_params in distributions.items():
            feature_data, params = generate_feature(feature_name, dist_params, num_samples)
            client_data[feature_name] = feature_data
            client_metadata[feature_name] = params

        # Add noise for rouge clients
        if client_id in rouge_clients:
            client_data = add_noise_to_features(client_data)

        # Encode categorical features
        encoded_client_data, encoding_maps = encode_categorical_features(client_data)

        # Generate target variable (house price) as a weighted linear combination
        features = [encoded_client_data[feature_name] for feature_name in distributions.keys()]
        encoded_client_data['House Price'] = generate_target_variable(features, weights, noise_std=0)

        # Convert encoded data to PyTorch tensors
        encoded_data = pd.DataFrame(encoded_client_data)
        feature_tensor = torch.tensor(encoded_data.drop(columns=["House Price"]).values, dtype=torch.float32)  # All but last column
        target_tensor = torch.tensor(encoded_data['House Price'].values.reshape(-1, 1), dtype=torch.float32)  # Last column

        # Create TensorDataset for the client
        client_dataset = TensorDataset(feature_tensor, target_tensor)
        client_datasets.append(client_dataset)
    

        # Decode categorical features before saving
        decoded_client_data = decode_categorical_features(encoded_client_data, encoding_maps)

        # Save client dataset to CSV
        client_df = pd.DataFrame(decoded_client_data)
        save_to_csv(client_df, output_dir / f"client_{client_id}.csv")

        # Save metadata
        client_metadata['Encoding Maps'] = encoding_maps
        client_metadata['weights'] = weights
        metadata.append(client_metadata)

        # Save metadata to a CSV file
        metadata_df = pd.DataFrame(metadata)
        save_to_csv(metadata_df, output_dir / "metadata.csv")

    return client_datasets

def create_client_loaders(
    client_datasets,
    train_ratio=0.7,
    val_ratio=0.2,
    batch_size=16,
    seed=42
):
     """
    Creates train, validation, and test DataLoaders for each client's TensorDataset.

    Args:
        client_datasets (list): A list of TensorDatasets, one for each client.
        train_ratio (float): Proportion of samples used for training.
        val_ratio (float): Proportion of samples used for validation.
        batch_size (int): Batch size for the DataLoaders.
        seed (int): Random seed for reproducibility.

    Returns:
        list: A list of tuples, where each tuple contains (train_loader, val_loader, test_loader) for a client.
    """
     
     torch.manual_seed(seed)

     loaders = []

     for dataset in client_datasets:
         total_samples = len(dataset)
         train_len = int(total_samples * train_ratio)
         val_len = int(total_samples * val_ratio)
         test_len = total_samples - train_len - val_len

         # Ensure reproducible splits
         generator = torch.Generator().manual_seed(seed)
         train_ds, val_ds, test_ds = random_split(
             dataset,
             lengths=[train_len, val_len, test_len],
             generator=generator
         )

         # Create DataLoaders
         train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
         val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
         test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

         loaders.append((train_loader, val_loader, test_loader))

     return loaders

def load_data(client_id, data_path, seed=42, batch_size=16):
    df = pd.read_csv(data_path + f'/client_{client_id}.csv')

    # Encode categorical features
    for col in df.select_dtypes(include=['object']).columns:
        df[col] = pd.Categorical(df[col]).codes

    # Separate features and target
    features = df.drop(columns=["House Price"]).values
    target = df["House Price"].values.reshape(-1, 1)

    # Convert to PyTorch tensors
    features_tensor = torch.tensor(features, dtype=torch.float32)
    target_tensor = torch.tensor(target, dtype=torch.float32)

    # Create a full dataset
    full_dataset = TensorDataset(features_tensor, target_tensor)

    # Determine lengths for splits
    total_len = len(full_dataset)
    train_len = int(0.7 * total_len)
    val_len = int(0.2 * total_len)
    test_len = total_len - train_len - val_len

    # Use random_split for reproducible splits
    generator = torch.Generator().manual_seed(seed + client_id)  # Client-specific seed
    train_dataset, val_dataset, test_dataset = random_split(
        full_dataset,
        lengths=[train_len, val_len, test_len],
        generator=generator
    )

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # print(f"Client {client_id} stats: mean={df['House Price'].mean()}, std={df['House Price'].std()}")

    return train_loader, val_loader, test_loader

class TaskType(Enum):

    CLASSFICATION = 0
    REGRESSION = 1


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]

def train(net, trainloader, epochs: int, verbose=False, device = "cpu", task_type = TaskType.CLASSFICATION):
    """Train the network on the training set."""
    if task_type == TaskType.CLASSFICATION:    
        criterion = torch.nn.CrossEntropyLoss()
    elif task_type == TaskType.REGRESSION:
        criterion = nn.MSELoss(reduction='sum')
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    net.to(device)
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = net(images)
            if task_type == TaskType.REGRESSION:
                outputs = outputs.squeeze()
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss.item()
            total += labels.size(0)
            if task_type == TaskType.CLASSFICATION:
                correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        if verbose:
            if task_type == TaskType.CLASSFICATION:
                epoch_acc = correct / total
                print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")
            elif task_type == TaskType.REGRESSION:
                print(f"Epoch {epoch+1}: train loss {epoch_loss}")

def test(net, testloader, device = "cpu"):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    net.to(device)
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

def test_regression(net, testloader, device="cpu"):
    """Evaluate the regression model on the entire test set."""
    criterion = nn.MSELoss(reduction="sum")
    sum_of_squares, total_samples = 0.0, 0
    net.eval()
    net.to(device)
    with torch.no_grad():
        for x, y in testloader:
            x, y = x.to(device), y.to(device)
            
            # To accomodate for single output layer
            target = y.view(-1)
            outputs = net(x).view(-1)

            sum_of_squares += criterion(outputs, target).item()
            total_samples += len(y)

    if total_samples > 0:
        avg_mse = sum_of_squares / total_samples
    else:
        avg_mse = 0.0
    avg_loss = avg_mse

    # Note that to make sure the consistence,
    # we return MSE and RMSE to match {loss, accuracy as the test function}
    return avg_loss, (avg_mse ** 0.5)

# Custom Client for House Pricing Dataset

class HousePricingClient(NumPyClient):
    def __init__(
        self,
        net: nn.Module,
        train_loader: DataLoader,
        val_loader: DataLoader,
        test_loader: DataLoader,
        device: torch.device,
        client_id: int,
        epochs: int = 1.
    ):
        super().__init__()
        self.net = net
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.device = device
        self.client_id = client_id
        self.epochs = epochs

    def get_parameters(self, config: Dict[str, Scalar]) -> List[np.ndarray]:
        return get_parameters(self.net)

    '''def set_parameters(self, parameters):
        state_dict = dict(zip(self.model.state_dict().keys(), parameters))
        self.model.load_state_dict({k: torch.tensor(v) for k, v in state_dict.items()})'''

    def fit(
        self, parameters: List[np.ndarray], config: Dict[str, Scalar]
    ) -> Tuple[List[np.ndarray], int, Dict[str, Scalar]]:
        set_parameters(self.net, parameters)
        train(self.net, self.train_loader, device=self.device, epochs=self.epochs, verbose=False)
        new_params = get_parameters(self.net)
        # Return partition-id in the metrics
        # The simplest way to store the model
        return new_params, len(self.train_loader.dataset), {"partition-id": self.client_id}
        

    def evaluate(
        self, parameters: List[np.ndarray], config: Dict[str, Scalar]
    ) -> Tuple[float, int, Dict[str, Scalar]]:
        set_parameters(self.net, parameters)
        loss, rmse = test_regression(self.net, self.val_loader, self.device)
        print(f"[Client {self.client_id}] Evaluate -> Loss: {loss:.4f}, RMSE: {rmse:.4f}")
        return float(loss), len(self.val_loader.dataset), {"RMSE": float(rmse)}


class DefaultStrategy(FedAvg):

    # A custom strategy to store all the parameters.
    # https://github.com/adap/flower/issues/487
    # https://flower.ai/docs/framework/how-to-save-and-load-model-checkpoints.html

    def __init__(self, model: type, total_round: int, only_last: bool = True, save_dir: str = "models", *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.save_dir = save_dir
        os.makedirs(self.save_dir, exist_ok=True)
        self.model = model
        self.total_round = total_round
        self.only_last = only_last

    def aggregate_fit(
        self,
        server_round: int,
        results: list[tuple[ClientProxy, FitRes]],
        failures: list[Union[tuple[ClientProxy, FitRes], BaseException]],
    ) -> tuple[Optional[Parameters], dict[str, Scalar]]:
        """
        Aggregate model weights using weighted average.
        Also save each client's model and the global server model.
        """

        if self.only_last and server_round < self.total_round:
            return super().aggregate_fit(server_round, results, failures)

        # Call aggregate_fit from base class (FedAvg) to aggregate parameters and metrics
        aggregated_parameters, aggregated_metrics = super().aggregate_fit(
            server_round, results, failures
        )

        # For each client which returned FitRes, save the client model
        for (_, fit_res) in results:
            id_ = fit_res.metrics["partition-id"]

            client_parameters: Optional[Parameters] = fit_res.parameters
            if client_parameters is not None:
                net = self.model()
                print(f"[Round {server_round}] Saving model for client {id_}...")

                # Convert `Parameters` to `list[np.ndarray]`
                client_ndarrays : list[np.ndarray] = parameters_to_ndarrays(
                    client_parameters
                )

                # Convert `list[np.ndarray]` to PyTorch `state_dict`
                params_dict = zip(net.state_dict().keys(), client_ndarrays)
                state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
                net.load_state_dict(state_dict, strict=True)

                # Save the model to disk
                torch.save(net.state_dict(), f"{self.save_dir}/client-{server_round}-{id_}.pth")

        # If `aggregated_parameters` is not None, update the global net and save it
        if aggregated_parameters is not None:
            net = self.model()
            print(f"Saving round {server_round} aggregated_parameters...")

            # Convert `Parameters` to `list[np.ndarray]`
            aggregated_ndarrays: list[np.ndarray] = parameters_to_ndarrays(
                aggregated_parameters
            )

            # Convert `list[np.ndarray]` to PyTorch `state_dict`
            params_dict = zip(net.state_dict().keys(), aggregated_ndarrays)
            state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
            net.load_state_dict(state_dict, strict=True)

            # Save the model to disk
            torch.save(net.state_dict(), f"{self.save_dir}/server-{server_round}.pth")

        return aggregated_parameters, aggregated_metrics


# 4. Federated Learning Configuration

In [23]:
# Use the following class to run the experiment

# You need to provide the following information:
# 1. The Network class (dont instantiate it)
#       (assume we use the same network for all clients and server)
# 2. The list of data loaders for each client,
#       where loaders is a list of loader tuples (train, val, test)
#       i.e. loaders = [ (train_loader_0, val_loader_0, test_loader_0), ... ]
#       NOTE: In fit and evaluate, we ONLY use the train_loader and val_loader,
#             But we ask you to pyt them together for simplicity for any future test use.
#       NOTE: we assume the number of clients == number of data loaders
# 3. Number of clients

# See next block for an example of how to use this class

class FLExperiment:
    """
    A federated learning experiment interface class.

    NOTE: For each client, we now expect a tuple of three DataLoaders:
    (train_loader, val_loader, test_loader).
    """

    def __init__(
        self,
        model_cls: type,
        client_loaders: List[Tuple[DataLoader, DataLoader, DataLoader]],
        num_clients: int,
        device: torch.device = torch.device("cpu"),
        local_epochs: int = 1,
        num_rounds: int = 5,
        task_type: TaskType = TaskType.REGRESSION,
        # strategy: Optional[Strategy] = None, # Is not supported yet. and may not be needed
    ):
        """
        Args:
            model_cls (type): A PyTorch nn.Module class (not an instance).
                We'll instantiate `model_cls()` for each client and server.
            client_loaders (List[(DataLoader, DataLoader, DataLoader)]):
                A list of (train_loader, val_loader, test_loader) for each client.
            num_clients (int): Number of clients to simulate.
            device (torch.device): CPU or GPU device.
            local_epochs (int): Local epochs on each client per round.
            num_rounds (int): How many global training rounds.
            strategy (Optional[Strategy]): Use a custom Flower strategy or fallback to default FedAvg.
        """
        if len(client_loaders) != num_clients:
            raise ValueError(
                f"Number of client loader tuples ({len(client_loaders)}) does not match "
                f"the number of clients ({num_clients})."
            )

        self.model_cls = model_cls
        self.client_loaders = client_loaders
        self.num_clients = num_clients
        self.local_epochs = local_epochs
        self.num_rounds = num_rounds
        self.device = device
        self.task_type = task_type

        # Store final trained models
        self._client_models: List[Optional[nn.Module]] = [None] * self.num_clients
        self._server_model: Optional[nn.Module] = None

        # Create one model per client (instantiate model_cls)
        self.client_nets = [self.model_cls().to(self.device) for _ in range(self.num_clients)]

        self.strategy = self._create_default_strategy(save_only_last=True)

    def _create_default_strategy(self, save_only_last: bool) -> Strategy:
        """Create a default FedAvg strategy with a minimal server_evaluate."""

        def server_evaluate(
            server_round: int,
            parameters: NDArrays,
            config: Dict[str, Scalar]
        ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
            # Minimal server eval (no real evaluation)
            net = self.model_cls().to(self.device)
            set_parameters(net, parameters)
            print(f"[Server] Round {server_round} - no global evaluation implemented.")
            return None
        
        def weighted_average(metrics: List[Tuple[int, Dict[str, Scalar]]]) -> Dict[str, Scalar]:
            accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
            examples = [num_examples for num_examples, _ in metrics]
            if sum(examples) == 0:
                return {"accuracy": 0.0}
            return {"accuracy": sum(accuracies) / sum(examples)}

        def weighted_average_regression(metrics: List[Tuple[int, Dict[str, Scalar]]]) -> Dict[str, Scalar]:
            total_sum_of_squares = 0.0
            total_samples = 0
            for (num_examples, m) in metrics:
                if "RMSE" in m:
                    total_sum_of_squares += m["RMSE"]
                    total_samples += num_examples
            if total_samples == 0:
                return {"rmse": 0.0}
            rmse = (total_sum_of_squares / total_samples)
            return {"rmse": rmse}
        
        if self.task_type == TaskType.CLASSFICATION:
            aggregation_fn = weighted_average
        else:
            aggregation_fn = weighted_average_regression

        default_strategy = DefaultStrategy(
            model = self.model_cls,
            total_round = self.num_rounds,
            only_last = True,
            fraction_fit=1.0,
            fraction_evaluate=1.0,
            min_fit_clients=self.num_clients,
            min_evaluate_clients=self.num_clients,
            min_available_clients=self.num_clients,
            evaluate_fn=server_evaluate,
            evaluate_metrics_aggregation_fn=aggregation_fn,
        )
        return default_strategy

    def _client_fn(self, context: Context) -> Client:
        """Construct one Flower client using the partition_id to pick (train, val, test)."""
        partition_id = context.node_config["partition-id"]
        trainloader, valloader, testloader = self.client_loaders[partition_id]
        net = self.client_nets[partition_id]

        client = HousePricingClient(
            net=net,
            train_loader=trainloader,
            val_loader=valloader,
            test_loader = testloader,
            device=self.device,
            client_id=partition_id,
            epochs=self.local_epochs
        )
        return client.to_client()

    def _server_fn(self, context: Context) -> ServerAppComponents:
        """Server-side: configure strategy and server config."""
        config = ServerConfig(num_rounds=self.num_rounds)
        return ServerAppComponents(strategy=self.strategy, config=config)

    def run(self, save_only_last: bool = True) -> None:
        """Run the federated learning simulation and store final client/server models.
        
        Args:
            save_only_last (bool): Save only the last round of models.
                Default True. If False, all models will be saved.
        """
        print("[FLExperiment] Starting federated training...")
        self.strategy = self._create_default_strategy(save_only_last=save_only_last)
        client_app = ClientApp(client_fn=self._client_fn)
        server_app = ServerApp(server_fn=self._server_fn)

        # Resource allocation
        if self.device.type == "cuda":
            backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 1.0}}
        else:
            backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 0.0}}

        # Run the simulation
        run_simulation(
            client_app=client_app,
            server_app=server_app,
            num_supernodes=self.num_clients,
            backend_config=backend_config,
        )
        print("[FLExperiment] Federated training finished.")

    def get_clients(self, round_num: int = 0) -> List[nn.Module]:
        """Return final trained models for all clients (if they have been saved).
        
        Args:
            round_num (int): Round number to fetch models from. Default 0 (last round).
        
        Returns:
            List[nn.Module]: List of final trained models for all clients.
                The index of the list corresponds to the client ID 
                and the index of dataloader. 
        """
        assert round_num <= self.num_rounds, f"Round {round_num} not available, only {self.num_rounds} rounds."
        if round_num <= 0:
            round_num = self.num_rounds
        try:
            return [
                torch.load(f"models/client-{round_num}-{cid}.pth", map_location=self.device, weights_only=True)
                for cid in range(self.num_clients)
            ]
        except FileNotFoundError:
            raise RuntimeError("Client models are not available. Have you called run() or set only_last=True?")
    
    def get_client_dataloader_tuples(self, round_num: int = 0) -> List[Tuple[nn.Module, Tuple[DataLoader, DataLoader, DataLoader]]]:
        """Return the dataloaders for all clients.
         
        Args:
            round_num (int): Round number to fetch models from. Default 0 (last round).
        
        Returns:
            List[Tuple[nn.Module, Tuple[DataLoader, DataLoader, DataLoader]]]:
                List of (client_model, (train_loader, val_loader, test_loader))
        """
        assert round_num <= self.num_rounds, f"Round {round_num} not available, only {self.num_rounds} rounds."
        if round_num <= 0:
            round_num = self.num_rounds
        try:
            clients = self.get_clients(round_num)
            return list(zip(clients, self.client_loaders)) 
        except FileNotFoundError:
            raise RuntimeError("Client dataloaders are not available. Have you called run() or set only_last=True?")

    def get_server(self, round_num: int = 0) -> nn.Module:
        """Return the final server model (if stored).

        Args:
            round_num (int): Round number to fetch models from. Default 0 (last round).
        
        Returns:
            nn.Module: The final server model.
        """
        assert round_num <= self.num_rounds, f"Round {round_num} not available, only {self.num_rounds} rounds."
        if round_num <= 0:
            round_num = self.num_rounds
        try:
            return torch.load(f"models/server-{round_num}.pth", map_location=self.device, weights_only=True)
        except FileNotFoundError:
            raise RuntimeError("Server model is not available. Have you called run() or set only_last=True?")

# 5. Different Types of Neural Networks

In [6]:
# Multiple neural networks to use as models

class Net(nn.Module):
    def __init__(self, input_size = NUM_FEATURES):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 1)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class EZNet(nn.Module):
    def __init__(self, in_features=NUM_FEATURES):
        super().__init__()
        self.fc = nn.Linear(in_features, 1)

    def forward(self, x):
        return self.fc(x)
    
class SimpleNet(nn.Module):
    def __init__(self, input_size=NUM_FEATURES):
        super().__init__()
        self.fc1 = nn.Linear(input_size, 32)
        self.fc2 = nn.Linear(32, 1)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 6. Generate / Load Data & Create Loaders

In [24]:
'''
path = ""
loaders = [
    load_data(client_id, path) for client_id in range(1, NUM_CLIENTS + 1)
]
'''

# Generate datasets. If you want to re-use a dataset, use load_data instead
client_datasets = generate_house_pricing_dataset(
    num_clients=NUM_CLIENTS,
    num_features=NUM_FEATURES,
    min_samples=MIN_SAMPLES,
    max_samples=MAX_SAMPLES,
    num_rouge_clients=NUM_ROUGE_CLIENTS,
    distributions=DISTRIBUTIONS,
    is_sum=True,
    path="./Test_20250111",
    seed=42
)

# Create train, val, test loaders
loaders = create_client_loaders(
    client_datasets, train_ratio=0.7, val_ratio=0.2, batch_size=16, seed=42
)

# 7. Create and Run The Experiment

In [25]:
fl_exp = FLExperiment(
    model_cls=EZNet,
    client_loaders=loaders,
    num_clients=5,
    num_rounds=10,
    local_epochs=50,
    task_type = TaskType.REGRESSION
)


fl_exp.run(False)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=10, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client


[FLExperiment] Starting federated training...


[92mINFO [0m:      Received initial parameters from one random client
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)


[Server] Round 0 - no global evaluation implemented.


[92mINFO [0m:      aggregate_fit: received 5 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 5)


[Server] Round 1 - no global evaluation implemented.


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)


[36m(ClientAppActor pid=51112)[0m [Client 3] Evaluate -> Loss: 6204028.1972, RMSE: 2490.7887


[92mINFO [0m:      aggregate_fit: received 5 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 5)


[Server] Round 2 - no global evaluation implemented.


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)


[36m(ClientAppActor pid=51112)[0m [Client 4] Evaluate -> Loss: 6918559.6741, RMSE: 2630.3155[32m [repeated 5x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 5 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 5)


[Server] Round 3 - no global evaluation implemented.


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 4]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)


[36m(ClientAppActor pid=51112)[0m [Client 0] Evaluate -> Loss: 6701863.7990, RMSE: 2588.7958[32m [repeated 5x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 5 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 5)


[Server] Round 4 - no global evaluation implemented.


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 5]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)


[36m(ClientAppActor pid=51112)[0m [Client 4] Evaluate -> Loss: 6918559.7630, RMSE: 2630.3155[32m [repeated 5x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 5 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 5)


[Server] Round 5 - no global evaluation implemented.


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 6]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)


[36m(ClientAppActor pid=51112)[0m [Client 4] Evaluate -> Loss: 6918559.6741, RMSE: 2630.3155[32m [repeated 5x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 5 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 5)


[Server] Round 6 - no global evaluation implemented.


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 7]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)


[36m(ClientAppActor pid=30760)[0m [Client 2] Evaluate -> Loss: 5841260.3000, RMSE: 2416.8699[32m [repeated 5x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 5 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 5)


[Server] Round 7 - no global evaluation implemented.


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 8]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)


[36m(ClientAppActor pid=51112)[0m [Client 4] Evaluate -> Loss: 6918559.6741, RMSE: 2630.3155[32m [repeated 5x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 5 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 5)


[Server] Round 8 - no global evaluation implemented.


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 9]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)


[36m(ClientAppActor pid=51112)[0m [Client 3] Evaluate -> Loss: 6204028.1972, RMSE: 2490.7887[32m [repeated 5x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 5 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 5)


[Server] Round 9 - no global evaluation implemented.


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 10]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)


[36m(ClientAppActor pid=51112)[0m [Client 1] Evaluate -> Loss: 6637752.4143, RMSE: 2576.3836[32m [repeated 5x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 5 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 5)


[Round 10] Saving model for client 4...
[Round 10] Saving model for client 3...
[Round 10] Saving model for client 1...
[Round 10] Saving model for client 0...
[Round 10] Saving model for client 2...
Saving round 10 aggregated_parameters...
[Server] Round 10 - no global evaluation implemented.


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 10 round(s) in 118.35s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 6488664.743893929
[92mINFO [0m:      		round 2: 6488664.76064201
[92mINFO [0m:      		round 3: 6488664.76064201
[92mINFO [0m:      		round 4: 6488664.799720866
[92mINFO [0m:      		round 5: 6488664.76064201
[92mINFO [0m:      		round 6: 6488664.76064201
[92mINFO [0m:      		round 7: 6488664.76064201
[92mINFO [0m:      		round 8: 6488664.76064201
[92mINFO [0m:      		round 9: 6488664.76064201
[92mINFO [0m:      		round 10: 6488664.76064201
[92mINFO [0m:      	History (metrics, distributed, evaluate):
[92mINFO [0m:      	{'rmse': [(1, 8.864726807466392),
[92mINFO [0m:      	          (2, 8.864726817340816),
[92mINFO [0m:      	          (3, 8.864726817340815),
[92mINFO [0m:      	          (4, 8.8

[36m(ClientAppActor pid=51112)[0m [Client 2] Evaluate -> Loss: 5841260.3000, RMSE: 2416.8699[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=39456)[0m [Client 0] Evaluate -> Loss: 6701863.7990, RMSE: 2588.7958[32m [repeated 4x across cluster][0m
[FLExperiment] Federated training finished.


# 8. Output the Statistics for each Client

In [26]:
clients = fl_exp.get_clients()
for i, client in enumerate(clients):
    model = EZNet().to(DEVICE)
    model.load_state_dict(client)
    model.eval()

    _, _, test_loader = loaders[i]
    loss, rmse = test_regression(model, test_loader)
    print(f'Client_{i} loss: {loss}, RMSE: {rmse}')

Client_0 loss: 6653907.938461538, RMSE: 2579.516997125923
Client_1 loss: 6628783.492063492, RMSE: 2574.6424008128765
Client_2 loss: 5880526.819672131, RMSE: 2424.979756548935
Client_3 loss: 6147990.48951049, RMSE: 2479.51416400683
Client_4 loss: 6909169.352941177, RMSE: 2628.5298843538335


# 9. Test Train Loss

In [12]:
eznet = EZNet()
train(eznet, loaders[0][0], epochs=100, verbose=True, task_type=TaskType.REGRESSION)
print(test_regression(eznet, loaders[0][1]))

  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


Epoch 1: train loss 123.72262602168801
Epoch 2: train loss 121.67161010452921
Epoch 3: train loss 121.84293470337492
Epoch 4: train loss 120.80265381437907
Epoch 5: train loss 120.35047023104265
Epoch 6: train loss 119.62367313620038
Epoch 7: train loss 120.13106822515552
Epoch 8: train loss 119.67532413717694
Epoch 9: train loss 118.71319442676707
Epoch 10: train loss 119.29284074973157
Epoch 11: train loss 119.12419612938758
Epoch 12: train loss 119.29444718925873
Epoch 13: train loss 119.38451577136867
Epoch 14: train loss 119.50294125814574
Epoch 15: train loss 119.35580285239558
Epoch 16: train loss 119.51745840497492
Epoch 17: train loss 119.32753041344232
Epoch 18: train loss 119.46318068662526
Epoch 19: train loss 119.24895842493427
Epoch 20: train loss 119.15882866303502
Epoch 21: train loss 119.53557004521808
Epoch 22: train loss 119.21115307559334
Epoch 23: train loss 119.02346404016865
Epoch 24: train loss 119.19642487295431
Epoch 25: train loss 118.7215384533055
Epoch 26: 