In [5]:
import flwr as fl
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms 
from torchvision.datasets import CIFAR10, FashionMNIST, SVHN
 
# Constants
BATCH_SIZE = 64
LEARNING_RATE = 0.001
EPOCHS = 5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
# Define the model with dataset-specific layers
class CNN_Model(nn.Module):
    def __init__(self, dataset_name):
        super(CNN_Model, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.AvgPool2d(2, 2)
        self.fc_config = nn.Linear(64 * 8 * 8, 128)
 
        # Dataset-specific layers
        if dataset_name == "cifar10":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, 10)
            )
        elif dataset_name == "fashion_mnist":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 32),
                nn.ReLU(),
                nn.Linear(32, 10)
            )
        elif dataset_name == "svhn":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, 10)
            )
 
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc_config(x))
        x = self.fc_op(x)
        return F.softmax(x, dim=1)
 
# Define the Flower client
class FLClient(fl.client.NumPyClient):
    def __init__(self, model, train_loader, test_loader, client_id):
        self.model = model.to(DEVICE)
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=LEARNING_RATE)
        self.client_id = client_id
 
    def get_parameters(self, config=None):
        params = {
            k: v.cpu().numpy()
            for k, v in self.model.state_dict().items()
            if not k.startswith("fc_op")
        }
        return list(params.values())
 
    def set_parameters(self, parameters):
        state_dict = self.model.state_dict()
        shared_keys = [k for k in state_dict.keys() if not k.startswith("fc_op")]
        for key, param in zip(shared_keys, parameters):
            state_dict[key] = torch.tensor(param)
        self.model.load_state_dict(state_dict, strict=False)
 
    def fine_tune(self):
        self.model.train()
        for epoch in range(2):  # Fine-tune for 2 local epochs
            for images, labels in self.train_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()
 
    def fit(self, parameters, config):
        self.set_parameters(parameters)  # Set global model parameters
        self.model.train()
        for epoch in range(EPOCHS):  # Train on local data
            for images, labels in self.train_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()
        # Perform fine-tuning
        self.fine_tune()
 
        return self.get_parameters(config), len(self.train_loader.dataset), {}
 
    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        self.model.eval()
        correct, total = 0, 0
        loss = 0.0
        with torch.no_grad():
            for images, labels in self.test_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = self.model(images)
                loss += self.criterion(outputs, labels).item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy = correct / total
        print(f"Client {self.client_id}: Accuracy = {accuracy}")
        return float(loss) / len(self.test_loader), len(self.test_loader.dataset), {"accuracy": accuracy}
 
# Load the datasets
def load_data(dataset_name):
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    if dataset_name == "cifar10":
        train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
        test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
    elif dataset_name == "fashion_mnist":
        transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=3),
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ])
        train_dataset = FashionMNIST(root='./data', train=True, download=True, transform=transform)
        test_dataset = FashionMNIST(root='./data', train=False, download=True, transform=transform)
    elif dataset_name == "svhn":
        train_dataset = SVHN(root='./data', split='train', download=True, transform=transform)
        test_dataset = SVHN(root='./data', split='test', download=True, transform=transform)
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
 
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    return train_loader, test_loader
 
# Updated client_fn to avoid using Context's client_id
def client_fn(cid: str) -> fl.client.Client:
    if cid == "0":
        return cifar_client
    elif cid == "1":
        return fashion_client
    elif cid == "2":
        return svhn_client
 
# Start federated learning
def start_federated_learning():
    global cifar_client, fashion_client, svhn_client
    cifar_train_loader, cifar_test_loader = load_data("cifar10")
    fashion_train_loader, fashion_test_loader = load_data("fashion_mnist")
    svhn_train_loader, svhn_test_loader = load_data("svhn")
 
    cifar_client = FLClient(CNN_Model("cifar10"), cifar_train_loader, cifar_test_loader, "CIFAR-10")
    fashion_client = FLClient(CNN_Model("fashion_mnist"), fashion_train_loader, fashion_test_loader, "Fashion-MNIST")
    svhn_client = FLClient(CNN_Model("svhn"), svhn_train_loader, svhn_test_loader, "SVHN")
 
    fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=3,
        client_resources={"num_cpus": 2, "num_gpus": 1},
        config=fl.server.ServerConfig(num_rounds=10),  # Increased number of rounds
        strategy=fl.server.strategy.FedAvg(
            fraction_fit=1.0,
            min_fit_clients=3,
            min_available_clients=3,
        )
    )
 
if __name__ == "__main__":
    start_federated_learning()

Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ./data/train_32x32.mat
Using downloaded and verified file: ./data/test_32x32.mat


[92mINFO [0m:      Starting Flower simulation, config: num_rounds=10, no round_timeout
2024-12-04 17:56:55,005	INFO worker.py:1752 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'object_store_memory': 6006627532.0, 'memory': 12013255067.0, 'CPU': 12.0, 'node:131.227.65.97': 1.0, 'GPU': 1.0, 'accelerator_type:P4000': 1.0, 'node:__internal_head__': 1.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 2, 'num_gpus': 1}
[92mINFO [0m:      Flower VCE: Creating VirtualClientEngineActorPool with 1 actors
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[36m(ClientAppActor pid=2043586)[0m   return torch.load(io.BytesIO(b))
[36m(ClientAppActor pid=2043586)[0m   return torch.load(io.BytesIO(b))
[92mINFO [0m:      Received initia

[36m(ClientAppActor pid=2043586)[0m Client CIFAR-10: Accuracy = 0.1807


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         


[36m(ClientAppActor pid=2043586)[0m Client SVHN: Accuracy = 0.193799938537185


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=2043586)[0m Client Fashion-MNIST: Accuracy = 0.4491


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=2043

[36m(ClientAppActor pid=2043586)[0m Client Fashion-MNIST: Accuracy = 0.5915


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         


[36m(ClientAppActor pid=2043586)[0m Client SVHN: Accuracy = 0.32890288875230483


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=2043586)[0m Client CIFAR-10: Accuracy = 0.1911


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=2043

[36m(ClientAppActor pid=2043586)[0m Client Fashion-MNIST: Accuracy = 0.6183


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         


[36m(ClientAppActor pid=2043586)[0m Client CIFAR-10: Accuracy = 0.2204


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 4]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=2043586)[0m Client SVHN: Accuracy = 0.36270743700061464


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=2043

[36m(ClientAppActor pid=2043586)[0m Client Fashion-MNIST: Accuracy = 0.6605


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         


[36m(ClientAppActor pid=2043586)[0m Client CIFAR-10: Accuracy = 0.2369


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 5]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=2043586)[0m Client SVHN: Accuracy = 0.3990857406269207


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=2043

[36m(ClientAppActor pid=2043586)[0m Client SVHN: Accuracy = 0.3857559926244622


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         


[36m(ClientAppActor pid=2043586)[0m Client CIFAR-10: Accuracy = 0.2528


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 6]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=2043586)[0m Client Fashion-MNIST: Accuracy = 0.7268


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=2043

[36m(ClientAppActor pid=2043586)[0m Client CIFAR-10: Accuracy = 0.2646


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         


[36m(ClientAppActor pid=2043586)[0m Client SVHN: Accuracy = 0.40711432083589427


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 7]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=2043586)[0m Client Fashion-MNIST: Accuracy = 0.7159


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=2043

[36m(ClientAppActor pid=2043586)[0m Client CIFAR-10: Accuracy = 0.2695


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         


[36m(ClientAppActor pid=2043586)[0m Client Fashion-MNIST: Accuracy = 0.726


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 8]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=2043586)[0m Client SVHN: Accuracy = 0.40884296250768287


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=2043

[36m(ClientAppActor pid=2043586)[0m Client CIFAR-10: Accuracy = 0.2719


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         


[36m(ClientAppActor pid=2043586)[0m Client SVHN: Accuracy = 0.4061923786109404


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 9]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=2043586)[0m Client Fashion-MNIST: Accuracy = 0.7152


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=2043

[36m(ClientAppActor pid=2043586)[0m Client SVHN: Accuracy = 0.40331130915795943


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         


[36m(ClientAppActor pid=2043586)[0m Client CIFAR-10: Accuracy = 0.2716


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 10]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=2043586)[0m Client Fashion-MNIST: Accuracy = 0.7575


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=2043

[36m(ClientAppActor pid=2043586)[0m Client Fashion-MNIST: Accuracy = 0.7636


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         


[36m(ClientAppActor pid=2043586)[0m Client SVHN: Accuracy = 0.4001613398893669


[36m(ClientAppActor pid=2043586)[0m 
[36m(ClientAppActor pid=2043586)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2043586)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2043586)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 10 round(s) in 3533.15s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 2.2923804188299663
[92mINFO [0m:      		round 2: 2.260058136786089
[92mINFO [0m:      		round 3: 2.2267038313596603
[92mINFO [0m:      		round 4: 2.18524963034649
[92mINFO [0m:      		round 5: 2.152322626460022
[92mINFO [0m:      		round 6: 2.1329300689836512
[92mINFO [0m:      		round 7: 2.10424193578595
[92mINFO [0m:      		round 8: 2.089408822774193
[92mINFO [0m:      		round 9: 2.0678136107639182
[92mINFO [0m:      		round 10: 2.05

[36m(ClientAppActor pid=2043586)[0m Client CIFAR-10: Accuracy = 0.2715


In [8]:
import flwr as fl
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms 
from torchvision.datasets import CIFAR10, FashionMNIST, SVHN
 
# Constants
BATCH_SIZE = 64
LEARNING_RATE = 0.001
EPOCHS = 5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
# Define the model with dataset-specific layers
class CNN_Model(nn.Module):
    def __init__(self, dataset_name):
        super(CNN_Model, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.AvgPool2d(2, 2)
        self.fc_config = nn.Linear(64 * 8 * 8, 128)
 
        # Dataset-specific layers
        if dataset_name == "cifar10":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, 10)
            )
        elif dataset_name == "fashion_mnist":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 32),
                nn.ReLU(),
                nn.Linear(32, 10)
            )
        elif dataset_name == "svhn":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, 10)
            )
 
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc_config(x))
        x = self.fc_op(x)
        return F.softmax(x, dim=1)
 
# Define the Flower client
class FLClient(fl.client.NumPyClient):
    def __init__(self, model, train_loader, test_loader, client_id):
        self.model = model.to(DEVICE)
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=LEARNING_RATE)
        self.client_id = client_id
 
    def get_parameters(self, config=None):
        params = {
            k: v.cpu().numpy()
            for k, v in self.model.state_dict().items()
            if not k.startswith("fc_op")
        }
        return list(params.values())
 
    def set_parameters(self, parameters):
        state_dict = self.model.state_dict()
        shared_keys = [k for k in state_dict.keys() if not k.startswith("fc_op")]
        for key, param in zip(shared_keys, parameters):
            state_dict[key] = torch.tensor(param)
        self.model.load_state_dict(state_dict, strict=False)
 
    def fine_tune(self):
        self.model.train()
        for epoch in range(2):  # Fine-tune for 2 local epochs
            for images, labels in self.train_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()
 
    def fit(self, parameters, config):
        self.set_parameters(parameters)  # Set global model parameters
        self.model.train()
        for epoch in range(EPOCHS):  # Train on local data
            for images, labels in self.train_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()
        # Perform fine-tuning
        self.fine_tune()
 
        return self.get_parameters(config), len(self.train_loader.dataset), {}
 
    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        self.model.eval()
        correct, total = 0, 0
        loss = 0.0
        with torch.no_grad():
            for images, labels in self.test_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = self.model(images)
                loss += self.criterion(outputs, labels).item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy = correct / total
        print(f"Client {self.client_id}: Accuracy = {accuracy}")
        return float(loss) / len(self.test_loader), len(self.test_loader.dataset), {"accuracy": accuracy}
 
# Load the datasets
def load_data(dataset_name):
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    if dataset_name == "cifar10":
        train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
        test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
    elif dataset_name == "fashion_mnist":
        transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=3),
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ])
        train_dataset = FashionMNIST(root='./data', train=True, download=True, transform=transform)
        test_dataset = FashionMNIST(root='./data', train=False, download=True, transform=transform)
    elif dataset_name == "svhn":
        train_dataset = SVHN(root='./data', split='train', download=True, transform=transform)
        test_dataset = SVHN(root='./data', split='test', download=True, transform=transform)
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
 
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    return train_loader, test_loader
 
# Updated client_fn to avoid using Context's client_id
def client_fn(cid: str) -> fl.client.Client:
    if cid == "0":
        return cifar_client
    elif cid == "1":
        return fashion_client
    elif cid == "2":
        return svhn_client
 
# Start federated learning
def start_federated_learning():
    global cifar_client, fashion_client, svhn_client
    cifar_train_loader, cifar_test_loader = load_data("cifar10")
    fashion_train_loader, fashion_test_loader = load_data("fashion_mnist")
    svhn_train_loader, svhn_test_loader = load_data("svhn")
 
    cifar_client = FLClient(CNN_Model("cifar10"), cifar_train_loader, cifar_test_loader, "CIFAR-10")
    fashion_client = FLClient(CNN_Model("fashion_mnist"), fashion_train_loader, fashion_test_loader, "Fashion-MNIST")
    svhn_client = FLClient(CNN_Model("svhn"), svhn_train_loader, svhn_test_loader, "SVHN")
 
    fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=3,
        client_resources={"num_cpus": 2, "num_gpus": 1},
        config=fl.server.ServerConfig(num_rounds=5),  # Increased number of rounds
        strategy=fl.server.strategy.FedAvg(
            fraction_fit=1.0,
            min_fit_clients=3,
            min_available_clients=3,
        )
    )
 
if __name__ == "__main__":
    start_federated_learning()

Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ./data/train_32x32.mat
Using downloaded and verified file: ./data/test_32x32.mat


[92mINFO [0m:      Starting Flower simulation, config: num_rounds=5, no round_timeout
2024-12-11 19:18:09,185	INFO worker.py:1752 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'accelerator_type:P4000': 1.0, 'node:__internal_head__': 1.0, 'memory': 9905858151.0, 'node:131.227.65.97': 1.0, 'CPU': 12.0, 'GPU': 1.0, 'object_store_memory': 4952929075.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 2, 'num_gpus': 1}
[92mINFO [0m:      Flower VCE: Creating VirtualClientEngineActorPool with 1 actors
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[36m(ClientAppActor pid=3499243)[0m   return torch.load(io.BytesIO(b))
[36m(ClientAppActor pid=3499243)[0m   return torch.load(io.BytesIO(b))
[92mINFO [0m:      Received initial 

[36m(ClientAppActor pid=3499243)[0m Client CIFAR-10: Accuracy = 0.1501


[36m(ClientAppActor pid=3499243)[0m 
[36m(ClientAppActor pid=3499243)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3499243)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3499243)[0m         


[36m(ClientAppActor pid=3499243)[0m Client Fashion-MNIST: Accuracy = 0.4043


[36m(ClientAppActor pid=3499243)[0m 
[36m(ClientAppActor pid=3499243)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3499243)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3499243)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3499243)[0m Client SVHN: Accuracy = 0.2884142593730793


[36m(ClientAppActor pid=3499243)[0m 
[36m(ClientAppActor pid=3499243)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3499243)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3499243)[0m         
[36m(ClientAppActor pid=3499243)[0m 
[36m(ClientAppActor pid=3499243)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3499243)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3499243)[0m         
[36m(ClientAppActor pid=3499243)[0m 
[36m(ClientAppActor pid=3499243)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3499243)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3499243)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3499

[36m(ClientAppActor pid=3499243)[0m Client CIFAR-10: Accuracy = 0.1711


[36m(ClientAppActor pid=3499243)[0m 
[36m(ClientAppActor pid=3499243)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3499243)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3499243)[0m         


[36m(ClientAppActor pid=3499243)[0m Client SVHN: Accuracy = 0.3879840196681008


[36m(ClientAppActor pid=3499243)[0m 
[36m(ClientAppActor pid=3499243)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3499243)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3499243)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3499243)[0m Client Fashion-MNIST: Accuracy = 0.6743


[36m(ClientAppActor pid=3499243)[0m 
[36m(ClientAppActor pid=3499243)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3499243)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3499243)[0m         
[36m(ClientAppActor pid=3499243)[0m 
[36m(ClientAppActor pid=3499243)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3499243)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3499243)[0m         
[36m(ClientAppActor pid=3499243)[0m 
[36m(ClientAppActor pid=3499243)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3499243)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3499243)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3499

[36m(ClientAppActor pid=3499243)[0m Client Fashion-MNIST: Accuracy = 0.6861


[36m(ClientAppActor pid=3499243)[0m 
[36m(ClientAppActor pid=3499243)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3499243)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3499243)[0m         


[36m(ClientAppActor pid=3499243)[0m Client SVHN: Accuracy = 0.432659803318992


[36m(ClientAppActor pid=3499243)[0m 
[36m(ClientAppActor pid=3499243)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3499243)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3499243)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 4]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3499243)[0m Client CIFAR-10: Accuracy = 0.2122


[36m(ClientAppActor pid=3499243)[0m 
[36m(ClientAppActor pid=3499243)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3499243)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3499243)[0m         
[36m(ClientAppActor pid=3499243)[0m 
[36m(ClientAppActor pid=3499243)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3499243)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3499243)[0m         
[36m(ClientAppActor pid=3499243)[0m 
[36m(ClientAppActor pid=3499243)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3499243)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3499243)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3499

[36m(ClientAppActor pid=3499243)[0m Client CIFAR-10: Accuracy = 0.2157


[36m(ClientAppActor pid=3499243)[0m 
[36m(ClientAppActor pid=3499243)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3499243)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3499243)[0m         


[36m(ClientAppActor pid=3499243)[0m Client Fashion-MNIST: Accuracy = 0.7149


[36m(ClientAppActor pid=3499243)[0m 
[36m(ClientAppActor pid=3499243)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3499243)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3499243)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 5]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3499243)[0m Client SVHN: Accuracy = 0.4506760909649662


[36m(ClientAppActor pid=3499243)[0m 
[36m(ClientAppActor pid=3499243)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3499243)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3499243)[0m         
[36m(ClientAppActor pid=3499243)[0m 
[36m(ClientAppActor pid=3499243)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3499243)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3499243)[0m         
[36m(ClientAppActor pid=3499243)[0m 
[36m(ClientAppActor pid=3499243)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3499243)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3499243)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3499

[36m(ClientAppActor pid=3499243)[0m Client Fashion-MNIST: Accuracy = 0.7145


[36m(ClientAppActor pid=3499243)[0m 
[36m(ClientAppActor pid=3499243)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3499243)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3499243)[0m         


[36m(ClientAppActor pid=3499243)[0m Client SVHN: Accuracy = 0.4594345421020283


[36m(ClientAppActor pid=3499243)[0m 
[36m(ClientAppActor pid=3499243)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3499243)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3499243)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 5 round(s) in 1631.11s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 2.2919704759266084
[92mINFO [0m:      		round 2: 2.2569667891916025
[92mINFO [0m:      		round 3: 2.213540887183102
[92mINFO [0m:      		round 4: 2.1674095973126404
[92mINFO [0m:      		round 5: 2.1347596536828917
[92mINFO [0m:      


[36m(ClientAppActor pid=3499243)[0m Client CIFAR-10: Accuracy = 0.247


In [8]:
import flwr as fl
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, FashionMNIST, SVHN

# Constants
BATCH_SIZE = 128
LEARNING_RATE = 0.001
EPOCHS = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the model with dataset-specific layers
class CNN_Model(nn.Module):
    def __init__(self, dataset_name):
        super(CNN_Model, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.AvgPool2d(2, 2)
        self.fc_config = nn.Linear(64 * 8 * 8, 128)

        # Dataset-specific layers
        if dataset_name == "cifar10":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, 10)
            )
        elif dataset_name == "fashion_mnist":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 32),
                nn.ReLU(),
                nn.Linear(32, 10)
            )
        elif dataset_name == "svhn":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, 10)
            )

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc_config(x))
        x = self.fc_op(x)
        return F.softmax(x, dim=1)

# Define the Flower client with FedProx
class FLClient(fl.client.NumPyClient):
    def __init__(self, model, train_loader, test_loader, client_id):
        self.model = model.to(DEVICE)
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=LEARNING_RATE)
        self.client_id = client_id

    def get_parameters(self, config=None):
        params = {
            k: v.cpu().numpy()
            for k, v in self.model.state_dict().items()
            if not k.startswith("fc_op")
        }
        return list(params.values())

    def set_parameters(self, parameters):
        state_dict = self.model.state_dict()
        shared_keys = [k for k in state_dict.keys() if not k.startswith("fc_op")]
        for key, param in zip(shared_keys, parameters):
            state_dict[key] = torch.tensor(param)
        self.model.load_state_dict(state_dict, strict=False)

    def fit(self, parameters, config):
        # Set global parameters on the local model
        self.set_parameters(parameters)
        self.model.train()
        
        # Save global model parameters for proximity
        global_params = [param.clone().detach().to(DEVICE) for param in self.model.parameters()]
        
        for epoch in range(EPOCHS):
            for images, labels in self.train_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                
                # Add proximal term to the loss
                prox_term = 0.0
                for param, global_param in zip(self.model.parameters(), global_params):
                    prox_term += torch.norm(param - global_param) ** 2
                loss += (0.01 / 2) * prox_term  # 0.01 is the proximal coefficient
                
                loss.backward()
                self.optimizer.step()
        
        return self.get_parameters(config), len(self.train_loader.dataset), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        self.model.eval()
        correct, total = 0, 0
        loss = 0.0
        with torch.no_grad():
            for images, labels in self.test_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = self.model(images)
                loss += self.criterion(outputs, labels).item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy = correct / total
        print(f"Client {self.client_id}: Accuracy = {accuracy}")
        return float(loss) / len(self.test_loader), len(self.test_loader.dataset), {"accuracy": accuracy}

# Load the datasets
def load_data(dataset_name):
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    if dataset_name == "cifar10":
        train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
        test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
    elif dataset_name == "fashion_mnist": 
        transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=3),
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ])
        train_dataset = FashionMNIST(root='./data', train=True, download=True, transform=transform)
        test_dataset = FashionMNIST(root='./data', train=False, download=True, transform=transform)
    elif dataset_name == "svhn":
        train_dataset = SVHN(root='./data', split='train', download=True, transform=transform)
        test_dataset = SVHN(root='./data', split='test', download=True, transform=transform)
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    return train_loader, test_loader

# Updated client_fn to avoid using Context's client_id
def client_fn(cid: str) -> fl.client.Client:
    if cid == "0":
        return cifar_client
    elif cid == "1":
        return fashion_client
    elif cid == "2":
        return svhn_client

# Start federated learning
def start_federated_learning():
    global cifar_client, fashion_client, svhn_client
    cifar_train_loader, cifar_test_loader = load_data("cifar10")
    fashion_train_loader, fashion_test_loader = load_data("fashion_mnist")
    svhn_train_loader, svhn_test_loader = load_data("svhn")

    cifar_client = FLClient(CNN_Model("cifar10"), cifar_train_loader, cifar_test_loader, "CIFAR-10")
    fashion_client = FLClient(CNN_Model("fashion_mnist"), fashion_train_loader, fashion_test_loader, "Fashion-MNIST")
    svhn_client = FLClient(CNN_Model("svhn"), svhn_train_loader, svhn_test_loader, "SVHN")

    fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=3,
        client_resources={"num_cpus": 4, "num_gpus": 1},
        config=fl.server.ServerConfig(num_rounds=20),  # Increased number of rounds
        strategy=fl.server.strategy.FedAvg(
            fraction_fit=1.0,
            min_fit_clients=3,
            min_available_clients=3,
        )
    )

if __name__ == "__main__":
    start_federated_learning()


Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ./data/train_32x32.mat
Using downloaded and verified file: ./data/test_32x32.mat


[92mINFO [0m:      Starting Flower simulation, config: num_rounds=20, no round_timeout
2024-12-04 19:38:04,941	INFO worker.py:1752 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'memory': 11841161627.0, 'object_store_memory': 5920580812.0, 'GPU': 1.0, 'CPU': 12.0, 'node:131.227.65.97': 1.0, 'node:__internal_head__': 1.0, 'accelerator_type:P4000': 1.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 4, 'num_gpus': 1}
[92mINFO [0m:      Flower VCE: Creating VirtualClientEngineActorPool with 1 actors
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[36m(ClientAppActor pid=2062115)[0m   return torch.load(io.BytesIO(b))
[36m(ClientAppActor pid=2062115)[0m   return torch.load(io.BytesIO(b))
[92mINFO [0m:      Received initia

[36m(ClientAppActor pid=2062115)[0m Client Fashion-MNIST: Accuracy = 0.2716


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         


[36m(ClientAppActor pid=2062115)[0m Client CIFAR-10: Accuracy = 0.134


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=2062115)[0m Client SVHN: Accuracy = 0.15565457897971727


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=2062

[36m(ClientAppActor pid=2062115)[0m Client Fashion-MNIST: Accuracy = 0.3026


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         


[36m(ClientAppActor pid=2062115)[0m Client CIFAR-10: Accuracy = 0.1729


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=2062115)[0m Client SVHN: Accuracy = 0.15872771972956362


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=2062

[36m(ClientAppActor pid=2062115)[0m Client CIFAR-10: Accuracy = 0.1893


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         


[36m(ClientAppActor pid=2062115)[0m Client Fashion-MNIST: Accuracy = 0.4575


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 4]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=2062115)[0m Client SVHN: Accuracy = 0.2143899815611555


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=2062

[36m(ClientAppActor pid=2062115)[0m Client CIFAR-10: Accuracy = 0.2034


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         


[36m(ClientAppActor pid=2062115)[0m Client Fashion-MNIST: Accuracy = 0.4619


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 5]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=2062115)[0m Client SVHN: Accuracy = 0.2825752919483712


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=2062

[36m(ClientAppActor pid=2062115)[0m Client SVHN: Accuracy = 0.2974031960663798


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         


[36m(ClientAppActor pid=2062115)[0m Client Fashion-MNIST: Accuracy = 0.4824


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 6]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=2062115)[0m Client CIFAR-10: Accuracy = 0.2218


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=2062

[36m(ClientAppActor pid=2062115)[0m Client CIFAR-10: Accuracy = 0.2323


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         


[36m(ClientAppActor pid=2062115)[0m Client Fashion-MNIST: Accuracy = 0.4421


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 7]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=2062115)[0m Client SVHN: Accuracy = 0.3401582667486171


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=2062

[36m(ClientAppActor pid=2062115)[0m Client Fashion-MNIST: Accuracy = 0.4799


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         


[36m(ClientAppActor pid=2062115)[0m Client SVHN: Accuracy = 0.3606714812538414


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 8]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=2062115)[0m Client CIFAR-10: Accuracy = 0.2576


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=2062

[36m(ClientAppActor pid=2062115)[0m Client Fashion-MNIST: Accuracy = 0.4748


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         


[36m(ClientAppActor pid=2062115)[0m Client SVHN: Accuracy = 0.37188844499078055


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 9]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=2062115)[0m Client CIFAR-10: Accuracy = 0.2529


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=2062

[36m(ClientAppActor pid=2062115)[0m Client SVHN: Accuracy = 0.36528119237861095


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         


[36m(ClientAppActor pid=2062115)[0m Client CIFAR-10: Accuracy = 0.2563


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 10]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=2062115)[0m Client Fashion-MNIST: Accuracy = 0.4763


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=2062

[36m(ClientAppActor pid=2062115)[0m Client CIFAR-10: Accuracy = 0.2567


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         


[36m(ClientAppActor pid=2062115)[0m Client Fashion-MNIST: Accuracy = 0.4991


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 11]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=2062115)[0m Client SVHN: Accuracy = 0.39278580208973574


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=2062

[36m(ClientAppActor pid=2062115)[0m Client CIFAR-10: Accuracy = 0.2412


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         


[36m(ClientAppActor pid=2062115)[0m Client SVHN: Accuracy = 0.41629532882606024


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 12]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=2062115)[0m Client Fashion-MNIST: Accuracy = 0.4922


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=2062

[36m(ClientAppActor pid=2062115)[0m Client SVHN: Accuracy = 0.41256914566687153


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         


[36m(ClientAppActor pid=2062115)[0m Client CIFAR-10: Accuracy = 0.2526


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 13]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=2062115)[0m Client Fashion-MNIST: Accuracy = 0.4727


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=2062

[36m(ClientAppActor pid=2062115)[0m Client CIFAR-10: Accuracy = 0.2509


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         


[36m(ClientAppActor pid=2062115)[0m Client SVHN: Accuracy = 0.4152581438229871


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 14]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=2062115)[0m Client Fashion-MNIST: Accuracy = 0.4394


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=2062

[36m(ClientAppActor pid=2062115)[0m Client SVHN: Accuracy = 0.4187538414259373


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         


[36m(ClientAppActor pid=2062115)[0m Client Fashion-MNIST: Accuracy = 0.4504


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 15]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=2062115)[0m Client CIFAR-10: Accuracy = 0.2487


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=2062

[36m(ClientAppActor pid=2062115)[0m Client Fashion-MNIST: Accuracy = 0.4661


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         


[36m(ClientAppActor pid=2062115)[0m Client CIFAR-10: Accuracy = 0.2506


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 16]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=2062115)[0m Client SVHN: Accuracy = 0.4287415488629379


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=2062

[36m(ClientAppActor pid=2062115)[0m Client CIFAR-10: Accuracy = 0.2511


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         


[36m(ClientAppActor pid=2062115)[0m Client SVHN: Accuracy = 0.43012446220036876


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 17]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=2062115)[0m Client Fashion-MNIST: Accuracy = 0.4632


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=2062

[36m(ClientAppActor pid=2062115)[0m Client CIFAR-10: Accuracy = 0.2505


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         


[36m(ClientAppActor pid=2062115)[0m Client Fashion-MNIST: Accuracy = 0.4881


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 18]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=2062115)[0m Client SVHN: Accuracy = 0.4208282114320836


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=2062

[36m(ClientAppActor pid=2062115)[0m Client SVHN: Accuracy = 0.42904886293792255


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         


[36m(ClientAppActor pid=2062115)[0m Client CIFAR-10: Accuracy = 0.248


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 19]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=2062115)[0m Client Fashion-MNIST: Accuracy = 0.4409


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=2062

[36m(ClientAppActor pid=2062115)[0m Client Fashion-MNIST: Accuracy = 0.4734


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         


[36m(ClientAppActor pid=2062115)[0m Client SVHN: Accuracy = 0.42828057775046097


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 20]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=2062115)[0m Client CIFAR-10: Accuracy = 0.2482


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=2062

[36m(ClientAppActor pid=2062115)[0m Client Fashion-MNIST: Accuracy = 0.4806


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         


[36m(ClientAppActor pid=2062115)[0m Client CIFAR-10: Accuracy = 0.2528


[36m(ClientAppActor pid=2062115)[0m 
[36m(ClientAppActor pid=2062115)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2062115)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2062115)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 20 round(s) in 9588.11s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 2.292483091673799
[92mINFO [0m:      		round 2: 2.2738414826886726
[92mINFO [0m:      		round 3: 2.2569204008615764
[92mINFO [0m:      		round 4: 2.2354987290361032
[92mINFO [0m:      		round 5: 2.21899238995141
[92mINFO [0m:      		round 6: 2.2032743558466485
[92mINFO [0m:      		round 7: 2.186733261612544
[92mINFO [0m:      		round 8: 2.174674830982889
[92mINFO [0m:      		round 9: 2.169691954038025
[92mINFO [0m:      		round 10: 2.1

[36m(ClientAppActor pid=2062115)[0m Client SVHN: Accuracy = 0.4267440073755378


In [1]:
#####SWATSSS

In [2]:
import flwr as fl

import torch

import torch.nn as nn

import torch.optim as optim

import torch.nn.functional as F

from torch.utils.data import DataLoader

import torchvision.transforms as transforms 

from torchvision.datasets import CIFAR10, FashionMNIST, SVHN
 
# Constants

BATCH_SIZE = 64

LEARNING_RATE = 0.001

EPOCHS = 5

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SWITCH_EPOCH = 3  # The epoch at which we switch from Adam to SGD
 
# Define the model with dataset-specific layers

class CNN_Model(nn.Module):

    def __init__(self, dataset_name):

        super(CNN_Model, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)

        self.pool = nn.AvgPool2d(2, 2)

        self.fc_config = nn.Linear(64 * 8 * 8, 128)
 
        # Dataset-specific layers

        if dataset_name == "cifar10":

            self.fc_op = nn.Sequential(

                nn.Linear(128, 64),

                nn.ReLU(),

                nn.Linear(64, 10)

            )

        elif dataset_name == "fashion_mnist":

            self.fc_op = nn.Sequential(

                nn.Linear(128, 32),

                nn.ReLU(),

                nn.Linear(32, 10)

            )

        elif dataset_name == "svhn":

            self.fc_op = nn.Sequential(

                nn.Linear(128, 64),

                nn.ReLU(),

                nn.Linear(64, 10)

            )
 
    def forward(self, x):

        x = F.relu(self.conv1(x))

        x = self.pool(x)

        x = F.relu(self.conv2(x))

        x = self.pool(x)

        x = x.view(-1, 64 * 8 * 8)

        x = F.relu(self.fc_config(x))

        x = self.fc_op(x)

        return F.softmax(x, dim=1)
 
# Define the Flower client

class FLClient(fl.client.NumPyClient):

    def __init__(self, model, train_loader, test_loader, client_id):

        self.model = model.to(DEVICE)

        self.train_loader = train_loader

        self.test_loader = test_loader

        self.criterion = nn.CrossEntropyLoss()

        self.optimizer = optim.Adam(self.model.parameters(), lr=LEARNING_RATE)

        self.sgd = None

        self.switched = False

        self.client_id = client_id
 
    def switch_to_sgd(self):

        # Switch to SGD optimizer

        self.sgd = optim.SGD(self.model.parameters(), lr=self.optimizer.param_groups[0]['lr'])

        self.switched = True

        print(f"Client {self.client_id}: Switched to SGD")
 
    def step_optimizer(self, epoch):

        # Switch from Adam to SGD if the conditions are met

        if not self.switched and epoch >= SWITCH_EPOCH:

            self.switch_to_sgd()

        return self.sgd if self.switched else self.optimizer
 
    def get_parameters(self, config=None):

        params = {

            k: v.cpu().numpy()

            for k, v in self.model.state_dict().items()

            if not k.startswith("fc_op")

        }

        return list(params.values())
 
    def set_parameters(self, parameters):

        state_dict = self.model.state_dict()

        shared_keys = [k for k in state_dict.keys() if not k.startswith("fc_op")]

        for key, param in zip(shared_keys, parameters):

            state_dict[key] = torch.tensor(param)

        self.model.load_state_dict(state_dict, strict=False)
 
    def fine_tune(self):

        self.model.train()

        for epoch in range(2):  # Fine-tune for 2 local epochs

            optimizer = self.step_optimizer(epoch)

            for images, labels in self.train_loader:

                images, labels = images.to(DEVICE), labels.to(DEVICE)

                optimizer.zero_grad()

                outputs = self.model(images)

                loss = self.criterion(outputs, labels)

                loss.backward()

                optimizer.step()
 
    def fit(self, parameters, config):

        self.set_parameters(parameters)  # Set global model parameters

        self.model.train()

        for epoch in range(EPOCHS):  # Train on local data

            optimizer = self.step_optimizer(epoch)

            for images, labels in self.train_loader:

                images, labels = images.to(DEVICE), labels.to(DEVICE)

                optimizer.zero_grad()

                outputs = self.model(images)

                loss = self.criterion(outputs, labels)

                loss.backward()

                optimizer.step()

        # Perform fine-tuning

        self.fine_tune()
 
        return self.get_parameters(config), len(self.train_loader.dataset), {}
 
    def evaluate(self, parameters, config):

        self.set_parameters(parameters)

        self.model.eval()

        correct, total = 0, 0

        loss = 0.0

        with torch.no_grad():

            for images, labels in self.test_loader:

                images, labels = images.to(DEVICE), labels.to(DEVICE)

                outputs = self.model(images)

                loss += self.criterion(outputs, labels).item()

                _, predicted = torch.max(outputs.data, 1)

                total += labels.size(0)

                correct += (predicted == labels).sum().item()

        accuracy = correct / total

        print(f"Client {self.client_id}: Accuracy = {accuracy}")

        return float(loss) / len(self.test_loader), len(self.test_loader.dataset), {"accuracy": accuracy}
 
# Load the datasets

def load_data(dataset_name):

    transform = transforms.Compose([

        transforms.Resize((32, 32)),

        transforms.ToTensor(),

        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

    ])

    if dataset_name == "cifar10":

        train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)

        test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)

    elif dataset_name == "fashion_mnist":

        transform = transforms.Compose([

            transforms.Grayscale(num_output_channels=3),

            transforms.Resize((32, 32)),

            transforms.ToTensor(),

            transforms.Normalize((0.5,), (0.5,)),

        ])

        train_dataset = FashionMNIST(root='./data', train=True, download=True, transform=transform)

        test_dataset = FashionMNIST(root='./data', train=False, download=True, transform=transform)

    elif dataset_name == "svhn":

        train_dataset = SVHN(root='./data', split='train', download=True, transform=transform)

        test_dataset = SVHN(root='./data', split='test', download=True, transform=transform)

    else:

        raise ValueError(f"Unknown dataset: {dataset_name}")
 
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    return train_loader, test_loader
 
# Updated client_fn to avoid using Context's client_id

def client_fn(cid: str) -> fl.client.Client:

    if cid == "0":

        return cifar_client

    elif cid == "1":

        return fashion_client

    elif cid == "2":

        return svhn_client
 
# Start federated learning

def start_federated_learning():

    global cifar_client, fashion_client, svhn_client

    cifar_train_loader, cifar_test_loader = load_data("cifar10")

    fashion_train_loader, fashion_test_loader = load_data("fashion_mnist")

    svhn_train_loader, svhn_test_loader = load_data("svhn")
 
    cifar_client = FLClient(CNN_Model("cifar10"), cifar_train_loader, cifar_test_loader, "CIFAR-10")

    fashion_client = FLClient(CNN_Model("fashion_mnist"), fashion_train_loader, fashion_test_loader, "Fashion-MNIST")

    svhn_client = FLClient(CNN_Model("svhn"), svhn_train_loader, svhn_test_loader, "SVHN")
 
    fl.simulation.start_simulation(

        client_fn=client_fn,

        num_clients=3,

        client_resources={"num_cpus": 2, "num_gpus": 1},

        config=fl.server.ServerConfig(num_rounds=10),  # Increased number of rounds

        strategy=fl.server.strategy.FedAvg(

            fraction_fit=1.0,

            min_fit_clients=3,

            min_available_clients=3,

        )

    )
 
if __name__ == "__main__":

    start_federated_learning()

 

2024-12-09 18:51:43,831	INFO util.py:154 -- Outdated packages:
  ipywidgets==6.0.0 found, needs ipywidgets>=8
Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ./data/train_32x32.mat
Using downloaded and verified file: ./data/test_32x32.mat


[92mINFO [0m:      Starting Flower simulation, config: num_rounds=10, no round_timeout
2024-12-09 18:53:15,213	INFO worker.py:1752 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'GPU': 1.0, 'object_store_memory': 5514640588.0, 'memory': 11029281179.0, 'accelerator_type:P4000': 1.0, 'node:__internal_head__': 1.0, 'CPU': 12.0, 'node:131.227.65.97': 1.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 2, 'num_gpus': 1}
[92mINFO [0m:      Flower VCE: Creating VirtualClientEngineActorPool with 1 actors
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[36m(ClientAppActor pid=3050191)[0m   return torch.load(io.BytesIO(b))
[36m(ClientAppActor pid=3050191)[0m   return torch.load(io.BytesIO(b))
[36m(ClientAppActor pid=3050191)[0

[36m(ClientAppActor pid=3050191)[0m Client SVHN: Switched to SGD


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client CIFAR-10: Switched to SGD


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client Fashion-MNIST: Switched to SGD


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client SVHN: Accuracy = 0.3045098340503995


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client Fashion-MNIST: Accuracy = 0.3957


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3050191)[0m Client CIFAR-10: Accuracy = 0.1683


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client SVHN: Switched to SGD


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client Fashion-MNIST: Switched to SGD


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client CIFAR-10: Switched to SGD


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client Fashion-MNIST: Accuracy = 0.5435


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client SVHN: Accuracy = 0.37161954517516904


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3050191)[0m Client CIFAR-10: Accuracy = 0.2193


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client SVHN: Switched to SGD


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client Fashion-MNIST: Switched to SGD


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client CIFAR-10: Switched to SGD


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client SVHN: Accuracy = 0.4134142593730793


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client CIFAR-10: Accuracy = 0.2121


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 4]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3050191)[0m Client Fashion-MNIST: Accuracy = 0.5989


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client Fashion-MNIST: Switched to SGD


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client CIFAR-10: Switched to SGD


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client SVHN: Switched to SGD


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client SVHN: Accuracy = 0.45152120467117396


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client CIFAR-10: Accuracy = 0.2332


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 5]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3050191)[0m Client Fashion-MNIST: Accuracy = 0.6636


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client Fashion-MNIST: Switched to SGD


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client SVHN: Switched to SGD


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client CIFAR-10: Switched to SGD


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client SVHN: Accuracy = 0.4560540872771973


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client CIFAR-10: Accuracy = 0.2575


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 6]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3050191)[0m Client Fashion-MNIST: Accuracy = 0.6545


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client SVHN: Switched to SGD


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client CIFAR-10: Switched to SGD


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client Fashion-MNIST: Switched to SGD


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client Fashion-MNIST: Accuracy = 0.7324


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client CIFAR-10: Accuracy = 0.2513


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 7]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3050191)[0m Client SVHN: Accuracy = 0.47887215734480637


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client SVHN: Switched to SGD


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client Fashion-MNIST: Switched to SGD


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client CIFAR-10: Switched to SGD


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client SVHN: Accuracy = 0.48152274124154887


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client Fashion-MNIST: Accuracy = 0.7341


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 8]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3050191)[0m Client CIFAR-10: Accuracy = 0.2641


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client SVHN: Switched to SGD


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client Fashion-MNIST: Switched to SGD


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client CIFAR-10: Switched to SGD


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client SVHN: Accuracy = 0.48570989551321453


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client CIFAR-10: Accuracy = 0.2694


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 9]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3050191)[0m Client Fashion-MNIST: Accuracy = 0.7869


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client Fashion-MNIST: Switched to SGD


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client SVHN: Switched to SGD


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client CIFAR-10: Switched to SGD


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client Fashion-MNIST: Accuracy = 0.8054


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client SVHN: Accuracy = 0.49312384757221883


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 10]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3050191)[0m Client CIFAR-10: Accuracy = 0.2671


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client CIFAR-10: Switched to SGD


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client Fashion-MNIST: Switched to SGD


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client SVHN: Switched to SGD


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client Fashion-MNIST: Accuracy = 0.7722


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         


[36m(ClientAppActor pid=3050191)[0m Client CIFAR-10: Accuracy = 0.274


[36m(ClientAppActor pid=3050191)[0m 
[36m(ClientAppActor pid=3050191)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3050191)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3050191)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 10 round(s) in 3284.93s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 2.2808811926651433
[92mINFO [0m:      		round 2: 2.2205961677201715
[92mINFO [0m:      		round 3: 2.158824339132976
[92mINFO [0m:      		round 4: 2.113066815545423
[92mINFO [0m:      		round 5: 2.085564257232204
[92mINFO [0m:      		round 6: 2.045375787749561
[92mINFO [0m:      		round 7: 2.0267888823173137
[92mINFO [0m:      		round 8: 2.004824337948899
[92mINFO [0m:      		round 9: 1.987430869516504
[92mINFO [0m:      		round 10: 1.9

[36m(ClientAppActor pid=3050191)[0m Client SVHN: Accuracy = 0.4969268592501537


In [9]:
import flwr as fl

import torch

import torch.nn as nn

import torch.optim as optim

import torch.nn.functional as F

from torch.utils.data import DataLoader

import torchvision.transforms as transforms 

from torchvision.datasets import CIFAR10, FashionMNIST, SVHN
 
# Constants

BATCH_SIZE = 64

LEARNING_RATE = 0.001

EPOCHS = 5

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SWITCH_EPOCH = 3  # The epoch at which we switch from Adam to SGD
 
# Define the model with dataset-specific layers

class CNN_Model(nn.Module):

    def __init__(self, dataset_name):

        super(CNN_Model, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)

        self.pool = nn.AvgPool2d(2, 2)

        self.fc_config = nn.Linear(64 * 8 * 8, 128)
 
        # Dataset-specific layers

        if dataset_name == "cifar10":

            self.fc_op = nn.Sequential(

                nn.Linear(128, 64),

                nn.ReLU(),

                nn.Linear(64, 10)

            )

        elif dataset_name == "fashion_mnist":

            self.fc_op = nn.Sequential(

                nn.Linear(128, 32),

                nn.ReLU(),

                nn.Linear(32, 10)

            )

        elif dataset_name == "svhn":

            self.fc_op = nn.Sequential(

                nn.Linear(128, 64),

                nn.ReLU(),

                nn.Linear(64, 10)

            )
 
    def forward(self, x):

        x = F.relu(self.conv1(x))

        x = self.pool(x)

        x = F.relu(self.conv2(x))

        x = self.pool(x)

        x = x.view(-1, 64 * 8 * 8)

        x = F.relu(self.fc_config(x))

        x = self.fc_op(x)

        return F.softmax(x, dim=1)
 
# Define the Flower client

class FLClient(fl.client.NumPyClient):

    def __init__(self, model, train_loader, test_loader, client_id):

        self.model = model.to(DEVICE)

        self.train_loader = train_loader

        self.test_loader = test_loader

        self.criterion = nn.CrossEntropyLoss()

        self.optimizer = optim.Adam(self.model.parameters(), lr=LEARNING_RATE)

        self.sgd = None

        self.switched = False

        self.client_id = client_id
 
    def switch_to_sgd(self):

        # Switch to SGD optimizer

        self.sgd = optim.SGD(self.model.parameters(), lr=self.optimizer.param_groups[0]['lr'])

        self.switched = True

        print(f"Client {self.client_id}: Switched to SGD")
 
    def step_optimizer(self, epoch):

        # Switch from Adam to SGD if the conditions are met

        if not self.switched and epoch >= SWITCH_EPOCH:

            self.switch_to_sgd()

        return self.sgd if self.switched else self.optimizer
 
    def get_parameters(self, config=None):

        params = {

            k: v.cpu().numpy()

            for k, v in self.model.state_dict().items()

            if not k.startswith("fc_op")

        }

        return list(params.values())
 
    def set_parameters(self, parameters):

        state_dict = self.model.state_dict()

        shared_keys = [k for k in state_dict.keys() if not k.startswith("fc_op")]

        for key, param in zip(shared_keys, parameters):

            state_dict[key] = torch.tensor(param)

        self.model.load_state_dict(state_dict, strict=False)
 
    def fine_tune(self):

        self.model.train()

        for epoch in range(2):  # Fine-tune for 2 local epochs

            optimizer = self.step_optimizer(epoch)

            for images, labels in self.train_loader:

                images, labels = images.to(DEVICE), labels.to(DEVICE)

                optimizer.zero_grad()

                outputs = self.model(images)

                loss = self.criterion(outputs, labels)

                loss.backward()

                optimizer.step()
 
    def fit(self, parameters, config):

        self.set_parameters(parameters)  # Set global model parameters

        self.model.train()

        for epoch in range(EPOCHS):  # Train on local data

            optimizer = self.step_optimizer(epoch)

            for images, labels in self.train_loader:

                images, labels = images.to(DEVICE), labels.to(DEVICE)

                optimizer.zero_grad()

                outputs = self.model(images)

                loss = self.criterion(outputs, labels)

                loss.backward()

                optimizer.step()

        # Perform fine-tuning

        self.fine_tune()
 
        return self.get_parameters(config), len(self.train_loader.dataset), {}
 
    def evaluate(self, parameters, config):

        self.set_parameters(parameters)

        self.model.eval()

        correct, total = 0, 0

        loss = 0.0

        with torch.no_grad():

            for images, labels in self.test_loader:

                images, labels = images.to(DEVICE), labels.to(DEVICE)

                outputs = self.model(images)

                loss += self.criterion(outputs, labels).item()

                _, predicted = torch.max(outputs.data, 1)

                total += labels.size(0)

                correct += (predicted == labels).sum().item()

        accuracy = correct / total

        print(f"Client {self.client_id}: Accuracy = {accuracy}")

        return float(loss) / len(self.test_loader), len(self.test_loader.dataset), {"accuracy": accuracy}
 
# Load the datasets

def load_data(dataset_name):

    transform = transforms.Compose([

        transforms.Resize((32, 32)),

        transforms.ToTensor(),

        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

    ])

    if dataset_name == "cifar10":

        train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)

        test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)

    elif dataset_name == "fashion_mnist":

        transform = transforms.Compose([

            transforms.Grayscale(num_output_channels=3),

            transforms.Resize((32, 32)),

            transforms.ToTensor(),

            transforms.Normalize((0.5,), (0.5,)),

        ])

        train_dataset = FashionMNIST(root='./data', train=True, download=True, transform=transform)

        test_dataset = FashionMNIST(root='./data', train=False, download=True, transform=transform)

    elif dataset_name == "svhn":

        train_dataset = SVHN(root='./data', split='train', download=True, transform=transform)

        test_dataset = SVHN(root='./data', split='test', download=True, transform=transform)

    else:

        raise ValueError(f"Unknown dataset: {dataset_name}")
 
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    return train_loader, test_loader
 
# Updated client_fn to avoid using Context's client_id

def client_fn(cid: str) -> fl.client.Client:

    if cid == "0":

        return cifar_client

    elif cid == "1":

        return fashion_client

    elif cid == "2":

        return svhn_client
 
# Start federated learning

def start_federated_learning():

    global cifar_client, fashion_client, svhn_client

    cifar_train_loader, cifar_test_loader = load_data("cifar10")

    fashion_train_loader, fashion_test_loader = load_data("fashion_mnist")

    svhn_train_loader, svhn_test_loader = load_data("svhn")
 
    cifar_client = FLClient(CNN_Model("cifar10"), cifar_train_loader, cifar_test_loader, "CIFAR-10")

    fashion_client = FLClient(CNN_Model("fashion_mnist"), fashion_train_loader, fashion_test_loader, "Fashion-MNIST")

    svhn_client = FLClient(CNN_Model("svhn"), svhn_train_loader, svhn_test_loader, "SVHN")
 
    fl.simulation.start_simulation(

        client_fn=client_fn,

        num_clients=3,

        client_resources={"num_cpus": 2, "num_gpus": 1},

        config=fl.server.ServerConfig(num_rounds=5),  # Increased number of rounds

        strategy=fl.server.strategy.FedAvg(

            fraction_fit=1.0,

            min_fit_clients=3,

            min_available_clients=3,

        )

    )
 
if __name__ == "__main__":

    start_federated_learning()

 

Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ./data/train_32x32.mat
Using downloaded and verified file: ./data/test_32x32.mat


[92mINFO [0m:      Starting Flower simulation, config: num_rounds=5, no round_timeout
2024-12-11 19:45:36,273	INFO worker.py:1752 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'memory': 9886003200.0, 'GPU': 1.0, 'accelerator_type:P4000': 1.0, 'node:__internal_head__': 1.0, 'object_store_memory': 4943001600.0, 'CPU': 12.0, 'node:131.227.65.97': 1.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 2, 'num_gpus': 1}
[92mINFO [0m:      Flower VCE: Creating VirtualClientEngineActorPool with 1 actors
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[36m(ClientAppActor pid=3505856)[0m   return torch.load(io.BytesIO(b))
[36m(ClientAppActor pid=3505856)[0m   return torch.load(io.BytesIO(b))
[92mINFO [0m:      Received initial 

[36m(ClientAppActor pid=3505856)[0m Client Fashion-MNIST: Switched to SGD


[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         


[36m(ClientAppActor pid=3505856)[0m Client CIFAR-10: Switched to SGD


[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         


[36m(ClientAppActor pid=3505856)[0m Client SVHN: Switched to SGD


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         


[36m(ClientAppActor pid=3505856)[0m Client Fashion-MNIST: Accuracy = 0.2447


[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         


[36m(ClientAppActor pid=3505856)[0m Client CIFAR-10: Accuracy = 0.1605


[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3505856)[0m Client SVHN: Accuracy = 0.2034419176398279


[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         


[36m(ClientAppActor pid=3505856)[0m Client CIFAR-10: Switched to SGD


[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         


[36m(ClientAppActor pid=3505856)[0m Client Fashion-MNIST: Switched to SGD


[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         


[36m(ClientAppActor pid=3505856)[0m Client SVHN: Switched to SGD


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         


[36m(ClientAppActor pid=3505856)[0m Client SVHN: Accuracy = 0.42820374923171484


[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         


[36m(ClientAppActor pid=3505856)[0m Client CIFAR-10: Accuracy = 0.2495


[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3505856)[0m Client Fashion-MNIST: Accuracy = 0.5791


[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         


[36m(ClientAppActor pid=3505856)[0m Client CIFAR-10: Switched to SGD


[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         


[36m(ClientAppActor pid=3505856)[0m Client Fashion-MNIST: Switched to SGD


[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         


[36m(ClientAppActor pid=3505856)[0m Client SVHN: Switched to SGD


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         


[36m(ClientAppActor pid=3505856)[0m Client Fashion-MNIST: Accuracy = 0.6591


[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         


[36m(ClientAppActor pid=3505856)[0m Client CIFAR-10: Accuracy = 0.2647


[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 4]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3505856)[0m Client SVHN: Accuracy = 0.5234711124769514


[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         


[36m(ClientAppActor pid=3505856)[0m Client SVHN: Switched to SGD


[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         


[36m(ClientAppActor pid=3505856)[0m Client CIFAR-10: Switched to SGD


[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         


[36m(ClientAppActor pid=3505856)[0m Client Fashion-MNIST: Switched to SGD


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         


[36m(ClientAppActor pid=3505856)[0m Client CIFAR-10: Accuracy = 0.297


[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         


[36m(ClientAppActor pid=3505856)[0m Client SVHN: Accuracy = 0.5464044253226797


[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 5]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3505856)[0m Client Fashion-MNIST: Accuracy = 0.6641


[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         


[36m(ClientAppActor pid=3505856)[0m Client Fashion-MNIST: Switched to SGD


[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         


[36m(ClientAppActor pid=3505856)[0m Client CIFAR-10: Switched to SGD


[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         


[36m(ClientAppActor pid=3505856)[0m Client SVHN: Switched to SGD


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         


[36m(ClientAppActor pid=3505856)[0m Client SVHN: Accuracy = 0.5675706822372465


[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         


[36m(ClientAppActor pid=3505856)[0m Client Fashion-MNIST: Accuracy = 0.692


[36m(ClientAppActor pid=3505856)[0m 
[36m(ClientAppActor pid=3505856)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3505856)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3505856)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 5 round(s) in 1612.60s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 2.285966410083734
[92mINFO [0m:      		round 2: 2.228248428496121
[92mINFO [0m:      		round 3: 2.149272329851232
[92mINFO [0m:      		round 4: 2.092918914707557
[92mINFO [0m:      		round 5: 2.0643550829975954
[92mINFO [0m:      


[36m(ClientAppActor pid=3505856)[0m Client CIFAR-10: Accuracy = 0.289


In [3]:
import flwr as fl

import torch

import torch.nn as nn

import torch.optim as optim

import torch.nn.functional as F

from torch.utils.data import DataLoader

import torchvision.transforms as transforms 

from torchvision.datasets import CIFAR10, FashionMNIST, SVHN
 
# Constants

BATCH_SIZE = 64

LEARNING_RATE = 0.001

EPOCHS = 5

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SWITCH_EPOCH = 3  # The epoch at which we switch from Adam to SGD
 
# Define the model with dataset-specific layers

class CNN_Model(nn.Module):

    def __init__(self, dataset_name):

        super(CNN_Model, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)

        self.pool = nn.AvgPool2d(2, 2)

        self.fc_config = nn.Linear(64 * 8 * 8, 128)
 
        # Dataset-specific layers

        if dataset_name == "cifar10":

            self.fc_op = nn.Sequential(

                nn.Linear(128, 64),

                nn.ReLU(),

                nn.Linear(64, 10)

            )

        elif dataset_name == "fashion_mnist":

            self.fc_op = nn.Sequential(

                nn.Linear(128, 32),

                nn.ReLU(),

                nn.Linear(32, 10)

            )

        elif dataset_name == "svhn":

            self.fc_op = nn.Sequential(

                nn.Linear(128, 64),

                nn.ReLU(),

                nn.Linear(64, 10)

            )
 
    def forward(self, x):

        x = F.relu(self.conv1(x))

        x = self.pool(x)

        x = F.relu(self.conv2(x))

        x = self.pool(x)

        x = x.view(-1, 64 * 8 * 8)

        x = F.relu(self.fc_config(x))

        x = self.fc_op(x)

        return F.softmax(x, dim=1)
 
# Define the Flower client

class FLClient(fl.client.NumPyClient):

    def __init__(self, model, train_loader, test_loader, client_id):

        self.model = model.to(DEVICE)

        self.train_loader = train_loader

        self.test_loader = test_loader

        self.criterion = nn.CrossEntropyLoss()

        self.optimizer = optim.Adam(self.model.parameters(), lr=LEARNING_RATE)

        self.sgd = None

        self.switched = False

        self.client_id = client_id
 
    def switch_to_sgd(self):

        # Switch to SGD optimizer

        self.sgd = optim.SGD(self.model.parameters(), lr=self.optimizer.param_groups[0]['lr'])

        self.switched = True

        print(f"Client {self.client_id}: Switched to SGD")
 
    def step_optimizer(self, epoch):

        # Switch from Adam to SGD if the conditions are met

        if not self.switched and epoch >= SWITCH_EPOCH:

            self.switch_to_sgd()

        return self.sgd if self.switched else self.optimizer
 
    def get_parameters(self, config=None):

        params = {

            k: v.cpu().numpy()

            for k, v in self.model.state_dict().items()

            if not k.startswith("fc_op")

        }

        return list(params.values())
 
    def set_parameters(self, parameters):

        state_dict = self.model.state_dict()

        shared_keys = [k for k in state_dict.keys() if not k.startswith("fc_op")]

        for key, param in zip(shared_keys, parameters):

            state_dict[key] = torch.tensor(param)

        self.model.load_state_dict(state_dict, strict=False)
 
    def fine_tune(self):

        self.model.train()

        for epoch in range(2):  # Fine-tune for 2 local epochs

            optimizer = self.step_optimizer(epoch)

            for images, labels in self.train_loader:

                images, labels = images.to(DEVICE), labels.to(DEVICE)

                optimizer.zero_grad()

                outputs = self.model(images)

                loss = self.criterion(outputs, labels)

                loss.backward()

                optimizer.step()
 
    def fit(self, parameters, config):

        self.set_parameters(parameters)  # Set global model parameters

        self.model.train()

        for epoch in range(EPOCHS):  # Train on local data

            optimizer = self.step_optimizer(epoch)

            for images, labels in self.train_loader:

                images, labels = images.to(DEVICE), labels.to(DEVICE)

                optimizer.zero_grad()

                outputs = self.model(images)

                loss = self.criterion(outputs, labels)

                loss.backward()

                optimizer.step()

        # Perform fine-tuning

        self.fine_tune()
 
        return self.get_parameters(config), len(self.train_loader.dataset), {}
 
    def evaluate(self, parameters, config):

        self.set_parameters(parameters)

        self.model.eval()

        correct, total = 0, 0

        loss = 0.0

        with torch.no_grad():

            for images, labels in self.test_loader:

                images, labels = images.to(DEVICE), labels.to(DEVICE)

                outputs = self.model(images)

                loss += self.criterion(outputs, labels).item()

                _, predicted = torch.max(outputs.data, 1)

                total += labels.size(0)

                correct += (predicted == labels).sum().item()

        accuracy = correct / total

        print(f"Client {self.client_id}: Accuracy = {accuracy}")

        return float(loss) / len(self.test_loader), len(self.test_loader.dataset), {"accuracy": accuracy}
 
# Load the datasets

def load_data(dataset_name):

    transform = transforms.Compose([

        transforms.Resize((32, 32)),

        transforms.ToTensor(),

        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

    ])

    if dataset_name == "cifar10":

        train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)

        test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)

    elif dataset_name == "fashion_mnist":

        transform = transforms.Compose([

            transforms.Grayscale(num_output_channels=3),

            transforms.Resize((32, 32)),

            transforms.ToTensor(),

            transforms.Normalize((0.5,), (0.5,)),

        ])

        train_dataset = FashionMNIST(root='./data', train=True, download=True, transform=transform)

        test_dataset = FashionMNIST(root='./data', train=False, download=True, transform=transform)

    elif dataset_name == "svhn":

        train_dataset = SVHN(root='./data', split='train', download=True, transform=transform)

        test_dataset = SVHN(root='./data', split='test', download=True, transform=transform)

    else:

        raise ValueError(f"Unknown dataset: {dataset_name}")
 
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    return train_loader, test_loader
 
# Updated client_fn to avoid using Context's client_id

def client_fn(cid: str) -> fl.client.Client:

    if cid == "0":

        return cifar_client

    elif cid == "1":

        return fashion_client

    elif cid == "2":

        return svhn_client
 
# Start federated learning

def start_federated_learning():

    global cifar_client, fashion_client, svhn_client

    cifar_train_loader, cifar_test_loader = load_data("cifar10")

    fashion_train_loader, fashion_test_loader = load_data("fashion_mnist")

    svhn_train_loader, svhn_test_loader = load_data("svhn")
 
    cifar_client = FLClient(CNN_Model("cifar10"), cifar_train_loader, cifar_test_loader, "CIFAR-10")

    fashion_client = FLClient(CNN_Model("fashion_mnist"), fashion_train_loader, fashion_test_loader, "Fashion-MNIST")

    svhn_client = FLClient(CNN_Model("svhn"), svhn_train_loader, svhn_test_loader, "SVHN")
 
    fl.simulation.start_simulation(

        client_fn=client_fn,

        num_clients=3,

        client_resources={"num_cpus": 2, "num_gpus": 1},

        config=fl.server.ServerConfig(num_rounds=5),  # Updated to 5 rounds

        strategy=fl.server.strategy.FedAvg(

            fraction_fit=1.0,

            min_fit_clients=3,

            min_available_clients=3,

        )

    )
 
if __name__ == "__main__":

    start_federated_learning()


Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ./data/train_32x32.mat
Using downloaded and verified file: ./data/test_32x32.mat


[92mINFO [0m:      Starting Flower simulation, config: num_rounds=5, no round_timeout
2024-12-11 16:46:30,829	INFO worker.py:1752 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'object_store_memory': 5350174310.0, 'accelerator_type:P4000': 1.0, 'node:__internal_head__': 1.0, 'node:131.227.65.97': 1.0, 'CPU': 12.0, 'memory': 10700348622.0, 'GPU': 1.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 2, 'num_gpus': 1}
[92mINFO [0m:      Flower VCE: Creating VirtualClientEngineActorPool with 1 actors
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[36m(ClientAppActor pid=3463827)[0m   return torch.load(io.BytesIO(b))
[36m(ClientAppActor pid=3463827)[0m   return torch.load(io.BytesIO(b))
[92mINFO [0m:      Received initial

[36m(ClientAppActor pid=3463827)[0m Client CIFAR-10: Switched to SGD


[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         


[36m(ClientAppActor pid=3463827)[0m Client Fashion-MNIST: Switched to SGD


[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         


[36m(ClientAppActor pid=3463827)[0m Client SVHN: Switched to SGD


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         


[36m(ClientAppActor pid=3463827)[0m Client SVHN: Accuracy = 0.1810464044253227


[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         


[36m(ClientAppActor pid=3463827)[0m Client CIFAR-10: Accuracy = 0.1318


[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3463827)[0m Client Fashion-MNIST: Accuracy = 0.4094


[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         


[36m(ClientAppActor pid=3463827)[0m Client CIFAR-10: Switched to SGD


[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         


[36m(ClientAppActor pid=3463827)[0m Client SVHN: Switched to SGD


[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         


[36m(ClientAppActor pid=3463827)[0m Client Fashion-MNIST: Switched to SGD


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         


[36m(ClientAppActor pid=3463827)[0m Client SVHN: Accuracy = 0.3721573448063921


[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         


[36m(ClientAppActor pid=3463827)[0m Client Fashion-MNIST: Accuracy = 0.5428


[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3463827)[0m Client CIFAR-10: Accuracy = 0.1953


[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         


[36m(ClientAppActor pid=3463827)[0m Client SVHN: Switched to SGD


[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         


[36m(ClientAppActor pid=3463827)[0m Client Fashion-MNIST: Switched to SGD


[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         


[36m(ClientAppActor pid=3463827)[0m Client CIFAR-10: Switched to SGD


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         


[36m(ClientAppActor pid=3463827)[0m Client CIFAR-10: Accuracy = 0.2153


[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         


[36m(ClientAppActor pid=3463827)[0m Client Fashion-MNIST: Accuracy = 0.5989


[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 4]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3463827)[0m Client SVHN: Accuracy = 0.44821757836508913


[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         


[36m(ClientAppActor pid=3463827)[0m Client SVHN: Switched to SGD


[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         


[36m(ClientAppActor pid=3463827)[0m Client CIFAR-10: Switched to SGD


[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         


[36m(ClientAppActor pid=3463827)[0m Client Fashion-MNIST: Switched to SGD


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         


[36m(ClientAppActor pid=3463827)[0m Client SVHN: Accuracy = 0.49930854333128455


[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         


[36m(ClientAppActor pid=3463827)[0m Client Fashion-MNIST: Accuracy = 0.6158


[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 5]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3463827)[0m Client CIFAR-10: Accuracy = 0.2828


[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         


[36m(ClientAppActor pid=3463827)[0m Client CIFAR-10: Switched to SGD


[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         


[36m(ClientAppActor pid=3463827)[0m Client Fashion-MNIST: Switched to SGD


[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         


[36m(ClientAppActor pid=3463827)[0m Client SVHN: Switched to SGD


[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         


[36m(ClientAppActor pid=3463827)[0m Client CIFAR-10: Accuracy = 0.3091


[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         


[36m(ClientAppActor pid=3463827)[0m Client SVHN: Accuracy = 0.5221266133988937


[36m(ClientAppActor pid=3463827)[0m 
[36m(ClientAppActor pid=3463827)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3463827)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3463827)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 5 round(s) in 1629.96s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 2.290308430598742
[92mINFO [0m:      		round 2: 2.238324681197446
[92mINFO [0m:      		round 3: 2.188418283666788
[92mINFO [0m:      		round 4: 2.1304967574452203
[92mINFO [0m:      		round 5: 2.0855737343181335
[92mINFO [0m:      


[36m(ClientAppActor pid=3463827)[0m Client Fashion-MNIST: Accuracy = 0.6906


In [7]:
import flwr as fl
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms 
from torchvision.datasets import CIFAR10, FashionMNIST, SVHN

# Constants
BATCH_SIZE = 64
LEARNING_RATE = 0.001
EPOCHS = 5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the model with dataset-specific layers
class CNN_Model(nn.Module):
    def __init__(self, dataset_name):
        super(CNN_Model, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.AvgPool2d(2, 2)
        self.dropout = nn.Dropout(0.5) if dataset_name == "cifar10" else nn.Identity()
        self.fc_config = nn.Linear(64 * 8 * 8, 128)

        # Dataset-specific layers
        if dataset_name == "cifar10":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 64),
                nn.ReLU(),
                self.dropout,
                nn.Linear(64, 10)
            )
        elif dataset_name == "fashion_mnist":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 32),
                nn.ReLU(),
                nn.Linear(32, 10)
            )
        elif dataset_name == "svhn":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, 10)
            )

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc_config(x))
        x = self.fc_op(x)
        return F.softmax(x, dim=1)

# Define the Flower client
class FLClient(fl.client.NumPyClient):
    def __init__(self, model, train_loader, test_loader, client_id):
        self.model = model.to(DEVICE)
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=LEARNING_RATE)
        self.client_id = client_id

    def get_parameters(self, config=None):
        params = {
            k: v.cpu().numpy()
            for k, v in self.model.state_dict().items()
            if not k.startswith("fc_op")
        }
        return list(params.values())

    def set_parameters(self, parameters):
        state_dict = self.model.state_dict()
        shared_keys = [k for k in state_dict.keys() if not k.startswith("fc_op")]
        for key, param in zip(shared_keys, parameters):
            state_dict[key] = torch.tensor(param)
        self.model.load_state_dict(state_dict, strict=False)

    def fine_tune(self):
        self.model.train()
        for epoch in range(2):  # Fine-tune for 2 local epochs
            for images, labels in self.train_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()

    def fit(self, parameters, config):
        self.set_parameters(parameters)  # Set global model parameters
        self.model.train()
        for epoch in range(EPOCHS):  # Train on local data
            for images, labels in self.train_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()
        # Perform fine-tuning
        self.fine_tune()

        return self.get_parameters(config), len(self.train_loader.dataset), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        self.model.eval()
        correct, total = 0, 0
        loss = 0.0
        with torch.no_grad():
            for images, labels in self.test_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = self.model(images)
                loss += self.criterion(outputs, labels).item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy = correct / total
        print(f"Client {self.client_id}: Accuracy = {accuracy}")
        return float(loss) / len(self.test_loader), len(self.test_loader.dataset), {"accuracy": accuracy}

# Load the datasets
def load_data(dataset_name):
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    if dataset_name == "cifar10":
        train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
        test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
    elif dataset_name == "fashion_mnist":
        transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=3),
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ])
        train_dataset = FashionMNIST(root='./data', train=True, download=True, transform=transform)
        test_dataset = FashionMNIST(root='./data', train=False, download=True, transform=transform)
    elif dataset_name == "svhn":
        train_dataset = SVHN(root='./data', split='train', download=True, transform=transform)
        test_dataset = SVHN(root='./data', split='test', download=True, transform=transform)
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    return train_loader, test_loader

# Updated client_fn to avoid using Context's client_id
def client_fn(cid: str) -> fl.client.Client:
    if cid == "0":
        return cifar_client
    elif cid == "1":
        return fashion_client
    elif cid == "2":
        return svhn_client

# Start federated learning
def start_federated_learning():
    global cifar_client, fashion_client, svhn_client
    cifar_train_loader, cifar_test_loader = load_data("cifar10")
    fashion_train_loader, fashion_test_loader = load_data("fashion_mnist")
    svhn_train_loader, svhn_test_loader = load_data("svhn")

    cifar_client = FLClient(CNN_Model("cifar10"), cifar_train_loader, cifar_test_loader, "CIFAR-10")
    fashion_client = FLClient(CNN_Model("fashion_mnist"), fashion_train_loader, fashion_test_loader, "Fashion-MNIST")
    svhn_client = FLClient(CNN_Model("svhn"), svhn_train_loader, svhn_test_loader, "SVHN")

    fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=3,
        client_resources={"num_cpus": 2, "num_gpus": 1},
        config=fl.server.ServerConfig(num_rounds=10),  # Increased number of rounds
        strategy=fl.server.strategy.FedAvg(
            fraction_fit=1.0,
            min_fit_clients=3,
            min_available_clients=3,
        )
    )

if __name__ == "__main__":
    start_federated_learning()


Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ./data/train_32x32.mat
Using downloaded and verified file: ./data/test_32x32.mat


[92mINFO [0m:      Starting Flower simulation, config: num_rounds=10, no round_timeout
2024-12-11 17:36:25,297	INFO worker.py:1752 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'object_store_memory': 4941254246.0, 'memory': 9882508494.0, 'node:131.227.65.97': 1.0, 'CPU': 12.0, 'node:__internal_head__': 1.0, 'accelerator_type:P4000': 1.0, 'GPU': 1.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 2, 'num_gpus': 1}
[92mINFO [0m:      Flower VCE: Creating VirtualClientEngineActorPool with 1 actors
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[36m(ClientAppActor pid=3477374)[0m   return torch.load(io.BytesIO(b))
[36m(ClientAppActor pid=3477374)[0m   return torch.load(io.BytesIO(b))
[92mINFO [0m:      Received initial

[36m(ClientAppActor pid=3477374)[0m Client SVHN: Accuracy = 0.2796558082360172


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         


[36m(ClientAppActor pid=3477374)[0m Client CIFAR-10: Accuracy = 0.1131


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3477374)[0m Client Fashion-MNIST: Accuracy = 0.2142


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3477

[36m(ClientAppActor pid=3477374)[0m Client Fashion-MNIST: Accuracy = 0.3774


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         


[36m(ClientAppActor pid=3477374)[0m Client CIFAR-10: Accuracy = 0.1734


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3477374)[0m Client SVHN: Accuracy = 0.41810079901659497


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3477

[36m(ClientAppActor pid=3477374)[0m Client SVHN: Accuracy = 0.42198063921327594


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         


[36m(ClientAppActor pid=3477374)[0m Client CIFAR-10: Accuracy = 0.1888


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 4]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3477374)[0m Client Fashion-MNIST: Accuracy = 0.4082


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3477

[36m(ClientAppActor pid=3477374)[0m Client SVHN: Accuracy = 0.40769053472649047


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         


[36m(ClientAppActor pid=3477374)[0m Client Fashion-MNIST: Accuracy = 0.4926


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 5]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3477374)[0m Client CIFAR-10: Accuracy = 0.2065


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3477

[36m(ClientAppActor pid=3477374)[0m Client SVHN: Accuracy = 0.44387676705593115


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         


[36m(ClientAppActor pid=3477374)[0m Client CIFAR-10: Accuracy = 0.2265


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 6]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3477374)[0m Client Fashion-MNIST: Accuracy = 0.5503


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3477

[36m(ClientAppActor pid=3477374)[0m Client Fashion-MNIST: Accuracy = 0.5587


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         


[36m(ClientAppActor pid=3477374)[0m Client CIFAR-10: Accuracy = 0.2323


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 7]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3477374)[0m Client SVHN: Accuracy = 0.45413337430854334


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3477

[36m(ClientAppActor pid=3477374)[0m Client CIFAR-10: Accuracy = 0.243


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         


[36m(ClientAppActor pid=3477374)[0m Client SVHN: Accuracy = 0.46469729563614015


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 8]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3477374)[0m Client Fashion-MNIST: Accuracy = 0.5659


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3477

[36m(ClientAppActor pid=3477374)[0m Client Fashion-MNIST: Accuracy = 0.5525


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         


[36m(ClientAppActor pid=3477374)[0m Client CIFAR-10: Accuracy = 0.2522


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 9]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3477374)[0m Client SVHN: Accuracy = 0.4719960049170252


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3477

[36m(ClientAppActor pid=3477374)[0m Client CIFAR-10: Accuracy = 0.2577


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         


[36m(ClientAppActor pid=3477374)[0m Client SVHN: Accuracy = 0.4641979102642901


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 10]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3477374)[0m Client Fashion-MNIST: Accuracy = 0.5438


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3477

[36m(ClientAppActor pid=3477374)[0m Client Fashion-MNIST: Accuracy = 0.5683


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         


[36m(ClientAppActor pid=3477374)[0m Client SVHN: Accuracy = 0.468000921942225


[36m(ClientAppActor pid=3477374)[0m 
[36m(ClientAppActor pid=3477374)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3477374)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3477374)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 10 round(s) in 3290.23s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 2.2972366015501593
[92mINFO [0m:      		round 2: 2.2784153685049544
[92mINFO [0m:      		round 3: 2.249284768228714
[92mINFO [0m:      		round 4: 2.2202483703289655
[92mINFO [0m:      		round 5: 2.1866156974433557
[92mINFO [0m:      		round 6: 2.162673748756688
[92mINFO [0m:      		round 7: 2.140796797593429
[92mINFO [0m:      		round 8: 2.1157980077295098
[92mINFO [0m:      		round 9: 2.108499870196716
[92mINFO [0m:      		round 10: 2

[36m(ClientAppActor pid=3477374)[0m Client CIFAR-10: Accuracy = 0.2598


In [None]:
####FEDPROXXX

In [11]:
import flwr as fl
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, FashionMNIST, SVHN

# Constants
BATCH_SIZE = 128  # Default batch size for proper processing
LEARNING_RATE = 0.001  # Default learning rate for stability
EPOCHS = 10  # Standard epochs for training
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the model with dataset-specific layers
class CNN_Model(nn.Module):
    def __init__(self, dataset_name):
        super(CNN_Model, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.AvgPool2d(2, 2)
        self.fc_config = nn.Linear(64 * 8 * 8, 128)

        # Dataset-specific layers
        if dataset_name == "cifar10":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, 10)
            )
        elif dataset_name == "fashion_mnist":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 32),
                nn.ReLU(),
                nn.Linear(32, 10)
            )
        elif dataset_name == "svhn":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, 10)
            )

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc_config(x))
        x = self.fc_op(x)
        return F.softmax(x, dim=1)

# Define the Flower client with PFL using FedProx
class FLClient(fl.client.NumPyClient):
    def __init__(self, model, train_loader, test_loader, client_id):
        self.model = model.to(DEVICE)
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=LEARNING_RATE)  # Default optimizer for proper training
        self.client_id = client_id

    def get_parameters(self, config=None):
        params = {
            k: v.cpu().numpy()
            for k, v in self.model.state_dict().items()
            if not k.startswith("fc_op")
        }
        return list(params.values())

    def set_parameters(self, parameters):
        state_dict = self.model.state_dict()
        shared_keys = [k for k in state_dict.keys() if not k.startswith("fc_op")]
        for key, param in zip(shared_keys, parameters):
            state_dict[key] = torch.tensor(param)
        self.model.load_state_dict(state_dict, strict=False)

    def fit(self, parameters, config):
        # Set global parameters on the local model
        self.set_parameters(parameters)
        self.model.train()
        
        # Save global model parameters for proximity
        global_params = [param.clone().detach().to(DEVICE) for param in self.model.parameters()]
        
        for epoch in range(EPOCHS):
            for images, labels in self.train_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                
                # Add proximal term to the loss
                prox_term = 0.0
                for param, global_param in zip(self.model.parameters(), global_params):
                    prox_term += torch.norm(param - global_param) ** 2
                loss += (0.01 / 2) * prox_term  # 0.01 is the proximal coefficient
                
                loss.backward()
                self.optimizer.step()
        
        # Fine-tune locally for personalization
        for epoch in range(2):  # Fine-tune for 2 local epochs
            for images, labels in self.train_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()
        
        return self.get_parameters(config), len(self.train_loader.dataset), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        self.model.eval()
        correct, total = 0, 0
        loss = 0.0
        with torch.no_grad():
            for images, labels in self.test_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = self.model(images)
                loss += self.criterion(outputs, labels).item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy = correct / total
        print(f"Client {self.client_id}: Accuracy = {accuracy}")
        return float(loss) / len(self.test_loader), len(self.test_loader.dataset), {"accuracy": accuracy}

# Load the datasets
def load_data(dataset_name):
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    if dataset_name == "cifar10":
        train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
        test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
    elif dataset_name == "fashion_mnist": 
        transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=3),
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ])
        train_dataset = FashionMNIST(root='./data', train=True, download=True, transform=transform)
        test_dataset = FashionMNIST(root='./data', train=False, download=True, transform=transform)
    elif dataset_name == "svhn":
        train_dataset = SVHN(root='./data', split='train', download=True, transform=transform)
        test_dataset = SVHN(root='./data', split='test', download=True, transform=transform)
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    return train_loader, test_loader

# Updated client_fn to avoid using Context's client_id
def client_fn(cid: str) -> fl.client.Client:
    if cid == "0":
        return cifar_client
    elif cid == "1":
        return fashion_client
    elif cid == "2":
        return svhn_client

# Start federated learning
def start_federated_learning():
    global cifar_client, fashion_client, svhn_client
    cifar_train_loader, cifar_test_loader = load_data("cifar10")
    fashion_train_loader, fashion_test_loader = load_data("fashion_mnist")
    svhn_train_loader, svhn_test_loader = load_data("svhn")

    cifar_client = FLClient(CNN_Model("cifar10"), cifar_train_loader, cifar_test_loader, "CIFAR-10")
    fashion_client = FLClient(CNN_Model("fashion_mnist"), fashion_train_loader, fashion_test_loader, "Fashion-MNIST")
    svhn_client = FLClient(CNN_Model("svhn"), svhn_train_loader, svhn_test_loader, "SVHN")

    fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=3,
        client_resources={"num_cpus": 4, "num_gpus": 1},
        config=fl.server.ServerConfig(num_rounds=10),  # Increased number of rounds for proper results
        strategy=fl.server.strategy.FedAvg(
            fraction_fit=1.0,
            min_fit_clients=3,
            min_available_clients=3,
        )
    )

if __name__ == "__main__":
    start_federated_learning()


Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ./data/train_32x32.mat
Using downloaded and verified file: ./data/test_32x32.mat


[92mINFO [0m:      Starting Flower simulation, config: num_rounds=10, no round_timeout
2024-12-12 19:41:14,201	INFO worker.py:1752 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'object_store_memory': 5006184038.0, 'memory': 10012368078.0, 'GPU': 1.0, 'accelerator_type:P4000': 1.0, 'node:__internal_head__': 1.0, 'node:131.227.65.97': 1.0, 'CPU': 12.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 4, 'num_gpus': 1}
[92mINFO [0m:      Flower VCE: Creating VirtualClientEngineActorPool with 1 actors
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[36m(ClientAppActor pid=3810323)[0m   return torch.load(io.BytesIO(b))
[36m(ClientAppActor pid=3810323)[0m   return torch.load(io.BytesIO(b))
[92mINFO [0m:      Received initia

[36m(ClientAppActor pid=3810323)[0m Client CIFAR-10: Accuracy = 0.1413


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         


[36m(ClientAppActor pid=3810323)[0m Client Fashion-MNIST: Accuracy = 0.1808


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3810323)[0m Client SVHN: Accuracy = 0.3015519360786724


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3810

[36m(ClientAppActor pid=3810323)[0m Client SVHN: Accuracy = 0.32671327596803934


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         


[36m(ClientAppActor pid=3810323)[0m Client CIFAR-10: Accuracy = 0.1729


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3810323)[0m Client Fashion-MNIST: Accuracy = 0.3102


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3810

[36m(ClientAppActor pid=3810323)[0m Client SVHN: Accuracy = 0.3581361401352182


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         


[36m(ClientAppActor pid=3810323)[0m Client CIFAR-10: Accuracy = 0.199


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 4]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3810323)[0m Client Fashion-MNIST: Accuracy = 0.4092


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3810

[36m(ClientAppActor pid=3810323)[0m Client Fashion-MNIST: Accuracy = 0.4595


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         


[36m(ClientAppActor pid=3810323)[0m Client CIFAR-10: Accuracy = 0.1961


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 5]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3810323)[0m Client SVHN: Accuracy = 0.3770359557467732


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3810

[36m(ClientAppActor pid=3810323)[0m Client CIFAR-10: Accuracy = 0.1907


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         


[36m(ClientAppActor pid=3810323)[0m Client Fashion-MNIST: Accuracy = 0.5523


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 6]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3810323)[0m Client SVHN: Accuracy = 0.36393669330055317


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3810

[36m(ClientAppActor pid=3810323)[0m Client Fashion-MNIST: Accuracy = 0.5345


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         


[36m(ClientAppActor pid=3810323)[0m Client SVHN: Accuracy = 0.360940381069453


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 7]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3810323)[0m Client CIFAR-10: Accuracy = 0.1988


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3810

[36m(ClientAppActor pid=3810323)[0m Client Fashion-MNIST: Accuracy = 0.5721


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         


[36m(ClientAppActor pid=3810323)[0m Client CIFAR-10: Accuracy = 0.2131


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 8]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3810323)[0m Client SVHN: Accuracy = 0.36570374923171484


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3810

[36m(ClientAppActor pid=3810323)[0m Client CIFAR-10: Accuracy = 0.2072


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         


[36m(ClientAppActor pid=3810323)[0m Client SVHN: Accuracy = 0.34857098955132143


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 9]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3810323)[0m Client Fashion-MNIST: Accuracy = 0.5492


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3810

[36m(ClientAppActor pid=3810323)[0m Client SVHN: Accuracy = 0.34526736324523666


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         


[36m(ClientAppActor pid=3810323)[0m Client Fashion-MNIST: Accuracy = 0.5692


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 10]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3810323)[0m Client CIFAR-10: Accuracy = 0.2051


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3810

[36m(ClientAppActor pid=3810323)[0m Client Fashion-MNIST: Accuracy = 0.5794


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         


[36m(ClientAppActor pid=3810323)[0m Client SVHN: Accuracy = 0.3524124154886294


[36m(ClientAppActor pid=3810323)[0m 
[36m(ClientAppActor pid=3810323)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3810323)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3810323)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 10 round(s) in 5081.95s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 2.293990036116014
[92mINFO [0m:      		round 2: 2.270926991560752
[92mINFO [0m:      		round 3: 2.250725599171052
[92mINFO [0m:      		round 4: 2.229209118952503
[92mINFO [0m:      		round 5: 2.2110769580682255
[92mINFO [0m:      		round 6: 2.197544544239026
[92mINFO [0m:      		round 7: 2.1829120097083607
[92mINFO [0m:      		round 8: 2.173876832365916
[92mINFO [0m:      		round 9: 2.1604292001759267
[92mINFO [0m:      		round 10: 2.1

[36m(ClientAppActor pid=3810323)[0m Client CIFAR-10: Accuracy = 0.1912


In [None]:
import flwr as fl
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, FashionMNIST, SVHN

# Constants
BATCH_SIZE = 128  # Default batch size for proper processing
LEARNING_RATE = 0.001  # Default learning rate for stability
EPOCHS = 10  # Standard epochs for training
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the model with dataset-specific layers
class CNN_Model(nn.Module):
    def __init__(self, dataset_name):
        super(CNN_Model, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.AvgPool2d(2, 2)
        self.fc_config = nn.Linear(64 * 8 * 8, 128)

        # Dataset-specific layers
        if dataset_name == "cifar10":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, 10)
            )
        elif dataset_name == "fashion_mnist":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 32),
                nn.ReLU(),
                nn.Linear(32, 10)
            )
        elif dataset_name == "svhn":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, 10)
            )

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc_config(x))
        x = self.fc_op(x)
        return F.softmax(x, dim=1)

# Define the Flower client with PFL using FedProx
class FLClient(fl.client.NumPyClient):
    def __init__(self, model, train_loader, test_loader, client_id):
        self.model = model.to(DEVICE)
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=LEARNING_RATE)  # Default optimizer for proper training
        self.client_id = client_id

    def get_parameters(self, config=None):
        params = {
            k: v.cpu().numpy()
            for k, v in self.model.state_dict().items()
            if not k.startswith("fc_op")
        }
        return list(params.values())

    def set_parameters(self, parameters):
        state_dict = self.model.state_dict()
        shared_keys = [k for k in state_dict.keys() if not k.startswith("fc_op")]
        for key, param in zip(shared_keys, parameters):
            state_dict[key] = torch.tensor(param)
        self.model.load_state_dict(state_dict, strict=False)

    def fit(self, parameters, config):
        # Set global parameters on the local model
        self.set_parameters(parameters)
        self.model.train()
        
        # Save global model parameters for proximity
        global_params = [param.clone().detach().to(DEVICE) for param in self.model.parameters()]
        
        for epoch in range(EPOCHS):
            for images, labels in self.train_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                
                # Add proximal term to the loss
                prox_term = 0.0
                for param, global_param in zip(self.model.parameters(), global_params):
                    prox_term += torch.norm(param - global_param) ** 2
                loss += (0.01 / 2) * prox_term  # 0.01 is the proximal coefficient
                
                loss.backward()
                self.optimizer.step()
        
        # Fine-tune locally for personalization
        for epoch in range(2):  # Fine-tune for 2 local epochs
            for images, labels in self.train_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()
        
        return self.get_parameters(config), len(self.train_loader.dataset), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        self.model.eval()
        correct, total = 0, 0
        loss = 0.0
        with torch.no_grad():
            for images, labels in self.test_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = self.model(images)
                loss += self.criterion(outputs, labels).item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy = correct / total
        print(f"Client {self.client_id}: Accuracy = {accuracy}")
        return float(loss) / len(self.test_loader), len(self.test_loader.dataset), {"accuracy": accuracy}

# Load the datasets
def load_data(dataset_name):
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    if dataset_name == "cifar10":
        train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
        test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
    elif dataset_name == "fashion_mnist": 
        transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=3),
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ])
        train_dataset = FashionMNIST(root='./data', train=True, download=True, transform=transform)
        test_dataset = FashionMNIST(root='./data', train=False, download=True, transform=transform)
    elif dataset_name == "svhn":
        train_dataset = SVHN(root='./data', split='train', download=True, transform=transform)
        test_dataset = SVHN(root='./data', split='test', download=True, transform=transform)
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    return train_loader, test_loader

# Updated client_fn to avoid using Context's client_id
def client_fn(cid: str) -> fl.client.Client:
    if cid == "0":
        return cifar_client
    elif cid == "1":
        return fashion_client
    elif cid == "2":
        return svhn_client

# Start federated learning
def start_federated_learning():
    global cifar_client, fashion_client, svhn_client
    cifar_train_loader, cifar_test_loader = load_data("cifar10")
    fashion_train_loader, fashion_test_loader = load_data("fashion_mnist")
    svhn_train_loader, svhn_test_loader = load_data("svhn")

    cifar_client = FLClient(CNN_Model("cifar10"), cifar_train_loader, cifar_test_loader, "CIFAR-10")
    fashion_client = FLClient(CNN_Model("fashion_mnist"), fashion_train_loader, fashion_test_loader, "Fashion-MNIST")
    svhn_client = FLClient(CNN_Model("svhn"), svhn_train_loader, svhn_test_loader, "SVHN")

    fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=3,
        client_resources={"num_cpus": 4, "num_gpus": 1},
        config=fl.server.ServerConfig(num_rounds=20),  # Increased number of rounds for proper results
        strategy=fl.server.strategy.FedAvg(
            fraction_fit=1.0,
            min_fit_clients=3,
            min_available_clients=3,
        )
    )

if __name__ == "__main__":
    start_federated_learning()


2025-01-03 21:22:44,790	INFO util.py:154 -- Outdated packages:
  ipywidgets==6.0.0 found, needs ipywidgets>=8
Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [None]:
import flwr as fl
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, FashionMNIST, SVHN

# Constants
BATCH_SIZE = 64  # Default batch size for proper processing
LEARNING_RATE = 0.001  # Default learning rate for stability
EPOCHS = 5  # Standard epochs for training
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the model with dataset-specific layers
class CNN_Model(nn.Module):
    def __init__(self, dataset_name):
        super(CNN_Model, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.AvgPool2d(2, 2)
        self.fc_config = nn.Linear(64 * 8 * 8, 128)

        # Dataset-specific layers
        if dataset_name == "cifar10":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, 10)
            )
        elif dataset_name == "fashion_mnist":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 32),
                nn.ReLU(),
                nn.Linear(32, 10)
            )
        elif dataset_name == "svhn":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, 10)
            )

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc_config(x))
        x = self.fc_op(x)
        return F.softmax(x, dim=1)

# Define the Flower client with PFL using FedProx
class FLClient(fl.client.NumPyClient):
    def __init__(self, model, train_loader, test_loader, client_id):
        self.model = model.to(DEVICE)
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=LEARNING_RATE)  # Default optimizer for proper training
        self.client_id = client_id

    def get_parameters(self, config=None):
        params = {
            k: v.cpu().numpy()
            for k, v in self.model.state_dict().items()
            if not k.startswith("fc_op")
        }
        return list(params.values())

    def set_parameters(self, parameters):
        state_dict = self.model.state_dict()
        shared_keys = [k for k in state_dict.keys() if not k.startswith("fc_op")]
        for key, param in zip(shared_keys, parameters):
            state_dict[key] = torch.tensor(param)
        self.model.load_state_dict(state_dict, strict=False)

    def fit(self, parameters, config):
        # Set global parameters on the local model
        self.set_parameters(parameters)
        self.model.train()
        
        # Save global model parameters for proximity
        global_params = [param.clone().detach().to(DEVICE) for param in self.model.parameters()]
        
        for epoch in range(EPOCHS):
            for images, labels in self.train_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                
                # Add proximal term to the loss
                prox_term = 0.0
                for param, global_param in zip(self.model.parameters(), global_params):
                    prox_term += torch.norm(param - global_param) ** 2
                loss += (0.01 / 2) * prox_term  # 0.01 is the proximal coefficient
                
                loss.backward()
                self.optimizer.step()
        
        # Fine-tune locally for personalization
        for epoch in range(2):  # Fine-tune for 2 local epochs
            for images, labels in self.train_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()
        
        return self.get_parameters(config), len(self.train_loader.dataset), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        self.model.eval()
        correct, total = 0, 0
        loss = 0.0
        with torch.no_grad():
            for images, labels in self.test_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = self.model(images)
                loss += self.criterion(outputs, labels).item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy = correct / total
        print(f"Client {self.client_id}: Accuracy = {accuracy}")
        return float(loss) / len(self.test_loader), len(self.test_loader.dataset), {"accuracy": accuracy}

# Load the datasets
def load_data(dataset_name):
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    if dataset_name == "cifar10":
        train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
        test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
    elif dataset_name == "fashion_mnist": 
        transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=3),
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ])
        train_dataset = FashionMNIST(root='./data', train=True, download=True, transform=transform)
        test_dataset = FashionMNIST(root='./data', train=False, download=True, transform=transform)
    elif dataset_name == "svhn":
        train_dataset = SVHN(root='./data', split='train', download=True, transform=transform)
        test_dataset = SVHN(root='./data', split='test', download=True, transform=transform)
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    return train_loader, test_loader

# Updated client_fn to avoid using Context's client_id
def client_fn(cid: str) -> fl.client.Client:
    if cid == "0":
        return cifar_client
    elif cid == "1":
        return fashion_client
    elif cid == "2":
        return svhn_client

# Start federated learning
def start_federated_learning():
    global cifar_client, fashion_client, svhn_client
    cifar_train_loader, cifar_test_loader = load_data("cifar10")
    fashion_train_loader, fashion_test_loader = load_data("fashion_mnist")
    svhn_train_loader, svhn_test_loader = load_data("svhn")

    cifar_client = FLClient(CNN_Model("cifar10"), cifar_train_loader, cifar_test_loader, "CIFAR-10")
    fashion_client = FLClient(CNN_Model("fashion_mnist"), fashion_train_loader, fashion_test_loader, "Fashion-MNIST")
    svhn_client = FLClient(CNN_Model("svhn"), svhn_train_loader, svhn_test_loader, "SVHN")

    fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=3,
        client_resources={"num_cpus": 4, "num_gpus": 1},
        config=fl.server.ServerConfig(num_rounds=5),  # Increased number of rounds for proper results
        strategy=fl.server.strategy.FedAvg(
            fraction_fit=1.0,
            min_fit_clients=3,
            min_available_clients=3,
        )
    )

if __name__ == "__main__":
    start_federated_learning()


In [None]:
##ADaptive lambda 

In [12]:
import flwr as fl
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, FashionMNIST, SVHN
 
# Constants
BATCH_SIZE = 32  # Default batch size for proper processing
LEARNING_RATE = 0.001  # Default learning rate for stability
EPOCHS = 2  # Standard epochs for training
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
# Define the model with dataset-specific layers
class CNN_Model(nn.Module):
    def __init__(self, dataset_name):
        super(CNN_Model, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.AvgPool2d(2, 2)
        self.fc_config = nn.Linear(64 * 8 * 8, 128)
 
        # Dataset-specific layers
        if dataset_name == "cifar10":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, 10)
            )
        elif dataset_name == "fashion_mnist":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 32),
                nn.ReLU(),
                nn.Linear(32, 10)
            )
        elif dataset_name == "svhn":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, 10)
            )
 
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc_config(x))
        x = self.fc_op(x)
        return F.softmax(x, dim=1)
 
# Helper functions for heterogeneity and dynamic lambda
def compute_statistical_heterogeneity(global_params, local_params):
    heterogeneity = 0.0
    for global_param, local_param in zip(global_params, local_params):
        heterogeneity += torch.norm(local_param - global_param) ** 2
    return heterogeneity
 
def update_lambda(heterogeneity, alpha=0.1, epsilon=1e-5):
    return alpha / (heterogeneity + epsilon)
 
# Define the Flower client with PFL using FedProx
class FLClient(fl.client.NumPyClient):
    def __init__(self, model, train_loader, test_loader, client_id):
        self.model = model.to(DEVICE)
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=LEARNING_RATE)
        self.client_id = client_id
 
    def get_parameters(self, config=None):
        params = {
            k: v.cpu().numpy()
            for k, v in self.model.state_dict().items()
            if not k.startswith("fc_op")
        }
        return list(params.values())
 
    def set_parameters(self, parameters):
        state_dict = self.model.state_dict()
        shared_keys = [k for k in state_dict.keys() if not k.startswith("fc_op")]
        for key, param in zip(shared_keys, parameters):
            state_dict[key] = torch.tensor(param)
        self.model.load_state_dict(state_dict, strict=False)
 
    def fit(self, parameters, config):
        # Set global parameters on the local model
        self.set_parameters(parameters)
        self.model.train()
 
        # Save global model parameters for proximity
        global_params = [param.clone().detach().to(DEVICE) for param in self.model.parameters()]
        for epoch in range(EPOCHS):
            for images, labels in self.train_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
 
                # Calculate heterogeneity and update lambda
                heterogeneity = compute_statistical_heterogeneity(global_params, self.model.parameters())
                lambda_val = update_lambda(heterogeneity)
 
                # Add proximal term with dynamic lambda
                prox_term = 0.0
                for param, global_param in zip(self.model.parameters(), global_params):
                    prox_term += torch.norm(param - global_param) ** 2
                loss += (lambda_val / 2) * prox_term
 
                loss.backward()
                self.optimizer.step()
 
        # Fine-tune locally for personalization
        for epoch in range(2):  # Fine-tune for 2 local epochs
            for images, labels in self.train_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()
        return self.get_parameters(config), len(self.train_loader.dataset), {}
 
    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        self.model.eval()
        correct, total = 0, 0
        loss = 0.0
        with torch.no_grad():
            for images, labels in self.test_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = self.model(images)
                loss += self.criterion(outputs, labels).item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy = correct / total
        print(f"Client {self.client_id}: Accuracy = {accuracy}")
        return float(loss) / len(self.test_loader), len(self.test_loader.dataset), {"accuracy": accuracy}
 
# Load the datasets
def load_data(dataset_name):
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    if dataset_name == "cifar10":
        train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
        test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
    elif dataset_name == "fashion_mnist": 
        transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=3),
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ])
        train_dataset = FashionMNIST(root='./data', train=True, download=True, transform=transform)
        test_dataset = FashionMNIST(root='./data', train=False, download=True, transform=transform)
    elif dataset_name == "svhn":
        train_dataset = SVHN(root='./data', split='train', download=True, transform=transform)
        test_dataset = SVHN(root='./data', split='test', download=True, transform=transform)
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
 
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    return train_loader, test_loader
 
# Updated client_fn to avoid using Context's client_id
def client_fn(cid: str) -> fl.client.Client:
    if cid == "0":
        return cifar_client
    elif cid == "1":
        return fashion_client
    elif cid == "2":
        return svhn_client
 
# Start federated learning
def start_federated_learning():
    global cifar_client, fashion_client, svhn_client
    cifar_train_loader, cifar_test_loader = load_data("cifar10")
    fashion_train_loader, fashion_test_loader = load_data("fashion_mnist")
    svhn_train_loader, svhn_test_loader = load_data("svhn")
 
    cifar_client = FLClient(CNN_Model("cifar10"), cifar_train_loader, cifar_test_loader, "CIFAR-10")
    fashion_client = FLClient(CNN_Model("fashion_mnist"), fashion_train_loader, fashion_test_loader, "Fashion-MNIST")
    svhn_client = FLClient(CNN_Model("svhn"), svhn_train_loader, svhn_test_loader, "SVHN")
 
    fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=3,
        client_resources={"num_cpus": 4, "num_gpus": 1},
        config=fl.server.ServerConfig(num_rounds=3),
        strategy=fl.server.strategy.FedAvg(
            fraction_fit=1.0,
            min_fit_clients=3,
            min_available_clients=3,
        )
    )
 
if __name__ == "__main__":
    start_federated_learning()

Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ./data/train_32x32.mat
Using downloaded and verified file: ./data/test_32x32.mat


[92mINFO [0m:      Starting Flower simulation, config: num_rounds=3, no round_timeout
2024-12-12 21:08:56,137	INFO worker.py:1752 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'GPU': 1.0, 'CPU': 12.0, 'node:131.227.65.97': 1.0, 'memory': 10003867239.0, 'node:__internal_head__': 1.0, 'accelerator_type:P4000': 1.0, 'object_store_memory': 5001933619.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 4, 'num_gpus': 1}
[92mINFO [0m:      Flower VCE: Creating VirtualClientEngineActorPool with 1 actors
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[36m(ClientAppActor pid=3827597)[0m   return torch.load(io.BytesIO(b))
[36m(ClientAppActor pid=3827597)[0m   return torch.load(io.BytesIO(b))
[92mINFO [0m:      Received initial

[36m(ClientAppActor pid=3827597)[0m Client SVHN: Accuracy = 0.25069145666871545


[36m(ClientAppActor pid=3827597)[0m 
[36m(ClientAppActor pid=3827597)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3827597)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3827597)[0m         


[36m(ClientAppActor pid=3827597)[0m Client CIFAR-10: Accuracy = 0.1723


[36m(ClientAppActor pid=3827597)[0m 
[36m(ClientAppActor pid=3827597)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3827597)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3827597)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3827597)[0m Client Fashion-MNIST: Accuracy = 0.2473


[36m(ClientAppActor pid=3827597)[0m 
[36m(ClientAppActor pid=3827597)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3827597)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3827597)[0m         
[36m(ClientAppActor pid=3827597)[0m 
[36m(ClientAppActor pid=3827597)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3827597)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3827597)[0m         
[36m(ClientAppActor pid=3827597)[0m 
[36m(ClientAppActor pid=3827597)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3827597)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3827597)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3827

[36m(ClientAppActor pid=3827597)[0m Client SVHN: Accuracy = 0.3509142593730793


[36m(ClientAppActor pid=3827597)[0m 
[36m(ClientAppActor pid=3827597)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3827597)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3827597)[0m         


[36m(ClientAppActor pid=3827597)[0m Client CIFAR-10: Accuracy = 0.2039


[36m(ClientAppActor pid=3827597)[0m 
[36m(ClientAppActor pid=3827597)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3827597)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3827597)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3827597)[0m Client Fashion-MNIST: Accuracy = 0.3661


[36m(ClientAppActor pid=3827597)[0m 
[36m(ClientAppActor pid=3827597)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3827597)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3827597)[0m         
[36m(ClientAppActor pid=3827597)[0m 
[36m(ClientAppActor pid=3827597)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3827597)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3827597)[0m         
[36m(ClientAppActor pid=3827597)[0m 
[36m(ClientAppActor pid=3827597)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3827597)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3827597)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3827

[36m(ClientAppActor pid=3827597)[0m Client Fashion-MNIST: Accuracy = 0.5361


[36m(ClientAppActor pid=3827597)[0m 
[36m(ClientAppActor pid=3827597)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3827597)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3827597)[0m         


[36m(ClientAppActor pid=3827597)[0m Client SVHN: Accuracy = 0.3898279041180086


[36m(ClientAppActor pid=3827597)[0m 
[36m(ClientAppActor pid=3827597)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3827597)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3827597)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 3 round(s) in 747.66s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 2.296437486170988
[92mINFO [0m:      		round 2: 2.275086521070226
[92mINFO [0m:      		round 3: 2.2498299371511137
[92mINFO [0m:      


[36m(ClientAppActor pid=3827597)[0m Client CIFAR-10: Accuracy = 0.204


In [None]:
import flwr as fl
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, FashionMNIST, SVHN
 
# Constants
BATCH_SIZE = 128  # Default batch size for proper processing
LEARNING_RATE = 0.001  # Default learning rate for stability
EPOCHS = 10  # Standard epochs for training
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
# Define the model with dataset-specific layers
class CNN_Model(nn.Module):
    def __init__(self, dataset_name):
        super(CNN_Model, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.AvgPool2d(2, 2)
        self.fc_config = nn.Linear(64 * 8 * 8, 128)
 
        # Dataset-specific layers
        if dataset_name == "cifar10":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, 10)
            )
        elif dataset_name == "fashion_mnist":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 32),
                nn.ReLU(),
                nn.Linear(32, 10)
            )
        elif dataset_name == "svhn":
            self.fc_op = nn.Sequential(
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, 10)
            )
 
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc_config(x))
        x = self.fc_op(x)
        return F.softmax(x, dim=1)
 
# Helper functions for heterogeneity and dynamic lambda
def compute_statistical_heterogeneity(global_params, local_params):
    heterogeneity = 0.0
    for global_param, local_param in zip(global_params, local_params):
        heterogeneity += torch.norm(local_param - global_param) ** 2
    return heterogeneity
 
def update_lambda(heterogeneity, alpha=0.1, epsilon=1e-5):
    return alpha / (heterogeneity + epsilon)
 
# Define the Flower client with PFL using FedProx
class FLClient(fl.client.NumPyClient):
    def __init__(self, model, train_loader, test_loader, client_id):
        self.model = model.to(DEVICE)
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=LEARNING_RATE)
        self.client_id = client_id
 
    def get_parameters(self, config=None):
        params = {
            k: v.cpu().numpy()
            for k, v in self.model.state_dict().items()
            if not k.startswith("fc_op")
        }
        return list(params.values())
 
    def set_parameters(self, parameters):
        state_dict = self.model.state_dict()
        shared_keys = [k for k in state_dict.keys() if not k.startswith("fc_op")]
        for key, param in zip(shared_keys, parameters):
            state_dict[key] = torch.tensor(param)
        self.model.load_state_dict(state_dict, strict=False)
 
    def fit(self, parameters, config):
        # Set global parameters on the local model
        self.set_parameters(parameters)
        self.model.train()
 
        # Save global model parameters for proximity
        global_params = [param.clone().detach().to(DEVICE) for param in self.model.parameters()]
        for epoch in range(EPOCHS):
            for images, labels in self.train_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
 
                # Calculate heterogeneity and update lambda
                heterogeneity = compute_statistical_heterogeneity(global_params, self.model.parameters())
                lambda_val = update_lambda(heterogeneity)
 
                # Add proximal term with dynamic lambda
                prox_term = 0.0
                for param, global_param in zip(self.model.parameters(), global_params):
                    prox_term += torch.norm(param - global_param) ** 2
                loss += (lambda_val / 2) * prox_term
 
                loss.backward()
                self.optimizer.step()
 
        # Fine-tune locally for personalization
        for epoch in range(2):  # Fine-tune for 2 local epochs
            for images, labels in self.train_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()
        return self.get_parameters(config), len(self.train_loader.dataset), {}
 
    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        self.model.eval()
        correct, total = 0, 0
        loss = 0.0
        with torch.no_grad():
            for images, labels in self.test_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = self.model(images)
                loss += self.criterion(outputs, labels).item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy = correct / total
        print(f"Client {self.client_id}: Accuracy = {accuracy}")
        return float(loss) / len(self.test_loader), len(self.test_loader.dataset), {"accuracy": accuracy}
 
# Load the datasets
def load_data(dataset_name):
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    if dataset_name == "cifar10":
        train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
        test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
    elif dataset_name == "fashion_mnist": 
        transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=3),
            transforms.Resize((32, 32)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ])
        train_dataset = FashionMNIST(root='./data', train=True, download=True, transform=transform)
        test_dataset = FashionMNIST(root='./data', train=False, download=True, transform=transform)
    elif dataset_name == "svhn":
        train_dataset = SVHN(root='./data', split='train', download=True, transform=transform)
        test_dataset = SVHN(root='./data', split='test', download=True, transform=transform)
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
 
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    return train_loader, test_loader
 
# Updated client_fn to avoid using Context's client_id
def client_fn(cid: str) -> fl.client.Client:
    if cid == "0":
        return cifar_client
    elif cid == "1":
        return fashion_client
    elif cid == "2":
        return svhn_client
 
# Start federated learning
def start_federated_learning():
    global cifar_client, fashion_client, svhn_client
    cifar_train_loader, cifar_test_loader = load_data("cifar10")
    fashion_train_loader, fashion_test_loader = load_data("fashion_mnist")
    svhn_train_loader, svhn_test_loader = load_data("svhn")
 
    cifar_client = FLClient(CNN_Model("cifar10"), cifar_train_loader, cifar_test_loader, "CIFAR-10")
    fashion_client = FLClient(CNN_Model("fashion_mnist"), fashion_train_loader, fashion_test_loader, "Fashion-MNIST")
    svhn_client = FLClient(CNN_Model("svhn"), svhn_train_loader, svhn_test_loader, "SVHN")
 
    fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=3,
        client_resources={"num_cpus": 4, "num_gpus": 1},
        config=fl.server.ServerConfig(num_rounds=10),
        strategy=fl.server.strategy.FedAvg(
            fraction_fit=1.0,
            min_fit_clients=3,
            min_available_clients=3,
        )
    )
 
if __name__ == "__main__":
    start_federated_learning()

Files already downloaded and verified
Files already downloaded and verified
Using downloaded and verified file: ./data/train_32x32.mat
Using downloaded and verified file: ./data/test_32x32.mat


[92mINFO [0m:      Starting Flower simulation, config: num_rounds=10, no round_timeout
2024-12-12 21:35:24,797	INFO worker.py:1752 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'GPU': 1.0, 'object_store_memory': 5134589952.0, 'memory': 10269179904.0, 'accelerator_type:P4000': 1.0, 'node:__internal_head__': 1.0, 'CPU': 12.0, 'node:131.227.65.97': 1.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 4, 'num_gpus': 1}
[92mINFO [0m:      Flower VCE: Creating VirtualClientEngineActorPool with 1 actors
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[36m(ClientAppActor pid=3832905)[0m   return torch.load(io.BytesIO(b))
[36m(ClientAppActor pid=3832905)[0m   return torch.load(io.BytesIO(b))
[92mINFO [0m:      Received initia

[36m(ClientAppActor pid=3832905)[0m Client Fashion-MNIST: Accuracy = 0.2978


[36m(ClientAppActor pid=3832905)[0m 
[36m(ClientAppActor pid=3832905)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3832905)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3832905)[0m         


[36m(ClientAppActor pid=3832905)[0m Client CIFAR-10: Accuracy = 0.1515


[36m(ClientAppActor pid=3832905)[0m 
[36m(ClientAppActor pid=3832905)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3832905)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3832905)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3832905)[0m Client SVHN: Accuracy = 0.21876920712968653


[36m(ClientAppActor pid=3832905)[0m 
[36m(ClientAppActor pid=3832905)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3832905)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3832905)[0m         
[36m(ClientAppActor pid=3832905)[0m 
[36m(ClientAppActor pid=3832905)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3832905)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3832905)[0m         
[36m(ClientAppActor pid=3832905)[0m 
[36m(ClientAppActor pid=3832905)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3832905)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3832905)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=3832

[36m(ClientAppActor pid=3832905)[0m Client SVHN: Accuracy = 0.3451521204671174


[36m(ClientAppActor pid=3832905)[0m 
[36m(ClientAppActor pid=3832905)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3832905)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3832905)[0m         


[36m(ClientAppActor pid=3832905)[0m Client Fashion-MNIST: Accuracy = 0.5811


[36m(ClientAppActor pid=3832905)[0m 
[36m(ClientAppActor pid=3832905)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3832905)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3832905)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)


[36m(ClientAppActor pid=3832905)[0m Client CIFAR-10: Accuracy = 0.2003


[36m(ClientAppActor pid=3832905)[0m 
[36m(ClientAppActor pid=3832905)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3832905)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3832905)[0m         
[36m(ClientAppActor pid=3832905)[0m 
[36m(ClientAppActor pid=3832905)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=3832905)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=3832905)[0m         
