# A Jupyter-based Federated Learning wizard for Distributed Medical Applications across Institutions

User input (AI researcher module): in this notebook, AI developers can customize the way the data will be load as well as the model architecture that will be used to train models at each collaborating institution within the federation. This implementation, preferably in [PyTorch](https://pytorch.org/) for this first version, will then be integrated in the [Flower](https://flower.dev/) application.

<img src="../figures/federation_figure.JPG" alt="Overview" style="width: 800px;"/>

## How to?

In this notebook, the overall structure of the code has been pre-filled so that the user only has to complete the missing parts.<br/>Note that all comments (preceded by a "#" sign) correspond to parts of code that need completion. Comments in quotation marks indicate the specific inputs and/or outputs that are needed for a given function or class to be further successfully integrated into the [Flower](https://flower.dev/) application.

## Dependencies

In [None]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader
import flwr as fl

# import ...

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Data preparation

In [None]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self):
        ### START CODE HERE ###
        
        self.X = #TODO
        self.Y = #TODO
        
        ### END CODE HERE ###
        
        self.indexes = np.arange(0, len(self.X))
        
    def __len__(self):
        """Required output: length of the dataset"""
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        """Required input: index to access
        Required output: image corresponding to the input index and its label"""
        
        image = self.X[idx]
        label = self.Y[idx]
        
        return image, label

## Data loading

In [None]:
def load_data():
    """Load data (training, validation and test sets).
    Required outputs: loaders of each set and dictionary containing the length of each corresponding set
    """
    ### START CODE HERE ###
    trainset = #TODO
    valset = #TODO
    testset = #TODO
    
    trainloader = #TODO
    valloader = #TODO
    testloader = #TODO
    ### END CODE HERE ###
    
    num_examples = {"trainset" : len(trainset), "valset": len(valset), "testset" : len(testset)}
    
    return trainloader, valloader, testloader, num_examples

## Model definition

In [None]:
class Model(nn.Module):
    """Model architecture. Inputs in the init function can be added if needed."""
    def __init__(self) -> None:
        super(Model, self).__init__()
        
        ### START CODE HERE ###
        #TODO
        ### END CODE HERE ###
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Required input: tensor of images to predict
        Required output: output of the model given the images tensor in input"""
            
        ### START CODE HERE ###
        #TODO
        ### END CODE HERE ###
        
        return x

## Training loop

In [None]:
def train(net, trainloader, valloader, epochs):
    """Train the network on the training set, evaluating it on the validation set at each epoch."""
    criterion = #TODO Define loss function
    optimizer = #TODO Define optimizer
    
    for i_epoch in range(epochs):
        ### START CODE HERE ###
        #TODO
        ### END CODE HERE ###

## Test function

In [None]:
def test(net, testloader):
    """Validate the network on the entire test set."""
    criterion = #TODO Define loss function
    
    with torch.no_grad():
        ### START CODE HERE ###
        #TODO
        ### END CODE HERE ###
    
    return loss, accuracy

## Aggregation function

In [None]:
# https://flower.dev/docs/implementing-strategies.html

strategy = fl.server.strategy.FedAvg(
    # ... other FedAvg arguments
    fraction_fit=1,
    fraction_eval=1,
    min_eval_clients=2,
    min_available_clients=2,
    # on_evaluate_config_fn=evaluate_config,
)

config = {"num_rounds": 3}