In [None]:
!pip install -q flwr torch torchvision tensorboard
!pip install -U "flwr[simulation]"
!pip install timm

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m540.0/540.0 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m100.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m94.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m58.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
from datetime import datetime
from functools import partial
from flwr.server import ServerConfig
import random
import numpy as np
import torch.nn as nn
from torchvision import transforms
import timm
import torch
import os
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import random_split, DataLoader
from torch.utils.tensorboard import SummaryWriter
import flwr as fl
from flwr.simulation import run_simulation
from flwr.common import Context
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.client import ClientApp
from flwr.common import parameters_to_ndarrays
from collections import OrderedDict
from wandb_logger import FederatedWandBLogger
import data_utils
import clients
import strategies
import data_preprocessing

In [None]:
def build_optimizer_config(optimizer_type):
    if optimizer_type == clients.OptimizerType.SGD:
        config = {
            "lr": LR,
            "momentum": MOMENTUM,
            "weight_decay": WEIGHT_DECAY,
            "nesterov": NESTEROV
        }

    elif optimizer_type == clients.OptimizerType.SSGD:
        config = {
            "lr": LR,
            "momentum": MOMENTUM,
            "weight_decay": WEIGHT_DECAY,
        }


    elif optimizer_type == clients.OptimizerType.ADAM:
        config = {
            "lr": LR,
            "betas": BETAS,
            "weight_decay": WEIGHT_DECAY,
            "eps": EPSILON
        }

    elif optimizer_type == clients.OptimizerType.ADAMW:
        config = {
            "lr": LR,
            "betas": BETAS,
            "weight_decay": WEIGHT_DECAY
        }

    elif optimizer_type == clients.OptimizerType.RMSPROP:
        config = {
            "lr": LR,
            "alpha": ALPHA,
            "eps": EPSILON,
            "weight_decay": WEIGHT_DECAY,
            "momentum": MOMENTUM,
            "centered": CENTERED
        }

    elif optimizer_type == clients.OptimizerType.ADAGRAD:
        config = {
            "lr": LR,
            "weight_decay": WEIGHT_DECAY,
            "eps": EPSILON
        }
    else:
        raise ValueError(f"Unsupported optimizer type: {optimizer_type}")

    print(f"Optimizer '{optimizer_type}' initialized successfully.")
    return config

In [None]:
def build_scheduler_config(scheduler_type):

    if scheduler_type == clients.SchedulerType.COSINE:
        config = {
            "T_max": T_MAX,
            "eta_min": ETA_MIN
        }

    elif scheduler_type == clients.SchedulerType.COSINE_RESTART:
        config = {
            "T_0": T_MAX,
            "T_mult": 1,
            "eta_min": ETA_MIN
        }

    elif scheduler_type == clients.SchedulerType.STEP:
        config = {
            "step_size": STEP_SIZE,
            "gamma": GAMMA
        }

    elif scheduler_type == clients.SchedulerType.MULTISTEP:
        config = {
            "milestones": MILESTONES,
            "gamma": GAMMA
        }

    elif scheduler_type == clients.SchedulerType.EXPONENTIAL:
        config = {
            "gamma": GAMMA
        }

    elif scheduler_type == clients.SchedulerType.REDUCE_ON_PLATEAU:
        config = {
            "mode": "min",
            "factor": FACTOR,
            "patience": PATIENCE,
            "threshold": THRESHOLD
        }

    elif scheduler_type == clients.SchedulerType.CONSTANT:
        config = {
            "factor": 1.0,
            "total_iters": TOTAL_ITERS
        }

    elif scheduler_type == clients.SchedulerType.LINEAR:
        config = {
            "start_factor": 0.5,
            "end_factor": LR_END,
            "total_iters": TOTAL_ITERS
        }

    else:
        raise ValueError(f"Unsupported scheduler type: {scheduler_type}")

    print(f"Scheduler '{scheduler_type}' initialized successfully.")
    return config

In [None]:
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)

In [None]:
# THE WHOLE TRAINING SETTINGS ARE HERE!

# GENERAL
BATCH_SIZE = 256
EPOCHS = 15
VAL_SPLIT = 0.1
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BACKBONE_FREEZING = True

# FL
NUM_CLIENTS = 5
NC = 20
LOCAL_EPOCHS = 5
NUM_ROUNDS = 3
FRACTION_FIT = 1
IID = True

# LOSS
SMOOTHING=0.1

# Optimizer Hyperparameters
LR = 0.01
MOMENTUM = 0.9
WEIGHT_DECAY = 0.0005
NESTEROV = False
BETAS = (0.9, 0.999)
EPSILON = 1e-8
ALPHA = 0.99
CENTERED = False
# do not change this
OPTIMIZER_TYPE  = clients.OptimizerType.SSGD

# Scheduler Hyperparameters
T_MAX = 100
ETA_MIN = 0.0
STEP_SIZE = 30
GAMMA = 0.1
MILESTONES = [50, 75]
FACTOR = 0.1
PATIENCE = 5
THRESHOLD = 1e-4
TOTAL_ITERS = 10
LR_END = 0.0
SCHEDULER_TYPE  = clients.SchedulerType.COSINE

TALOS_CONFIG = {
    "final_sparsity": 0.3,
    "num_batches": 3,
    "rounds": 2
}

OPTIMIZER_CONFIG = build_optimizer_config(OPTIMIZER_TYPE)
SCHEDULER_CONFIG = build_scheduler_config(SCHEDULER_TYPE)

Optimizer 'OptimizerType.SSGD' initialized successfully.
Scheduler 'SchedulerType.COSINE' initialized successfully.


In [None]:
pipeline = data_preprocessing.CIFAR100Pipeline(val_split=VAL_SPLIT, use_augment=True)
trainset, valset, testset = pipeline.run_pipeline()

trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
valloader = DataLoader(valset, batch_size=BATCH_SIZE)
testloader = DataLoader(testset, batch_size=BATCH_SIZE)

# Apply sharding
client_datasets_iid = data_utils.iid_split(trainset, NUM_CLIENTS)
client_datasets_noniid = data_utils.non_iid_split(trainset, NUM_CLIENTS, NC, 100)

print("Setup complete. IID and Non-IID splits created.")

100%|██████████| 169M/169M [00:05<00:00, 31.5MB/s]


Setup complete. IID and Non-IID splits created.


In [None]:
# Create model
def create_dino_vit_s16_for_cifar100(freezing=BACKBONE_FREEZING):
    model = timm.create_model("vit_small_patch16_224_dino", pretrained=True, num_classes=0)

    # Replace the head with CIFAR-100 classification head
    model.head = nn.Linear(model.num_features, 100)

    if freezing:
      # Freeze all parameters except head
      for param in model.parameters():
          param.requires_grad = False

      # Unfreeze only the head
      for param in model.head.parameters():
          param.requires_grad = True

    return model

In [None]:
!wandb login f8ad3703c9023ee3f86c7242a87d9280b6c031fb

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
data_distribution = "IID_DATA" if IID else "NON_IID_DATA"
run_name = f"FEDERATED_{data_distribution}_TALOS_{datetime.now().strftime('%Y%m%d-%H%M%S')}"

client_app = ClientApp(
    client_fn=clients.build_client_talos_fn(
        use_iid=IID,
        scheduler_type=SCHEDULER_TYPE,
        optimizer_config=OPTIMIZER_CONFIG,
        scheduler_config=SCHEDULER_CONFIG,
        iid_partitions=client_datasets_iid,
        non_iid_partitions=client_datasets_noniid,
        model_fn=create_dino_vit_s16_for_cifar100,
        device=DEVICE,
        valset=valset,
        batch_size=BATCH_SIZE,
        local_epochs=LOCAL_EPOCHS,
        talos_config=TALOS_CONFIG
    )
)

logger = FederatedWandBLogger(
    project_name="federated-learning-project",
    run_name=run_name,
    global_config={
        # Federated Learning Configuration
        "use_iid": IID,
        "local_epochs": LOCAL_EPOCHS,
        "batch_size": BATCH_SIZE,
        "num_clients": NUM_CLIENTS,
        "fraction_fit": FRACTION_FIT,
        "fraction_evaluate": FRACTION_FIT,
        "num_rounds": NUM_ROUNDS,
        "backbone_freezing": BACKBONE_FREEZING,

        # Optimizer Configuration
        "optimizer": OPTIMIZER_TYPE.value,
        "optimizer_config": OPTIMIZER_CONFIG,

        # Scheduler Configuration
        "scheduler": SCHEDULER_TYPE.value,
        "scheduler_config": SCHEDULER_CONFIG,

        "talos_config": TALOS_CONFIG,
    }
)

strategy = strategies.FedAvgStandard(
    logger=logger,
    fraction_fit=FRACTION_FIT,
    min_fit_clients=int(FRACTION_FIT*NUM_CLIENTS),
    min_evaluate_clients=int(FRACTION_FIT*NUM_CLIENTS),
    fraction_evaluate=FRACTION_FIT,
    min_available_clients=NUM_CLIENTS,
    evaluate_metrics_aggregation_fn = lambda metrics: {
        "accuracy": sum(num * m["val_accuracy"] for num, m in metrics) / sum(num for num, _ in metrics)
    }
)

def server_fn(context: Context) -> ServerAppComponents:
    """Construct components that set the ServerApp behaviour.

    You can use the settings in `context.run_config` to parameterize the
    construction of all elements (e.g the strategy or the number of rounds)
    wrapped in the returned ServerAppComponents object.
    """
    config = ServerConfig(num_rounds=NUM_ROUNDS)
    return ServerAppComponents(strategy=strategy, config=config)


server_app = ServerApp(server_fn=server_fn)

  self.scope.user = {"email": email}  # noqa


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 0.0}}

if device.type == "cuda":
    backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 1.0}}

In [None]:
import importlib

importlib.reload(clients)

<module 'clients' from '/content/clients.py'>

In [None]:
run_simulation(
    server_app=server_app,
    client_app=client_app,
    num_supernodes=NUM_CLIENTS,
    backend_config=backend_config
)

[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=3, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[36m(pid=3165)[0m 2025-05-11 15:45:31.574259: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=3165)[0m E0000 00:00:1746978331.597226    3165 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=3165)[0m E0000 00:00:1746978331.603611    3165 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


[36m(ClientAppActor pid=3165)[0m LOG: Initializing client with CID=0
[36m(ClientAppActor pid=3165)[0m 0-LOG: Data partition assigned to client 0 -> IID


[36m(ClientAppActor pid=3165)[0m   model = create_fn(


[36m(ClientAppActor pid=3165)[0m 0-LOG: Model initialized for client 0
[36m(ClientAppActor pid=3165)[0m 0-LOG: Dataloaders initialized for client 0
[36m(ClientAppActor pid=3165)[0m 🟢 Found model head. Pruning will be applied to the head only.
[36m(ClientAppActor pid=3165)[0m 0-LOG: Starting TaLoS Mask Calibration
[36m(ClientAppActor pid=3165)[0m 🔎 Starting multi-round calibration for the model head only.
[36m(ClientAppActor pid=3165)[0m 🌀 Calibration Round 1/2


[36m(ClientAppActor pid=3165)[0m   self.scaler = GradScaler()
[36m(ClientAppActor pid=3165)[0m   with autocast():


[36m(ClientAppActor pid=3165)[0m ✅ Mask updated with 15.00% sparsity
[36m(ClientAppActor pid=3165)[0m 🌀 Calibration Round 2/2
[36m(ClientAppActor pid=3165)[0m ✅ Mask updated with 30.00% sparsity


[92mINFO [0m:      Received initial parameters from one random client
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)


[36m(ClientAppActor pid=3165)[0m LOG: Initializing client with CID=0
[36m(ClientAppActor pid=3165)[0m 0-LOG: Data partition assigned to client 0 -> IID
[36m(ClientAppActor pid=3165)[0m 0-LOG: Model initialized for client 0
[36m(ClientAppActor pid=3165)[0m 0-LOG: Dataloaders initialized for client 0
[36m(ClientAppActor pid=3165)[0m 🟢 Found model head. Pruning will be applied to the head only.
[36m(ClientAppActor pid=3165)[0m 0-LOG: Starting TaLoS Mask Calibration
[36m(ClientAppActor pid=3165)[0m 🔎 Starting multi-round calibration for the model head only.
[36m(ClientAppActor pid=3165)[0m 🌀 Calibration Round 1/2
[36m(ClientAppActor pid=3165)[0m ✅ Mask updated with 15.00% sparsity
[36m(ClientAppActor pid=3165)[0m 🌀 Calibration Round 2/2
[36m(ClientAppActor pid=3165)[0m ✅ Mask updated with 30.00% sparsity
[36m(ClientAppActor pid=3165)[0m 0-LOG: Starting local training round with TaLoS
[36m(ClientAppActor pid=3165)[0m 0-LOG: Starting epoch 1/5


[36m(ClientAppActor pid=3165)[0m   with autocast():


[36m(ClientAppActor pid=3165)[0m 0-LOG: Starting epoch 2/5
[36m(ClientAppActor pid=3165)[0m 0-LOG: Starting epoch 3/5
[36m(ClientAppActor pid=3165)[0m 0-LOG: Starting epoch 4/5
[36m(ClientAppActor pid=3165)[0m 0-LOG: Starting epoch 5/5
[36m(ClientAppActor pid=3165)[0m 0-LOG: Completed local training - Loss: 1.5825 | Accuracy: 0.6316
[36m(ClientAppActor pid=3165)[0m LOG: Initializing client with CID=1
[36m(ClientAppActor pid=3165)[0m 1-LOG: Data partition assigned to client 1 -> IID
[36m(ClientAppActor pid=3165)[0m 1-LOG: Model initialized for client 1
[36m(ClientAppActor pid=3165)[0m 1-LOG: Dataloaders initialized for client 1
[36m(ClientAppActor pid=3165)[0m 🟢 Found model head. Pruning will be applied to the head only.
[36m(ClientAppActor pid=3165)[0m 1-LOG: Starting TaLoS Mask Calibration
[36m(ClientAppActor pid=3165)[0m 🔎 Starting multi-round calibration for the model head only.
[36m(ClientAppActor pid=3165)[0m 🌀 Calibration Round 1/2
[36m(ClientAppActor p

[92mINFO [0m:      aggregate_fit: received 5 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 5)


[36m(ClientAppActor pid=3165)[0m LOG: Initializing client with CID=2
[36m(ClientAppActor pid=3165)[0m 2-LOG: Data partition assigned to client 2 -> IID
[36m(ClientAppActor pid=3165)[0m 2-LOG: Model initialized for client 2
[36m(ClientAppActor pid=3165)[0m 2-LOG: Dataloaders initialized for client 2
[36m(ClientAppActor pid=3165)[0m 🟢 Found model head. Pruning will be applied to the head only.
[36m(ClientAppActor pid=3165)[0m 2-LOG: Starting TaLoS Mask Calibration
[36m(ClientAppActor pid=3165)[0m 🔎 Starting multi-round calibration for the model head only.
[36m(ClientAppActor pid=3165)[0m 🌀 Calibration Round 1/2
[36m(ClientAppActor pid=3165)[0m ✅ Mask updated with 15.00% sparsity
[36m(ClientAppActor pid=3165)[0m 🌀 Calibration Round 2/2
[36m(ClientAppActor pid=3165)[0m ✅ Mask updated with 30.00% sparsity
[36m(ClientAppActor pid=3165)[0m 2-LOG: Starting evaluation
[36m(ClientAppActor pid=3165)[0m 2-LOG: Evaluation completed - Val Loss: 0.9582 | Val Accuracy: 0.7166

[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)


[36m(ClientAppActor pid=3165)[0m 1-LOG: Evaluation completed - Val Loss: 0.9582 | Val Accuracy: 0.7166
[36m(ClientAppActor pid=3165)[0m LOG: Initializing client with CID=0
[36m(ClientAppActor pid=3165)[0m 0-LOG: Data partition assigned to client 0 -> IID
[36m(ClientAppActor pid=3165)[0m 0-LOG: Model initialized for client 0
[36m(ClientAppActor pid=3165)[0m 0-LOG: Dataloaders initialized for client 0
[36m(ClientAppActor pid=3165)[0m 🟢 Found model head. Pruning will be applied to the head only.
[36m(ClientAppActor pid=3165)[0m 0-LOG: Starting TaLoS Mask Calibration
[36m(ClientAppActor pid=3165)[0m 🔎 Starting multi-round calibration for the model head only.
[36m(ClientAppActor pid=3165)[0m 🌀 Calibration Round 1/2
[36m(ClientAppActor pid=3165)[0m ✅ Mask updated with 15.00% sparsity
[36m(ClientAppActor pid=3165)[0m 🌀 Calibration Round 2/2
[36m(ClientAppActor pid=3165)[0m ✅ Mask updated with 30.00% sparsity
[36m(ClientAppActor pid=3165)[0m 0-LOG: Starting local trai

[92mINFO [0m:      aggregate_fit: received 5 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 5 clients (out of 5)


[36m(ClientAppActor pid=3165)[0m LOG: Initializing client with CID=0
[36m(ClientAppActor pid=3165)[0m 0-LOG: Data partition assigned to client 0 -> IID
[36m(ClientAppActor pid=3165)[0m 0-LOG: Model initialized for client 0
[36m(ClientAppActor pid=3165)[0m 0-LOG: Dataloaders initialized for client 0
[36m(ClientAppActor pid=3165)[0m 🟢 Found model head. Pruning will be applied to the head only.
[36m(ClientAppActor pid=3165)[0m 0-LOG: Starting TaLoS Mask Calibration
[36m(ClientAppActor pid=3165)[0m 🔎 Starting multi-round calibration for the model head only.
[36m(ClientAppActor pid=3165)[0m 🌀 Calibration Round 1/2
[36m(ClientAppActor pid=3165)[0m ✅ Mask updated with 15.00% sparsity
[36m(ClientAppActor pid=3165)[0m 🌀 Calibration Round 2/2
[36m(ClientAppActor pid=3165)[0m ✅ Mask updated with 30.00% sparsity
[36m(ClientAppActor pid=3165)[0m 0-LOG: Starting evaluation
[36m(ClientAppActor pid=3165)[0m 0-LOG: Evaluation completed - Val Loss: 0.8759 | Val Accuracy: 0.7428

[92mINFO [0m:      aggregate_evaluate: received 5 results and 0 failures
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 5 clients (out of 5)


[36m(ClientAppActor pid=3165)[0m 4-LOG: Evaluation completed - Val Loss: 0.8759 | Val Accuracy: 0.7428
[36m(ClientAppActor pid=3165)[0m LOG: Initializing client with CID=0
[36m(ClientAppActor pid=3165)[0m 0-LOG: Data partition assigned to client 0 -> IID
[36m(ClientAppActor pid=3165)[0m 0-LOG: Model initialized for client 0
[36m(ClientAppActor pid=3165)[0m 0-LOG: Dataloaders initialized for client 0
[36m(ClientAppActor pid=3165)[0m 🟢 Found model head. Pruning will be applied to the head only.
[36m(ClientAppActor pid=3165)[0m 0-LOG: Starting TaLoS Mask Calibration
[36m(ClientAppActor pid=3165)[0m 🔎 Starting multi-round calibration for the model head only.
[36m(ClientAppActor pid=3165)[0m 🌀 Calibration Round 1/2
[36m(ClientAppActor pid=3165)[0m ✅ Mask updated with 15.00% sparsity
[36m(ClientAppActor pid=3165)[0m 🌀 Calibration Round 2/2
[36m(ClientAppActor pid=3165)[0m ✅ Mask updated with 30.00% sparsity
[36m(ClientAppActor pid=3165)[0m 0-LOG: Starting local trai

KeyboardInterrupt: 

In [None]:
final_parameters = strategy.latest_parameters

# Convert Flower parameters to list of numpy arrays
ndarrays = parameters_to_ndarrays(final_parameters)

# Load them into a PyTorch model
iid_model = create_dino_vit_s16_for_cifar100()
iid_model.to(device)
state_dict = OrderedDict(
    (key, torch.tensor(val)) for key, val in zip(iid_model.state_dict().keys(), ndarrays)
)
iid_model.load_state_dict(state_dict)

correct, total, loss_total = 0, 0, 0.0
criterion = torch.nn.CrossEntropyLoss()

with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = iid_model(images)
        loss = criterion(outputs, labels)

        loss_total += loss.item() * labels.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_loss = loss_total / total
test_accuracy = correct / total

logger.log_global_metrics({
    "test_loss": test_loss,
    "test_accuracy": test_accuracy
}, round_number=NUM_ROUNDS)

model_save_path = "final_federated_model.pth"
torch.save(iid_model.state_dict(), model_save_path)

logger.log_model(iid_model, path=model_save_path)

logger.finish()

print(f"✅ Test Accuracy: {test_accuracy:.4f} | Test Loss: {test_loss:.4f}")