# Federating centralized machine learning in a private Jupyter notebook to distributed institutions

User input (AI researcher module): in this notebook, AI developers can customize the way the data will be load as well as the model architecture that will be used to train models at each collaborating institution within the federation. This implementation, preferably in [PyTorch](https://pytorch.org/) for this first version, will then be integrated in the [Flower](https://flower.dev/) application.

<img src="../figures/federation_figure.JPG" alt="Overview" style="width: 800px;"/>

## How to?

In this notebook, the overall structure of the code has been pre-filled so that the user only has to complete the missing parts.<br/>Note that all comments (preceded by a "#" sign) correspond to parts of code that need completion. Comments in quotation marks indicate the specific inputs and/or outputs that are needed for a given function or class to be further successfully integrated into the [Flower](https://flower.dev/) application.

## Dependencies

In [None]:
import torch
import torch.nn as nn
import torchvision
from torch.utils.data import DataLoader

# import ...
import numpy as np
import pandas as pd
import torchvision.models as models
import cv2
from collections import OrderedDict
import time

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Data preparation

In [None]:
def load_images(df_data, dir_data, input_shape):
    """Function to load images to study and apply preprocessing if needed."""
    list_images = []
    dir_data = dir_data + "/patches/"
    print("dir_data: ", dir_data)
    for i in range(0, len(df_data)):
        img = cv2.imread(dir_data + df_data.images[i])
        img = cv2.resize(img, input_shape)
        list_images.append(img)

    return np.array(list_images)

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, partition_file, folder, input_shape):
        ### START CODE HERE ###
        self.input_shape = input_shape
        
        df_data = pd.read_csv(folder + "/" + partition_file + '.csv', dtype=str, delimiter=',')
        val_counts = df_data['GT'].value_counts()
        factor = 1-(val_counts[1]/val_counts[0])
        # Balance dataset
        df_data = df_data.drop(df_data.loc[df_data['GT']=='0'].sample(frac=float(factor)).index).reset_index()
        # Load images for the experiment and their corresponding labels
        self.X = load_images(df_data, folder, self.input_shape)
        self.Y = df_data.GT
        self.Y = [int(el) for el in self.Y]
        
        ### END CODE HERE ###
        
        self.indexes = np.arange(0, len(self.X))
        
    def __len__(self):
        """Required output: length of the dataset"""
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        """Required input: index to access
        Required output: image corresponding to the input index and its label"""
        
        image = self.X[idx]
        label = self.Y[idx]
        
        return image, label

## Data loading

In [None]:
def load_data(data_path, input_shape=(224,224)):
    """Load data (training, validation and test sets).
    Required outputs: loaders of each set and dictionary containing the length of each corresponding set
    """
    ### START CODE HERE ###
    trainset = CustomDataset(partition_file="train_client", folder=data_path, input_shape=input_shape)
    valset = CustomDataset(partition_file="val_client", folder=data_path, input_shape=input_shape)
    testset = CustomDataset(partition_file="test", folder=data_path, input_shape=input_shape)
    
    trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
    valloader = DataLoader(valset, batch_size=64, shuffle=True)
    testloader = DataLoader(testset, batch_size=64, shuffle=True)
    ### END CODE HERE ###
    
    num_examples = {"trainset" : len(trainset), "valset": len(valset), "testset" : len(testset)}
    
    return trainloader, valloader, testloader, num_examples

## Model definition

In [None]:
class Model(nn.Module):
    """Model architecture. Inputs in the init function can be added if needed."""
    def __init__(self, input_shape=(224,224), n_classes=2):
        super(Model, self).__init__()
        
        ### START CODE HERE ###
        self.input_shape = input_shape
        self.n_classes = n_classes
        
        # Define model architecture
        self.model = models.vgg16(pretrained=False)
        self.model.classifier[-1] = torch.nn.Linear(in_features=4096, out_features=self.n_classes)

        ### END CODE HERE ###
        
    def forward(self, x):
        """Required input: tensor of images to predict
        Required output: output of the model given the images tensor in input"""
            
        ### START CODE HERE ###
        x = x.permute(0, 3, 1, 2)
        x = self.model(x)
        x = torch.squeeze(x)
        ### END CODE HERE ###
        
        return x

## Training loop

In [None]:
def train(net, trainloader, valloader, epochs):
    """Train the network on the training set, evaluating it on the validation set at each epoch."""
    criterion = torch.nn.CrossEntropyLoss() #TODO Define loss function
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001) #TODO Define optimizer
    
    start = time.time()
    for i_epoch in range(epochs):
        ### START CODE HERE ###
        print("Epoch ", i_epoch+1)
        
        correct, total, train_loss_epoch = 0, 0, 0.0
        for images, labels in trainloader:
            images = torch.from_numpy(np.asarray(images).astype('float32'))
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            loss_iteration = criterion(outputs, labels)
            loss_iteration.backward()
            optimizer.step()
            
            optimizer.zero_grad()
            
            train_loss_epoch += loss_iteration.item()
        
        train_loss_epoch = train_loss_epoch / total
        train_acc_epoch = correct / total
        val_loss_epoch, val_acc_epoch = test(net, valloader)
        info = "[INFO] Epoch {}/{} - train_loss: {:.6f} - train_acc: {:.6f} - val_loss: {:.6f} - val_acc: {:.6f}".format(
                i_epoch + 1, epochs, train_loss_epoch, train_acc_epoch, val_loss_epoch, val_acc_epoch)
        print(info + "\n")
        
    end = time.time()
    print("Time to train the whole network: ", end-start, " s")
        
        ### END CODE HERE ###

## Test function

In [None]:
def test(net, testloader):
    """Validate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss() #TODO Define loss function
    
    correct, total, loss = 0, 0, 0.0
    with torch.no_grad():
        ### START CODE HERE ###
        for data in testloader:
            images = torch.from_numpy(np.asarray(data[0]).astype('float32'))
            images, labels = images.to(DEVICE), data[1].to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    loss = loss / total
        ### END CODE HERE ###
    
    return loss, accuracy

## Aggregation function

In [None]:
# https://flower.dev/docs/implementing-strategies.html

strategy = fl.server.strategy.FedAvg(
    # ... other FedAvg arguments
    fraction_fit=1,
    fraction_eval=1,
    min_fit_clients=3,
    min_available_clients=2,
)

config = {"num_rounds": 3}
grpc_max_message_length = 895_870_912