In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [1]:
!pip install -q flwr[simulation] torch torchvision tqdm

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.7/66.7 MB[0m [31m26.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m106.5 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m76.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m47.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
# ============================
# FEDERATED LEARNING WITH FLOWER + ViT-B/16
# Modified: Non-IID full-dataset client split + per-client 70/20/10 (holdout 10% used for final global eval)
# Keeps model, strategy, training loops, and most hyperparams unchanged.
# ============================
# !pip install -q "flwr[simulation]" torch torchvision tqdm
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms, models
from torchvision.models import ViT_B_16_Weights
from tqdm import tqdm
from sklearn.metrics import confusion_matrix, classification_report
import numpy as np
import random
import flwr as fl
from flwr.common import parameters_to_ndarrays, ndarrays_to_parameters
from collections import OrderedDict
import warnings

warnings.filterwarnings("ignore")

# ============================
# HYPERPARAMETERS
# ============================
NUM_CLIENTS = 4
NUM_ROUNDS = 10
BATCH_SIZE = 32
LOCAL_EPOCHS = 5
LEARNING_RATE = 1e-4
FRACTION_FIT = 1.0
FRACTION_EVALUATE = 1.0
SEED = 42

# reproducibility
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ============================
# TRANSFORMS / DATASET (base)
# ============================
data_dir = '/kaggle/input/leukemia/Original'  # update if needed

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

val_test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# We'll instantiate three ImageFolder objects that point to the same root
# but carry different transforms so that Subsets can reference them independently
base_dataset = datasets.ImageFolder(data_dir)  # used for labels and consistent indexing
full_dataset_for_train = datasets.ImageFolder(data_dir, transform=train_transform)
full_dataset_for_val = datasets.ImageFolder(data_dir, transform=val_test_transform)
full_dataset_for_test = datasets.ImageFolder(data_dir, transform=val_test_transform)

class_names = base_dataset.classes
num_classes = len(class_names)
print(f"Found {len(base_dataset)} samples across {num_classes} classes: {class_names}")

# ============================
# NON-IID SPLIT (Dirichlet)
# ============================
def non_iid_dirichlet_split(base_dataset, num_clients, alpha=0.5, seed=SEED):
    """
    Return a list of index lists (one per client) partitioning the full dataset in a non-iid way
    using a Dirichlet distribution over classes.
    """
    np.random.seed(seed)
    labels = np.array(base_dataset.targets)
    n_classes = len(base_dataset.classes)
    # indices per class
    class_idx = [np.where(labels == i)[0] for i in range(n_classes)]

    client_indices = [[] for _ in range(num_clients)]
    for c, idx in enumerate(class_idx):
        if len(idx) == 0:
            continue
        np.random.shuffle(idx)
        # Draw a distribution for this class across clients
        proportions = np.random.dirichlet([alpha] * num_clients)
        # To avoid empty splits when counts are small, scale and round
        counts = (proportions * len(idx)).astype(int)
        # Fix rounding to ensure sum(counts) == len(idx)
        while counts.sum() < len(idx):
            counts[np.argmax(proportions)] += 1
        while counts.sum() > len(idx):
            counts[np.argmax(counts)] -= 1
        start = 0
        for i in range(num_clients):
            cnt = counts[i]
            if cnt > 0:
                client_indices[i].extend(idx[start:start + cnt].tolist())
            start += cnt
    # Final sanity: if any client got zero samples (rare), move one sample from largest client
    for i in range(num_clients):
        if len(client_indices[i]) == 0:
            # find client with max samples
            j = np.argmax([len(x) for x in client_indices])
            client_indices[i].append(client_indices[j].pop())
    return client_indices

# Create non-iid indices for each client from the full dataset
# Note: alpha controls skew; smaller alpha -> more skew. The user earlier used alpha=5.0 but comment said smaller alpha => more skew.
# We'll keep alpha moderate (e.g., 5.0) to reflect the original call; you can adjust if you want more skew.
CLIENT_RAW_INDICES = non_iid_dirichlet_split(base_dataset, NUM_CLIENTS, alpha=5.0, seed=SEED)

# ============================
# Per-client 70/20/10 split (local)
# ============================
client_train_subsets = []
client_val_subsets = []
client_holdout_subsets = []  # the 10% hold-out (untouched during local training/eval)
client_sizes = []

for i, indices in enumerate(CLIENT_RAW_INDICES):
    np.random.seed(SEED + i)  # per-client reproducibility
    idx = np.array(indices)
    np.random.shuffle(idx)

    total = len(idx)
    if total < 3:
        # very tiny client: assign at least 1 to train, 1 to val, rest to holdout if possible
        n_train = max(1, int(0.7 * total))
        n_val = max(1, int(0.2 * total))
    else:
        n_train = int(0.7 * total)
        n_val = int(0.2 * total)
    n_holdout = total - n_train - n_val
    # If rounding made holdout zero while total >=3, transfer one from val -> holdout
    if n_holdout == 0 and total >= 3:
        if n_val > 1:
            n_val -= 1
            n_holdout = 1
        elif n_train > 1:
            n_train -= 1
            n_holdout = 1

    # compute splits
    train_idx = idx[:n_train].tolist()
    val_idx = idx[n_train:n_train + n_val].tolist()
    holdout_idx = idx[n_train + n_val:].tolist()

    # Create Subsets referencing the transform-specific full datasets
    train_subset = Subset(full_dataset_for_train, train_idx)
    val_subset = Subset(full_dataset_for_val, val_idx)
    holdout_subset = Subset(full_dataset_for_test, holdout_idx)

    client_train_subsets.append(train_subset)
    client_val_subsets.append(val_subset)
    client_holdout_subsets.append(holdout_subset)
    client_sizes.append((len(train_subset), len(val_subset), len(holdout_subset)))

print("\nPer-client local dataset sizes (train, val, holdout):")
for i, (tr, va, ho) in enumerate(client_sizes):
    # get class-wise counts for this client's full indices using base_dataset labels
    client_labels_full = np.array(base_dataset.targets)[CLIENT_RAW_INDICES[i]]
    counts_full = np.bincount(client_labels_full, minlength=num_classes)
    counts_str = ", ".join([f"{class_names[j]}: {counts_full[j]}" for j in range(num_classes)])
    print(f"Client {i} -> Total full: {len(CLIENT_RAW_INDICES[i])} | train: {tr}, val: {va}, holdout(10%): {ho} | class dist in full client data: {counts_str}")

# For debugging: show a quick overall check
total_assigned = sum([sum(client_sizes[i]) for i in range(len(client_sizes))])
print(f"\nTotal samples assigned across clients (sum of train+val+holdout): {total_assigned} (should equal dataset size {len(base_dataset)})")

# ============================
# MODEL DEFINITION (unchanged)
# ============================
def get_vit_model(num_classes=4):
    model = models.vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
    for param in model.parameters():
        param.requires_grad = False
    in_features = model.heads.head.in_features
    model.heads = nn.Sequential(
        nn.Dropout(p=0.5),
        nn.Linear(in_features, num_classes)
    )
    for param in model.heads.parameters():
        param.requires_grad = True
    return model.to(device)

# ============================
# CLIENT IMPLEMENTATION (unchanged API; data sources updated)
# ============================
class ViTClient(fl.client.NumPyClient):
    def __init__(self, trainset, valset):
        self.trainset = trainset
        self.valset = valset
        self.model = get_vit_model(num_classes=num_classes)
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.heads.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
        self.trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
        self.valloader = DataLoader(valset, batch_size=BATCH_SIZE, shuffle=False)

    def get_parameters(self, config):
        # Return only model.heads parameters as numpy arrays (consistent with set_parameters)
        return [val.cpu().numpy() for _, val in self.model.heads.state_dict().items()]

    def set_parameters(self, parameters):
        params_dict = zip(self.model.heads.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.heads.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        self.model.train()
        for epoch in range(LOCAL_EPOCHS):
            for images, labels in self.trainloader:
                images, labels = images.to(device), labels.to(device)
                self.optimizer.zero_grad()
                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()
        return self.get_parameters({}), len(self.trainloader.dataset), {}

    def evaluate(self, parameters, config):
        # evaluate on the client's valset (20%)
        self.set_parameters(parameters)
        self.model.eval()
        loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in self.valloader:
                images, labels = images.to(device), labels.to(device)
                outputs = self.model(images)
                loss += self.criterion(outputs, labels).item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        loss /= max(1, len(self.valloader))
        accuracy = correct / max(1, total) if total > 0 else 0.0
        return loss, len(self.valloader.dataset), {"accuracy": accuracy}

# ============================
# GLOBAL EVALUATION HELPERS (per-client holdout)
# ============================
def load_head_params_into_model(model, params):
    """
    Load parameters (ndarrays list matching model.heads.state_dict().keys()) into model.heads
    """
    params_dict = zip(model.heads.state_dict().keys(), params)
    state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
    model.heads.load_state_dict(state_dict, strict=True)

def evaluate_on_holdout(server_params, holdout_subset, client_id):
    """
    Evaluate the given server/head parameters on the client's holdout (10%) and print classification report.
    """
    model = get_vit_model(num_classes=num_classes)
    try:
        # server_params might already be ndarrays
        params = server_params
        load_head_params_into_model(model, params)
    except Exception as e:
        # try converting from Flower Parameter objects if necessary
        try:
            params = parameters_to_ndarrays(server_params)
            load_head_params_into_model(model, params)
        except Exception as e2:
            print(f"Could not load server params for client {client_id}: {e} / {e2}")
            return None

    model.eval()
    loader = DataLoader(holdout_subset, batch_size=BATCH_SIZE, shuffle=False)
    all_preds = []
    all_labels = []
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    accuracy = 100 * correct / total if total > 0 else 0.0
    print(f"\nClient {client_id} - Holdout Evaluation (10%): samples={total}, Accuracy={accuracy:.2f}%")
    if total > 0:
        cm = confusion_matrix(all_labels, all_preds)
        print("Confusion Matrix:")
        print(cm)
        print("Classification Report:")
        print(classification_report(all_labels, all_preds, target_names=class_names, zero_division=0))
    else:
        print("No samples in holdout for this client; skipping detailed report.")
    return accuracy

# ============================
# CUSTOM STRATEGY: Save final params (unchanged)
# ============================
from flwr.server.strategy import FedAvg

class SaveModelStrategy(FedAvg):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.final_params = None

    def aggregate_fit(self, rnd, results, failures):
        aggregated_parameters, agg_metrics = super().aggregate_fit(rnd, results, failures)
        if aggregated_parameters is not None:
            try:
                nds = parameters_to_ndarrays(aggregated_parameters)
                self.final_params = nds
            except Exception:
                self.final_params = aggregated_parameters
        return aggregated_parameters, agg_metrics

# ============================
# STRATEGY & START SERVER
# ============================
def client_fn(cid: str):
    cid_i = int(cid)
    # return the ViTClient with that client's train and val subsets (70% and 20%)
    train_subset = client_train_subsets[cid_i]
    val_subset = client_val_subsets[cid_i]
    return ViTClient(train_subset, val_subset)

def weighted_average(metrics):
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]
    return {"accuracy": sum(accuracies) / max(1, sum(examples))}

strategy = SaveModelStrategy(
    fraction_fit=FRACTION_FIT,
    fraction_evaluate=FRACTION_EVALUATE,
    min_fit_clients=NUM_CLIENTS,
    min_evaluate_clients=NUM_CLIENTS,
    min_available_clients=NUM_CLIENTS,
    evaluate_metrics_aggregation_fn=weighted_average,
)

print("\nStarting Federated Learning Simulation...")
print(f"Clients: {NUM_CLIENTS}, Rounds: {NUM_ROUNDS}, Local Epochs: {LOCAL_EPOCHS}")

history = fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    config=fl.server.ServerConfig(num_rounds=NUM_ROUNDS),
    strategy=strategy,
    client_resources={"num_cpus": 1, "num_gpus": 0.5 if device.type == "cuda" else 0},
)

# ============================
# FINAL GLOBAL EVALUATION ON EACH CLIENT'S HOLDOUT (10%)
# ============================
print("\n" + "="*60)
print("FINAL GLOBAL MODEL EVALUATION ON EACH CLIENT'S 10% HOLDOUT SET")
print("="*60)

# retrieve the final params from strategy (should be ndarrays list)
global_model_params = getattr(strategy, "final_params", None)
if global_model_params is None:
    print("Warning: final parameters not found in strategy. Falling back to last client's parameters.")
    last_client = client_fn(str(NUM_CLIENTS - 1))
    global_model_params = last_client.get_parameters({})

# Evaluate per-client holdout and print classification reports
per_client_accuracies = []
for cid_i in range(NUM_CLIENTS):
    holdout_subset = client_holdout_subsets[cid_i]
    acc = evaluate_on_holdout(global_model_params, holdout_subset, cid_i)
    per_client_accuracies.append((cid_i, acc))

print("\nSummary of global model performance on client holdouts:")
for cid_i, acc in per_client_accuracies:
    print(f"Client {cid_i}: Holdout Accuracy = {acc:.2f}%")

print("\nFederated Learning Completed Successfully!")


2025-11-23 12:25:24.907915: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763900725.129136      48 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763900725.191384      48 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Using device: cuda


	Instead, use the `flwr run` CLI command to start a local simulation in your Flower app, as shown for example below:

		$ flwr new  # Create a new Flower app from a template

		$ flwr run  # Run the Flower app in Simulation Mode

	Using `start_simulation()` is deprecated.

            This is a deprecated feature. It will be removed
            entirely in future versions of Flower.
        
[92mINFO [0m:      Starting Flower simulation, config: num_rounds=10, no round_timeout


Found 3256 samples across 4 classes: ['Benign', 'Early', 'Pre', 'Pro']

Per-client local dataset sizes (train, val, holdout):
Client 0 -> Total full: 984 | train: 688, val: 196, holdout(10%): 100 | class dist in full client data: Benign: 240, Early: 327, Pre: 345, Pro: 72
Client 1 -> Total full: 650 | train: 454, val: 130, holdout(10%): 66 | class dist in full client data: Benign: 74, Early: 227, Pre: 236, Pro: 113
Client 2 -> Total full: 886 | train: 620, val: 177, holdout(10%): 89 | class dist in full client data: Benign: 83, Early: 243, Pre: 250, Pro: 310
Client 3 -> Total full: 736 | train: 515, val: 147, holdout(10%): 74 | class dist in full client data: Benign: 107, Early: 188, Pre: 132, Pro: 309

Total samples assigned across clients (sum of train+val+holdout): 3256 (should equal dataset size 3256)

Starting Federated Learning Simulation...
Clients: 4, Rounds: 10, Local Epochs: 5


2025-11-23 12:25:52,269	INFO worker.py:1771 -- Started a local Ray instance.
[92mINFO [0m:      Flower VCE: Ray initialized with resources: {'node:172.19.2.2': 1.0, 'node:__internal_head__': 1.0, 'CPU': 4.0, 'object_store_memory': 9184966656.0, 'memory': 18369933312.0, 'GPU': 2.0, 'accelerator_type:T4': 1.0}
[92mINFO [0m:      Optimize your simulation with Flower VCE: https://flower.ai/docs/framework/how-to-run-simulations.html
[92mINFO [0m:      Flower VCE: Resources for each Virtual Client: {'num_cpus': 1, 'num_gpus': 0.5}
[92mINFO [0m:      Flower VCE: Creating VirtualClientEngineActorPool with 4 actors
[92mINFO [0m:      [INIT]
[92mINFO [0m:      Requesting initial parameters from one random client
[36m(pid=403)[0m 2025-11-23 12:26:01.827920: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
[36m(pid=403)[0m E0000 00:00:1763900761.86356


FINAL GLOBAL MODEL EVALUATION ON EACH CLIENT'S 10% HOLDOUT SET

Client 0 - Holdout Evaluation (10%): samples=100, Accuracy=93.00%
Confusion Matrix:
[[21  3  0  0]
 [ 0 41  0  1]
 [ 0  2 26  1]
 [ 0  0  0  5]]
Classification Report:
              precision    recall  f1-score   support

      Benign       1.00      0.88      0.93        24
       Early       0.89      0.98      0.93        42
         Pre       1.00      0.90      0.95        29
         Pro       0.71      1.00      0.83         5

    accuracy                           0.93       100
   macro avg       0.90      0.94      0.91       100
weighted avg       0.94      0.93      0.93       100


Client 1 - Holdout Evaluation (10%): samples=66, Accuracy=96.97%
Confusion Matrix:
[[ 6  1  0  0]
 [ 0 24  1  0]
 [ 0  0 24  0]
 [ 0  0  0 10]]
Classification Report:
              precision    recall  f1-score   support

      Benign       1.00      0.86      0.92         7
       Early       0.96      0.96      0.96        25
 