## Cài đặt môi trường

In [None]:
!pip install -q flwr[simulation]
!pip install -q flwr-datasets[vision]

In [None]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple
import random

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Context
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import Strategy
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset

from flwr_datasets import FederatedDataset
from flwr_datasets.visualization import plot_label_distributions

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from torch.utils.data import DataLoader
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import DirichletPartitioner

import time


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")

## Base config

In [None]:
import os

CACHE_DIR = "/kaggle/working/hf_cache"
os.environ["HF_DATASETS_OFFLINE"] = "0"
os.makedirs(CACHE_DIR, exist_ok=True)

In [None]:
backend_config = {"client_resources": {"num_cpus": 1}}
if DEVICE.type == "cuda":
    backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 1}}  # GPU optional

NUM_PARTITIONS = 20
NUM_ROUND = 1
EPOCH = 2
BATCH_SIZE = 20


FRACTION_FIT=0.85
FRACTION_EVALUATE=0.3
MIN_FIT_CLIENTS=7
MIN_EVALUATE_CLIENTS=10
MIN_AVAILABLE_CLIENTS=10

ALPHA = 0.3

CLIENT_LR = 10**(-2.5)  # ≈ 0.00316 theo paper
ETA_VALUES = [0.01,0.03, 0.001, 0.0001] # Server_LR

# Tham số FedAdam
BETA_1 = 0.9
BETA_2 = 0.99
TAU = 1e-3  # Adaptive tau

## Load Dataset

In [None]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)  # Chỉ cần gọi 1 lần ở đầu chương trình

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from torch.utils.data import DataLoader
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import DirichletPartitioner

# Cập nhật transforms cho CIFAR-10 với ResNet50
def load_datasets(partition_id, num_partitions: int):
    # Tạo dữ liệu Non-IID
    partitioner = DirichletPartitioner(
        num_partitions=NUM_PARTITIONS,
        partition_by="label",
        alpha=ALPHA,
        min_partition_size=num_partitions,
        self_balancing=True
    )

    fds = FederatedDataset(
        dataset="cifar10",
        partitioners={"train": partitioner},
        cache_dir=CACHE_DIR
    )

    partition = fds.load_partition(partition_id)
    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)

    # Transforms được tối ưu cho ResNet18 + FL
    train_transforms = transforms.Compose([
        transforms.Resize((112, 112)),  # Nhỏ hơn, tiết kiệm computation
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=5),  # Giảm từ 10 xuống 5
        transforms.ColorJitter(brightness=0.1, contrast=0.1),  # Thêm augmentation nhẹ
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    val_transforms = transforms.Compose([
        transforms.Resize((112, 112)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    def apply_train_transforms(batch):
        try:
            batch["img"] = [train_transforms(img.convert("RGB")) for img in batch["img"]]
            return batch
        except Exception as e:
            print(f"Error in train transforms: {e}")
            raise

    def apply_val_transforms(batch):
        try:
            batch["img"] = [val_transforms(img.convert("RGB")) for img in batch["img"]]
            return batch
        except Exception as e:
            print(f"Error in val transforms: {e}")
            raise

    # Apply transforms
    partition_train_test["train"] = partition_train_test["train"].with_transform(apply_train_transforms)
    partition_train_test["test"] = partition_train_test["test"].with_transform(apply_val_transforms)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    pin_memory = torch.cuda.is_available()
    num_workers = 0  # Luôn để 0 để tránh bug multiprocessing

    trainloader = DataLoader(
        partition_train_test["train"],
        batch_size=BATCH_SIZE,
        shuffle=True,
        pin_memory=pin_memory,
        num_workers=num_workers,
        persistent_workers=False,
        drop_last=True
    )

    valloader = DataLoader(
        partition_train_test["test"],
        batch_size=BATCH_SIZE,
        shuffle=False,
        pin_memory=pin_memory,
        num_workers=num_workers,
        persistent_workers=False,
        drop_last=False
    )


    testset = fds.load_split("test").with_transform(apply_val_transforms)
    testloader = DataLoader(
        testset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        pin_memory=pin_memory,
        num_workers=num_workers,
        persistent_workers=False,  # num_workers=0 thì phải là False
        drop_last=False
    )

    return trainloader, valloader, testloader


In [None]:
from huggingface_hub import login

login(token="")
all_data = [load_datasets(i, NUM_PARTITIONS) for i in range(NUM_PARTITIONS)]

In [None]:
from flwr_datasets import FederatedDataset
from flwr_datasets.visualization import plot_label_distributions
import matplotlib.pyplot as plt

# Tạo partitioner
partitioner = DirichletPartitioner(
    num_partitions=NUM_PARTITIONS,
    partition_by="label",
    alpha=ALPHA,
    min_partition_size=NUM_PARTITIONS,
    self_balancing=True
)

# Tạo FederatedDataset
fds = FederatedDataset(
    dataset="cifar10",
    partitioners={"train": partitioner},
    cache_dir=CACHE_DIR
)

# Visualize phân bố
print(f"Visualizing Non-IID distribution with alpha={ALPHA}")
partitioner = fds.partitioners["train"]
figure, axis, dataframe = plot_label_distributions(
    partitioner=partitioner,
    label_name="label",
    legend=True,
    verbose_labels=True,
    plot_type="bar",  # Thêm này để rõ hơn
    size_unit="absolute"  # Hoặc "percentage"
)

# Customize plot
plt.title(f'CIFAR-10 Non-IID Distribution (α={ALPHA})')
plt.xlabel('Clients')
plt.ylabel('Number of Samples')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# In thống kê
print("\n=== Distribution Statistics ===")
print(f"Total clients: {NUM_PARTITIONS}")
print(f"Alpha (non-IID level): {ALPHA}")
print(f"Dataset: CIFAR-10")
print(f"Classes: {dataframe.columns.tolist()}")

# Thống kê chi tiết
total_samples = dataframe.sum().sum()
print(f"Total samples: {total_samples}")
print(f"Adam samples per client: {total_samples/NUM_PARTITIONS:.1f}")
print(f"Min samples: {dataframe.sum(axis=1).min()}")
print(f"Max samples: {dataframe.sum(axis=1).max()}")

In [None]:
import numpy as np
from scipy.stats import entropy

# Function to calculate entropy for a list of counts
def calculate_entropy(counts):
    total = sum(counts)
    if total == 0:
        return 0
    probabilities = [c / total for c in counts]
    return entropy(probabilities, base=2) # Using base 2 for bits

entropies = []
for client_id in range(NUM_PARTITIONS):
    label_counts = dataframe.iloc[client_id].tolist() # Get label counts for this client
    entropies.append(calculate_entropy(label_counts))

average_entropy = np.mean(entropies)
std_entropy = np.std(entropies)

print(f"Average Entropy: {average_entropy:.4f}")
print(f"Entropy Std: {std_entropy:.4f}")


## Model ResNet18

In [None]:
import torch.nn as nn
from torchvision import models

class ResNet18_GroupNorm(nn.Module):
    def __init__(self, num_classes: int = 10, pretrained: bool = True):
        super(ResNet18_GroupNorm, self).__init__()

        # Sử dụng pretrained weights
        weights = models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
        self.backbone = models.resnet18(weights=weights)

        # Thay BatchNorm bằng GroupNorm (2 nhóm/lớp)
        self._replace_batchnorm_with_groupnorm(self.backbone)

        # Lấy số feature từ fc gốc
        feat_dim = self.backbone.fc.in_features

        # Bỏ lớp fc gốc
        self.backbone.fc = nn.Identity()

        # Classifier với GroupNorm
        self.classifier = nn.Sequential(
            # Thêm dropout để tránh overfitting
            nn.Dropout(0.3),  # Dropout sau feature extraction
            nn.Linear(feat_dim, 128),
            nn.GroupNorm(2, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),  # Dropout giữa các lớp

            nn.Linear(128, 64),
            nn.GroupNorm(2, 64),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),  # Dropout nhẹ hơn ở cuối

            nn.Linear(64, num_classes)
        )

        self._init_weights()

    def _replace_batchnorm_with_groupnorm(self, module):
        """Thay tất cả BatchNorm bằng GroupNorm với 2 nhóm"""
        for name, child in module.named_children():
            if isinstance(child, nn.BatchNorm2d):
                num_channels = child.num_features
                # Thay BatchNorm2d bằng GroupNorm với 2 nhóm
                setattr(module, name, nn.GroupNorm(2, num_channels))
            elif isinstance(child, nn.BatchNorm1d):
                num_channels = child.num_features
                # Thay BatchNorm1d bằng GroupNorm với 2 nhóm
                setattr(module, name, nn.GroupNorm(2, num_channels))
            else:
                # Đệ quy cho các module con
                self._replace_batchnorm_with_groupnorm(child)

    def _init_weights(self):
        for m in self.classifier:
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.backbone(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

# =====================================
# 3. CẬP NHẬT HÀM TRAIN VỚI CLIENT_LR
# =====================================

def train(net, trainloader, epochs: int):
    """Huấn luyện mô hình với client learning rate theo paper"""
    criterion = torch.nn.CrossEntropyLoss()
    # Sử dụng CLIENT_LR theo paper
    optimizer = torch.optim.SGD(net.parameters(), lr=CLIENT_LR, weight_decay=1e-3)
    net.train()

    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for batch in trainloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # Tính loss và accuracy cho epoch
            epoch_loss += loss.item() * labels.size(0)
            total += labels.size(0)
            correct += (outputs.argmax(dim=1) == labels).sum().item()

        epoch_loss /= total
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss:.4f}, accuracy {epoch_acc:.4f}")

    return epoch_loss, epoch_acc

def test(net, testloader):
    """Đánh giá mô hình trên tập validation và trả về loss, acc."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, total_loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * labels.size(0)
            correct += (outputs.argmax(dim=1) == labels).sum().item()
            total += labels.size(0)
    adam_loss = total_loss / total
    accuracy = correct / total
    return adam_loss, accuracy


In [None]:
def get_parameters(net) -> List[np.ndarray]:
    params = []
    for name, val in net.state_dict().items():
        if val.numel() > 0:
            params.append(val.cpu().numpy())
        else:
            print(f"Warning: Empty tensor found: {name}")
    return params

def set_parameters(net, parameters: List[np.ndarray]):
    if not parameters:
        print("Warning: No parameters received")
        return

    state_dict_keys = list(net.state_dict().keys())
    print(f"Model has {len(state_dict_keys)} parameters")
    print(f"Received {len(parameters)} parameters")

    if len(parameters) != len(state_dict_keys):
        print("Parameter count mismatch!")
        return

    params_dict = zip(state_dict_keys, parameters)
    state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=False)

## Custom FedAdam

In [None]:
class FlowerClientFedAdam(NumPyClient):
    def __init__(self, partition_id, net, trainloader, valloader):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.partition_id}] FedAdam fit")
        set_parameters(self.net, parameters)
        loss, acc = train(self.net, self.trainloader, epochs=EPOCH)
        return get_parameters(self.net), len(self.trainloader.dataset), {
            "loss": float(loss), "accuracy": float(acc)
        }

    def evaluate(self, parameters, config):
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader.dataset), {
            "loss": float(loss), "accuracy": float(accuracy)
        }

In [None]:
def client_fn_fedadam(context):
    cid = context.node_config["partition-id"]
    trainloader, valloader, _ = all_data[cid]

    net = ResNet18_GroupNorm(num_classes=10).to(DEVICE)
    return FlowerClientFedAdam(cid, net, trainloader, valloader).to_client()

client_app_fedadam = ClientApp(client_fn=client_fn_fedadam)

In [None]:
from flwr.common import Metrics, Context
from flwr.common import ndarrays_to_parameters

def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    # Multiply accuracy of each client by number of examples used
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    # Aggregate and return custom metric (weighted average)
    return {"accuracy": sum(accuracies) / sum(examples)}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = ResNet18_GroupNorm().to(DEVICE)
initial_parameters = flwr.common.ndarrays_to_parameters(get_parameters(net))

In [None]:
from flwr.server.strategy import FedAdam
import time

class CustomFedAdam(FedAdam):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.metrics_centralized = []
        self.round_start_times = {}

    def configure_fit(self, server_round, parameters, client_manager):
        # Ghi lại thời gian bắt đầu của round
        self.round_start_times[server_round] = time.time()
        return super().configure_fit(server_round, parameters, client_manager)

    def aggregate_fit(self, rnd, results, failures):
        aggregated_parameters, _ = super().aggregate_fit(rnd, results, failures)

        losses = [r.metrics.get("loss") for _, r in results if "loss" in r.metrics]
        accs = [r.metrics.get("accuracy") for _, r in results if "accuracy" in r.metrics]
        avg_loss = sum(losses) / len(losses) if losses else None
        avg_acc = sum(accs) / len(accs) if accs else None

        print(f"[Round {rnd}] Train — Loss: {avg_loss:.4f}, Acc: {avg_acc:.4f}" if avg_loss is not None else
              f"[Round {rnd}] Train — Loss: N/A, Acc: N/A")

        while len(self.metrics_centralized) < rnd:
            self.metrics_centralized.append({})

        self.metrics_centralized[rnd - 1].update({
            "train_loss": avg_loss,
            "train_acc": avg_acc
        })

        return aggregated_parameters, {}

    def aggregate_evaluate(self, rnd, results, failures):
        eval_start = time.time()
        aggregated_loss, _ = super().aggregate_evaluate(rnd, results, failures)
        eval_duration = time.time() - eval_start

        round_total_time = time.time() - self.round_start_times.get(rnd, time.time())

        losses = [r.metrics.get("loss") for _, r in results if "loss" in r.metrics]
        accs = [r.metrics.get("accuracy") for _, r in results if "accuracy" in r.metrics]
        avg_loss = sum(losses) / len(losses) if losses else None
        avg_acc = sum(accs) / len(accs) if accs else None

        print(f"[Round {rnd}] Val   — Loss: {avg_loss:.4f}, Acc: {avg_acc:.4f}, Eval Time: {eval_duration:.2f}s")
        print(f"[Round {rnd}] Total Round Time: {round_total_time:.2f}s")

        while len(self.metrics_centralized) < rnd:
            self.metrics_centralized.append({})

        self.metrics_centralized[rnd - 1].update({
            "val_loss": avg_loss,
            "val_acc": avg_acc,
            "eval_duration": eval_duration,
            "round_total_time": round_total_time
        })

        return aggregated_loss, {}


In [None]:
class GridSearchFedAdam(FedAdam):
    def __init__(self, eta_value, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.eta_value = eta_value
        self.metrics_centralized = []
        self.experiment_start_time = None
        self.experiment_end_time = None

    def configure_fit(self, server_round, parameters, client_manager):
        # Chỉ ghi thời gian bắt đầu experiment (round 1)
        if server_round == 1:
            self.experiment_start_time = time.time()
        return super().configure_fit(server_round, parameters, client_manager)

    def aggregate_fit(self, rnd, results, failures):
        aggregated_parameters, _ = super().aggregate_fit(rnd, results, failures)

        losses = [r.metrics.get("loss") for _, r in results if "loss" in r.metrics]
        accs = [r.metrics.get("accuracy") for _, r in results if "accuracy" in r.metrics]
        avg_loss = sum(losses) / len(losses) if losses else None
        avg_acc = sum(accs) / len(accs) if accs else None

        print(f"[η={self.eta_value}][Round {rnd}] Train — Loss: {avg_loss:.4f}, Acc: {avg_acc:.4f}")

        # Lưu metrics không có thời gian round
        while len(self.metrics_centralized) < rnd:
            self.metrics_centralized.append({})

        self.metrics_centralized[rnd - 1].update({
            "train_loss": avg_loss,
            "train_acc": avg_acc
        })

        return aggregated_parameters, {}

    def aggregate_evaluate(self, rnd, results, failures):
        aggregated_loss, _ = super().aggregate_evaluate(rnd, results, failures)

        losses = [r.metrics.get("loss") for _, r in results if "loss" in r.metrics]
        accs = [r.metrics.get("accuracy") for _, r in results if "accuracy" in r.metrics]
        avg_loss = sum(losses) / len(losses) if losses else None
        avg_acc = sum(accs) / len(accs) if accs else None

        print(f"[η={self.eta_value}][Round {rnd}] Val   — Loss: {avg_loss:.4f}, Acc: {avg_acc:.4f}")

        # Ghi thời gian kết thúc experiment (round cuối)
        if rnd == NUM_ROUND:
            self.experiment_end_time = time.time()

        while len(self.metrics_centralized) < rnd:
            self.metrics_centralized.append({})

        self.metrics_centralized[rnd - 1].update({
            "val_loss": avg_loss,
            "val_acc": avg_acc
        })

        return aggregated_loss, {}

    def get_total_experiment_time(self):
        """Trả về tổng thời gian thực hiện experiment"""
        if self.experiment_start_time and self.experiment_end_time:
            return self.experiment_end_time - self.experiment_start_time
        return None


In [None]:
import gc
import psutil
def monitor_memory():
    """Monitor memory usage"""
    process = psutil.Process(os.getpid())
    memory_info = process.memory_info()
    memory_mb = memory_info.rss / 1024 / 1024
    print(f"🧠 Current memory usage: {memory_mb:.1f} MB")
    return memory_mb

def cleanup_memory():
    """Cleanup memory để tránh memory leak"""
    print("🧹 Cleaning up memory...")

    # Force garbage collection
    collected = gc.collect()
    print(f"   Collected {collected} objects")

    # Clear CUDA cache nếu đang dùng GPU
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print("   Cleared CUDA cache")

    # Monitor memory after cleanup
    monitor_memory()

In [None]:
def run_single_eta_experiment_with_memory_management(eta_value):
    """Chạy một experiment với giá trị eta cụ thể + memory management"""
    print(f"\n🚀 Starting experiment with η = {eta_value}")
    print("="*60)

    # Monitor memory trước khi bắt đầu
    print("Memory before experiment:")
    monitor_memory()

    try:
        # Tạo initial parameters
        net = ResNet18_GroupNorm().to(DEVICE)
        initial_parameters = flwr.common.ndarrays_to_parameters(get_parameters(net))

        # Tạo strategy với eta cụ thể
        strategy = GridSearchFedAdam(
            eta_value=eta_value,
            fraction_fit=FRACTION_FIT,
            fraction_evaluate=FRACTION_EVALUATE,
            min_fit_clients=MIN_FIT_CLIENTS,
            min_evaluate_clients=MIN_EVALUATE_CLIENTS,
            min_available_clients=MIN_AVAILABLE_CLIENTS,
            evaluate_metrics_aggregation_fn=weighted_average,
            initial_parameters=initial_parameters,

            eta=eta_value,
            beta_1=BETA_1,
            beta_2=BETA_2,
            tau=TAU,
        )

        def server_fn(context: Context) -> ServerAppComponents:
            config = ServerConfig(num_rounds=NUM_ROUND)
            return ServerAppComponents(strategy=strategy, config=config)

        server = ServerApp(server_fn=server_fn)

        # Chạy simulation
        start_time = time.time()
        run_simulation(
            server_app=server,
            client_app=client_app_fedadam,
            num_supernodes=NUM_PARTITIONS,
            backend_config=backend_config,
        )
        end_time = time.time()

        total_time = end_time - start_time
        print(f"✅ Experiment η = {eta_value} completed in {total_time:.2f} seconds")

        # Memory cleanup sau mỗi experiment
        print("Memory after experiment:")
        monitor_memory()

        # Cleanup trước khi return
        del net, initial_parameters, server
        cleanup_memory()

        return strategy, total_time

    except Exception as e:
        print(f"❌ Error in experiment η = {eta_value}: {str(e)}")
        # Cleanup ngay cả khi có lỗi
        cleanup_memory()
        raise e

def run_grid_search_with_memory_management():
    """Chạy grid search với memory management"""
    print("🔍 STARTING GRID SEARCH FOR FEDADAM ETA PARAMETER")
    print(f"📊 Testing eta values: {ETA_VALUES}")
    print(f"⚙️  Configuration: {NUM_ROUND} rounds, {NUM_PARTITIONS} clients, {EPOCH} epochs/round")
    print("="*80)

    results = {}
    all_strategies = {}
    grid_search_start = time.time()

    # Initial memory check
    print("Initial memory state:")
    monitor_memory()

    for i, eta in enumerate(ETA_VALUES):
        print(f"\n[{i+1}/{len(ETA_VALUES)}] Testing η = {eta}")

        # Memory check trước mỗi experiment
        memory_before = monitor_memory()

        # Warning nếu memory quá cao
        if memory_before > 8000:  # 8GB
            print("⚠️  WARNING: High memory usage detected. Running cleanup...")
            cleanup_memory()

        try:
            strategy, experiment_time = run_single_eta_experiment_with_memory_management(eta)

            # Lấy kết quả cuối cùng
            final_metrics = strategy.metrics_centralized[-1] if strategy.metrics_centralized else {}
            final_val_acc = final_metrics.get("val_acc", 0)
            final_val_loss = final_metrics.get("val_loss", float('inf'))

            # Tìm best validation accuracy
            val_accs = [m.get("val_acc", 0) for m in strategy.metrics_centralized if m.get("val_acc") is not None]
            best_val_acc = max(val_accs) if val_accs else 0
            best_round = val_accs.index(best_val_acc) + 1 if val_accs else 0

            results[eta] = {
                "final_val_acc": final_val_acc,
                "final_val_loss": final_val_loss,
                "best_val_acc": best_val_acc,
                "best_round": best_round,
                "experiment_time": experiment_time,
                "converged": final_val_acc > 0.1
            }

            all_strategies[f"FedAdam_eta_{eta}"] = strategy

            print(f"📈 Results for η = {eta}:")
            print(f"   Final Val Acc: {final_val_acc:.4f}")
            print(f"   Best Val Acc: {best_val_acc:.4f} (Round {best_round})")
            print(f"   Experiment Time: {experiment_time:.2f}s")

            # Lưu kết quả interim để tránh mất dữ liệu
            if i == 0 or (i + 1) % 2 == 0:  # Lưu sau mỗi 2 experiments
                print(f"💾 Saving interim results after experiment {i+1}...")
                interim_results_df = pd.DataFrame.from_dict(results, orient='index')
                interim_results_df.index.name = 'eta'
                best_eta_so_far = interim_results_df.loc[interim_results_df['best_val_acc'].idxmax()].name

                save_grid_search_results(
                    results,
                    all_strategies,
                    time.time() - grid_search_start,
                    best_eta_so_far,
                    interim_results_df,
                    save_dir=f"fl_gridsearch_interim_{i+1}"
                )

        except Exception as e:
            print(f"❌ Error with η = {eta}: {str(e)}")
            results[eta] = {
                "final_val_acc": 0,
                "final_val_loss": float('inf'),
                "best_val_acc": 0,
                "best_round": 0,
                "experiment_time": 0,
                "converged": False,
                "error": str(e)
            }

            # Cleanup sau khi có lỗi
            cleanup_memory()

    grid_search_end = time.time()
    total_grid_search_time = grid_search_end - grid_search_start

    print("\n" + "="*80)
    print("🏆 GRID SEARCH COMPLETED!")
    print(f"⏱️  Total grid search time: {total_grid_search_time/60:.2f} minutes")
    print("="*80)

    # Final cleanup
    cleanup_memory()

    return results, all_strategies, total_grid_search_time

In [None]:
# Execute grid search với memory management
print("🚀 Starting Grid Search with Memory Management...")
grid_results, all_strategies_dict, total_time = run_grid_search_with_memory_management()


In [None]:
import pandas as pd
import matplotlib.pyplot as plt

def analyze_grid_search_results(results, total_time):
    """Phân tích và hiển thị kết quả grid search"""

    print("\n📊 GRID SEARCH RESULTS ANALYSIS")
    print("="*80)

    # Tạo DataFrame để dễ phân tích
    df_results = pd.DataFrame.from_dict(results, orient='index')
    df_results.index.name = 'eta'
    df_results = df_results.sort_values('best_val_acc', ascending=False)

    print("\n🏆 RANKING BY BEST VALIDATION ACCURACY:")
    print("-"*60)
    for i, (eta, row) in enumerate(df_results.iterrows()):
        print(f"{i+1:2d}. η = {eta:6.3f} | Best Val Acc: {row['best_val_acc']:.4f} | "
              f"Final Val Acc: {row['final_val_acc']:.4f} | Time: {row['experiment_time']:6.1f}s")

    # Tìm eta tối ưu
    best_eta = df_results.index[0]
    best_result = df_results.iloc[0]

    print(f"\n🎯 OPTIMAL ETA FOUND:")
    print(f"   η* = {best_eta}")
    print(f"   Best Validation Accuracy: {best_result['best_val_acc']:.4f}")
    print(f"   Achieved at Round: {best_result['best_round']}")
    print(f"   Final Validation Accuracy: {best_result['final_val_acc']:.4f}")
    print(f"   Experiment Time: {best_result['experiment_time']:.2f}s")

    print(f"\n⏱️  TIMING SUMMARY:")
    print(f"   Total Grid Search Time: {total_time/60:.2f} minutes")
    print(f"   Average Time per η: {df_results['experiment_time'].mean():.2f}s")
    print(f"   Fastest Experiment: {df_results['experiment_time'].min():.2f}s")
    print(f"   Slowest Experiment: {df_results['experiment_time'].max():.2f}s")

    # Visualization
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

    # 1. Best Validation Accuracy vs Eta
    eta_values = df_results.index.tolist()
    best_accs = df_results['best_val_acc'].tolist()

    ax1.plot(eta_values, best_accs, 'bo-', linewidth=2, markersize=8)
    ax1.axhline(y=max(best_accs), color='r', linestyle='--', alpha=0.7)
    ax1.axvline(x=best_eta, color='r', linestyle='--', alpha=0.7)
    ax1.set_xlabel('η (Server Learning Rate)')
    ax1.set_ylabel('Best Validation Accuracy')
    ax1.set_title('Grid Search Results: Best Val Accuracy vs η')
    ax1.set_xscale('log')
    ax1.grid(True, alpha=0.3)
    ax1.text(best_eta, max(best_accs), f'  η*={best_eta}\n  Acc={max(best_accs):.4f}',
             verticalalignment='bottom')

    # 2. Final vs Best Accuracy
    final_accs = df_results['final_val_acc'].tolist()
    ax2.scatter(best_accs, final_accs, s=100, alpha=0.7)
    ax2.plot([0, 1], [0, 1], 'k--', alpha=0.5)
    ax2.set_xlabel('Best Validation Accuracy')
    ax2.set_ylabel('Final Validation Accuracy')
    ax2.set_title('Final vs Best Validation Accuracy')
    ax2.grid(True, alpha=0.3)

    # Add eta labels
    for eta, best_acc, final_acc in zip(eta_values, best_accs, final_accs):
        ax2.annotate(f'{eta}', (best_acc, final_acc), xytext=(5, 5),
                    textcoords='offset points', fontsize=8)

    # 3. Experiment Time vs Eta
    times = df_results['experiment_time'].tolist()
    ax3.bar(range(len(eta_values)), times, alpha=0.7)
    ax3.set_xticks(range(len(eta_values)))
    ax3.set_xticklabels([f'{eta:.3f}' for eta in eta_values], rotation=45)
    ax3.set_xlabel('η (Server Learning Rate)')
    ax3.set_ylabel('Experiment Time (seconds)')
    ax3.set_title('Experiment Duration vs η')
    ax3.grid(True, alpha=0.3)

    # 4. Convergence Round vs Eta
    best_rounds = df_results['best_round'].tolist()
    ax4.plot(eta_values, best_rounds, 'go-', linewidth=2, markersize=8)
    ax4.set_xlabel('η (Server Learning Rate)')
    ax4.set_ylabel('Round of Best Performance')
    ax4.set_title('Convergence Speed vs η')
    ax4.set_xscale('log')
    ax4.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    return df_results, best_eta

# Analyze results
results_df, optimal_eta = analyze_grid_search_results(grid_results, total_time)


In [None]:
from datetime import datetime
import os
import json
def save_grid_search_results(results, strategies_dict, total_time, optimal_eta, results_df, save_dir="fl_gridsearch_results"):
    """Lưu kết quả grid search với memory optimization"""
    try:
        os.makedirs(save_dir, exist_ok=True)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

        # 1. Save summary results to JSON
        summary_results = {
            "experiment_info": {
                "timestamp": timestamp,
                "algorithm": "FedAdam",
                "parameter_optimized": "eta (server_learning_rate)",
                "eta_values_tested": ETA_VALUES,
                "optimal_eta": float(optimal_eta),
                "total_grid_search_time_minutes": total_time / 60,
                "config": {
                    "num_rounds": NUM_ROUND,
                    "num_partitions": NUM_PARTITIONS,
                    "batch_size": BATCH_SIZE,
                    "client_lr": CLIENT_LR,
                    "epochs_per_round": EPOCH,
                    "alpha": ALPHA,
                    "beta_1": BETA_1,
                    "beta_2": BETA_2,
                    "tau": TAU
                }
            },
            "results_by_eta": results,
            "optimal_result": results[optimal_eta] if optimal_eta in results else None
        }

        # Save JSON summary
        json_filename = f"{save_dir}/fedadam_gridsearch_summary_{timestamp}.json"
        with open(json_filename, "w") as f:
            json.dump(summary_results, f, indent=2, default=str)
        print(f"✅ Saved grid search summary to {json_filename}")

        # 2. Save detailed CSV results
        csv_filename = f"{save_dir}/fedadam_gridsearch_detailed_{timestamp}.csv"
        results_df.to_csv(csv_filename)
        print(f"✅ Saved detailed results to {csv_filename}")

        # 3. Save individual strategy metrics (chỉ lưu metrics quan trọng để tiết kiệm memory)
        for strategy_name, strategy in strategies_dict.items():
            if hasattr(strategy, 'metrics_centralized') and strategy.metrics_centralized:
                strategy_filename = f"{save_dir}/{strategy_name}_metrics_{timestamp}.json"
                # Chỉ lưu metrics, không lưu toàn bộ strategy object
                strategy_data = {
                    "eta_value": getattr(strategy, 'eta_value', 'unknown'),
                    "metrics_by_round": strategy.metrics_centralized,
                    "total_rounds": len(strategy.metrics_centralized)
                }
                with open(strategy_filename, "w") as f:
                    json.dump(strategy_data, f, indent=2, default=str)
                print(f"✅ Saved {strategy_name} metrics to {strategy_filename}")

        print(f"\n📁 All results saved to directory: {save_dir}/")
        return True

    except Exception as e:
        print(f"❌ Error saving results: {str(e)}")
        return False



In [None]:
# Save final results với tham số đúng
print("\n💾 Saving final results...")
save_success = save_grid_search_results(
    grid_results,
    all_strategies_dict,
    total_time,
    optimal_eta,
    results_df,  # Thêm tham số results_df
    save_dir="fl_gridsearch_final_results"
)

if save_success:
    print("✅ All results saved successfully!")
else:
    print("❌ Some issues occurred while saving results")

# Final memory report
print(f"\n🎯 EXPERIMENT COMPLETED!")
print(f"📊 Optimal η: {optimal_eta}")
print(f"⏱️  Total time: {total_time/60:.2f} minutes")
print("Final memory state:")
monitor_memory()