# Federated PyTorch CIFAR10

In [None]:
#!pip install -r requirements.txt

In [None]:
import os
import glob

from PIL import Image

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from openfl.interface.interactive_api.federation import Federation
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment
from copy import deepcopy
import torchvision
from torchvision import transforms as T
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import tqdm

torch.manual_seed(0)
np.random.seed(0)

## Connect to the Federation

In [None]:
# Create a federation

# please use the same identificator that was used in signed certificate
client_id = 'api'
cert_dir = 'cert'
director_node_fqdn = 'director'
# 1) Run with API layer - Director mTLS 
# If the user wants to enable mTLS their must provide CA root chain, and signed key pair to the federation interface
# cert_chain = f'{cert_dir}/root_ca.crt'
# api_certificate = f'{cert_dir}/{client_id}.crt'
# api_private_key = f'{cert_dir}/{client_id}.key'

# federation = Federation(client_id=client_id, director_node_fqdn=director_node_fqdn, director_port='50051',
#                        cert_chain=cert_chain, api_cert=api_certificate, api_private_key=api_private_key)

# --------------------------------------------------------------------------------------------------------------------

# 2) Run with TLS disabled (trusted environment)
# Federation can also determine local fqdn automatically
federation = Federation(client_id=client_id, director_node_fqdn=director_node_fqdn, director_port='50051', tls=False)


In [None]:
federation.target_shape

In [None]:
shard_registry = federation.get_shard_registry()
shard_registry

In [None]:
# First, request a dummy_shard_desc that holds information about the federated dataset 
dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)
dummy_shard_dataset = dummy_shard_desc.get_dataset('train')
sample, target = dummy_shard_dataset[0]
print(sample.shape)
print(target.shape)

## Creating a FL experiment using Interactive API

### Register dataset

In [None]:
normalize = T.Normalize(
    mean=[0.3037],
    std=[0.2889]
)

augmentation = T.RandomApply(
    [T.RandomHorizontalFlip()],
     #T.RandomCrop(((32,32)), padding=4)],
    p=.5
)

training_transform = T.Compose(
    [T.ToTensor(),
     T.Resize((32,32)),
     #T.Grayscale(num_output_channels=1),
     augmentation,
     normalize]
)

valid_transform = T.Compose(
    [T.ToTensor(),
     T.Resize((32,32)),
     #T.Grayscale(num_output_channels=1),
     normalize]
)


In [None]:
class TransformedDataset(Dataset):
    """Image Person ReID Dataset."""

    def __init__(self, dataset, transform=None, target_transform=None):
        """Initialize Dataset."""
        self.dataset = dataset
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        """Length of dataset."""
        return len(self.dataset)

    def __getitem__(self, index):
        img, label = self.dataset[index]
        label = self.target_transform(label) if self.target_transform else label
        img = self.transform(img) if self.transform else img
        return img, label


In [None]:
class MNISTSVHNDataset(DataInterface):
    def __init__(self, **kwargs):
        self.kwargs = kwargs
    
    @property
    def shard_descriptor(self):
        return self._shard_descriptor
        
    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor  will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor
        
        self.train_set = TransformedDataset(
            self._shard_descriptor.get_dataset('train'),
            transform=training_transform
        )
        self.valid_set = TransformedDataset(
            self._shard_descriptor.get_dataset('val'),
            transform=valid_transform
        )
        
    def get_train_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks with optimizer in contract
        """
        generator=torch.Generator()
        generator.manual_seed(0)
        return DataLoader(
            self.train_set, batch_size=self.kwargs['train_bs'], shuffle=True, generator=generator
            )

    def get_valid_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks without optimizer in contract
        """
        return DataLoader(self.valid_set, batch_size=self.kwargs['valid_bs'])

    def get_train_data_size(self):
        """
        Information for aggregation
        """
        return len(self.train_set)

    def get_valid_data_size(self):
        """
        Information for aggregation
        """
        return len(self.valid_set)
    

In [None]:
fed_dataset = MNISTSVHNDataset(train_bs=64, valid_bs=64)

### Describe the model and optimizer

In [None]:
class VGG16(nn.Module):

    def __init__(self, num_classes):
        super(VGG16, self).__init__()

        # calculate same padding:
        # (w - k + 2*p)/s + 1 = o
        # => p = (s(o-1) - w + k)/2

        self.block_1 = nn.Sequential(
            nn.Conv2d(in_channels=1,
                      out_channels=64,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      # (1(32-1)- 32 + 3)/2 = 1
                      padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64,
                      out_channels=64,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2),
                         stride=(2, 2))
        )

        self.block_2 = nn.Sequential(
            nn.Conv2d(in_channels=64,
                      out_channels=128,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128,
                      out_channels=128,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2),
                         stride=(2, 2))
        )
        
        self.block_3 = nn.Sequential(
            nn.Conv2d(in_channels=128,
                      out_channels=256,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256,
                      out_channels=256,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256,
                      out_channels=256,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2),
                         stride=(2, 2))
        )

        self.block_4 = nn.Sequential(
            nn.Conv2d(in_channels=256,
                      out_channels=512,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512,
                      out_channels=512,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512,
                      out_channels=512,
                      kernel_size=(3, 3),
                      stride=(1, 1),
                      padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2),
                         stride=(2, 2))
        )            

        self.classifier = nn.Sequential(
            nn.Linear(2048, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.65),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.65),
            nn.Linear(4096, num_classes) 
        )

        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
                nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='leaky_relu')
                if m.bias is not None:
                    m.bias.detach().zero_()


    def forward(self, x):

        x = self.block_1(x)
        x = self.block_2(x)
        x = self.block_3(x)
        x = self.block_4(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x
    
    
model = VGG16(10)

model.block_1[1] = nn.GroupNorm(32, 64)
model.block_1[4] = nn.GroupNorm(32, 64)
model.block_2[1] = nn.GroupNorm(32, 128)
model.block_2[4] = nn.GroupNorm(32, 128)
model.block_3[1] = nn.GroupNorm(32, 256)
model.block_3[4] = nn.GroupNorm(32, 256)
model.block_3[7] = nn.GroupNorm(32, 256)
model.block_4[1] = nn.GroupNorm(32, 512)
model.block_4[4] = nn.GroupNorm(32, 512)
model.block_4[7] = nn.GroupNorm(32, 512)


In [None]:
model_net = model

In [None]:
model_net

In [None]:
list(model_net.parameters())

In [None]:
params_to_update = []
for param in model_net.parameters():
    if param.requires_grad == True:
        params_to_update.append(param)
'''
FEDPROX
'''        
#from openfl.utilities.optimizers.torch import FedProxAdam        
#optimizer = FedProxAdam(params_to_update, lr=1e-4, mu=0.01)

'''
ORIGINALE
'''
optimizer = optim.Adam(params_to_update, lr=1e-4)
#optimizer = optim.AdamW(params_to_update, lr=0.001, weight_decay=0.02)
#optimizer = optim.SGD(params_to_update, lr=0.01)

#scheduler
#scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

def cross_entropy(output, target):
    """Cross-entropy metric
    """
    #return F.cross_entropy(input=output,target=target)
    #return F.binary_cross_entropy_with_logits(input=output,target=target)
    criterion = nn.CrossEntropyLoss()
    loss = criterion(output, target)
    return loss

### Register model

In [None]:
framework_adapter = 'openfl.plugins.frameworks_adapters.pytorch_adapter.FrameworkAdapterPlugin'
model_interface = ModelInterface(model=model_net, optimizer=optimizer, framework_plugin=framework_adapter)

# Save the initial model state
initial_model = deepcopy(model_net)

## Define and register FL tasks

In [None]:
task_interface = TaskInterface()

'''
FEDCURV
'''
#from openfl.utilities.fedcurv.torch import FedCurv
#from openfl.component.aggregation_functions import FedCurvWeightedAverage
#import tqdm

#fedcurv = FedCurv(model_interface.provide_model(), importance=1e3)

'''
FEDOPT
'''

#from openfl.component.aggregation_functions import AdagradAdaptiveAggregation    
#agg_fn = AdagradAdaptiveAggregation(model_interface=model_interface, learning_rate=0.4)     
#@task_interface.set_aggregation_function(agg_fn)


# The Interactive API supports registering functions definied in main module or imported.
def function_defined_in_notebook(some_parameter):
    print(f'Also I accept a parameter and it is {some_parameter}')

# Task interface currently supports only standalone functions.
@task_interface.add_kwargs(**{'some_parameter': 42})
@task_interface.register_fl_task(model='net_model', data_loader='train_loader', \
                     device='device', optimizer='optimizer') 
#@task_interface.set_aggregation_function(FedCurvWeightedAverage())


def train(net_model, train_loader, optimizer, device, loss_fn=cross_entropy, some_parameter=None):
    torch.manual_seed(0)
    #fedcurv.on_train_begin(net_model)
    device='cuda'
    function_defined_in_notebook(some_parameter)
    
    train_loader = tqdm.tqdm(train_loader, desc="train")
    net_model.train()
    net_model.to(device)

    losses = []
    epochs = 2
    
    for epoch in range(epochs):
        for data, target in train_loader:
            data, target = torch.tensor(data).to(device), torch.tensor(
                target).to(device, dtype=torch.int64)
            optimizer.zero_grad()
            #data = data.type(torch.LongTensor)
            #target = target.type(torch.LongTensor)
            output = net_model(data)
            #output = output.logits #per GOOGLENET
            loss = loss_fn(output=output, target=target) #+ fedcurv.get_penalty(net_model)
            loss.backward()
            optimizer.step()
            losses.append(loss.detach().cpu().numpy())
    #fedcurv.on_train_end(net_model, train_loader, device)    
    return {'train_loss': np.mean(losses),}


@task_interface.register_fl_task(model='net_model', data_loader='val_loader', device='device')     
def validate(net_model, val_loader, device):
    torch.manual_seed(0)
    device = torch.device('cuda')
    net_model.eval()
    net_model.to(device)
    
    val_loader = tqdm.tqdm(val_loader, desc="validate")
    val_score = 0
    total_samples = 0

    with torch.no_grad():
        for data, target in val_loader:
            samples = target.shape[0]
            total_samples += samples
            data, target = torch.tensor(data).to(device), \
                torch.tensor(target).to(device, dtype=torch.int64)
            output = net_model(data)
            #da wine
            #_, preds = torch.max(outputs, dim=1)
            #return torch.tensor(torch.sum(preds == labels).item() / len(preds))
            
            #originale
            #pred = output.argmax(dim=1,keepdim=True)
            
            #tentativo
            _, pred = torch.max(output, dim=1)
            val_score += pred.eq(target).sum().cpu().numpy()
            
    return {'acc': val_score / total_samples,}

## Time to start a federated learning experiment

In [None]:
# create an experimnet in federation
experiment_name = 'mnist_svhn_VGG16GN_FederatedStreamflow_50rounds_2epoch'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [None]:
# The following command zips the workspace and python requirements to be transfered to collaborator nodes
fl_experiment.start(
    model_provider=model_interface, 
    task_keeper=task_interface,
    data_loader=fed_dataset,
    rounds_to_train=50,
    opt_treatment='CONTINUE_GLOBAL'
)

In [None]:
# If user want to stop IPython session, then reconnect and check how experiment is going
# fl_experiment.restore_experiment_state(model_interface)

fl_experiment.stream_metrics(tensorboard_logs=True)