In [4]:
from collections import OrderedDict
from typing import List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from datasets.utils.logging import disable_progress_bar
from torch.utils.data import DataLoader

import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Metrics, Context
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAvg
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")
disable_progress_bar()

2025-08-07 15:50:56,686	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


Training on cuda
Flower 1.20.0 / PyTorch 2.8.0+cu128


In [None]:
NUM_CLIENTS = 10
BATCH_SIZE = 32

def load_datasets(partition_id: int):
    """Load datasets for a given partition ID."""
    fds = FederatedDataset(
        dataset="cifar10",
        partitioners = {
            "train": NUM_CLIENTS
        }
    )
    
    partition = fds.load_partition(partition_id)

    """Divide the partition into 80% train and 20% test sets."""

    partition_train_test = partition.train_test_split(
        train_size=0.8,
        seed=42
    )
    
    pytorch_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
    )

    def apply_transform(batch):
        """Instead of passing transforms to CIFAR10(..., transform=transform)
        we will use this function to dataset.with_transform(apply_transforms)
        The transforms object is exactly the same"""
        batch["img"] = [pytorch_transform(img) for img in batch["img"]]
        return batch
    
    """Create train and test datasets with the applied transforms for each partition,
    and wrap them in DataLoader."""

    partition_train_test = partition_train_test.with_transform(apply_transform)
    trainloader = DataLoader(
        partition_train_test["train"], 
        batch_size=BATCH_SIZE,
        shuffle=True
    )
    """Create a test DataLoader without shuffling
    to ensure that the test data is processed in the same order each time.
    This is important for evaluation consistency."""
    valloader = DataLoader(
        partition_train_test["test"],
        batch_size=BATCH_SIZE,
    )

    testset = fds.load_dataset("cifar10", "test")
    testset = testset.with_transform(apply_transform)
    testloader = DataLoader(
        testset,
        batch_size=BATCH_SIZE
    )

    return trainloader, valloader, testloader