https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/brats_segmentation_3d.ipynb

In [1]:
import os
import shutil
import tempfile
import time
import matplotlib.pyplot as plt
from monai.apps import DecathlonDataset
from monai.config import print_config
from monai.data import DataLoader, decollate_batch
from monai.handlers.utils import from_engine
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.networks.nets import SegResNet
from monai.transforms import (
    Activations,
    Activationsd,
    AsDiscrete,
    AsDiscreted,
    Compose,
    Invertd,
    LoadImaged,
    MapTransform,
    NormalizeIntensityd,
    Orientationd,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandSpatialCropd,
    Spacingd,
    EnsureTyped,
    EnsureChannelFirstd,
)
from monai.utils import set_determinism
import onnxruntime
from tqdm import tqdm

import torch

print_config()
set_determinism(seed=0)

[0;93m2025-10-24 12:19:36.263256874 [W:onnxruntime:Default, device_discovery.cc:164 DiscoverDevicesForPlatform] GPU device discovery failed: device_discovery.cc:89 ReadFileContents Failed to open file: "/sys/class/drm/card0/device/vendor"[m


MONAI version: 1.6.dev2542
Numpy version: 2.1.2
Pytorch version: 2.8.0+cu126
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 612f3dd3cba4d73cfcea4b5329079e20aa31523d
MONAI __file__: /home/<username>/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: 5.4.4
Nibabel version: 5.3.2
scikit-image version: 0.25.2
scipy version: 1.15.3
Pillow version: 11.0.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.23.0+cu126
tqdm version: 4.67.1
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 7.0.0
pandas version: 2.3.2
einops version: 0.8.1
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For d

In [25]:
directory = "./seg_ckpts"
if directory is not None:
    os.makedirs(directory, exist_ok=True)
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

./seg_ckpts


# Transforms/Preprocessing taken as is from MONAI

In [3]:
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    """
    Convert labels to multi channels based on brats classes:
    label 1 is the peritumoral edema
    label 2 is the GD-enhancing tumor
    label 3 is the necrotic and non-enhancing tumor core
    The possible classes are TC (Tumor core), WT (Whole tumor)
    and ET (Enhancing tumor).

    """

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            # merge label 2 and label 3 to construct TC
            result.append(torch.logical_or(d[key] == 2, d[key] == 3))
            # merge labels 1, 2 and 3 to construct WT
            result.append(torch.logical_or(torch.logical_or(d[key] == 2, d[key] == 3), d[key] == 1))
            # label 2 is ET
            result.append(d[key] == 2)
            d[key] = torch.stack(result, axis=0).float()
        return d

train_transform = Compose(
    [
        # load 4 Nifti images and stack them together
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        RandSpatialCropd(keys=["image", "label"], roi_size=[224, 224, 144], random_size=False),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
    ]
)
val_transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)

monai.transforms.spatial.dictionary Orientationd.__init__:labels: Current default value of argument `labels=(('L', 'R'), ('P', 'A'), ('I', 'S'))` was changed in version None from `labels=(('L', 'R'), ('P', 'A'), ('I', 'S'))` to `labels=None`. Default value changed to None meaning that the transform now uses the 'space' of a meta-tensor, if applicable, to determine appropriate axis labels.


# Load the data in a federated manner ready for the Data Valuation pipeline

In [15]:
# MINE!!

import os, copy, time, random, torch, numpy as np                                 
import glob
from tqdm import tqdm
import pandas as pd
from torch.utils.data import DataLoader
from monai.data import CacheDataset
import glob, nibabel as nib, pandas as pd
from monai.data import CacheDataset, DataLoader
from monai.transforms import (
    LoadImaged, EnsureChannelFirstd, Orientationd, ScaleIntensityd,
    RandFlipd, RandSpatialCropd, Compose, SelectItemsd
)

from utils import *
  
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# -----------------------------------------------------------
# 0. paths & meta-data (unchanged) ---------------------------
# -----------------------------------------------------------
BRATS_DIR = "/mnt/d/Datasets/FETS_data/MICCAI_FeTS2022_TrainingData"
CSV_PATH  = f"{BRATS_DIR}/partitioning_1.csv"
MODALITIES = ["flair", "t1", "t1ce", "t2"]
LABEL_KEY  = "seg"

# -----------------------------------------------------------
# 1. read partition file  ➜  { id : [subjects] } ------------
# -----------------------------------------------------------
part_df       = pd.read_csv(CSV_PATH)
partition_map = (
    part_df.groupby("Partition_ID")["Subject_ID"]
           .apply(list).to_dict()
)                               # keys are 1 … 23

VAL_CENTRES = {21, 22, 23}          # ← our hold-out set
# VAL_CENTRES = {22, 23}          # ← our sanity set

# split once, reuse everywhere
train_partitions = {cid: sids for cid, sids in partition_map.items()
                    if cid not in VAL_CENTRES}
val_subjects     = sum((partition_map[cid] for cid in VAL_CENTRES), [])

# -----------------------------------------------------------
# 2. helper to build MONAI-style record dicts ----------------
# -----------------------------------------------------------

def build_records(subject_ids):
    recs = []
    for sid in subject_ids:
        sdir = f"{BRATS_DIR}/{sid}"
        images = [f"{sdir}/{sid}_{m}.nii.gz" for m in MODALITIES]  # 4 modalities
        recs.append({"image": images, "label": f"{sdir}/{sid}_{LABEL_KEY}.nii.gz"})
    return recs


# -----------------------------------------------------------
# 3. MONAI CacheDatasets ------------------------------------
# -----------------------------------------------------------
# ── client-wise training sets ───────────────────────────────
CUT_OFF, FRAC, SEED = 3, 0.2, 42
rng = random.Random(SEED)

train_datasets = {}
for cid, subj_ids in train_partitions.items():
    if cid > CUT_OFF:                                    # keep your cap
        break
    k = max(1, int(len(subj_ids) * FRAC))                # e.g. 30 %
    sample_ids = rng.sample(subj_ids, k)
    train_datasets[cid] = CacheDataset(
        build_records(sample_ids), transform=train_transform, cache_rate=1
    )

# ── single validation dataset made from *all* val subjects ─
val_dataset = CacheDataset(
    build_records(val_subjects), transform=val_transform, cache_rate=1
)

print("train per-centre sizes:", {k: len(v) for k, v in train_datasets.items()})
print("validation size:", len(test_dataset))


Loading dataset: 100%|████████████████████████████████████████████████████████████████| 102/102 [03:18<00:00,  1.95s/it]
Loading dataset: 100%|████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.01s/it]
Loading dataset: 100%|████████████████████████████████████████████████████████████████████| 3/3 [00:05<00:00,  1.86s/it]
Loading dataset: 100%|██████████████████████████████████████████████████████████████████| 47/47 [02:48<00:00,  3.59s/it]

train per-centre sizes: {1: 102, 2: 1, 3: 3}
validation size: 47





# Create Model, Loss, Optimizer (as is from MONAI tutorial)


In [17]:
max_epochs = 300
val_interval = 1
VAL_AMP = True

# standard PyTorch program style: create SegResNet, DiceLoss and Adam optimizer
device = torch.device("cuda:0")
global_model = SegResNet(
    blocks_down=[1, 2, 2, 4],
    blocks_up=[1, 1, 1],
    init_filters=16,
    in_channels=4,
    out_channels=3,
    dropout_prob=0.2,
).to(device)
loss_function = DiceLoss(smooth_nr=0, smooth_dr=1e-5, squared_pred=True, to_onehot_y=False, sigmoid=True)
optimizer = torch.optim.Adam(global_model.parameters(), 1e-4, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)

dice_metric = DiceMetric(include_background=True, reduction="mean")
dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")

post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])


# define inference method
def inference(input):
    def _compute(input):
        return sliding_window_inference(
            inputs=input,
            roi_size=(240, 240, 160),
            sw_batch_size=1,
            predictor=global_model,
            overlap=0.5,
        )

    if VAL_AMP:
        with torch.autocast("cuda"):
            return _compute(input)
    else:
        return _compute(input)


# use amp to accelerate training
scaler = torch.GradScaler("cuda")
# enable cuDNN benchmark
torch.backends.cudnn.benchmark = True

# Evaluation function

In [19]:
def evaluate_model(model, dataset, device, batch_size=1):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    dice_metric.reset()
    model.eval()
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            inputs = batch["image"].to(device)
            labels = batch["label"].to(device)          # already multi-channel (WT/TC/ET) floats

            # sliding-window inference (your helper)
            logits = inference(inputs)                  # [B, 3, D, H, W]

            # tensor-wise post-processing (no decollate, no meta juggling)
            preds = torch.sigmoid(logits)
            preds = (preds > 0.5).float()

            # accumulate Dice on tensors directly
            dice_metric(y_pred=preds, y=labels)
            dice_metric_batch(y_pred=preds, y=labels)

    mean_dice = dice_metric.aggregate().item()
    metric_batch = dice_metric_batch.aggregate()
    metric_tc = metric_batch[0].item()
    metric_wt = metric_batch[1].item()
    metric_et = metric_batch[2].item()
    dice_metric.reset()
    dice_metric_batch.reset()
    
    return mean_dice, metric_tc, metric_wt, metric_et
    
print("Dice before any training:", evaluate_model(global_model, test_dataset, device)) # quick sanity check


Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 47/47 [00:24<00:00,  1.92it/s]

Dice before any training: (0.015499441884458065, 0.022785522043704987, 0.0017377998447045684, 0.02197500318288803)





# Train federated

In [None]:
from tqdm.auto import tqdm, trange   # trange == tqdm(range())
from collections import OrderedDict

def average_weights(state_dicts, fractions):
    """
    Federated averaging with client fractions (must sum to 1).
    state_dicts: list of state_dicts (same keys)
    fractions:   list of floats, same length, sum≈1
    """
    avg_sd = OrderedDict()
    for k in state_dicts[0].keys():
        avg = 0.0
        for sd, w in zip(state_dicts, fractions):
            avg += sd[k] * w
        avg_sd[k] = avg
    return avg_sd

# ────────────────────────────────────────────────────────────
# 1. one-client update (returns weights + mean loss)          │
# ────────────────────────────────────────────────────────────
def local_train(model, loader, device, lr=1e-4, epochs=1):
    """
    Train a local copy of the global model on one client's DataLoader.
    Uses your DiceLoss (multi-label, sigmoid) and full crops from transforms.
    """
    model = copy.deepcopy(model).to(device)
    model.train()

    # reuse your loss choice; or inline DiceLoss the same way
    crit = DiceLoss(smooth_nr=0, smooth_dr=1e-5, squared_pred=True,
                    to_onehot_y=False, sigmoid=True).to(device)

    opt = torch.optim.Adam(model.parameters(), lr=lr)
    epoch_losses = []

    for _ in range(epochs):
        running = 0.0
        for batch in loader:
            img = batch["image"].to(device)   # [B, 4, D, H, W]
            msk = batch["label"].to(device)   # [B, 3, D, H, W]

            opt.zero_grad(set_to_none=True)
            logits = model(img)               # [B, 3, D, H, W]
            loss = crit(logits, msk)
            loss.backward()
            opt.step()

            running += float(loss.item())
        epoch_losses.append(running / max(1, len(loader)))

    return model.state_dict(), float(np.mean(epoch_losses))


# ────────────────────────────────────────────────────────────
# 2. FedAvg training loop (simple tqdm + clean prints)        │
# ────────────────────────────────────────────────────────────
from tqdm.auto import tqdm, trange

EPOCHS, LOCAL_EPOCHS, LR, BATCH = 50, 1, 1e-4, 1

idxs_users = list(train_datasets.keys())
sizes      = {k: len(ds) for k, ds in train_datasets.items()}
total_n    = sum(sizes.values())
fractions  = [sizes[k] / total_n for k in idxs_users]

print("Dice before any training:", evaluate_model(global_model, test_dataset, device))

best_metric = -1
best_metric_round = -1
best_metrics_rounds_and_time = [[], [], []]
round_loss_values = []
metric_values = []
metric_values_tc = []
metric_values_wt = []
metric_values_et = []

total_start = time.time()
for rnd in trange(1, EPOCHS + 1, desc="Global rounds", position=0, leave=True, dynamic_ncols=True):
    local_weights, client_losses = [], []

    # client bar (line 1)
    for cid in tqdm(idxs_users, desc=" clients", position=1, leave=False, total=len(idxs_users), dynamic_ncols=True):
        loader = DataLoader(train_datasets[cid], batch_size=BATCH, shuffle=True, num_workers=4, pin_memory=True)
        w, loss = local_train(global_model, loader, device, lr=LR, epochs=LOCAL_EPOCHS)
        local_weights.append(w); client_losses.append(loss)

    global_model.load_state_dict(average_weights(local_weights, fractions))

    # eval bar (line 2)
    mean_dice, metric_tc, metric_wt, metric_et = evaluate_model(global_model, test_dataset, device)
    
    metric_values.append(mean_dice)
    metric_values_tc.append(metric_tc)
    metric_values_wt.append(metric_wt)
    metric_values_et.append(metric_et)
    
    mean_loss = float(np.mean(client_losses))
    round_loss_values.append(mean_loss)


    if mean_dice > best_metric:
        best_metric = mean_dice
        best_metric_round = rnd + 1
        best_metrics_rounds_and_time[0].append(best_metric)
        best_metrics_rounds_and_time[1].append(best_metric_epoch)
        best_metrics_rounds_and_time[2].append(time.time() - total_start)
        torch.save(
            global_model.state_dict(),
            os.path.join(root_dir, "best_metric_model.pth"),
        )
        print("saved new best metric model")
    
    tqdm.write(f"Round {rnd:02d}:  mean-loss = {mean_loss:.4f}   mean-Dice = {mean_dice:.4f}    TC-Dice = {metric_tc:.4f}    TC-Dice = {metric_wt:.4f}    TC-Dice = {metric_tc:.4f}")



Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Dice before any training: (0.2995908558368683, 0.3623102605342865, 0.03415071591734886, 0.4463581442832947)


Global rounds:   0%|                                                                             | 0/50 [00:00…

 clients:   0%|                                                                                   | 0/3 [00:00…

In [None]:
plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Federated Round Aggregate Loss")
x = [i + 1 for i in range(len(round_loss_values))]
y = round_loss_values
plt.xlabel("epoch")
plt.plot(x, y, color="red")
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice")
x = [val_interval * (i + 1) for i in range(len(metric_values))]
y = metric_values
plt.xlabel("epoch")
plt.plot(x, y, color="green")
plt.show()

plt.figure("train", (18, 6))
plt.subplot(1, 3, 1)
plt.title("Val Mean Dice TC")
x = [val_interval * (i + 1) for i in range(len(metric_values_tc))]
y = metric_values_tc
plt.xlabel("epoch")
plt.plot(x, y, color="blue")
plt.subplot(1, 3, 2)
plt.title("Val Mean Dice WT")
x = [val_interval * (i + 1) for i in range(len(metric_values_wt))]
y = metric_values_wt
plt.xlabel("epoch")
plt.plot(x, y, color="brown")
plt.subplot(1, 3, 3)
plt.title("Val Mean Dice ET")
x = [val_interval * (i + 1) for i in range(len(metric_values_et))]
y = metric_values_et
plt.xlabel("epoch")
plt.plot(x, y, color="purple")
plt.show()

In [8]:
from tqdm.auto import tqdm, trange   # trange == tqdm(range())
from collections import OrderedDict

def average_weights(state_dicts, fractions):
    """
    Federated averaging with client fractions (must sum to 1).
    state_dicts: list of state_dicts (same keys)
    fractions:   list of floats, same length, sum≈1
    """
    avg_sd = OrderedDict()
    for k in state_dicts[0].keys():
        avg = 0.0
        for sd, w in zip(state_dicts, fractions):
            avg += sd[k] * w
        avg_sd[k] = avg
    return avg_sd

# ────────────────────────────────────────────────────────────
# 1. one-client update (returns weights + mean loss)          │
# ────────────────────────────────────────────────────────────
def local_train(model, loader, device, lr=1e-4, epochs=1):
    """
    Train a local copy of the global model on one client's DataLoader.
    Uses your DiceLoss (multi-label, sigmoid) and full crops from transforms.
    """
    model = copy.deepcopy(model).to(device)
    model.train()

    # reuse your loss choice; or inline DiceLoss the same way
    crit = DiceLoss(smooth_nr=0, smooth_dr=1e-5, squared_pred=True,
                    to_onehot_y=False, sigmoid=True).to(device)

    opt = torch.optim.Adam(model.parameters(), lr=lr)
    epoch_losses = []

    for _ in range(epochs):
        running = 0.0
        for batch in loader:
            img = batch["image"].to(device)   # [B, 4, D, H, W]
            msk = batch["label"].to(device)   # [B, 3, D, H, W]

            opt.zero_grad(set_to_none=True)
            logits = model(img)               # [B, 3, D, H, W]
            loss = crit(logits, msk)
            loss.backward()
            opt.step()

            running += float(loss.item())
        epoch_losses.append(running / max(1, len(loader)))

    return model.state_dict(), float(np.mean(epoch_losses))


# ────────────────────────────────────────────────────────────
# 2. FedAvg training loop (simple tqdm + clean prints)        │
# ────────────────────────────────────────────────────────────
from tqdm.auto import tqdm, trange

EPOCHS, LOCAL_EPOCHS, LR, BATCH = 50, 1, 1e-4, 1

idxs_users = list(train_datasets.keys())
sizes      = {k: len(ds) for k, ds in train_datasets.items()}
total_n    = sum(sizes.values())
fractions  = [sizes[k] / total_n for k in idxs_users]

print("Dice before any training:", evaluate_model(global_model, test_dataset, device))

best_metric = -1
best_metric_epoch = -1
best_metrics_epochs_and_time = [[], [], []]
epoch_loss_values = []
metric_values = []
metric_values_tc = []
metric_values_wt = []
metric_values_et = []

for rnd in trange(1, EPOCHS + 1, desc="Global rounds", position=0, leave=True, dynamic_ncols=True):
    local_weights, client_losses = [], []

    # client bar (line 1)
    for cid in tqdm(idxs_users, desc=" clients", position=1, leave=False, total=len(idxs_users), dynamic_ncols=True):
        loader = DataLoader(train_datasets[cid], batch_size=BATCH, shuffle=True, num_workers=4, pin_memory=True)
        w, loss = local_train(global_model, loader, device, lr=LR, epochs=LOCAL_EPOCHS)
        local_weights.append(w); client_losses.append(loss)

    global_model.load_state_dict(average_weights(local_weights, fractions))

    # eval bar (line 2)
    mean_dice, metric_tc, metric_wt, metric_et = float(evaluate_model(global_model, test_dataset, device))
    
    metric_values.append(mean_dice)
    metric_values_tc.append(metric_tc)
    metric_values_wt.append(metric_wt)
    metric_values_et.append(metric_et)
    
    mean_loss = float(np.mean(client_losses))


    if mean_dice > best_metric:
        best_metric = mean_dice
        best_metric_epoch = epoch + 1
        best_metrics_epochs_and_time[0].append(best_metric)
        best_metrics_epochs_and_time[1].append(best_metric_epoch)
        best_metrics_epochs_and_time[2].append(time.time() - total_start)
        torch.save(
            model.state_dict(),
            os.path.join(root_dir, "best_metric_model.pth"),
        )
        print("saved new best metric model")
    
    tqdm.write(f"Round {rnd:02d}:  mean-loss = {mean_loss:.4f}   mean-Dice = {mean_dice:.4f}    TC-Dice = {metric_tc:.4f}    TC-Dice = {metric_wt:.4f}    TC-Dice = {metric_tc:.4f}")



Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Dice before any training: 0.3333894908428192


Global rounds:   0%|                                                                             | 0/50 [00:00…

 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 01:  mean-loss = 0.9034   mean-Dice = 0.3775


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 02:  mean-loss = 0.8968   mean-Dice = 0.4515


 clients:   0%|                                                                                   | 0/3 [00:00…

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb6ed313400>
Traceback (most recent call last):
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb6ed313400>
Traceback (most recent call last):
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
 

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 03:  mean-loss = 0.8932   mean-Dice = 0.4312


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 04:  mean-loss = 0.8910   mean-Dice = 0.5128


 clients:   0%|                                                                                   | 0/3 [00:00…

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb6ed313400>
Traceback (most recent call last):
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb6ed313400>
Traceback (most recent call last):
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
 

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 05:  mean-loss = 0.8807   mean-Dice = 0.4922


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 06:  mean-loss = 0.8760   mean-Dice = 0.5016


 clients:   0%|                                                                                   | 0/3 [00:00…

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb6ed313400>
Traceback (most recent call last):
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb6ed313400>
Traceback (most recent call last):
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
 

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 07:  mean-loss = 0.8719   mean-Dice = 0.5437


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 08:  mean-loss = 0.8654   mean-Dice = 0.5649


 clients:   0%|                                                                                   | 0/3 [00:00…

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb6ed313400>
Traceback (most recent call last):
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb6ed313400>
Traceback (most recent call last):
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
 

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 09:  mean-loss = 0.8586   mean-Dice = 0.5487


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 10:  mean-loss = 0.8531   mean-Dice = 0.5218


 clients:   0%|                                                                                   | 0/3 [00:00…

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb6ed313400>
Traceback (most recent call last):
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb6ed313400>
Traceback (most recent call last):
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
 

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 11:  mean-loss = 0.8499   mean-Dice = 0.6257


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 12:  mean-loss = 0.8371   mean-Dice = 0.6337


 clients:   0%|                                                                                   | 0/3 [00:00…

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb6ed313400>
Traceback (most recent call last):
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb6ed313400>
Traceback (most recent call last):
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
 

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 13:  mean-loss = 0.8288   mean-Dice = 0.6693


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 14:  mean-loss = 0.8226   mean-Dice = 0.6345


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 15:  mean-loss = 0.8154   mean-Dice = 0.6589


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 16:  mean-loss = 0.8061   mean-Dice = 0.7012


 clients:   0%|                                                                                   | 0/3 [00:00…

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb6ed313400>
Traceback (most recent call last):
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb6ed313400>
Traceback (most recent call last):
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
 

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 17:  mean-loss = 0.7940   mean-Dice = 0.6881


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 18:  mean-loss = 0.7855   mean-Dice = 0.7274


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 19:  mean-loss = 0.7732   mean-Dice = 0.6659


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 20:  mean-loss = 0.7670   mean-Dice = 0.6978


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 21:  mean-loss = 0.7560   mean-Dice = 0.6990


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 22:  mean-loss = 0.7438   mean-Dice = 0.7574


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 23:  mean-loss = 0.7346   mean-Dice = 0.7632


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 24:  mean-loss = 0.7143   mean-Dice = 0.7513


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 25:  mean-loss = 0.7031   mean-Dice = 0.7307


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 26:  mean-loss = 0.6931   mean-Dice = 0.7653


 clients:   0%|                                                                                   | 0/3 [00:00…

IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out


Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 27:  mean-loss = 0.6779   mean-Dice = 0.7750


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 28:  mean-loss = 0.6591   mean-Dice = 0.7503


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 29:  mean-loss = 0.6536   mean-Dice = 0.7419


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 30:  mean-loss = 0.6399   mean-Dice = 0.7634


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 31:  mean-loss = 0.6186   mean-Dice = 0.7636


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 32:  mean-loss = 0.6098   mean-Dice = 0.7789


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 33:  mean-loss = 0.5843   mean-Dice = 0.7808


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 34:  mean-loss = 0.5660   mean-Dice = 0.7631


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 35:  mean-loss = 0.5581   mean-Dice = 0.7692


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 36:  mean-loss = 0.5615   mean-Dice = 0.7707


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 37:  mean-loss = 0.5326   mean-Dice = 0.7685


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 38:  mean-loss = 0.5110   mean-Dice = 0.7670


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 39:  mean-loss = 0.5016   mean-Dice = 0.7841


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 40:  mean-loss = 0.4734   mean-Dice = 0.7852


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 41:  mean-loss = 0.4653   mean-Dice = 0.7988


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 42:  mean-loss = 0.4404   mean-Dice = 0.7938


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 43:  mean-loss = 0.4235   mean-Dice = 0.7905


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 44:  mean-loss = 0.4177   mean-Dice = 0.7959


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 45:  mean-loss = 0.3944   mean-Dice = 0.8132


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 46:  mean-loss = 0.3839   mean-Dice = 0.7673


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 47:  mean-loss = 0.3966   mean-Dice = 0.7511


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 48:  mean-loss = 0.3874   mean-Dice = 0.8087


 clients:   0%|                                                                                   | 0/3 [00:00…

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 49:  mean-loss = 0.3344   mean-Dice = 0.8044


 clients:   0%|                                                                                   | 0/3 [00:00…

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb6ed313400>
Traceback (most recent call last):
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb6ed313400>
Traceback (most recent call last):
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
 

Evaluating:   0%|          | 0/47 [00:00<?, ?it/s]

Round 50:  mean-loss = 0.3340   mean-Dice = 0.8194


In [14]:
from monai.metrics import DiceMetric
from torch.utils.data import DataLoader
import torch, numpy as np

def evaluate_model_per_class(model, dataset, device, batch_size=1, num_workers=0):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
                        num_workers=num_workers, pin_memory=False)

    # mean per channel across batch  → (C,)  with empties counted
    dice_percls = DiceMetric(include_background=False,
                             reduction="mean_channel",
                             ignore_empty=False)

    dice_percls.reset()
    model.eval()
    with torch.no_grad():
        for batch in loader:
            x = batch["image"].to(device)
            y = batch["label"].to(device)     # [B,3,...] = [TC, WT, ET]
            logits = inference(x)
            preds  = (torch.sigmoid(logits) > 0.5).float()
            dice_percls(y_pred=preds, y=y)

    percls_t = dice_percls.aggregate()        # torch.Size([3])
    percls   = percls_t.detach().cpu().tolist()
    dice_percls.reset()

    metric_tc, metric_wt, metric_et = percls[0], percls[1], percls[2]
    mean_dice = float(np.mean(percls))        # simple arithmetic mean over classes
    return mean_dice, metric_tc, metric_wt, metric_et


mean_dice, metric_tc, metric_wt, metric_et = evaluate_model_per_class(
    global_model, test_dataset, device, batch_size=1, num_workers=0
)
print(f"Mean Dice: {mean_dice:.4f}")
print(f"TC: {metric_tc:.4f}  WT: {metric_wt:.4f}  ET: {metric_et:.4f}")


Mean Dice: 0.8154
TC: 0.5636  WT: 0.4310  ET: 0.6576


In [13]:
def evaluate_model_per_class_manual(model, dataset, device, batch_size=1, num_workers=0, eps=1e-5):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
                        num_workers=num_workers, pin_memory=False)

    tp = torch.zeros(3, device=device)
    p  = torch.zeros(3, device=device)
    q  = torch.zeros(3, device=device)

    model.eval()
    with torch.no_grad():
        for batch in loader:
            x = batch["image"].to(device)
            y = batch["label"].to(device)     # [B,3,...]
            logits = inference(x)
            pred   = (torch.sigmoid(logits) > 0.5).float()

            # flatten to [B, C, -1]
            B, C = pred.shape[:2]
            pr = pred.view(B, C, -1)
            gt = y.view(B, C, -1)

            tp += (pr * gt).sum(dim=(0,2))
            p  += pr.sum(dim=(0,2))
            q  += gt.sum(dim=(0,2))

    dice_c = (2*tp + eps) / (p + q + eps)     # [3]
    metric_tc, metric_wt, metric_et = [float(v) for v in dice_c.tolist()]
    mean_dice = float(dice_c.mean().item())
    return mean_dice, metric_tc, metric_wt, metric_et

mean_dice, metric_tc, metric_wt, metric_et = evaluate_model_per_class_manual(
    global_model, test_dataset, device, batch_size=1, num_workers=0
)
print(f"Mean Dice: {mean_dice:.4f}")
print(f"TC: {metric_tc:.4f}  WT: {metric_wt:.4f}  ET: {metric_et:.4f}")

Mean Dice: 0.8557
TC: 0.8639  WT: 0.8428  ET: 0.8603
