<a href="https://colab.research.google.com/github/artsasse/fedkan/blob/main/Flower_MNIST_MLP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Using a Federated MLP to classify MNIST

This notebook is based mainly on the Flower Tutorial "Use a federated learning strategy", found in https://flower.ai/docs/framework/tutorial-series-use-a-federated-learning-strategy-pytorch.html .

## Dependencies

In [None]:
pip install -q flwr[simulation] flwr-datasets[vision] torch torchvision

In [None]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAvg, FedAdagrad
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset
from flwr.common import ndarrays_to_parameters, NDArrays, Scalar, Context

# Preciso alterar o runtime para usar GPU (SASSE)
# DEVICE = torch.device("cuda")  # Try "cuda" to train on GPU
DEVICE = torch.device("cpu")  # Run training on CPU
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")

Training on cpu
Flower 1.10.0 / PyTorch 2.3.1+cu121


## Data loading

In [None]:
NUM_PARTITIONS = 10
# SASSE - confirmar o batch size
BATCH_SIZE = 32


def load_datasets(partition_id: int, num_partitions: int):
    fds = FederatedDataset(dataset="mnist", partitioners={"train": num_partitions})
    partition = fds.load_partition(partition_id)
    # Divide data on each node: 80% train, 20% test
    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)

    pytorch_transforms = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5,), (0.5,)),
         transforms.Lambda(lambda x: torch.flatten(x))  # Flatten the image into a 1D tensor
         ]
    )

    def apply_transforms(batch):
        # Instead of passing transforms to MNIST(..., transform=transform)
        # we will use this function to dataset.with_transform(apply_transforms)
        # The transforms object is exactly the same
        batch["image"] = [pytorch_transforms(img) for img in batch["image"]]
        return batch

    partition_train_test = partition_train_test.with_transform(apply_transforms)
    trainloader = DataLoader(
        partition_train_test["train"], batch_size=BATCH_SIZE, shuffle=True
    )
    valloader = DataLoader(partition_train_test["test"], batch_size=BATCH_SIZE)
    testset = fds.load_split("test").with_transform(apply_transforms)
    testloader = DataLoader(testset, batch_size=BATCH_SIZE)
    return trainloader, valloader, testloader

## Model training/evaluation (PyTorch)

In [None]:
# class Net(nn.Module):

#     def __init__(self) -> None:
#         super(Net, self).__init__()
#         self.layer1 = nn.Linear(28 * 28, 200)  # 28 x 28 pixels
#         self.layer2 = nn.Linear(200, 200)  # 2 hidden layers with 200 neurons each
#         self.layer3 = nn.Linear(200, 10)  # 10 classes
#         self.relu = nn.ReLU()
#         self.softmax = nn.Softmax(dim=1)

#     def forward(self, x: torch.Tensor) -> torch.Tensor:
#         x = self.relu(self.layer1(x))
#         x = self.relu(self.layer2(x))
#         x = self.softmax(self.layer3(x))
#         return x

class Net(KAN):
    def __init__(self) -> None:
        super().__init__([28 * 28, 64, 10])


def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]


def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)


def train(net, trainloader, epochs: int):
    """Train the network on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    # Será que Adam e SGD influenciam a KAN? (SASSE)
    # Estão usando o default para learning rate (lr) e momentum
    # Um dos requisitos para garantir a convergencia é lr decrescente
    optimizer = torch.optim.Adam(net.parameters())
    net.train()
    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for batch in trainloader:
            images, labels = batch["image"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(net(images), labels)
            loss.backward()
            optimizer.step()
            # Metrics
            epoch_loss += loss
            total += labels.size(0)
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
        epoch_loss /= len(trainloader.dataset)
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss}, accuracy {epoch_acc}")


def test(net, testloader):
    """Evaluate the network on the entire test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            images, labels = batch["image"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    loss /= len(testloader.dataset)
    accuracy = correct / total
    return loss, accuracy

KAN Model:

In [None]:
pip install git+https://github.com/Blealtan/efficient-kan.git

Collecting git+https://github.com/Blealtan/efficient-kan.git
  Cloning https://github.com/Blealtan/efficient-kan.git to /tmp/pip-req-build-mkgni0oh
  Running command git clone --filter=blob:none --quiet https://github.com/Blealtan/efficient-kan.git /tmp/pip-req-build-mkgni0oh
  Resolved https://github.com/Blealtan/efficient-kan.git to commit 7b6ce1c87f18c8bc90c208f6b494042344216b11
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting pytest>=8.2.0 (from efficient-kan==0.1.0)
  Downloading pytest-8.3.2-py3-none-any.whl.metadata (7.5 kB)
Downloading pytest-8.3.2-py3-none-any.whl (341 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m341.8/341.8 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: efficient-kan
  Building wheel for efficient-kan (pyproject.toml) ... [?25l[?25hdone
  Created wheel for effic

In [None]:
from efficient_kan import KAN

In [None]:
KAN

## Flower Architecture

### Flower client

In [None]:
class FlowerClient(NumPyClient):
    def __init__(self, pid, net, trainloader, valloader):
        self.pid = pid  # partition ID of a client
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        print(f"[Client {self.pid}] get_parameters")
        return get_parameters(self.net)

    def fit(self, parameters, config):
        # Read values from config
        server_round = config["server_round"]
        local_epochs = config["local_epochs"]

        # Use values provided by the config
        print(f"[Client {self.pid}, round {server_round}] fit, config: {config}")
        set_parameters(self.net, parameters)
        train(self.net, self.trainloader, epochs=local_epochs)
        return get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.pid}] evaluate, config: {config}")
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}


def client_fn(context: Context) -> Client:
    net = Net().to(DEVICE)
    partition_id = context.node_config["partition-id"]
    num_partitions = context.node_config["num-partitions"]
    trainloader, valloader, _ = load_datasets(partition_id, num_partitions)
    return FlowerClient(partition_id, net, trainloader, valloader).to_client()


# Create the ClientApp
client = ClientApp(client_fn=client_fn)

### Server-side parameter **initialization**

In [None]:
# Create an instance of the model and get the parameters
params = get_parameters(Net())

### Server-side parameter **evaluation**

In [None]:
# The `evaluate` function will be called by Flower after every round
def evaluate(
    server_round: int,
    parameters: NDArrays,
    config: Dict[str, Scalar],
) -> Optional[Tuple[float, Dict[str, Scalar]]]:
    net = Net().to(DEVICE)
    _, _, testloader = load_datasets(0, NUM_PARTITIONS)
    set_parameters(net, parameters)  # Update model with the latest parameters
    loss, accuracy = test(net, testloader)
    print(f"Server-side evaluation loss {loss} / accuracy {accuracy}")
    return loss, {"accuracy": accuracy}

### Training **Configuration** (e.g. epochs)




In [None]:
def fit_config(server_round: int):
    """Return training configuration dict for each round.

    Perform two rounds of training with one local epoch, increase to two local
    epochs afterwards.
    """
    config = {
        "server_round": server_round,  # The current round of federated learning
        "local_epochs": 1 if server_round < 2 else 2,
    }
    return config

### Flower **Server**

In [None]:
def server_fn(context: Context) -> ServerAppComponents:
    # Create FedAvg strategy
    strategy = FedAvg(
        fraction_fit=0.3,
        fraction_evaluate=0.3,
        min_fit_clients=3,
        min_evaluate_clients=3,
        min_available_clients=NUM_PARTITIONS,
        initial_parameters=ndarrays_to_parameters(params),
        evaluate_fn=evaluate,
        on_fit_config_fn=fit_config,  # Pass the fit_config function
    )
    config = ServerConfig(num_rounds=3)
    return ServerAppComponents(strategy=strategy, config=config)


# Create the ServerApp
server = ServerApp(server_fn=server_fn)

## Simulation

### Run Simulation

In [None]:
NUM_PARTITIONS = 10

In [None]:
# Run simulation
run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=3, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Evaluating initial global parameters
  self.pid = _posixsubprocess.fork_exec(
  self.pid = _posixsubprocess.fork_exec(
[36m(pid=157070)[0m 2024-08-22 23:39:10.352060: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=157070)[0m 2024-08-22 23:39:10.411109: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=157070)[0m 2024-08-22 23:39:10.427705: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for

Server-side evaluation loss 0.07205835869312287 / accuracy 0.1201


[36m(ClientAppActor pid=157070)[0m see the appropriate new directories, set the environment variable
[36m(ClientAppActor pid=157070)[0m `JUPYTER_PLATFORM_DIRS=1` and then run `jupyter --paths`.
[36m(ClientAppActor pid=157070)[0m The use of platformdirs will be the default in `jupyter_core` v6
[36m(ClientAppActor pid=157070)[0m   from jupyter_core.paths import jupyter_data_dir, jupyter_runtime_dir, secure_write


[36m(ClientAppActor pid=157070)[0m [Client 0, round 1] fit, config: {'server_round': 1, 'local_epochs': 3}
[36m(ClientAppActor pid=157070)[0m Epoch 1: train loss 0.028885068371891975, accuracy 0.7454166666666666
[36m(ClientAppActor pid=157070)[0m Epoch 2: train loss 0.013063565827906132, accuracy 0.8835416666666667
[36m(ClientAppActor pid=157070)[0m Epoch 3: train loss 0.010519752278923988, accuracy 0.9014583333333334
[36m(ClientAppActor pid=157070)[0m [Client 3, round 1] fit, config: {'server_round': 1, 'local_epochs': 3}
[36m(ClientAppActor pid=157070)[0m Epoch 1: train loss 0.02827400527894497, accuracy 0.7435416666666667
[36m(ClientAppActor pid=157070)[0m Epoch 2: train loss 0.012408753857016563, accuracy 0.8902083333333334
[36m(ClientAppActor pid=157070)[0m Epoch 3: train loss 0.009528066962957382, accuracy 0.9127083333333333
[36m(ClientAppActor pid=157070)[0m [Client 5, round 1] fit, config: {'server_round': 1, 'local_epochs': 3}
[36m(ClientAppActor pid=157070)

[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      fit progress: (1, 0.009746716725081206, {'accuracy': 0.909}, 85.01334493600007)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 10)


Server-side evaluation loss 0.009746716725081206 / accuracy 0.909
[36m(ClientAppActor pid=157070)[0m [Client 8] evaluate, config: {}
[36m(ClientAppActor pid=157070)[0m [Client 4] evaluate, config: {}
[36m(ClientAppActor pid=157070)[0m [Client 6] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=157070)[0m [Client 3, round 2] fit, config: {'server_round': 2, 'local_epochs': 3}
[36m(ClientAppActor pid=157070)[0m Epoch 1: train loss 0.009585177525877953, accuracy 0.91
[36m(ClientAppActor pid=157070)[0m Epoch 2: train loss 0.007805875968188047, accuracy 0.9277083333333334
[36m(ClientAppActor pid=157070)[0m Epoch 3: train loss 0.006518766283988953, accuracy 0.9408333333333333
[36m(ClientAppActor pid=157070)[0m [Client 6, round 2] fit, config: {'server_round': 2, 'local_epochs': 3}
[36m(ClientAppActor pid=157070)[0m Epoch 1: train loss 0.01050153374671936, accuracy 0.8910416666666666
[36m(ClientAppActor pid=157070)[0m Epoch 2: train loss 0.008600138127803802, accuracy 0.916875
[36m(ClientAppActor pid=157070)[0m Epoch 3: train loss 0.006984667852520943, accuracy 0.9354166666666667
[36m(ClientAppActor pid=157070)[0m [Client 8, round 2] fit, config: {'server_round': 2, 'local_epochs': 3}
[36m(ClientAppActor pid=157070)[0m Epoch 1: train loss

[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      fit progress: (2, 0.007842272404953837, {'accuracy': 0.9287}, 171.77642727200146)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 10)


Server-side evaluation loss 0.007842272404953837 / accuracy 0.9287
[36m(ClientAppActor pid=157070)[0m [Client 4] evaluate, config: {}
[36m(ClientAppActor pid=157070)[0m [Client 5] evaluate, config: {}
[36m(ClientAppActor pid=157070)[0m [Client 7] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 10)


[36m(ClientAppActor pid=157070)[0m [Client 6, round 3] fit, config: {'server_round': 3, 'local_epochs': 3}
[36m(ClientAppActor pid=157070)[0m Epoch 1: train loss 0.00757924560457468, accuracy 0.92375
[36m(ClientAppActor pid=157070)[0m Epoch 2: train loss 0.006012601312249899, accuracy 0.9433333333333334
[36m(ClientAppActor pid=157070)[0m Epoch 3: train loss 0.0046373400837183, accuracy 0.9595833333333333
[36m(ClientAppActor pid=157070)[0m [Client 7, round 3] fit, config: {'server_round': 3, 'local_epochs': 3}
[36m(ClientAppActor pid=157070)[0m Epoch 1: train loss 0.008686114102602005, accuracy 0.9183333333333333
[36m(ClientAppActor pid=157070)[0m Epoch 2: train loss 0.006823414005339146, accuracy 0.9377083333333334
[36m(ClientAppActor pid=157070)[0m Epoch 3: train loss 0.005371880251914263, accuracy 0.9541666666666667
[36m(ClientAppActor pid=157070)[0m [Client 8, round 3] fit, config: {'server_round': 3, 'local_epochs': 3}
[36m(ClientAppActor pid=157070)[0m Epoch 1:

[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures


[36m(ClientAppActor pid=157070)[0m Epoch 3: train loss 0.005151687655597925, accuracy 0.9535416666666666


[92mINFO [0m:      fit progress: (3, 0.006593332336656749, {'accuracy': 0.9386}, 255.52648013900034)
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 10)


Server-side evaluation loss 0.006593332336656749 / accuracy 0.9386
[36m(ClientAppActor pid=157070)[0m [Client 2] evaluate, config: {}
[36m(ClientAppActor pid=157070)[0m [Client 3] evaluate, config: {}
[36m(ClientAppActor pid=157070)[0m [Client 6] evaluate, config: {}


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 3 round(s) in 265.57s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.010639411028888492
[92mINFO [0m:      		round 2: 0.00832694680016074
[92mINFO [0m:      		round 3: 0.006698395094523827
[92mINFO [0m:      	History (loss, centralized):
[92mINFO [0m:      		round 0: 0.07205835869312287
[92mINFO [0m:      		round 1: 0.009746716725081206
[92mINFO [0m:      		round 2: 0.007842272404953837
[92mINFO [0m:      		round 3: 0.006593332336656749
[92mINFO [0m:      	History (metrics, centralized):
[92mINFO [0m:      	{'accuracy': [(0, 0.1201), (1, 0.909), (2, 0.9287), (3, 0.9386)]}
[92mINFO [0m:      
