In [1]:
!pip install -q flwr torch torchvision tensorboard
!pip install -U "flwr[simulation]"
!pip install timm

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m540.0/540.0 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m35.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m32.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m39.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
# upload data_utils, clients, data_preprocessing and strategies py modules

In [2]:
from datetime import datetime
from functools import partial
from flwr.server import ServerConfig
import random
import numpy as np
import torch.nn as nn
from torchvision import transforms
import timm
import torch
import os
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import random_split, DataLoader
from torch.utils.tensorboard import SummaryWriter
import flwr as fl
from flwr.simulation import run_simulation
from flwr.common import Context
from flwr.server import ServerApp, ServerAppComponents, ServerConfig
from flwr.client import ClientApp
from flwr.common import parameters_to_ndarrays
from collections import OrderedDict
from wandb_logger import FederatedWandBLogger
import data_utils
import clients
import strategies
import data_preprocessing

In [3]:
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)

In [4]:
DATA_DIR = './data'
NUM_CLIENTS = 5
NC = 20
VAL_SPLIT = 0.1
BATCH_SIZE = 256
LOCAL_EPOCHS = 2
NUM_ROUNDS = 2

In [5]:
pipeline = data_preprocessing.CIFAR100Pipeline(val_split=VAL_SPLIT, use_augment=True)
trainset, valset, testset = pipeline.run_pipeline()

trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
valloader = DataLoader(valset, batch_size=BATCH_SIZE)
testloader = DataLoader(testset, batch_size=BATCH_SIZE)

# Apply sharding
client_datasets_iid = data_utils.iid_split(trainset, NUM_CLIENTS)
client_datasets_noniid = data_utils.non_iid_split(trainset, NUM_CLIENTS, NC, 100)

# Setup logging/checkpoint directories
LOG_DIR = "/content/fed_logs"
CKPT_PATH = "/content/fed_checkpoints/model.pt"

print("Setup complete. IID and Non-IID splits created.")

100%|██████████| 169M/169M [00:05<00:00, 33.6MB/s]


Setup complete. IID and Non-IID splits created.


In [6]:
# Create model
def create_dino_vit_s16_for_cifar100(freezing=True):
    model = timm.create_model("vit_small_patch16_224_dino", pretrained=True, num_classes=0)

    # Replace the head with CIFAR-100 classification head
    model.head = nn.Linear(model.num_features, 100)

    if freezing:
      # Freeze all parameters except head
      for param in model.parameters():
          param.requires_grad = False

      # Unfreeze only the head
      for param in model.head.parameters():
          param.requires_grad = True

    return model

In [7]:
optimizer_config = {
    "lr": 0.01,               # Learning rate (required)
    "momentum": 0.9,          # Momentum term
    "dampening": 0.0,         # Dampening for momentum (usually 0)
    "weight_decay": 0.0005,   # L2 regularization
    "nesterov": False         # Whether to use Nesterov accelerated gradients
}

scheduler_config = {
    "T_max": 100,     # Required: max number of iterations (can be total batches or epochs)
    "eta_min": 0.0,   # Minimum learning rate (default is 0.0)
    "last_epoch": -1  # Use -1 to start from scratch
}

In [8]:
#set batchsize in client class
client_app = ClientApp(
    client_fn=clients.build_client_fn(
        use_iid=True,
        optimizer_type=clients.OptimizerType.SGD,
        scheduler_type=clients.SchedulerType.COSINE,
        optimizer_config=optimizer_config,
        scheduler_config=scheduler_config,
        iid_partitions=client_datasets_iid,
        non_iid_partitions=client_datasets_noniid,
        model_fn=create_dino_vit_s16_for_cifar100,
        device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        valset=valset,
        local_epochs=LOCAL_EPOCHS
    )
)

In [9]:
# paste your api key

In [10]:
!wandb login f8ad3703c9023ee3f86c7242a87d9280b6c031fb

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


In [12]:
run_name = f"FL-run_iid_data_{datetime.now().strftime('%Y%m%d-%H%M%S')}"

logger = FederatedWandBLogger(
    project_name="federated-learning-project",
    run_name=run_name,
    global_config={
        "learning_rate": optimizer_config.get("lr"),
        "momentum": optimizer_config.get("momentum"),
        "weight_decay": optimizer_config.get("weight_decay"),
        "nesterov": optimizer_config.get("nesterov"),
        "scheduler": clients.SchedulerType.COSINE,
        "scheduler_T_max": scheduler_config.get("T_max"),
        "scheduler_eta_min": scheduler_config.get("eta_min"),
        "scheduler_last_epoch": scheduler_config.get("last_epoch"),
        "use_iid": True,   # Change to False if non-IID
        "local_epochs": LOCAL_EPOCHS,
        "batch_size": BATCH_SIZE,
        "optimizer": clients.OptimizerType.SGD,  # Replace with AdamW if needed
        "rounds": NUM_ROUNDS
    }
)

strategy = strategies.FedAvg(
    logger=logger,
    fraction_fit=1.0,
    min_fit_clients=NUM_CLIENTS,
    min_available_clients=NUM_CLIENTS,
    evaluate_metrics_aggregation_fn = lambda metrics: {
        "accuracy": sum(num * m["val_accuracy"] for num, m in metrics) / sum(num for num, _ in metrics)
    }
)

  return LooseVersion(v) >= LooseVersion(check)
[34m[1mwandb[0m: Currently logged in as: [33ms339170[0m ([33mpolito-fl[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
  self.scope.user = {"email": email}  # noqa


In [13]:
def server_fn(context: Context) -> ServerAppComponents:
    """Construct components that set the ServerApp behaviour.

    You can use the settings in `context.run_config` to parameterize the
    construction of all elements (e.g the strategy or the number of rounds)
    wrapped in the returned ServerAppComponents object.
    """
    config = ServerConfig(num_rounds=NUM_ROUNDS)
    return ServerAppComponents(strategy=strategy, config=config)


server_app = ServerApp(server_fn=server_fn)

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 0.0}}

if device.type == "cuda":
    backend_config = {"client_resources": {"num_cpus": 1, "num_gpus": 1.0}}

In [15]:
run_simulation(
    server_app=server_app,
    client_app=client_app,
    num_supernodes=NUM_CLIENTS,
    backend_config=backend_config
)

DEBUG:flwr:Asyncio event loop already running.
[92mINFO [0m:      Starting Flower ServerApp, config: num_rounds=2, no round_timeout
[92mINFO [0m:      
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[36m(pid=1869)[0m 2025-05-07 11:11:00.758386: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=1869)[0m E0000 00:00:1746616260.778781    1869 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
[36m(pid=1869)[0m E0000 00:00:1746616260.784901    1869 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
[36m(ClientAppActor pid=1869)[0m   model = create_fn(
[36m(ClientAppActor pid=1869)[0m   self.scaler = GradScaler()

In [19]:
final_parameters = strategy.latest_parameters

# Convert Flower parameters to list of numpy arrays
ndarrays = parameters_to_ndarrays(final_parameters)

# Load them into a PyTorch model
iid_model = create_dino_vit_s16_for_cifar100()
iid_model.to(device)
state_dict = OrderedDict(
    (key, torch.tensor(val)) for key, val in zip(iid_model.state_dict().keys(), ndarrays)
)
iid_model.load_state_dict(state_dict)

correct, total, loss_total = 0, 0, 0.0
criterion = torch.nn.CrossEntropyLoss()

with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = iid_model(images)
        loss = criterion(outputs, labels)

        loss_total += loss.item() * labels.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_loss = loss_total / total
test_accuracy = correct / total

logger.log_global_metrics({
    "test_loss": test_loss,
    "test_accuracy": test_accuracy
}, round_number=NUM_ROUNDS)

model_save_path = "final_federated_model.pth"
torch.save(iid_model.state_dict(), model_save_path)

logger.log_model(iid_model, path=model_save_path)

logger.finish()

print(f"✅ Test Accuracy: {test_accuracy:.4f} | Test Loss: {test_loss:.4f}")

0,1
client_0/train_accuracy,▁█
client_0/train_loss,█▁
client_1/train_accuracy,▁█
client_1/train_loss,█▁
client_2/train_accuracy,▁█
client_2/train_loss,█▁
client_3/train_accuracy,▁█
client_3/train_loss,█▁
client_4/train_accuracy,▁█
client_4/train_loss,█▁

0,1
client_0/train_accuracy,0.68917
client_0/train_loss,1.48847
client_1/train_accuracy,0.69056
client_1/train_loss,1.51969
client_2/train_accuracy,0.68411
client_2/train_loss,1.5296
client_3/train_accuracy,0.67583
client_3/train_loss,1.69556
client_4/train_accuracy,0.68789
client_4/train_loss,1.5225


✅ Test Accuracy: 0.7486 | Test Loss: 1.1065
✅ Test Accuracy: 0.7486 | Test Loss: 1.1065
