<a href="https://colab.research.google.com/github/NartoTeroKK/hybrid-federated-learning/blob/main/Hybrid__Federated_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Configuration and imports



## Environment installation and conf

In [1]:
# Path and version numbers for python depend on operating system
#!which python
#!python --version
%ls -l
# environment variable
#%env PYTHONPATH=
# install virtual environment package
#!pip install virtualenv
# create virtual environment
#!virtualenv myenv

total 4
drwxr-xr-x 1 root root 4096 Nov 21 14:24 [0m[01;34msample_data[0m/


In [None]:
!pip install -U flwr flwr["simulation"] torch torchvision

Collecting flwr
  Downloading flwr-1.5.0-py3-none-any.whl (200 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m200.4/200.4 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
Collecting torch
  Downloading torch-2.1.1-cp310-cp310-manylinux1_x86_64.whl (670.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m670.2/670.2 MB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[?25h

## YAML Constants



In [None]:
import yaml

yaml_config = """
num_rounds: 3
num_partitions: 100
batch_size: 4
num_classes: 10
config_fit:
  lr: 0.0004
  momentum: 0.9
  local_epochs: 1
num_cpus: 2
num_gpus: 0
"""

# Load YAML data into a Python dictionary
config = yaml.safe_load(yaml_config)

# Access individual configuration values
num_rounds = config["num_rounds"]
num_partitions = config["num_partitions"]
batch_size = config["batch_size"]
num_classes = config["num_classes"]
config_fit = config["config_fit"]
lr = config_fit["lr"]
momentum = config_fit["momentum"]
local_epochs = config_fit["local_epochs"]
num_cpus = config.get("num_cpus", 4)  # You can provide default values
num_gpus = config.get("num_gpus", 0.5)


# SIMULATION

## Import libs



In [None]:
# General import
from pathlib import Path
from PIL import Image
import os
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union
import timeit
# Flower
import flwr as fl
from flwr.common import (
    EvaluateRes,
    FitRes,
    Scalar
)
from flwr.server.client_proxy import ClientProxy
from flwr.common import Config, NDArrays, Scalar
# PyTorch
import torch
from torch.utils.data import random_split, DataLoader, ConcatDataset, Subset
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# TorchVision
from torchvision.datasets import MNIST, utils
from torchvision.transforms import ToTensor, Compose, Normalize

from logging import DEBUG, INFO
from flwr.common.logger import log
import timeit
from flwr.server.history import History
import concurrent.futures
from flwr.common import (
    Code,
    EvaluateIns,
    EvaluateRes,
    Scalar,
    Status,
    Parameters
)
from flwr.common.logger import log

FitResultsAndFailures = Tuple[
    List[Tuple[ClientProxy, FitRes]],
    List[Union[Tuple[ClientProxy, FitRes], BaseException]],
]

EvaluateResultsAndFailures = Tuple[
    List[Tuple[ClientProxy, EvaluateRes]],
    List[Union[Tuple[ClientProxy, EvaluateRes], BaseException]],
]
import numpy as np

## Dataset configuration

### FEMNIST class

In [None]:
from torchvision.datasets import MNIST, utils, VisionDataset
from torchvision.datasets.utils import check_integrity

class FEMNIST(MNIST):
    """
    This dataset is derived from the Leaf repository
    (https://github.com/TalwalkarLab/leaf) pre-processing of the Extended MNIST
    dataset, grouping examples by writer. Details about Leaf were published in
    "LEAF: A Benchmark for Federated Settings" https://arxiv.org/abs/1812.01097.
    """
    resources = [
        ('https://raw.githubusercontent.com/tao-shen/FEMNIST_pytorch/master/femnist.tar.gz',
         '59c65cec646fc57fe92d27d83afdf0ed')]

    def __init__(self, root, train=True, transform=None, target_transform=None,
                 download=False, num_partitions=100):
        super(MNIST, self).__init__(root, transform=transform,
                                    target_transform=target_transform)
        self.train = train

        if self._check_legacy_exist():
            self.data, self.targets, self.users_index = self.load_data(num_partitions)
            return

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')

        self.data, self.targets, self.users_index = self.load_data(num_partitions)


    def select_users_data(self, num_partitions):
        data = self.data[:num_partitions]
        targets = self.targets[:num_partitions]
        users_index = self.users_index[:num_partitions]

        return data, targets, users_index


    def load_data(self, num_partitions):

        if self.train:
            data_file = self.training_file
        else:
            data_file = self.test_file

        return torch.load(os.path.join(self.processed_folder, data_file))



    def __getitem__(self, index):
        img, target = self.data[index], int(self.targets[index])
        img = Image.fromarray(img.numpy(), mode='F')
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target

    def download(self):
        """Download the FEMNIST data if it doesn't exist in processed_folder already."""
        import shutil
        def makedir_exist_ok(dirpath):
            import errno
            """
            Python2 support for os.makedirs(.., exist_ok=True)
            """
            try:
                os.makedirs(dirpath)
            except OSError as e:
                if e.errno == errno.EEXIST:
                    pass
                else:
                    raise

        if self._check_exists():
            return

        makedir_exist_ok(self.raw_folder)
        makedir_exist_ok(self.processed_folder)

        # download files
        print('Downloading and extracting ...')
        for url, md5 in self.resources:
            filename = url.rpartition('/')[2]
            utils.download_and_extract_archive(url, download_root=self.raw_folder, extract_root=self.processed_folder, filename=filename, md5=md5)

    def _check_exists(self) -> bool:
        return all(
            check_integrity(os.path.join( self.raw_folder, os.path.basename(url) ))
            for url, _ in self.resources
        )



### Prepare dataset and loaders

In [None]:
def get_femnist(data_path: str='./data'):

    transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])

    trainset = FEMNIST(root=data_path, train=True, download=True, transform=transform, num_partitions=num_partitions)
    testset = FEMNIST(root=data_path, train=False, download=True, transform=transform, num_partitions=num_partitions)

    print(len(trainset.data), len(testset.data))
    #print(trainset.data.shape,'\n', trainset.data.size,'\n', trainset.data[0].shape)
    #print(trainset.targets.shape,'\n', trainset.targets.size,'\n', trainset.targets[0].shape)

    return trainset, testset

def split_user_datatset(dataset, num_partitions):
    datasets = []
    images_taken = 0

    for i in range(num_partitions):
        num_images = dataset.users_index[i]
        datasets.append(Subset(dataset, range(images_taken, images_taken + num_images)))
        images_taken += num_images

    return datasets


def merge_datasets(trainsets, testsets):
    import random

    merged_trainsets = []
    merged_testset = []
    if len(trainsets) == len(testsets):
        rand_indexes = random.sample(range(0, len(trainsets)), X)
        rand_indexes.sort()
    else:
        raise Exception("trainsets and testsets must have the same length")
    print(rand_indexes)

    for i in rand_indexes:
        merged_trainsets.append(trainsets[i])
        trainsets[i] = None
        merged_testset.append(testsets[i])
        testsets[i] = None


    trainsets = [x for x in trainsets if x is not None]
    testsets = [x for x in testsets if x is not None]

    trainsets.append(ConcatDataset(merged_trainsets))
    testsets.append(ConcatDataset(merged_testset))

    return trainsets, testsets

def prepare_dataset(num_partitions: int, batch_size: int, val_ratio: float = 0.1):

    trainset, testset = get_femnist()

    trainsets = split_user_datatset(trainset, num_partitions)
    testsets = split_user_datatset(testset, num_partitions)

    if X > 0:
        trainsets, testsets = merge_datasets(trainsets, testsets)

    # Create dataloaders with train+val support
    trainloaders = []
    valloaders = []

    for trainset_ in trainsets:
        num_total = len(trainset_)
        num_val = int(num_total * val_ratio)
        num_train = num_total - num_val

        for_train, for_val = random_split(trainset_, [num_train, num_val], generator=torch.Generator().manual_seed(2023))
        del trainset_

        trainloaders.append(DataLoader(for_train, batch_size=batch_size, shuffle=True, num_workers=2))
        valloaders.append(DataLoader(for_val, batch_size=batch_size, shuffle=False, num_workers=2))

    testloaders = []

    for testset in testsets:
        testloaders.append(DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2))

    del testsets, trainsets, trainset, testset

    return trainloaders, valloaders, testloaders

## Network MODEL

In [None]:
class Net(nn.Module):

    def __init__(self, num_classes: int) -> None:
        super(Net, self).__init__()
        self.conv2d_1 = nn.Conv2d(1, 32, kernel_size=3)
        self.max_pooling = nn.MaxPool2d(2, stride=2)
        self.conv2d_2 = nn.Conv2d(32, 64, kernel_size=3)
        self.dropout_1 = nn.Dropout(0.25)
        self.flatten = nn.Flatten()
        self.linear_1 = nn.Linear(9216, 128)
        self.dropout_2 = nn.Dropout(0.5)
        self.linear_2 = nn.Linear(128, num_classes)
        self.relu = nn.ReLU()
        #self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.conv2d_1(x)
        x = self.relu(x)
        x = self.conv2d_2(x)
        x = self.relu(x)
        x = self.max_pooling(x)
        x = self.dropout_1(x)
        x = self.flatten(x)
        x = self.linear_1(x)
        x = self.relu(x)
        x = self.dropout_2(x)
        x = self.linear_2(x)
        #x = self.softmax(x)
        return x

    '''
        def __init__(self, num_classes: int) -> None:
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 6, 5)
            self.pool = nn.MaxPool2d(2, 2)
            self.conv2 = nn.Conv2d(6, 16, 5)
            self.fc1 = nn.Linear(16 * 4 * 4, 120)
            self.fc2 = nn.Linear(120, 84)
            self.fc3 = nn.Linear(84, num_classes)

        def forward(self, x: torch.Tensor) -> torch.Tensor:
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = x.view(-1, 16 * 4 * 4)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    '''

def train(net, trainloader, optimizer, epochs, device: str):
    """Train the network on the training set.

    This is a fairly simple training loop for PyTorch.
    """
    criterion = torch.nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min')
    net.train()
    net.to(device)
    for epoch in range(epochs):
        total, correct, epoch_loss = 0, 0, 0.0
        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss.item()
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        #print(f"accuracy {epoch_acc}")

from sklearn.metrics import f1_score

def test(net, testloader, device: str):
    """Validate the network on the entire test set.
    and report loss and accuracy.
    """
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    net.to(device)
    y_true, y_pred = [], []
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)

            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())

    accuracy = correct / total
    f_score = f1_score(y_true, y_pred, average='weighted')
    return loss, accuracy, f_score

## Helper functions

In [None]:
def get_client_parameters(cid):
        params_array = np.load("parameters.npy", allow_pickle=True)
        params = params_array[int(cid)]
        del params_array

        return parameters_to_ndarrays(params)


## Client

In [None]:
class FlowerClient(fl.client.NumPyClient):
    def __init__(self,
                 cid: str,
                 trainloader,
                 valloader,
                 testloader,
                 num_classes) -> None:
        super().__init__()

        self.cid = cid
        self.trainloader = trainloader
        self.valloader = valloader
        self.testloader = testloader

        self.model = Net(num_classes)

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


    def set_parameters(self, parameters):

        params_dict = zip(self.model.state_dict().keys(), parameters)

        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})

        self.model.load_state_dict(state_dict, strict=True)


    def get_parameters(self, config: Dict[str, Scalar]):

        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]


    def fit(self, parameters, config):

        server_round = config['server_round']
        # copy parameter sent by the server into client's local model
        '''
        if server_round == 1:
            self.set_parameters(self.get_parameters({}))
        else:
            self.set_parameters(get_client_parameters(self.cid))
        '''

        self.set_parameters(parameters)

        lr = config['lr']
        momentum = config['momentum']
        epochs = config['local_epochs']

        optimizer = torch.optim.SGD(self.model.parameters(), lr=lr, momentum=momentum)

        # do local training
        #print("Client CID: ", self.cid, " - trainset length: ", len(self.trainloader.dataset))
        train(self.model, self.trainloader, optimizer, epochs, self.device)

        #self.set_parameters(self.get_parameters({}))

        return self.get_parameters({}), len(self.trainloader), {}
        #int(self.cid)


    def evaluate(self, parameters: NDArrays, config: Dict[str, Scalar]):
        server_round = config['server_round']

        #self.set_parameters(get_client_parameters(self.cid))
        self.set_parameters(parameters)

        if server_round == config['num_rounds']:

            #print("Client CID: ", self.cid, " - testset length: ", len(self.testloader.dataset))
            loss, accuracy, f_score = test(self.model, self.testloader, self.device)
            #print("FINAL ACCURACY: ", accuracy)

            return float(loss), len(self.testloader), {'accuracy': accuracy, 'f-score': f_score}
        else:

            #print("Client CID: ", self.cid, " - testset length: ", len(self.valloader.dataset))
            loss, accuracy, f_score = test(self.model, self.valloader, self.device)
            #print("accuracy: ", accuracy)

            return float(loss), len(self.valloader), {'accuracy': accuracy, 'f-score': f_score}


def generate_client_fn(trainloaders, valloaders, testloaders, num_classes):

    def client_fn(cid: str):
        return FlowerClient(cid=cid,
                            trainloader=trainloaders[int(cid)],
                            valloader=valloaders[int(cid)],
                            testloader=testloaders[int(cid)],
                            num_classes=num_classes)
    return client_fn


## Strategy

In [None]:
def get_on_fit_config_fn(config: dict):

    def fit_config_fn(server_round: int):

        return {'lr': config["lr"], 'momentum': config["momentum"],
                'local_epochs': config["local_epochs"], 'server_round': server_round}

    return fit_config_fn

def get_on_evaluate_config_fn(config):

    def evaluate_config_fn(server_round: int):

        if server_round == config['num_rounds']:
            print('______ FINAL EVALUATE _____')

        return {'num_rounds': config['num_rounds'], 'server_round': server_round}

    return evaluate_config_fn


def get_evaluate_fn(num_classes: int, testloader):

    def evaluate_fn(server_round: int, parameters, config):

        model = Net(num_classes)

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        params_dict = zip(model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
        model.load_state_dict(state_dict, strict=True)

        loss, accuracy = test(model, testloader, device)

        return loss, {'accuracy': accuracy}

    return evaluate_fn


from logging import WARNING
from flwr.common import (
    ndarrays_to_parameters,
    Parameters,
    parameters_to_ndarrays)
from flwr.server.strategy.aggregate import aggregate


class AggregateCustomMetricStrategy(fl.server.strategy.FedAvg):
    def aggregate_evaluate(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, EvaluateRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[float], Dict[str, Scalar]]:
        """Aggregate evaluation accuracy using weighted average."""

        if not results:
            return None, {}

        # Call aggregate_evaluate from base class (FedAvg) to aggregate loss and metrics
        aggregated_loss, aggregated_metrics = super().aggregate_evaluate(server_round, results, failures)

        # Weigh accuracy of each client by number of examples used
        accuracies = [r.metrics["accuracy"] * r.num_examples for _, r in results]
        f_scores = [r.metrics["f-score"] * r.num_examples for _, r in results]
        examples = [r.num_examples for _, r in results]

        print("accuracies: ",accuracies,"examples: ",examples)

        # Aggregate and print custom metric
        aggregated_accuracy = sum(accuracies) / sum(examples)
        aggregated_f_score = sum(f_scores) / sum(examples)
        print(f"Round {server_round} accuracy aggregated from client results: {aggregated_accuracy}")
        print(f"Round {server_round} f-score aggregated from client results: {aggregated_f_score}")

        # Return aggregated loss and metrics (i.e., aggregated accuracy)
        return aggregated_loss, {"accuracy": aggregated_accuracy, "f-score": aggregated_f_score}

    '''
    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Dict[str, Scalar]:
        """Aggregate fit results using weighted average."""
        if not results:
            return {}
        # Do not aggregate if there are failures and failures are not accepted
        if not self.accept_failures and failures:
            return {}


        # Aggregate custom metrics if aggregation fn was provided
        metrics_aggregated = {}
        if self.fit_metrics_aggregation_fn:
            fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
            metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
        elif server_round == 1:  # Only log this warning once
            log(WARNING, "No fit_metrics_aggregation_fn provided")

        return metrics_aggregated
    '''

## Server

In [None]:
'''
from flwr.server.server import fit_clients

class MyServer(fl.server.Server):

    # pylint: disable=too-many-locals
    def fit(self, num_rounds: int, timeout: Optional[float]) -> History:
        """Run federated averaging for a number of rounds."""
        history = History()

        # Initialize parameters
        log(INFO, "Initializing global parameters")
        self.parameters = Parameters(
            tensors=[], tensor_type="numpy.ndarray"
        )
        log(INFO, "Evaluating initial parameters")
        res = self.strategy.evaluate(0, parameters=self.parameters)
        if res is not None:
            log(
                INFO,
                "initial parameters (loss, other metrics): %s, %s",
                res[0],
                res[1],
            )
            history.add_loss_centralized(server_round=0, loss=res[0])
            history.add_metrics_centralized(server_round=0, metrics=res[1])

        # Run federated learning for num_rounds
        log(INFO, "FL starting")
        start_time = timeit.default_timer()

        for current_round in range(1, num_rounds + 1):
            # Train model and replace previous global model
            res_fit = self.fit_round(
                server_round=current_round,
                timeout=timeout,
            )
            if res_fit is not None:
                fit_metrics, _ = res_fit  # fit_metrics_aggregated

                history.add_metrics_distributed_fit(
                    server_round=current_round, metrics=fit_metrics
                )

            # Evaluate model on a sample of available clients
            res_fed = self.evaluate_round(server_round=current_round, timeout=timeout)
            print("results federated: ", res_fed)
            if res_fed is not None:
                loss_fed, evaluate_metrics_fed, _ = res_fed
                if loss_fed is not None:
                    history.add_loss_distributed(
                        server_round=current_round, loss=loss_fed
                    )
                    history.add_metrics_distributed(
                        server_round=current_round, metrics=evaluate_metrics_fed
                    )

        # Bookkeeping
        end_time = timeit.default_timer()
        elapsed = end_time - start_time
        log(INFO, "FL finished in %s", elapsed)
        return history

    def fit_round(
        self,
        server_round: int,
        timeout: Optional[float],
    ) -> Optional[
        Tuple[Dict[str, Scalar], FitResultsAndFailures]
    ]:
        """Perform a single round of federated averaging."""
        # Get clients and their respective instructions from strategy
        client_instructions = self.strategy.configure_fit(
            server_round=server_round,
            parameters=self.parameters,
            client_manager=self._client_manager,
        )

        if not client_instructions:
            log(INFO, "fit_round %s: no clients selected, cancel", server_round)
            return None
        log(
            DEBUG,
            "fit_round %s: strategy sampled %s clients (out of %s)",
            server_round,
            len(client_instructions),
            self._client_manager.num_available(),
        )

        # Collect `fit` results from all clients participating in this round
        results, failures = fit_clients(
            client_instructions=client_instructions,
            max_workers=self.max_workers,
            timeout=timeout,
        )
        log(
            DEBUG,
            "fit_round %s received %s results and %s failures",
            server_round,
            len(results),
            len(failures),
        )

        params_array = np.load("parameters.npy", allow_pickle=True)
        for _, res in results:
            cid = res.num_examples
            #print("cid: ", cid, "len parameters: ", len(res.parameters.tensors))
            params_array[cid] = res.parameters
        np.save("parameters.npy", params_array)

        # Aggregate training results
        aggregated_result: Tuple[
            Dict[str, Scalar],
        ] = self.strategy.aggregate_fit(server_round, results, failures)

        metrics_aggregated = aggregated_result
        return metrics_aggregated, (results, failures)
'''

# Main


In [None]:
#for X in [80, 60, 40, 20, 0]:
#for num_rounds in [3,5]:
for iteration in range(4):
    X = 0
    ## num of FL and Centralized clients configured
    if X <= 1: X = 0
    if X > num_partitions: X = num_partitions

    num_clients = num_partitions
    if X > 0: num_clients = num_clients - X + 1

    print("--- ITERATION n. ",iteration + 1," ---\n")

    # Print the configuration
    print("num_rounds:", num_rounds)
    print("num_partitions:", num_partitions)
    '''
    print("batch_size:", batch_size)
    print("num_classes:", num_classes)
    print("config_fit:")
    print("   lr:", lr)
    print("   momentum:", momentum)
    print("   local_epochs:", local_epochs)
    print("num_cpus:", num_cpus)
    print("num_gpus:", num_gpus)
    '''
    print("____________________")
    print("\nX:", X)
    print("num_clients: ", num_clients)
    print("____________________")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Training on {device} using PyTorch {torch.__version__} and Flower {fl.__version__}\n")

    ## 2. Prepare dataset
    trainloaders, validationloaders, testloaders = prepare_dataset(num_partitions, batch_size)
    print("FEMNIST Dataset prepared\n")

    ## 3. Define clients
    client_fn = generate_client_fn(trainloaders, validationloaders, testloaders, num_classes)

    ## 4. Define strategy
    strategy =  AggregateCustomMetricStrategy(  # fl.server.strategy.FedAvg
        min_fit_clients=num_clients,
        min_evaluate_clients=num_clients,
        min_available_clients=num_clients,
        on_fit_config_fn=get_on_fit_config_fn(config_fit),
        evaluate_fn=None,   # get_evaluate_fn(num_classes, testloader)
        on_evaluate_config_fn=get_on_evaluate_config_fn(config),
        accept_failures=False,
    )

    start = timeit.default_timer()
    ## 5. Start Simulation
    history = fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=num_clients,
        config=fl.server.ServerConfig(num_rounds=num_rounds),
        strategy=strategy,
        client_resources={'num_cpus': num_cpus, 'num_gpus': num_gpus},
        #server=MyServer(client_manager=fl.server.client_manager.SimpleClientManager(), strategy=strategy)
    )
    end = timeit.default_timer()
    elapsed = end - start
    print("SIMULATION TIME: ",elapsed," s")
    ## 6. Save your results
    import datetime

    output_dir = "/content/outputs/"
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    # Get the current date and time
    now = datetime.datetime.now()
    formatted_date = now.strftime("%Y-%m-%d_%H-%M-%S")  # Format the date and time as desired

    filename = f"result_{formatted_date}_{device}_X{X}.txt"
    results_path = output_dir + filename
    results_content = f"Config:{yaml_config}\nX: {X}\nnum_clients: {num_clients}\n\nSimulation time: {elapsed}\nHistory:\n{history}\n"
    # Save the content to the file
    with open(results_path, 'w') as file:
        file.write(results_content)
    print("results file created.")

    # Optionally, you can download the saved file to your local machine
    from google.colab import files
    files.download(results_path)
    print("results file downloaded.")
    print("---------------------------------------------------------\n")