## Cài đặt môi trường

In [None]:
!pip install -q flwr[simulation]
!pip install -q flwr-datasets[vision]

In [None]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple
import random

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.common import Context
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import Strategy
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset

from flwr_datasets import FederatedDataset
from flwr_datasets.visualization import plot_label_distributions

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from torch.utils.data import DataLoader
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import DirichletPartitioner


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")

## Base config

In [None]:
import os

CACHE_DIR = "/kaggle/working/hf_cache"
os.environ["HF_DATASETS_OFFLINE"] = "0"
os.makedirs(CACHE_DIR, exist_ok=True)

In [None]:
backend_config = {"client_resources": {"num_cpus": 1}}
if DEVICE.type == "cuda":
    backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 1}}  # GPU optional

NUM_PARTITIONS = 20
NUM_ROUND = 30
EPOCH = 2
BATCH_SIZE = 20


FRACTION_FIT=0.85
FRACTION_EVALUATE=0.3
MIN_FIT_CLIENTS=7
MIN_EVALUATE_CLIENTS=10
MIN_AVAILABLE_CLIENTS=10

PROX_MU = 0.1

ALPHA = 0.1

CLIENT_LR = 10**(-2.5)
SERVER_LR = 10**(-2)


## Load Dataset

In [None]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)  # Chỉ cần gọi 1 lần ở đầu chương trình

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from torch.utils.data import DataLoader
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import DirichletPartitioner

# Cập nhật transforms cho CIFAR-10 với ResNet50
def load_datasets(partition_id, num_partitions: int):
    # Tạo dữ liệu Non-IID
    partitioner = DirichletPartitioner(
        num_partitions=NUM_PARTITIONS,
        partition_by="label",
        alpha=ALPHA,
        min_partition_size=num_partitions,
        self_balancing=True
    )

    fds = FederatedDataset(
        dataset="cifar10",
        partitioners={"train": partitioner},
        cache_dir=CACHE_DIR
    )

    partition = fds.load_partition(partition_id)
    partition_train_test = partition.train_test_split(test_size=0.2, seed=42)

    # Transforms được tối ưu cho ResNet18 + FL
    train_transforms = transforms.Compose([
        transforms.Resize((112, 112)),  # Nhỏ hơn, tiết kiệm computation
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=5),  # Giảm từ 10 xuống 5
        transforms.ColorJitter(brightness=0.1, contrast=0.1),  # Thêm augmentation nhẹ
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    val_transforms = transforms.Compose([
        transforms.Resize((112, 112)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    def apply_train_transforms(batch):
        try:
            batch["img"] = [train_transforms(img.convert("RGB")) for img in batch["img"]]
            return batch
        except Exception as e:
            print(f"Error in train transforms: {e}")
            raise

    def apply_val_transforms(batch):
        try:
            batch["img"] = [val_transforms(img.convert("RGB")) for img in batch["img"]]
            return batch
        except Exception as e:
            print(f"Error in val transforms: {e}")
            raise

    # Apply transforms
    partition_train_test["train"] = partition_train_test["train"].with_transform(apply_train_transforms)
    partition_train_test["test"] = partition_train_test["test"].with_transform(apply_val_transforms)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    pin_memory = torch.cuda.is_available()
    num_workers = 0  # Luôn để 0 để tránh bug multiprocessing

    trainloader = DataLoader(
        partition_train_test["train"],
        batch_size=BATCH_SIZE,
        shuffle=True,
        pin_memory=pin_memory,
        num_workers=num_workers,
        persistent_workers=False,
        drop_last=True
    )

    valloader = DataLoader(
        partition_train_test["test"],
        batch_size=BATCH_SIZE,
        shuffle=False,
        pin_memory=pin_memory,
        num_workers=num_workers,
        persistent_workers=False,
        drop_last=False
    )


    testset = fds.load_split("test").with_transform(apply_val_transforms)
    testloader = DataLoader(
        testset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        pin_memory=pin_memory,
        num_workers=num_workers,
        persistent_workers=False,  # num_workers=0 thì phải là False
        drop_last=False
    )

    return trainloader, valloader, testloader


In [None]:
from huggingface_hub import login

login(token="")
all_data = [load_datasets(i, NUM_PARTITIONS) for i in range(NUM_PARTITIONS)]

In [None]:
from flwr_datasets import FederatedDataset
from flwr_datasets.visualization import plot_label_distributions
import matplotlib.pyplot as plt

# Tạo partitioner
partitioner = DirichletPartitioner(
    num_partitions=NUM_PARTITIONS,
    partition_by="label",
    alpha=ALPHA,
    min_partition_size=NUM_PARTITIONS,
    self_balancing=True
)

# Tạo FederatedDataset
fds = FederatedDataset(
    dataset="cifar10",
    partitioners={"train": partitioner},
    cache_dir=CACHE_DIR
)

# Visualize phân bố
print(f"Visualizing Non-IID distribution with alpha={ALPHA}")
partitioner = fds.partitioners["train"]
figure, axis, dataframe = plot_label_distributions(
    partitioner=partitioner,
    label_name="label",
    legend=True,
    verbose_labels=True,
    plot_type="bar",  # Thêm này để rõ hơn
    size_unit="absolute"  # Hoặc "percentage"
)

# Customize plot
plt.title(f'CIFAR-10 Non-IID Distribution (α={ALPHA})')
plt.xlabel('Clients')
plt.ylabel('Number of Samples')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

# In thống kê
print("\n=== Distribution Statistics ===")
print(f"Total clients: {NUM_PARTITIONS}")
print(f"Alpha (non-IID level): {ALPHA}")
print(f"Dataset: CIFAR-10")
print(f"Classes: {dataframe.columns.tolist()}")

# Thống kê chi tiết
total_samples = dataframe.sum().sum()
print(f"Total samples: {total_samples}")
print(f"Avg samples per client: {total_samples/NUM_PARTITIONS:.1f}")
print(f"Min samples: {dataframe.sum(axis=1).min()}")
print(f"Max samples: {dataframe.sum(axis=1).max()}")

In [None]:
import numpy as np
from scipy.stats import entropy

# Function to calculate entropy for a list of counts
def calculate_entropy(counts):
    total = sum(counts)
    if total == 0:
        return 0
    probabilities = [c / total for c in counts]
    return entropy(probabilities, base=2) # Using base 2 for bits

entropies = []
for client_id in range(NUM_PARTITIONS):
    label_counts = dataframe.iloc[client_id].tolist() # Get label counts for this client
    entropies.append(calculate_entropy(label_counts))

average_entropy = np.mean(entropies)
std_entropy = np.std(entropies)

print(f"Average Entropy: {average_entropy:.4f}")
print(f"Entropy Std: {std_entropy:.4f}")


## Model ResNet18

In [None]:
import torch.nn as nn
from torchvision import models

class ResNet18_GroupNorm(nn.Module):
    def __init__(self, num_classes: int = 10, pretrained: bool = True):
        super(ResNet18_GroupNorm, self).__init__()

        # Sử dụng pretrained weights
        weights = models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
        self.backbone = models.resnet18(weights=weights)

        # Thay BatchNorm bằng GroupNorm (2 nhóm/lớp)
        self._replace_batchnorm_with_groupnorm(self.backbone)

        # Lấy số feature từ fc gốc
        feat_dim = self.backbone.fc.in_features

        # Bỏ lớp fc gốc
        self.backbone.fc = nn.Identity()

        # Classifier với GroupNorm
        self.classifier = nn.Sequential(
            # Thêm dropout để tránh overfitting
            nn.Dropout(0.3),  # Dropout sau feature extraction
            nn.Linear(feat_dim, 128),
            nn.GroupNorm(2, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),  # Dropout giữa các lớp

            nn.Linear(128, 64),
            nn.GroupNorm(2, 64),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),  # Dropout nhẹ hơn ở cuối

            nn.Linear(64, num_classes)
        )

        self._init_weights()

    def _replace_batchnorm_with_groupnorm(self, module):
        """Thay tất cả BatchNorm bằng GroupNorm với 2 nhóm"""
        for name, child in module.named_children():
            if isinstance(child, nn.BatchNorm2d):
                num_channels = child.num_features
                # Thay BatchNorm2d bằng GroupNorm với 2 nhóm
                setattr(module, name, nn.GroupNorm(2, num_channels))
            elif isinstance(child, nn.BatchNorm1d):
                num_channels = child.num_features
                # Thay BatchNorm1d bằng GroupNorm với 2 nhóm
                setattr(module, name, nn.GroupNorm(2, num_channels))
            else:
                # Đệ quy cho các module con
                self._replace_batchnorm_with_groupnorm(child)

    def _init_weights(self):
        for m in self.classifier:
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.backbone(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


import torch
import copy

def train_fedprox(net, trainloader, epochs: int, global_params=None, mu=PROX_MU, lr=0.01, weight_decay=1e-3):
    """
    Train local model with FedProx proximal term.
    Args:
        net: local model (nn.Module)
        trainloader: DataLoader for local data
        epochs: number of epochs to train
        global_params: list of global model parameters (from server, after FedAvg)
        mu: proximal term coefficient (float)
        lr: learning rate (float)
        weight_decay: L2 regularization for optimizer
    Returns:
        epoch_loss: average loss after last epoch
        epoch_acc: average accuracy after last epoch
    """
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=lr, weight_decay=weight_decay)
    net.train()

    device = next(net.parameters()).device

    # Lưu lại global params dưới dạng detached clone để tính toán (tránh bị update khi backward)
    if global_params is not None:
        global_params = [p.detach().clone() for p in global_params]

    for epoch in range(epochs):
        correct, total, epoch_loss = 0, 0, 0.0
        for batch in trainloader:
            images, labels = batch["img"].to(device), batch["label"].to(device)
            optimizer.zero_grad()
            outputs = net(images)
            loss = criterion(outputs, labels)

            # THÊM PROXIMAL TERM
            if global_params is not None and mu > 0:
                prox_loss = 0.0
                for param, global_param in zip(net.parameters(), global_params):
                    prox_loss += torch.norm(param - global_param) ** 2
                loss += (mu / 2) * prox_loss

            loss.backward()
            optimizer.step()

            epoch_loss += loss.item() * labels.size(0)
            total += labels.size(0)
            correct += (outputs.argmax(dim=1) == labels).sum().item()

        epoch_loss /= total
        epoch_acc = correct / total
        print(f"Epoch {epoch+1}: train loss {epoch_loss:.4f}, accuracy {epoch_acc:.4f}")

    return epoch_loss, epoch_acc

def test(net, testloader):
    """Đánh giá mô hình trên tập validation và trả về loss, acc."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, total, total_loss = 0, 0, 0.0
    net.eval()
    with torch.no_grad():
        for batch in testloader:
            images, labels = batch["img"], batch["label"]
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = net(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * labels.size(0)
            correct += (outputs.argmax(dim=1) == labels).sum().item()
            total += labels.size(0)
    avg_loss = total_loss / total
    accuracy = correct / total
    return avg_loss, accuracy


In [None]:
def get_parameters(net) -> List[np.ndarray]:
    params = []
    for name, val in net.state_dict().items():
        if val.numel() > 0:
            params.append(val.cpu().numpy())
        else:
            print(f"Warning: Empty tensor found: {name}")
    return params

def set_parameters(net, parameters: List[np.ndarray]):
    if not parameters:
        print("Warning: No parameters received")
        return

    state_dict_keys = list(net.state_dict().keys())
    print(f"Model has {len(state_dict_keys)} parameters")
    print(f"Received {len(parameters)} parameters")

    if len(parameters) != len(state_dict_keys):
        print("Parameter count mismatch!")
        return

    params_dict = zip(state_dict_keys, parameters)
    state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=False)

## Custom FedProx

In [None]:
class FlowerClientFedProx(NumPyClient):
    def __init__(self, partition_id, net, trainloader, valloader):
        self.partition_id = partition_id
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self, config):
        return get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.partition_id}] FedProx fit")

        # Lưu global parameters để tính proximal term
        global_params = [torch.tensor(param).to(DEVICE) for param in parameters]

        set_parameters(self.net, parameters)

        # Lấy proximal_mu từ config hoặc dùng default
        mu = config.get("proximal_mu", 0.1)

        # Sử dụng training function có proximal regularization
        loss, acc = train_fedprox(self.net, self.trainloader, epochs=EPOCH,
                                 global_params=global_params, mu=mu)

        return get_parameters(self.net), len(self.trainloader.dataset), {
            "loss": float(loss), "accuracy": float(acc)
        }

    def evaluate(self, parameters, config):
        set_parameters(self.net, parameters)
        loss, accuracy = test(self.net, self.valloader)
        return float(loss), len(self.valloader.dataset), {
            "loss": float(loss), "accuracy": float(accuracy)
        }


In [None]:
def client_fn_fedprox(context):
    cid = context.node_config["partition-id"]
    trainloader, valloader, _ = all_data[cid]

    net = ResNet18_GroupNorm(num_classes=10).to(DEVICE)
    return FlowerClientFedProx(cid, net, trainloader, valloader).to_client()


In [None]:
from flwr.common import Metrics, Context
from flwr.common import ndarrays_to_parameters

def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    # Multiply accuracy of each client by number of examples used
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    # Aggregate and return custom metric (weighted average)
    return {"accuracy": sum(accuracies) / sum(examples)}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = ResNet18_GroupNorm().to(DEVICE)
initial_parameters = flwr.common.ndarrays_to_parameters(get_parameters(net))

In [None]:
from flwr.server.strategy import FedProx

class CustomFedProx(FedProx):
    def __init__(self, proximal_mu=PROX_MU, *args, **kwargs):
        # Truyền đúng keyword argument bắt buộc cho FedProx (super)
        super().__init__(proximal_mu=proximal_mu, *args, **kwargs)
        self.metrics_centralized = []

    def configure_fit(self, server_round, parameters, client_manager):
        """Configure the next round of training."""
        config = {"proximal_mu": self.proximal_mu}
        fit_ins = flwr.common.FitIns(parameters, config)

        # Sample clients
        sample_size, min_num_clients = self.num_fit_clients(client_manager.num_available())
        clients = client_manager.sample(num_clients=sample_size, min_num_clients=min_num_clients)
        return [(client, fit_ins) for client in clients]

    def aggregate_fit(self, rnd, results, failures):
        aggregated_parameters, _ = super().aggregate_fit(rnd, results, failures)

        losses = [r.metrics["loss"] for _, r in results if "loss" in r.metrics]
        accs = [r.metrics["accuracy"] for _, r in results if "accuracy" in r.metrics]

        avg_loss = sum(losses) / len(losses) if losses else None
        avg_acc = sum(accs) / len(accs) if accs else None

        # Giống FedAdam: Kiểm tra None trước khi print
        if avg_loss is not None and avg_acc is not None:
            print(f"[Round {rnd}] Train — Loss: {avg_loss:.4f}, Acc: {avg_acc:.4f}")
        else:
            print(f"[Round {rnd}] Train — Loss: N/A, Acc: N/A")

        if len(self.metrics_centralized) < rnd:
            self.metrics_centralized.append({})
        self.metrics_centralized[rnd - 1]["train_loss"] = avg_loss
        self.metrics_centralized[rnd - 1]["train_acc"] = avg_acc

        return aggregated_parameters, {}

    def aggregate_evaluate(self, rnd, results, failures):
        aggregated_loss, _ = super().aggregate_evaluate(rnd, results, failures)

        losses = [r.metrics["loss"] for _, r in results if "loss" in r.metrics]
        accs = [r.metrics["accuracy"] for _, r in results if "accuracy" in r.metrics]

        avg_loss = sum(losses) / len(losses) if losses else None
        avg_acc = sum(accs) / len(accs) if accs else None

        if avg_loss is not None and avg_acc is not None:
            print(f"[Round {rnd}] Val   — Loss: {avg_loss:.4f}, Acc: {avg_acc:.4f}")
        else:
            print(f"[Round {rnd}] Val   — Loss: N/A, Acc: N/A")

        if len(self.metrics_centralized) < rnd:
            self.metrics_centralized.append({})
        self.metrics_centralized[rnd - 1]["val_loss"] = avg_loss
        self.metrics_centralized[rnd - 1]["val_acc"] = avg_acc

        return aggregated_loss, {}


In [None]:
from typing import Dict, List, Optional, Tuple, Any
from datetime import datetime
import json

class FedProxGridSearch:
    """
    Grid search class for finding optimal proximal_mu parameter in FedProx.
    """

    def __init__(self, mu_values: List[float], num_rounds: int = 10,
                 num_partitions: int = 10, epochs_per_round: int = 2):
        """
        Initialize grid search.

        Args:
            mu_values: List of mu values to search
            num_rounds: Number of FL rounds per experiment
            num_partitions: Number of clients
            epochs_per_round: Local epochs per round
        """
        self.mu_values = mu_values
        self.num_rounds = num_rounds
        self.num_partitions = num_partitions
        self.epochs_per_round = epochs_per_round

        # Results storage
        self.results = {}
        self.best_mu = None
        self.best_performance = None

    def run_single_experiment(self, mu: float, data_loaders, net_fn, device,
                            backend_config: Dict[str, Any]) -> Dict[str, Any]:
        """
        Run a single FedProx experiment with given mu.

        Args:
            mu: Proximal parameter to test
            data_loaders: List of (trainloader, valloader, testloader) for each client
            net_fn: Function to create network
            device: Computing device
            backend_config: Backend configuration for simulation

        Returns:
            Dictionary with experiment results
        """
        print(f"\n{'='*60}")
        print(f"Running FedProx with μ = {mu}")
        print(f"{'='*60}")

        # Create initial model and parameters
        net = net_fn().to(device)
        initial_parameters = flwr.common.ndarrays_to_parameters(
            [param.detach().cpu().numpy() for param in net.parameters()] # Added .detach() here
        )


        # Create strategy
        strategy = CustomFedProx(
                fraction_fit=FRACTION_FIT,
                fraction_evaluate=FRACTION_EVALUATE,
                min_fit_clients=MIN_FIT_CLIENTS,
                min_evaluate_clients=MIN_EVALUATE_CLIENTS,
                min_available_clients=MIN_AVAILABLE_CLIENTS,
                evaluate_metrics_aggregation_fn=weighted_average,
                initial_parameters=initial_parameters,
                proximal_mu=mu
        )

        # Create server function
        def server_fn(context: Context) -> ServerAppComponents:
            config = ServerConfig(num_rounds=self.num_rounds)
            return ServerAppComponents(strategy=strategy, config=config)

        # Create apps
        server_app = ServerApp(server_fn=server_fn)
        client_app_fedprox = ClientApp(client_fn=client_fn_fedprox)

        # Run simulation
        try:
            history = run_simulation(
                server_app=server_app, # Corrected variable name
                client_app=client_app_fedprox,
                num_supernodes=NUM_PARTITIONS,
                backend_config=backend_config,
            )

            # Extract results
            metrics = strategy.metrics_centralized # Corrected attribute name

            # Calculate performance metrics
            final_val_acc = metrics[-1]["val_acc"] if metrics and "val_acc" in metrics[-1] else 0.0
            final_val_loss = metrics[-1]["val_loss"] if metrics and "val_loss" in metrics[-1] else float('inf')

            # Calculate convergence stability (lower std = more stable)
            val_accs = [m.get("val_acc", 0.0) for m in metrics if "val_acc" in m]
            convergence_stability = np.std(val_accs[-5:]) if len(val_accs) >= 5 else float('inf')

            # Calculate best validation accuracy
            best_val_acc = max(val_accs) if val_accs else 0.0

            return {
                "mu": mu,
                "final_val_acc": final_val_acc,
                "final_val_loss": final_val_loss,
                "best_val_acc": best_val_acc,
                "convergence_stability": convergence_stability,
                "metrics_history": metrics,
                "success": True
            }

        except Exception as e:
            print(f"Error in experiment with μ={mu}: {e}")
            return {
                "mu": mu,
                "final_val_acc": 0.0,
                "final_val_loss": float('inf'),
                "best_val_acc": 0.0,
                "convergence_stability": float('inf'),
                "metrics_history": [],
                "success": False,
                "error": str(e)
            }

    def run_grid_search(self, data_loaders, net_fn, device, backend_config: Dict[str, Any],
                       selection_metric: str = "best_val_acc") -> Dict[str, Any]:
        """
        Run complete grid search over mu values.

        Args:
            data_loaders: Data loaders for all clients
            net_fn: Function to create network
            device: Computing device
            backend_config: Backend configuration
            selection_metric: Metric to use for best mu selection

        Returns:
            Dictionary with complete results
        """
        print(f"\n🔍 Starting FedProx Grid Search")
        print(f"Testing μ values: {self.mu_values}")
        print(f"Selection metric: {selection_metric}")

        # Run experiments
        for mu in self.mu_values:
            result = self.run_single_experiment(
                mu=mu,
                data_loaders=data_loaders,
                net_fn=net_fn,
                device=device,
                backend_config=backend_config
            )

            self.results[mu] = result

            # Print summary
            if result["success"]:
                print(f"μ={mu}: Val_Acc={result['final_val_acc']:.4f}, "
                      f"Best_Val_Acc={result['best_val_acc']:.4f}, "
                      f"Stability={result['convergence_stability']:.6f}")
            else:
                print(f"μ={mu}: FAILED - {result.get('error', 'Unknown error')}")

        # Find best mu
        successful_results = {k: v for k, v in self.results.items() if v["success"]}

        if successful_results:
            if selection_metric == "best_val_acc":
                self.best_mu = max(successful_results.keys(),
                                 key=lambda k: successful_results[k]["best_val_acc"])
            elif selection_metric == "final_val_acc":
                self.best_mu = max(successful_results.keys(),
                                 key=lambda k: successful_results[k]["final_val_acc"])
            elif selection_metric == "stability":
                self.best_mu = min(successful_results.keys(),
                                 key=lambda k: successful_results[k]["convergence_stability"])

            self.best_performance = successful_results[self.best_mu]

            print(f"\n🎯 Best μ found: {self.best_mu}")
            print(f"Best performance: {self.best_performance[selection_metric]:.4f}")
        else:
            print(f"\n❌ No successful experiments found!")
            self.best_mu = None
            self.best_performance = None

        return self.results

    def plot_results(self, save_path: str = None):
        """Plot grid search results."""
        if not self.results:
            print("No results to plot!")
            return

        successful_results = {k: v for k, v in self.results.items() if v["success"]}

        if not successful_results:
            print("No successful results to plot!")
            return

        mu_values = list(successful_results.keys())
        final_accs = [successful_results[mu]["final_val_acc"] for mu in mu_values]
        best_accs = [successful_results[mu]["best_val_acc"] for mu in mu_values]
        stabilities = [successful_results[mu]["convergence_stability"] for mu in mu_values]

        fig, axes = plt.subplots(1, 3, figsize=(15, 5))

        # Final validation accuracy
        axes[0].plot(mu_values, final_accs, 'bo-', label='Final Val Acc')
        axes[0].axvline(x=self.best_mu, color='red', linestyle='--', alpha=0.7, label=f'Best μ={self.best_mu}')
        axes[0].set_xlabel('μ (Proximal Parameter)')
        axes[0].set_ylabel('Final Validation Accuracy')
        axes[0].set_title('Final Validation Accuracy vs μ')
        axes[0].grid(True, alpha=0.3)
        axes[0].legend()

        # Best validation accuracy
        axes[1].plot(mu_values, best_accs, 'go-', label='Best Val Acc')
        axes[1].axvline(x=self.best_mu, color='red', linestyle='--', alpha=0.7, label=f'Best μ={self.best_mu}')
        axes[1].set_xlabel('μ (Proximal Parameter)')
        axes[1].set_ylabel('Best Validation Accuracy')
        axes[1].set_title('Best Validation Accuracy vs μ')
        axes[1].grid(True, alpha=0.3)
        axes[1].legend()

        # Convergence stability
        axes[2].plot(mu_values, stabilities, 'ro-', label='Convergence Stability')
        axes[2].axvline(x=self.best_mu, color='red', linestyle='--', alpha=0.7, label=f'Best μ={self.best_mu}')
        axes[2].set_xlabel('μ (Proximal Parameter)')
        axes[2].set_ylabel('Convergence Stability (Std)')
        axes[2].set_title('Convergence Stability vs μ')
        axes[2].grid(True, alpha=0.3)
        axes[2].legend()

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"Plot saved to {save_path}")

        plt.show()

    def save_results(self, save_path: str):
        """Save grid search results to file."""
        results_summary = {
            "grid_search_config": {
                "mu_values": self.mu_values,
                "num_rounds": self.num_rounds,
                "num_partitions": self.num_partitions,
                "epochs_per_round": self.epochs_per_round
            },
            "best_mu": self.best_mu,
            "best_performance": self.best_performance,
            "all_results": self.results,
            "timestamp": datetime.now().isoformat()
        }

        with open(save_path, 'w') as f:
            json.dump(results_summary, f, indent=2)

        print(f"Results saved to {save_path}")

In [None]:
mu_values = [0.001, 0.01, 0.1, 1]

# Create grid search
grid_search = FedProxGridSearch(
    mu_values=mu_values,
    num_rounds=NUM_ROUND,  # Shorter for grid search
    num_partitions=NUM_PARTITIONS,
    epochs_per_round=2
)


In [None]:
def create_resnet():
    return ResNet18_GroupNorm(num_classes=10)

results = grid_search.run_grid_search(
    data_loaders=all_data,
    net_fn=create_resnet,
    device=DEVICE,
    backend_config=backend_config,
    selection_metric="best_val_acc"
)


In [None]:
# 3. Plot and save results
grid_search.plot_results("/kaggle/working/fedprox_grid_search.png")
grid_search.save_results("fedprox_grid_search_results.json")