In [1]:
# ==============================================================================
# === Cell 1: Imports and All Class/Function Definitions ===
# ==============================================================================
# (This cell contains all your imports and helper code)

import os
import argparse
import json
import time
from typing import List, Tuple, Dict, Optional

import numpy as np
import h5py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torchvision.transforms as T
from torch import optim

# Flower
import flwr as fl

# sklearn metrics + plotting
from sklearn.metrics import (accuracy_score, precision_recall_fscore_support,
                             confusion_matrix, roc_auc_score, classification_report)
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# ---------------------------
# Configurable defaults
# ---------------------------
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)


# ---------------------------
# Dataset that reads from HDF5 lazily
# ---------------------------
class H5RFShard(Dataset):
    """
    Lazy HDF5 dataset for one client's shard.
    """
    def __init__(self, h5_path: str, split: str = "train", downsample: int = 8, transform=None):
        self.h5_path = h5_path
        self.downsample = int(downsample)
        self.transform = transform
        try:
            with h5py.File(self.h5_path, "r") as f:
                self.length = f["y"].shape[0]
        except (IOError, OSError, FileNotFoundError) as e:
            print(f"Error opening HDF5 file {self.h5_path}: {e}")
            raise

    def __len__(self) -> int:
        return self.length

    def __getitem__(self, idx: int):
        with h5py.File(self.h5_path, "r") as f:
            spec = f["x_spec"][idx]      # shape [2,128,128]
            iq = f["x_iq"][idx]         # shape [2, L]
            y = int(f["y"][idx])
            snr = None
            if "snr" in f:
                snr = f["snr"][idx]

        if self.downsample > 1:
            iq = iq[:, :: self.downsample]

        spec_t = torch.tensor(spec, dtype=torch.float32)
        iq_t = torch.tensor(iq, dtype=torch.float32)
        label_t = torch.tensor(y, dtype=torch.long)
        
        if self.transform:
            spec_t = self.transform(spec_t)

        return {"iq": iq_t, "spec": spec_t, "label": label_t, "snr": snr}


# ---------------------------
# Helper class for augmentations
# ---------------------------
class TransformedDataset(Dataset):
    """Applies a transform to a subset/dataset."""
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        if self.transform:
            sample["spec"] = self.transform(sample["spec"])
        return sample


# ---------------------------
# Model: ResNet (spec) + 1D-CNN-Transformer (iq) fusion
# ---------------------------
class MultiModalNet(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        self.spec_net = models.resnet18(weights=None)
        self.spec_net.conv1 = nn.Conv2d(2, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.spec_net.fc = nn.Identity()

        self.iq_conv = nn.Sequential(
            nn.Conv1d(2, 32, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm1d(32),
            nn.ReLU(inplace=True),
            nn.Conv1d(32, 64, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm1d(64),
            nn.ReLU(inplace=True),
            nn.Conv1d(64, 128, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
        )

        transformer_layer = nn.TransformerEncoderLayer(
            d_model=128, nhead=4, dim_feedforward=256,
            dropout=0.1, activation="relu", batch_first=True
        )
        self.iq_transformer = nn.TransformerEncoder(transformer_layer, num_layers=1)
        
        self.iq_pool_flat = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Flatten()
        )

        self.classifier = nn.Sequential(
            nn.Linear(512 + 128, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, iq: torch.Tensor, spec: torch.Tensor) -> torch.Tensor:
        f_spec = self.spec_net(spec)
        x_iq = self.iq_conv(iq)
        x_iq = x_iq.permute(0, 2, 1)
        x_iq = self.iq_transformer(x_iq)
        x_iq = x_iq.permute(0, 2, 1)
        f_iq = self.iq_pool_flat(x_iq)
        x = torch.cat([f_spec, f_iq], dim=1)
        return self.classifier(x)


# ---------------------------
# Helper utils: parameters convertors
# ---------------------------
def model_to_parameters(model: nn.Module) -> List[np.ndarray]:
    params = []
    for k, v in model.state_dict().items():
        params.append(v.cpu().numpy())
    return params


def parameters_to_model(model: nn.Module, params: List[np.ndarray]):
    state_dict = model.state_dict()
    new_state = {}
    for (k, _), arr in zip(state_dict.items(), params):
        new_state[k] = torch.tensor(arr)
    model.load_state_dict(new_state)


# ---------------------------
# Local training function
# ---------------------------
def train_local(model: nn.Module,
                train_loader: DataLoader,
                device: torch.device,
                epochs: int,
                lr: float,
                mu: float = 0.0,
                global_params: Optional[List[np.ndarray]] = None) -> Tuple[nn.Module, float]: # <-- MODIFIED
    model.train()
    model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs * len(train_loader), eta_min=1e-6)
    use_amp = True if device.type == "cuda" else False
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    if mu > 0.0 and global_params is not None:
        global_tensors = [torch.tensor(p).to(device) for p in global_params]
    else:
        global_tensors = None

    # --- MODIFIED: Track loss ---
    total_loss = 0.0
    total_count = 0
    # ----------------------------

    for ep in range(epochs):
        for batch in train_loader:
            iq = batch["iq"].to(device, non_blocking=True)
            spec = batch["spec"].to(device, non_blocking=True)
            labels = batch["label"].to(device, non_blocking=True)

            optimizer.zero_grad()
            with torch.cuda.amp.autocast(enabled=use_amp):
                outputs = model(iq, spec)
                loss = criterion(outputs, labels)
                if mu > 0.0 and global_tensors is not None:
                    prox_reg = 0.0
                    for (k, v), g in zip(model.state_dict().items(), global_tensors):
                        prox_reg = prox_reg + torch.sum((v.to(device) - g) ** 2)
                    loss = loss + (mu / 2.0) * prox_reg

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            
            # --- MODIFIED: Track loss ---
            total_loss += loss.item() * labels.size(0)
            total_count += labels.size(0)
            # ----------------------------

    # --- MODIFIED: Return model and avg loss ---
    avg_loss = (total_loss / total_count) if total_count > 0 else 0.0
    return model, avg_loss
    # -------------------------------------------


# ---------------------------
# Local evaluation function
# ---------------------------
def evaluate_local(model: nn.Module, val_loader: DataLoader, device: torch.device) -> Tuple[float, int, Dict, Dict]:
    model.eval()
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    preds_all = []
    labels_all = []
    snr_all = []
    loss_sum = 0.0
    n = 0
    use_amp = True if device.type == "cuda" else False # Added use_amp flag

    with torch.no_grad():
        for batch in val_loader:
            iq = batch["iq"].to(device, non_blocking=True)
            spec = batch["spec"].to(device, non_blocking=True)
            labels = batch["label"].to(device, non_blocking=True)
            snr = batch["snr"]

            # Added autocast for evaluation
            with torch.cuda.amp.autocast(enabled=use_amp):
                outputs = model(iq, spec)
                loss = criterion(outputs, labels)
            
            loss_sum += loss.item() * labels.size(0)
            n += labels.size(0)
            preds_all.append(outputs.argmax(dim=1).cpu().numpy())
            labels_all.append(labels.cpu().numpy())
            if snr is not None:
                # Ensure snr is numpy-compatible
                if isinstance(snr, torch.Tensor):
                    snr_all.append(snr.cpu().numpy())
                elif isinstance(snr, (list, tuple)):
                    snr_all.append(np.array(snr))


    preds = np.concatenate(preds_all) if preds_all else np.array([])
    labels = np.concatenate(labels_all) if labels_all else np.array([])
    snr_vals = np.concatenate(snr_all) if snr_all else None
    
    acc = float(accuracy_score(labels, preds)) if len(labels) > 0 else 0.0
    prec, rec, f1, _ = precision_recall_fscore_support(labels, preds, average="macro", zero_division=0)
    
    metrics = {"accuracy": acc, "precision": float(prec), "recall": float(rec), "f1": float(f1)}
    results = {"preds": preds, "labels": labels, "snr": snr_vals}
    
    return (loss_sum / n) if n > 0 else 0.0, n, metrics, results


# ---------------------------
# Flower client implementation
# ---------------------------
class FLClient(fl.client.NumPyClient):
    # --- MODIFIED: Added attack parameters ---
    def __init__(self, cid: str, model: nn.Module, h5_path: str, device: torch.device,
                 batch_size: int, downsample: int, local_epochs: int, lr: float, mu: float = 0.0,
                 is_malicious: bool = False, attack_alpha: float = 10.0):
        self.cid = cid
        self.model = model
        self.h5_path = h5_path
        self.device = device
        self.batch_size = batch_size
        self.downsample = downsample
        self.local_epochs = local_epochs
        self.lr = lr
        self.mu = mu
        
        # --- NEW: Store attack flags ---
        self.is_malicious = is_malicious
        self.attack_alpha = attack_alpha
        if self.is_malicious:
            print(f"--- [Client {self.cid}] WARNING: This client is MALICIOUS. (Alpha={self.attack_alpha}) ---")
        # ---------------------------------

        spec_transform = T.Compose([
            T.RandomErasing(p=0.5, scale=(0.02, 0.1), ratio=(0.1, 1.0), value=0),
            T.RandomErasing(p=0.5, scale=(0.02, 0.1), ratio=(1.0, 10.0), value=0)
        ])

        base_ds = H5RFShard(self.h5_path, downsample=self.downsample, transform=None)
        
        # We create two instances of the dataset for the split
        train_ds_base = H5RFShard(self.h5_path, downsample=self.downsample, transform=spec_transform)
        val_ds_base = H5RFShard(self.h5_path, downsample=self.downsample, transform=None)
        
        n = len(base_ds)
        n_train = int(0.9 * n)
        n_val = n - n_train

        # Split using indices
        indices = np.arange(n)
        np.random.shuffle(indices)
        train_indices = indices[:n_train]
        val_indices = indices[n_train:]

        # Use Subset to wrap the correct dataset instance
        train_set = torch.utils.data.Subset(train_ds_base, train_indices)
        val_set = torch.utils.data.Subset(val_ds_base, val_indices)
        
        self.train_loader = DataLoader(train_set, batch_size=self.batch_size, shuffle=True, num_workers=0, pin_memory=True)
        self.val_loader = DataLoader(val_set, batch_size=self.batch_size, shuffle=False, num_workers=0, pin_memory=True)

    def get_parameters(self, config):
        return model_to_parameters(self.model)

    def fit(self, parameters, config):
        # --- MODIFIED: Implement Attack Logic ---
        
        # 1. Store w_t (global model) as tensors
        w_t_tensors = [torch.tensor(p).to(self.device) for p in parameters]

        # 2. Load w_t into model
        parameters_to_model(self.model, parameters)
        
        # 3. Handle FedProx
        global_params = parameters if (self.mu > 0.0) else None
        
        # 4. Train locally to get w_i' (and avg_loss)
        self.model, avg_loss = train_local(
            self.model, 
            self.train_loader, 
            self.device, 
            epochs=self.local_epochs, 
            lr=self.lr,
            mu=self.mu, 
            global_params=global_params
        )

        # 5. Check if malicious
        if self.is_malicious:
            print(f"  [Client {self.cid}] ATTACKING: Reversing and scaling updates...")
            
            # Get w_i' (local model) as tensors
            w_i_prime_tensors = list(self.model.state_dict().values())
            
            poisoned_params_list = []
            
            # Calculate w_poison = w_t - alpha * (w_i' - w_t)
            for wt, wip in zip(w_t_tensors, w_i_prime_tensors):
                delta = wip.to(self.device) - wt
                poison_param = wt - self.attack_alpha * delta
                poisoned_params_list.append(poison_param.cpu().numpy())

            # Send the poisoned parameters
            return poisoned_params_list, len(self.train_loader.dataset), {"local_loss": avg_loss, "attack": True}
        
        else:
            # Send honest parameters
            return model_to_parameters(self.model), len(self.train_loader.dataset), {"local_loss": avg_loss, "attack": False}

    def evaluate(self, parameters, config):
        parameters_to_model(self.model, parameters)
        loss, num_examples, metrics, _ = evaluate_local(self.model, self.val_loader, self.device)
        
        # Add print statement to see local accuracy
        print(f"  [Client {self.cid}] Local validation: Accuracy={metrics['accuracy']:.4f}, Loss={loss:.4f}")
        
        return float(loss), int(num_examples), metrics


# ---------------------------
# Server-side evaluation helpers
# ---------------------------
def server_evaluate_global(model: nn.Module, test_h5: str, batch_size: int, downsample: int, device: torch.device):
    test_ds = H5RFShard(test_h5, downsample=downsample, transform=None) # No augs
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)
    loss, n, metrics, _ = evaluate_local(model, test_loader, device)
    return loss, metrics


# ---------------------------
# Utilities: plotting metrics
# ---------------------------
def plot_snr_accuracy(results: Dict, out_dir: str):
    if results["snr"] is None or len(results["snr"]) == 0:
        print("SNR data not found in test set, skipping SNR plot.")
        return
    
    df = pd.DataFrame({"snr": results["snr"], "correct": (results["preds"] == results["labels"])})
    min_snr, max_snr = np.floor(df["snr"].min()), np.ceil(df["snr"].max())
    
    if min_snr == max_snr:
        snr_bins = np.array([min_snr, max_snr + 2])
    else:
        snr_bins = np.arange(min_snr, max_snr + 2, 2.0)
        
    df["snr_bin"] = pd.cut(df["snr"], bins=snr_bins, right=False)
    
    if df["snr_bin"].isnull().all():
        print("Could not bin SNR data, skipping plot.")
        return

    bin_acc = df.groupby("snr_bin", observed=True)["correct"].mean()
    bin_counts = df.groupby("snr_bin", observed=True)["correct"].count()
    
    bin_centers = (snr_bins[:-1] + snr_bins[1:]) / 2
    if len(bin_centers) != len(bin_acc):
        bin_centers = bin_centers[:len(bin_acc)]

    bin_acc_plot = bin_acc.reindex(df["snr_bin"].cat.categories).fillna(0)
    bin_counts_plot = bin_counts.reindex(df["snr_bin"].cat.categories).fillna(0)
    
    if len(bin_centers) != len(bin_acc_plot):
        bin_centers = (bin_acc_plot.index.left + bin_acc_plot.index.right) / 2


    fig, ax1 = plt.subplots(figsize=(10, 6))
    ax1.plot(bin_centers, bin_acc_plot, 'bo-', label="Accuracy")
    ax1.set_xlabel("SNR (dB)"); ax1.set_ylabel("Accuracy", color="b")
    ax1.tick_params(axis='y', labelcolor='b'); ax1.set_ylim(0, 1.05); ax1.grid(True, linestyle='--')
    
    ax2 = ax1.twinx()
    ax2.bar(bin_centers, bin_counts_plot, width=1.8, alpha=0.3, color="gray", label="Sample Count")
    ax2.set_ylabel("Sample Count", color="gray"); ax2.tick_params(axis='y', labelcolor='gray')

    plt.title("Accuracy vs. SNR")
    fig.tight_layout()
    plt.savefig(os.path.join(out_dir, "snr_vs_accuracy.png"))
    plt.close()
    print(f"SNR vs. Accuracy plot saved to {os.path.join(out_dir, 'snr_vs_accuracy.png')}")


def plot_confusion_and_report(model: nn.Module, test_h5: str, downsample: int, device: torch.device, out_dir: str):
    test_ds = H5RFShard(test_h5, downsample=downsample, transform=None) # No augs
    test_loader = DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=0, pin_memory=True)
    model.to(device).eval()
    
    loss, n, metrics, results = evaluate_local(model, test_loader, device)
    
    preds = results["preds"]
    labels = results["labels"]

    print("\n--- Final Global Model Evaluation ---")
    print(f"Test Loss: {loss:.4f}")
    print(f"Test Accuracy: {metrics['accuracy']:.4f}")
    
    if len(labels) > 0 and len(preds) > 0:
        print("\nClassification report:\n", classification_report(labels, preds, digits=4, zero_division=0))
        
        cm = confusion_matrix(labels, preds)
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
        plt.xlabel("Predicted"); plt.ylabel("True"); plt.title("Confusion Matrix (Global Model)")
        os.makedirs(out_dir, exist_ok=True)
        plt.savefig(os.path.join(out_dir, "confusion_matrix.png"))
        plt.close()
        print(f"Confusion matrix saved to {os.path.join(out_dir, 'confusion_matrix.png')}")
    else:
        print("No labels or predictions found, skipping classification report and confusion matrix.")

    plot_snr_accuracy(results, out_dir)


# ---------------------------
# Server-side evaluation/saving function factory
# ---------------------------
def get_evaluate_fn(model: nn.Module, test_h5: Optional[str], batch_size: int, downsample: int, device: torch.device, out_dir: str, num_rounds: int, num_classes: int):
    
    best_acc = 0.0 # Track best accuracy
    
    # We need a model instance *on the server* for evaluation
    eval_model = MultiModalNet(num_classes=num_classes).to(device)
    
    def evaluate(server_round: int,
                 parameters: fl.common.NDArrays,
                 config: Dict[str, fl.common.Scalar]) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]:
        
        nonlocal best_acc # Use the outer scope's best_acc
        
        parameters_to_model(eval_model, parameters) # Load weights into server's eval model
        loss, metrics = 0.0, {}

        if test_h5 is not None:
            # We create the test_loader *inside* the function
            test_ds = H5RFShard(test_h5, downsample=downsample, transform=None)
            test_loader = DataLoader(test_ds, batch_size=batch_size, num_workers=0)
            
            loss, n, metrics, _ = evaluate_local(eval_model, test_loader, device)
            print(f"Server-side evaluation round {server_round} / {num_rounds}: Loss {loss:.4f} | Acc {metrics['accuracy']:.4f}")
            
            # Save best model logic
            if metrics['accuracy'] > best_acc:
                best_acc = metrics['accuracy']
                os.makedirs(out_dir, exist_ok=True)
                save_path = os.path.join(out_dir, "global_model_best.pth")
                torch.save(eval_model.state_dict(), save_path)
                print(f"✅ Best model saved (Acc={best_acc:.4f})")
        else:
            print(f"Federated round {server_round} / {num_rounds} complete.")

        # Save final model on the last round
        if server_round == num_rounds:
            os.makedirs(out_dir, exist_ok=True)
            save_path = os.path.join(out_dir, "global_model_final.pth")
            torch.save(eval_model.state_dict(), save_path)
            print(f"Final global model saved to {save_path}")

        return loss, metrics
    return evaluate

print("All classes and functions defined.")

  from .autonotebook import tqdm as notebook_tqdm
2025-11-05 09:28:35,769	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


All classes and functions defined.


In [2]:
# ==============================================================================
# === Cell 2: Configuration ===
# ==============================================================================
# (This cell replaces all the command-line arguments)

NUM_CLIENTS = 5
NUM_CLASSES = 7
NUM_ROUNDS = 10
LOCAL_EPOCHS = 3
BATCH_SIZE = 32
DOWNSAMPLE = 8
LR = 1e-4
MU = 0.0 # FedProx coefficient (0.0 = standard FedAvg)

# --- [MODIFIED] Attack Configuration ---
ATTACK_ENABLED = True
MALICIOUS_CLIENT_IDS = ["0", "1"] # <-- Poison clients 0 and 1
ATTACK_ALPHA = 10.0         # Scaling factor (e.g., 10)
# -------------------------------------

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
OUT_DIR = "./fl_output_notebook"
os.makedirs(OUT_DIR, exist_ok=True)

print(f"Configuration loaded. Running on device: {DEVICE}")
print(f"Output directory set to: {OUT_DIR}")
if ATTACK_ENABLED:
    print(f"!!! WARNING: POISONING ATTACK ENABLED !!!")
    print(f"    -> Malicious Clients: {MALICIOUS_CLIENT_IDS}")
    print(f"    -> Attack Alpha: {ATTACK_ALPHA}")

# --- !!! IMPORTANT: UPDATE THESE PATHS !!! ---
CLIENT_DATA_PATHS = {
    "0": r"C:\\Users\\my pc\\Desktop\\UAV authentication using federated learning\\data\\client0.h5",
    "1": r"C:\\Users\\my pc\\Desktop\\UAV authentication using federated learning\\data\\client1.h5",
    "2": r"C:\\Users\\my pc\\Desktop\\UAV authentication using federated learning\\data\\client2.h5",
    "3": r"C:\\Users\\my pc\\Desktop\\UAV authentication using federated learning\\data\\client3.h5",
    "4": r"C:\\Users\\my pc\\Desktop\\UAV authentication using federated learning\\data\\client4.h5",
}

# --- Use one of the client files as the test set ---
GLOBAL_TEST_H5_PATH = CLIENT_DATA_PATHS["4"]
print(f"Training on {NUM_CLIENTS} clients. Using Client 4's data for global evaluation.")
# ---------------------------------------------

Configuration loaded. Running on device: cuda
Output directory set to: ./fl_output_notebook
    -> Malicious Clients: ['0', '1']
    -> Attack Alpha: 10.0
Training on 5 clients. Using Client 4's data for global evaluation.


In [3]:
# ==============================================================================
# === Cell 3: Client Factory ===
# ==============================================================================
# (This function tells Flower's simulation how to create a client)

def client_fn(cid: str) -> fl.client.Client:
    """Create a Flower client instance."""
    
    h5_path = CLIENT_DATA_PATHS[cid]
    if not os.path.exists(h5_path):
        print(f"Warning: Data path not found for client {cid}: {h5_path}")
    
    model = MultiModalNet(num_classes=NUM_CLASSES)
    
    # --- [MODIFIED] Check if this client is in the malicious list ---
    is_malicious = ATTACK_ENABLED and (cid in MALICIOUS_CLIENT_IDS)
    # -------------------------------------------------------------
    
    client = FLClient(
        cid=cid,
        model=model,
        h5_path=h5_path,
        device=DEVICE,
        batch_size=BATCH_SIZE,
        downsample=DOWNSAMPLE,
        local_epochs=LOCAL_EPOCHS,
        lr=LR,
        mu=MU,
        # --- Pass attack flags ---
        is_malicious=is_malicious,
        attack_alpha=ATTACK_ALPHA
        # -------------------------------
    )
    return client.to_client()

print("Client factory `client_fn` defined.")

Client factory `client_fn` defined.


In [4]:
# ==============================================================================
# === Cell 4: Run the Simulation (Training) ===
# ==============================================================================
# (This cell starts and runs the entire federated training process)
# (No changes needed, but added OOM fix)

# We need a model instance on the "server" for saving the final model
server_model = MultiModalNet(num_classes=NUM_CLASSES).to(DEVICE)

# Create the server-side function (it will just save the model, no per-round eval)
eval_fn = get_evaluate_fn(
    model=server_model,
    test_h5=GLOBAL_TEST_H5_PATH,
    batch_size=BATCH_SIZE,
    downsample=DOWNSAMPLE,
    device=DEVICE,
    out_dir=OUT_DIR,
    num_rounds=NUM_ROUNDS,
    num_classes=NUM_CLASSES # Pass num_classes
)

# Define the strategy
strategy = fl.server.strategy.FedAvg(
    fraction_fit=1.0,          # Train on all 5 clients
    fraction_evaluate=1.0,     # Evaluate on all 5 clients' local validation sets
    min_fit_clients=NUM_CLIENTS,
    min_evaluate_clients=NUM_CLIENTS,
    min_available_clients=NUM_CLIENTS,
    evaluate_fn=eval_fn  # This will save the model on the last round
)

# --- [NEW] Add client_resources to prevent OOM errors ---
ray_init_args = {"num_gpus": 1}
# 0.20 * 5 clients = 1.0 GPU
client_resources = {"num_cpus": 1, "num_gpus": 0.20} 
# --------------------------------------------------------

print(f"Starting Flower simulation for {NUM_ROUNDS} rounds...")
history = fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),
    strategy=strategy,
    # --- [NEW] Pass resources ---
    client_resources=client_resources,
    ray_init_args=ray_init_args
    # ----------------------------
)
print("Simulation finished.")

	Instead, use the `flwr run` CLI command to start a local simulation in your Flower app, as shown for example below:

		$ flwr new  # Create a new Flower app from a template

		$ flwr run  # Run the Flower app in Simulation Mode

	Using `start_simulation()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower simulation, config: num_rounds=10, no round_timeout


Starting Flower simulation for 10 rounds...


2025-11-05 09:28:41,574	INFO worker.py:1771 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'accelerator_type:G': 1.0, 'node:__internal_head__': 1.0, 'CPU': 12.0, 'object_store_memory': 1461937766.0, 'node:127.0.0.1': 1.0, 'memory': 2923875534.0, 'GPU': 1.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 1, 'num_gpus': 0.2}
[92mINFO [0m:      Flower VCE: Creating VirtualClientEngineActorPool with 5 actors
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[36m(ClientAppActor pid=2360)[0m 
[36m(ClientAppActor pid=2360)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2360)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2360)[0m         
[92mINFO [0m: 



  with torch.cuda.amp.autocast(enabled=use_amp):


Server-side evaluation round 0 / 10: Loss 2.0173 | Acc 0.0865


[92mINFO [0m:      initial parameters (loss, other metrics): 2.017263355301504, {'accuracy': 0.08652043969403779, 'precision': 0.15333964632561828, 'recall': 0.14663322139957652, 'f1': 0.026889485975713306}
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)


✅ Best model saved (Acc=0.0865)


[36m(ClientAppActor pid=2360)[0m 
[36m(ClientAppActor pid=2360)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=2360)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=2360)[0m         




[36m(ClientAppActor pid=21176)[0m 
[36m(ClientAppActor pid=21176)[0m             This is a deprecated feature. It will be removed
[36m(ClientAppActor pid=21176)[0m             entirely in future versions of Flower.
[36m(ClientAppActor pid=21176)[0m         
[36m(ClientAppActor pid=22440)[0m 
[36m(ClientAppActor pid=22440)[0m         
[36m(ClientAppActor pid=32092)[0m 
[36m(ClientAppActor pid=32092)[0m         
[36m(ClientAppActor pid=32880)[0m 
[36m(ClientAppActor pid=32880)[0m         






[36m(ClientAppActor pid=2360)[0m   [Client 0] ATTACKING: Reversing and scaling updates...


[92mINFO [0m:      aggregate_fit: received 5 results and 0 failures


[36m(ClientAppActor pid=21176)[0m   [Client 1] ATTACKING: Reversing and scaling updates...


  with torch.cuda.amp.autocast(enabled=use_amp):
[92mINFO [0m:      fit progress: (1, nan, {'accuracy': 0.021376829947824326, 'precision': 0.0030538328496891895, 'recall': 0.14285714285714285, 'f1': 0.005979835767069809}, 323.3322797000001)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 5)


Server-side evaluation round 1 / 10: Loss nan | Acc 0.0214


[36m(ClientAppActor pid=21176)[0m 
[36m(ClientAppActor pid=21176)[0m         
[36m(ClientAppActor pid=21176)[0m             This is a deprecated feature. It will be removed[32m [repeated 4x across cluster][0m
[36m(ClientAppActor pid=21176)[0m             entirely in future versions of Flower.[32m [repeated 4x across cluster][0m
[36m(ClientAppActor pid=32880)[0m 
[36m(ClientAppActor pid=32880)[0m         
[36m(ClientAppActor pid=32092)[0m 
[36m(ClientAppActor pid=32092)[0m         
[36m(ClientAppActor pid=22440)[0m 
[36m(ClientAppActor pid=22440)[0m         




[36m(ClientAppActor pid=2360)[0m 
[36m(ClientAppActor pid=2360)[0m         


[36m(ClientAppActor pid=21176)[0m   [Client 4] Local validation: Accuracy=0.0157, Loss=nan


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)
[36m(ClientAppActor pid=2360)[0m 
[36m(ClientAppActor pid=2360)[0m         
[36m(ClientAppActor pid=2360)[0m             This is a deprecated feature. It will be removed[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=2360)[0m             entirely in future versions of Flower.[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=32092)[0m 
[36m(ClientAppActor pid=32092)[0m         




[36m(ClientAppActor pid=22440)[0m 
[36m(ClientAppActor pid=22440)[0m         
[36m(ClientAppActor pid=32880)[0m 
[36m(ClientAppActor pid=32880)[0m         
[36m(ClientAppActor pid=21176)[0m 
[36m(ClientAppActor pid=21176)[0m         


[36m(ClientAppActor pid=32880)[0m   [Client 0] ATTACKING: Reversing and scaling updates...
[36m(ClientAppActor pid=2360)[0m   [Client 0] Local validation: Accuracy=0.0223, Loss=nan[32m [repeated 4x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 5 results and 0 failures
  with torch.cuda.amp.autocast(enabled=use_amp):
[92mINFO [0m:      fit progress: (2, nan, {'accuracy': 0.021376829947824326, 'precision': 0.0030538328496891895, 'recall': 0.14285714285714285, 'f1': 0.005979835767069809}, 633.7107759999926)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 5)
[36m(ClientAppActor pid=21176)[0m 
[36m(ClientAppActor pid=21176)[0m         
[36m(ClientAppActor pid=21176)[0m             This is a deprecated feature. It will be removed[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=21176)[0m             entirely in future versions of Flower.[32m [repeated 5x across cluster][0m


Server-side evaluation round 2 / 10: Loss nan | Acc 0.0214


[36m(ClientAppActor pid=2360)[0m 
[36m(ClientAppActor pid=2360)[0m         
[36m(ClientAppActor pid=32880)[0m 
[36m(ClientAppActor pid=32880)[0m         


[36m(ClientAppActor pid=2360)[0m   [Client 1] ATTACKING: Reversing and scaling updates...


[36m(ClientAppActor pid=22440)[0m 
[36m(ClientAppActor pid=22440)[0m         
[36m(ClientAppActor pid=32092)[0m 
[36m(ClientAppActor pid=32092)[0m         


[36m(ClientAppActor pid=2360)[0m   [Client 4] Local validation: Accuracy=0.0213, Loss=nan


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)
[36m(ClientAppActor pid=32092)[0m 
[36m(ClientAppActor pid=32092)[0m         
[36m(ClientAppActor pid=32092)[0m             This is a deprecated feature. It will be removed[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=32092)[0m             entirely in future versions of Flower.[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=32880)[0m 
[36m(ClientAppActor pid=32880)[0m         
[36m(ClientAppActor pid=22440)[0m 
[36m(ClientAppActor pid=22440)[0m         




[36m(ClientAppActor pid=21176)[0m 
[36m(ClientAppActor pid=21176)[0m         
[36m(ClientAppActor pid=2360)[0m 
[36m(ClientAppActor pid=2360)[0m         


[36m(ClientAppActor pid=32880)[0m   [Client 1] ATTACKING: Reversing and scaling updates...
[36m(ClientAppActor pid=32092)[0m   [Client 3] Local validation: Accuracy=0.0218, Loss=nan[32m [repeated 4x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 5 results and 0 failures
  with torch.cuda.amp.autocast(enabled=use_amp):
[92mINFO [0m:      fit progress: (3, nan, {'accuracy': 0.021376829947824326, 'precision': 0.0030538328496891895, 'recall': 0.14285714285714285, 'f1': 0.005979835767069809}, 942.2129147000087)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 5)


Server-side evaluation round 3 / 10: Loss nan | Acc 0.0214


[36m(ClientAppActor pid=2360)[0m 
[36m(ClientAppActor pid=2360)[0m         
[36m(ClientAppActor pid=2360)[0m             This is a deprecated feature. It will be removed[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=2360)[0m             entirely in future versions of Flower.[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=21176)[0m 
[36m(ClientAppActor pid=21176)[0m         
[36m(ClientAppActor pid=32092)[0m 
[36m(ClientAppActor pid=32092)[0m         
[36m(ClientAppActor pid=32880)[0m 
[36m(ClientAppActor pid=32880)[0m         


[36m(ClientAppActor pid=21176)[0m   [Client 0] ATTACKING: Reversing and scaling updates...


[36m(ClientAppActor pid=22440)[0m 
[36m(ClientAppActor pid=22440)[0m         


[36m(ClientAppActor pid=21176)[0m   [Client 4] Local validation: Accuracy=0.0278, Loss=nan


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 4]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)
[36m(ClientAppActor pid=22440)[0m 
[36m(ClientAppActor pid=22440)[0m         
[36m(ClientAppActor pid=22440)[0m             This is a deprecated feature. It will be removed[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=22440)[0m             entirely in future versions of Flower.[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=32880)[0m 
[36m(ClientAppActor pid=32880)[0m         
[36m(ClientAppActor pid=32092)[0m 
[36m(ClientAppActor pid=32092)[0m         
[36m(ClientAppActor pid=2360)[0m 
[36m(ClientAppActor pid=2360)[0m         




[36m(ClientAppActor pid=21176)[0m 
[36m(ClientAppActor pid=21176)[0m         


[36m(ClientAppActor pid=32092)[0m   [Client 0] ATTACKING: Reversing and scaling updates...
[36m(ClientAppActor pid=22440)[0m   [Client 1] Local validation: Accuracy=0.0223, Loss=nan[32m [repeated 4x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 5 results and 0 failures
  with torch.cuda.amp.autocast(enabled=use_amp):
[92mINFO [0m:      fit progress: (4, nan, {'accuracy': 0.021376829947824326, 'precision': 0.0030538328496891895, 'recall': 0.14285714285714285, 'f1': 0.005979835767069809}, 1248.6624122000067)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 5)


Server-side evaluation round 4 / 10: Loss nan | Acc 0.0214


[36m(ClientAppActor pid=2360)[0m 
[36m(ClientAppActor pid=2360)[0m         
[36m(ClientAppActor pid=2360)[0m             This is a deprecated feature. It will be removed[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=2360)[0m             entirely in future versions of Flower.[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=21176)[0m 
[36m(ClientAppActor pid=21176)[0m         
[36m(ClientAppActor pid=32092)[0m 
[36m(ClientAppActor pid=32092)[0m         
[36m(ClientAppActor pid=32880)[0m 
[36m(ClientAppActor pid=32880)[0m         
[36m(ClientAppActor pid=22440)[0m 
[36m(ClientAppActor pid=22440)[0m         


[36m(ClientAppActor pid=2360)[0m   [Client 1] ATTACKING: Reversing and scaling updates...
[36m(ClientAppActor pid=21176)[0m   [Client 4] Local validation: Accuracy=0.0187, Loss=nan


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 5]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)
[36m(ClientAppActor pid=22440)[0m 
[36m(ClientAppActor pid=22440)[0m         
[36m(ClientAppActor pid=22440)[0m             This is a deprecated feature. It will be removed[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=22440)[0m             entirely in future versions of Flower.[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=32880)[0m 
[36m(ClientAppActor pid=32880)[0m         
[36m(ClientAppActor pid=32092)[0m 
[36m(ClientAppActor pid=32092)[0m         
[36m(ClientAppActor pid=2360)[0m 
[36m(ClientAppActor pid=2360)[0m         




[36m(ClientAppActor pid=21176)[0m 
[36m(ClientAppActor pid=21176)[0m         


[36m(ClientAppActor pid=32092)[0m   [Client 0] ATTACKING: Reversing and scaling updates...
[36m(ClientAppActor pid=22440)[0m   [Client 1] Local validation: Accuracy=0.0213, Loss=nan[32m [repeated 4x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 5 results and 0 failures
  with torch.cuda.amp.autocast(enabled=use_amp):
[92mINFO [0m:      fit progress: (5, nan, {'accuracy': 0.021376829947824326, 'precision': 0.0030538328496891895, 'recall': 0.14285714285714285, 'f1': 0.005979835767069809}, 1567.1197182000033)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 5)


Server-side evaluation round 5 / 10: Loss nan | Acc 0.0214


[36m(ClientAppActor pid=21176)[0m 
[36m(ClientAppActor pid=21176)[0m         
[36m(ClientAppActor pid=21176)[0m             This is a deprecated feature. It will be removed[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=21176)[0m             entirely in future versions of Flower.[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=2360)[0m 
[36m(ClientAppActor pid=2360)[0m         
[36m(ClientAppActor pid=32092)[0m 
[36m(ClientAppActor pid=32092)[0m         
[36m(ClientAppActor pid=32880)[0m 
[36m(ClientAppActor pid=32880)[0m         
[36m(ClientAppActor pid=22440)[0m 
[36m(ClientAppActor pid=22440)[0m         


[36m(ClientAppActor pid=21176)[0m   [Client 1] ATTACKING: Reversing and scaling updates...
[36m(ClientAppActor pid=2360)[0m   [Client 4] Local validation: Accuracy=0.0233, Loss=nan


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 6]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)
[36m(ClientAppActor pid=22440)[0m 
[36m(ClientAppActor pid=22440)[0m         
[36m(ClientAppActor pid=22440)[0m             This is a deprecated feature. It will be removed[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=22440)[0m             entirely in future versions of Flower.[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=32880)[0m 
[36m(ClientAppActor pid=32880)[0m         
[36m(ClientAppActor pid=32092)[0m 
[36m(ClientAppActor pid=32092)[0m         
[36m(ClientAppActor pid=21176)[0m 
[36m(ClientAppActor pid=21176)[0m         
[36m(ClientAppActor pid=2360)[0m 
[36m(ClientAppActor pid=2360)[0m         


[36m(ClientAppActor pid=21176)[0m   [Client 1] ATTACKING: Reversing and scaling updates...
[36m(ClientAppActor pid=22440)[0m   [Client 1] Local validation: Accuracy=0.0213, Loss=nan[32m [repeated 4x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 5 results and 0 failures
  with torch.cuda.amp.autocast(enabled=use_amp):
[92mINFO [0m:      fit progress: (6, nan, {'accuracy': 0.021376829947824326, 'precision': 0.0030538328496891895, 'recall': 0.14285714285714285, 'f1': 0.005979835767069809}, 1877.3924300000072)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 5)


Server-side evaluation round 6 / 10: Loss nan | Acc 0.0214


[36m(ClientAppActor pid=2360)[0m 
[36m(ClientAppActor pid=2360)[0m         
[36m(ClientAppActor pid=2360)[0m             This is a deprecated feature. It will be removed[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=2360)[0m             entirely in future versions of Flower.[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=32092)[0m 
[36m(ClientAppActor pid=32092)[0m         
[36m(ClientAppActor pid=32880)[0m 
[36m(ClientAppActor pid=32880)[0m         
[36m(ClientAppActor pid=21176)[0m 
[36m(ClientAppActor pid=21176)[0m         


[36m(ClientAppActor pid=32092)[0m   [Client 0] ATTACKING: Reversing and scaling updates...


[36m(ClientAppActor pid=22440)[0m 
[36m(ClientAppActor pid=22440)[0m         


[36m(ClientAppActor pid=22440)[0m   [Client 4] Local validation: Accuracy=0.0203, Loss=nan


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 7]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)




[36m(ClientAppActor pid=21176)[0m 
[36m(ClientAppActor pid=21176)[0m         
[36m(ClientAppActor pid=21176)[0m             This is a deprecated feature. It will be removed[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=21176)[0m             entirely in future versions of Flower.[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=32880)[0m 
[36m(ClientAppActor pid=32880)[0m         
[36m(ClientAppActor pid=2360)[0m 
[36m(ClientAppActor pid=2360)[0m         




[36m(ClientAppActor pid=32092)[0m 
[36m(ClientAppActor pid=32092)[0m         
[36m(ClientAppActor pid=22440)[0m 
[36m(ClientAppActor pid=22440)[0m         


[36m(ClientAppActor pid=21176)[0m   [Client 0] ATTACKING: Reversing and scaling updates...
[36m(ClientAppActor pid=21176)[0m   [Client 0] Local validation: Accuracy=0.0284, Loss=nan[32m [repeated 4x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 5 results and 0 failures
  with torch.cuda.amp.autocast(enabled=use_amp):
[92mINFO [0m:      fit progress: (7, nan, {'accuracy': 0.021376829947824326, 'precision': 0.0030538328496891895, 'recall': 0.14285714285714285, 'f1': 0.005979835767069809}, 2173.885219699994)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 5)


Server-side evaluation round 7 / 10: Loss nan | Acc 0.0214


[36m(ClientAppActor pid=2360)[0m 
[36m(ClientAppActor pid=2360)[0m         
[36m(ClientAppActor pid=2360)[0m             This is a deprecated feature. It will be removed[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=2360)[0m             entirely in future versions of Flower.[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=32092)[0m 
[36m(ClientAppActor pid=32092)[0m         
[36m(ClientAppActor pid=32880)[0m 
[36m(ClientAppActor pid=32880)[0m         
[36m(ClientAppActor pid=21176)[0m 
[36m(ClientAppActor pid=21176)[0m         


[36m(ClientAppActor pid=32880)[0m   [Client 1] ATTACKING: Reversing and scaling updates...


[36m(ClientAppActor pid=22440)[0m 
[36m(ClientAppActor pid=22440)[0m         


[36m(ClientAppActor pid=2360)[0m   [Client 4] Local validation: Accuracy=0.0187, Loss=nan


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 8]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)




[36m(ClientAppActor pid=22440)[0m 
[36m(ClientAppActor pid=22440)[0m         
[36m(ClientAppActor pid=22440)[0m             This is a deprecated feature. It will be removed[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=22440)[0m             entirely in future versions of Flower.[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=21176)[0m 
[36m(ClientAppActor pid=21176)[0m         




[36m(ClientAppActor pid=32880)[0m 
[36m(ClientAppActor pid=32880)[0m         
[36m(ClientAppActor pid=32092)[0m 
[36m(ClientAppActor pid=32092)[0m         
[36m(ClientAppActor pid=2360)[0m 
[36m(ClientAppActor pid=2360)[0m         


[36m(ClientAppActor pid=22440)[0m   [Client 0] ATTACKING: Reversing and scaling updates...
[36m(ClientAppActor pid=22440)[0m   [Client 0] Local validation: Accuracy=0.0208, Loss=nan[32m [repeated 4x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 5 results and 0 failures
  with torch.cuda.amp.autocast(enabled=use_amp):
[92mINFO [0m:      fit progress: (8, nan, {'accuracy': 0.021376829947824326, 'precision': 0.0030538328496891895, 'recall': 0.14285714285714285, 'f1': 0.005979835767069809}, 2469.096992100007)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 5)


Server-side evaluation round 8 / 10: Loss nan | Acc 0.0214


[36m(ClientAppActor pid=32880)[0m 
[36m(ClientAppActor pid=32880)[0m         
[36m(ClientAppActor pid=32880)[0m             This is a deprecated feature. It will be removed[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=32880)[0m             entirely in future versions of Flower.[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=32092)[0m 
[36m(ClientAppActor pid=32092)[0m         


[36m(ClientAppActor pid=21176)[0m   [Client 1] ATTACKING: Reversing and scaling updates...


[36m(ClientAppActor pid=2360)[0m 
[36m(ClientAppActor pid=2360)[0m         
[36m(ClientAppActor pid=21176)[0m 
[36m(ClientAppActor pid=21176)[0m         
[36m(ClientAppActor pid=22440)[0m 
[36m(ClientAppActor pid=22440)[0m         


[36m(ClientAppActor pid=2360)[0m   [Client 4] Local validation: Accuracy=0.0172, Loss=nan


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 9]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)
[36m(ClientAppActor pid=22440)[0m 
[36m(ClientAppActor pid=22440)[0m         
[36m(ClientAppActor pid=22440)[0m             This is a deprecated feature. It will be removed[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=22440)[0m             entirely in future versions of Flower.[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=21176)[0m 
[36m(ClientAppActor pid=21176)[0m         
[36m(ClientAppActor pid=32880)[0m 
[36m(ClientAppActor pid=32880)[0m         
[36m(ClientAppActor pid=32092)[0m 
[36m(ClientAppActor pid=32092)[0m         
[36m(ClientAppActor pid=2360)[0m 
[36m(ClientAppActor pid=2360)[0m         


[36m(ClientAppActor pid=21176)[0m   [Client 0] ATTACKING: Reversing and scaling updates...
[36m(ClientAppActor pid=22440)[0m   [Client 1] Local validation: Accuracy=0.0223, Loss=nan[32m [repeated 4x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 5 results and 0 failures
  with torch.cuda.amp.autocast(enabled=use_amp):
[92mINFO [0m:      fit progress: (9, nan, {'accuracy': 0.021376829947824326, 'precision': 0.0030538328496891895, 'recall': 0.14285714285714285, 'f1': 0.005979835767069809}, 2766.9072826000047)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 5)


Server-side evaluation round 9 / 10: Loss nan | Acc 0.0214


[36m(ClientAppActor pid=32092)[0m 
[36m(ClientAppActor pid=32092)[0m         
[36m(ClientAppActor pid=32092)[0m             This is a deprecated feature. It will be removed[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=32092)[0m             entirely in future versions of Flower.[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=2360)[0m 
[36m(ClientAppActor pid=2360)[0m         
[36m(ClientAppActor pid=32880)[0m 
[36m(ClientAppActor pid=32880)[0m         
[36m(ClientAppActor pid=21176)[0m 
[36m(ClientAppActor pid=21176)[0m         


[36m(ClientAppActor pid=2360)[0m   [Client 1] ATTACKING: Reversing and scaling updates...


[36m(ClientAppActor pid=22440)[0m 
[36m(ClientAppActor pid=22440)[0m         


[36m(ClientAppActor pid=32880)[0m   [Client 4] Local validation: Accuracy=0.0203, Loss=nan


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 10]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)




[36m(ClientAppActor pid=21176)[0m 
[36m(ClientAppActor pid=21176)[0m         
[36m(ClientAppActor pid=21176)[0m             This is a deprecated feature. It will be removed[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=21176)[0m             entirely in future versions of Flower.[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=22440)[0m 
[36m(ClientAppActor pid=22440)[0m         




[36m(ClientAppActor pid=2360)[0m 
[36m(ClientAppActor pid=2360)[0m         
[36m(ClientAppActor pid=32092)[0m 
[36m(ClientAppActor pid=32092)[0m         
[36m(ClientAppActor pid=32880)[0m 
[36m(ClientAppActor pid=32880)[0m         


[36m(ClientAppActor pid=22440)[0m   [Client 1] ATTACKING: Reversing and scaling updates...
[36m(ClientAppActor pid=22440)[0m   [Client 0] Local validation: Accuracy=0.0213, Loss=nan[32m [repeated 4x across cluster][0m


[92mINFO [0m:      aggregate_fit: received 5 results and 0 failures
  with torch.cuda.amp.autocast(enabled=use_amp):


Server-side evaluation round 10 / 10: Loss nan | Acc 0.0214


[92mINFO [0m:      fit progress: (10, nan, {'accuracy': 0.021376829947824326, 'precision': 0.0030538328496891895, 'recall': 0.14285714285714285, 'f1': 0.005979835767069809}, 3057.4336729999923)
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 5)


Final global model saved to ./fl_output_notebook\global_model_final.pth


[36m(ClientAppActor pid=2360)[0m 
[36m(ClientAppActor pid=2360)[0m         
[36m(ClientAppActor pid=2360)[0m             This is a deprecated feature. It will be removed[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=2360)[0m             entirely in future versions of Flower.[32m [repeated 5x across cluster][0m
[36m(ClientAppActor pid=32880)[0m 
[36m(ClientAppActor pid=32880)[0m         
[36m(ClientAppActor pid=21176)[0m 
[36m(ClientAppActor pid=21176)[0m         
[36m(ClientAppActor pid=22440)[0m 
[36m(ClientAppActor pid=22440)[0m         


[36m(ClientAppActor pid=21176)[0m   [Client 0] ATTACKING: Reversing and scaling updates...


[36m(ClientAppActor pid=32092)[0m 
[36m(ClientAppActor pid=32092)[0m         


[36m(ClientAppActor pid=32092)[0m   [Client 4] Local validation: Accuracy=0.0218, Loss=nan


[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [SUMMARY]
[92mINFO [0m:      Run finished 10 round(s) in 3063.42s
[92mINFO [0m:      	History (loss, distributed):
[92mINFO [0m:      		round 1: nan
[92mINFO [0m:      		round 2: nan
[92mINFO [0m:      		round 3: nan
[92mINFO [0m:      		round 4: nan
[92mINFO [0m:      		round 5: nan
[92mINFO [0m:      		round 6: nan
[92mINFO [0m:      		round 7: nan
[92mINFO [0m:      		round 8: nan
[92mINFO [0m:      		round 9: nan
[92mINFO [0m:      		round 10: nan
[92mINFO [0m:      	History (loss, centralized):
[92mINFO [0m:      		round 0: 2.017263355301504
[92mINFO [0m:      		round 1: nan
[92mINFO [0m:      		round 2: nan
[92mINFO [0m:      		round 3: nan
[92mINFO [0m:      		round 4: nan
[92mINFO [0m:      		round 5: nan
[92mINFO [0m:      		round 6: nan
[92mINFO [0m:      		round 7: nan
[92mINFO [0m:      		round 8: nan
[92mINFO [0



[92mINFO [0m:      	            (1, 0.14285714285714285),
[92mINFO [0m:      	            (2, 0.14285714285714285),
[92mINFO [0m:      	            (3, 0.14285714285714285),
[92mINFO [0m:      	            (4, 0.14285714285714285),
[92mINFO [0m:      	            (5, 0.14285714285714285),
[92mINFO [0m:      	            (6, 0.14285714285714285),
[92mINFO [0m:      	            (7, 0.14285714285714285),
[92mINFO [0m:      	            (8, 0.14285714285714285),
[92mINFO [0m:      	            (9, 0.14285714285714285),
[92mINFO [0m:      	            (10, 0.14285714285714285)]}
[92mINFO [0m:      


Simulation finished.


In [5]:
# ==============================================================================
# === Cell 5: Evaluate the Final Model ===
# ==============================================================================
# (This cell loads the saved model and runs a full evaluation)
# (No changes needed)

# --- Configuration for evaluation ---
MODEL_PATH = os.path.join(OUT_DIR, "global_model_best.pth") # Load best model
EVAL_H5_PATH = CLIENT_DATA_PATHS["4"] 
# --------------------------------------


print(f"\n--- Starting Final Evaluation ---")
print(f"Loading best model from: {MODEL_PATH}")
print(f"Evaluating on data from: {EVAL_H5_PATH}")

# 1. Initialize model
final_model = MultiModalNet(num_classes=NUM_CLASSES)

# 2. Load saved weights
try:
    final_model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    final_model.to(DEVICE)

    # 3. Run evaluation and plotting
    plot_confusion_and_report(
        model=final_model, 
        test_h5=EVAL_H5_PATH, 
        downsample=DOWNSAMPLE, 
        device=DEVICE, 
        out_dir=OUT_DIR
    )
    print(f"\nEvaluation complete. Plots saved to {OUT_DIR}")

except FileNotFoundError:
    print(f"Error: Model file not found at {MODEL_PATH}.")
    print("Please ensure the simulation in 'Cell 4' ran successfully and saved a model.")
except Exception as e:
    print(f"An error occurred during evaluation: {e}")

# --- Plot Training History ---
if 'history' in locals():
    print("\nPlotting training history...")
    # Extract local loss
    local_loss_hist = []
    rounds = history.metrics_distributed["fit"].keys()
    
    if rounds:
        for r in rounds:
            round_losses = [metrics.get("local_loss", 0.0) for cid, metrics in history.metrics_distributed["fit"][r]]
            if round_losses:
                avg_loss = np.mean(round_losses)
                local_loss_hist.append(avg_loss)
            else:
                local_loss_hist.append(np.nan)

        # Extract global accuracy
        if "evaluate" in history.metrics_centralized and history.metrics_centralized["evaluate"]:
            global_acc_hist = [metrics["accuracy"] for r, metrics in history.metrics_centralized["evaluate"]]
            
            plt.figure(figsize=(10,4))
            plt.subplot(1,2,1)
            plt.plot(rounds, local_loss_hist, marker='o')
            plt.title("Avg Client Loss per Round")
            plt.xlabel("Round"); plt.ylabel("Loss"); plt.grid(True)

            plt.subplot(1,2,2)
            plt.plot(rounds, global_acc_hist, marker='o', color='green')
            plt.title("Global Test Accuracy per Round")
            plt.xlabel("Round"); plt.ylabel("Accuracy"); plt.grid(True)
            plt.tight_layout()
            plt.show()
        else:
            print("No centralized evaluation metrics found in history.")
    else:
        print("No fit metrics found in history.")
else:
    print("No 'history' object found. Skipping training plots.")


--- Starting Final Evaluation ---
Loading best model from: ./fl_output_notebook\global_model_best.pth
Evaluating on data from: C:\\Users\\my pc\\Desktop\\UAV authentication using federated learning\\data\\client4.h5


  with torch.cuda.amp.autocast(enabled=use_amp):



--- Final Global Model Evaluation ---
Test Loss: 2.0173
Test Accuracy: 0.0865

Classification report:
               precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000       422
           1     0.0734    1.0000    0.1367      1428
           2     0.0000    0.0000    0.0000       690
           3     0.0000    0.0000    0.0000      1281
           4     1.0000    0.0263    0.0513     10593
           5     0.0000    0.0000    0.0000      3246
           6     0.0000    0.0000    0.0000      2081

    accuracy                         0.0865     19741
   macro avg     0.1533    0.1466    0.0269     19741
weighted avg     0.5419    0.0865    0.0374     19741

Confusion matrix saved to ./fl_output_notebook\confusion_matrix.png
SNR vs. Accuracy plot saved to ./fl_output_notebook\snr_vs_accuracy.png

Evaluation complete. Plots saved to ./fl_output_notebook

Plotting training history...


KeyError: 'fit'