## Cài đặt môi trường

In [None]:
!pip install -q flwr[simulation]
!pip install -q flwr-datasets[vision]

In [None]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple
import random

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Context
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import Strategy
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset

from flwr_datasets import FederatedDataset
from flwr_datasets.visualization import plot_label_distributions

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from torch.utils.data import DataLoader
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import DirichletPartitioner


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")

## Base config

In [None]:
import os

CACHE_DIR = "/kaggle/working/hf_cache"
os.environ["HF_DATASETS_OFFLINE"] = "0"
os.makedirs(CACHE_DIR, exist_ok=True)

In [None]:
backend_config = {"client_resources": {"num_cpus": 1}}
if DEVICE.type == "cuda":
    backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 1}}  # GPU optional

NUM_PARTITIONS = 20
NUM_ROUND = 1
EPOCH = 2
BATCH_SIZE = 20


FRACTION_FIT=0.85
FRACTION_EVALUATE=0.3
MIN_FIT_CLIENTS=7
MIN_EVALUATE_CLIENTS=10
MIN_AVAILABLE_CLIENTS=10

PROX_MU = 0.01 # Tham số Mu tốt nhất

ALPHA = 0.1

CLIENT_LR = 10**(-2.5)
SERVER_LR = 10**(-2)


## Load Dataset

In [None]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from torch.utils.data import DataLoader
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import DirichletPartitioner

# Cập nhật transforms cho CIFAR-10 với ResNet50
def load_datasets(partition_id, num_partitions: int):
    # Tạo dữ liệu Non-IID
    partitioner = DirichletPartitioner(
        num_partitions=NUM_PARTITIONS,
        partition_by="label",
        alpha=ALPHA,
        min_partition_size=num_partitions,
        self_balancing=True
    )

    fds = FederatedDataset(
        dataset="cifar10",
        partitioners={"train": partitioner},
        cache_dir=CACHE_DIR
    )

    partition = fds.load_partition(partition_id)
    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)

    # Transforms được tối ưu cho ResNet18 + FL
    train_transforms = transforms.Compose([
        transforms.Resize((112, 112)),  # Nhỏ hơn, tiết kiệm computation
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=5),  # Giảm từ 10 xuống 5
        transforms.ColorJitter(brightness=0.1, contrast=0.1),  # Thêm augmentation nhẹ
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    val_transforms = transforms.Compose([
        transforms.Resize((112, 112)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    def apply_train_transforms(batch):
        try:
            batch["img"] = [train_transforms(img.convert("RGB")) for img in batch["img"]]
            return batch
        except Exception as e:
            print(f"Error in train transforms: {e}")
            raise

    def apply_val_transforms(batch):
        try:
            batch["img"] = [val_transforms(img.convert("RGB")) for img in batch["img"]]
            return batch
        except Exception as e:
            print(f"Error in val transforms: {e}")
            raise

    # Apply transforms
    partition_train_test["train"] = partition_train_test["train"].with_transform(apply_train_transforms)
    partition_train_test["test"] = partition_train_test["test"].with_transform(apply_val_transforms)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    pin_memory = torch.cuda.is_available()
    num_workers = 0  # Luôn để 0 để tránh bug multiprocessing

    trainloader = DataLoader(
        partition_train_test["train"],
        batch_size=BATCH_SIZE,
        shuffle=True,
        pin_memory=pin_memory,
        num_workers=num_workers,
        persistent_workers=False,
        drop_last=True
    )

    valloader = DataLoader(
        partition_train_test["test"],
        batch_size=BATCH_SIZE,
        shuffle=False,
        pin_memory=pin_memory,
        num_workers=num_workers,
        persistent_workers=False,
        drop_last=False
    )


    testset = fds.load_split("test").with_transform(apply_val_transforms)
    testloader = DataLoader(
        testset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        pin_memory=pin_memory,
        num_workers=num_workers,
        persistent_workers=False,  # num_workers=0 thì phải là False
        drop_last=False
    )

    return trainloader, valloader, testloader


In [None]:
from huggingface_hub import login

login(token="")
all_data = [load_datasets(i, NUM_PARTITIONS) for i in range(NUM_PARTITIONS)]

In [None]:
from flwr_datasets import FederatedDataset
from flwr_datasets.visualization import plot_label_distributions
import matplotlib.pyplot as plt

# Tạo partitioner
partitioner = DirichletPartitioner(
    num_partitions=NUM_PARTITIONS,
    partition_by="label",
    alpha=ALPHA,
    min_partition_size=NUM_PARTITIONS,
    self_balancing=True
)

# Tạo FederatedDataset
fds = FederatedDataset(
    dataset="cifar10",
    partitioners={"train": partitioner},
    cache_dir=CACHE_DIR
)

# Visualize phân bố
print(f"Visualizing Non-IID distribution with alpha={ALPHA}")
partitioner = fds.partitioners["train"]
figure, axis, dataframe = plot_label_distributions(
    partitioner=partitioner,
    label_name="label",
    legend=True,
    verbose_labels=True,
    plot_type="bar",  # Thêm này để rõ hơn
    size_unit="absolute"  # Hoặc "percentage"
)

# Customize plot
plt.title(f'CIFAR-10 Non-IID Distribution (α={ALPHA})')
plt.xlabel('Clients')
plt.ylabel('Number of Samples')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# In thống kê
print("\n=== Distribution Statistics ===")
print(f"Total clients: {NUM_PARTITIONS}")
print(f"Alpha (non-IID level): {ALPHA}")
print(f"Dataset: CIFAR-10")
print(f"Classes: {dataframe.columns.tolist()}")

# Thống kê chi tiết
total_samples = dataframe.sum().sum()
print(f"Total samples: {total_samples}")
print(f" samples per client: {total_samples/NUM_PARTITIONS:.1f}")
print(f"Min samples: {dataframe.sum(axis=1).min()}")
print(f"Max samples: {dataframe.sum(axis=1).max()}")

In [None]:
import numpy as np
from scipy.stats import entropy

# Function to calculate entropy for a list of counts
def calculate_entropy(counts):
    total = sum(counts)
    if total == 0:
        return 0
    probabilities = [c / total for c in counts]
    return entropy(probabilities, base=2) # Using base 2 for bits

entropies = []
for client_id in range(NUM_PARTITIONS):
    label_counts = dataframe.iloc[client_id].tolist() # Get label counts for this client
    entropies.append(calculate_entropy(label_counts))

average_entropy = np.mean(entropies)
std_entropy = np.std(entropies)

print(f"Average Entropy: {average_entropy:.4f}")
print(f"Entropy Std: {std_entropy:.4f}")


## Model ResNet18

In [None]:
import torch.nn as nn
from torchvision import models

class ResNet18_GroupNorm(nn.Module):
    def __init__(self, num_classes: int = 10, pretrained: bool = True):
        super(ResNet18_GroupNorm, self).__init__()

        # Sử dụng pretrained weights
        weights = models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
        self.backbone = models.resnet18(weights=weights)

        # Thay BatchNorm bằng GroupNorm (2 nhóm/lớp)
        self._replace_batchnorm_with_groupnorm(self.backbone)

        # Lấy số feature từ fc gốc
        feat_dim = self.backbone.fc.in_features

        # Bỏ lớp fc gốc
        self.backbone.fc = nn.Identity()

        # Classifier với GroupNorm
        self.classifier = nn.Sequential(
            # Thêm dropout để tránh overfitting
            nn.Dropout(0.3),  # Dropout sau feature extraction
            nn.Linear(feat_dim, 128),
            nn.GroupNorm(2, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),  # Dropout giữa các lớp

            nn.Linear(128, 64),
            nn.GroupNorm(2, 64),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),  # Dropout nhẹ hơn ở cuối

            nn.Linear(64, num_classes)
        )

        self._init_weights()

    def _replace_batchnorm_with_groupnorm(self, module):
        """Thay tất cả BatchNorm bằng GroupNorm với 2 nhóm"""
        for name, child in module.named_children():
            if isinstance(child, nn.BatchNorm2d):
                num_channels = child.num_features
                # Thay BatchNorm2d bằng GroupNorm với 2 nhóm
                setattr(module, name, nn.GroupNorm(2, num_channels))
            elif isinstance(child, nn.BatchNorm1d):
                num_channels = child.num_features
                # Thay BatchNorm1d bằng GroupNorm với 2 nhóm
                setattr(module, name, nn.GroupNorm(2, num_channels))
            else:
                # Đệ quy cho các module con
                self._replace_batchnorm_with_groupnorm(child)

    def _init_weights(self):
        for m in self.classifier:
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.backbone(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


import torch
import copy

def train_fedprox(net, trainloader, epochs: int, global_params=None, mu=PROX_MU, lr=0.01, weight_decay=1e-3):
    """
    Train local model with FedProx proximal term.
    Args:
        net: local model (nn.Module)
        trainloader: DataLoader for local data
        epochs: number of epochs to train
        global_params: list of global model parameters (from server, after FedAvg)
        mu: proximal term coefficient (float)
        lr: learning rate (float)
        weight_decay: L2 regularization for optimizer
    Returns:
        epoch_loss: average loss after last epoch
        epoch_acc: average accuracy after last epoch
    """
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, weight_decay=weight_decay)
    net.train()

    device = next(net.parameters()).device

    # Lưu lại global params dưới dạng detached clone để tính toán (tránh bị update khi backward)
    if global_params is not None:
        global_params = [p.detach().clone() for p in global_params]

    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for batch in trainloader:
            images, labels = batch["img"].to(device), batch["label"].to(device)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(outputs, labels)

            # THÊM PROXIMAL TERM
            if global_params is not None and mu > 0:
                prox_loss = 0.0
                for param, global_param in zip(net.parameters(), global_params):
                    prox_loss += torch.norm(param - global_param) ** 2
                loss += (mu / 2) * prox_loss

            loss.backward()
            optimizer.step()

            epoch_loss += loss.item() * labels.size(0)
            total += labels.size(0)
            correct += (outputs.argmax(dim=1) == labels).sum().item()

        epoch_loss /= total
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss:.4f}, accuracy {epoch_acc:.4f}")

    return epoch_loss, epoch_acc

def test(net, testloader):
    """Đánh giá mô hình trên tập validation và trả về loss, acc."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, total_loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * labels.size(0)
            correct += (outputs.argmax(dim=1) == labels).sum().item()
            total += labels.size(0)
    prox_loss = total_loss / total
    accuracy = correct / total
    return prox_loss, accuracy


In [None]:
def get_parameters(net) -> List[np.ndarray]:
    params = []
    for name, val in net.state_dict().items():
        if val.numel() > 0:
            params.append(val.cpu().numpy())
        else:
            print(f"Warning: Empty tensor found: {name}")
    return params

def set_parameters(net, parameters: List[np.ndarray]):
    if not parameters:
        print("Warning: No parameters received")
        return

    state_dict_keys = list(net.state_dict().keys())
    print(f"Model has {len(state_dict_keys)} parameters")
    print(f"Received {len(parameters)} parameters")

    if len(parameters) != len(state_dict_keys):
        print("Parameter count mismatch!")
        return

    params_dict = zip(state_dict_keys, parameters)
    state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=False)

## Custom FedProx

In [None]:
class FlowerClientFedProx(NumPyClient):
    def __init__(self, partition_id, net, trainloader, valloader):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.partition_id}] FedProx fit")

        # Lưu global parameters để tính proximal term
        global_params = [torch.tensor(param).to(DEVICE) for param in parameters]

        set_parameters(self.net, parameters)

        # Lấy proximal_mu từ config hoặc dùng default
        mu = config.get("proximal_mu", 0.1)

        # Sử dụng training function có proximal regularization
        loss, acc = train_fedprox(self.net, self.trainloader, epochs=EPOCH,
                                 global_params=global_params, mu=mu)

        return get_parameters(self.net), len(self.trainloader.dataset), {
            "loss": float(loss), "accuracy": float(acc)
        }

    def evaluate(self, parameters, config):
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader.dataset), {
            "loss": float(loss), "accuracy": float(accuracy)
        }


In [None]:
def client_fn_fedprox(context):
    cid = context.node_config["partition-id"]
    trainloader, valloader, _ = all_data[cid]

    net = ResNet18_GroupNorm(num_classes=10).to(DEVICE)
    return FlowerClientFedProx(cid, net, trainloader, valloader).to_client()

client_app_fedprox = ClientApp(client_fn=client_fn_fedprox)

In [None]:
from flwr.common import Metrics, Context
from flwr.common import ndarrays_to_parameters

def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    # Multiply accuracy of each client by number of examples used
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    # Aggregate and return custom metric (weighted average)
    return {"accuracy": sum(accuracies) / sum(examples)}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = ResNet18_GroupNorm().to(DEVICE)
initial_parameters = flwr.common.ndarrays_to_parameters(get_parameters(net))

In [None]:
from flwr.server.strategy import FedProx
import time

from flwr.server.strategy import FedProx
import time

class CustomFedProx(FedProx):
    def __init__(self, proximal_mu=0.1, *args, **kwargs):
        super().__init__(proximal_mu=proximal_mu, *args, **kwargs)
        self.metrics_centralized = []
        self.round_start_times = {}  # Để tính thời gian của từng round

    def configure_fit(self, server_round, parameters, client_manager):
        # Ghi nhận thời gian bắt đầu round
        self.round_start_times[server_round] = time.time()

        config = {"proximal_mu": self.proximal_mu}
        fit_ins = flwr.common.FitIns(parameters, config)

        sample_size, min_num_clients = self.num_fit_clients(client_manager.num_available())
        clients = client_manager.sample(num_clients=sample_size, min_num_clients=min_num_clients)

        return [(client, fit_ins) for client in clients]

    def aggregate_fit(self, rnd, results, failures):
        aggregated_parameters, _ = super().aggregate_fit(rnd, results, failures)

        losses = [r.metrics.get("loss") for _, r in results if r.metrics and "loss" in r.metrics]
        accs = [r.metrics.get("accuracy") for _, r in results if r.metrics and "accuracy" in r.metrics]

        prox_loss = sum(losses) / len(losses) if losses else None
        prox_acc = sum(accs) / len(accs) if accs else None

        print(f"[Round {rnd}] Train — Loss: {prox_loss:.4f}, Acc: {prox_acc:.4f}" if prox_loss is not None else
              f"[Round {rnd}] Train — Loss: N/A, Acc: N/A")

        while len(self.metrics_centralized) < rnd:
            self.metrics_centralized.append({})

        self.metrics_centralized[rnd - 1].update({
            "train_loss": prox_loss,
            "train_acc": prox_acc
        })

        return aggregated_parameters, {}

    def aggregate_evaluate(self, rnd, results, failures):
        eval_start = time.time()
        aggregated_loss, _ = super().aggregate_evaluate(rnd, results, failures)
        eval_duration = time.time() - eval_start

        round_total_time = time.time() - self.round_start_times.get(rnd, time.time())

        losses = [r.metrics.get("loss") for _, r in results if r.metrics and "loss" in r.metrics]
        accs = [r.metrics.get("accuracy") for _, r in results if r.metrics and "accuracy" in r.metrics]

        prox_loss = sum(losses) / len(losses) if losses else None
        prox_acc = sum(accs) / len(accs) if accs else None

        print(f"[Round {rnd}] Val   — Loss: {prox_loss:.4f}, Acc: {prox_acc:.4f}, Time: {eval_duration:.2f}s")
        print(f"[Round {rnd}] Total Round Time: {round_total_time:.2f}s")

        while len(self.metrics_centralized) < rnd:
            self.metrics_centralized.append({})

        self.metrics_centralized[rnd - 1].update({
            "val_loss": prox_loss,
            "val_acc": prox_acc,
            "eval_duration": eval_duration,
            "round_total_time": round_total_time,
            "round_duration": round_total_time  # fallback để JSON unified hỗ trợ đầy đủ
        })

        return aggregated_loss, {}


In [None]:
# Tạo strategy cho FedProx
strategyProx = CustomFedProx(
    fraction_fit=FRACTION_FIT,
    fraction_evaluate=FRACTION_EVALUATE,
    min_fit_clients=MIN_FIT_CLIENTS,
    min_evaluate_clients=MIN_EVALUATE_CLIENTS,
    min_available_clients=MIN_AVAILABLE_CLIENTS,
    evaluate_metrics_aggregation_fn=weighted_average,
    initial_parameters=initial_parameters,
    proximal_mu=PROX_MU,
)

def server_fn(context: Context) -> ServerAppComponents:
    config = ServerConfig(num_rounds=NUM_ROUND)
    return ServerAppComponents(strategy=strategyProx, config=config)

# Tạo server app mới với FedProx
server = ServerApp(server_fn=server_fn)

# Run simulation với client FedProx
run_simulation(
    server_app=server,
    client_app=client_app_fedprox,    # đổi lại đúng client_app của bạn cho FedProx
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)


In [None]:
import json
import pickle
import pandas as pd
import numpy as np
from datetime import datetime
import os

# =============================================================================
# FUNCTIONS FOR SAVING RESULTS
# =============================================================================

import os
import json
from datetime import datetime

def save_results_to_json_unified(strategy, strategy_name, save_dir="fl_results"):
    """
    Lưu kết quả huấn luyện của strategy với định dạng JSON chuẩn hóa (không bao gồm fit_duration).
    Áp dụng được cho FedAvg, FedAdam, FedProx và các strategy khác.
    """
    os.makedirs(save_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    unified_metrics = []
    for i, metrics in enumerate(strategy.metrics_centralized):
        unified_metrics.append({
            "round": i + 1,
            "train_loss": metrics.get("train_loss"),
            "train_acc": metrics.get("train_acc"),
            "val_loss": metrics.get("val_loss"),
            "val_acc": metrics.get("val_acc"),
            "eval_duration": metrics.get("eval_duration"),  # Có thể là None
            "round_total_time": metrics.get("round_total_time", metrics.get("round_duration"))  # fallback nếu cần
        })

    # Runtime tổng hợp
    round_times = [m["round_total_time"] for m in unified_metrics if m["round_total_time"] is not None]
    total_runtime = sum(round_times)
    prox_runtime = total_runtime / len(round_times) if round_times else 0

    results = {
        "strategy_name": strategy_name,
        "timestamp": timestamp,
        "metrics": unified_metrics,
        "runtime_summary": {
            "total_runtime_seconds": total_runtime,
            "average_round_time_seconds": prox_runtime,
            "fastest_round_seconds": min(round_times) if round_times else None,
            "slowest_round_seconds": max(round_times) if round_times else None,
            "num_rounds": len(unified_metrics)
        },
        "config": {
            "num_rounds": len(unified_metrics),
            "num_partitions": globals().get("NUM_PARTITIONS", None),
            "batch_size": globals().get("BATCH_SIZE", None),
            "learning_rate": globals().get("CLIENT_LR", None),
            "epochs_per_round": globals().get("EPOCH", None)
        }
    }

    filename = f"{save_dir}/{strategy_name}_{timestamp}.json"
    with open(filename, "w") as f:
        json.dump(results, f, indent=2)

    print(f"✅ Saved unified JSON (compatible with all strategies) for {strategy_name} to {filename}")
    return filename


def save_results_to_pickle(strategy, strategy_name, save_dir="fl_results"):
    """
    Lưu kết quả của strategy vào file pickle (có thể lưu được object phức tạp)
    """
    os.makedirs(save_dir, exist_ok=True)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    results = {
        "strategy_name": strategy_name,
        "timestamp": timestamp,
        "strategy_object": strategy,  # Lưu cả object
        "metrics": strategy.metrics_centralized,
        "config": {
            "num_rounds": len(strategy.metrics_centralized),
            "num_partitions": NUM_PARTITIONS,
            "batch_size": BATCH_SIZE,
            "learning_rate": CLIENT_LR,
            "epochs_per_round": EPOCH
        }
    }

    filename = f"{save_dir}/{strategy_name}_{timestamp}.pkl"
    with open(filename, 'wb') as f:
        pickle.dump(results, f)

    print(f"✅ Saved {strategy_name} results to {filename}")
    return filename

def save_results_to_csv(strategy, strategy_name, save_dir="fl_results"):
    """
    Lưu kết quả của strategy vào file CSV. Hoạt động tốt cho mọi strategy, kể cả FedProx.
    """
    os.makedirs(save_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    metrics_data = []
    for round_idx, metrics in enumerate(strategy.metrics_centralized):
        row = {
            "round": round_idx + 1,
            "train_loss": metrics.get("train_loss"),
            "train_acc": metrics.get("train_acc"),
            "val_loss": metrics.get("val_loss"),
            "val_acc": metrics.get("val_acc"),
            "round_total_time": metrics.get("round_total_time", metrics.get("round_duration", None)),
            "strategy": strategy_name,
            "timestamp": timestamp
        }
        metrics_data.append(row)

    df = pd.DataFrame(metrics_data)

    filename = f"{save_dir}/{strategy_name}_{timestamp}.csv"
    df.to_csv(filename, index=False)

    print(f"✅ Saved {strategy_name} results to {filename}")
    return filename


def save_all_strategies_comparison(strategies_dict, save_dir="fl_results"):
    """
    Lưu so sánh tất cả strategies vào một file CSV

    Args:
        strategies_dict: Dict chứa {strategy_name: strategy_object}
        save_dir: Thư mục lưu kết quả
    """
    os.makedirs(save_dir, exist_ok=True)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    all_data = []

    for strategy_name, strategy in strategies_dict.items():
        for round_idx, metrics in enumerate(strategy.metrics_centralized):
            row = {
                "round": round_idx + 1,
                "train_loss": metrics.get("train_loss"),
                "train_acc": metrics.get("train_acc"),
                "val_loss": metrics.get("val_loss"),
                "val_acc": metrics.get("val_acc"),
                "strategy": strategy_name,
                "timestamp": timestamp
            }
            all_data.append(row)

    df = pd.DataFrame(all_data)

    filename = f"{save_dir}/comparison_{timestamp}.csv"
    df.to_csv(filename, index=False)

    print(f"✅ Saved comparison results to {filename}")
    return filename

# =============================================================================
# FUNCTIONS FOR LOADING RESULTS
# =============================================================================

def load_results_from_json(filename):
    """Load kết quả từ file JSON"""
    with open(filename, 'r') as f:
        results = json.load(f)
    print(f"✅ Loaded results from {filename}")
    return results

def load_results_from_pickle(filename):
    """Load kết quả từ file pickle"""
    with open(filename, 'rb') as f:
        results = pickle.load(f)
    print(f"✅ Loaded results from {filename}")
    return results

def load_results_from_csv(filename):
    """Load kết quả từ file CSV"""
    df = pd.read_csv(filename)
    print(f"✅ Loaded results from {filename}")
    return df

# =============================================================================
# VISUALIZATION FUNCTIONS
# =============================================================================

def plot_comparison_from_csv(csv_filename, save_plot=True):
    """
    Tạo biểu đồ so sánh từ file CSV
    """
    df = pd.read_csv(csv_filename)

    import matplotlib.pyplot as plt

    # Tạo subplots
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # Train Loss
    for strategy in df['strategy'].unique():
        strategy_data = df[df['strategy'] == strategy]
        axes[0, 0].plot(strategy_data['round'], strategy_data['train_loss'],
                       label=f'{strategy}', marker='o')
    axes[0, 0].set_title('Train Loss Comparison')
    axes[0, 0].set_xlabel('Round')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True)

    # Val Loss
    for strategy in df['strategy'].unique():
        strategy_data = df[df['strategy'] == strategy]
        axes[0, 1].plot(strategy_data['round'], strategy_data['val_loss'],
                       label=f'{strategy}', marker='s')
    axes[0, 1].set_title('Validation Loss Comparison')
    axes[0, 1].set_xlabel('Round')
    axes[0, 1].set_ylabel('Loss')
    axes[0, 1].legend()
    axes[0, 1].grid(True)

    # Train Accuracy
    for strategy in df['strategy'].unique():
        strategy_data = df[df['strategy'] == strategy]
        axes[1, 0].plot(strategy_data['round'], strategy_data['train_acc'],
                       label=f'{strategy}', marker='o')
    axes[1, 0].set_title('Train Accuracy Comparison')
    axes[1, 0].set_xlabel('Round')
    axes[1, 0].set_ylabel('Accuracy')
    axes[1, 0].legend()
    axes[1, 0].grid(True)

    # Val Accuracy
    for strategy in df['strategy'].unique():
        strategy_data = df[df['strategy'] == strategy]
        axes[1, 1].plot(strategy_data['round'], strategy_data['val_acc'],
                       label=f'{strategy}', marker='s')
    axes[1, 1].set_title('Validation Accuracy Comparison')
    axes[1, 1].set_xlabel('Round')
    axes[1, 1].set_ylabel('Accuracy')
    axes[1, 1].legend()
    axes[1, 1].grid(True)

    plt.tight_layout()

    if save_plot:
        plot_filename = csv_filename.replace('.csv', '_comparison_plo_100.png')
        plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
        print(f"✅ Saved plot to {plot_filename}")

    plt.show()

def create_summary_report(strategies_dict, save_dir="fl_results"):
    """
    Tạo báo cáo tóm tắt cho tất cả strategies. Tự động fallback nếu thiếu runtime.
    """
    os.makedirs(save_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    summary_data = []

    for strategy_name, strategy in strategies_dict.items():
        metrics = strategy.metrics_centralized

        final_round = len(metrics)
        final_train_acc = metrics[-1].get("train_acc", 0) if metrics else 0
        final_val_acc = metrics[-1].get("val_acc", 0) if metrics else 0
        final_train_loss = metrics[-1].get("train_loss", 0) if metrics else 0
        final_val_loss = metrics[-1].get("val_loss", 0) if metrics else 0

        val_accs = [m.get("val_acc", 0) for m in metrics if m.get("val_acc") is not None]
        best_val_acc = max(val_accs) if val_accs else 0
        best_val_round = val_accs.index(best_val_acc) + 1 if val_accs else 0

        round_times = [
            m.get("round_total_time", m.get("round_duration"))
            for m in metrics if m.get("round_total_time") or m.get("round_duration")
        ]
        total_runtime = sum(round_times)
        prox_round_time = total_runtime / len(round_times) if round_times else 0
        max_round_time = max(round_times) if round_times else 0
        min_round_time = min(round_times) if round_times else 0


        summary_data.append({
            "strategy": strategy_name,
            "final_round": final_round,
            "final_train_acc": final_train_acc,
            "final_val_acc": final_val_acc,
            "final_train_loss": final_train_loss,
            "final_val_loss": final_val_loss,
            "best_val_acc": best_val_acc,
            "best_val_round": best_val_round,
            "total_runtime_sec": total_runtime,
            "total_runtime_min": total_runtime / 60,
            "prox_round_time_sec": prox_round_time,
            "max_round_time_sec": max_round_time,
            "min_round_time_sec": min_round_time,
            "timestamp": timestamp
        })

    df_summary = pd.DataFrame(summary_data)

    filename = f"{save_dir}/summary_report_{timestamp}.csv"
    df_summary.to_csv(filename, index=False)

    print(f"✅ Saved summary report to {filename}")
    print("\n📊 SUMMARY REPORT:")
    print(df_summary.to_string(index=False))

    print("\n⏱️ RUNTIME SUMMARY:")
    for _, row in df_summary.iterrows():
        print(f"{row['strategy']}:")
        print(f"  Total time: {row['total_runtime_min']:.1f}min ({row['total_runtime_sec']:.1f}s)")
        print(f"  Prox round: {row['prox_round_time_sec']:.2f}s")
        print(f"  Range: {row['min_round_time_sec']:.2f}s - {row['max_round_time_sec']:.2f}s")

    return filename

save_results_to_json_unified(strategyProx, "FedProx")
save_results_to_csv(strategyProx, "FedProx")

strategies_dict = {
    "FedProx": strategyProx,
}

create_summary_report(strategies_dict)