# FEDPROX BENCHMARKS

We try to reproduce the results of the paper [Federated Optimization in Heterogeneous Networks](https://openreview.net/pdf?id=SkgwE5Ss3N) on the MNIST dataset. The paper introduces a new algorithm called FedProx, which is a Federated Averaging algorithm with a proximal term added to the loss function to encourage the local models to be close to each other. The algorithm is tested on the MNIST dataset with a [MNIST_LR](../fluke/nets.py:549)(a simple logistic regression model). The paper shows that FedProx outperforms FedAvg in terms of convergence speed and accuracy.

## Setup of the experiment

### Loading and splitting the dataset

In [None]:
from fluke.data.datasets import Datasets
dataset = Datasets.get("mnist", path="../data", channel_dim=1)  #by default we use the data folder that will be created upon the first run, 
                                                                #the get method will create another folder if the selected dataset is not present

In [None]:
from fluke.data import DataSplitter
splitter = DataSplitter(dataset=dataset,
                        distribution="iid",
                        client_split=0.03,
                        sampling_perc=1) #fedprox works best with a non-iid distribution but this can be achieved by reducing the batch size and 
                                            #increasing the number of clients and eligible clients, 
                                            # the experiment calls for 2 digits per client and that equals to 0.03 of the dataset

### Setting up the evaluator

In [None]:
from fluke.evaluation import ClassificationEval, Evaluator
from fluke import GlobalSettings

evaluator = ClassificationEval(1,n_classes=dataset.num_classes)
GlobalSettings().set_evaluator(evaluator)
GlobalSettings().set_device("cuda")

### Setting up the hyperparameters and the model

In [None]:
from fluke import DDict
# We set up the hyperparameters according to the paper's description
client_hp = DDict(
    batch_size=10,
    local_epochs=50,
    loss="CrossEntropyLoss",
    mu=1,#overall best obtained from grid search 
    optimizer=DDict(
      lr=0.03),
    scheduler=DDict(
      gamma=1,
      step_size=1)
)

alg_hp = DDict(
    client = client_hp,
    server=DDict(weighted=True),
    model="MNIST_LR")

In [None]:
from fluke.algorithms.fedprox import FedProx
algorithm = FedProx(1000, splitter, alg_hp)

### Setting up the logger

In [None]:
from fluke.utils.log import Log
logger = Log()
algorithm.set_callbacks(logger)

## Running the experiment

In [None]:
algorithm.run(400, 0.01)

We have a global target accuracy of ~0.85, our implementation of FedProx achieves a final accuracy of 0.90