Carregando os pacotes necessários:

In [1]:
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import numpy as np

import flwr as fl
import pandas as pd

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
acc_train = []
acc_test = []
loss_train = []
loss_test = []

Carregando o dataset CIFAR-10 (conjuntos de treinamento e teste):

In [2]:
def load_data():
    transform = transforms.Compose([transforms.ToTensor(), 
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    trainset = CIFAR10(".", train = True, download = True, transform = transform)
    testset = CIFAR10(".", train = False, download = True, transform = transform)
    trainloader = DataLoader(trainset, batch_size = 8, shuffle = True)
    testloader = DataLoader(testset, batch_size = 8)
    return trainloader, testloader

### Redes Neurais

Modelo 1:

In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3,
                               padding = 1, 
                               stride = 1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(64, 128, 3,
                               padding = 1)
        self.conv3 = nn.Conv2d(128, 256, 3,
                               padding = 1)
        self.conv4 = nn.Conv2d(256, 512, 3,
                               padding = 1)
        self.fc1 = nn.Linear(256, 10)
        self.bn1 = nn.BatchNorm2d(num_features = 256)
        #self.bn2 = nn.BatchNorm2d(num_features = 512)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.bn1(self.conv3(x))))
        #x = self.pool(F.relu(self.conv4(x)))
        x = x.view(-1, 256)
        x = self.fc1(x)
        return x

Modelo 2: 

In [4]:
class CNN(nn.Module):
    """Simple CNN adapted"""

    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(12544, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return x

Função de treinamento da rede:

In [5]:
def train(net, trainloader, epochs):
    criterion = torch.nn.CrossEntropyLoss()
    #criterion = NoPeekLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr = 0.001, momentum = 0.9)
    
    DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Training {epochs} epoch(s) w/ {len(trainloader)} mini-batches each")
    
    for epoch in range(epochs):
        print()
        loss_epoch: float = 0.0
        correct: float = 0.0
        num_examples_train: int = 0
        batch_idx: int = 0
        total: int = 0
            
        for images, labels in trainloader:
            
            if batch_idx < len(trainloader)-1:
                images, labels = images.to(DEVICE), labels.to(DEVICE)

                num_examples_train += len(images)
                batch_idx += 1

                optimizer.zero_grad()
                outputs = net(images)
                loss = criterion(outputs, labels)   
                #intermediate_parameters = []
                #for param in net.parameters():
                #    intermediate_parameters.append(param.view(-1))
                #intermediate_parameters = torch.cat(intermediate_parameters)                    
                #loss = criterion(images, intermediate_parameters, net(images), labels).item() 
                loss.backward()
                optimizer.step()
                loss_epoch += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)
                if batch_idx % 10 == 8:
                    print(
                        "Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}\t\t\t\t\t\t\t\t".format(
                            epoch,
                            num_examples_train,
                            len(trainloader) * trainloader.batch_size,
                            100.0
                            * num_examples_train
                            / len(trainloader)
                            / trainloader.batch_size,
                            loss.item(),
                        ),
                        end="\r",
                        flush=True,
                    )
            
        loss_train.append(loss.item())
        acc_train.append(correct / total    )

Definindo a função de teste do servidor:

In [6]:
def test(net, testloader):
    criterion = torch.nn.CrossEntropyLoss()
    #criterion = NoPeekLoss()
    correct, total, loss = 0, 0, 0.0
    
    DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(DEVICE), data[1].to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()    
            #intermediate_parameters = []
            #for param in net.parameters():
            #    intermediate_parameters.append(param.view(-1))
            #intermediate_parameters = torch.cat(intermediate_parameters)
            #loss += criterion(images, intermediate_parameters, outputs, labels).item()    
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total    
    acc_test.append(accuracy)
    loss_test.append(loss/len(testloader))
    return loss/len(testloader), accuracy

Definindo a função cliente:

In [7]:
from torchvision import models

class CifarClient(fl.client.NumPyClient):
    def __init__(self, cid, train_loader, test_loader, epochs, device: torch.device = torch.device(DEVICE)):
        self.model = CNN().to(device)   
        
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.device = device
        self.epochs = epochs

    def set_parameters(self, parameters):
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=True)

    def set_weights(self, weights):
        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in zip(self.model.state_dict().keys(), weights)})
        self.model.load_state_dict(state_dict, strict = True)
        
    def get_parameters(self) -> fl.common.Weights:
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        #intermediate_parameters  = self.get_weights()
        train(self.model, self.train_loader, epochs = self.epochs)
        return self.get_parameters(), len(self.train_loader), {}

    def evaluate(self, parameters, config):        
        self.set_parameters(parameters)
        loss, accuracy = test(self.model, self.test_loader)   
        return float(loss), len(self.test_loader), {"accuracy":float(accuracy)}

Executando o servidor:

In [8]:
current_client = 1
n_clients = 2
train_batch_size = 32
test_batch_size = 1000
epochs = 10

train_loader, test_loader = load_data()

df = pd.DataFrame(list())
df.to_csv('accuracy_1.csv')
df.to_csv('loss_1.csv')

client = CifarClient(
        cid = current_client,
        train_loader = train_loader,
        test_loader = test_loader,
        epochs = epochs,
        device = DEVICE,
)

fl.client.start_numpy_client("[::]:8081", client = client)

Files already downloaded and verified
Files already downloaded and verified


DEBUG flower 2021-08-19 00:35:35,280 | connection.py:36 | ChannelConnectivity.IDLE
INFO flower 2021-08-19 00:35:35,281 | app.py:61 | Opened (insecure) gRPC connection
DEBUG flower 2021-08-19 00:35:35,281 | connection.py:36 | ChannelConnectivity.READY


Training 10 epoch(s) w/ 6250 mini-batches each


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)




DEBUG flower 2021-08-19 00:41:37,703 | connection.py:68 | Insecure gRPC channel closed
INFO flower 2021-08-19 00:41:37,704 | app.py:72 | Disconnect and shut down


In [9]:
print("Accuracy Test", acc_test)
print("Loss Train", loss_train)
print("Loss Test", loss_test)

Accuracy Test [0.6213]
Loss Train [1.650738000869751, 1.6103835105895996, 0.7185038924217224, 0.8863129615783691, 0.7790144085884094, 0.6811245083808899, 1.0082215070724487, 0.9230215549468994, 1.0828007459640503, 1.1398513317108154]
Loss Test [1.0870145292401314]
