# Federated Learning with Director

In [96]:
from openfl.interface.interactive_api.federation import Federation
from openfl.interface.interactive_api.experiment import TaskInterface, ModelInterface, FLExperiment
import sys
sys.path.append("../utils/")

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
from loss import *
import matplotlib.pyplot as plt
import time
import os
import copy

NUM_CLASSES=1
ROUND_TO_TRAIN=2

# Connect to the Federation

In [116]:
# please use the same identificator that was used in signed certificate
client_id = 'frontend'
director_node_fqdn = 'localhost'
director_port = 50053

federation = Federation(
    client_id=client_id,
    director_node_fqdn=director_node_fqdn,
    director_port=director_port,
    tls=False
)
shard_registry = federation.get_shard_registry()
shard_registry
federation.target_shape
dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)
dummy_shard_dataset = dummy_shard_desc.get_dataset('train')
sample, target = dummy_shard_dataset[0]
f"Sample shape: {sample.shape}, target shape: {target.shape}"

'Sample shape: (400, 400, 3), target shape: (400, 400)'

## Creating a FL experiment using Interactive API

### Register dataset

We extract User dataset class implementation.
Is it convinient?
What if the dataset is not a class?

In [117]:
from gear_shard_dataset import GearSD
from kvasir_shard_dataset import KvasirSD

fed_dataset = KvasirSD(train_bs=4, valid_bs=8)
fed_dataset.shard_descriptor = dummy_shard_desc
for i, (sample, target) in enumerate(fed_dataset.get_train_loader()):
    print("Sample shape : "+str(sample.shape))
    print("Target shape : "+str(target.shape))



Sample shape : torch.Size([4, 3, 332, 332])
Target shape : torch.Size([4, 1, 332, 332])
Sample shape : torch.Size([4, 3, 332, 332])
Target shape : torch.Size([4, 1, 332, 332])
Sample shape : torch.Size([1, 3, 332, 332])
Target shape : torch.Size([1, 1, 332, 332])


### Describe a model and optimizer

In [132]:
"""
DeepLab model definition 
"""
from torchvision.models.segmentation import deeplabv3_resnet50, deeplabv3_mobilenet_v3_large, deeplabv3_resnet101

class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1):
        super().__init__()
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.up1 = Up(512, 256)
        self.up2 = Up(256, 128)
        self.up3 = Up(128, 64)
        self.outc = nn.Conv2d(64, n_classes, 1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x = self.up1(x4, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        x = self.outc(x)
        x = torch.sigmoid(x)
        return x

class DeepLabv3:
    def build_deeplab(self, num_features_fc: int=256, backbone: str="mobilenetv3", pretrained_backbone: bool=True, pretrained_head: bool=True, alpha: float=0):
        """ change the output layer and add a freezing coeff.
            the number of in channel for the DeepLabHead depends on the backbone, 
            MobileNetv3 has 960 out channel whereas resnet101 2048
        Args: 
            backbone: str='mobilenetv3' "pre-trained backbone to download"
        """
        if backbone == "mobilenetv3":
            if pretrained_head:
                self.model = deeplabv3_mobilenet_v3_large(pretrained=True, pretrained_backbone=True)
                out_channel = 960
                self.model.classifier = DeepLabHead(out_channel, NUM_CLASSES)
                self.model.aux_classifier = nn.Identity()
                print("[*] Changing head for {} classes and removing aux classifier".format(NUM_CLASSES))
            else:
                self.model = deeplabv3_mobilenet_v3_large(pretrained_backbone=True, num_classes=NUM_CLASSES)
        elif backbone == "resnet101":
            if pretrained_head:
                self.model= deeplabv3_resnet101(pretrained=True, pretrained_backbone=True)
                out_channel = 2048
                self.model.classifier = DeepLabHead(out_channel, NUM_CLASSES)
                self.model.aux_classifier = nn.Identity()
                print("[*] Changing head for {} classes and removing aux classifier".format(NUM_CLASSES))
            else:
                self.model= deeplabv3_resnet101(pretrained_backbone=True, num_classes=NUM_CLASSES)
        else:
            assert "No such backbone"

        if alpha == 0:
            self.freeze = False
        else:
            self.freeze = True
            self.alpha=alpha
            print("[!] This model will be trained using alpha freezing coef = {} meaning {}/{} layers will be freeze".format(self.alpha, int(self.alpha*sum(1 for x in self.model.parameters())), sum(1 for x in self.model.parameters())))
        return self.model
 
    def freeze(self):
        if self.freeze:
            s = sum(1 for x in self.model.parameters())
            l_freeze = int(s*self.alpha)
            print("{} layers in this model, freezing {} layer\n".format(s, l_freeze))
            for i,param in enumerate(self.model.parameters()):
                param.requires_grad = False
                if l_freeze < i:
                    break
            for name, layer in self.model.named_modules():
                print(name, layer)
    
    def unfreeze(self):
        for i,param in enumerate(self.model.parameters()):
            param.requires_grad = True



In [133]:
d = DeepLabv3()
model_deeplab = d.build_deeplab(alpha=0.7)
model_unet = UNet()

[*] Changing head for 1 classes and removing aux classifier
[!] This model will be trained using alpha freezing coef = 0.7 meaning 135/193 layers will be freeze


In [134]:
model = model_deeplab
# take low learning rate for Tversky loss and to not change so much the current trained weights
optimizer_adam = optim.Adam(model.parameters(), lr=1e-4)

#### Register model

In [135]:
from copy import deepcopy

framework_adapter = 'openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin'
MI = ModelInterface(model=model, optimizer=optimizer_adam, framework_plugin=framework_adapter)

# Save the initial model state
initial_model = deepcopy(model)


### Define and register FL tasks

In [136]:
TI = TaskInterface()
import torch
import numpy as np
import tqdm
from openfl.component.aggregation_functions import Median
from torchmetrics import F1Score

#CRITERION=torch.nn.MSELoss(reduction='mean')
CRITERION=IoULoss
CRITERION_VAL=soft_dice_coef

# The Interactive API supports registering functions definied in main module or imported.
def function_defined_in_notebook(some_parameter):
    print(f'Also I accept a parameter and it is {some_parameter}')

#The Interactive API supports overriding of the aggregation function
aggregation_function = Median()

# Task interface currently supports only standalone functions.
@TI.add_kwargs(**{'some_parameter': 42})
@TI.register_fl_task(model='model', data_loader='train_loader', \
                     device='device', optimizer='optimizer')     
@TI.set_aggregation_function(aggregation_function)
def train(model, train_loader, optimizer, device, loss_fn=CRITERION, some_parameter=None):
    # TODO we can tune the loss functon with the aux output and apply a coeff
    """    
    The following constructions, that may lead to resource race
    is no longer needed:
    
    if not torch.cuda.is_available():
        device = 'cpu'
    else:
        device = 'cuda'        
    """

    # we freeze the layers during the training (otherwise the opt don't load the model correctly afterwards)
    d.freeze()

    print(f'\n\n TASK TRAIN GOT DEVICE {device}\n\n')
    
    function_defined_in_notebook(some_parameter)
    
    train_loader = tqdm.tqdm(train_loader, desc="train")
    
    model.train()
    model.to(device)

    losses = []

    for data, target in train_loader:
        data, target = torch.tensor(data).to(device), torch.tensor(
            target).to(device, dtype=torch.float32)

        optimizer.zero_grad()
        output = model(data)["out"]
        
        #loss = loss_fn().forward(output, target)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
        losses.append(loss.detach().cpu().numpy())
    
    d.unfreeze()

    return {'train_loss (dice loss)': np.mean(losses),}

@TI.register_fl_task(model='model', data_loader='val_loader', device='device')     
def validate(model, val_loader, device, loss_fn=CRITERION, val_fn=CRITERION_VAL):

    print(f'\n\n TASK VALIDATE GOT DEVICE {device}\n\n')
    model.eval()
    model.to(device)
    
    val_loader = tqdm.tqdm(val_loader, desc="validate")
    val_score = 0
    total_samples = 0
    losses = []
    with torch.no_grad():
        for data, target in val_loader:
            samples = target.shape[0]
            total_samples += samples
            data, target = torch.tensor(data).to(device), \
                torch.tensor(target).to(device, dtype=torch.int64)

            output = model(data)["out"]
            val = soft_dice_coef(output, target)
            val_score += val.sum().cpu().numpy()
            losses.append(loss_fn(output, target))

    return {'Dice coef': val_score / total_samples, 'val_loss (dice loss)': np.mean(losses)}

## Time to start a federated learning experiment

In [137]:
# create an experimnet in federation
experiment_name = 'gear_test_experiment'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [138]:
# If I use autoreload I got a pickling error

# The following command zips the workspace and python requirements to be transfered to collaborator nodes
fl_experiment.start(model_provider=MI, 
                    task_keeper=TI,
                    data_loader=fed_dataset,
                    rounds_to_train=ROUND_TO_TRAIN,
                    opt_treatment='CONTINUE_GLOBAL',
                    device_assignment_policy='CUDA_PREFERRED')


In [163]:
# If user want to stop IPython session, then reconnect and check how experiment is going 
# fl_experiment.restore_experiment_state(MI)

fl_experiment.stream_metrics()

KeyboardInterrupt: 

## Now we validate the best model!

In [164]:
best_model = fl_experiment.get_best_model()

  new_state[k] = pt.from_numpy(tensor_dict.pop(k)).to(device)


In [None]:
# We remove exremove_experiment_datamove_experiment_datamove_experiment_datariment data from director
fl_experiment.remove_experiment_data()

In [None]:
best_model.inc.conv[0].weight
# model_unet.inc.conv[0].weight

In [166]:
# Validating initial model
validate(initial_model, fed_dataset.get_valid_loader(), 'cpu')

validate:   0%|          | 0/1 [00:00<?, ?it/s]



 TASK VALIDATE GOT DEVICE cpu




  data, target = torch.tensor(data).to(device), \
  torch.tensor(target).to(device, dtype=torch.int64)
validate: 100%|██████████| 1/1 [00:00<00:00,  1.73it/s]


{'val_loss': 0.00042889933683909476, 'iou_score': 0.999981701374054}

In [167]:
# Validating trained model
validate(best_model, fed_dataset.get_valid_loader(), 'cpu')

validate:   0%|          | 0/1 [00:00<?, ?it/s]



 TASK VALIDATE GOT DEVICE cpu




  data, target = torch.tensor(data).to(device), \
  torch.tensor(target).to(device, dtype=torch.int64)
validate: 100%|██████████| 1/1 [00:00<00:00,  2.41it/s]


{'val_loss': 0.00042889933683909476, 'iou_score': 0.999981701374054}

## We can tune model further!

In [168]:
MI = ModelInterface(model=best_model, optimizer=optimizer_adam, framework_plugin=framework_adapter)
fl_experiment.start(model_provider=MI, task_keeper=TI, data_loader=fed_dataset, rounds_to_train=4, \
                              opt_treatment='CONTINUE_GLOBAL')

In [None]:
best_model = fl_experiment.get_best_model()
# Validating trained model
validate(best_model, fed_dataset.get_valid_loader(), 'cpu')