with random clients 1

In [None]:
import flwr as fl
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms 
from torchvision.datasets import CIFAR10, FashionMNIST, SVHN
import random

# Constants
BATCH_SIZE = 64
LEARNING_RATE = 0.001
EPOCHS = 5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the model with dataset-specific layers
class CNN_Model(nn.Module):
    def __init__(self, dataset_name):
        super(CNN_Model, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.AvgPool2d(2, 2)
        self.fc_config = nn.Linear(64 * 8 * 8, 128)

        # Dataset-specific layers
        if dataset_name == "cifar10":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, 10)
            )
        elif dataset_name == "fashion_mnist":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 32),
                nn.ReLU(),
                nn.Linear(32, 10)
            )
        elif dataset_name == "svhn":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, 10)
            )

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc_config(x))
        x = self.fc_op(x)
        return F.softmax(x, dim=1)

# Define the Flower client
class FLClient(fl.client.NumPyClient):
    def __init__(self, model, train_loader, test_loader, client_id):
        self.model = model.to(DEVICE)
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=LEARNING_RATE)
        self.client_id = client_id

    def get_parameters(self, config=None):
        params = {
            k: v.cpu().numpy()
            for k, v in self.model.state_dict().items()
            if not k.startswith("fc_op")
        }
        return list(params.values())

    def set_parameters(self, parameters):
        state_dict = self.model.state_dict()
        shared_keys = [k for k in state_dict.keys() if not k.startswith("fc_op")]
        for key, param in zip(shared_keys, parameters):
            state_dict[key] = torch.tensor(param)
        self.model.load_state_dict(state_dict, strict=False)

    def fine_tune(self):
        self.model.train()
        for epoch in range(2):  # Fine-tune for 2 local epochs
            for images, labels in self.train_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()

    def fit(self, parameters, config):
        self.set_parameters(parameters)  # Set global model parameters
        self.model.train()
        for epoch in range(EPOCHS):  # Train on local data
            for images, labels in self.train_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()
        # Perform fine-tuning
        self.fine_tune()

        return self.get_parameters(config), len(self.train_loader.dataset), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        self.model.eval()
        correct, total = 0, 0
        loss = 0.0
        with torch.no_grad():
            for images, labels in self.test_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = self.model(images)
                loss += self.criterion(outputs, labels).item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy = correct / total
        print(f"Client {self.client_id}: Accuracy = {accuracy}")
        return float(loss) / len(self.test_loader), len(self.test_loader.dataset), {"accuracy": accuracy}

# Load the datasets
def load_data(dataset_name):
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    if dataset_name == "cifar10":
        train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
        test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
    elif dataset_name == "fashion_mnist":
        transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=3),
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ])
        train_dataset = FashionMNIST(root='./data', train=True, download=True, transform=transform)
        test_dataset = FashionMNIST(root='./data', train=False, download=True, transform=transform)
    elif dataset_name == "svhn":
        train_dataset = SVHN(root='./data', split='train', download=True, transform=transform)
        test_dataset = SVHN(root='./data', split='test', download=True, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    return train_loader, test_loader

def start_federated_learning():
    global clients
    clients = [
        FLClient(CNN_Model("cifar10"), *load_data("cifar10"), "CIFAR-10"),
        FLClient(CNN_Model("fashion_mnist"), *load_data("fashion_mnist"), "Fashion-MNIST"),
        FLClient(CNN_Model("svhn"), *load_data("svhn"), "SVHN"),
    ]

    fl.simulation.start_simulation(
        client_fn=lambda cid: random.choice(clients),
        num_clients=3,
        client_resources={"num_cpus": 2, "num_gpus": 1},
        config=fl.server.ServerConfig(num_rounds=10),
        strategy=fl.server.strategy.FedAvg(
            fraction_fit=0.67,
            min_fit_clients=2,
            min_available_clients=3,
        )
    )

if __name__ == "__main__":
    start_federated_learning()


2025-01-02 20:42:19,283	INFO util.py:154 -- Outdated packages:
  ipywidgets==6.0.0 found, needs ipywidgets>=8
Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ./data/train_32x32.mat
Using downloaded and verified file: ./data/test_32x32.mat


[92mINFO [0m:      Starting Flower simulation, config: num_rounds=10, no round_timeout
2025-01-02 20:43:35,656	INFO worker.py:1752 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'object_store_memory': 7855559884.0, 'accelerator_type:P4000': 1.0, 'node:__internal_head__': 1.0, 'node:131.227.65.35': 1.0, 'GPU': 1.0, 'CPU': 12.0, 'memory': 15711119771.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 2, 'num_gpus': 1}
[92mINFO [0m:      Flower VCE: Creating VirtualClientEngineActorPool with 1 actors
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[36m(ClientAppActor pid=71836)[0m   return torch.load(io.BytesIO(b))
[36m(ClientAppActor pid=71836)[0m   return torch.load(io.BytesIO(b))
[36m(ClientAppActor pid=71836)[0m 
[3

[36m(ClientAppActor pid=71836)[0m Client SVHN: Accuracy = 0.07498463429625077


[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         


[36m(ClientAppActor pid=71836)[0m Client CIFAR-10: Accuracy = 0.2658


[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 3)


[36m(ClientAppActor pid=71836)[0m Client CIFAR-10: Accuracy = 0.2658


[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         
[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         
[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         


[36m(ClientAppActor pid=71836)[0m Client SVHN: Accuracy = 0.21266133988936695


[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         


[36m(ClientAppActor pid=71836)[0m Client SVHN: Accuracy = 0.21266133988936695


[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 3)


[36m(ClientAppActor pid=71836)[0m Client SVHN: Accuracy = 0.21266133988936695


[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         
[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         
[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         


[36m(ClientAppActor pid=71836)[0m Client Fashion-MNIST: Accuracy = 0.249


[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         


[36m(ClientAppActor pid=71836)[0m Client SVHN: Accuracy = 0.3581361401352182


[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 4]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 3)


[36m(ClientAppActor pid=71836)[0m Client Fashion-MNIST: Accuracy = 0.249


[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         
[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         
[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         


[36m(ClientAppActor pid=71836)[0m Client Fashion-MNIST: Accuracy = 0.5331


[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         


[36m(ClientAppActor pid=71836)[0m Client SVHN: Accuracy = 0.3382375537799631


[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 5]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 3)


[36m(ClientAppActor pid=71836)[0m Client SVHN: Accuracy = 0.3382375537799631


[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         
[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         
[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         


[36m(ClientAppActor pid=71836)[0m Client SVHN: Accuracy = 0.23912876459741855


[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         


[36m(ClientAppActor pid=71836)[0m Client CIFAR-10: Accuracy = 0.1545


[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 6]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 3)


[36m(ClientAppActor pid=71836)[0m Client SVHN: Accuracy = 0.23912876459741855


[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         
[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         
[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         


[36m(ClientAppActor pid=71836)[0m Client SVHN: Accuracy = 0.41679471419791025


[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         


[36m(ClientAppActor pid=71836)[0m Client SVHN: Accuracy = 0.41679471419791025


[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 7]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 3)


[36m(ClientAppActor pid=71836)[0m Client Fashion-MNIST: Accuracy = 0.4723


[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         
[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         
[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         


[36m(ClientAppActor pid=71836)[0m Client CIFAR-10: Accuracy = 0.1123


[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         


[36m(ClientAppActor pid=71836)[0m Client SVHN: Accuracy = 0.3894821757836509


[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 8]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 3)


[36m(ClientAppActor pid=71836)[0m Client Fashion-MNIST: Accuracy = 0.6584


[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         
[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         
[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         


[36m(ClientAppActor pid=71836)[0m Client CIFAR-10: Accuracy = 0.1193


[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         


[36m(ClientAppActor pid=71836)[0m Client CIFAR-10: Accuracy = 0.1193


[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 9]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 3)


[36m(ClientAppActor pid=71836)[0m Client CIFAR-10: Accuracy = 0.1193


[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         
[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         
[92mINFO [0m:      aggregate_fit: received 2 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         


[36m(ClientAppActor pid=71836)[0m Client Fashion-MNIST: Accuracy = 0.6138


[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         


[36m(ClientAppActor pid=71836)[0m Client Fashion-MNIST: Accuracy = 0.6138


[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 10]
[92mINFO [0m:      configure_fit: strategy sampled 2 clients (out of 3)


[36m(ClientAppActor pid=71836)[0m Client CIFAR-10: Accuracy = 0.3648


[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         
[36m(ClientAppActor pid=71836)[0m 
[36m(ClientAppActor pid=71836)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=71836)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=71836)[0m         


In [None]:
import flwr as fl
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, FashionMNIST, SVHN
import random

# Constants
BATCH_SIZE = 128
LEARNING_RATE = 0.001
EPOCHS = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the model with dataset-specific layers
class CNN_Model(nn.Module):
    def __init__(self, dataset_name):
        super(CNN_Model, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.AvgPool2d(2, 2)
        self.fc_config = nn.Linear(64 * 8 * 8, 128)

        # Dataset-specific layers
        if dataset_name == "cifar10":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, 10)
            )
        elif dataset_name == "fashion_mnist":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 32),
                nn.ReLU(),
                nn.Linear(32, 10)
            )
        elif dataset_name == "svhn":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, 10)
            )

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc_config(x))
        x = self.fc_op(x)
        return F.softmax(x, dim=1)

# Define the Flower client with FedProx
class FLClient(fl.client.NumPyClient):
    def __init__(self, model, train_loader, test_loader, client_id):
        self.model = model.to(DEVICE)
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=LEARNING_RATE)
        self.client_id = client_id

    def get_parameters(self, config=None):
        params = {
            k: v.cpu().numpy()
            for k, v in self.model.state_dict().items()
            if not k.startswith("fc_op")
        }
        return list(params.values())

    def set_parameters(self, parameters):
        state_dict = self.model.state_dict()
        shared_keys = [k for k in state_dict.keys() if not k.startswith("fc_op")]
        for key, param in zip(shared_keys, parameters):
            state_dict[key] = torch.tensor(param)
        self.model.load_state_dict(state_dict, strict=False)

    def fit(self, parameters, config):
        # Set global parameters on the local model
        self.set_parameters(parameters)
        self.model.train()
        
        # Save global model parameters for proximity
        global_params = [param.clone().detach().to(DEVICE) for param in self.model.parameters()]
        
        for epoch in range(EPOCHS):
            for images, labels in self.train_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                
                # Add proximal term to the loss
                prox_term = 0.0
                for param, global_param in zip(self.model.parameters(), global_params):
                    prox_term += torch.norm(param - global_param) ** 2
                loss += (0.01 / 2) * prox_term  # 0.01 is the proximal coefficient
                
                loss.backward()
                self.optimizer.step()
        
        return self.get_parameters(config), len(self.train_loader.dataset), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        self.model.eval()
        correct, total = 0, 0
        loss = 0.0
        with torch.no_grad():
            for images, labels in self.test_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = self.model(images)
                loss += self.criterion(outputs, labels).item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy = correct / total
        print(f"Client {self.client_id}: Accuracy = {accuracy}")
        return float(loss) / len(self.test_loader), len(self.test_loader.dataset), {"accuracy": accuracy}

# Load the datasets
def load_data(dataset_name):
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    if dataset_name == "cifar10":
        train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
        test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
    elif dataset_name == "fashion_mnist": 
        transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=3),
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ])
        train_dataset = FashionMNIST(root='./data', train=True, download=True, transform=transform)
        test_dataset = FashionMNIST(root='./data', train=False, download=True, transform=transform)
    elif dataset_name == "svhn":
        train_dataset = SVHN(root='./data', split='train', download=True, transform=transform)
        test_dataset = SVHN(root='./data', split='test', download=True, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    return train_loader, test_loader

def start_federated_learning():
    global clients
    clients = [
        FLClient(CNN_Model("cifar10"), *load_data("cifar10"), "CIFAR-10"),
        FLClient(CNN_Model("fashion_mnist"), *load_data("fashion_mnist"), "Fashion-MNIST"),
        FLClient(CNN_Model("svhn"), *load_data("svhn"), "SVHN"),
    ]

    fl.simulation.start_simulation(
        client_fn=lambda cid: random.choice(clients),
        num_clients=3,
        client_resources={"num_cpus": 4, "num_gpus": 1},
        config=fl.server.ServerConfig(num_rounds=10),
        strategy=fl.server.strategy.FedAvg(
            fraction_fit=0.67,
            min_fit_clients=2,
            min_available_clients=3,
        )
    )

if __name__ == "__main__":
    start_federated_learning()
