<a href="https://colab.research.google.com/github/HumayraFerdous/Hybrid-Models/blob/master/Wine_Quality_FL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
from IPython.display import display
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
#!pip install opencv-python
#!pip install opacus
#!pip install flwr
#!pip install -U "flwr[simulation]"
import ray
from opacus import PrivacyEngine
import flwr as fl
import numpy as np
from flwr.simulation import start_simulation
from flwr.server.strategy import FedAvg
import os
ray.init(ignore_reinit_error=True)

data = load_wine()
X = StandardScaler().fit_transform(data['data'])
y = data['target']

X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)


client_trainloaders, client_testloaders = [], []
for _ in range(3):
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y)
    train_ds = TensorDataset(X_train, y_train)
    test_ds = TensorDataset(X_test, y_test)

    train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)
    test_dl = DataLoader(test_ds, batch_size=32)

    client_trainloaders.append(train_dl)
    client_testloaders.append(test_dl)


class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(13, 64)
        self.fc2 = nn.Linear(64, 3)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)


def train(model, dataloader, epochs=1):
    model.train()
    optimizer = optim.SGD(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()

    privacy_engine = PrivacyEngine()
    model, optimizer, dataloader = privacy_engine.make_private(
        module=model,
        optimizer=optimizer,
        data_loader=dataloader,
        noise_multiplier=1.0,
        max_grad_norm=1.0,
    )

    for _ in range(epochs):
        for x_batch, y_batch in dataloader:
            optimizer.zero_grad()
            output = model(x_batch)
            loss = criterion(output, y_batch)
            loss.backward()
            optimizer.step()

def evaluate_model(model, dataloader):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for x_batch, y_batch in dataloader:
            outputs = model(x_batch)
            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.numpy())
            all_labels.extend(y_batch.numpy())

    acc = accuracy_score(all_labels, all_preds)
    prec = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    rec = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
    f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
    return acc, prec, rec, f1


class WineClient(fl.client.NumPyClient):
    def __init__(self, model, trainloader, testloader):
        self.model = model
        self.trainloader = trainloader
        self.testloader = testloader

    def get_parameters(self, config):
        return [val.cpu().numpy() for val in self.model.state_dict().values()]

    def set_parameters(self, parameters):
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = {k: torch.tensor(v) for k, v in params_dict}
        self.model.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
      self.set_parameters(parameters)
      train(self.model, self.trainloader, epochs=1)
      return self.get_parameters(config), len(self.trainloader.dataset), {}

    def evaluate(self, parameters, config):
      self.set_parameters(parameters)
      acc, prec, rec, f1 = evaluate_model(self.model, self.testloader)
      return float(1 - acc), len(self.testloader.dataset), {
        "accuracy": acc,
        "precision": prec,
        "recall": rec,
        "f1": f1
    }


def client_fn(cid: str):
    cid = int(cid)
    model = MLP()
    return WineClient(model, client_trainloaders[cid], client_testloaders[cid])


def weighted_average(metrics):
    total_samples = sum(num_examples for num_examples, _ in metrics)
    avg_metrics = {}
    for metric in ["accuracy", "precision", "recall", "f1"]:
        weighted_sum = sum(
            num_examples * client_metrics[metric]
            for num_examples, client_metrics in metrics
        )
        avg_metrics[metric] = weighted_sum / total_samples
    print(f"[Server Avg] Accuracy: {avg_metrics['accuracy']:.4f}, "
          f"Precision: {avg_metrics['precision']:.4f}, "
          f"Recall: {avg_metrics['recall']:.4f}, F1: {avg_metrics['f1']:.4f}")
    return avg_metrics


strategy = FedAvg(evaluate_metrics_aggregation_fn=weighted_average)


start_simulation(
    client_fn=client_fn,
    num_clients=3,
    config=fl.server.ServerConfig(num_rounds=5),
    strategy=strategy,
)

2025-05-22 15:49:24,112	INFO worker.py:1604 -- Calling ray.init() again after it has already been called.
	Instead, use the `flwr run` CLI command to start a local simulation in your Flower app, as shown for example below:

		$ flwr new  # Create a new Flower app from a template

		$ flwr run  # Run the Flower app in Simulation Mode

	Using `start_simulation()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower simulation, config: num_rounds=5, no round_timeout
2025-05-22 15:49:28,227	INFO worker.py:1771 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'CPU': 2.0, 'object_store_memory': 3995584512.0, 'node:172.28.0.12': 1.0, 'memory': 7991169024.0, 'node:__internal_head__': 1.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:

[Server Avg] Accuracy: 0.1543, Precision: 0.1094, Recall: 0.1543, F1: 0.1280


[36m(ClientAppActor pid=9274)[0m 
[36m(ClientAppActor pid=9274)[0m         
[36m(ClientAppActor pid=9274)[0m 
[36m(ClientAppActor pid=9274)[0m         
[36m(ClientAppActor pid=9274)[0m 
[36m(ClientAppActor pid=9274)[0m         
[36m(ClientAppActor pid=9274)[0m 
[36m(ClientAppActor pid=9274)[0m         
[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=9274)[0m 
[36m(ClientAppActor pid=9274)[0m         
[36m(ClientAppActor pid=9274)[0m 
[36m(ClientAppActor pid=9274)[0m         
[36m(ClientAppActor pid=9274)[0m 
[36m(ClientAppActor pid=9274)[0m         
[36m(ClientAppActor pid=9274)[0m 
[36m(ClientAppActor pid=9274)[0m         
[36m(ClientAppActor pid=9275)[0m 
[36m(ClientAppActor pid=9275)[0m         
[36m(ClientAppActor pid=9275)[0m 
[36m(ClientAppActor pid=9275)[0m        

[Server Avg] Accuracy: 0.1914, Precision: 0.1799, Recall: 0.1914, F1: 0.1598


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 4]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=9274)[0m 
[36m(ClientAppActor pid=9274)[0m         
[36m(ClientAppActor pid=9274)[0m 
[36m(ClientAppActor pid=9274)[0m         
[36m(ClientAppActor pid=9274)[0m 
[36m(ClientAppActor pid=9274)[0m         
[36m(ClientAppActor pid=9274)[0m 
[36m(ClientAppActor pid=9274)[0m         
[36m(ClientAppActor pid=9275)[0m 
[36m(ClientAppActor pid=9275)[0m         
[36m(ClientAppActor pid=9275)[0m 
[36m(ClientAppActor pid=9275)[0m         
[36m(ClientAppActor pid=9275)[0m 
[36m(ClientAppActor pid=9275)[0m         
[36m(ClientAppActor pid=9275)[0m 
[36m(ClientAppActor pid=9275)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(Cli

[Server Avg] Accuracy: 0.2037, Precision: 0.3479, Recall: 0.2037, F1: 0.1836


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 5]
[92mINFO [0m:      configure_fit: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=9274)[0m 
[36m(ClientAppActor pid=9274)[0m         
[36m(ClientAppActor pid=9274)[0m 
[36m(ClientAppActor pid=9274)[0m         
[36m(ClientAppActor pid=9275)[0m 
[36m(ClientAppActor pid=9275)[0m         
[36m(ClientAppActor pid=9275)[0m 
[36m(ClientAppActor pid=9275)[0m         
[92mINFO [0m:      aggregate_fit: received 3 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 3)
[36m(ClientAppActor pid=9274)[0m 
[36m(ClientAppActor pid=9274)[0m         
[36m(ClientAppActor pid=9274)[0m 
[36m(ClientAppActor pid=9274)[0m         
[36m(ClientAppActor pid=9275)[0m 
[36m(ClientAppActor pid=9275)[0m         
[36m(ClientAppActor pid=9275)[0m 
[36m(ClientAppActor pid=9275)[0m         
[36m(Cli

[Server Avg] Accuracy: 0.2099, Precision: 0.3494, Recall: 0.2099, F1: 0.1866


[92mINFO [0m:      aggregate_evaluate: received 3 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 5 round(s) in 6.15s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: 0.845679012345679
[92mINFO [0m:      		round 2: 0.808641975308642
[92mINFO [0m:      		round 3: 0.7962962962962963
[92mINFO [0m:      		round 4: 0.7901234567901234
[92mINFO [0m:      		round 5: 0.7901234567901234
[92mINFO [0m:      	History (metrics, distributed, evaluate):
[92mINFO [0m:      	{'accuracy': [(1, 0.15432098765432098),
[92mINFO [0m:      	              (2, 0.19135802469135801),
[92mINFO [0m:      	              (3, 0.2037037037037037),
[92mINFO [0m:      	              (4, 0.20987654320987653),
[92mINFO [0m:      	              (5, 0.20987654320987653)],
[92mINFO [0m:      	 'f1': [(1, 0.12800290486564994),
[92mINFO [0m:      	        (2, 0.1598147837799254),
[92mINFO [0m:      	       

[Server Avg] Accuracy: 0.2099, Precision: 0.3494, Recall: 0.2099, F1: 0.1866


[36m(ClientAppActor pid=9274)[0m 
[36m(ClientAppActor pid=9274)[0m         
[36m(ClientAppActor pid=9274)[0m 
[36m(ClientAppActor pid=9274)[0m         
[36m(ClientAppActor pid=9275)[0m 
[36m(ClientAppActor pid=9275)[0m         
[36m(ClientAppActor pid=9275)[0m 
[36m(ClientAppActor pid=9275)[0m         
[36m(ClientAppActor pid=9275)[0m 
[36m(ClientAppActor pid=9275)[0m         
[36m(ClientAppActor pid=9275)[0m 
[36m(ClientAppActor pid=9275)[0m         


History (loss, distributed):
	round 1: 0.845679012345679
	round 2: 0.808641975308642
	round 3: 0.7962962962962963
	round 4: 0.7901234567901234
	round 5: 0.7901234567901234
History (metrics, distributed, evaluate):
{'accuracy': [(1, 0.15432098765432098),
              (2, 0.19135802469135801),
              (3, 0.2037037037037037),
              (4, 0.20987654320987653),
              (5, 0.20987654320987653)],
 'f1': [(1, 0.12800290486564994),
        (2, 0.1598147837799254),
        (3, 0.18360366324616575),
        (4, 0.18659302577431233),
        (5, 0.18659302577431233)],
 'precision': [(1, 0.10936568752660707),
               (2, 0.1798631476050831),
               (3, 0.3479406130268199),
               (4, 0.34941367699988385),
               (5, 0.3494136769998839)],
 'recall': [(1, 0.15432098765432098),
            (2, 0.19135802469135801),
            (3, 0.2037037037037037),
            (4, 0.20987654320987653),
            (5, 0.20987654320987653)]}