In [1]:
from collections import OrderedDict
import warnings
from typing import Dict, List, Optional, Tuple

from flwr.common import ndarrays_to_parameters, NDArrays, Scalar, Context
from flwr.client import Client, ClientApp, NumPyClient
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAvg, FedAdagrad
from flwr.simulation import run_simulation

import numpy as np
import torch
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import transforms

# Importar o MedMNIST
from medmnist import OrganMNIST3D

In [2]:
from medical_fl import train, test, evaluate
from medical_fl.data import load_data_iid, load_data_niid
from medical_fl.utils import get_parameters, set_parameters
from medical_fl.model import GenericCNN as Net
from medical_fl.client import FlowerClient


In [3]:
# Desativar um aviso comum do Matplotlib no MedMNIST
warnings.filterwarnings("ignore", category=UserWarning)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
NUM_PARTITIONS = 5
BATCH_SIZE = 32

In [4]:
# Torch ToTensor() não lida com imagens volumétricas
class ToTensor:
    def __call__(self, x):
        return torch.from_numpy(x).float()

In [5]:
def print_samples_per_client(NUM_PARTITIONS = 3, iid = False):
    for i in range(NUM_PARTITIONS):
        trainloader, valloader, testloader = load_data_iid(i, NUM_PARTITIONS, transforms=ToTensor()) if iid else load_data_niid(i, NUM_PARTITIONS, transforms=ToTensor(), alpha=0.5)
        print(f"Train samples: {sum([len(trainloader.dataset)])}")
        print(f"Val samples: {sum([len(valloader.dataset)])}")
        print(f"Test samples: {len(testloader.dataset)}")
        print(f"Distribuição: {np.unique(trainloader.dataset[:][1], return_counts=True)}")
        print("____________________________________________________________")
print_samples_per_client(NUM_PARTITIONS=NUM_PARTITIONS, iid = False)


Train samples: 280
Val samples: 47
Test samples: 610
Distribuição: (array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10]), array([69, 26,  4, 61,  1,  5,  4, 17, 59, 34]))
____________________________________________________________
Train samples: 248
Val samples: 42
Test samples: 610
Distribuição: (array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), array([74,  5, 19, 72,  1, 45, 14,  8,  9,  1]))
____________________________________________________________
Train samples: 144
Val samples: 20
Test samples: 610
Distribuição: (array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]), array([ 1, 11, 30, 14,  9, 20,  2,  2,  6,  4, 45]))
____________________________________________________________
Train samples: 128
Val samples: 22
Test samples: 610
Distribuição: (array([ 0,  1,  2,  4,  5,  6,  7,  8,  9, 10]), array([ 9,  7,  5, 15,  7,  2, 21,  4, 50,  8]))
____________________________________________________________
Train samples: 171
Val samples: 30
Test samples: 610
Distribuição: (array([ 0,  1,  2,  3,  4,

In [6]:
'''# --- 4. Verificação Visual (Opcional, mas muito útil) ---

def plot_client_distributions(partitions, dataset, num_clients_to_show=5):
    labels = np.array(dataset.targets)
    num_classes = len(np.unique(labels))
    
    fig, ax = plt.subplots(figsize=(12, 6))
    
    for i in range(min(num_clients_to_show, len(partitions))):
        client_indices = partitions[i]
        client_labels = labels[client_indices]
        # Conta a ocorrência de cada classe para este cliente
        label_counts = np.bincount(client_labels, minlength=num_classes)
        ax.bar(np.arange(num_classes) + i * 0.1, label_counts, width=0.1, label=f'Cliente {i}')
        
    ax.set_xticks(np.arange(num_classes))
    ax.set_xticklabels(tmp.info['labels'].items(), rotation=45, ha="right")
    ax.set_ylabel("Número de Amostras")
    ax.set_title(f"Distribuição de Rótulos por Cliente (alpha={ALPHA})")
    ax.legend()
    plt.tight_layout()
    plt.show()

plot_client_distributions(partitions, tmp)

# Se você mudar para um alpha grande, verá barras muito mais uniformes
# partitions_iid = partition_dataset_by_dirichlet(trainset, NUM_CLIENTS, alpha=100)
# plot_client_distributions(partitions_iid, trainset, alpha=100)

# --- 5. Como Usar as Partições ---

# Para criar um DataLoader para um cliente específico (ex: cliente 0)
client_0_indices = partitions[0]
client_0_dataset = Subset(tmp, client_0_indices)
client_0_loader = torch.utils.data.DataLoader(client_0_dataset, batch_size=32, shuffle=True)

# Agora você pode usar `client_0_loader` para treinar o modelo do cliente 0'''

'# --- 4. Verificação Visual (Opcional, mas muito útil) ---\n\ndef plot_client_distributions(partitions, dataset, num_clients_to_show=5):\n    labels = np.array(dataset.targets)\n    num_classes = len(np.unique(labels))\n\n    fig, ax = plt.subplots(figsize=(12, 6))\n\n    for i in range(min(num_clients_to_show, len(partitions))):\n        client_indices = partitions[i]\n        client_labels = labels[client_indices]\n        # Conta a ocorrência de cada classe para este cliente\n        label_counts = np.bincount(client_labels, minlength=num_classes)\n        ax.bar(np.arange(num_classes) + i * 0.1, label_counts, width=0.1, label=f\'Cliente {i}\')\n\n    ax.set_xticks(np.arange(num_classes))\n    ax.set_xticklabels(tmp.info[\'labels\'].items(), rotation=45, ha="right")\n    ax.set_ylabel("Número de Amostras")\n    ax.set_title(f"Distribuição de Rótulos por Cliente (alpha={ALPHA})")\n    ax.legend()\n    plt.tight_layout()\n    plt.show()\n\nplot_client_distributions(partitions, tmp)\n

In [23]:
# 5. Função para criar clientes (client_fn)
def client_fn(context: Context) -> Client:
    """Cria um Flower client para um dado client ID."""
    net = Net().to(DEVICE)
    # Cada cliente recebe seu próprio DataLoader de treino
    partition_id = context.node_config['partition-id']
    num_partitions = context.node_config['num-partitions']

    train_loader, val_loader, _ = load_data_iid(partition_id=partition_id, num_partitions=num_partitions, transforms=ToTensor())
    
    return FlowerClient(partition_id, net, train_loader, val_loader).to_client()

client = ClientApp(client_fn=client_fn)

In [24]:
def fit_config(server_round: int):
    """Return training configuration dict for each round

    Perform two rounds of training with one local epoch, increase to two local
    epochs afterwards.
    """
    config = {
        "server_round": server_round,
        "local_epochs": 3,
        }
    return config

In [9]:
test = load_data_niid(partition_id=0, num_partitions=5, transforms=ToTensor())

In [22]:
for i in range(5):
    test = load_data_niid(partition_id=i, num_partitions=5, transforms=ToTensor())
    print('_______________________')
    print(i)
    for img, label in test[1]:
        print(len(label))

_______________________
0
19
_______________________
1
32
2
_______________________
2
32
1
_______________________
3
28
_______________________
4
32
15


In [17]:
19+32+32+2+1+28+32+15

161

In [11]:
torch.tensor(test[1].dataset[0][1]).long()

tensor([10])

In [25]:
params = get_parameters(Net())

def server_fn(context: Context) -> ServerAppComponents:
    strategy = FedAvg(
        fraction_fit = 0.3,
        fraction_evaluate = 0.3,
        min_fit_clients = 3,
        min_evaluate_clients = 3,
        min_available_clients = NUM_PARTITIONS,
        initial_parameters = ndarrays_to_parameters(params),
        evaluate_fn=evaluate,
        on_fit_config_fn=fit_config,
    )
    config = ServerConfig(num_rounds=10)
    return ServerAppComponents(strategy=strategy, config=config)

server = ServerApp(server_fn=server_fn)

In [None]:
# 6. Início da Simulação
NUM_CLIENTS = NUM_PARTITIONS

backend_config = {"client_resources": None}
if DEVICE.type == "cuda":
    backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 1}}

# Iniciar a simulação
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_CLIENTS,
    backend_config=backend_config,
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=10, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      initial parameters (loss, other metrics): 0.07878883236744365, {'accuracy': 0.11311475409836065}
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 5)


32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
2
Server-side evaluation loss 0.07878883236744365 / accuracy 0.11311475409836065
[36m(ClientAppActor pid=245517)[0m [Client 1, round 1] fit, config: {'server_round': 1, 'local_epochs': 3}
[36m(ClientAppActor pid=245517)[0m [Client 2, round 1] fit, config: {'server_round': 1, 'local_epochs': 3}
[36m(ClientAppActor pid=245517)[0m [Client 4, round 1] fit, config: {'server_round': 1, 'local_epochs': 3}


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      fit progress: (1, 0.07741040988046614, {'accuracy': 0.21311475409836064}, 7.511805022018962)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 5)


32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
2
Server-side evaluation loss 0.07741040988046614 / accuracy 0.21311475409836064


[91mERROR [0m:     An exception was raised when processing a message by RayBackend
[91mERROR [0m:     [36mray::ClientAppActor.run()[39m (pid=245517, ip=143.106.45.41, actor_id=8a08bd26e97a5adcfabd6c9d01000000, repr=<flwr.simulation.ray_transport.ray_actor.ClientAppActor object at 0x7f62a379a360>)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/arthur/miniconda3/envs/flwr/lib/python3.12/site-packages/flwr/client/client_app.py", line 144, in __call__
    return self._call(message, context)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/arthur/miniconda3/envs/flwr/lib/python3.12/site-packages/flwr/client/client_app.py", line 128, in ffn
    out_message = handle_legacy_message_from_msgtype(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/arthur/miniconda3/envs/flwr/lib/python3.12/site-packages/flwr/client/message_handler/message_handler.py", line 135, in handle_legacy_message_from_msgtype
    evaluate_res = maybe_call_evaluate(
       

[36m(ClientAppActor pid=245517)[0m [Client 0] evaluate, config: {}
[36m(ClientAppActor pid=245517)[0m 32
[36m(ClientAppActor pid=245517)[0m 1
[36m(ClientAppActor pid=245517)[0m [Client 1] evaluate, config: {}
[36m(ClientAppActor pid=245517)[0m 32


[92mINFO [0m:      aggregate_evaluate: received 2 results and 1 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 5)


[36m(ClientAppActor pid=245517)[0m [Client 4] evaluate, config: {}
[36m(ClientAppActor pid=245517)[0m 32
[36m(ClientAppActor pid=245517)[0m [Client 1, round 2] fit, config: {'server_round': 2, 'local_epochs': 3}
[36m(ClientAppActor pid=245517)[0m [Client 2, round 2] fit, config: {'server_round': 2, 'local_epochs': 3}
[36m(ClientAppActor pid=245517)[0m [Client 4, round 2] fit, config: {'server_round': 2, 'local_epochs': 3}


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      fit progress: (2, 0.06334711625927784, {'accuracy': 0.30327868852459017}, 11.02219041902572)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 5)


32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
2
Server-side evaluation loss 0.06334711625927784 / accuracy 0.30327868852459017


[91mERROR [0m:     An exception was raised when processing a message by RayBackend
[91mERROR [0m:     [36mray::ClientAppActor.run()[39m (pid=245517, ip=143.106.45.41, actor_id=8a08bd26e97a5adcfabd6c9d01000000, repr=<flwr.simulation.ray_transport.ray_actor.ClientAppActor object at 0x7f62a379a360>)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/arthur/miniconda3/envs/flwr/lib/python3.12/site-packages/flwr/client/client_app.py", line 144, in __call__
    return self._call(message, context)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/arthur/miniconda3/envs/flwr/lib/python3.12/site-packages/flwr/client/client_app.py", line 128, in ffn
    out_message = handle_legacy_message_from_msgtype(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/arthur/miniconda3/envs/flwr/lib/python3.12/site-packages/flwr/client/message_handler/message_handler.py", line 135, in handle_legacy_message_from_msgtype
    evaluate_res = maybe_call_evaluate(
       

[36m(ClientAppActor pid=245517)[0m [Client 0] evaluate, config: {}
[36m(ClientAppActor pid=245517)[0m 32
[36m(ClientAppActor pid=245517)[0m 1
[36m(ClientAppActor pid=245517)[0m [Client 2] evaluate, config: {}
[36m(ClientAppActor pid=245517)[0m 32


[92mINFO [0m:      aggregate_evaluate: received 2 results and 1 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 5)


[36m(ClientAppActor pid=245517)[0m [Client 4] evaluate, config: {}
[36m(ClientAppActor pid=245517)[0m 32
[36m(ClientAppActor pid=245517)[0m [Client 0, round 3] fit, config: {'server_round': 3, 'local_epochs': 3}
[36m(ClientAppActor pid=245517)[0m [Client 2, round 3] fit, config: {'server_round': 3, 'local_epochs': 3}
[36m(ClientAppActor pid=245517)[0m [Client 4, round 3] fit, config: {'server_round': 3, 'local_epochs': 3}


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      fit progress: (3, 0.04456765417192803, {'accuracy': 0.6262295081967213}, 14.632271370966919)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 5)


32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
2
Server-side evaluation loss 0.04456765417192803 / accuracy 0.6262295081967213
[36m(ClientAppActor pid=245517)[0m [Client 1] evaluate, config: {}
[36m(ClientAppActor pid=245517)[0m 32
[36m(ClientAppActor pid=245517)[0m [Client 2] evaluate, config: {}
[36m(ClientAppActor pid=245517)[0m 32


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 4]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 5)


[36m(ClientAppActor pid=245517)[0m [Client 4] evaluate, config: {}
[36m(ClientAppActor pid=245517)[0m 32
[36m(ClientAppActor pid=245517)[0m [Client 0, round 4] fit, config: {'server_round': 4, 'local_epochs': 3}
[36m(ClientAppActor pid=245517)[0m [Client 2, round 4] fit, config: {'server_round': 4, 'local_epochs': 3}
[36m(ClientAppActor pid=245517)[0m [Client 4, round 4] fit, config: {'server_round': 4, 'local_epochs': 3}


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      fit progress: (4, 0.03660112257863654, {'accuracy': 0.660655737704918}, 18.238490827032365)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 5)


32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
2
Server-side evaluation loss 0.03660112257863654 / accuracy 0.660655737704918
[36m(ClientAppActor pid=245517)[0m [Client 2] evaluate, config: {}
[36m(ClientAppActor pid=245517)[0m 32
[36m(ClientAppActor pid=245517)[0m [Client 3] evaluate, config: {}
[36m(ClientAppActor pid=245517)[0m 32


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 5]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 5)


[36m(ClientAppActor pid=245517)[0m [Client 4] evaluate, config: {}
[36m(ClientAppActor pid=245517)[0m 32
[36m(ClientAppActor pid=245517)[0m [Client 0, round 5] fit, config: {'server_round': 5, 'local_epochs': 3}
[36m(ClientAppActor pid=245517)[0m [Client 2, round 5] fit, config: {'server_round': 5, 'local_epochs': 3}
[36m(ClientAppActor pid=245517)[0m [Client 3, round 5] fit, config: {'server_round': 5, 'local_epochs': 3}


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      fit progress: (5, 0.03296549916267395, {'accuracy': 0.7081967213114754}, 21.945261709974147)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 5)


32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
2
Server-side evaluation loss 0.03296549916267395 / accuracy 0.7081967213114754


[91mERROR [0m:     An exception was raised when processing a message by RayBackend
[91mERROR [0m:     [36mray::ClientAppActor.run()[39m (pid=245517, ip=143.106.45.41, actor_id=8a08bd26e97a5adcfabd6c9d01000000, repr=<flwr.simulation.ray_transport.ray_actor.ClientAppActor object at 0x7f62a379a360>)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/arthur/miniconda3/envs/flwr/lib/python3.12/site-packages/flwr/client/client_app.py", line 144, in __call__
    return self._call(message, context)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/arthur/miniconda3/envs/flwr/lib/python3.12/site-packages/flwr/client/client_app.py", line 128, in ffn
    out_message = handle_legacy_message_from_msgtype(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/arthur/miniconda3/envs/flwr/lib/python3.12/site-packages/flwr/client/message_handler/message_handler.py", line 135, in handle_legacy_message_from_msgtype
    evaluate_res = maybe_call_evaluate(
       

[36m(ClientAppActor pid=245517)[0m [Client 0] evaluate, config: {}
[36m(ClientAppActor pid=245517)[0m 32
[36m(ClientAppActor pid=245517)[0m 1
[36m(ClientAppActor pid=245517)[0m [Client 1] evaluate, config: {}
[36m(ClientAppActor pid=245517)[0m 32


[92mINFO [0m:      aggregate_evaluate: received 2 results and 1 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 6]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 5)


[36m(ClientAppActor pid=245517)[0m [Client 2] evaluate, config: {}
[36m(ClientAppActor pid=245517)[0m 32
[36m(ClientAppActor pid=245517)[0m [Client 1, round 6] fit, config: {'server_round': 6, 'local_epochs': 3}
[36m(ClientAppActor pid=245517)[0m [Client 3, round 6] fit, config: {'server_round': 6, 'local_epochs': 3}
[36m(ClientAppActor pid=245517)[0m [Client 4, round 6] fit, config: {'server_round': 6, 'local_epochs': 3}


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      fit progress: (6, 0.027166647109829013, {'accuracy': 0.7622950819672131}, 25.554735091980547)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 5)


32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
2
Server-side evaluation loss 0.027166647109829013 / accuracy 0.7622950819672131


[91mERROR [0m:     An exception was raised when processing a message by RayBackend
[91mERROR [0m:     [36mray::ClientAppActor.run()[39m (pid=245517, ip=143.106.45.41, actor_id=8a08bd26e97a5adcfabd6c9d01000000, repr=<flwr.simulation.ray_transport.ray_actor.ClientAppActor object at 0x7f62a379a360>)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/arthur/miniconda3/envs/flwr/lib/python3.12/site-packages/flwr/client/client_app.py", line 144, in __call__
    return self._call(message, context)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/arthur/miniconda3/envs/flwr/lib/python3.12/site-packages/flwr/client/client_app.py", line 128, in ffn
    out_message = handle_legacy_message_from_msgtype(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/arthur/miniconda3/envs/flwr/lib/python3.12/site-packages/flwr/client/message_handler/message_handler.py", line 135, in handle_legacy_message_from_msgtype
    evaluate_res = maybe_call_evaluate(
       

[36m(ClientAppActor pid=245517)[0m [Client 0] evaluate, config: {}
[36m(ClientAppActor pid=245517)[0m 32
[36m(ClientAppActor pid=245517)[0m 1
[36m(ClientAppActor pid=245517)[0m [Client 3] evaluate, config: {}
[36m(ClientAppActor pid=245517)[0m 32


[92mINFO [0m:      aggregate_evaluate: received 2 results and 1 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 7]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 5)


[36m(ClientAppActor pid=245517)[0m [Client 4] evaluate, config: {}
[36m(ClientAppActor pid=245517)[0m 32
[36m(ClientAppActor pid=245517)[0m [Client 0, round 7] fit, config: {'server_round': 7, 'local_epochs': 3}
[36m(ClientAppActor pid=245517)[0m [Client 3, round 7] fit, config: {'server_round': 7, 'local_epochs': 3}
[36m(ClientAppActor pid=245517)[0m [Client 4, round 7] fit, config: {'server_round': 7, 'local_epochs': 3}


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      fit progress: (7, 0.02388942256325581, {'accuracy': 0.7918032786885246}, 29.2625474530505)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 5)


32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
32
2
Server-side evaluation loss 0.02388942256325581 / accuracy 0.7918032786885246
