In [1]:
from collections import OrderedDict
import warnings
from typing import Dict, List, Optional, Tuple

import flwr as fl
from flwr.common import ndarrays_to_parameters, NDArrays, Scalar, Context
from flwr.client import Client, ClientApp, NumPyClient
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAvg, FedAdagrad
from flwr.simulation import run_simulation

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import transforms

# Importar o MedMNIST
from medmnist import OrganMNIST3D

In [2]:
# Desativar um aviso comum do Matplotlib no MedMNIST
warnings.filterwarnings("ignore", category=UserWarning)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
NUM_PARTITIONS = 3
BATCH_SIZE = 32

In [4]:
# Torch ToTensor() não lida com imagens volumétricas
class ToTensor:
    def __call__(self, x):
        return torch.from_numpy(x).float()

In [5]:
# 1. Definição do Modelo (uma CNN simples)
class Net(nn.Module):
    def __init__(self, in_channels=1, num_classes=11):
        super(Net, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv3d(in_channels, 16, kernel_size=3),
            nn.BatchNorm3d(16),
            nn.ReLU())
        self.layer2 = nn.Sequential(
            nn.Conv3d(16, 16, kernel_size=3),
            nn.BatchNorm3d(16),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2, stride=2))
        self.layer3 = nn.Sequential(
            nn.Conv3d(16, 64, kernel_size=3),
            nn.BatchNorm3d(64),
            nn.ReLU())
        self.layer4 = nn.Sequential(
            nn.Conv3d(64, 64, kernel_size=3),
            nn.BatchNorm3d(64),
            nn.ReLU())
        self.layer5 = nn.Sequential(
            nn.Conv3d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm3d(64),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2, stride=2))
        self.fc = nn.Sequential(
            nn.Linear(64 * 4**3, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes))

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [6]:
# 2. Funções de Treino e Teste (padrão)
def train(net, trainloader, epochs):
    criterion = torch.nn.CrossEntropyLoss()
    # optimizer = torch.optim.SGD(net.parameters(), lr=0.1e-5, momentum=0.9)
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
    net.train()
    for _ in range(epochs):
        for images, labels in trainloader:
            optimizer.zero_grad()
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            output = net(images)
            loss = criterion(output, labels.long().squeeze(dim=1))
            loss.backward()
            optimizer.step()

In [7]:

def test(net, testloader):
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.cuda().float(), labels.cuda().long()
            outputs = net(images)
            labels = labels.squeeze()
            loss += criterion(outputs, labels).item()
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
    return loss / len(testloader.dataset), correct / total

In [8]:
# 3. Preparação e Particionamento dos Dados (A PARTE MAIS IMPORTANTE)
def load_data(partition_id: int, num_partitions: int):
    # Transformações para o dataset
    data_transform = transforms.Compose([ToTensor()])
    
    # Carregar o dataset de treino completo do MedMNIST
    full_train_dataset = OrganMNIST3D(split="train", transform=data_transform, download=True)
    full_val_dataset = OrganMNIST3D(split="val", transform=data_transform, download=True)

    def partition_data(data):
        num_images = len(data)
        partition_size = num_images // num_partitions
        lengths = [partition_size] * num_partitions
        # Caso divisão não exata
        remainder = num_images % num_partitions
        for i in range(remainder):
            lengths[i] += 1
        return lengths

    train_lengths = partition_data(full_train_dataset)
    val_lengths = partition_data(full_val_dataset)

    # Usar a função do PyTorch para dividir o dataset em partições não sobrepostas
    # Esta é a alternativa ao Partitioner do Flower
    # random_split não deve gerar amostras únicas para um mesmo cliente, necessário verificar
    train_partitions = random_split(full_train_dataset, train_lengths)
    val_partitions = random_split(full_val_dataset, val_lengths)

    # Criar um DataLoader para cada partição
    train_loaders = [DataLoader(part, batch_size=32, shuffle=True) for part in train_partitions]
    val_loaders = [DataLoader(part, batch_size=32, shuffle=True) for part in val_partitions]
    
    # Carregar o dataset de teste (geralmente é centralizado e não particionado)
    test_loader = DataLoader(OrganMNIST3D(split="test", download=True), batch_size=32)

    return train_loaders[partition_id], val_loaders[partition_id], test_loader

In [9]:
trainloader, valloader, testloader = load_data(0, 1)
print(f"Train samples: {sum([len(trainloader.dataset)])}")
print(f"Val samples: {sum([len(valloader.dataset)])}")
print(f"Test samples: {len(testloader.dataset)}")

Train samples: 971
Val samples: 161
Test samples: 610


In [10]:
trainloader, valloader, testloader = load_data(0, 2)
print(f"Train samples: {sum([len(trainloader.dataset)])}")
print(f"Val samples: {sum([len(valloader.dataset)])}")
print(f"Test samples: {len(testloader.dataset)}")

Train samples: 486
Val samples: 81
Test samples: 610


In [11]:
def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]

def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)

In [12]:
# 4. Definição do Cliente Flower
class FlowerClient(NumPyClient):
    def __init__(self, net, trainloader, testloader):
        self.net = net
        self.trainloader = trainloader
        self.testloader = testloader

    def get_parameters(self, config):
        return get_parameters(self.net)

    def fit(self, parameters, config):
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=1)
        return get_parameters(self.net), len(self.trainloader.dataset), {}

    def evaluate(self, parameters, config):
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.testloader)
        return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)}

In [13]:
# 5. Função para criar clientes (client_fn)
def client_fn(context: Context) -> Client:
    """Cria um Flower client para um dado client ID."""
    net = Net().to(DEVICE)
    # Cada cliente recebe seu próprio DataLoader de treino
    partition_id = context.node_config['partition-id']
    num_partitions = context.node_config['num-partitions']

    train_loader, val_loader, _ = load_data(partition_id=partition_id, num_partitions=num_partitions)
    
    return FlowerClient(net, train_loader, val_loader).to_client()

client = ClientApp(client_fn=client_fn)

In [17]:
params = get_parameters(Net())
def server_fn(context: Context) -> ServerAppComponents:
    strategy = FedAvg(
        fraction_fit = 0.3,
        fraction_evaluate = 0.3,
        min_fit_clients = 3,
        min_evaluate_clients = 3,
        min_available_clients = NUM_PARTITIONS,
        initial_parameters = ndarrays_to_parameters(params),
    )
    config = ServerConfig(num_rounds=10)
    return ServerAppComponents(strategy=strategy, config=config)

server = ServerApp(server_fn=server_fn)

In [18]:
# 6. Início da Simulação
NUM_CLIENTS = 3

backend_config = {"client_resources": None}
if DEVICE.type == "cuda":
    backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 1}}

In [19]:
# Iniciar a simulação
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=10, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[92mINFO [0m:      aggregate