In [1]:
import logging
import os
import sys
import tempfile
from glob import glob

import nibabel as nib
import numpy as np
import torch

from monai import config
from monai.data import ImageDataset, create_test_image_3d, decollate_batch, DataLoader
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.networks.nets import UNet
from monai.transforms import Activations, EnsureChannelFirst, AsDiscrete, Compose, SaveImage, ScaleIntensity

In [3]:
# 创建测试数据集
tempdir = 'eval_dataset'
for i in range(5):
    im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

    n = nib.Nifti1Image(im, np.eye(4))
    nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

    n = nib.Nifti1Image(seg, np.eye(4))
    nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

In [4]:
images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))

In [6]:
# 定义图像和分割的变化操作
imtrans = Compose([ScaleIntensity(), EnsureChannelFirst()])
segtrans = Compose([EnsureChannelFirst()])

In [7]:
# 创建数据加载器
val_ds = ImageDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available())

In [8]:
# 创建平均dice评价指标，保存器 get_not_nans 是否返回not nans 的个数 
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
saver = SaveImage(output_dir="./output", output_ext=".nii.gz", output_postfix="seg")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
# 定义模型
model = UNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

In [10]:
# 加载训练好的权重
model.load_state_dict(torch.load("best_metric_model_segmentation3d_array.pth"))
model.eval()
with torch.no_grad():
    for val_data in val_loader:
        val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
        # 每次迭代时对一张图像进行滑动窗口推理
        # 定义窗口推理的滑动窗口大小和批量大小
        roi_size = (96, 96, 96)
        sw_batch_size = 4
        val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
        val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
        val_labels = decollate_batch(val_labels)
        # 计算当前迭代的指标
        dice_metric(y_pred=val_outputs, y=val_labels)
        # 保存
        for val_output in val_outputs:
            saver(val_output)
    # 汇总最终的平均dice结果
    print("evaluation metric:", dice_metric.aggregate().item())
    # 重置状态
    dice_metric.reset()

2022-09-22 20:00:56,983 INFO image_writer.py:194 - writing: output\im0\im0_seg.nii.gz
2022-09-22 20:00:57,241 INFO image_writer.py:194 - writing: output\im1\im1_seg.nii.gz
2022-09-22 20:00:57,430 INFO image_writer.py:194 - writing: output\im2\im2_seg.nii.gz
2022-09-22 20:00:57,635 INFO image_writer.py:194 - writing: output\im3\im3_seg.nii.gz
2022-09-22 20:00:57,819 INFO image_writer.py:194 - writing: output\im4\im4_seg.nii.gz
evaluation metric: 0.8846399188041687


In [11]:
from torchsummary import summary

In [36]:
summary(model, input_size=(1, 96, 96, 96), batch_size=1, device='cuda')


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1        [1, 16, 48, 48, 48]             448
            Conv3d-2        [1, 16, 48, 48, 48]             448
    InstanceNorm3d-3        [1, 16, 48, 48, 48]               0
           Dropout-4        [1, 16, 48, 48, 48]               0
             PReLU-5        [1, 16, 48, 48, 48]               1
            Conv3d-6        [1, 16, 48, 48, 48]           6,928
    InstanceNorm3d-7        [1, 16, 48, 48, 48]               0
           Dropout-8        [1, 16, 48, 48, 48]               0
             PReLU-9        [1, 16, 48, 48, 48]               1
     ResidualUnit-10        [1, 16, 48, 48, 48]               0
           Conv3d-11        [1, 32, 24, 24, 24]          13,856
           Conv3d-12        [1, 32, 24, 24, 24]          13,856
   InstanceNorm3d-13        [1, 32, 24, 24, 24]               0
          Dropout-14        [1, 32, 24,