# Federated Learning with Director

In [51]:
from openfl.interface.interactive_api.federation import Federation
from openfl.interface.interactive_api.experiment import TaskInterface, ModelInterface, FLExperiment
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

NUM_CLASSES=1

# Connect to the Federation

In [61]:
# please use the same identificator that was used in signed certificate
client_id = 'frontend'
director_node_fqdn = 'localhost'
director_port = 50051

federation = Federation(
    client_id=client_id,
    director_node_fqdn=director_node_fqdn,
    director_port=director_port,
    tls=False
)
shard_registry = federation.get_shard_registry()
shard_registry
federation.target_shape



['400', '400']

In [24]:
# First, request a dummy_shard_desc that holds information about the federated dataset 
dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)
dummy_shard_dataset = dummy_shard_desc.get_dataset('train')
sample, target = dummy_shard_dataset[0]
f"Sample shape: {sample.shape}, target shape: {target.shape}"

'Sample shape: (400, 400, 3), target shape: (400, 400)'

## Creating a FL experiment using Interactive API

### Register dataset

We extract User dataset class implementation.
Is it convinient?
What if the dataset is not a class?

In [62]:
from gear_shard_dataset import GearSD
from kvasir_shard_dataset import KvasirSD

fed_dataset = KvasirSD(train_bs=4, valid_bs=8)
fed_dataset.shard_descriptor = dummy_shard_desc
for i, (sample, target) in enumerate(fed_dataset.get_train_loader()):
    print("Sample shape : "+str(sample.shape))
    print("Target shape : "+str(target.shape))

Sample shape : torch.Size([4, 3, 332, 332])
Target shape : torch.Size([4, 1, 332, 332])
Sample shape : torch.Size([4, 3, 332, 332])
Target shape : torch.Size([4, 1, 332, 332])
Sample shape : torch.Size([1, 3, 332, 332])
Target shape : torch.Size([1, 1, 332, 332])


### Describe a model and optimizer

In [63]:
"""
UNet and DeepLab model definition 
"""
from loss import TverskyLoss, IoULoss, FocalLoss, soft_dice_coef, soft_dice_loss, DoubleConv, Down, Up
from torchvision.models.segmentation import deeplabv3_resnet50, deeplabv3_mobilenet_v3_large

class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=NUM_CLASSES):
        super().__init__()
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.up1 = Up(512, 256)
        self.up2 = Up(256, 128)
        self.up3 = Up(128, 64)
        self.outc = nn.Conv2d(64, n_classes, 1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x = self.up1(x4, x3)
        x = self.up2(x, x2)
        x = self.up3(x, x1)
        x = self.outc(x)
        x = torch.sigmoid(x)
        return x

model_unet = UNet()
model_deeplab_resnet = deeplabv3_resnet50(pretrained=False, num_classes=NUM_CLASSES, pretrained_backbone=True)
model_deeplab_mobilenet = deeplabv3_mobilenet_v3_large(pretrained=False, num_classes=NUM_CLASSES, pretrained_backbone=True)


In [64]:
optimizer_adam = optim.Adam(model_deeplab_mobilenet.parameters(), lr=1e-4)

#### Register model

In [65]:
from copy import deepcopy

framework_adapter = 'openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin'
MI = ModelInterface(model=model_deeplab_mobilenet, optimizer=optimizer_adam, framework_plugin=framework_adapter)

# Save the initial model state
initial_model = deepcopy(model_deeplab_mobilenet)

### Define and register FL tasks

In [66]:
TI = TaskInterface()
import torch
import numpy as np
import tqdm
from openfl.component.aggregation_functions import Median

# The Interactive API supports registering functions definied in main module or imported.
def function_defined_in_notebook(some_parameter):
    print(f'Also I accept a parameter and it is {some_parameter}')

#The Interactive API supports overriding of the aggregation function
aggregation_function = Median()

# Task interface currently supports only standalone functions.
@TI.add_kwargs(**{'some_parameter': 42})
@TI.register_fl_task(model='model', data_loader='train_loader', \
                     device='device', optimizer='optimizer')     
@TI.set_aggregation_function(aggregation_function)
def train(model, train_loader, optimizer, device, loss_fn=soft_dice_loss, some_parameter=None):
    
    """    
    The following constructions, that may lead to resource race
    is no longer needed:
    
    if not torch.cuda.is_available():
        device = 'cpu'
    else:
        device = 'cuda'
        
    """

    print(f'\n\n TASK TRAIN GOT DEVICE {device}\n\n')
    
    function_defined_in_notebook(some_parameter)
    
    train_loader = tqdm.tqdm(train_loader, desc="train")
    
    model.train()
    model.to(device)

    losses = []

    for data, target in train_loader:
        data, target = torch.tensor(data).to(device), torch.tensor(
            target).to(device, dtype=torch.float32)

        optimizer.zero_grad()
        
        output = model(data)["out"]
        print("Output of the model {}".format(output))
        loss = loss_fn(output=output, target=target)
        loss.backward()
        optimizer.step()
        losses.append(loss.detach().cpu().numpy())
        
    return {'train_loss': np.mean(losses),}


@TI.register_fl_task(model='model', data_loader='val_loader', device='device')     
def validate(model, val_loader, device):
    print(f'\n\n TASK VALIDATE GOT DEVICE {device}\n\n')
    
    model.eval()
    model.to(device)
    
    val_loader = tqdm.tqdm(val_loader, desc="validate")

    val_score = 0
    total_samples = 0

    with torch.no_grad():
        for data, target in val_loader:
            print("Target Tensor shape  {} Data {} ".format(target.shape, data.shape))
            samples = target.shape[0]
            total_samples += samples
            data, target = torch.tensor(data).to(device), \
                torch.tensor(target).to(device, dtype=torch.int64)
            output = model(data)
            val = soft_dice_coef(output, target)
            val_score += val.sum().cpu().numpy()
            
    return {'dice_coef': val_score / total_samples,}

## Time to start a federated learning experiment

In [67]:
# create an experimnet in federation
experiment_name = 'gear_test_experiment'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [68]:
# If I use autoreload I got a pickling error

# The following command zips the workspace and python requirements to be transfered to collaborator nodes
fl_experiment.start(model_provider=MI, 
                    task_keeper=TI,
                    data_loader=fed_dataset,
                    rounds_to_train=2,
                    opt_treatment='CONTINUE_GLOBAL',
                    device_assignment_policy='CUDA_PREFERRED')


In [58]:
# If user want to stop IPython session, then reconnect and check how experiment is going 
# fl_experiment.restore_experiment_state(MI)

fl_experiment.stream_metrics()

_MultiThreadedRendezvous: <_MultiThreadedRendezvous of RPC that terminated with:
	status = StatusCode.UNAVAILABLE
	details = "Socket closed"
	debug_error_string = "{"created":"@1652444609.976731078","description":"Error received from peer ipv4:127.0.0.1:50051","file":"src/core/lib/surface/call.cc","file_line":1062,"grpc_message":"Socket closed","grpc_status":14}"
>

## Now we validate the best model!

In [14]:
best_model = fl_experiment.get_best_model()

In [None]:
# We remove exremove_experiment_datamove_experiment_datamove_experiment_datariment data from director
fl_experiment.remove_experiment_data()

In [15]:
best_model.inc.conv[0].weight
# model_unet.inc.conv[0].weight


AttributeError: 'DeepLabV3' object has no attribute 'inc'

In [None]:
# Validating initial model
validate(initial_model, fed_dataset.get_valid_loader(), 'cpu')

In [None]:
# Validating trained model
validate(best_model, fed_dataset.get_valid_loader(), 'cpu')

## We can tune model further!

In [16]:
MI = ModelInterface(model=best_model, optimizer=optimizer_adam, framework_plugin=framework_adapter)
fl_experiment.start(model_provider=MI, task_keeper=TI, data_loader=fed_dataset, rounds_to_train=4, \
                              opt_treatment='CONTINUE_GLOBAL')

In [None]:
best_model = fl_experiment.get_best_model()
# Validating trained model
validate(best_model, fed_dataset.get_valid_loader(), 'cpu')