In [None]:
from typing import Dict, List, Tuple
import tensorflow as tf
import flwr as fl
import numpy as np
from flwr.common import Metrics
from flwr.simulation.ray_transport.utils import enable_tf_gpu_growth
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import DirichletPartitioner
from flwr.server.strategy import DPFedAvgFixed
from sklearn.feature_selection import mutual_info_regression
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import logging
from typing import Optional, Union
import flwr as fl
from flwr.common import Parameters, NDArrays
from flwr.server.client_proxy import ClientProxy
from flwr.common.typing import Config
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms, datasets
import numpy as np
from collections import deque
import matplotlib.pyplot as plt
from sklearn.feature_selection import mutual_info_regression
from flwr.common import EvaluateIns, EvaluateRes, FitIns, FitRes, Parameters, Scalar
from flwr.common.dp import add_gaussian_noise
from flwr.common.logger import warn_deprecated_feature
from flwr.common.parameter import ndarrays_to_parameters, parameters_to_ndarrays
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy import DPFedAvgFixed
from sklearn.feature_selection import mutual_info_regression
from typing import Dict, List, Optional
import random
from flwr.server.client_proxy import ClientProxy
from flwr.server.criterion import Criterion
import threading



# Constants
NUM_CLIENTS = 10
K_CLIENTS = 5
BATCH_SIZE = 16
NUM_ROUNDS = 1000
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 0.001

# Privacy parameters
INITIAL_EPSILON = 2.0
MIN_EPSILON = 0.1
MAX_EPSILON = 5.0
TARGET_ACCURACY = 0.85
ADJUST_RATE = 0.05
NOISE_CLIP = 0.1
WINDOW_SIZE = 3


def write_to_file(filename, data):
    """Ghi dữ liệu vào file, mỗi dòng là một giá trị mới."""
    with open(filename, "a") as f:
        f.write(f"{data}\n")

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        # Improved CNN architecture
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = nn.functional.relu(x)
        x = nn.functional.max_pool2d(x, 2)
        x = self.dropout1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = nn.functional.relu(x)
        x = nn.functional.max_pool2d(x, 2)
        x = self.dropout2(x)

        x = torch.flatten(x, 1)
        x = nn.functional.relu(self.fc1(x))
        return self.fc2(x)


class AdaptiveLDP:
    def __init__(self, epsilon=INITIAL_EPSILON):
        self.epsilon = epsilon
        self.accuracy_history = deque(maxlen=WINDOW_SIZE)
        self.noise_history = deque(maxlen=WINDOW_SIZE)
        self.leakage_history = deque(maxlen=WINDOW_SIZE)
        self.current_noise = 0.0
        self.total_noise_added = 0.0
        self.noise_samples = 0

    def add_noise(self, data: torch.Tensor) -> torch.Tensor:
        sensitivity = min(torch.max(torch.abs(data)).item(), 1.0)
        scale = min(sensitivity / max(self.epsilon, MIN_EPSILON), NOISE_CLIP)
        noise = torch.tensor(np.random.laplace(0, scale, data.shape), dtype=data.dtype, device=data.device)
        self.current_noise = torch.mean(torch.abs(noise)).item()
        self.noise_history.append(self.current_noise)
        self.total_noise_added += self.current_noise
        self.noise_samples += 1
        return torch.clamp(data + noise, -1.0, 1.0)

    def compute_privacy_leakage(self, original: torch.Tensor, noisy: torch.Tensor) -> float:
        original_np = original.cpu().numpy().flatten()
        noisy_np = noisy.cpu().numpy().flatten()
        if len(original_np) > 10:
            leakage = mutual_info_regression(original_np.reshape(-1, 1), noisy_np)
            return float(leakage[0])
        return 0.0

    def adjust_epsilon(self, accuracy: float, leakage: float):
        self.accuracy_history.append(accuracy)
        self.leakage_history.append(leakage)
        if len(self.accuracy_history) >= WINDOW_SIZE:
            avg_accuracy = np.mean(self.accuracy_history)
            avg_leakage = np.mean(self.leakage_history)
            accuracy_diff = TARGET_ACCURACY - avg_accuracy
            leakage_penalty = 0.1 * avg_leakage
            delta = ADJUST_RATE * (accuracy_diff - leakage_penalty)
            new_epsilon = self.epsilon * (1.0 + delta)
            self.epsilon = np.clip(new_epsilon, MIN_EPSILON, MAX_EPSILON)

class PrivacyClient(fl.client.NumPyClient):
    def __init__(self, cid: str, model: nn.Module, train_loader: DataLoader):
        self.cid = cid
        self.model = model.to(DEVICE)
        self.train_loader = train_loader
        self.optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
        self.criterion = nn.CrossEntropyLoss()
        self.ldp = AdaptiveLDP()
        self.metrics_history = []

    def fit(self, parameters: NDArrays, config: Config):
        self.model.train()
        total_loss, correct, total = 0, 0, 0
        privacy_leakage_values = []

        for data, target in self.train_loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            noisy_data = self.ldp.add_noise(data)
            leakage = self.ldp.compute_privacy_leakage(data, noisy_data)
            privacy_leakage_values.append(leakage)

            self.optimizer.zero_grad()
            output = self.model(noisy_data)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()
            correct += (output.argmax(dim=1) == target).sum().item()
            total += target.size(0)

        accuracy = correct / total
        avg_loss = total_loss / len(self.train_loader)
        avg_leakage = np.mean(privacy_leakage_values)
        self.ldp.adjust_epsilon(accuracy, avg_leakage)

        metrics = {"loss": avg_loss, "accuracy": accuracy, "epsilon": self.ldp.epsilon, "leakage": avg_leakage}
        self.metrics_history.append(metrics)
        return self.get_parameters({}), total, metrics

class FedAvg_Privacy(fl.server.strategy.FedAvg):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.metrics_history = []

    def aggregate_fit(self, server_round, results, failures):
        aggregated = super().aggregate_fit(server_round, results, failures)
        if aggregated is None:
            return None, {}
        parameters, _ = aggregated
        total_samples = sum(fit_res.num_examples for _, fit_res in results)
        metrics = {
            "round": server_round,
            "accuracy": sum(fit_res.metrics["accuracy"] * fit_res.num_examples for _, fit_res in results) / total_samples,
            "epsilon": np.mean([fit_res.metrics["epsilon"] for _, fit_res in results]),
            "leakage": np.mean([fit_res.metrics["leakage"] for _, fit_res in results])
        }
        write_to_file("loss.txt", metrics["loss"])
        write_to_file("accuracy.txt", metrics["accuracy"])
        write_to_file("epsilon.txt", metrics["epsilon"])
        write_to_file("leakage.txt", metrics["leakage"])
        self.metrics_history.append(metrics)
        return parameters, metrics

def plot_training_metrics(history):
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.plot([m["accuracy"] for m in history])
    plt.title("Accuracy")
    plt.subplot(1, 3, 2)
    plt.plot([m["epsilon"] for m in history])
    plt.title("Epsilon")
    plt.subplot(1, 3, 3)
    plt.plot([m["leakage"] for m in history])
    plt.title("Privacy Leakage")
    plt.show()

def split_mnist_dirichlet_flwr(num_clients=NUM_CLIENTS, alpha=0.5, seed=42):
    partitioner = DirichletPartitioner(
        num_partitions=num_clients, partition_by="label", alpha=alpha, seed=seed
    )
    fds = FederatedDataset(dataset="mnist", partitioners={"train": partitioner})
    federated_data = {f"client_{i}": fds.load_partition(i) for i in range(num_clients)}
    return fds, federated_data  # Trả về cả fds và dữ liệu phân vùng

def load_data(num_clients: int):
    _, federated_data = split_mnist_dirichlet_flwr(num_clients)
    client_loaders = {}

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    for i in range(num_clients):
        images, labels = federated_data[f"client_{i}"]
        images = torch.tensor(images).unsqueeze(1).float() / 255.0
        labels = torch.tensor(labels, dtype=torch.long)
        
        dataset = TensorDataset(images, labels)
        client_loaders[i] = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    return client_loaders

class SimpleClientManager(ClientManager):
    def __init__(self) -> None:
        self.clients: Dict[str, ClientProxy] = {}
        self._cv = threading.Condition()
        self.seed = 0 # cài đặt seed để fix client tham gia mỗi round

    def __len__(self) -> int:
        return len(self.clients)

    def num_available(self) -> int:
        return len(self)

    def wait_for(self, num_clients: int, timeout: int = 86400) -> bool:
        with self._cv:
            return self._cv.wait_for(
                lambda: len(self.clients) >= num_clients, timeout=timeout
            )

    def register(self, client: ClientProxy) -> bool:
        if client.cid in self.clients:
            return False

        self.clients[client.cid] = client
        with self._cv:
            self._cv.notify_all()

        return True

    def unregister(self, client: ClientProxy) -> None:
        if client.cid in self.clients:
            del self.clients[client.cid]

            with self._cv:
                self._cv.notify_all()

    def all(self) -> Dict[str, ClientProxy]:
        return self.clients

    def sample(
        self,
        num_clients: int,
        min_num_clients: Optional[int] = None,
        criterion: Optional[Criterion] = None,
    ) -> List[ClientProxy]:

        if min_num_clients is None:
            min_num_clients = num_clients
        self.wait_for(min_num_clients)
        available_cids = list(self.clients)

        if num_clients == 1:
            sampled_cids = random.sample(available_cids, num_clients)
            return [self.clients[cid] for cid in sampled_cids]

        sampled_cids = random.sample(available_cids, num_clients)
        self.seed +=1
        return [self.clients[cid] for cid in sampled_cids]

def main():
    client_data = load_data(NUM_CLIENTS)
    client_manager = SimpleClientManager()
    def client_fn(cid: str): return PrivacyClient(cid, Net(), client_data[int(cid)])
    strategy = FedAvg_Privacy(min_available_clients=NUM_CLIENTS, min_fit_clients=K_CLIENTS,client_manager =client_manager)
    history = fl.simulation.start_simulation(client_fn=client_fn, num_clients=NUM_CLIENTS, config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS), strategy=strategy)
    if hasattr(strategy, 'metrics_history'):
        plot_training_metrics(strategy.metrics_history)

if __name__ == "__main__":
    main()
