Carregando os pacotes necessários:

In [1]:
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10

from typing import Tuple

import flwr as fl

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Carregando o dataset CIFAR-10 (conjuntos de treinamento e teste):

In [2]:
def load_data():
    transform = transforms.Compose([transforms.ToTensor(), 
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    trainset = CIFAR10(".", train = True, download = True, transform = transform)
    testset = CIFAR10(".", train = False, download = True, transform = transform)
    trainloader = DataLoader(trainset, batch_size = 32, shuffle = True)
    testloader = DataLoader(testset, batch_size = 32)
    return trainloader, testloader

### Redes Neurais

Modelo 1:

In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

Modelo 2:

In [4]:
class CNN(nn.Module):
    """Simple CNN adapted"""

    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(12544, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return x

Definindo a função cliente:

In [5]:
from torchvision import models

class CifarClient(fl.client.NumPyClient):
    def __init__(self, cid, train_loader, test_loader, epochs, device: torch.device = torch.device(DEVICE)):
        self.model = CNN().to(device)   
        
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device
        self.epochs = epochs
        
    def get_parameters(self):
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def set_parameters(self, parameters):
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=True)

    def set_weights(self, weights):
        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in zip(self.model.state_dict().keys(), weights)})
        self.model.load_state_dict(state_dict, strict = True)
        
    def get_weights(self) -> fl.common.Weights:
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        train(self.model, self.train_loader, epochs = self.epochs)
        return self.get_parameters(), len(self.train_loader), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        loss, accuracy = test(self.model, self.test_loader)
        return float(loss), len(self.test_loader), {"accuracy":float(accuracy)}

Definindo a função de avaliação do servidor:

In [6]:
def test(model, test_loader, device: torch.device = torch.device(DEVICE)):
    model.eval()
    
    test_loss: float = 0
    correct: int = 0
    num_test_samples: int = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            num_test_samples += len(data)
            output = model(data)
            test_loss += torch.nn.CrossEntropyLoss()(output, target).item()  
            pred = output.argmax(dim = 1, keepdim = True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= num_test_samples

    return (test_loss, {"accuracy": correct / num_test_samples})

def eval(w):
    trainloader, testloader = load_data()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    server = CifarClient(
            cid = 999,
            train_loader = trainloader,
            test_loader = testloader,
            epochs = 1,
            device = DEVICE,
    )
    server.set_weights(w)
    return test(server.model, trainloader, device)

Executando o servidor:

In [7]:
if __name__ == "__main__":
    strategy = fl.server.strategy.FedAvg(eval_fn = eval)
    fl.server.start_server("[::]:8081", config = {"num_rounds": 2}, strategy = strategy)

INFO flower 2021-08-19 00:32:55,979 | app.py:76 | Flower server running (insecure, 2 rounds)
INFO flower 2021-08-19 00:32:55,982 | server.py:118 | Getting initial parameters
INFO flower 2021-08-19 00:33:27,557 | server.py:306 | Received initial parameters from one random client
INFO flower 2021-08-19 00:33:27,558 | server.py:120 | Evaluating initial parameters


Files already downloaded and verified
Files already downloaded and verified


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
INFO flower 2021-08-19 00:33:39,620 | server.py:127 | initial parameters (loss, other metrics): 0.0720656501197815, {'accuracy': 0.09948}
INFO flower 2021-08-19 00:33:39,621 | server.py:133 | FL starting
DEBUG flower 2021-08-19 00:33:39,621 | server.py:255 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-08-19 00:34:14,715 | server.py:264 | fit_round received 0 results and 2 failures


Files already downloaded and verified
Files already downloaded and verified


INFO flower 2021-08-19 00:34:23,862 | server.py:154 | fit progress: (1, 0.07206574760913849, {'accuracy': 0.09948}, 44.24102469906211)
INFO flower 2021-08-19 00:34:23,863 | server.py:199 | evaluate_round: no clients selected, cancel
DEBUG flower 2021-08-19 00:35:38,367 | server.py:255 | fit_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-08-19 00:41:25,813 | server.py:264 | fit_round received 2 results and 0 failures


Files already downloaded and verified
Files already downloaded and verified


INFO flower 2021-08-19 00:41:35,033 | server.py:154 | fit progress: (2, 0.01976533472776413, {'accuracy': 0.8364}, 475.4119321871549)
INFO flower 2021-08-19 00:41:35,034 | server.py:199 | evaluate_round: no clients selected, cancel
INFO flower 2021-08-19 00:41:35,035 | server.py:172 | FL finished in 475.41344188153744
INFO flower 2021-08-19 00:41:35,036 | app.py:109 | app_fit: losses_distributed []
INFO flower 2021-08-19 00:41:35,037 | app.py:110 | app_fit: metrics_distributed {}
INFO flower 2021-08-19 00:41:35,038 | app.py:111 | app_fit: losses_centralized [(0, 0.0720656501197815), (1, 0.07206574760913849), (2, 0.01976533472776413)]
INFO flower 2021-08-19 00:41:35,038 | app.py:112 | app_fit: metrics_centralized {'accuracy': [(0, 0.09948), (1, 0.09948), (2, 0.8364)]}
DEBUG flower 2021-08-19 00:41:35,038 | server.py:205 | evaluate_round: strategy sampled 2 clients (out of 2)
DEBUG flower 2021-08-19 00:41:37,693 | server.py:214 | evaluate_round received 2 results and 0 failures
INFO flow