* ISIC Challenge datasets

1. Import the libs

In [14]:
import pandas as pd
import numpy as np
import threading
from collections import defaultdict
import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import flwr as fl
from PIL import Image
from sklearn.model_selection import train_test_split

2. Statistics of the datasets

All the images were preprocessed using the `data_processing/preprocess.py` to images of 256 x 256 pixels for uniformity to reduce the domain shift.

ISIC 2020 - The dataset we will use for federated learning

The dataset has only 2 labels for the images: malignant and benign.

In [2]:
df = pd.read_csv("labels/ISIC_2020_Training_GroundTruth.csv")
num_images = len(df)
num_patients = df['patient_id'].nunique()
num_classes = df['benign_malignant'].nunique()
class_counts = df['benign_malignant'].value_counts()

print(f"Total images: {num_images}")
print(f"Total patients: {num_patients}")
print(f"Classes: {num_classes}")
print("Class distribution:\n", class_counts)
print("\nAge statistics:")
print(df['age_approx'].describe())
print("\nSex distribution:\n", df['sex'].value_counts())
print("\nAnatomical site distribution:\n", df['anatom_site_general_challenge'].value_counts())

Total images: 33126
Total patients: 2056
Classes: 2
Class distribution:
 benign_malignant
benign       32542
malignant      584
Name: count, dtype: int64

Age statistics:
count    33058.000000
mean        48.870016
std         14.380360
min          0.000000
25%         40.000000
50%         50.000000
75%         60.000000
max         90.000000
Name: age_approx, dtype: float64

Sex distribution:
 sex
male      17080
female    15981
Name: count, dtype: int64

Anatomical site distribution:
 anatom_site_general_challenge
torso              16845
lower extremity     8417
upper extremity     4983
head/neck           1855
palms/soles          375
oral/genital         124
Name: count, dtype: int64


3. Split the data for Federated Learning

We need to simulate the real life where hospitals have different types of pacients (different ages, different social backgrounds, etc). We decided to split the data into 4 clients. Each client will receive a different data split. All the data for a particular patient is inside the same client (hospital) - people tend to go to the hospital they are used to.

Each hospital has a different "story" to try to mimic the real world:

1) Hospital A - Big city hospital with most of the patients (45% of the total number of patients). Also, the remaining patients that had missing metadata is assigned to this hospital as well because there can be some mistakes at bigger hospitals with many patients.
2) Hospital B - Hospital focused more on younger people (all the patients younger than 40 years, max 20% of the total number)
3) Hospital C - Hospital focused on the upper body. It contains all the patients that come to take images of lesions of "head/neck", "oral/genital", "upper extremity"
4) Hospital D - Hospital with the rest. Can be a regular dermathological hospital in a city.

In [4]:
def split_patients_federated(csv_path, seed=42):
    np.random.seed(seed)
    df = pd.read_csv(csv_path)

    # Unique patients
    patients = df.groupby("patient_id").agg({
        "age_approx": "mean",
        "anatom_site_general_challenge": lambda x: x.mode().iloc[0] if not x.mode().empty else "unknown",
        "benign_malignant": lambda x: "malignant" if "malignant" in x.values else "benign"
    }).reset_index()

    # ---- Define hospital groups ----
    # Hospital A: Urban General (largest)
    # Hospital B: Youth Dermatology (<40)
    # Hospital C: Head/Neck specialist
    # Hospital D: Rural/Older

    num_patients = len(patients)
    hospital_a_size = int(0.45 * num_patients)
    hospital_b_size = int(0.20 * num_patients)
    hospital_c_size = int(0.15 * num_patients)

   # Hospital B -> Young patients (<40)
    available_patients = patients.copy()
    young = available_patients[available_patients["age_approx"] < 40]
    if len(young) > hospital_b_size:
        young = young.sample(hospital_b_size, random_state=seed)
    hospital_b_patients = set(young["patient_id"])

    # Hospital C -> Upper body focus
    available_patients = patients[~patients["patient_id"].isin(hospital_b_patients)]
    head_sites = ["head/neck", "oral/genital", "upper extremity"]
    upper_body_patients = available_patients[available_patients["anatom_site_general_challenge"].isin(head_sites)]
    if len(upper_body_patients) > hospital_c_size:
        upper_body_patients = upper_body_patients.sample(hospital_c_size, random_state=seed)
    hospital_c_patients = set(upper_body_patients["patient_id"])

    # Hospital A -> Big City Hospital
    available_patients = patients[~patients["patient_id"].isin(hospital_b_patients | hospital_c_patients)]
    hospital_a_patients = set(
        available_patients.sample(hospital_a_size, random_state=seed)["patient_id"]
    )

    # Hospital D -> Rest
    hospital_d_patients = set(patients["patient_id"]) - (hospital_a_patients | hospital_b_patients | hospital_c_patients)

    hospitals = {
        "Hospital_A_Urban": df[df["patient_id"].isin(hospital_a_patients)],
        "Hospital_B_Youth": df[df["patient_id"].isin(hospital_b_patients)],
        "Hospital_C_Upper_Body": df[df["patient_id"].isin(hospital_c_patients)],
        "Hospital_D_Rural": df[df["patient_id"].isin(hospital_d_patients)]
    }

    df["hospital_id"] = 0
    df.loc[df["patient_id"].isin(hospital_a_patients), "hospital_id"] = 1
    df.loc[df["patient_id"].isin(hospital_b_patients), "hospital_id"] = 2
    df.loc[df["patient_id"].isin(hospital_c_patients), "hospital_id"] = 3
    df.loc[df["patient_id"].isin(hospital_d_patients), "hospital_id"] = 4
    df.to_csv(csv_path, index=False)

    # --- Print summaries ---
    print("\n=== Hospital Summary ===")
    for name, hdf in hospitals.items():
        counts = hdf["benign_malignant"].value_counts()
        mean_age = hdf["age_approx"].mean()
        top_site = hdf["anatom_site_general_challenge"].mode().iloc[0]
        print(f"\n{name}:")
        print(f" Patients: {hdf['patient_id'].nunique():5d} | Images: {len(hdf):5d}")
        print(f" Avg. age: {mean_age:.1f} | Dominant body part: {top_site}")
        print(f" Benign: {counts.get('benign',0):5d} | Malignant: {counts.get('malignant',0):4d}")

    return hospitals

In [5]:
csv_path = "labels/ISIC_2020_Training_GroundTruth.csv"
seed = 42
client_splits = split_patients_federated(csv_path, seed)


=== Hospital Summary ===

Hospital_A_Urban:
 Patients:   925 | Images: 14505
 Avg. age: 54.8 | Dominant body part: torso
 Benign: 14212 | Malignant:  293

Hospital_B_Youth:
 Patients:   411 | Images:  8141
 Avg. age: 31.1 | Dominant body part: torso
 Benign:  8074 | Malignant:   67

Hospital_C_Upper_Body:
 Patients:   163 | Images:  1429
 Avg. age: 58.2 | Dominant body part: upper extremity
 Benign:  1367 | Malignant:   62

Hospital_D_Rural:
 Patients:   557 | Images:  9051
 Avg. age: 54.0 | Dominant body part: torso
 Benign:  8889 | Malignant:  162


4. Write the model to train

In [6]:
class SmallCNN(nn.Module):
    def __init__(self, num_classes=2):
        super(SmallCNN, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),  # input 3x256x256
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 32x128x128

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 64x64x64

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 128x32x32

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)   # 256x16x16
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256*16*16, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

5. Create the Federated Learning set-up

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [8]:
class SkinLesionDataset(Dataset):
    def __init__(self, df, dir_path='../train/'):
        self.df = df.reset_index(drop=True)
        self.dir_path = dir_path

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        path = self.dir_path + row['image_name'] + ".jpg"
        img = Image.open(path).convert("RGB")
        img = np.array(img) / 255.0
        img = torch.tensor(img.transpose(2,0,1), dtype=torch.float32)
        label = 0 if row['benign_malignant'] == 'benign' else 1
        return img, label

In [9]:
dir_path = '../train/'

def load_client_dataset(csv_path, hospital_id, dir_path='../train/'):
    df = pd.read_csv(csv_path)
    client_df = df[df["hospital_id"] == hospital_id]
    if client_df.empty:
        return None, None

    # Optional: split into train/val
    from sklearn.model_selection import train_test_split
    train_df, val_df = train_test_split(client_df, test_size=0.2, random_state=42, stratify=client_df['benign_malignant'])

    train_dataset = SkinLesionDataset(train_df, dir_path)
    val_dataset = SkinLesionDataset(val_df, dir_path)
    return train_dataset, val_dataset

In [10]:
df = pd.read_csv("labels/ISIC_2020_Training_GroundTruth.csv")
hospital_ids = df["hospital_id"].unique()

clients_train_data = []
clients_test_data = []

for hid in hospital_ids:
    train_ds, val_ds = load_client_dataset("labels/ISIC_2020_Training_GroundTruth.csv", hospital_id=hid)
    if train_ds is None:
        continue
    clients_train_data.append(train_ds)
    clients_test_data.append(val_ds)

In [11]:
def train_fn(model, trainloader, epochs, lr, device):
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    for _ in range(epochs):
        for X_batch, y_batch in trainloader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
    return loss.item()

def test_fn(model, testloader, device):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    correct, total, loss_total = 0, 0, 0
    with torch.no_grad():
        for X_batch, y_batch in testloader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss_total += loss.item() * len(y_batch)
            preds = outputs.argmax(dim=1)
            correct += (preds == y_batch).sum().item()
            total += len(y_batch)
    return loss_total / total, correct / total

In [12]:
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, model, train_dataset, test_dataset, device):
        self.model = model
        self.trainloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
        self.testloader = DataLoader(test_dataset, batch_size=32)
        self.device = device

    def get_parameters(self):
        return [val.cpu().numpy() for val in self.model.state_dict().values()]

    def set_parameters(self, parameters):
        state_dict = {k: torch.tensor(v) for k, v in zip(self.model.state_dict().keys(), parameters)}
        self.model.load_state_dict(state_dict)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        train_fn(self.model, self.trainloader, epochs=config["local_epochs"], lr=config["lr"], device=self.device)
        return self.get_parameters(), len(self.trainloader.dataset), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        loss, accuracy = test_fn(self.model, self.testloader, self.device)
        return float(loss), len(self.testloader.dataset), {"accuracy": float(accuracy)}

In [16]:
strategy = fl.server.strategy.FedAvg(
    fraction_fit=1.0,
    min_fit_clients=4,
    min_available_clients=4,
)

def start_flower_server():
    fl.server.start_server(
        server_address="localhost:8080",
        config=fl.server.ServerConfig(num_rounds=5),
        strategy=strategy
    )

In [17]:
clients = []
for i in range(4):
    model = SmallCNN().to(device)
    clients.append(FlowerClient(model, clients_train_data[i], clients_test_data[i], device))

In [18]:
# --- Start the server in a background thread ---
server_thread = threading.Thread(target=start_flower_server, daemon=True)
server_thread.start()

# --- Start clients ---
def start_client(client_obj):
    fl.client.start_numpy_client("localhost:8080", client=client_obj)

client_threads = []
for client_obj in clients:  # your FlowerClient instances
    t = threading.Thread(target=start_client, args=(client_obj,))
    t.start()
    client_threads.append(t)

# Optional: wait for clients to finish
for t in client_threads:
    t.join()

	Instead, use the `flower-superlink` CLI command to start a SuperLink as shown below:

		$ flower-superlink --insecure

	To view usage and all available options, run:

		$ flower-superlink --help

	Using `start_server()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
Exception in thread Thread-7 (start_client):
Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
Exception in thread Thread-8 (start_client):
Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
Exception in thread Thread-9 (start_client):
Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
Exception in thread Thread-10 (start_client):
Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/u

    default_handler = signal.signal(sig, graceful_exit_handler)  # type: ignore
  File "/usr/lib/python3.10/signal.py", line 56, in signal
    handler = _signal.signal(_enum_to_int(signalnum), _enum_to_int(handler))
ValueError: signal only works in main thread of the main interpreter
