In [1]:
!pip install scikit-learn



In [2]:
!pip install --force-reinstall --user einops==0.3.2

Collecting einops==0.3.2
  Downloading einops-0.3.2-py3-none-any.whl.metadata (10 kB)
Downloading einops-0.3.2-py3-none-any.whl (25 kB)
Installing collected packages: einops
Successfully installed einops-0.3.2


In [3]:
!pip install torch torchvision nibabel

Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch)
  Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch)
  Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch)
  Downloading nvidia_nvjitlink_cu12-12.4.127-py3-n

In [4]:
!pip install monai

Collecting monai
  Downloading monai-1.4.0-py3-none-any.whl.metadata (11 kB)
Downloading monai-1.4.0-py3-none-any.whl (1.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m20.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: monai
Successfully installed monai-1.4.0


In [5]:
!pip install thop

Collecting thop
  Downloading thop-0.1.1.post2209072238-py3-none-any.whl.metadata (2.7 kB)
Downloading thop-0.1.1.post2209072238-py3-none-any.whl (15 kB)
Installing collected packages: thop
Successfully installed thop-0.1.1.post2209072238


# **Another try of survey**

In [7]:
#!/usr/bin/env python3
# train_hippocampus.py

import os
import glob
import time
import copy
import random

import numpy as np
import pandas as pd
import nibabel as nib
import torch
from torch.utils.data import Dataset, DataLoader

# use threads instead of processes so we don't hit pickling issues
from multiprocessing.dummy import Pool

from monai.networks.nets import (
    UNet, UNETR, SwinUNETR, AttentionUnet,
    HighResNet, VNet, DynUNet, RegUNet, SegResNet
)
from monai.losses import DiceCELoss
from monai.inferers import sliding_window_inference

from torch.cuda.amp import autocast, GradScaler
from scipy.ndimage import zoom
from scipy.spatial.distance import directed_hausdorff
from sklearn.metrics import confusion_matrix
from thop import profile
from torchmetrics import Precision, Recall, F1Score, Specificity

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import ListedColormap
from tensorflow.keras import callbacks

# —————————————————————————————————————————————
# 1) Config & constants

DATA_ROOT   = "/kaggle/input/task04-hippocampus/Task04_Hippocampus"
IMAGES_DIR  = os.path.join(DATA_ROOT, "imagesTr")
LABELS_DIR  = os.path.join(DATA_ROOT, "labelsTr")

IN_CHANNELS = 1           # one MRI volume per case
NUM_CLASSES = 3           # 0=background,1=Anterior,2=Posterior
LABEL_NAMES = {
    0: "background",
    1: "Anterior",
    2: "Posterior",
}

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED   = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)


# —————————————————————————————————————————————
# 2) Data loading

def collect_pairs(img_dir, lbl_dir):
    imgs = sorted(glob.glob(os.path.join(img_dir, "hippocampus_*.nii*")))
    lbls = sorted(glob.glob(os.path.join(lbl_dir, "hippocampus_*.nii*")))
    img_dict = {os.path.splitext(os.path.basename(p))[0]: p for p in imgs}
    lbl_dict = {os.path.splitext(os.path.basename(p))[0]: p for p in lbls}
    keys = sorted(set(img_dict) & set(lbl_dict))
    return [(img_dict[k], lbl_dict[k]) for k in keys]

def load_case(pair):
    img_path, lbl_path = pair
    img = nib.load(img_path).get_fdata().astype(np.float32)[..., None]
    seg = nib.load(lbl_path).get_fdata().astype(np.float32)
    mn, mx = img.min(), img.max()
    img = (img - mn) / (mx - mn + 1e-8)
    factors = (64/img.shape[0], 64/img.shape[1], 64/img.shape[2], 1)
    img_rs = zoom(img, factors, order=1)
    seg_rs = zoom(seg, factors[:3], order=0)
    return img_rs, seg_rs.astype(np.float32), 1

def parallel_load(pairs, max_cases=None):
    to_load = pairs if max_cases is None else pairs[:max_cases]
    with Pool() as pool:
        results = pool.map(load_case, to_load)
    imgs, segs, flags = zip(*results)
    imgs = [i for i,f in zip(imgs, flags) if f]
    segs = [s for s,f in zip(segs, flags) if f]
    return imgs, segs

class HippocampusDataset(Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels
    def __len__(self):
        return len(self.images)
    def __getitem__(self, idx):
        img = torch.from_numpy(self.images[idx]).permute(3,0,1,2)
        seg = torch.from_numpy(self.labels[idx]).long()
        return {"image": img, "label": seg}


# —————————————————————————————————————————————
# 3) Visualization callback

class EnhancedVisCallback(callbacks.Callback):
    def __init__(self, val_loader, num_samples=1):
        super().__init__()
        self.val_loader  = val_loader
        self.num_samples = num_samples
        # black=bg, red=Anterior, blue=Posterior
        self.cmap = ListedColormap(["black","red","blue"])
        self._model = None

    def attach_model(self, model):
        self._model = model

    def on_epoch_end(self, epoch, logs=None):
        if self._model is None: return
        self._model.eval()
        with torch.no_grad():
            batch = next(iter(self.val_loader))
            imgs = batch["image"].to(DEVICE)
            lbls = batch["label"].cpu().numpy()
            out  = sliding_window_inference(imgs, (64,64,64), 4, self._model)
            preds= torch.argmax(out, dim=1).cpu().numpy()
            i, mid = 0, imgs.shape[-1]//2
            im = imgs.cpu().numpy()[i,0,:,:,mid]
            gt = lbls[i,:,:,mid]
            pr = preds[i,:,:,mid]
            fig, ax = plt.subplots(1,3,figsize=(12,4))
            ax[0].imshow(im, cmap="gray");     ax[0].axis("off"); ax[0].set_title("Input")
            ax[1].imshow(im, cmap="gray"); ax[1].imshow(gt, cmap=self.cmap, alpha=0.5); ax[1].axis("off"); ax[1].set_title("GT")
            ax[2].imshow(im, cmap="gray"); ax[2].imshow(pr, cmap=self.cmap, alpha=0.5); ax[2].axis("off"); ax[2].set_title("Pred")
            plt.tight_layout(); plt.show()


# —————————————————————————————————————————————
# 4) Metrics & visualizer

class MetricsTracker:
    def __init__(self):
        self.history = {
            "train_loss":    [],
            "val_dice":      [],
            "val_hausdorff": [],
            "per_class_dice": {f"class_{i}": [] for i in range(1,NUM_CLASSES)},
            "precision":     [],
            "recall":        [],
            "f1_score":      [],
            "specificity":   [],
            "learning_rate": [],
            "epoch_time":    []
        }
        self.precision   = Precision(task="multiclass", num_classes=NUM_CLASSES, average="macro").to(DEVICE)
        self.recall      = Recall   (task="multiclass", num_classes=NUM_CLASSES, average="macro").to(DEVICE)
        self.f1          = F1Score  (task="multiclass", num_classes=NUM_CLASSES, average="macro").to(DEVICE)
        self.specificity = Specificity(task="multiclass", num_classes=NUM_CLASSES, average="macro").to(DEVICE)

    def update(self, d):
        for k,v in d.items():
            if k in self.history:
                self.history[k].append(v)

class TrainingVisualizer:
    def __init__(self, save_dir="visualization_results"):
        self.save_dir = save_dir
        os.makedirs(save_dir, exist_ok=True)
        sns.set_theme()

    def plot_comparison_metrics(self, mh, metrics, fname="metrics.png"):
        fig, ax = plt.subplots(len(metrics),1,figsize=(8,4*len(metrics)))
        if len(metrics)==1: ax=[ax]
        for i,m in enumerate(metrics):
            for name,h in mh.items():
                ax[i].plot(h[m], label=name)
            ax[i].set_title(m.replace("_"," ").title()); ax[i].legend(); ax[i].grid(True)
        plt.tight_layout(); plt.savefig(os.path.join(self.save_dir,fname)); plt.close()

    def plot_class_performance(self, mh, fname="per_class.png"):
        fig, ax = plt.subplots(len(mh),1,figsize=(8,4*len(mh)))
        if len(mh)==1: ax=[ax]
        for i,(name,h) in enumerate(mh.items()):
            for cls, vals in h["per_class_dice"].items():
                ax[i].plot(vals, label=cls)
            ax[i].set_title(f"{name} per-class Dice"); ax[i].legend(); ax[i].grid(True)
        plt.tight_layout(); plt.savefig(os.path.join(self.save_dir,fname)); plt.close()

    def create_training_summary(self, mh, fname="summary.png"):
        fig = plt.figure(figsize=(12,10))
        gs  = fig.add_gridspec(3,2)
        ax1 = fig.add_subplot(gs[0,:])
        for m,h in mh.items(): ax1.plot(h["train_loss"], label=m)
        ax1.set_title("Train Loss"); ax1.legend(); ax1.grid(True)
        ax2 = fig.add_subplot(gs[1,0])
        for m,h in mh.items(): ax2.plot(h["val_dice"], label=m)
        ax2.set_title("Val Dice"); ax2.legend(); ax2.grid(True)
        ax3 = fig.add_subplot(gs[1,1])
        for m,h in mh.items(): ax3.plot(h["val_hausdorff"], label=m)
        ax3.set_title("Val Hausdorff"); ax3.legend(); ax3.grid(True)
        ax4 = fig.add_subplot(gs[2,0])
        mods = list(mh.keys()); x=np.arange(len(mods)); w=0.2
        for i,met in enumerate(["precision","recall","f1_score"]):
            vals=[np.mean(mh[m][met]) for m in mods]
            ax4.bar(x+i*w, vals, w, label=met)
        ax4.set_xticks(x+w); ax4.set_xticklabels(mods); ax4.set_title("Macro metrics"); ax4.legend(); ax4.grid(True)
        ax5 = fig.add_subplot(gs[2,1])
        for m,h in mh.items(): ax5.plot(h["learning_rate"], label=m)
        ax5.set_title("LR"); ax5.legend(); ax5.grid(True)
        plt.tight_layout(); plt.savefig(os.path.join(self.save_dir,fname)); plt.close()

    def plot_confusion_matrices(self, cms, fname="confusion.png"):
        valid = {n:cm for n,cm in cms.items() if isinstance(cm,np.ndarray)}
        if not valid: return
        fig, axs = plt.subplots(1,len(valid),figsize=(5*len(valid),4))
        if len(valid)==1: axs=[axs]
        for ax,(n,cm) in zip(axs, valid.items()):
            sns.heatmap(cm, annot=True, fmt="d", ax=ax,
                        xticklabels=[LABEL_NAMES[i] for i in range(NUM_CLASSES)],
                        yticklabels=[LABEL_NAMES[i] for i in range(NUM_CLASSES)])
            ax.set_title(f"{n} Confusion"); ax.set_xlabel("Pred"); ax.set_ylabel("True")
        plt.tight_layout(); plt.savefig(os.path.join(self.save_dir,fname)); plt.close()

    def create_performance_report(self, mh, save_csv="performance_report.csv"):
        rows=[]
        for m,h in mh.items():
            rows.append({
                "Model": m,
                "Best Dice": max(h["val_dice"]),
                "Final Dice": h["val_dice"][-1],
                "Precision": np.mean(h["precision"]),
                "Recall":    np.mean(h["recall"]),
                "F1 Score":  np.mean(h["f1_score"]),
                "Best Epoch": int(np.argmax(h["val_dice"])+1),
                "Final Loss": h["train_loss"][-1],
                "Final Hausdorff": h["val_hausdorff"][-1],
            })
        df = pd.DataFrame(rows)
        df.to_csv(os.path.join(self.save_dir, save_csv), index=False)
        return df


# —————————————————————————————————————————————
# 5) Utility metrics

def per_class_dice(pred, tgt):
    out = {}
    for cls in range(1, NUM_CLASSES):
        p = (pred==cls).astype(np.float32)
        t = (tgt==cls).astype(np.float32)
        i = (p*t).sum(); u = p.sum()+t.sum()
        out[f"class_{cls}"] = 2*i/(u+1e-8) if u>0 else 1.0
    return out

def hausdorff(pred, tgt):
    ds=[]
    for cls in range(1, NUM_CLASSES):
        pp, tt = np.argwhere(pred==cls), np.argwhere(tgt==cls)
        if len(pp) and len(tt):
            d1 = directed_hausdorff(pp, tt)[0]
            d2 = directed_hausdorff(tt, pp)[0]
            ds.append(max(d1,d2))
    return ds

def compute_model_metrics(model, inp_shape):
    m = copy.deepcopy(model).cpu().eval()
    x = torch.randn(*inp_shape).cpu()
    flops, params = profile(m, inputs=(x,), verbose=False)
    return flops, params


# —————————————————————————————————————————————
# 6) Training & evaluation

def train_model(model, train_loader, val_loader, name,
                max_epochs=50, patience=10, cbs=None):
    model.to(DEVICE)
    loss_fn = DiceCELoss(to_onehot_y=True, softmax=True)
    opt     = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
    scaler  = torch.amp.GradScaler()
    best_d, no_imp = -1, 0
    tracker = MetricsTracker()
    t0 = time.time()

    if cbs:
        for cb in cbs:
            if hasattr(cb, "attach_model"):
                cb.attach_model(model)

    for ep in range(1, max_epochs+1):
        print(f"Epoch {ep}/{max_epochs}")
        t_ep = time.time()
        model.train()
        running_loss=0
        for b in train_loader:
            imgs = b["image"].to(DEVICE)
            lbls = b["label"].to(DEVICE).unsqueeze(1)
            opt.zero_grad()
            with torch.amp.autocast(device_type='cuda'):
                out  = model(imgs)
                loss = loss_fn(out, lbls)
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
            running_loss += loss.item()
        avg_loss = running_loss/len(train_loader)

        model.eval()
        dices, hauss = [], []
        all_p, all_t = [], []
        with torch.no_grad():
            for i,b in enumerate(val_loader):
                imgs = b["image"].to(DEVICE)
                lbls = b["label"].to(DEVICE).flatten()
                out  = sliding_window_inference(imgs,(64,64,64),4,model)
                pr   = torch.argmax(out,dim=1).flatten()
                all_p.append(pr); all_t.append(lbls)

                pr_np = pr.cpu().numpy().reshape(b["label"].shape[1:])
                lb_np = b["label"].cpu().numpy()[0]
                dices.append(np.mean(list(per_class_dice(pr_np,lb_np).values())))
                hauss.append(np.mean(hausdorff(pr_np, lb_np)))

                if i==0:
                    pc = per_class_dice(pr_np, lb_np)
                    for k,v in pc.items():
                        tracker.history["per_class_dice"][k].append(v)

        preds = torch.cat(all_p).to(DEVICE)
        trues = torch.cat(all_t).to(DEVICE)
        prec = tracker.precision(preds,trues).item()
        rec  = tracker.recall   (preds,trues).item()
        f1   = tracker.f1       (preds,trues).item()
        spec = tracker.specificity(preds,trues).item()

        m_d = np.mean(dices)
        m_h = np.mean(hauss)
        tracker.update({
            "train_loss":    avg_loss,
            "val_dice":      m_d,
            "val_hausdorff": m_h,
            "precision":     prec,
            "recall":        rec,
            "f1_score":      f1,
            "specificity":   spec,
            "learning_rate": opt.param_groups[0]["lr"]
        })
        tracker.history["epoch_time"].append(time.time()-t_ep)

        if cbs:
            for cb in cbs:
                cb.on_epoch_end(ep, {"val_dice":m_d,"val_hausdorff":m_h})

        if m_d > best_d:
            best_d = m_d
            torch.save({"model_state_dict":model.state_dict()},
                       f"best_model_{name}.pth")
            no_imp = 0
        else:
            no_imp += 1
            if no_imp >= patience:
                print("Early stopping.")
                break

    print(f"{name} done in {time.time()-t0:.1f}s, best val_dice={best_d:.4f}")
    return best_d, time.time()-t0, tracker.history


def evaluate_best_model(model, name, val_loader, save_dir="visualization_results"):
    ckpt = torch.load(f"best_model_{name}.pth", map_location=DEVICE,weights_only=True)
    model.load_state_dict(ckpt["model_state_dict"])
    model.to(DEVICE).eval()

    all_p, all_t = [], []
    pc_metrics = {f"class_{i}":{"dice":[],"precision":[],"recall":[],"iou":[],"hd95":[]} 
                  for i in range(1,NUM_CLASSES)}

    with torch.no_grad():
        for b in val_loader:
            imgs = b["image"].to(DEVICE)
            lbl  = b["label"].cpu().numpy()[0]
            out  = sliding_window_inference(imgs,(64,64,64),4,model)
            pr   = torch.argmax(out,dim=1).cpu().numpy()[0]
            all_p.append(pr.flatten()); all_t.append(lbl.flatten())
            for cls in range(1,NUM_CLASSES):
                p = (pr==cls).astype(np.float32)
                t = (lbl==cls).astype(np.float32)
                tp = (p*t).sum(); fp=(p*(1-t)).sum(); fn=((1-p)*t).sum()
                dice = 2*tp/(2*tp+fp+fn+1e-8)
                prec = tp/(tp+fp+1e-8); rec=tp/(tp+fn+1e-8)
                iou  = tp/(tp+fp+fn+1e-8)
                if p.sum() and t.sum():
                    d1 = directed_hausdorff(np.argwhere(p), np.argwhere(t))[0]
                    d2 = directed_hausdorff(np.argwhere(t), np.argwhere(p))[0]
                    hd  = max(d1,d2)
                else:
                    hd = np.nan
                for met,val in zip(["dice","precision","recall","iou","hd95"],
                                   [dice,prec,rec,iou,hd]):
                    pc_metrics[f"class_{cls}"][met].append(val)

    preds_flat = np.concatenate(all_p)
    trues_flat = np.concatenate(all_t)
    cm = confusion_matrix(trues_flat, preds_flat)

    # save per-class CSV
    df_pc = pd.DataFrame({
        LABEL_NAMES[int(c.split("_")[1])] : {
            m: np.nanmean(vals) for m,vals in mets.items()
        }
        for c,mets in pc_metrics.items()
    }).T
    os.makedirs(save_dir, exist_ok=True)
    df_pc.to_csv(os.path.join(save_dir, f"{name}_per_class_metrics.csv"))

    macro = {
        "macro_precision": np.mean([np.mean(m["precision"]) for m in pc_metrics.values()]),
        "macro_recall":    np.mean([np.mean(m["recall"])    for m in pc_metrics.values()]),
        "macro_iou":       np.mean([np.mean(m["iou"])       for m in pc_metrics.values()])
    }

    return {"per_class": df_pc.to_dict(orient="index"),
            "macro":    macro,
            "confusion_matrix": cm}


# —————————————————————————————————————————————
# 7) Main()

def main():
    print("Collecting and loading data…")
    pairs = collect_pairs(IMAGES_DIR, LABELS_DIR)
    images, labels = parallel_load(pairs)

    idx = list(range(len(images)))
    random.shuffle(idx)
    split = int(0.8 * len(idx))
    tr_idx, vl_idx = idx[:split], idx[split:]

    train_ds = HippocampusDataset([images[i] for i in tr_idx],
                                  [labels[i] for i in tr_idx])
    val_ds   = HippocampusDataset([images[i] for i in vl_idx],
                                  [labels[i] for i in vl_idx])

    train_loader = DataLoader(train_ds, batch_size=1, shuffle=True,  num_workers=4)
    val_loader   = DataLoader(val_ds,   batch_size=1, shuffle=False, num_workers=4)

    models = {
        "UNet":         UNet(3, IN_CHANNELS, NUM_CLASSES, (16,32,64,128,256), (2,2,2,2)),
        "VNet":         VNet(3, IN_CHANNELS, NUM_CLASSES, dropout_prob_down=0.5, dropout_prob_up=(0.5,0.5)),
        "DynUNet":      DynUNet(3, IN_CHANNELS, NUM_CLASSES,
                                [[3]*3]*5,
                                [[1]*3]+[[2]*3]*4,
                                [[2]*3]*4,
                                [8,16,32,64,128]),
        "RegUNet":      RegUNet(3, IN_CHANNELS, NUM_CLASSES, depth=4, out_channels=NUM_CLASSES),
        "SegResNet":    SegResNet(
                spatial_dims=3,
                init_filters=16,            # base number of feature maps
                in_channels=IN_CHANNELS,    # 1
                out_channels=NUM_CLASSES,   # 3
                blocks_down=(1,1,1,2),
                blocks_up=(1,1,1),
                num_groups=4                 # optional: make sure num_channels % num_groups == 0
            ),
        "UNETR":     UNETR(
                spatial_dims=3,
                in_channels=IN_CHANNELS,
                out_channels=NUM_CLASSES,
                img_size=(64,64,64),
                feature_size=16,
                hidden_size=768,
                mlp_dim=3072,
                num_heads=12
            ),
        "SwinUNETR":    SwinUNETR(img_size=(64,64,64), in_channels=IN_CHANNELS,
                                 out_channels=NUM_CLASSES, feature_size=48),
        "AttentionUnet":AttentionUnet(3, IN_CHANNELS, NUM_CLASSES,
                                      channels=(16,32,64,128,256), strides=(2,2,2,2)),
        "HighResNet":   HighResNet(
            spatial_dims=3,
            in_channels=IN_CHANNELS,
            out_channels=NUM_CLASSES)
    }

    results = {}
    vis_cb = EnhancedVisCallback(val_loader)

    for name, model in models.items():
        print(f"\n=== {name} ===")
        fl, pa = compute_model_metrics(model, (1,IN_CHANNELS,64,64,64))
        print(f"{name}: {fl/1e9:.2f} GFLOPs, {(pa*4)/(1024**2):.2f} MB")

        best_d, tot_t, hist = train_model(
            model, train_loader, val_loader, name,
            max_epochs=50, patience=10, cbs=[]
        )
        perf = evaluate_best_model(model, name, val_loader)
        results[name] = {"history": hist, "confusion_matrix": perf["confusion_matrix"]}

    viz = TrainingVisualizer()
    mh  = {n:r["history"]           for n,r in results.items()}
    cms = {n:r["confusion_matrix"]  for n,r in results.items()}

    viz.plot_comparison_metrics(mh, ["train_loss","val_dice","val_hausdorff"])
    viz.plot_class_performance(mh)
    viz.create_training_summary(mh)
    viz.plot_confusion_matrices(cms)
    df_summary = viz.create_performance_report(mh)

    print("\nFinal performance summary:\n", df_summary)


if __name__ == "__main__":
    main()


Collecting and loading data…

=== UNet ===
UNet: 2.55 GFLOPs, 7.55 MB
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Early stopping.
UNet done in 121.8s, best val_dice=0.8338

=== VNet ===
VNet: 96.00 GFLOPs, 173.96 MB
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Early stopping.
VNet done in 1324.6s, best val_dice=0.7281

=== DynUNet ===
DynUNet: 16.98 GFLOPs, 10.78 MB
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epo



AttentionUnet: 18.31 GFLOPs, 22.54 MB
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Early stopping.
AttentionUnet done in 205.0s, best val_dice=0.7286

=== HighResNet ===
HighResNet: 212.46 GFLOPs, 3.09 MB
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Early stopping.
HighResNet done in 902.8s, best val_dice=0.5679

Final performance summary:
            Model  Best Dice  Final Dice  Precision    Recall  F1 Score  \
0           UNet   0.833847    0.830391   0.863981  0.871808  0.866991   
1           VNet   0.728091    0.707552   0.833184  0.771704  0.781677   
2        DynUNet   0.837548    0.836554   0.829523  0.908471  0.863669   
3        RegUNet   0.4190