# Federated PyTorch Histology Tutorial

## Connect to the Federation

In [1]:
# Create a federation
from openfl.interface.interactive_api.federation import Federation

# please use the same identificator that was used in signed certificate
client_id = 'api'
cert_dir = 'cert'
director_node_fqdn = 'localhost'
# 1) Run with API layer - Director mTLS 
# If the user wants to enable mTLS their must provide CA root chain, and signed key pair to the federation interface
# cert_chain = f'{cert_dir}/root_ca.crt'
# api_certificate = f'{cert_dir}/{client_id}.crt'
# api_private_key = f'{cert_dir}/{client_id}.key'

# federation = Federation(client_id=client_id, director_node_fqdn=director_node_fqdn, director_port='50051',
#                        cert_chain=cert_chain, api_cert=api_certificate, api_private_key=api_private_key)

# --------------------------------------------------------------------------------------------------------------------

# 2) Run with TLS disabled (trusted environment)
# Federation can also determine local fqdn automatically
federation = Federation(client_id=client_id, director_node_fqdn=director_node_fqdn, director_port='50051', tls=False)


In [2]:
federation.target_shape

['1']

In [3]:
shard_registry = federation.get_shard_registry()
shard_registry

{'env_one': {'shard_info': node_info {
    name: "env_one"
  }
  shard_description: "Histology dataset, shard number 1 out of 2"
  sample_shape: "150"
  sample_shape: "150"
  target_shape: "1",
  'is_online': True,
  'is_experiment_running': False,
  'last_updated': '2022-11-14 23:02:31',
  'current_time': '2022-11-14 23:02:45',
  'valid_duration': seconds: 120,
  'experiment_name': 'ExperimentName Mock'},
 'env_two': {'shard_info': node_info {
    name: "env_two"
  }
  shard_description: "Histology dataset, shard number 1 out of 2"
  sample_shape: "150"
  sample_shape: "150"
  target_shape: "1",
  'is_online': True,
  'is_experiment_running': False,
  'last_updated': '2022-11-14 23:02:33',
  'current_time': '2022-11-14 23:02:45',
  'valid_duration': seconds: 120,
  'experiment_name': 'ExperimentName Mock'}}

In [4]:
# First, request a dummy_shard_desc that holds information about the federated dataset 
dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)
sample, target = dummy_shard_desc.get_dataset('train')[0]

## Creating a FL experiment using Interactive API

In [5]:
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment

### Register dataset

In [6]:
import torchvision
from torchvision import transforms as T

normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

augmentation = T.RandomApply(
    [T.RandomHorizontalFlip(),
     T.RandomRotation(10),
     T.RandomResizedCrop(64)], 
    p=.8
)

training_transform = T.ToTensor()

valid_transform = T.ToTensor()


In [7]:
from torch.utils.data import Dataset


class TransformedDataset(Dataset):
    """Image Person ReID Dataset."""

    def __init__(self, dataset, transform=None, target_transform=None):
        """Initialize Dataset."""
        self.dataset = dataset
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        """Length of dataset."""
        return len(self.dataset)

    def __getitem__(self, index):
        img, label = self.dataset[index]
        label = self.target_transform(label) if self.target_transform else label
        img = self.transform(img) if self.transform else img
        return img, label


In [8]:
class HistologyDataset(DataInterface):
    def __init__(self, **kwargs):
        self.kwargs = kwargs
    
    @property
    def shard_descriptor(self):
        return self._shard_descriptor
        
    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor  will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor
        
        self.train_set = TransformedDataset(
            self._shard_descriptor.get_dataset('train'),
            transform=training_transform
        )
        self.valid_set = TransformedDataset(
            self._shard_descriptor.get_dataset('val'),
            transform=valid_transform
        )
        
    def get_train_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks with optimizer in contract
        """
        return DataLoader(
            self.train_set, num_workers=8, batch_size=self.kwargs['train_bs'], shuffle=True
            )

    def get_valid_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks without optimizer in contract
        """
        return DataLoader(self.valid_set, num_workers=8, batch_size=self.kwargs['valid_bs'])

    def get_train_data_size(self):
        """
        Information for aggregation
        """
        return len(self.train_set)

    def get_valid_data_size(self):
        """
        Information for aggregation
        """
        return len(self.valid_set)
    

In [9]:
fed_dataset = HistologyDataset(train_bs=64, valid_bs=64)

### Describe the model and optimizer

In [10]:
import os
import glob
from torch.utils.data import Dataset, DataLoader
from PIL import Image

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [11]:
"""
MobileNetV2 model
"""

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
        self.conv1 = nn.Conv2d(3, 16, **conv_kwargs)
        self.conv2 = nn.Conv2d(16, 32, **conv_kwargs)
        self.conv3 = nn.Conv2d(32, 64, **conv_kwargs)
        self.conv4 = nn.Conv2d(64, 128, **conv_kwargs)
        self.conv5 = nn.Conv2d(128 + 32, 256, **conv_kwargs)
        self.conv6 = nn.Conv2d(256, 512, **conv_kwargs)
        self.conv7 = nn.Conv2d(512 + 128 + 32, 256, **conv_kwargs)
        self.conv8 = nn.Conv2d(256, 512, **conv_kwargs)
        self.fc1 = nn.Linear(1184 * 9 * 9, 128)
        self.fc2 = nn.Linear(128, 8)


    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        maxpool = F.max_pool2d(x, 2, 2)

        x = F.relu(self.conv3(maxpool))
        x = F.relu(self.conv4(x))
        concat = torch.cat([maxpool, x], dim=1)
        maxpool = F.max_pool2d(concat, 2, 2)

        x = F.relu(self.conv5(maxpool))
        x = F.relu(self.conv6(x))
        concat = torch.cat([maxpool, x], dim=1)
        maxpool = F.max_pool2d(concat, 2, 2)

        x = F.relu(self.conv7(maxpool))
        x = F.relu(self.conv8(x))
        concat = torch.cat([maxpool, x], dim=1)
        maxpool = F.max_pool2d(concat, 2, 2)

        x = maxpool.flatten(start_dim=1)
        x = F.dropout(self.fc1(x), p=0.5)
        x = self.fc2(x)
        return x

model_net = Net()

In [12]:
optimizer_adam = optim.Adam(model_net.parameters(), lr=1e-4)

### Register model

In [13]:
from copy import deepcopy

framework_adapter = 'openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin'
model_interface = ModelInterface(model=model_net, optimizer=optimizer_adam, framework_plugin=framework_adapter)

# Save the initial model state
initial_model = deepcopy(model_net)

## Define and register FL tasks

In [14]:
task_interface = TaskInterface()
import torch

import tqdm

# The Interactive API supports registering functions definied in main module or imported.
def function_defined_in_notebook(some_parameter):
    print(f'Also I accept a parameter and it is {some_parameter}')

# Task interface currently supports only standalone functions.
@task_interface.add_kwargs(**{'some_parameter': 42})
@task_interface.register_fl_task(model='net_model', data_loader='train_loader', \
                     device='device', optimizer='optimizer')     
def train(net_model, train_loader, optimizer, device, loss_fn=F.cross_entropy, some_parameter=None):
    device = torch.device('cuda')
    if not torch.cuda.is_available():
        device = 'cpu'
    
    function_defined_in_notebook(some_parameter)
    
    train_loader = tqdm.tqdm(train_loader, desc="train")
    net_model.train()
    net_model.to(device)

    losses = []

    for data, target in train_loader:
        data, target = torch.tensor(data).to(device), torch.tensor(
            target).to(device) 
        optimizer.zero_grad()
        output = net_model(data)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
        losses.append(loss.detach().cpu().numpy())
        
    return {'train_loss': np.mean(losses),}


@task_interface.register_fl_task(model='net_model', data_loader='val_loader', device='device')     
def validate(net_model, val_loader, device):
    device = torch.device('cuda')
    if not torch.cuda.is_available():
        device = 'cpu'
        
    net_model.eval()
    net_model.to(device)
    
    val_loader = tqdm.tqdm(val_loader, desc="validate")
    val_score = 0
    total_samples = 0

    with torch.no_grad():
        for data, target in val_loader:
            samples = target.shape[0]
            total_samples += samples
            data, target = torch.tensor(data).to(device), \
                torch.tensor(target).to(device)
            output = net_model(data)
            pred = output.argmax(dim=1)
            val_score += pred.eq(target).sum().cpu().numpy()
            
    return {'acc': val_score / total_samples,}

## Time to start a federated learning experiment

In [15]:
# create an experimnet in federation
experiment_name = 'histology_test_experiment'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [16]:
# The following command zips the workspace and python requirements to be transfered to collaborator nodes
fl_experiment.start(
    model_provider=model_interface, 
    task_keeper=task_interface,
    data_loader=fed_dataset,
    rounds_to_train=5,
    opt_treatment='CONTINUE_GLOBAL'
)



In [17]:
# If user want to stop IPython session, then reconnect and check how experiment is going
# fl_experiment.restore_experiment_state(model_interface)

fl_experiment.stream_metrics(tensorboard_logs=False)