In [None]:
!pip install -q flwr torch torchvision tensorboard
!pip install -U "flwr[simulation]"
!pip install timm



In [None]:
from datetime import datetime
from functools import partial
from flwr.server import ServerConfig
import random
import numpy as np
import torch.nn as nn
from torchvision import transforms
import timm
import torch
import os
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import random_split, DataLoader
from torch.utils.tensorboard import SummaryWriter
import flwr as fl
from flwr.simulation import run_simulation
from flwr.common import Context
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.client import ClientApp
from flwr.common import parameters_to_ndarrays
from collections import OrderedDict
from wandb_logger import FederatedWandBLogger
import data_utils
import clients
import strategies
import data_preprocessing

In [None]:
def build_optimizer_config(optimizer_type):
    if optimizer_type == clients.OptimizerType.SGD:
        config = {
            "lr": LR,
            "momentum": MOMENTUM,
            "weight_decay": WEIGHT_DECAY,
            "nesterov": NESTEROV
        }

    elif optimizer_type == clients.OptimizerType.ADAM:
        config = {
            "lr": LR,
            "betas": BETAS,
            "weight_decay": WEIGHT_DECAY,
            "eps": EPSILON
        }

    elif optimizer_type == clients.OptimizerType.ADAMW:
        config = {
            "lr": LR,
            "betas": BETAS,
            "weight_decay": WEIGHT_DECAY
        }

    elif optimizer_type == clients.OptimizerType.RMSPROP:
        config = {
            "lr": LR,
            "alpha": ALPHA,
            "eps": EPSILON,
            "weight_decay": WEIGHT_DECAY,
            "momentum": MOMENTUM,
            "centered": CENTERED
        }

    elif optimizer_type == clients.OptimizerType.ADAGRAD:
        config = {
            "lr": LR,
            "weight_decay": WEIGHT_DECAY,
            "eps": EPSILON
        }
    else:
        raise ValueError(f"Unsupported optimizer type: {optimizer_type}")

    print(f"Optimizer '{optimizer_type}' initialized successfully.")
    return config

In [None]:
def build_scheduler_config(scheduler_type):

    if scheduler_type == clients.SchedulerType.COSINE:
        config = {
            "T_max": T_MAX,
            "eta_min": ETA_MIN
        }

    elif scheduler_type == clients.SchedulerType.COSINE_RESTART:
        config = {
            "T_0": T_MAX,
            "T_mult": 1,
            "eta_min": ETA_MIN
        }

    elif scheduler_type == clients.SchedulerType.STEP:
        config = {
            "step_size": STEP_SIZE,
            "gamma": GAMMA
        }

    elif scheduler_type == clients.SchedulerType.MULTISTEP:
        config = {
            "milestones": MILESTONES,
            "gamma": GAMMA
        }

    elif scheduler_type == clients.SchedulerType.EXPONENTIAL:
        config = {
            "gamma": GAMMA
        }

    elif scheduler_type == clients.SchedulerType.REDUCE_ON_PLATEAU:
        config = {
            "mode": "min",
            "factor": FACTOR,
            "patience": PATIENCE,
            "threshold": THRESHOLD
        }

    elif scheduler_type == clients.SchedulerType.CONSTANT:
        config = {
            "factor": 1.0,
            "total_iters": TOTAL_ITERS
        }

    elif scheduler_type == clients.SchedulerType.LINEAR:
        config = {
            "start_factor": 0.5,
            "end_factor": LR_END,
            "total_iters": TOTAL_ITERS
        }

    else:
        raise ValueError(f"Unsupported scheduler type: {scheduler_type}")

    print(f"Scheduler '{scheduler_type}' initialized successfully.")
    return config

In [None]:
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)

In [None]:
# THE WHOLE TRAINING SETTINGS ARE HERE!

# GENERAL
BATCH_SIZE = 32
EPOCHS = 15
VAL_SPLIT = 0.1
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BACKBONE_FREEZING = True

# FL
NUM_CLIENTS = 100
NC = 50
LOCAL_EPOCHS = 4
NUM_ROUNDS = 8
FRACTION_FIT = 0.1
FRACTION_EVAL = 0.03
CLIENT_TYPE = clients.ClientType.FEDPROX
STRATEGY_TYPE = strategies.StrategiesType.FEDAVG
IID = False

# LOSS
SMOOTHING=0.1

# Optimizer Hyperparameters
LR = 0.03
MOMENTUM = 0.9
WEIGHT_DECAY = 1e-4
NESTEROV = False
BETAS = (0.9, 0.999)
EPSILON = 1e-8
ALPHA = 0.99
CENTERED = False
OPTIMIZER_TYPE  = clients.OptimizerType.SGD

# Scheduler Hyperparameters
T_MAX = 8
ETA_MIN = 1e-6
STEP_SIZE = 30
GAMMA = 0.1
MILESTONES = [50, 75]
FACTOR = 0.1
PATIENCE = 5
THRESHOLD = 1e-4
TOTAL_ITERS = 10
LR_END = 0.0
SCHEDULER_TYPE  = clients.SchedulerType.COSINE

# FedProx

MU = 0.1

# YOGI

ETA = 0.01           # Learning rate for Yogi updates
ETA_L = 0.0316      # Learning rate for local updates
TAU = 0.1            # Adaptive learning parameter for Yogi
BETA_1 = 0.9         # Momentum term (default value from the original paper)
BETA_2 = 0.999       # Second momentum term (default value from the original paper)

OPTIMIZER_CONFIG = build_optimizer_config(OPTIMIZER_TYPE)
SCHEDULER_CONFIG = build_scheduler_config(SCHEDULER_TYPE)

Optimizer 'OptimizerType.SGD' initialized successfully.
Scheduler 'SchedulerType.COSINE' initialized successfully.


In [None]:
pipeline = data_preprocessing.CIFAR100Pipeline(val_split=VAL_SPLIT, use_augment=True)
trainset, valset, testset = pipeline.run_pipeline()

trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
valloader = DataLoader(valset, batch_size=BATCH_SIZE)
testloader = DataLoader(testset, batch_size=BATCH_SIZE)

# Apply sharding
client_datasets_iid = data_utils.iid_split(trainset, NUM_CLIENTS)
client_datasets_noniid = None
if not IID:
    client_datasets_noniid = data_utils.non_iid_split(trainset, NUM_CLIENTS, NC, 100)

print("Setup complete. IID and Non-IID splits created.")

Setup complete. IID and Non-IID splits created.


In [None]:
# Create model
def create_dino_vit_s16_for_cifar100(freezing=BACKBONE_FREEZING):
    model = timm.create_model("vit_small_patch16_224_dino", pretrained=True, num_classes=0)

    # Replace the head with CIFAR-100 classification head
    model.head = nn.Linear(model.num_features, 100)

    if freezing:
      # Freeze all parameters except head
      for param in model.parameters():
          param.requires_grad = False

      # Unfreeze only the head
      for param in model.head.parameters():
          param.requires_grad = True

    return model

In [None]:
!wandb login f8ad3703c9023ee3f86c7242a87d9280b6c031fb

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
data_distribution = "IID_DATA" if IID else "NON_IID_DATA"
run_name = f"FEDERATED_{data_distribution}_{datetime.now().strftime('%Y%m%d-%H%M%S')}"

if CLIENT_TYPE == clients.ClientType.STANDARD:
    client_app = ClientApp(
        client_fn=clients.build_client_fn(
            use_iid=IID,
            optimizer_type=OPTIMIZER_TYPE,
            scheduler_type=SCHEDULER_TYPE,
            optimizer_config=OPTIMIZER_CONFIG,
            scheduler_config=SCHEDULER_CONFIG,
            iid_partitions=client_datasets_iid,
            non_iid_partitions=client_datasets_noniid,
            model_fn=create_dino_vit_s16_for_cifar100,
            device=DEVICE,
            valset=valset,
            local_epochs=LOCAL_EPOCHS,
            batch_size=BATCH_SIZE
        )
    )
    client_config = None
else:
    client_app = ClientApp(
    client_fn=clients.build_client_fedprox_fn(
        use_iid=IID,
        optimizer_type=OPTIMIZER_TYPE,
        scheduler_type=SCHEDULER_TYPE,
        optimizer_config=OPTIMIZER_CONFIG,
        scheduler_config=SCHEDULER_CONFIG,
        iid_partitions=client_datasets_iid,
        non_iid_partitions=client_datasets_noniid,
        model_fn=create_dino_vit_s16_for_cifar100,
        device=DEVICE,
        valset=valset,
        local_epochs=LOCAL_EPOCHS,
        batch_size=BATCH_SIZE,
        mu=MU
    )
    )
    client_config = {"mu": MU}


if STRATEGY_TYPE == strategies.StrategiesType.FEDAVG:
    strategy_config = None
else:
    strategy_config = {"eta": ETA, "eta_l": ETA_L, "tau": TAU, "beta_1": BETA_1, "beta_2": BETA_2,}


if not IID:
  DataConfig = {"use_iid": IID, "num_classes": NC}
else:
  DataConfig = {"use_iid": IID}

logger = FederatedWandBLogger(
    project_name="federated-learning-project",
    run_name=run_name,
    global_config={
        # Federated Learning Configuration
        "data_config": DataConfig,
        "local_epochs": LOCAL_EPOCHS,
        "batch_size": BATCH_SIZE,
        "num_clients": NUM_CLIENTS,
        "fraction_fit": FRACTION_FIT,
        "fraction_evaluate": FRACTION_EVAL,
        "backbone_freezing": BACKBONE_FREEZING,
        "client_type": CLIENT_TYPE.value,
        "client_config": client_config,
        "strategy_type": STRATEGY_TYPE.value,
        "strategy_config": strategy_config,
        "val_split": VAL_SPLIT,

        # Optimizer Configuration
        "optimizer": OPTIMIZER_TYPE.value,
        "optimizer_config": OPTIMIZER_CONFIG,

        # Scheduler Configuration
        "scheduler": SCHEDULER_TYPE.value,
        "scheduler_config": SCHEDULER_CONFIG,
    }
)

if STRATEGY_TYPE == strategies.StrategiesType.FEDAVG:
    strategy = strategies.FedAvgStandard(
        logger=logger,
        fraction_fit=FRACTION_FIT,
        min_fit_clients=int(FRACTION_FIT*NUM_CLIENTS),
        min_evaluate_clients=int(FRACTION_EVAL*NUM_CLIENTS),
        fraction_evaluate=FRACTION_EVAL,
        min_available_clients=NUM_CLIENTS,
        evaluate_metrics_aggregation_fn = lambda metrics: {
            "accuracy": sum(num * m["val_accuracy"] for num, m in metrics) / sum(num for num, _ in metrics)
        }
    )
else:
    strategy = strategies.FedYogiStandard(
        logger=logger,
        fraction_fit=FRACTION_FIT,
        min_fit_clients=int(FRACTION_FIT*NUM_CLIENTS),
        min_evaluate_clients=int(FRACTION_EVAL*NUM_CLIENTS),
        fraction_evaluate=FRACTION_EVAL,
        min_available_clients=NUM_CLIENTS,
        eta=ETA,
        eta_l =ETA_L,
        tau=TAU,
        beta_1=BETA_1,
        beta_2=BETA_2,
        evaluate_metrics_aggregation_fn = lambda metrics: {
            "accuracy": sum(num * m["val_accuracy"] for num, m in metrics) / sum(num for num, _ in metrics)
        }
    )

def server_fn(context: Context) -> ServerAppComponents:
    """Construct components that set the ServerApp behaviour.

    You can use the settings in `context.run_config` to parameterize the
    construction of all elements (e.g the strategy or the number of rounds)
    wrapped in the returned ServerAppComponents object.
    """
    config = ServerConfig(num_rounds=NUM_ROUNDS)
    return ServerAppComponents(strategy=strategy, config=config)

server_app = ServerApp(server_fn=server_fn)

  return LooseVersion(v) >= LooseVersion(check)
[34m[1mwandb[0m: Currently logged in as: [33ms339170[0m ([33mpolito-fl[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
  self.scope.user = {"email": email}  # noqa


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 0.0}}

if device.type == "cuda":
    backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 1.0}}

In [None]:
run_simulation(
    server_app=server_app,
    client_app=client_app,
    num_supernodes=NUM_CLIENTS,
    backend_config=backend_config
)

DEBUG:flwr:Asyncio event loop already running.
[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=8, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[36m(pid=50957)[0m 2025-05-12 11:17:19.802809: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=50957)[0m E0000 00:00:1747048639.840505   50957 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=50957)[0m E0000 00:00:1747048639.851141   50957 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


[36m(ClientAppActor pid=50957)[0m LOG: Initializing client with CID=80
[36m(ClientAppActor pid=50957)[0m 80-LOG: Data partition assigned to client 80 -> Non-IID


[36m(ClientAppActor pid=50957)[0m   model = create_fn(


[36m(ClientAppActor pid=50957)[0m 80-LOG: Model initialized for client 80
[36m(ClientAppActor pid=50957)[0m 80-LOG: Dataloaders initialized for client 80


[36m(ClientAppActor pid=50957)[0m   self.scaler = GradScaler()
[92mINFO [0m:      Received initial parameters from one random client
[92mINFO [0m:      Starting evaluation of initial global parameters
[92mINFO [0m:      Evaluation returned no results (`None`)
[92mINFO [0m:      
[92mINFO [0m:      [ROUND 1]
[92mINFO [0m:      configure_fit: strategy sampled 10 clients (out of 100)


[36m(ClientAppActor pid=50957)[0m LOG: Initializing client with CID=12
[36m(ClientAppActor pid=50957)[0m 12-LOG: Data partition assigned to client 12 -> Non-IID
[36m(ClientAppActor pid=50957)[0m 12-LOG: Model initialized for client 12
[36m(ClientAppActor pid=50957)[0m 12-LOG: Dataloaders initialized for client 12
[36m(ClientAppActor pid=50957)[0m 12-LOG: Starting local training round with FedProx (mu=0.1)
[36m(ClientAppActor pid=50957)[0m 12-LOG: Starting epoch 1/4


[36m(ClientAppActor pid=50957)[0m   with autocast():


[36m(ClientAppActor pid=50957)[0m 12-LOG: Starting epoch 2/4
[36m(ClientAppActor pid=50957)[0m 12-LOG: Starting epoch 3/4
[36m(ClientAppActor pid=50957)[0m 12-LOG: Starting epoch 4/4
[36m(ClientAppActor pid=50957)[0m 12-LOG: Completed local training - Loss: 25.9132 | Accuracy: 0.5361
[36m(ClientAppActor pid=50957)[0m LOG: Initializing client with CID=24
[36m(ClientAppActor pid=50957)[0m 24-LOG: Data partition assigned to client 24 -> Non-IID
[36m(ClientAppActor pid=50957)[0m 24-LOG: Model initialized for client 24
[36m(ClientAppActor pid=50957)[0m 24-LOG: Dataloaders initialized for client 24
[36m(ClientAppActor pid=50957)[0m 24-LOG: Starting local training round with FedProx (mu=0.1)
[36m(ClientAppActor pid=50957)[0m 24-LOG: Starting epoch 1/4
[36m(ClientAppActor pid=50957)[0m 24-LOG: Starting epoch 2/4
[36m(ClientAppActor pid=50957)[0m 24-LOG: Starting epoch 3/4
[36m(ClientAppActor pid=50957)[0m 24-LOG: Starting epoch 4/4
[36m(ClientAppActor pid=50957)[0m 

[92mINFO [0m:      aggregate_fit: received 10 results and 0 failures
[92mINFO [0m:      configure_evaluate: strategy sampled 3 clients (out of 100)


[36m(ClientAppActor pid=50957)[0m LOG: Initializing client with CID=14
[36m(ClientAppActor pid=50957)[0m 14-LOG: Data partition assigned to client 14 -> Non-IID
[36m(ClientAppActor pid=50957)[0m 14-LOG: Model initialized for client 14
[36m(ClientAppActor pid=50957)[0m 14-LOG: Dataloaders initialized for client 14
[36m(ClientAppActor pid=50957)[0m 14-LOG: Starting evaluation
[36m(ClientAppActor pid=50957)[0m 14-LOG: Evaluation completed - Val Loss: 10.1311 | Val Accuracy: 0.3090
[36m(ClientAppActor pid=50957)[0m LOG: Initializing client with CID=20
[36m(ClientAppActor pid=50957)[0m 20-LOG: Data partition assigned to client 20 -> Non-IID
[36m(ClientAppActor pid=50957)[0m 20-LOG: Model initialized for client 20
[36m(ClientAppActor pid=50957)[0m 20-LOG: Dataloaders initialized for client 20
[36m(ClientAppActor pid=50957)[0m 20-LOG: Starting evaluation
[36m(ClientAppActor pid=50957)[0m 20-LOG: Evaluation completed - Val Loss: 10.1311 | Val Accuracy: 0.3090
[36m(Clie

KeyboardInterrupt: 

In [None]:
final_parameters = strategy.latest_parameters

# Convert Flower parameters to list of numpy arrays
ndarrays = parameters_to_ndarrays(final_parameters)

# Load them into a PyTorch model
iid_model = create_dino_vit_s16_for_cifar100()
iid_model.to(device)
state_dict = OrderedDict(
    (key, torch.tensor(val)) for key, val in zip(iid_model.state_dict().keys(), ndarrays)
)
iid_model.load_state_dict(state_dict)

correct, total, loss_total = 0, 0, 0.0
criterion = torch.nn.CrossEntropyLoss()

with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = iid_model(images)
        loss = criterion(outputs, labels)

        loss_total += loss.item() * labels.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_loss = loss_total / total
test_accuracy = correct / total

logger.log_global_metrics({
    "test_loss": test_loss,
    "test_accuracy": test_accuracy
}, round_number=NUM_ROUNDS)

model_save_path = "final_federated_model.pth"
torch.save(iid_model.state_dict(), model_save_path)

logger.log_model(iid_model, path=model_save_path)

logger.finish()

print(f"✅ Test Accuracy: {test_accuracy:.4f} | Test Loss: {test_loss:.4f}")