In [1]:
from collections import OrderedDict
import warnings

import flwr as fl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import transforms

# Importar o MedMNIST
from medmnist import OrganMNIST3D

  from .autonotebook import tqdm as notebook_tqdm
2025-06-11 18:14:22,578	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [2]:
class ToTensor:
    def __call__(self, x):
        return torch.from_numpy(x).float()

In [5]:
# Desativar um aviso comum do Matplotlib no MedMNIST
warnings.filterwarnings("ignore", category=UserWarning)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 1. Definição do Modelo (uma CNN simples)
class Net(nn.Module):
    def __init__(self, in_channels=1, num_classes=11):
        super(Net, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv3d(in_channels, 16, kernel_size=3),
            nn.BatchNorm3d(16),
            nn.ReLU())
        self.layer2 = nn.Sequential(
            nn.Conv3d(16, 16, kernel_size=3),
            nn.BatchNorm3d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer3 = nn.Sequential(
            nn.Conv3d(16, 64, kernel_size=3),
            nn.BatchNorm3d(64),
            nn.ReLU())
        self.layer4 = nn.Sequential(
            nn.Conv3d(64, 64, kernel_size=3),
            nn.BatchNorm3d(64),
            nn.ReLU())
        self.layer5 = nn.Sequential(
            nn.Conv3d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm3d(64),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=2, stride=2))
        self.fc = nn.Sequential(
            nn.Linear(64 * 4 * 4, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes))

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# 2. Funções de Treino e Teste (padrão)
def train(net, trainloader, epochs):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    for _ in range(epochs):
        for images, labels in trainloader:
            optimizer.zero_grad()
            criterion(net(images.to(DEVICE)), labels.squeeze().long().to(DEVICE)).backward()
            optimizer.step()

def test(net, testloader):
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    with torch.no_grad():
        for images, labels in testloader:
            outputs = net(images.to(DEVICE))
            labels = labels.squeeze().long().to(DEVICE)
            loss += criterion(outputs, labels).item()
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
    return loss / len(testloader.dataset), correct / total

# 3. Preparação e Particionamento dos Dados (A PARTE MAIS IMPORTANTE)
def load_data(num_partitions):
    # Transformações para o dataset
    data_transform = transforms.Compose([ToTensor()])
    
    # Carregar o dataset de treino completo do MedMNIST
    full_train_dataset = OrganMNIST3D(split="train", transform=data_transform, download=True)

    # Calcular o tamanho de cada partição
    num_images = len(full_train_dataset)
    partition_size = num_images // num_partitions
    lengths = [partition_size] * num_partitions
    remainder = num_images % num_partitions
    for i in range(remainder):
        lengths[i] += 1
    
    # Usar a função do PyTorch para dividir o dataset em partições não sobrepostas
    # Esta é a alternativa ao Partitioner do Flower
    partitions = random_split(full_train_dataset, lengths)

    # Criar um DataLoader para cada partição
    train_loaders = [DataLoader(part, batch_size=32, shuffle=True) for part in partitions]
    
    # Carregar o dataset de teste (geralmente é centralizado e não particionado)
    test_loader = DataLoader(OrganMNIST3D(split="test", download=True), batch_size=32)
    
    return train_loaders, test_loader

# 4. Definição do Cliente Flower
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, model, trainloader, testloader):
        self.model = model
        self.trainloader = trainloader
        self.testloader = testloader

    def get_parameters(self, config):
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

    def set_parameters(self, parameters):
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        train(self.model, self.trainloader, epochs=1)
        return self.get_parameters(config={}), len(self.trainloader.dataset), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        loss, accuracy = test(self.model, self.testloader)
        return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)}

# 5. Função para criar clientes (client_fn)
def client_fn(cid: str) -> FlowerClient:
    """Cria um Flower client para um dado client ID."""
    model = Net().to(DEVICE)
    # Cada cliente recebe seu próprio DataLoader de treino
    train_loader = train_loaders[int(cid)]
    # O testloader pode ser compartilhado
    test_loader = test_loader_global
    
    return FlowerClient(model, train_loader, test_loader)

# 6. Início da Simulação
if __name__ == "__main__":
    NUM_CLIENTS = 10
    
    # Carregar e particionar os dados ANTES da simulação
    train_loaders, test_loader_global = load_data(num_partitions=NUM_CLIENTS)

    # Definir a estratégia de agregação (FedAvg)
    strategy = fl.server.strategy.FedAvg(
        fraction_fit=1.0,  # Usar 100% dos clientes para treino em cada rodada
        min_fit_clients=NUM_CLIENTS,
        min_available_clients=NUM_CLIENTS,
    )

    # Iniciar a simulação
    fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=NUM_CLIENTS,
        config=fl.server.ServerConfig(num_rounds=3),
        strategy=strategy,
    )

	Instead, use the `flwr run` CLI command to start a local simulation in your Flower app, as shown for example below:

		$ flwr new  # Create a new Flower app from a template

		$ flwr run  # Run the Flower app in Simulation Mode

	Using `start_simulation()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower simulation, config: num_rounds=3, no round_timeout
2025-06-11 18:31:10,205	INFO worker.py:1771 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'accelerator_type:G': 1.0, 'node:__internal_head__': 1.0, 'node:143.106.45.41': 1.0, 'CPU': 12.0, 'object_store_memory': 15811032268.0, 'memory': 31622064539.0, 'GPU': 1.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      No `client_resources` specified. Using minimal resources

RuntimeError: Simulation crashed.