# Federated Learning POC: Chest X-ray with ConvNet2

This notebook demonstrates a federated learning setup using NVFlare to train the `ConvNet2` model (or the LoRA-optimized version) on the Chest X-ray dataset distributed across 3 sites (`site1`, `site2`, `site3`).

## Objectives:
1. **Pre-train** the backbone on a small data subset to initialize the global model.
2. Compare **Full Fine-tuning** vs **LoRA** (Low-Rank Adaptation).
3. Use **Mixed Precision** training for faster execution.
4. Simulate 3 federated sites with `FedAvg`.

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import transforms, datasets
from model import ConvNet2, LoRAConvNet2
import numpy as np

## 1. Pre-training Initialization

Instead of starting with random weights, we pre-train the model backbone on a small subset of the training data (e.g., from site 1). This stable initialization helps with convergence in federated environments.

In [2]:
def get_pretrain_loader(data_path, batch_size=16, num_samples=100):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.2, 0.2, 0.2]),
    ])
    
    # Use site1 training data for pre-training
    site1_train_path = os.path.join(data_path, "site1", "train")
    full_dataset = datasets.ImageFolder(root=site1_train_path, transform=transform)
    
    # Take a small subset
    indices = np.random.choice(len(full_dataset), num_samples, replace=False)
    subset = Subset(full_dataset, indices)
    
    return DataLoader(subset, batch_size=batch_size, shuffle=True)

def pretrain_model(model, loader, epochs=2, device="cpu"):
    model.to(device)
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    
    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Pre-train Epoch {epoch+1} Loss: {running_loss/len(loader):.4f}")
    
    return model

data_path = os.path.abspath("chest_xray")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("---- Starting Pre-training Initialization ----")
base_model = ConvNet2()
pretrain_loader = get_pretrain_loader(data_path)
initialized_model = pretrain_model(base_model, pretrain_loader, device=device)
print("---- Pre-training Done ----")

---- Starting Pre-training Initialization ----
Pre-train Epoch 1 Loss: 0.6932
Pre-train Epoch 2 Loss: 0.6731
---- Pre-training Done ----


## 2. Define the FedJob Recipe

We can now choose between the **Full** model or the **LoRA** model. 

### Choice: LoRA vs Full
Change `USE_LORA` to `True` to use the Low-Rank Adaptation model. 

In [3]:
from nvflare.app_opt.pt.recipes.fedavg import FedAvgRecipe
from nvflare.recipe import SimEnv, add_experiment_tracking

USE_LORA = True # Set to False for Full fine-tuning
USE_AMP = True
n_clients = 3
num_rounds = 5
batch_size = 16

if USE_LORA:
    # 1. Pass the PRE-INITIALIZED model directly into the LoRA wrapper.
    # This avoids the load_state_dict key mismatch error.
    model_to_run = LoRAConvNet2(rank=8, base_model=initialized_model)
    
    model_type = "lora"
    job_name = "chest-xray-lora"
else:
    model_to_run = initialized_model
    model_type = "full"
    job_name = "chest-xray-full"

recipe = FedAvgRecipe(
    name=job_name,
    min_clients=n_clients,
    num_rounds=num_rounds,
    model=model_to_run,
    train_script="client_xray.py",
    train_args=f"--batch_size {batch_size} --epochs 1 --data_path {data_path} --model_type {model_type} --use_amp {USE_AMP}",
)

add_experiment_tracking(recipe, tracking_type="tensorboard")

2026-02-11 16:59:20.804068: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## 3. Run the Job in Simulation

We will now execute the federated learning job across 3 simulated clients.

In [4]:
env = SimEnv(num_clients=n_clients)
run = recipe.execute(env)

print()
print("Job Status is:", run.get_status())
print("Result can be found in :", run.get_result())
print()



[38m2026-02-11 16:59:29,198 - INFO - model selection weights control: {}[0m


2026-02-11 16:59:34.056604: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


[38m2026-02-11 16:59:38,682 - INFO - Tensorboard records can be found in /tmp/nvflare/simulation/chest-xray-lora/server/simulate_job/tb_events you can view it using `tensorboard --logdir=/tmp/nvflare/simulation/chest-xray-lora/server/simulate_job/tb_events`[0m
[38m2026-02-11 16:59:38,684 - INFO - Initializing BaseModelController workflow.[0m
[38m2026-02-11 16:59:38,684 - INFO - Beginning model controller run.[0m
[38m2026-02-11 16:59:38,685 - INFO - 
                                 Start FedAvg.                                  
[0m
[38m2026-02-11 16:59:38,685 - INFO - loading initial model from persistor[0m
[38m2026-02-11 16:59:38,686 - INFO - Both source_ckpt_file_full_name and ckpt_preload_path are not provided. Using the default model weights initialized on the persistor side.[0m
[38m2026-02-11 16:59:38,688 - INFO - 
--------------------------------------------------------------------------------
                                Round 0 started.                         



[38m2026-02-11 16:59:44,970 - INFO - start task run() with full path: /tmp/nvflare/simulation/chest-xray-lora/site-3/simulate_job/app_site-3/custom/client_xray.py[0m
[38m2026-02-11 16:59:44,978 - INFO - start task run() with full path: /tmp/nvflare/simulation/chest-xray-lora/site-1/simulate_job/app_site-1/custom/client_xray.py[0m
[38m2026-02-11 16:59:44,981 - INFO - start task run() with full path: /tmp/nvflare/simulation/chest-xray-lora/site-2/simulate_job/app_site-2/custom/client_xray.py[0m
[38m2026-02-11 16:59:45,069 - INFO - set transaction info: tx_id='T1f366e97-1c36-401a-a1b5-d2fb2e826885', ref_id='af3b9108-fd37-4e68-8850-fc80299ea746' self.num_receivers=1[0m
[38m2026-02-11 16:59:45,075 - INFO - set transaction info: tx_id='Te31798be-00db-4523-ad13-113cb53dd8c5', ref_id='e36625c8-5a34-479a-a832-e034d63c22e5' self.num_receivers=1[0m
[38m2026-02-11 16:59:45,085 - INFO - set transaction info: tx_id='Teedceb2d-7fb2-4ed4-a843-117e4d1a53bf', ref_id='f36e9731-2520-4dbe-8314-4



[38m2026-02-11 17:00:16,980 - INFO - Accuracy of the network: 50.00 %[0m
[38m2026-02-11 17:00:17,073 - INFO - Accuracy of the network: 50.00 %[0m
[38m2026-02-11 17:00:17,074 - INFO - Accuracy of the network: 60.87 %[0m




[38m2026-02-11 17:01:07,210 - INFO - [1,     4] loss: 0.679[0m
[38m2026-02-11 17:01:17,365 - INFO - Finished Training for site-2[0m
[38m2026-02-11 17:01:17,388 - INFO - site: site-2, sending model to server.[0m
[38m2026-02-11 17:01:17,504 - INFO - set transaction info: tx_id='Tac8a1346-b2e3-4380-b923-e3630ee1cb15', ref_id='ce0e0b1c-9843-4cc9-965d-760dae0172dc' self.num_receivers=1[0m
[38m2026-02-11 17:01:18,316 - INFO - object has been downloaded to all 1 receivers - clear cache[0m
[38m2026-02-11 17:01:18,388 - INFO - Aggregated 1/3 results[0m
[38m2026-02-11 17:01:23,660 - INFO - [1,     7] loss: 0.694[0m
[38m2026-02-11 17:01:24,416 - INFO - [1,     7] loss: 0.690[0m
[38m2026-02-11 17:01:33,726 - INFO - Finished Training for site-3[0m
[38m2026-02-11 17:01:33,743 - INFO - site: site-3, sending model to server.[0m
[38m2026-02-11 17:01:34,024 - INFO - set transaction info: tx_id='T2627a22e-bddb-4fdb-b52d-49aa3e6854ec', ref_id='b5fdf4a3-7e74-481a-989b-b07178bba761' sel



[38m2026-02-11 17:02:07,389 - INFO - Accuracy of the network: 50.00 %[0m
[38m2026-02-11 17:02:07,412 - INFO - Accuracy of the network: 60.87 %[0m
[38m2026-02-11 17:02:07,420 - INFO - Accuracy of the network: 50.00 %[0m




[38m2026-02-11 17:03:23,003 - INFO - [1,     4] loss: 0.660[0m
[38m2026-02-11 17:03:33,098 - INFO - Finished Training for site-2[0m
[38m2026-02-11 17:03:33,118 - INFO - site: site-2, sending model to server.[0m
[38m2026-02-11 17:03:33,516 - INFO - set transaction info: tx_id='Tdad404d9-e284-42a9-847b-c0a369d5c69f', ref_id='3be07476-990b-4e17-af1b-bd2c0ceeb076' self.num_receivers=1[0m
[38m2026-02-11 17:03:34,310 - INFO - object has been downloaded to all 1 receivers - clear cache[0m
[38m2026-02-11 17:03:34,364 - INFO - validation metric 50.0 from client site-2[0m
[38m2026-02-11 17:03:34,497 - INFO - Aggregated 1/3 results[0m
[38m2026-02-11 17:03:43,346 - INFO - [1,     7] loss: 0.685[0m
[38m2026-02-11 17:03:45,217 - INFO - [1,     7] loss: 0.687[0m
[38m2026-02-11 17:03:53,409 - INFO - Finished Training for site-3[0m
[38m2026-02-11 17:03:53,414 - INFO - site: site-3, sending model to server.[0m
[38m2026-02-11 17:03:53,544 - INFO - set transaction info: tx_id='T22ee



[38m2026-02-11 17:04:35,550 - INFO - Accuracy of the network: 60.87 %[0m
[38m2026-02-11 17:04:36,089 - INFO - Accuracy of the network: 50.00 %[0m
[38m2026-02-11 17:04:36,390 - INFO - Accuracy of the network: 50.00 %[0m




[38m2026-02-11 17:06:10,983 - INFO - [1,     4] loss: 0.640[0m
[38m2026-02-11 17:06:21,079 - INFO - Finished Training for site-2[0m
[38m2026-02-11 17:06:21,096 - INFO - site: site-2, sending model to server.[0m
[38m2026-02-11 17:06:21,619 - INFO - set transaction info: tx_id='T3bf872db-365e-4d87-941e-d96f0dff5a9a', ref_id='b884b192-17c9-4950-8904-1251fd8476b8' self.num_receivers=1[0m
[38m2026-02-11 17:06:22,692 - INFO - object has been downloaded to all 1 receivers - clear cache[0m
[38m2026-02-11 17:06:22,730 - INFO - validation metric 50.0 from client site-2[0m
[38m2026-02-11 17:06:22,857 - INFO - Aggregated 1/3 results[0m
[38m2026-02-11 17:06:32,401 - INFO - [1,     7] loss: 0.667[0m
[38m2026-02-11 17:06:33,278 - INFO - [1,     7] loss: 0.675[0m
[38m2026-02-11 17:06:42,447 - INFO - Finished Training for site-3[0m
[38m2026-02-11 17:06:42,451 - INFO - site: site-3, sending model to server.[0m
[38m2026-02-11 17:06:42,640 - INFO - set transaction info: tx_id='Tb8f2



[38m2026-02-11 17:07:30,707 - INFO - Accuracy of the network: 60.87 %[0m
[38m2026-02-11 17:07:32,433 - INFO - Accuracy of the network: 50.00 %[0m
[38m2026-02-11 17:07:32,963 - INFO - Accuracy of the network: 50.00 %[0m




[38m2026-02-11 17:08:44,586 - INFO - [1,     4] loss: 0.629[0m
[38m2026-02-11 17:08:54,662 - INFO - Finished Training for site-2[0m
[38m2026-02-11 17:08:54,676 - INFO - site: site-2, sending model to server.[0m
[38m2026-02-11 17:08:55,117 - INFO - set transaction info: tx_id='Tb0c87ffe-c517-496f-a36b-a6ca2c6e855b', ref_id='ff1fec7b-06ed-494a-bb9b-79082c333e3d' self.num_receivers=1[0m
[38m2026-02-11 17:08:55,909 - INFO - object has been downloaded to all 1 receivers - clear cache[0m
[38m2026-02-11 17:08:55,951 - INFO - validation metric 50.0 from client site-2[0m
[38m2026-02-11 17:08:56,084 - INFO - Aggregated 1/3 results[0m
[38m2026-02-11 17:08:58,457 - INFO - [1,     7] loss: 0.655[0m
[38m2026-02-11 17:09:00,630 - INFO - [1,     7] loss: 0.674[0m
[38m2026-02-11 17:09:08,506 - INFO - Finished Training for site-3[0m
[38m2026-02-11 17:09:08,511 - INFO - site: site-3, sending model to server.[0m
[38m2026-02-11 17:09:08,948 - INFO - set transaction info: tx_id='Td947



[38m2026-02-11 17:09:46,039 - INFO - Accuracy of the network: 57.50 %[0m
[38m2026-02-11 17:09:46,252 - INFO - Accuracy of the network: 50.00 %[0m
[38m2026-02-11 17:09:46,445 - INFO - Accuracy of the network: 60.87 %[0m




[38m2026-02-11 17:10:59,209 - INFO - [1,     4] loss: 0.615[0m
[38m2026-02-11 17:11:09,286 - INFO - Finished Training for site-2[0m
[38m2026-02-11 17:11:09,299 - INFO - site: site-2, sending model to server.[0m
[38m2026-02-11 17:11:09,540 - INFO - set transaction info: tx_id='T63415912-c2a5-4246-b2d0-7cfe126ebea5', ref_id='c2ecfac4-9d1b-4145-883e-a8dfe7813d6b' self.num_receivers=1[0m
[38m2026-02-11 17:11:10,612 - INFO - object has been downloaded to all 1 receivers - clear cache[0m
[38m2026-02-11 17:11:10,663 - INFO - validation metric 50.0 from client site-2[0m
[38m2026-02-11 17:11:10,908 - INFO - Aggregated 1/3 results[0m

Note, get_status returns None in SimEnv. The simulation logs can be found at /tmp/nvflare/simulation/chest-xray-lora
Job Status is: None
Note, get_status returns None in SimEnv. The simulation logs can be found at /tmp/nvflare/simulation/chest-xray-lora
Result can be found in : /tmp/nvflare/simulation/chest-xray-lora



## 4. Visualize the results

Launch TensorBoard to see the training progress.

In [None]:
%load_ext tensorboard
%tensorboard --bind_all |--logdir /tmp/nvflare/simulation/chest-xray-lora/server/simulate_job/tb_events

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 90743), started 0:02:04 ago. (Use '!kill 90743' to kill it.)