In [13]:
import logging
import os
import sys
import tempfile
from glob import glob

import nibabel as nib
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

import monai
from monai.data import ImageDataset, create_test_image_3d, decollate_batch, DataLoader
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.transforms import Activations, EnsureChannelFirst, AsDiscrete, Compose, RandRotate90, RandSpatialCrop, ScaleIntensity
from monai.visualize import plot_2d_or_3d_image

In [2]:
tempdir = 'Temp_dataset'
# 随机生成40个数据 大小为(128, 128, 128) 分割类别为1
for i in range(40):
    im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1)

    n = nib.Nifti1Image(im, np.eye(4))
    nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz"))

    n = nib.Nifti1Image(seg, np.eye(4))
    nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))

In [3]:
# 获取图像和分割的路径
images = sorted(glob(os.path.join(tempdir, "im*.nii.gz")))
segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz")))

In [5]:
# 定义训练集图像和分割的变化操作
train_imtrans = Compose(
    [
        ScaleIntensity(),           # 将输入图像的强度放大到给定的值范围 默认为(0,1)之间
        EnsureChannelFirst(),       # 调整或增加输入数据的通道维数 确保第一维为通道维度
        RandSpatialCrop((96, 96, 96), random_size=False),       # 随机裁剪图像固定块的大小 random_size=False 输入图像小于roi维度将不会被裁剪
        RandRotate90(prob=0.5, spatial_axes=(0, 2)),            # 随机旋转图像概率为0.5， spatial_axes=（0，2） 在第一维和第三维组成的平面上进行旋转
    ]
)
train_segtrans = Compose(
    [
        EnsureChannelFirst(),
        RandSpatialCrop((96, 96, 96), random_size=False),
        RandRotate90(prob=0.5, spatial_axes=(0, 2)),
    ]
)

In [6]:
# 定义预测集图像和分割的变化操作
val_imtrans = Compose([ScaleIntensity(), EnsureChannelFirst()])
val_segtrans = Compose([EnsureChannelFirst()])

In [8]:
# 创建训练数据加载器
train_ds = ImageDataset(images[:20], segs[:20], transform=train_imtrans, seg_transform=train_segtrans)
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available())

In [10]:
# 创建预测数据加载器
val_ds = ImageDataset(images[-20:], segs[-20:], transform=val_imtrans, seg_transform=val_segtrans)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available())

In [11]:
# 创建平均dice评价指标， get_not_nans 是否返回not nans 的个数
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)     
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

In [12]:
# 创建 UNet, DiceLoss 和 Adam 优化器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = monai.networks.nets.UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)
loss_function = monai.losses.DiceLoss(sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)

In [14]:
# start a typical PyTorch training
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
writer = SummaryWriter()
for epoch in range(5):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{5}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = len(train_ds) // train_loader.batch_size
        print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            val_images = None
            val_labels = None
            val_outputs = None
            for val_data in val_loader:
                val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
                roi_size = (96, 96, 96)
                sw_batch_size = 4
                val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
                val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
                # compute metric for current iteration
                dice_metric(y_pred=val_outputs, y=val_labels)
            # aggregate the final mean dice result
            metric = dice_metric.aggregate().item()
            # reset the status for next validation round
            dice_metric.reset()
            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), "best_metric_model_segmentation3d_array.pth")
                print("saved new best metric model")
            print(
                "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
                    epoch + 1, metric, best_metric, best_metric_epoch
                )
            )
            writer.add_scalar("val_mean_dice", metric, epoch + 1)
            # plot the last model output as GIF image in TensorBoard with the corresponding image and label
            plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
            plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
            plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")

print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
writer.close()

----------
epoch 1/5
1/5, train_loss: 0.6528
2/5, train_loss: 0.6035
3/5, train_loss: 0.5523
4/5, train_loss: 0.5589
5/5, train_loss: 0.5027
epoch 1 average loss: 0.5740
----------
epoch 2/5
1/5, train_loss: 0.5185
2/5, train_loss: 0.5524
3/5, train_loss: 0.4978
4/5, train_loss: 0.4990
5/5, train_loss: 0.4820
epoch 2 average loss: 0.5099
saved new best metric model
current epoch: 2 current mean dice: 0.5774 best mean dice: 0.5774 at epoch 2
----------
epoch 3/5
1/5, train_loss: 0.5014
2/5, train_loss: 0.4777
3/5, train_loss: 0.4466
4/5, train_loss: 0.4978
5/5, train_loss: 0.4914
epoch 3 average loss: 0.4830
----------
epoch 4/5
1/5, train_loss: 0.4648
2/5, train_loss: 0.4874
3/5, train_loss: 0.4355
4/5, train_loss: 0.4689
5/5, train_loss: 0.5098
epoch 4 average loss: 0.4733
saved new best metric model
current epoch: 4 current mean dice: 0.9042 best mean dice: 0.9042 at epoch 4
----------
epoch 5/5
1/5, train_loss: 0.4755
2/5, train_loss: 0.4222
3/5, train_loss: 0.4470
4/5, train_loss: 