## Cài đặt môi trường

In [None]:
!pip install -q flwr[simulation]
!pip install -q flwr-datasets[vision]

In [None]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple
import random

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Context
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import Strategy
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset

from flwr_datasets import FederatedDataset
from flwr_datasets.visualization import plot_label_distributions

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from torch.utils.data import DataLoader
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import DirichletPartitioner


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")

## Base config

In [None]:
import os

CACHE_DIR = "/content/hf_cache"
os.environ["HF_DATASETS_OFFLINE"] = "0"
os.makedirs(CACHE_DIR, exist_ok=True)

In [None]:
backend_config = {"client_resources": {"num_cpus": 1}}
if DEVICE.type == "cuda":
    backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 1}}  # GPU optional

NUM_PARTITIONS = 20
NUM_ROUND = 1
EPOCH = 2
BATCH_SIZE = 20

FRACTION_FIT=0.85
FRACTION_EVALUATE=0.3
MIN_FIT_CLIENTS=7
MIN_EVALUATE_CLIENTS=10
MIN_AVAILABLE_CLIENTS=10

ALPHA = 0.3

CLIENT_LR = 10**(-2)

## Load Dataset

In [None]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)  # Chỉ cần gọi 1 lần ở đầu chương trình

In [None]:
def load_datasets(partition_id, num_partitions: int):
    # Tạo dữ liệu Non-IID
    partitioner = DirichletPartitioner(
        num_partitions=NUM_PARTITIONS,
        partition_by="label",
        alpha=ALPHA,
        min_partition_size=num_partitions,
        self_balancing=True
    )

    fds = FederatedDataset(
        dataset="fashion_mnist",  # Fashion-MNIST
        partitioners={"train": partitioner},
        cache_dir=CACHE_DIR
    )

    partition = fds.load_partition(partition_id)
    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)

    # Transforms cho Fashion-MNIST (grayscale, 28x28)
    train_transforms = transforms.Compose([
        transforms.Resize((32, 32)),  # Resize lên 32x32 cho ResNet
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=5),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.2860], std=[0.3530])  # Fashion-MNIST stats (1 channel)
    ])

    val_transforms = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.2860], std=[0.3530])  # Fashion-MNIST stats (1 channel)
    ])

    def apply_train_transforms(batch):
        try:
            # SỬA LỖI: Fashion-MNIST sử dụng key "image" thay vì "img"
            batch["image"] = [train_transforms(img) for img in batch["image"]]
            return batch
        except Exception as e:
            print(f"Error in train transforms: {e}")
            raise

    def apply_val_transforms(batch):
        try:
            # SỬA LỖI: Fashion-MNIST sử dụng key "image" thay vì "img"
            batch["image"] = [val_transforms(img) for img in batch["image"]]
            return batch
        except Exception as e:
            print(f"Error in val transforms: {e}")
            raise

    # Apply transforms
    partition_train_test["train"] = partition_train_test["train"].with_transform(apply_train_transforms)
    partition_train_test["test"] = partition_train_test["test"].with_transform(apply_val_transforms)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    pin_memory = torch.cuda.is_available()
    num_workers = 0

    trainloader = DataLoader(
        partition_train_test["train"],
        batch_size=BATCH_SIZE,
        shuffle=True,
        pin_memory=pin_memory,
        num_workers=num_workers,
        persistent_workers=False,
        drop_last=True
    )

    valloader = DataLoader(
        partition_train_test["test"],
        batch_size=BATCH_SIZE,
        shuffle=False,
        pin_memory=pin_memory,
        num_workers=num_workers,
        persistent_workers=False,
        drop_last=False
    )

    testset = fds.load_split("test").with_transform(apply_val_transforms)
    testloader = DataLoader(
        testset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        pin_memory=pin_memory,
        num_workers=num_workers,
        persistent_workers=False,
        drop_last=False
    )

    return trainloader, valloader, testloader

In [None]:
from huggingface_hub import login

login(token="")
all_data = [load_datasets(i, NUM_PARTITIONS) for i in range(NUM_PARTITIONS)]

In [None]:
from flwr_datasets import FederatedDataset
from flwr_datasets.visualization import plot_label_distributions
import matplotlib.pyplot as plt

# Tạo partitioner
partitioner = DirichletPartitioner(
    num_partitions=NUM_PARTITIONS,
    partition_by="label",
    alpha=ALPHA,
    min_partition_size=NUM_PARTITIONS,
    self_balancing=True
)

# Tạo FederatedDataset
fds = FederatedDataset(
    dataset="fashion_mnist",  # THAY ĐỔI: cifar10 -> fashion_mnist
    partitioners={"train": partitioner},
    cache_dir=CACHE_DIR
)

# Visualize phân bố
print(f"Visualizing Non-IID distribution with alpha={ALPHA}")
partitioner = fds.partitioners["train"]
figure, axis, dataframe = plot_label_distributions(
    partitioner=partitioner,
    label_name="label",
    legend=True,
    verbose_labels=True,
    plot_type="bar",  # Thêm này để rõ hơn
    size_unit="absolute"  # Hoặc "percentage"
)

# Customize plot
plt.title(f'Fashion-MNIST Non-IID Distribution (α={ALPHA})')  # THAY ĐỔI: CIFAR-10 -> Fashion-MNIST
plt.xlabel('Clients')
plt.ylabel('Number of Samples')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

print(f"Dataset: Fashion-MNIST")

In [None]:
import numpy as np
from scipy.stats import entropy

# Function to calculate entropy for a list of counts
def calculate_entropy(counts):
    total = sum(counts)
    if total == 0:
        return 0
    probabilities = [c / total for c in counts]
    return entropy(probabilities, base=2) # Using base 2 for bits

entropies = []
for client_id in range(NUM_PARTITIONS):
    label_counts = dataframe.iloc[client_id].tolist() # Get label counts for this client
    entropies.append(calculate_entropy(label_counts))

average_entropy = np.mean(entropies)
std_entropy = np.std(entropies)

print(f"Average Entropy: {average_entropy:.4f}")
print(f"Entropy Std: {std_entropy:.4f}")


## Model ResNet18

In [None]:
import torch.nn as nn
from torchvision import models

class ResNet18_GroupNorm(nn.Module):
    def __init__(self, num_classes: int = 10, pretrained: bool = True):
        super(ResNet18_GroupNorm, self).__init__()

        # Sử dụng pretrained weights (sẽ adapt cho 1 kênh)
        weights = models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
        self.backbone = models.resnet18(weights=weights)

        # THAY ĐỔI: Adapt conv1 cho 1 kênh thay vì 3 kênh
        original_conv1 = self.backbone.conv1
        self.backbone.conv1 = nn.Conv2d(
            in_channels=1,  # THAY ĐỔI: 3 -> 1 (grayscale)
            out_channels=original_conv1.out_channels,
            kernel_size=original_conv1.kernel_size,
            stride=original_conv1.stride,
            padding=original_conv1.padding,
            bias=False
        )

        # Nếu sử dụng pretrained, copy weights từ 3 kênh sang 1 kênh
        if pretrained:
            # Lấy trung bình của 3 kênh RGB để tạo 1 kênh grayscale
            with torch.no_grad():
                self.backbone.conv1.weight = nn.Parameter(
                    original_conv1.weight.mean(dim=1, keepdim=True)
                )

        # Thay BatchNorm bằng GroupNorm
        self._replace_batchnorm_with_groupnorm(self.backbone)

        # Lấy số feature từ fc gốc
        feat_dim = self.backbone.fc.in_features

        # Bỏ lớp fc gốc
        self.backbone.fc = nn.Identity()

        # Classifier với GroupNorm (giữ nguyên, phù hợp cho Fashion-MNIST)
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(feat_dim, 128),
            nn.GroupNorm(2, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),

            nn.Linear(128, 64),
            nn.GroupNorm(2, 64),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),

            nn.Linear(64, num_classes)
        )

        self._init_weights()

    def _replace_batchnorm_with_groupnorm(self, module):
        """Thay tất cả BatchNorm bằng GroupNorm với 2 nhóm"""
        for name, child in module.named_children():
            if isinstance(child, nn.BatchNorm2d):
                num_channels = child.num_features
                setattr(module, name, nn.GroupNorm(2, num_channels))
            elif isinstance(child, nn.BatchNorm1d):
                num_channels = child.num_features
                setattr(module, name, nn.GroupNorm(2, num_channels))
            else:
                self._replace_batchnorm_with_groupnorm(child)

    def _init_weights(self):
        for m in self.classifier:
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.backbone(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

# =====================================
# 3. CẬP NHẬT HÀM TRAIN VỚI CLIENT_LR
# =====================================

def train(net, trainloader, epochs: int):
    """Huấn luyện mô hình với client learning rate theo paper"""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=CLIENT_LR, weight_decay=1e-3)
    net.train()

    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for batch in trainloader:
            # SỬA LỖI: Fashion-MNIST sử dụng key "image" thay vì "img"
            images, labels = batch["image"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # Tính loss và accuracy cho epoch
            epoch_loss += loss.item() * labels.size(0)
            total += labels.size(0)
            correct += (outputs.argmax(dim=1) == labels).sum().item()

        epoch_loss /= total
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss:.4f}, accuracy {epoch_acc:.4f}")

    return epoch_loss, epoch_acc

def test(net, testloader):
    """Đánh giá mô hình trên tập validation và trả về loss, acc."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, total_loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            # SỬA LỖI: Fashion-MNIST sử dụng key "image" thay vì "img"
            images, labels = batch["image"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * labels.size(0)
            correct += (outputs.argmax(dim=1) == labels).sum().item()
            total += labels.size(0)
    avg_loss = total_loss / total
    accuracy = correct / total
    return avg_loss, accuracy

In [None]:
def get_parameters(net) -> List[np.ndarray]:
    params = []
    for name, val in net.state_dict().items():
        if val.numel() > 0:
            params.append(val.cpu().numpy())
        else:
            print(f"Warning: Empty tensor found: {name}")
    return params

def set_parameters(net, parameters: List[np.ndarray]):
    if not parameters:
        print("Warning: No parameters received")
        return

    state_dict_keys = list(net.state_dict().keys())
    print(f"Model has {len(state_dict_keys)} parameters")
    print(f"Received {len(parameters)} parameters")

    if len(parameters) != len(state_dict_keys):
        print("Parameter count mismatch!")
        return

    params_dict = zip(state_dict_keys, parameters)
    state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=False)

## Custom FedAvg

In [None]:
class FlowerClientFedAvg(NumPyClient):
    def __init__(self, partition_id, net, trainloader, valloader):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.partition_id}] FedAvg fit")
        set_parameters(self.net, parameters)
        loss, acc = train(self.net, self.trainloader, epochs=EPOCH)
        return get_parameters(self.net), len(self.trainloader.dataset), {
            "loss": float(loss), "accuracy": float(acc)
        }

    def evaluate(self, parameters, config):
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader.dataset), {
            "loss": float(loss), "accuracy": float(accuracy)
        }

In [None]:
def client_fn_fedavg(context):
    cid = context.node_config["partition-id"]
    trainloader, valloader, _ = all_data[cid]

    net = ResNet18_GroupNorm(num_classes=10).to(DEVICE)
    return FlowerClientFedAvg(cid, net, trainloader, valloader).to_client()

client_app_fedavg = ClientApp(client_fn=client_fn_fedavg)

In [None]:
from flwr.common import Metrics, Context
from flwr.common import ndarrays_to_parameters

def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    # Multiply accuracy of each client by number of examples used
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    # Aggregate and return custom metric (weighted average)
    return {"accuracy": sum(accuracies) / sum(examples)}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = ResNet18_GroupNorm().to(DEVICE)
initial_parameters = flwr.common.ndarrays_to_parameters(get_parameters(net))

In [None]:
from flwr.server.strategy import FedAvg
import time # Import the time module

class CustomFedAvg(FedAvg):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.metrics_centralized = []
        # Thêm các biến theo dõi thời gian
        self.round_start_time = None
        self.experiment_start_time = time.time()

    def aggregate_fit(self, rnd, results, failures):
        # Bắt đầu đo thời gian round
        if self.round_start_time is None:
            self.round_start_time = time.time()

        fit_start = time.time()
        aggregated_parameters, metrics_aggregated = super().aggregate_fit(rnd, results, failures)
        fit_duration = time.time() - fit_start

        # Lấy metrics từ client results (code gốc)
        losses = [r.metrics.get("loss") for _, r in results if "loss" in r.metrics]
        accs = [r.metrics.get("accuracy") for _, r in results if "accuracy" in r.metrics]

        valid_losses = [l for l in losses if l is not None]
        valid_accs = [a for a in accs if a is not None]
        avg_loss = sum(valid_losses) / len(valid_losses) if valid_losses else None
        avg_acc = sum(valid_accs) / len(valid_accs) if valid_accs else None

        if avg_loss is not None and avg_acc is not None:
            print(f"[Round {rnd}] Train — Loss: {avg_loss:.4f}, Acc: {avg_acc:.4f}, Time: {fit_duration:.2f}s")
        else:
            print(f"[Round {rnd}] Train — Loss: N/A, Acc: N/A, Time: {fit_duration:.2f}s")

        # Đảm bảo danh sách đủ dài để lưu metrics của round hiện tại
        while len(self.metrics_centralized) < rnd:
            self.metrics_centralized.append({})

        # Lưu metrics bao gồm thời gian
        self.metrics_centralized[rnd - 1]["train_loss"] = avg_loss
        self.metrics_centralized[rnd - 1]["train_acc"] = avg_acc
        self.metrics_centralized[rnd - 1]["fit_duration"] = fit_duration

        return aggregated_parameters, metrics_aggregated

    def aggregate_evaluate(self, rnd, results, failures):
        eval_start = time.time()
        aggregated_loss, metrics_aggregated = super().aggregate_evaluate(rnd, results, failures)
        eval_duration = time.time() - eval_start

        # Tính tổng thời gian round
        round_total_time = time.time() - self.round_start_time if self.round_start_time else 0

        # Lấy metrics từ client results (code gốc)
        losses = [r.metrics.get("loss") for _, r in results if "loss" in r.metrics]
        accs = [r.metrics.get("accuracy") for _, r in results if "accuracy" in r.metrics]

        valid_losses = [l for l in losses if l is not None]
        valid_accs = [a for a in accs if a is not None]
        avg_loss = sum(valid_losses) / len(valid_losses) if valid_losses else None
        avg_acc = sum(valid_accs) / len(valid_accs) if valid_accs else None

        print(f"[Round {rnd}] Val   — Loss: {avg_loss:.4f}, Acc: {avg_acc:.4f}, Time: {eval_duration:.2f}s")
        print(f"[Round {rnd}] Total Round Time: {round_total_time:.2f}s")

        # Đảm bảo danh sách đủ dài để lưu metrics của round hiện tại
        while len(self.metrics_centralized) < rnd:
            self.metrics_centralized.append({})

        # Lưu metrics bao gồm thời gian
        self.metrics_centralized[rnd - 1]["val_loss"] = avg_loss
        self.metrics_centralized[rnd - 1]["val_acc"] = avg_acc
        self.metrics_centralized[rnd - 1]["eval_duration"] = eval_duration
        self.metrics_centralized[rnd - 1]["round_total_time"] = round_total_time

        # Reset cho round tiếp theo
        self.round_start_time = time.time()

        return aggregated_loss, metrics_aggregated

In [None]:
strategyAvg = CustomFedAvg(
    fraction_fit=FRACTION_FIT,
    fraction_evaluate=FRACTION_EVALUATE,
    min_fit_clients=MIN_FIT_CLIENTS,
    min_evaluate_clients=MIN_EVALUATE_CLIENTS,
    min_available_clients=MIN_AVAILABLE_CLIENTS,
    initial_parameters=initial_parameters,
    evaluate_metrics_aggregation_fn=weighted_average
)

def server_fn(context: Context) -> ServerAppComponents:
    config = ServerConfig(num_rounds=NUM_ROUND)
    return ServerAppComponents(strategy=strategyAvg, config=config)

# Create the ServerApp
server = ServerApp(server_fn=server_fn)

# Run simulation
run_simulation(
    server_app=server,
    client_app=client_app_fedavg,
    num_supernodes=NUM_PARTITIONS,
    backend_config=backend_config,
)

In [None]:
import matplotlib.pyplot as plt

# Lấy metrics đã lưu từ strategy
history = strategyAvg.metrics_centralized

rounds = list(range(1, len(history) + 1))
train_loss = [m.get("train_loss") for m in history]
val_loss = [m.get("val_loss") for m in history]
train_acc = [m.get("train_acc") for m in history]
val_acc = [m.get("val_acc") for m in history]

# Loss
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(rounds, train_loss, label="Train Loss", marker='o')
plt.plot(rounds, val_loss, label="Val Loss", marker='s')
plt.xlabel("Round")
plt.ylabel("Loss")
plt.title("Loss per Round")
plt.grid(True)
plt.legend()

# Accuracy
plt.subplot(1, 2, 2)
plt.plot(rounds, train_acc, label="Train Acc", marker='o')
plt.plot(rounds, val_acc, label="Val Acc", marker='s')
plt.xlabel("Round")
plt.ylabel("Accuracy")
plt.title("Accuracy per Round")
plt.grid(True)
plt.legend()

plt.tight_layout()
plt.show()


In [None]:
import json
import pickle
import pandas as pd
import numpy as np
from datetime import datetime
import os

# =============================================================================
# FUNCTIONS FOR SAVING RESULTS
# =============================================================================

def save_results_to_json(strategy, strategy_name, save_dir="fl_results"):
    """
    Lưu kết quả của strategy vào file JSON bao gồm thông tin runtime
    """
    os.makedirs(save_dir, exist_ok=True)

    # Lấy timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    # Tính tổng thời gian experiment
    total_experiment_time = time.time() - strategy.experiment_start_time

    # Tính các thống kê thời gian
    total_fit_time = sum(m.get("fit_duration", 0) for m in strategy.metrics_centralized)
    total_eval_time = sum(m.get("eval_duration", 0) for m in strategy.metrics_centralized)
    total_round_time = sum(m.get("round_total_time", 0) for m in strategy.metrics_centralized)
    avg_round_time = total_round_time / len(strategy.metrics_centralized) if strategy.metrics_centralized else 0

    # Chuẩn bị dữ liệu
    results = {
        "strategy_name": strategy_name,
        "timestamp": timestamp,
        "metrics": strategy.metrics_centralized,
        "runtime_summary": {
            "total_experiment_time": total_experiment_time,
            "total_fit_time": total_fit_time,
            "total_eval_time": total_eval_time,
            "total_round_time": total_round_time,
            "average_round_time": avg_round_time,
            "num_rounds": len(strategy.metrics_centralized)
        },
        "config": {
            "num_rounds": len(strategy.metrics_centralized),
            "num_partitions": NUM_PARTITIONS,
            "batch_size": BATCH_SIZE,
            "learning_rate": CLIENT_LR,
            "epochs_per_round": EPOCH
        }
    }

    # Lưu file
    filename = f"{save_dir}/{strategy_name}_{timestamp}.json"
    with open(filename, 'w') as f:
        json.dump(results, f, indent=2)

    print(f"✅ Saved {strategy_name} results to {filename}")
    print(f"📊 Runtime Summary:")
    print(f"   Total experiment time: {total_experiment_time:.2f}s ({total_experiment_time/60:.1f}min)")
    print(f"   Average round time: {avg_round_time:.2f}s")
    print(f"   Total training time: {total_fit_time:.2f}s")
    print(f"   Total evaluation time: {total_eval_time:.2f}s")

    return filename

def save_results_to_pickle(strategy, strategy_name, save_dir="fl_results"):
    """
    Lưu kết quả của strategy vào file pickle (có thể lưu được object phức tạp)
    """
    os.makedirs(save_dir, exist_ok=True)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    results = {
        "strategy_name": strategy_name,
        "timestamp": timestamp,
        "strategy_object": strategy,  # Lưu cả object
        "metrics": strategy.metrics_centralized,
        "config": {
            "num_rounds": len(strategy.metrics_centralized),
            "num_partitions": NUM_PARTITIONS,
            "batch_size": BATCH_SIZE,
            "learning_rate": CLIENT_LR,
            "epochs_per_round": EPOCH
        }
    }

    filename = f"{save_dir}/{strategy_name}_{timestamp}.pkl"
    with open(filename, 'wb') as f:
        pickle.dump(results, f)

    print(f"✅ Saved {strategy_name} results to {filename}")
    return filename

def save_results_to_csv(strategy, strategy_name, save_dir="fl_results"):
    """
    Lưu kết quả của strategy vào file CSV (dễ đọc và phân tích)
    """
    os.makedirs(save_dir, exist_ok=True)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    # Chuyển đổi metrics thành DataFrame
    metrics_data = []
    for round_idx, metrics in enumerate(strategy.metrics_centralized):
        row = {
            "round": round_idx + 1,
            "train_loss": metrics.get("train_loss"),
            "train_acc": metrics.get("train_acc"),
            "val_loss": metrics.get("val_loss"),
            "val_acc": metrics.get("val_acc"),
            "strategy": strategy_name,
            "timestamp": timestamp
        }
        metrics_data.append(row)

    df = pd.DataFrame(metrics_data)

    filename = f"{save_dir}/{strategy_name}_{timestamp}.csv"
    df.to_csv(filename, index=False)

    print(f"✅ Saved {strategy_name} results to {filename}")
    return filename

def save_all_strategies_comparison(strategies_dict, save_dir="fl_results"):
    """
    Lưu so sánh tất cả strategies vào một file CSV

    Args:
        strategies_dict: Dict chứa {strategy_name: strategy_object}
        save_dir: Thư mục lưu kết quả
    """
    os.makedirs(save_dir, exist_ok=True)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    all_data = []

    for strategy_name, strategy in strategies_dict.items():
        for round_idx, metrics in enumerate(strategy.metrics_centralized):
            row = {
                "round": round_idx + 1,
                "train_loss": metrics.get("train_loss"),
                "train_acc": metrics.get("train_acc"),
                "val_loss": metrics.get("val_loss"),
                "val_acc": metrics.get("val_acc"),
                "strategy": strategy_name,
                "timestamp": timestamp
            }
            all_data.append(row)

    df = pd.DataFrame(all_data)

    filename = f"{save_dir}/comparison_{timestamp}.csv"
    df.to_csv(filename, index=False)

    print(f"✅ Saved comparison results to {filename}")
    return filename

# =============================================================================
# FUNCTIONS FOR LOADING RESULTS
# =============================================================================

def load_results_from_json(filename):
    """Load kết quả từ file JSON"""
    with open(filename, 'r') as f:
        results = json.load(f)
    print(f"✅ Loaded results from {filename}")
    return results

def load_results_from_pickle(filename):
    """Load kết quả từ file pickle"""
    with open(filename, 'rb') as f:
        results = pickle.load(f)
    print(f"✅ Loaded results from {filename}")
    return results

def load_results_from_csv(filename):
    """Load kết quả từ file CSV"""
    df = pd.read_csv(filename)
    print(f"✅ Loaded results from {filename}")
    return df

# =============================================================================
# VISUALIZATION FUNCTIONS
# =============================================================================

def plot_comparison_from_csv(csv_filename, save_plot=True):
    """
    Tạo biểu đồ so sánh từ file CSV
    """
    df = pd.read_csv(csv_filename)

    import matplotlib.pyplot as plt

    # Tạo subplots
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # Train Loss
    for strategy in df['strategy'].unique():
        strategy_data = df[df['strategy'] == strategy]
        axes[0, 0].plot(strategy_data['round'], strategy_data['train_loss'],
                       label=f'{strategy}', marker='o')
    axes[0, 0].set_title('Train Loss Comparison')
    axes[0, 0].set_xlabel('Round')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True)

    # Val Loss
    for strategy in df['strategy'].unique():
        strategy_data = df[df['strategy'] == strategy]
        axes[0, 1].plot(strategy_data['round'], strategy_data['val_loss'],
                       label=f'{strategy}', marker='s')
    axes[0, 1].set_title('Validation Loss Comparison')
    axes[0, 1].set_xlabel('Round')
    axes[0, 1].set_ylabel('Loss')
    axes[0, 1].legend()
    axes[0, 1].grid(True)

    # Train Accuracy
    for strategy in df['strategy'].unique():
        strategy_data = df[df['strategy'] == strategy]
        axes[1, 0].plot(strategy_data['round'], strategy_data['train_acc'],
                       label=f'{strategy}', marker='o')
    axes[1, 0].set_title('Train Accuracy Comparison')
    axes[1, 0].set_xlabel('Round')
    axes[1, 0].set_ylabel('Accuracy')
    axes[1, 0].legend()
    axes[1, 0].grid(True)

    # Val Accuracy
    for strategy in df['strategy'].unique():
        strategy_data = df[df['strategy'] == strategy]
        axes[1, 1].plot(strategy_data['round'], strategy_data['val_acc'],
                       label=f'{strategy}', marker='s')
    axes[1, 1].set_title('Validation Accuracy Comparison')
    axes[1, 1].set_xlabel('Round')
    axes[1, 1].set_ylabel('Accuracy')
    axes[1, 1].legend()
    axes[1, 1].grid(True)

    plt.tight_layout()

    if save_plot:
        plot_filename = csv_filename.replace('.csv', '_comparison_plo_100.png')
        plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
        print(f"✅ Saved plot to {plot_filename}")

    plt.show()

def create_summary_report(strategies_dict, save_dir="fl_results"):
    """
    Tạo báo cáo tóm tắt kết quả của tất cả strategies bao gồm thông tin runtime
    """
    os.makedirs(save_dir, exist_ok=True)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

    summary_data = []

    for strategy_name, strategy in strategies_dict.items():
        metrics = strategy.metrics_centralized

        # Tính toán thống kê hiện có
        final_round = len(metrics)
        final_train_acc = metrics[-1].get("train_acc", 0) if metrics else 0
        final_val_acc = metrics[-1].get("val_acc", 0) if metrics else 0
        final_train_loss = metrics[-1].get("train_loss", 0) if metrics else 0
        final_val_loss = metrics[-1].get("val_loss", 0) if metrics else 0

        # Tìm best validation accuracy
        val_accs = [m.get("val_acc", 0) for m in metrics if m.get("val_acc") is not None]
        best_val_acc = max(val_accs) if val_accs else 0
        best_val_round = val_accs.index(best_val_acc) + 1 if val_accs else 0

        # THÊM: Tính toán thống kê thời gian
        total_experiment_time = time.time() - strategy.experiment_start_time if hasattr(strategy, 'experiment_start_time') else 0
        total_fit_time = sum(m.get("fit_duration", 0) for m in metrics)
        total_eval_time = sum(m.get("eval_duration", 0) for m in metrics)
        total_round_time = sum(m.get("round_total_time", 0) for m in metrics)
        avg_round_time = total_round_time / len(metrics) if metrics else 0
        avg_fit_time = total_fit_time / len(metrics) if metrics else 0
        avg_eval_time = total_eval_time / len(metrics) if metrics else 0

        # Tìm round chậm nhất và nhanh nhất
        round_times = [m.get("round_total_time", 0) for m in metrics if m.get("round_total_time") is not None]
        max_round_time = max(round_times) if round_times else 0
        min_round_time = min(round_times) if round_times else 0

        summary_data.append({
            "strategy": strategy_name,
            "final_round": final_round,
            "final_train_acc": final_train_acc,
            "final_val_acc": final_val_acc,
            "final_train_loss": final_train_loss,
            "final_val_loss": final_val_loss,
            "best_val_acc": best_val_acc,
            "best_val_round": best_val_round,
            # THÊM: Thông tin runtime
            "total_experiment_time_sec": total_experiment_time,
            "total_experiment_time_min": total_experiment_time / 60,
            "total_fit_time_sec": total_fit_time,
            "total_eval_time_sec": total_eval_time,
            "avg_round_time_sec": avg_round_time,
            "avg_fit_time_sec": avg_fit_time,
            "avg_eval_time_sec": avg_eval_time,
            "max_round_time_sec": max_round_time,
            "min_round_time_sec": min_round_time,
            "timestamp": timestamp
        })

    df_summary = pd.DataFrame(summary_data)

    filename = f"{save_dir}/summary_report_{timestamp}_100.csv"
    df_summary.to_csv(filename, index=False)

    print(f"✅ Saved summary report to {filename}")
    print("\n📊 SUMMARY REPORT:")
    print(df_summary.to_string(index=False))

    # THÊM: In thông tin runtime tóm tắt
    print("\n⏱️ RUNTIME SUMMARY:")
    for _, row in df_summary.iterrows():
        print(f"{row['strategy']}:")
        print(f"  Total time: {row['total_experiment_time_min']:.1f}min ({row['total_experiment_time_sec']:.1f}s)")
        print(f"  Avg round: {row['avg_round_time_sec']:.2f}s (fit: {row['avg_fit_time_sec']:.2f}s, eval: {row['avg_eval_time_sec']:.2f}s)")
        print(f"  Range: {row['min_round_time_sec']:.2f}s - {row['max_round_time_sec']:.2f}s")

    return filename

# THÊM: Cập nhật phần save results để bao gồm runtime
print("\n💾 Saving individual strategy results...")

# Lưu kết quả FedAvg với thông tin runtime
save_results_to_json(strategyAvg, "FedAvg")
save_results_to_csv(strategyAvg, "FedAvg")

# THÊM: Tạo dictionary cho summary report
strategies_dict = {
    "FedAvg": strategyAvg
}

# Tạo báo cáo tóm tắt với thông tin runtime
create_summary_report(strategies_dict)