https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/brats_segmentation_3d.ipynb

In [4]:
import os
import shutil
import tempfile
import time
import matplotlib.pyplot as plt
from monai.apps import DecathlonDataset
from monai.config import print_config
from monai.data import DataLoader, decollate_batch
from monai.handlers.utils import from_engine
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.networks.nets import SegResNet
from monai.transforms import (
    Activations,
    Activationsd,
    AsDiscrete,
    AsDiscreted,
    Compose,
    Invertd,
    LoadImaged,
    MapTransform,
    NormalizeIntensityd,
    Orientationd,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandSpatialCropd,
    Spacingd,
    EnsureTyped,
    EnsureChannelFirstd,
)
from monai.utils import set_determinism
import onnxruntime
from tqdm import tqdm

import torch

print_config()
set_determinism(seed=0)

MONAI version: 1.6.dev2542
Numpy version: 2.1.2
Pytorch version: 2.8.0+cu126
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 612f3dd3cba4d73cfcea4b5329079e20aa31523d
MONAI __file__: /home/<username>/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: 5.4.4
Nibabel version: 5.3.2
scikit-image version: 0.25.2
scipy version: 1.15.3
Pillow version: 11.0.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.23.0+cu126
tqdm version: 4.67.1
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 7.0.0
pandas version: 2.3.2
einops version: 0.8.1
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For d

In [9]:
directory = "/mnt/d/Datasets/tmp"
if directory is not None:
    os.makedirs(directory, exist_ok=True)
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

/mnt/d/Datasets/tmp


In [11]:
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    """
    Convert labels to multi channels based on brats classes:
    label 1 is the peritumoral edema
    label 2 is the GD-enhancing tumor
    label 3 is the necrotic and non-enhancing tumor core
    The possible classes are TC (Tumor core), WT (Whole tumor)
    and ET (Enhancing tumor).

    """

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            # merge label 2 and label 3 to construct TC
            result.append(torch.logical_or(d[key] == 2, d[key] == 3))
            # merge labels 1, 2 and 3 to construct WT
            result.append(torch.logical_or(torch.logical_or(d[key] == 2, d[key] == 3), d[key] == 1))
            # label 2 is ET
            result.append(d[key] == 2)
            d[key] = torch.stack(result, axis=0).float()
        return d

train_transform = Compose(
    [
        # load 4 Nifti images and stack them together
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        RandSpatialCropd(keys=["image", "label"], roi_size=[224, 224, 144], random_size=False),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
    ]
)
val_transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)

# MINE!!

In [21]:
# MINE!!

import os, copy, time, random, torch, numpy as np                                 
import glob
from tqdm import tqdm
import pandas as pd
from torch.utils.data import DataLoader
from monai.data import CacheDataset
import glob, nibabel as nib, pandas as pd
from monai.data import CacheDataset, DataLoader
from monai.transforms import (
    LoadImaged, EnsureChannelFirstd, Orientationd, ScaleIntensityd,
    RandFlipd, RandSpatialCropd, Compose, SelectItemsd
)

from utils import *
  
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# -----------------------------------------------------------
# 0. paths & meta-data (unchanged) ---------------------------
# -----------------------------------------------------------
BRATS_DIR = "/mnt/d/Datasets/FETS_data/MICCAI_FeTS2022_TrainingData"
CSV_PATH  = f"{BRATS_DIR}/partitioning_1.csv"
MODALITIES = ["flair", "t1", "t1ce", "t2"]
LABEL_KEY  = "seg"

# -----------------------------------------------------------
# 1. read partition file  ➜  { id : [subjects] } ------------
# -----------------------------------------------------------
part_df       = pd.read_csv(CSV_PATH)
partition_map = (
    part_df.groupby("Partition_ID")["Subject_ID"]
           .apply(list).to_dict()
)                               # keys are 1 … 23

VAL_CENTRES = {21, 22, 23}          # ← our hold-out set
# VAL_CENTRES = {22, 23}          # ← our sanity set

# split once, reuse everywhere
train_partitions = {cid: sids for cid, sids in partition_map.items()
                    if cid not in VAL_CENTRES}
val_subjects     = sum((partition_map[cid] for cid in VAL_CENTRES), [])

# -----------------------------------------------------------
# 2. helper to build MONAI-style record dicts ----------------
# -----------------------------------------------------------

def build_records(subject_ids):
    recs = []
    for sid in subject_ids:
        sdir = f"{BRATS_DIR}/{sid}"
        images = [f"{sdir}/{sid}_{m}.nii.gz" for m in MODALITIES]  # 4 modalities
        recs.append({"image": images, "label": f"{sdir}/{sid}_{LABEL_KEY}.nii.gz"})
    return recs


# -----------------------------------------------------------
# 3. MONAI CacheDatasets ------------------------------------
# -----------------------------------------------------------
# ── client-wise training sets ───────────────────────────────
CUT_OFF, FRAC, SEED = 3, 0.1, 42
rng = random.Random(SEED)

train_datasets = {}
for cid, subj_ids in train_partitions.items():
    if cid > CUT_OFF:                                    # keep your cap
        break
    k = max(1, int(len(subj_ids) * FRAC))                # e.g. 30 %
    sample_ids = rng.sample(subj_ids, k)
    train_datasets[cid] = CacheDataset(
        build_records(sample_ids), transform=train_transform, cache_rate=1
    )

# ── single validation dataset made from *all* val subjects ─
test_dataset = CacheDataset(
    build_records(val_subjects), transform=val_transform, cache_rate=1
)

print("train per-centre sizes:", {k: len(v) for k, v in train_datasets.items()})
print("validation size:", len(test_dataset))


Loading dataset: 100%|██████████████████████████████████████████████████████████████████| 51/51 [01:19<00:00,  1.55s/it]
Loading dataset: 100%|████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.59s/it]
Loading dataset: 100%|████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.43s/it]
Loading dataset: 100%|██████████████████████████████████████████████████████████████████| 80/80 [03:12<00:00,  2.41s/it]

train per-centre sizes: {1: 51, 2: 1, 3: 1}
validation size: 80





In [22]:
max_epochs = 300
val_interval = 1
VAL_AMP = True

# standard PyTorch program style: create SegResNet, DiceLoss and Adam optimizer
device = torch.device("cuda:0")
global_model = SegResNet(
    blocks_down=[1, 2, 2, 4],
    blocks_up=[1, 1, 1],
    init_filters=16,
    in_channels=4,
    out_channels=3,
    dropout_prob=0.2,
).to(device)
loss_function = DiceLoss(smooth_nr=0, smooth_dr=1e-5, squared_pred=True, to_onehot_y=False, sigmoid=True)
optimizer = torch.optim.Adam(global_model.parameters(), 1e-4, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)

dice_metric = DiceMetric(include_background=True, reduction="mean")
dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")

post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])


# define inference method
def inference(input):
    def _compute(input):
        return sliding_window_inference(
            inputs=input,
            roi_size=(240, 240, 160),
            sw_batch_size=1,
            predictor=global_model,
            overlap=0.5,
        )

    if VAL_AMP:
        with torch.autocast("cuda"):
            return _compute(input)
    else:
        return _compute(input)


# use amp to accelerate training
scaler = torch.GradScaler("cuda")
# enable cuDNN benchmark
torch.backends.cudnn.benchmark = True

In [27]:
def evaluate_model(model, dataset, device, batch_size=1):
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    dice_metric.reset()
    model.eval()
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            inputs = batch["image"].to(device)
            labels = batch["label"].to(device)          # already multi-channel (WT/TC/ET) floats

            # sliding-window inference (your helper)
            logits = inference(inputs)                  # [B, 3, D, H, W]

            # tensor-wise post-processing (no decollate, no meta juggling)
            preds = torch.sigmoid(logits)
            preds = (preds > 0.5).float()

            # accumulate Dice on tensors directly
            dice_metric(y_pred=preds, y=labels)

    mean_dice = dice_metric.aggregate().item()
    dice_metric.reset()
    return mean_dice

print("Dice before any training:", evaluate_model(global_model, test_dataset, device))


Evaluating: 100%|███████████████████████████████████████████████████████████████████████| 80/80 [00:38<00:00,  2.06it/s]


Dice before any training: 0.02203102968633175


In [30]:
from tqdm.auto import tqdm, trange   # trange == tqdm(range())
from collections import OrderedDict

def average_weights(state_dicts, fractions):
    """
    Federated averaging with client fractions (must sum to 1).
    state_dicts: list of state_dicts (same keys)
    fractions:   list of floats, same length, sum≈1
    """
    avg_sd = OrderedDict()
    for k in state_dicts[0].keys():
        avg = 0.0
        for sd, w in zip(state_dicts, fractions):
            avg += sd[k] * w
        avg_sd[k] = avg
    return avg_sd

# ────────────────────────────────────────────────────────────
# 1. one-client update (returns weights + mean loss)          │
# ────────────────────────────────────────────────────────────
def local_train(model, loader, device, lr=1e-4, epochs=1):
    """
    Train a local copy of the global model on one client's DataLoader.
    Uses your DiceLoss (multi-label, sigmoid) and full crops from transforms.
    """
    model = copy.deepcopy(model).to(device)
    model.train()

    # reuse your loss choice; or inline DiceLoss the same way
    crit = DiceLoss(smooth_nr=0, smooth_dr=1e-5, squared_pred=True,
                    to_onehot_y=False, sigmoid=True).to(device)

    opt = torch.optim.Adam(model.parameters(), lr=lr)
    epoch_losses = []

    for _ in range(epochs):
        running = 0.0
        for batch in loader:
            img = batch["image"].to(device)   # [B, 4, D, H, W]
            msk = batch["label"].to(device)   # [B, 3, D, H, W]

            opt.zero_grad(set_to_none=True)
            logits = model(img)               # [B, 3, D, H, W]
            loss = crit(logits, msk)
            loss.backward()
            opt.step()

            running += float(loss.item())
        epoch_losses.append(running / max(1, len(loader)))

    return model.state_dict(), float(np.mean(epoch_losses))


# ────────────────────────────────────────────────────────────
# 2. FedAvg training loop (simple tqdm + clean prints)        │
# ────────────────────────────────────────────────────────────
from tqdm.auto import tqdm, trange

EPOCHS, LOCAL_EPOCHS, LR, BATCH = 50, 1, 1e-4, 1

idxs_users = list(train_datasets.keys())
sizes      = {k: len(ds) for k, ds in train_datasets.items()}
total_n    = sum(sizes.values())
fractions  = [sizes[k] / total_n for k in idxs_users]

print("Dice before any training:", evaluate_model(global_model, test_dataset, device))

for rnd in trange(1, EPOCHS + 1, desc="Global rounds"):
    local_weights, client_losses = [], []

    for cid in tqdm(idxs_users, desc=" clients", leave=False):
        loader = DataLoader(
            train_datasets[cid],
            batch_size=BATCH,
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )
        w, loss = local_train(global_model, loader, device, lr=LR, epochs=LOCAL_EPOCHS)
        local_weights.append(w)
        client_losses.append(loss)

    # FedAvg aggregation
    global_model.load_state_dict(average_weights(local_weights, fractions))

    mean_loss = float(np.mean(client_losses))
    mean_dice = float(evaluate_model(global_model, test_dataset, device))
    print(f"Round {rnd:02d}:  mean-loss = {mean_loss:.4f}   mean-Dice = {mean_dice:.4f}")



Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Dice before any training: 0.02203102968633175


Global rounds:   0%|          | 0/50 [00:00<?, ?it/s]

 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 01:  mean-loss = 0.9634   mean-Dice = 0.3505


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 02:  mean-loss = 0.9272   mean-Dice = 0.3788


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 03:  mean-loss = 0.9216   mean-Dice = 0.4779


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 04:  mean-loss = 0.9172   mean-Dice = 0.5140


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 05:  mean-loss = 0.9131   mean-Dice = 0.4599


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 06:  mean-loss = 0.9102   mean-Dice = 0.5396


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 07:  mean-loss = 0.9060   mean-Dice = 0.4988


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 08:  mean-loss = 0.9027   mean-Dice = 0.5111


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 09:  mean-loss = 0.8987   mean-Dice = 0.5040


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 10:  mean-loss = 0.8933   mean-Dice = 0.5754


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 11:  mean-loss = 0.8874   mean-Dice = 0.5816


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 12:  mean-loss = 0.8820   mean-Dice = 0.6345


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 13:  mean-loss = 0.8761   mean-Dice = 0.6461


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 14:  mean-loss = 0.8704   mean-Dice = 0.6531


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 15:  mean-loss = 0.8649   mean-Dice = 0.5641


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 16:  mean-loss = 0.8606   mean-Dice = 0.6270


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 17:  mean-loss = 0.8542   mean-Dice = 0.6362


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 18:  mean-loss = 0.8449   mean-Dice = 0.6595


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f86f43170a0>
Traceback (most recent call last):
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f86f43170a0>
Traceback (most recent call last):
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
 

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 19:  mean-loss = 0.8384   mean-Dice = 0.6942


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 20:  mean-loss = 0.8299   mean-Dice = 0.7099


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f86f43170a0>
Traceback (most recent call last):
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f86f43170a0>
Traceback (most recent call last):
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
 

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 21:  mean-loss = 0.8221   mean-Dice = 0.6646


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 22:  mean-loss = 0.8133   mean-Dice = 0.6698


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f86f43170a0>
Traceback (most recent call last):
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f86f43170a0>
Traceback (most recent call last):
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
 

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 23:  mean-loss = 0.8063   mean-Dice = 0.6718


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 24:  mean-loss = 0.7949   mean-Dice = 0.6567


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f86f43170a0>
Traceback (most recent call last):
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1647, in _shutdown_workers
    if w.is_alive():
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f86f43170a0>
Traceback (most recent call last):
  File "/home/locolinux2/miniconda3/envs/m_quant_py310/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1664, in __del__
    self._shutdown_workers()
 

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 25:  mean-loss = 0.7885   mean-Dice = 0.7232


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 26:  mean-loss = 0.7737   mean-Dice = 0.7024


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 27:  mean-loss = 0.7653   mean-Dice = 0.7398


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 28:  mean-loss = 0.7520   mean-Dice = 0.7463


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 29:  mean-loss = 0.7405   mean-Dice = 0.7442


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 30:  mean-loss = 0.7263   mean-Dice = 0.6844


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 31:  mean-loss = 0.7226   mean-Dice = 0.7661


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 32:  mean-loss = 0.7022   mean-Dice = 0.7551


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 33:  mean-loss = 0.6848   mean-Dice = 0.7496


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 34:  mean-loss = 0.6746   mean-Dice = 0.7385


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 35:  mean-loss = 0.6617   mean-Dice = 0.7824


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 36:  mean-loss = 0.6449   mean-Dice = 0.7524


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out
IOStream.flush timed out


Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 37:  mean-loss = 0.6332   mean-Dice = 0.7225


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 38:  mean-loss = 0.6197   mean-Dice = 0.7937


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 39:  mean-loss = 0.5910   mean-Dice = 0.7935


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 40:  mean-loss = 0.5726   mean-Dice = 0.7829


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 41:  mean-loss = 0.5639   mean-Dice = 0.7851


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 42:  mean-loss = 0.5466   mean-Dice = 0.7931


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 43:  mean-loss = 0.5356   mean-Dice = 0.8000


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 44:  mean-loss = 0.5103   mean-Dice = 0.7888


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 45:  mean-loss = 0.4990   mean-Dice = 0.7706


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 46:  mean-loss = 0.4990   mean-Dice = 0.7882


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 47:  mean-loss = 0.4674   mean-Dice = 0.8039


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 48:  mean-loss = 0.4414   mean-Dice = 0.7950


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 49:  mean-loss = 0.4243   mean-Dice = 0.7762


 clients:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/80 [00:00<?, ?it/s]

Round 50:  mean-loss = 0.4072   mean-Dice = 0.8286
