# SCAFFOLD Benchmark

We try to reproduce the results of the paper [SCAFFOLD: Stochastic Controlled Averaging for Federated Learning](10.48550/arXiv.1910.06378) on the EMNIST dataset. The paper introduces SCAFFOLD, this algorithm corrects the estimated drift of the clients by calculating the difference between the global and the local control variates. The algorithm is tested on the EMNIST dataset with [SCAFFOLD_2FC](../fluke/nets.py:1261)(a 2 layer MLP). The paper measures SCAFFOLD's performance of differences in convergence speed.

## Setup

### Loading and Splitting the dataset

In [None]:
from fluke.data.datasets import Datasets

dataset = Datasets.get("emnist", path="../data")

In [None]:
from fluke.data import DataSplitter
from fluke.data import DDict

data = DDict( dataset=dataset,
              distribution="iid",
              sampling_perc=1,
              client_split=0.2,
              keep_test=True,
              server_test=True,
              server_split=0.0,
              uniform_test=True)

splitter = DataSplitter(**data)

### Setting up the evaluator

In [None]:
from fluke.evaluation import ClassificationEval, Evaluator
from fluke import GlobalSettings

evaluator = ClassificationEval(1, n_classes=dataset.num_classes)
GlobalSettings().set_evaluator(evaluator)
#GlobalSettings().set_device("cuda")

### Defining the model

In [None]:
import torch
import torch.nn as nn
from torch.functional import F
from fluke.nets import EncoderHeadNet

class Scaffold_2FC_E(nn.Module):
    """Encoder for the :class:`Scaffold_2FC` network.
    
    See Also:
        - :class:`Scaffold_2FC`
        - :class:`Scaffold_2FC_D`
    """
    def __init__(self):
        super(Scaffold_2FC_E , self).__init__()
        self.output_size = 512
        self.fc1 = nn.Linear(28*28, 1024)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(1024, 512)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.view(-1, 28*28)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x
    
class Scaffold_2FC_D(nn.Module):
    """Head for the :class:`Scaffold_2FC` network.
    
    See Also:
        - :class:`Scaffold_2FC`
        - :class:`Scaffold_2FC_E`
        """
    
    def __init__(self):
        super(Scaffold_2FC_D, self).__init__()
        self.output_size = 47
        #self.fc3 = nn.Linear(1024, 512)
        #self.relud = nn.ReLU()
        self.fc4 = nn.Linear(512, 47)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        #x = self.fc3(x)
        #x = self.relud(x)
        x = self.fc4(x)
        return x
    
# Scaffold: https://arxiv.org/abs/1910.06378 (EMNIST) 
class Scaffold_2FC (EncoderHeadNet):
    """A 2 layer fully connected network for EMNIST classification. This network attempts to recreate the architecture 
    proposed in the [SCAFFOLD]_ paper, while there are no specific details about the architecture, we have created a 2 layer
    fully connected network with 512 and 47 neurons in the first and second layer respectively.
    
    See Also:
        - :class:`Scaffold_2FC_E`
        - :class:`Scaffold_2FC_D`
    
    References:
        .. [SCAFFOLD] Sai Praneeth Karimireddy, Satyen Kale, Mehryar Mohri, Sashank J. Reddi, Sebastian U. Stich, Ananda Theertha Suresh. SCAFFOLD: Stochastic Controlled Averaging for Federated Learning. 
            In arXiv (2019).
    """
    
    def __init__(self):
        super(Scaffold_2FC, self).__init__(Scaffold_2FC_E(), Scaffold_2FC_D())

### Setting up hyperparameters and model

In [None]:
client_hp = DDict(
    batch_size=20,
    local_epochs=5,
    loss="CrossEntropyLoss",
    optimizer=DDict(
      lr=0.1),
    scheduler=DDict(
      gamma=1,
      step_size=1)
)

alg_hp = DDict(
    client = client_hp,
    model=Scaffold_2FC(),
    server=DDict(weighted=True))

In [None]:
from fluke.algorithms.scaffold import SCAFFOLD

algorithm = SCAFFOLD(100, splitter, alg_hp)

### Setting up the logger

In [None]:
from fluke.utils.log import Log
logger = Log()
algorithm.set_callbacks(logger)

## Running the experiment

In [None]:
algorithm.run(40, 0.2)

Our objective is to reach a test accuracy of 0.5 within 10 rounds of training. We achieve this level of performance in 4 rounds.