In [12]:
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

import flwr as fl

### Carregamento dos dados

In [14]:
def load_data():
    transform = transforms.Compose([
                   transforms.ToTensor(),
                   transforms.Normalize((0.1307,), (0.3081,))
                ])
    
    trainset = MNIST(root = './data', 
                     train = True, 
                     transform = transform, 
                     download = True)
    testset = MNIST(root = './data', 
                    train = False, 
                    transform = transform, 
                    download = True)

    victim_idx = random.sample(range(trainset.data.shape[0]), k=2000) # seleciona quais dados vão ser passados para o cliente
    victim_train_idx = victim_idx[:1000] # seleciona os 1000 primeiros dados para treinamento 
    attack_idx = victim_idx[1000:] # a segunda metade dos dados será destinada ao atacante
    victim_test_idx = random.sample(range(testset.data.shape[0]), k=15) # seleciona 15 dados para o teste

    victim_train_dataset = Subset(trainset, victim_train_idx)
    attack_dataset = Subset(trainset, attack_idx)
    victim_test_dataset = Subset(testset, victim_test_idx)

    victim_train_dataloader = torch.utils.data.DataLoader(victim_train_dataset, 
                                                          batch_size = 64, 
                                                          shuffle = True)
    attack_dataloader = torch.utils.data.DataLoader(attack_dataset, 
                                                    batch_size = 64, 
                                                    shuffle = True)
    victim_test_dataloader = torch.utils.data.DataLoader(victim_test_dataset, 
                                                         batch_size = 64, 
                                                         shuffle = False)
    
    return victim_train_dataloader, victim_test_dataloader, attack_dataloader

In [15]:
victim_train_dataloader, victim_test_dataloader, attack_dataloader = load_data()

### Funções de treinamento e teste do modelo

In [18]:
def train(net, victim_train_dataloader, epochs):
    """Train the network on the training set."""
    criterion = NoPeekLoss(alpha = 0.8)
    optimizer = optim.Adam(net.parameters(), lr = 1e-3)
    
    for _ in range(epochs):
        for images, labels in victim_train_dataloader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()

In [19]:
def test(net, victim_test_dataloader):
    """Validate the network on the entire test set."""
    criterion = NoPeekLoss(alpha = 0.8) #Analisar como vai ficar o criterio
    correct, total, loss = 0, 0, 0.0
    with torch.no_grad():
        for data in victim_test_dataloader:
            images, labels = data[0].to(DEVICE), data[1].to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    return loss, accuracy

### Modelo

In [27]:
class NoPeekFL(nn.Module):
    """Model for MNIST Classification."""
    def __init__(self) -> None:
        super(NoPeekFL, self).__init__()
        self.conv1 = nn.Conv2d(in_channels = 1, 
                               out_channels = 64,
                               kernel_size = 3, 
                               padding = 1, 
                               stride = 1)        
        self.conv2 = nn.Conv2d(in_channels = 64, 
                               out_channels = 128,
                               kernel_size = 3, 
                               padding = 1)        
        self.conv3 = nn.Conv2d(in_channels = 128, 
                               out_channels = 256,
                               kernel_size = 3, 
                               padding = 1)        
        self.conv4 = nn.Conv2d(in_channels = 256, 
                               out_channels = 512,
                               kernel_size = 3, 
                               padding = 1)
        
        self.bn3 = nn.BatchNorm2d(256)
        self.bn4 = nn.BatchNorm2d(512)
        
        self.L1 = nn.Linear(512, 10) # Temos 10 classes
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 3ch > 64ch, shape 32 x 32 -> 16 x 16
        x = self.conv1(x) # [64,32,32]
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2) # [64,16,16]
        
        # 64ch > 128ch, shape 16 x 16 -> 8 x 8
        x = self.conv2(x) # [128,16,16]
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2) # [128,8,8]
        
        x = self.conv3(x) # [256,8,8]
        x = self.bn3(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2) # [256,4,4]   

        # 256ch > 512ch, shape 4 x 4 -> 2 x 2
        x = self.conv4(x) # [512,4,4]
        x = self.bn4(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2) # [512,2,2]
        
        # camada totalmente conectada
        x = x.view(-1, 512)
        x = self.L1(x)
        return x

In [28]:
net = NoPeekFL()

### Definição do cliente FL

In [30]:
class ClientFL(fl.client.NumPyClient):
    def get_parameters(self):
        return [val.cpu().numpy() for _, val in net.state_dict().items()]

    def set_parameters(self, parameters):
        params_dict = zip(net.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
        net.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        train(net, trainloader, epochs=1)
        return self.get_parameters(), len(trainloader), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        loss, accuracy = test(net, testloader)
        return float(loss), len(testloader), {"accuracy":float(accuracy)}

In [31]:
fl.client.start_numpy_client("[::]:8080", client = ClientFL())

DEBUG flower 2021-07-28 16:13:32,937 | connection.py:36 | ChannelConnectivity.IDLE
INFO flower 2021-07-28 16:13:32,940 | app.py:61 | Opened (insecure) gRPC connection
DEBUG flower 2021-07-28 16:13:32,941 | connection.py:36 | ChannelConnectivity.TRANSIENT_FAILURE
DEBUG flower 2021-07-28 16:13:33,146 | connection.py:68 | Insecure gRPC channel closed


_MultiThreadedRendezvous: <_MultiThreadedRendezvous of RPC that terminated with:
	status = StatusCode.UNAVAILABLE
	details = "failed to connect to all addresses"
	debug_error_string = "{"created":"@1627499612.939922600","description":"Failed to pick subchannel","file":"src/core/ext/filters/client_channel/client_channel.cc","file_line":3008,"referenced_errors":[{"created":"@1627499612.939919500","description":"failed to connect to all addresses","file":"src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc","file_line":397,"grpc_status":14}]}"
>