# Federated Learning POC: Chest X-ray with ConvNet2

This notebook demonstrates a federated learning setup using NVFlare to train the `ConvNet2` model (or the LoRA-optimized version) on the Chest X-ray dataset distributed across 3 sites (`site1`, `site2`, `site3`).

## Objectives:
1. **Pre-train** the backbone on a small data subset to initialize the global model.
2. Compare **Full Fine-tuning** vs **LoRA** (Low-Rank Adaptation).
3. Use **Mixed Precision** training for faster execution.
4. Simulate 3 federated sites with `FedAvg`.

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import transforms, datasets
from model import ConvNet2, LoRAConvNet2
import numpy as npimport time
import pandas as pd


## 1. Pre-training Initialization

Instead of starting with random weights, we pre-train the model backbone on a small subset of the training data (e.g., from site 1). This stable initialization helps with convergence in federated environments.

In [None]:
def get_pretrain_loader(data_path, batch_size=16, num_samples=100):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.2, 0.2, 0.2]),
    ])
    
    # Use site1 training data for pre-training
    site1_train_path = os.path.join(data_path, "site1", "train")
    full_dataset = datasets.ImageFolder(root=site1_train_path, transform=transform)
    
    # Take a small subset
    indices = np.random.choice(len(full_dataset), num_samples, replace=False)
    subset = Subset(full_dataset, indices)
    
    return DataLoader(subset, batch_size=batch_size, shuffle=True)

def pretrain_model(model, loader, epochs=2, device="cpu"):
    model.to(device)
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    
    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Pre-train Epoch {epoch+1} Loss: {running_loss/len(loader):.4f}")
    
    return model

data_path = os.path.abspath("chest_xray")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("---- Starting Pre-training Initialization ----")
base_model = ConvNet2()
pretrain_loader = get_pretrain_loader(data_path)
initialized_model = pretrain_model(base_model, pretrain_loader, device=device)
print("---- Pre-training Done ----")

## 2. Define the FedJob Recipe

We can now choose between the **Full** model or the **LoRA** model. 

### Choice: LoRA vs Full
Change `USE_LORA` to `True` to use the Low-Rank Adaptation model. 

In [None]:
from nvflare.app_opt.pt.recipes.fedavg import FedAvgRecipe
from nvflare.recipe import SimEnv, add_experiment_tracking

USE_LORA = True # Set to False for Full fine-tuning
USE_AMP = True
n_clients = 3
num_rounds = 5
batch_size = 16

if USE_LORA:
    # Wrap the pre-trained weights into LoRA model
    model_to_run = LoRAConvNet2(rank=8, base_model=initialized_model)
    model_type = "lora"
    job_name = "chest-xray-lora"
else:
    model_to_run = initialized_model
    model_type = "full"
    job_name = "chest-xray-full"

recipe = FedAvgRecipe(
    name=job_name,
    min_clients=n_clients,
    num_rounds=num_rounds,
    model=model_to_run,
    train_script="client_xray.py",
    train_args=f"--batch_size {batch_size} --epochs 1 --data_path {data_path} --model_type {model_type} --use_amp {USE_AMP}",
)

add_experiment_tracking(recipe, tracking_type="tensorboard")

## 3. Run the Job in Simulation

We will now execute the federated learning job across 3 simulated clients.

In [None]:
env = SimEnv(num_clients=n_clients)
print(f"---- Starting {job_name} Experiment ----")
start_time_lora = time.time()
run = recipe.execute(env)
end_time_lora = time.time()
duration_lora = end_time_lora - start_time_lora

print()
print("Job Status is:", run.get_status())
print(f"Experiment duration: {duration_lora:.2f} seconds")
print("Result can be found in :", run.get_result())
print()


## 4. Visualize the results

Launch TensorBoard to see the training progress.

## 5. Comparison: Normal Federated Learning (Full Fine-tuning)

To understand the impact of LoRA, we'll also run a standard Federated Learning job where all parameters are updated.

In [None]:
import copy
import time

print("---- Starting Full Federated Learning Comparison ----")
# Reset to full model
full_model_init = copy.deepcopy(initialized_model)
job_name_full = "chest-xray-full-comparison"

recipe_full = FedAvgRecipe(
    name=job_name_full,
    min_clients=n_clients,
    num_rounds=num_rounds,
    model=full_model_init,
    train_script="client_xray.py",
    train_args=f"--batch_size {batch_size} --epochs 1 --data_path {data_path} --model_type full --use_amp {USE_AMP}",
)

add_experiment_tracking(recipe_full, tracking_type="tensorboard")

start_time_full = time.time()
run_full = recipe_full.execute(env)
end_time_full = time.time()
duration_full = end_time_full - start_time_full

print("Full FL Job Status:", run_full.get_status())
print(f"Experiment duration: {duration_full:.2f} seconds")
print("Full FL Result:", run_full.get_result())


In [None]:
%load_ext tensorboard
%tensorboard --bind_all --logdir /tmp/nvflare/simulation


In [None]:
print("
---- FINAL COMPARISON SUMMARY ----")
data = {
    "Experiment": ["LoRA Federated Learning", "Full Federated Learning"],
    "Method": ["LoRA (Low-Rank Adaptation)", "Full Parameter Tuning"],
    "Duration (sec)": [round(duration_lora, 2), round(duration_full, 2)],
    "Est. Communication": ["Low (~1% params)", "High (100% params)"]
}
df = pd.DataFrame(data)
display(df)
