In [1]:
!pip install -q flwr torch torchvision tensorboard
!pip install -U "flwr[simulation]"
!pip install timm



In [2]:
from datetime import datetime
from functools import partial
from flwr.server import ServerConfig
import random
import numpy as np
import torch.nn as nn
from torchvision import transforms
import timm
import torch
import os
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import random_split, DataLoader
from torch.utils.tensorboard import SummaryWriter
import flwr as fl
from flwr.simulation import run_simulation
from flwr.common import Context
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.client import ClientApp
from flwr.common import parameters_to_ndarrays, ndarrays_to_parameters
from flwr.common import Parameters
from collections import OrderedDict
from wandb_logger import FederatedWandBLogger
import data_utils
import clients
import strategies
import data_preprocessing

In [3]:
def build_optimizer_config(optimizer_type):
    if optimizer_type == clients.OptimizerType.SGD:
        config = {
            "lr": LR,
            "momentum": MOMENTUM,
            "weight_decay": WEIGHT_DECAY,
            "nesterov": NESTEROV
        }

    elif optimizer_type == clients.OptimizerType.SSGD:
        config = {
            "lr": LR,
            "momentum": MOMENTUM,
            "weight_decay": WEIGHT_DECAY,
        }


    elif optimizer_type == clients.OptimizerType.ADAM:
        config = {
            "lr": LR,
            "betas": BETAS,
            "weight_decay": WEIGHT_DECAY,
            "eps": EPSILON
        }

    elif optimizer_type == clients.OptimizerType.ADAMW:
        config = {
            "lr": LR,
            "betas": BETAS,
            "weight_decay": WEIGHT_DECAY
        }

    elif optimizer_type == clients.OptimizerType.SADAMW:
        config = {
            "lr": LR,
            "betas": BETAS,
            "weight_decay": WEIGHT_DECAY,
            "eps": EPSILON
        }

    elif optimizer_type == clients.OptimizerType.RMSPROP:
        config = {
            "lr": LR,
            "alpha": ALPHA,
            "eps": EPSILON,
            "weight_decay": WEIGHT_DECAY,
            "momentum": MOMENTUM,
            "centered": CENTERED
        }

    elif optimizer_type == clients.OptimizerType.ADAGRAD:
        config = {
            "lr": LR,
            "weight_decay": WEIGHT_DECAY,
            "eps": EPSILON
        }
    else:
        raise ValueError(f"Unsupported optimizer type: {optimizer_type}")

    print(f"Optimizer '{optimizer_type}' initialized successfully.")
    return config

In [4]:
def build_scheduler_config(scheduler_type):
    scheduler_type.lower()

    if scheduler_type == "cosine":
        config = {
            "T_max": T_MAX,
            "eta_min": ETA_MIN
        }

    elif scheduler_type == "cosine_restart":
        config = {
            "T_0": T_MAX,
            "T_mult": 1,
            "eta_min": ETA_MIN
        }

    elif scheduler_type == "step":
        config = {
            "step_size": STEP_SIZE,
            "gamma": GAMMA
        }

    elif scheduler_type == "multistep":
        config = {
            "milestones": MILESTONES,
            "gamma": GAMMA
        }

    elif scheduler_type == "exponential":
        config = {
            "gamma": GAMMA
        }

    elif scheduler_type == "reduce_on_plateau":
        config = {
            "mode": "min",
            "factor": FACTOR,
            "patience": PATIENCE,
            "threshold": THRESHOLD
        }

    elif scheduler_type == "constant":
        config = {
            "factor": 1.0,
            "total_iters": TOTAL_ITERS
        }

    elif scheduler_type == "linear":
        config = {
            "start_factor": 0.5,
            "end_factor": LR_END,
            "total_iters": TOTAL_ITERS
        }
    else:
        raise ValueError(f"Unsupported scheduler type: {scheduler_type}")

    print(f"Scheduler '{scheduler_type}' initialized successfully.")
    return config

In [5]:
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)

In [6]:
# THE WHOLE TRAINING SETTINGS ARE HERE!

# GENERAL
BATCH_SIZE = 256
EPOCHS = 15
VAL_SPLIT = 0.1
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BACKBONE_FREEZING = False

# FL
NUM_CLIENTS = 100
NC = 20
LOCAL_EPOCHS = 4
NUM_ROUNDS = 8
FRACTION_FIT = 0.1
FRACTION_EVAL = 0.03
CLIENT_TYPE = clients.ClientType.TALOSPROX
STRATEGY_TYPE = strategies.StrategiesType.YOGI
IID = False

# LOSS
SMOOTHING=0.1

# Optimizer Hyperparameters
LR = 0.00005
MOMENTUM = 0.9
WEIGHT_DECAY = 0.025
NESTEROV = False
BETAS = (0.9, 0.999)
EPSILON = 1e-8
ALPHA = 0.99
CENTERED = False
OPTIMIZER_TYPE  = clients.OptimizerType.SADAMW

# Scheduler Hyperparameters
T_MAX = LOCAL_EPOCHS
ETA_MIN = 0.0
STEP_SIZE = 30
GAMMA = 0.1
MILESTONES = [50, 75]
FACTOR = 0.1
PATIENCE = 5
THRESHOLD = 1e-4
TOTAL_ITERS = 10
LR_END = 0.0
SCHEDULER_TYPE  = "cosine"

TALOS_CONFIG = {
    "final_sparsity": 0.6,
    "num_batches": 3,
    "rounds": 4,
    "calibration_mode": "least_sensitive", # least_sensitive or most_sensitive
    "mode": "pfededit", # full if want to include all layers, head for head only, pfededit for custom topk client selection
    "k": 3, #pfedit setting
    "factor": 0.5 #probability of not doing pfededit
}

# FedProx

MU = 0.1

# YOGI

ETA = 0.5         # Learning rate for Yogi updates
ETA_L = 0.25      # Learning rate for local updates
TAU = 0.1            # Adaptive learning parameter for Yogi
BETA_1 = 0.9         # Momentum term (default value from the original paper)
BETA_2 = 0.999       # Second momentum term (default value from the original paper)

# METAFEDAVG

INNER_LR = 0.01

OPTIMIZER_CONFIG = build_optimizer_config(OPTIMIZER_TYPE)
SCHEDULER_CONFIG = build_scheduler_config(SCHEDULER_TYPE)

Optimizer 'OptimizerType.SADAMW' initialized successfully.
Scheduler 'cosine' initialized successfully.


In [7]:
pipeline = data_preprocessing.CIFAR100Pipeline(val_split=VAL_SPLIT, use_augment=True)
trainset, valset, testset = pipeline.run_pipeline()

trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
valloader = DataLoader(valset, batch_size=BATCH_SIZE)
testloader = DataLoader(testset, batch_size=BATCH_SIZE)

# Apply sharding
client_datasets_iid = data_utils.iid_split(trainset, NUM_CLIENTS)
client_datasets_noniid = data_utils.non_iid_split(trainset, NUM_CLIENTS, NC, 100)

print("Setup complete. IID and Non-IID splits created.")

Setup complete. IID and Non-IID splits created.


In [8]:
# Create model
def create_dino_vit_s16_for_cifar100(freezing=BACKBONE_FREEZING):
    model = timm.create_model("vit_small_patch16_224_dino", pretrained=True, num_classes=0)

    # Replace the head with CIFAR-100 classification head
    model.head = nn.Linear(model.num_features, 100)

    if freezing:
      # Freeze all parameters except head
      for param in model.parameters():
          param.requires_grad = False

      # Unfreeze only the head
      for param in model.head.parameters():
          param.requires_grad = True
    else:

      for param in model.parameters():
          param.requires_grad = True

      for param in model.head.parameters():
          param.requires_grad = True

    return model

In [9]:
model = create_dino_vit_s16_for_cifar100()

print(next(model.parameters()).device)
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total = sum(p.numel() for p in model.parameters())
print(f"Trainable params: {trainable:,} / {total:,}")

initial_parameters = [val.cpu().numpy() for _, val in model.state_dict().items()]
flower_parameters = ndarrays_to_parameters(initial_parameters)

  model = create_fn(
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


cpu
Trainable params: 21,704,164 / 21,704,164


In [10]:
!wandb login f8ad3703c9023ee3f86c7242a87d9280b6c031fb

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


In [11]:
data_distribution = "IID_DATA" if IID else "NON_IID_DATA"
run_name = f"FEDERATED_{data_distribution}_TALOS_{datetime.now().strftime('%Y%m%d-%H%M%S')}"

if CLIENT_TYPE == clients.ClientType.TALOS:
    client_app = ClientApp(
        client_fn=clients.build_client_talos_fn(
            use_iid=IID,
            optimizer_type=OPTIMIZER_TYPE,
            optimizer_config=OPTIMIZER_CONFIG,
            scheduler_type=SCHEDULER_TYPE,
            scheduler_config=SCHEDULER_CONFIG,
            iid_partitions=client_datasets_iid,
            non_iid_partitions=client_datasets_noniid,
            model_fn=create_dino_vit_s16_for_cifar100,
            device=DEVICE,
            valset=valset,
            batch_size=BATCH_SIZE,
            local_epochs=LOCAL_EPOCHS,
            talos_config=TALOS_CONFIG
        )
    )
    client_config = None
elif CLIENT_TYPE == clients.ClientType.TALOSPROX:
    client_app = ClientApp(
        client_fn=clients.build_client_talos_prox_fn(
            use_iid=IID,
            optimizer_type=OPTIMIZER_TYPE,
            optimizer_config=OPTIMIZER_CONFIG,
            scheduler_type=SCHEDULER_TYPE,
            scheduler_config=SCHEDULER_CONFIG,
            iid_partitions=client_datasets_iid,
            non_iid_partitions=client_datasets_noniid,
            model_fn=create_dino_vit_s16_for_cifar100,
            device=DEVICE,
            valset=valset,
            batch_size=BATCH_SIZE,
            local_epochs=LOCAL_EPOCHS,
            talos_config=TALOS_CONFIG,
            mu=MU
        )
    )
    client_config = {"mu": MU}

if STRATEGY_TYPE == strategies.StrategiesType.FEDAVG:
    strategy_config = None
elif STRATEGY_TYPE == strategies.StrategiesType.METAFEDAVG:
    strategy_config = {"inner_lr": INNER_LR}
else:
    strategy_config = {"eta": ETA, "eta_l": ETA_L, "tau": TAU, "beta_1": BETA_1, "beta_2": BETA_2,}


if not IID:
  DataConfig = {"use_iid": IID, "num_classes": NC}
else:
  DataConfig = {"use_iid": IID}


logger = FederatedWandBLogger(
    project_name="federated-learning-project",
    run_name=run_name,
    global_config={
        # Federated Learning Configuration
        "data_config": DataConfig,
        "local_epochs": LOCAL_EPOCHS,
        "batch_size": BATCH_SIZE,
        "num_clients": NUM_CLIENTS,
        "fraction_fit": FRACTION_FIT,
        "fraction_evaluate": FRACTION_EVAL,
        "num_rounds": NUM_ROUNDS,
        "backbone_freezing": BACKBONE_FREEZING,
        "client_type": CLIENT_TYPE.value,
        "client_config": client_config,
        "strategy_type": STRATEGY_TYPE.value,
        "strategy_config": strategy_config,
        "val_split": VAL_SPLIT,

        # Optimizer Configuration
        "optimizer": OPTIMIZER_TYPE.value,
        "optimizer_config": OPTIMIZER_CONFIG,

        # Scheduler Configuration
        "scheduler": SCHEDULER_TYPE,
        "scheduler_config": SCHEDULER_CONFIG,

        "talos_config": TALOS_CONFIG,
    }
)

if STRATEGY_TYPE == strategies.StrategiesType.FEDAVG:
    strategy = strategies.FedAvgStandard(
        logger=logger,
        initial_parameters=flower_parameters,
        fraction_fit=FRACTION_FIT,
        min_fit_clients=int(FRACTION_FIT*NUM_CLIENTS),
        min_evaluate_clients=int(FRACTION_EVAL*NUM_CLIENTS),
        fraction_evaluate=FRACTION_EVAL,
        min_available_clients=NUM_CLIENTS,
        evaluate_metrics_aggregation_fn = lambda metrics: {
            "accuracy": sum(num * m["val_accuracy"] for num, m in metrics) / sum(num for num, _ in metrics)
        }
    )
elif STRATEGY_TYPE == strategies.StrategiesType.METAFEDAVG:
    strategy = strategies.MetaFedAvg(
        logger=logger,
        inner_lr=INNER_LR,
        initial_parameters=flower_parameters,
        fraction_fit=FRACTION_FIT,
        min_fit_clients=int(FRACTION_FIT*NUM_CLIENTS),
        min_evaluate_clients=int(FRACTION_EVAL*NUM_CLIENTS),
        fraction_evaluate=FRACTION_EVAL,
        min_available_clients=NUM_CLIENTS,
        evaluate_metrics_aggregation_fn = lambda metrics: {
            "accuracy": sum(num * m["val_accuracy"] for num, m in metrics) / sum(num for num, _ in metrics)
        }
    )
else:
    strategy = strategies.FedYogiStandard(
        logger=logger,
        initial_parameters=flower_parameters,
        fraction_fit=FRACTION_FIT,
        min_fit_clients=int(FRACTION_FIT*NUM_CLIENTS),
        min_evaluate_clients=int(FRACTION_EVAL*NUM_CLIENTS),
        fraction_evaluate=FRACTION_EVAL,
        min_available_clients=NUM_CLIENTS,
        eta=ETA,
        eta_l=ETA_L,
        tau=TAU,
        beta_1=BETA_1,
        beta_2=BETA_2,
        evaluate_metrics_aggregation_fn = lambda metrics: {
            "accuracy": sum(num * m["val_accuracy"] for num, m in metrics) / sum(num for num, _ in metrics)
        }
    )

def server_fn(context: Context) -> ServerAppComponents:
    """Construct components that set the ServerApp behaviour.

    You can use the settings in `context.run_config` to parameterize the
    construction of all elements (e.g the strategy or the number of rounds)
    wrapped in the returned ServerAppComponents object.
    """
    config = ServerConfig(num_rounds=NUM_ROUNDS)
    return ServerAppComponents(strategy=strategy, config=config)


server_app = ServerApp(server_fn=server_fn)

  return LooseVersion(v) >= LooseVersion(check)
[34m[1mwandb[0m: Currently logged in as: [33ms339170[0m ([33mpolito-fl[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
  self.scope.user = {"email": email}  # noqa


In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 0.0}}

if device.type == "cuda":
    backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 1.0}}

In [None]:
run_simulation(
    server_app=server_app,
    client_app=client_app,
    num_supernodes=NUM_CLIENTS,
    backend_config=backend_config
)

DEBUG:flwr:Asyncio event loop already running.
[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=8, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Using initial global parameters provided by strategy
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 10 clients (out of 100)
[36m(pid=5073)[0m 2025-05-15 13:47:20.720685: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=5073)[0m E0000 00:00:1747316840.755845    5073 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=5073)[0m E0000 00:00:1747316840.766880    507

[36m(ClientAppActor pid=5073)[0m LOG: Initializing client with CID=11
[36m(ClientAppActor pid=5073)[0m 11-LOG: Data partition assigned to client 11 -> Non-IID
[36m(ClientAppActor pid=5073)[0m 11-LOG: Model initialized for client 11
[36m(ClientAppActor pid=5073)[0m 11-LOG: Dataloaders initialized for client 11


[36m(ClientAppActor pid=5073)[0m   self.scaler = GradScaler()


[36m(ClientAppActor pid=5073)[0m ⚙️ PFedEdit Mode Activated
[36m(ClientAppActor pid=5073)[0m 11-LOG: PFedEdit Local Layer Selection
[36m(ClientAppActor pid=5073)[0m 11-LOG: Selected Layers (Min Loss): [0, 1, 2]
[36m(ClientAppActor pid=5073)[0m 🟢 Pruning will be applied to the entire model.
[36m(ClientAppActor pid=5073)[0m 11-LOG: Starting TaLoS Mask Calibration with mode: least_sensitive
[36m(ClientAppActor pid=5073)[0m 🔎 Starting multi-round calibration for mode 'full'.
[36m(ClientAppActor pid=5073)[0m 🌀 Calibration Round 1/4
[36m(ClientAppActor pid=5073)[0m 📝 Calculating Fisher Information on 3 batches...


[36m(ClientAppActor pid=5073)[0m   with autocast():


[36m(ClientAppActor pid=5073)[0m ✅ Fisher Information Computation Completed.
[36m(ClientAppActor pid=5073)[0m 🌀 Calibration Round 2/4
[36m(ClientAppActor pid=5073)[0m 📝 Calculating Fisher Information on 3 batches...
[36m(ClientAppActor pid=5073)[0m ✅ Fisher Information Computation Completed.
[36m(ClientAppActor pid=5073)[0m 🌀 Calibration Round 3/4
[36m(ClientAppActor pid=5073)[0m 📝 Calculating Fisher Information on 3 batches...
[36m(ClientAppActor pid=5073)[0m ✅ Fisher Information Computation Completed.
[36m(ClientAppActor pid=5073)[0m 🌀 Calibration Round 4/4
[36m(ClientAppActor pid=5073)[0m 📝 Calculating Fisher Information on 3 batches...
[36m(ClientAppActor pid=5073)[0m ✅ Fisher Information Computation Completed.
[36m(ClientAppActor pid=5073)[0m ✅ Mask Calibration Completed!
[36m(ClientAppActor pid=5073)[0m 🔍 Mapping parameters to their masks...
[36m(ClientAppActor pid=5073)[0m ✅ Mapped 150 parameters to masks.
[36m(ClientAppActor pid=5073)[0m 11-LOG: Init

[36m(ClientAppActor pid=5073)[0m   with autocast():


[36m(ClientAppActor pid=5073)[0m 11-LOG: Starting epoch 2/4
[36m(ClientAppActor pid=5073)[0m 11-LOG: Starting epoch 3/4
[36m(ClientAppActor pid=5073)[0m 11-LOG: Starting epoch 4/4
[36m(ClientAppActor pid=5073)[0m 11-LOG: Completed local training - Loss: 5.8209 | Accuracy: 0.0367
[36m(ClientAppActor pid=5073)[0m LOG: Initializing client with CID=15
[36m(ClientAppActor pid=5073)[0m 15-LOG: Data partition assigned to client 15 -> Non-IID
[36m(ClientAppActor pid=5073)[0m 15-LOG: Model initialized for client 15
[36m(ClientAppActor pid=5073)[0m 15-LOG: Dataloaders initialized for client 15
[36m(ClientAppActor pid=5073)[0m ⚙️ PFedEdit Mode Activated
[36m(ClientAppActor pid=5073)[0m 15-LOG: PFedEdit Local Layer Selection
[36m(ClientAppActor pid=5073)[0m 15-LOG: Selected Layers (Min Loss): [0, 1, 2]
[36m(ClientAppActor pid=5073)[0m 🟢 Pruning will be applied to the entire model.
[36m(ClientAppActor pid=5073)[0m 15-LOG: Starting TaLoS Mask Calibration with mode: least_se

In [14]:
final_parameters = strategy.latest_parameters

# Convert Flower parameters to list of numpy arrays
ndarrays = parameters_to_ndarrays(final_parameters)

# Load them into a PyTorch model
iid_model = create_dino_vit_s16_for_cifar100()
iid_model.to(device)
state_dict = OrderedDict(
    (key, torch.tensor(val)) for key, val in zip(iid_model.state_dict().keys(), ndarrays)
)
iid_model.load_state_dict(state_dict)

correct, total, loss_total = 0, 0, 0.0
criterion = torch.nn.CrossEntropyLoss()

with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = iid_model(images)
        loss = criterion(outputs, labels)

        loss_total += loss.item() * labels.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_loss = loss_total / total
test_accuracy = correct / total

logger.log_global_metrics({
    "test_loss": test_loss,
    "test_accuracy": test_accuracy
}, round_number=NUM_ROUNDS)

model_save_path = "final_federated_model.pth"
torch.save(iid_model.state_dict(), model_save_path)

logger.log_model(iid_model, path=model_save_path)

logger.finish()

print(f"✅ Test Accuracy: {test_accuracy:.4f} | Test Loss: {test_loss:.4f}")

  model = create_fn(


0,1
client_0/train_accuracy,▂▁▄▅▅▇▇█
client_0/train_loss,██▅▅▃▂▂▁
client_1/train_accuracy,▁▂▄▅▅▆▇█
client_1/train_loss,█▇▅▃▄▃▁▁
client_2/train_accuracy,▁▁▄▅▅▆█▇
client_2/train_loss,██▄▃▄▂▁▂
client_3/train_accuracy,▁▄▄▅▅▇▇█
client_3/train_loss,█▅▅▄▃▂▂▁
client_4/train_accuracy,▁▂▅▄▆▆██
client_4/train_loss,█▇▄▄▃▃▁▁

0,1
client_0/train_accuracy,0.84412
client_0/train_loss,0.54479
client_1/train_accuracy,0.87695
client_1/train_loss,0.42577
client_2/train_accuracy,0.80312
client_2/train_loss,0.72841
client_3/train_accuracy,0.84741
client_3/train_loss,0.53301
client_4/train_accuracy,0.87139
client_4/train_loss,0.45359


✅ Test Accuracy: 0.4784 | Test Loss: 2.1179
