In [1]:
!pip install -q flwr torch torchvision tensorboard
!pip install -U "flwr[simulation]"
!pip install timm



In [2]:
from datetime import datetime
from functools import partial
from flwr.server import ServerConfig
import random
import numpy as np
import torch.nn as nn
from torchvision import transforms
import timm
import torch
import os
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import random_split, DataLoader
from torch.utils.tensorboard import SummaryWriter
import flwr as fl
from flwr.simulation import run_simulation
from flwr.common import Context
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.client import ClientApp
from flwr.common import parameters_to_ndarrays, ndarrays_to_parameters
from flwr.common import Parameters
from collections import OrderedDict
from wandb_logger import FederatedWandBLogger
import data_utils
import clients
import strategies
import data_preprocessing

In [3]:
def build_optimizer_config(optimizer_type):
    if optimizer_type == clients.OptimizerType.SGD:
        config = {
            "lr": LR,
            "momentum": MOMENTUM,
            "weight_decay": WEIGHT_DECAY,
            "nesterov": NESTEROV
        }

    elif optimizer_type == clients.OptimizerType.SSGD:
        config = {
            "lr": LR,
            "momentum": MOMENTUM,
            "weight_decay": WEIGHT_DECAY,
        }


    elif optimizer_type == clients.OptimizerType.ADAM:
        config = {
            "lr": LR,
            "betas": BETAS,
            "weight_decay": WEIGHT_DECAY,
            "eps": EPSILON
        }

    elif optimizer_type == clients.OptimizerType.ADAMW:
        config = {
            "lr": LR,
            "betas": BETAS,
            "weight_decay": WEIGHT_DECAY
        }

    elif optimizer_type == clients.OptimizerType.SADAMW:
        config = {
            "lr": LR,
            "betas": BETAS,
            "weight_decay": WEIGHT_DECAY,
            "eps": EPSILON
        }

    elif optimizer_type == clients.OptimizerType.RMSPROP:
        config = {
            "lr": LR,
            "alpha": ALPHA,
            "eps": EPSILON,
            "weight_decay": WEIGHT_DECAY,
            "momentum": MOMENTUM,
            "centered": CENTERED
        }

    elif optimizer_type == clients.OptimizerType.ADAGRAD:
        config = {
            "lr": LR,
            "weight_decay": WEIGHT_DECAY,
            "eps": EPSILON
        }
    else:
        raise ValueError(f"Unsupported optimizer type: {optimizer_type}")

    print(f"Optimizer '{optimizer_type}' initialized successfully.")
    return config

In [4]:
def build_scheduler_config(scheduler_type):
    scheduler_type.lower()

    if scheduler_type == "cosine":
        config = {
            "T_max": T_MAX,
            "eta_min": ETA_MIN
        }

    elif scheduler_type == "cosine_restart":
        config = {
            "T_0": T_MAX,
            "T_mult": 1,
            "eta_min": ETA_MIN
        }

    elif scheduler_type == "step":
        config = {
            "step_size": STEP_SIZE,
            "gamma": GAMMA
        }

    elif scheduler_type == "multistep":
        config = {
            "milestones": MILESTONES,
            "gamma": GAMMA
        }

    elif scheduler_type == "exponential":
        config = {
            "gamma": GAMMA
        }

    elif scheduler_type == "reduce_on_plateau":
        config = {
            "mode": "min",
            "factor": FACTOR,
            "patience": PATIENCE,
            "threshold": THRESHOLD
        }

    elif scheduler_type == "constant":
        config = {
            "factor": 1.0,
            "total_iters": TOTAL_ITERS
        }

    elif scheduler_type == "linear":
        config = {
            "start_factor": 0.5,
            "end_factor": LR_END,
            "total_iters": TOTAL_ITERS
        }
    else:
        raise ValueError(f"Unsupported scheduler type: {scheduler_type}")

    print(f"Scheduler '{scheduler_type}' initialized successfully.")
    return config

In [5]:
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)

In [None]:
# THE WHOLE TRAINING SETTINGS ARE HERE!

# GENERAL
BATCH_SIZE = 32
EPOCHS = 15
VAL_SPLIT = 0.1
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BACKBONE_FREEZING = False

# FL
NUM_CLIENTS = 10
NC = 50
LOCAL_EPOCHS = 1
NUM_ROUNDS = 25
FRACTION_FIT = 0.1
FRACTION_EVAL = 0.03
CLIENT_TYPE = clients.ClientType.TALOSPFEDEDIT
STRATEGY_TYPE = strategies.StrategiesType.YOGI
IID = False

# LOSS
SMOOTHING=0.1

# Optimizer Hyperparameters
LR = 0.00005
MOMENTUM = 0.9
WEIGHT_DECAY = 0.025
NESTEROV = False
BETAS = (0.9, 0.999)
EPSILON = 1e-8
ALPHA = 0.99
CENTERED = False
OPTIMIZER_TYPE  = clients.OptimizerType.SADAMW

# Scheduler Hyperparameters
T_MAX = LOCAL_EPOCHS
ETA_MIN = 0.0
STEP_SIZE = 30
GAMMA = 0.1
MILESTONES = [50, 75]
FACTOR = 0.1
PATIENCE = 5
THRESHOLD = 1e-4
TOTAL_ITERS = 10
LR_END = 0.0
SCHEDULER_TYPE  = "cosine"

TALOS_CONFIG = {
    "final_sparsity": 0.6,
    "num_batches": 3,
    "rounds": 4,
    "calibration_mode": "most_sensitive", # least_sensitive or most_sensitive
    "mode": "full", # full if want to include all layers, head for head only, pfededit for custom topk client selection
    "k": 3, #pfedit setting
}

# FedProx

MU = 0.1

# YOGI

ETA = 0.5         # Learning rate for Yogi updates
ETA_L = 0.25      # Learning rate for local updates
TAU = 0.1            # Adaptive learning parameter for Yogi
BETA_1 = 0.9         # Momentum term (default value from the original paper)
BETA_2 = 0.999       # Second momentum term (default value from the original paper)

# METAFEDAVG

INNER_LR = 0.01

#PFEDEDIT
LOCAL_EPOCHS_PFEDEDIT = 4
TOP_K_LAYERS = 3
#STOCHASTIC_FACTOR = 1
MAX_BATCHES = 4
DETERMINISTIC_ROUND = 10
ALL_ROUNDS_SCHEDULING = True
REVERSE_MODE = False
PFEDEDIT_CONFIG = {
    "local epochs": LOCAL_EPOCHS_PFEDEDIT,
    "top_k_layers": TOP_K_LAYERS,
    "stochastic factor": "linear decrease",
    "max_batches": MAX_BATCHES,
    "deterministic_round": DETERMINISTIC_ROUND,
    "all_rounds_scheduling": ALL_ROUNDS_SCHEDULING,
    "reverse_mode": REVERSE_MODE,
}


OPTIMIZER_CONFIG = build_optimizer_config(OPTIMIZER_TYPE)
SCHEDULER_CONFIG = build_scheduler_config(SCHEDULER_TYPE)

Optimizer 'OptimizerType.SADAMW' initialized successfully.
Scheduler 'cosine' initialized successfully.


In [7]:
pipeline = data_preprocessing.CIFAR100Pipeline(val_split=VAL_SPLIT, use_augment=True)
trainset, valset, testset = pipeline.run_pipeline()

trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
valloader = DataLoader(valset, batch_size=BATCH_SIZE)
testloader = DataLoader(testset, batch_size=BATCH_SIZE)

# Apply sharding
client_datasets_iid = data_utils.iid_split(trainset, NUM_CLIENTS)
client_datasets_noniid = data_utils.non_iid_split(trainset, NUM_CLIENTS, NC, 100)

print("Setup complete. IID and Non-IID splits created.")

Setup complete. IID and Non-IID splits created.


In [8]:
# Create model
def create_dino_vit_s16_for_cifar100(freezing=BACKBONE_FREEZING):
    model = timm.create_model("vit_small_patch16_224_dino", pretrained=True, num_classes=0)

    # Replace the head with CIFAR-100 classification head
    model.head = nn.Linear(model.num_features, 100)

    if freezing:
      # Freeze all parameters except head
      for param in model.parameters():
          param.requires_grad = False

      # Unfreeze only the head
      for param in model.head.parameters():
          param.requires_grad = True
    else:

      for param in model.parameters():
          param.requires_grad = True

      for param in model.head.parameters():
          param.requires_grad = True

    return model

In [9]:
model = create_dino_vit_s16_for_cifar100()

print(next(model.parameters()).device)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable params: {trainable:,} / {total:,}")

initial_parameters = [val.cpu().numpy() for _, val in model.state_dict().items()]
flower_parameters = ndarrays_to_parameters(initial_parameters)

  model = create_fn(
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


cpu
Trainable params: 21,704,164 / 21,704,164


In [10]:
!wandb login 89e5fee022a3a1cf86f958ee0b3dff6f2aa57aad

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


In [11]:
data_distribution = "IID_DATA" if IID else "NON_IID_DATA"
run_name = f"FEDERATED_{data_distribution}_TALOS_{datetime.now().strftime('%Y%m%d-%H%M%S')}"

if CLIENT_TYPE == clients.ClientType.TALOS:
    client_app = ClientApp(
        client_fn=clients.build_client_talos_fn(
            use_iid=IID,
            optimizer_type=OPTIMIZER_TYPE,
            optimizer_config=OPTIMIZER_CONFIG,
            scheduler_type=SCHEDULER_TYPE,
            scheduler_config=SCHEDULER_CONFIG,
            iid_partitions=client_datasets_iid,
            non_iid_partitions=client_datasets_noniid,
            model_fn=create_dino_vit_s16_for_cifar100,
            device=DEVICE,
            valset=valset,
            batch_size=BATCH_SIZE,
            local_epochs=LOCAL_EPOCHS,
            talos_config=TALOS_CONFIG,

        )
    )
    client_config = None
elif CLIENT_TYPE == clients.ClientType.TALOSPROX:
    client_app = ClientApp(
        client_fn=clients.build_client_talos_prox_fn(
            use_iid=IID,
            optimizer_type=OPTIMIZER_TYPE,
            optimizer_config=OPTIMIZER_CONFIG,
            scheduler_type=SCHEDULER_TYPE,
            scheduler_config=SCHEDULER_CONFIG,
            iid_partitions=client_datasets_iid,
            non_iid_partitions=client_datasets_noniid,
            model_fn=create_dino_vit_s16_for_cifar100,
            device=DEVICE,
            valset=valset,
            batch_size=BATCH_SIZE,
            local_epochs=LOCAL_EPOCHS,
            talos_config=TALOS_CONFIG,
            mu=MU
        )
    )
    client_config = {"mu": MU}

elif CLIENT_TYPE == clients.ClientType.TALOSPFEDEDIT:
    client_app = ClientApp(
        client_fn=clients.build_client_fn_talos_pfededit(
            use_iid=IID,
            optimizer_type=OPTIMIZER_TYPE,
            optimizer_config=OPTIMIZER_CONFIG,
            scheduler_type=SCHEDULER_TYPE,
            scheduler_config=SCHEDULER_CONFIG,
            iid_partitions=client_datasets_iid,
            non_iid_partitions=client_datasets_noniid,
            model_fn=create_dino_vit_s16_for_cifar100,
            device=DEVICE,
            valset=valset,
            batch_size=BATCH_SIZE,
            local_epochs=LOCAL_EPOCHS,
            talos_config=TALOS_CONFIG,
            pfededit_config = PFEDEDIT_CONFIG,
            rounds_stochastic = NUM_ROUNDS,
            deterministic_round = DETERMINISTIC_ROUND,
            all_rounds_scheduling = ALL_ROUNDS_SCHEDULING,
            reverse_mode = REVERSE_MODE

        )
    )
    client_config = {"mu": MU}

if STRATEGY_TYPE == strategies.StrategiesType.FEDAVG:
    strategy_config = None
elif STRATEGY_TYPE == strategies.StrategiesType.METAFEDAVG:
    strategy_config = {"inner_lr": INNER_LR}
else:
    strategy_config = {"eta": ETA, "eta_l": ETA_L, "tau": TAU, "beta_1": BETA_1, "beta_2": BETA_2,}


if not IID:
  DataConfig = {"use_iid": IID, "num_classes": NC}
else:
  DataConfig = {"use_iid": IID}


logger = FederatedWandBLogger(
    project_name="federated-learning-project",
    run_name=run_name,
    global_config={
        # Federated Learning Configuration
        "data_config": DataConfig,
        "local_epochs": LOCAL_EPOCHS,
        "batch_size": BATCH_SIZE,
        "num_clients": NUM_CLIENTS,
        "fraction_fit": FRACTION_FIT,
        "fraction_evaluate": FRACTION_EVAL,
        "num_rounds": NUM_ROUNDS,
        "backbone_freezing": BACKBONE_FREEZING,
        "client_type": CLIENT_TYPE.value,
        "client_config": client_config,
        "strategy_type": STRATEGY_TYPE.value,
        "strategy_config": strategy_config,
        "val_split": VAL_SPLIT,

        # Optimizer Configuration
        "optimizer": OPTIMIZER_TYPE.value,
        "optimizer_config": OPTIMIZER_CONFIG,

        # Scheduler Configuration
        "scheduler": SCHEDULER_TYPE,
        "scheduler_config": SCHEDULER_CONFIG,

        "talos_config": TALOS_CONFIG,

        "pfededit_config": PFEDEDIT_CONFIG


    }
)

if STRATEGY_TYPE == strategies.StrategiesType.FEDAVG:
    strategy = strategies.FedAvgStandard(
        logger=logger,
        initial_parameters=flower_parameters,
        fraction_fit=FRACTION_FIT,
        min_fit_clients=int(FRACTION_FIT*NUM_CLIENTS),
        min_evaluate_clients=int(FRACTION_EVAL*NUM_CLIENTS),
        fraction_evaluate=FRACTION_EVAL,
        min_available_clients=NUM_CLIENTS,
        evaluate_metrics_aggregation_fn = lambda metrics: {
            "accuracy": sum(num * m["val_accuracy"] for num, m in metrics) / sum(num for num, _ in metrics)
        }
    )
elif STRATEGY_TYPE == strategies.StrategiesType.METAFEDAVG:
    strategy = strategies.MetaFedAvg(
        logger=logger,
        inner_lr=INNER_LR,
        initial_parameters=flower_parameters,
        fraction_fit=FRACTION_FIT,
        min_fit_clients=int(FRACTION_FIT*NUM_CLIENTS),
        min_evaluate_clients=int(FRACTION_EVAL*NUM_CLIENTS),
        fraction_evaluate=FRACTION_EVAL,
        min_available_clients=NUM_CLIENTS,
        evaluate_metrics_aggregation_fn = lambda metrics: {
            "accuracy": sum(num * m["val_accuracy"] for num, m in metrics) / sum(num for num, _ in metrics)
        }
    )
else:
    strategy = strategies.FedYogiStandard(
        logger=logger,
        initial_parameters=flower_parameters,
        fraction_fit=FRACTION_FIT,
        min_fit_clients=int(FRACTION_FIT*NUM_CLIENTS),
        min_evaluate_clients=int(FRACTION_EVAL*NUM_CLIENTS),
        fraction_evaluate=FRACTION_EVAL,
        min_available_clients=NUM_CLIENTS,
        eta=ETA,
        eta_l=ETA_L,
        tau=TAU,
        beta_1=BETA_1,
        beta_2=BETA_2,
        evaluate_metrics_aggregation_fn = lambda metrics: {
            "accuracy": sum(num * m["val_accuracy"] for num, m in metrics) / sum(num for num, _ in metrics)
        }
    )

def server_fn(context: Context) -> ServerAppComponents:
    """Construct components that set the ServerApp behaviour.

    You can use the settings in `context.run_config` to parameterize the
    construction of all elements (e.g the strategy or the number of rounds)
    wrapped in the returned ServerAppComponents object.
    """
    config = ServerConfig(num_rounds=NUM_ROUNDS)
    return ServerAppComponents(strategy=strategy, config=config)


server_app = ServerApp(server_fn=server_fn)

  return LooseVersion(v) >= LooseVersion(check)
[34m[1mwandb[0m: Currently logged in as: [33ms348517giuseppe[0m ([33ms348517giuseppe-politecnico-di-torino[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
  self.scope.user = {"email": email}  # noqa


In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 0.0}}

if device.type == "cuda":
    backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 1.0}}

In [13]:
run_simulation(
    server_app=server_app,
    client_app=client_app,
    num_supernodes=NUM_CLIENTS,
    backend_config=backend_config
)

DEBUG:flwr:Asyncio event loop already running.
[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=25, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 1 clients (out of 10)
[36m(pid=10078)[0m 2025-06-02 15:04:45.103512: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=10078)[0m E0000 00:00:1748876685.135903   10078 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=10078)[0m E0000 00:00:1748876685.145920   10

[36m(ClientAppActor pid=10078)[0m LOG: Initializing client with CID=3
[36m(ClientAppActor pid=10078)[0m 3-LOG: Data partition assigned to client 3 -> Non-IID
[36m(ClientAppActor pid=10078)[0m 3-LOG: Model initialized for client 3
[36m(ClientAppActor pid=10078)[0m 3-LOG: Dataloaders initialized for client 3
[36m(ClientAppActor pid=10078)[0m 3-LOG: Initialized with FedProx μ = 0.1
[36m(ClientAppActor pid=10078)[0m Current Round: 1
[36m(ClientAppActor pid=10078)[0m Total Rounds: 10
[36m(ClientAppActor pid=10078)[0m 3-LOG: Using normal mode for scheduling
[36m(ClientAppActor pid=10078)[0m 3-LOG: Round 1 | Dynamic Stochastic Factor: 1.0000
[36m(ClientAppActor pid=10078)[0m 3-LOG: Random Sample: 0.8276
[36m(ClientAppActor pid=10078)[0m 3-LOG: Stochastic Sampling Activated (stochastic_factor=1.0000), Full Model Training
[36m(ClientAppActor pid=10078)[0m 🟢 Pruning will be applied to the entire model.
[36m(ClientAppActor pid=10078)[0m 3-LOG: Starting TaLoS calibration 

[92mINFO [0m:      aggregate_fit: received 1 results and 0 failures


Aggregating fit results for round 1


[92mINFO [0m:      configure_evaluate: no clients selected, skipping evaluation
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 2]
[92mINFO [0m:      configure_fit: strategy sampled 1 clients (out of 10)


[36m(ClientAppActor pid=10078)[0m LOG: Initializing client with CID=3
[36m(ClientAppActor pid=10078)[0m 3-LOG: Data partition assigned to client 3 -> Non-IID
[36m(ClientAppActor pid=10078)[0m 3-LOG: Model initialized for client 3
[36m(ClientAppActor pid=10078)[0m 3-LOG: Dataloaders initialized for client 3
[36m(ClientAppActor pid=10078)[0m 3-LOG: Initialized with FedProx μ = 0.1
[36m(ClientAppActor pid=10078)[0m Current Round: 2
[36m(ClientAppActor pid=10078)[0m Total Rounds: 10
[36m(ClientAppActor pid=10078)[0m 3-LOG: Using normal mode for scheduling
[36m(ClientAppActor pid=10078)[0m 3-LOG: Round 2 | Dynamic Stochastic Factor: 0.8889
[36m(ClientAppActor pid=10078)[0m 3-LOG: Random Sample: 0.0765
[36m(ClientAppActor pid=10078)[0m 3-LOG: Stochastic Sampling Activated (stochastic_factor=0.8889), Full Model Training
[36m(ClientAppActor pid=10078)[0m 🟢 Pruning will be applied to the entire model.
[36m(ClientAppActor pid=10078)[0m 3-LOG: Starting TaLoS calibration 

[92mINFO [0m:      aggregate_fit: received 1 results and 0 failures


Aggregating fit results for round 2


[92mINFO [0m:      configure_evaluate: no clients selected, skipping evaluation
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 3]
[92mINFO [0m:      configure_fit: strategy sampled 1 clients (out of 10)


[36m(ClientAppActor pid=10078)[0m LOG: Initializing client with CID=6
[36m(ClientAppActor pid=10078)[0m 6-LOG: Data partition assigned to client 6 -> Non-IID
[36m(ClientAppActor pid=10078)[0m 6-LOG: Model initialized for client 6
[36m(ClientAppActor pid=10078)[0m 6-LOG: Dataloaders initialized for client 6
[36m(ClientAppActor pid=10078)[0m 6-LOG: Initialized with FedProx μ = 0.1
[36m(ClientAppActor pid=10078)[0m Current Round: 3
[36m(ClientAppActor pid=10078)[0m Total Rounds: 10
[36m(ClientAppActor pid=10078)[0m 6-LOG: Using normal mode for scheduling
[36m(ClientAppActor pid=10078)[0m 6-LOG: Round 3 | Dynamic Stochastic Factor: 0.7778
[36m(ClientAppActor pid=10078)[0m 6-LOG: Random Sample: 0.5250
[36m(ClientAppActor pid=10078)[0m 6-LOG: Stochastic Sampling Activated (stochastic_factor=0.7778), Full Model Training
[36m(ClientAppActor pid=10078)[0m 🟢 Pruning will be applied to the entire model.
[36m(ClientAppActor pid=10078)[0m 6-LOG: Starting TaLoS calibration 

[92mINFO [0m:      aggregate_fit: received 1 results and 0 failures


Aggregating fit results for round 3


[92mINFO [0m:      configure_evaluate: no clients selected, skipping evaluation
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 4]
[92mINFO [0m:      configure_fit: strategy sampled 1 clients (out of 10)


[36m(ClientAppActor pid=10078)[0m LOG: Initializing client with CID=5
[36m(ClientAppActor pid=10078)[0m 
[36m(ClientAppActor pid=10078)[0m 5-LOG: Data partition assigned to client 5 -> Non-IID
[36m(ClientAppActor pid=10078)[0m 5-LOG: Model initialized for client 5
[36m(ClientAppActor pid=10078)[0m 5-LOG: Dataloaders initialized for client 5
[36m(ClientAppActor pid=10078)[0m 5-LOG: Initialized with FedProx μ = 0.1
[36m(ClientAppActor pid=10078)[0m Current Round: 4
[36m(ClientAppActor pid=10078)[0m Total Rounds: 10
[36m(ClientAppActor pid=10078)[0m 5-LOG: Using normal mode for scheduling
[36m(ClientAppActor pid=10078)[0m 5-LOG: Round 4 | Dynamic Stochastic Factor: 0.6667
[36m(ClientAppActor pid=10078)[0m 5-LOG: Random Sample: 0.5005
[36m(ClientAppActor pid=10078)[0m 5-LOG: Stochastic Sampling Activated (stochastic_factor=0.6667), Full Model Training
[36m(ClientAppActor pid=10078)[0m 🟢 Pruning will be applied to the entire model.
[36m(ClientAppActor pid=10078)[

[92mINFO [0m:      aggregate_fit: received 1 results and 0 failures


Aggregating fit results for round 4


[92mINFO [0m:      configure_evaluate: no clients selected, skipping evaluation
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 5]
[92mINFO [0m:      configure_fit: strategy sampled 1 clients (out of 10)


[36m(ClientAppActor pid=10078)[0m LOG: Initializing client with CID=6
[36m(ClientAppActor pid=10078)[0m 6-LOG: Data partition assigned to client 6 -> Non-IID
[36m(ClientAppActor pid=10078)[0m 6-LOG: Model initialized for client 6
[36m(ClientAppActor pid=10078)[0m 6-LOG: Dataloaders initialized for client 6
[36m(ClientAppActor pid=10078)[0m 6-LOG: Initialized with FedProx μ = 0.1
[36m(ClientAppActor pid=10078)[0m Current Round: 5
[36m(ClientAppActor pid=10078)[0m Total Rounds: 10
[36m(ClientAppActor pid=10078)[0m 6-LOG: Using normal mode for scheduling
[36m(ClientAppActor pid=10078)[0m 6-LOG: Round 5 | Dynamic Stochastic Factor: 0.5556
[36m(ClientAppActor pid=10078)[0m 6-LOG: Random Sample: 0.8655
[36m(ClientAppActor pid=10078)[0m 6-LOG: TaLoS Local Layer Selection (stochastic_factor=0.5556)
[36m(ClientAppActor pid=10078)[0m Max Batches: 4
[36m(ClientAppActor pid=10078)[0m 6-LOG: Selected Layers (Min Loss): [7, 4, 1]
[36m(ClientAppActor pid=10078)[0m 🟢 Prunin

[92mINFO [0m:      aggregate_fit: received 1 results and 0 failures


Aggregating fit results for round 5


[92mINFO [0m:      configure_evaluate: no clients selected, skipping evaluation
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 6]
[92mINFO [0m:      configure_fit: strategy sampled 1 clients (out of 10)


[36m(ClientAppActor pid=10078)[0m LOG: Initializing client with CID=4
[36m(ClientAppActor pid=10078)[0m 4-LOG: Data partition assigned to client 4 -> Non-IID
[36m(ClientAppActor pid=10078)[0m 4-LOG: Model initialized for client 4
[36m(ClientAppActor pid=10078)[0m 4-LOG: Dataloaders initialized for client 4
[36m(ClientAppActor pid=10078)[0m 4-LOG: Initialized with FedProx μ = 0.1
[36m(ClientAppActor pid=10078)[0m Current Round: 6
[36m(ClientAppActor pid=10078)[0m Total Rounds: 10
[36m(ClientAppActor pid=10078)[0m 4-LOG: Using normal mode for scheduling
[36m(ClientAppActor pid=10078)[0m 4-LOG: Round 6 | Dynamic Stochastic Factor: 0.4444
[36m(ClientAppActor pid=10078)[0m 4-LOG: Random Sample: 0.5250
[36m(ClientAppActor pid=10078)[0m 4-LOG: TaLoS Local Layer Selection (stochastic_factor=0.4444)
[36m(ClientAppActor pid=10078)[0m Max Batches: 4
[36m(ClientAppActor pid=10078)[0m 4-LOG: Selected Layers (Min Loss): [6, 1, 0]
[36m(ClientAppActor pid=10078)[0m 🟢 Prunin

[92mINFO [0m:      aggregate_fit: received 1 results and 0 failures


Aggregating fit results for round 6


[92mINFO [0m:      configure_evaluate: no clients selected, skipping evaluation
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 7]
[92mINFO [0m:      configure_fit: strategy sampled 1 clients (out of 10)


[36m(ClientAppActor pid=10078)[0m LOG: Initializing client with CID=3
[36m(ClientAppActor pid=10078)[0m 3-LOG: Data partition assigned to client 3 -> Non-IID
[36m(ClientAppActor pid=10078)[0m 3-LOG: Model initialized for client 3
[36m(ClientAppActor pid=10078)[0m 3-LOG: Dataloaders initialized for client 3
[36m(ClientAppActor pid=10078)[0m 3-LOG: Initialized with FedProx μ = 0.1
[36m(ClientAppActor pid=10078)[0m Current Round: 7
[36m(ClientAppActor pid=10078)[0m Total Rounds: 10
[36m(ClientAppActor pid=10078)[0m 3-LOG: Using normal mode for scheduling
[36m(ClientAppActor pid=10078)[0m 3-LOG: Round 7 | Dynamic Stochastic Factor: 0.3333
[36m(ClientAppActor pid=10078)[0m 3-LOG: Random Sample: 0.0125
[36m(ClientAppActor pid=10078)[0m 3-LOG: Stochastic Sampling Activated (stochastic_factor=0.3333), Full Model Training
[36m(ClientAppActor pid=10078)[0m 🟢 Pruning will be applied to the entire model.
[36m(ClientAppActor pid=10078)[0m 3-LOG: Starting TaLoS calibration 

[92mINFO [0m:      aggregate_fit: received 1 results and 0 failures


Aggregating fit results for round 7


[92mINFO [0m:      configure_evaluate: no clients selected, skipping evaluation
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 8]
[92mINFO [0m:      configure_fit: strategy sampled 1 clients (out of 10)


[36m(ClientAppActor pid=10078)[0m LOG: Initializing client with CID=5
[36m(ClientAppActor pid=10078)[0m 5-LOG: Data partition assigned to client 5 -> Non-IID
[36m(ClientAppActor pid=10078)[0m 5-LOG: Model initialized for client 5
[36m(ClientAppActor pid=10078)[0m 5-LOG: Dataloaders initialized for client 5
[36m(ClientAppActor pid=10078)[0m 5-LOG: Initialized with FedProx μ = 0.1
[36m(ClientAppActor pid=10078)[0m Current Round: 8
[36m(ClientAppActor pid=10078)[0m Total Rounds: 10
[36m(ClientAppActor pid=10078)[0m 5-LOG: Using normal mode for scheduling
[36m(ClientAppActor pid=10078)[0m 5-LOG: Round 8 | Dynamic Stochastic Factor: 0.2222
[36m(ClientAppActor pid=10078)[0m 5-LOG: Random Sample: 0.7331
[36m(ClientAppActor pid=10078)[0m 5-LOG: TaLoS Local Layer Selection (stochastic_factor=0.2222)
[36m(ClientAppActor pid=10078)[0m Max Batches: 4
[36m(ClientAppActor pid=10078)[0m 5-LOG: Selected Layers (Min Loss): [10, 0, 6]
[36m(ClientAppActor pid=10078)[0m 🟢 Pruni

[92mINFO [0m:      aggregate_fit: received 1 results and 0 failures


Aggregating fit results for round 8


[92mINFO [0m:      configure_evaluate: no clients selected, skipping evaluation
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 9]
[92mINFO [0m:      configure_fit: strategy sampled 1 clients (out of 10)


[36m(ClientAppActor pid=10078)[0m LOG: Initializing client with CID=8
[36m(ClientAppActor pid=10078)[0m 8-LOG: Data partition assigned to client 8 -> Non-IID
[36m(ClientAppActor pid=10078)[0m 8-LOG: Model initialized for client 8
[36m(ClientAppActor pid=10078)[0m 8-LOG: Dataloaders initialized for client 8
[36m(ClientAppActor pid=10078)[0m 8-LOG: Initialized with FedProx μ = 0.1
[36m(ClientAppActor pid=10078)[0m Current Round: 9
[36m(ClientAppActor pid=10078)[0m Total Rounds: 10
[36m(ClientAppActor pid=10078)[0m 8-LOG: Using normal mode for scheduling
[36m(ClientAppActor pid=10078)[0m 8-LOG: Round 9 | Dynamic Stochastic Factor: 0.1111
[36m(ClientAppActor pid=10078)[0m 8-LOG: Random Sample: 0.5138
[36m(ClientAppActor pid=10078)[0m 8-LOG: TaLoS Local Layer Selection (stochastic_factor=0.1111)
[36m(ClientAppActor pid=10078)[0m Max Batches: 4
[36m(ClientAppActor pid=10078)[0m 8-LOG: Selected Layers (Min Loss): [6, 7, 1]
[36m(ClientAppActor pid=10078)[0m 🟢 Prunin

[92mINFO [0m:      aggregate_fit: received 1 results and 0 failures


[36m(ClientAppActor pid=10078)[0m 8-LOG: Completed local training - Loss: 4.8195 | Accuracy: 0.0437
Aggregating fit results for round 9


[92mINFO [0m:      configure_evaluate: no clients selected, skipping evaluation
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 10]
[92mINFO [0m:      configure_fit: strategy sampled 1 clients (out of 10)


[36m(ClientAppActor pid=10078)[0m LOG: Initializing client with CID=4
[36m(ClientAppActor pid=10078)[0m 4-LOG: Data partition assigned to client 4 -> Non-IID
[36m(ClientAppActor pid=10078)[0m 4-LOG: Model initialized for client 4
[36m(ClientAppActor pid=10078)[0m 4-LOG: Dataloaders initialized for client 4
[36m(ClientAppActor pid=10078)[0m 4-LOG: Initialized with FedProx μ = 0.1
[36m(ClientAppActor pid=10078)[0m Current Round: 10
[36m(ClientAppActor pid=10078)[0m Total Rounds: 10
[36m(ClientAppActor pid=10078)[0m 4-LOG: Using normal mode for scheduling
[36m(ClientAppActor pid=10078)[0m 4-LOG: Round 10 | Dynamic Stochastic Factor: 0.0000
[36m(ClientAppActor pid=10078)[0m 4-LOG: Random Sample: 0.9081
[36m(ClientAppActor pid=10078)[0m 4-LOG: TaLoS Local Layer Selection (stochastic_factor=0.0000)
[36m(ClientAppActor pid=10078)[0m Max Batches: 4
[36m(ClientAppActor pid=10078)[0m 4-LOG: Selected Layers (Min Loss): [6, 10, 1]
[36m(ClientAppActor pid=10078)[0m 🟢 Pru

[92mINFO [0m:      aggregate_fit: received 1 results and 0 failures


[36m(ClientAppActor pid=10078)[0m 4-LOG: Completed local training - Loss: 4.5428 | Accuracy: 0.0772
Aggregating fit results for round 10


[92mINFO [0m:      configure_evaluate: no clients selected, skipping evaluation
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 11]
[92mINFO [0m:      configure_fit: strategy sampled 1 clients (out of 10)


[36m(ClientAppActor pid=10078)[0m LOG: Initializing client with CID=3
[36m(ClientAppActor pid=10078)[0m 3-LOG: Data partition assigned to client 3 -> Non-IID
[36m(ClientAppActor pid=10078)[0m 3-LOG: Model initialized for client 3
[36m(ClientAppActor pid=10078)[0m 3-LOG: Dataloaders initialized for client 3
[36m(ClientAppActor pid=10078)[0m 3-LOG: Initialized with FedProx μ = 0.1
[36m(ClientAppActor pid=10078)[0m Current Round: 11
[36m(ClientAppActor pid=10078)[0m Total Rounds: 10
[36m(ClientAppActor pid=10078)[0m 3-LOG: Using normal mode for scheduling
[36m(ClientAppActor pid=10078)[0m 3-LOG: Round 11 | Dynamic Stochastic Factor: 0.0000
[36m(ClientAppActor pid=10078)[0m 3-LOG: Random Sample: 0.3807
[36m(ClientAppActor pid=10078)[0m 3-LOG: TaLoS Local Layer Selection (stochastic_factor=0.0000)
[36m(ClientAppActor pid=10078)[0m Max Batches: 4
[36m(ClientAppActor pid=10078)[0m 3-LOG: Selected Layers (Min Loss): [9, 2, 1]
[36m(ClientAppActor pid=10078)[0m 🟢 Prun

KeyboardInterrupt: 

In [None]:
final_parameters = strategy.latest_parameters

# Convert Flower parameters to list of numpy arrays
ndarrays = parameters_to_ndarrays(final_parameters)

# Load them into a PyTorch model
iid_model = create_dino_vit_s16_for_cifar100()
iid_model.to(device)
state_dict = OrderedDict(
    (key, torch.tensor(val)) for key, val in zip(iid_model.state_dict().keys(), ndarrays)
)
iid_model.load_state_dict(state_dict)

correct, total, loss_total = 0, 0, 0.0
criterion = torch.nn.CrossEntropyLoss()

with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = iid_model(images)
        loss = criterion(outputs, labels)

        loss_total += loss.item() * labels.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_loss = loss_total / total
test_accuracy = correct / total

logger.log_global_metrics({
    "test_loss": test_loss,
    "test_accuracy": test_accuracy
}, round_number=NUM_ROUNDS)

model_save_path = "final_federated_model.pth"
torch.save(iid_model.state_dict(), model_save_path)

logger.log_model(iid_model, path=model_save_path)

logger.finish()

print(f"✅ Test Accuracy: {test_accuracy:.4f} | Test Loss: {test_loss:.4f}")