## 初始化

引入Oneke相关的的库

In [None]:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    ScaleIntensityRanged,
    Spacingd,
    EnsureTyped,
    EnsureType,
    Invertd,
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet, SegResNet, VNet, UNETR
from monai.networks.layers import Norm
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.config import print_config
from monai.apps import download_and_extract
import torch
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
import sys

### 数据集划分

默认进行随机划分，使用最后的8个作为测试集合。

`seg_idx = 1`可以通过修改seg_idx的值选择训练那个模型

  1. `1`代表训练T1
  2. `2`代表训练T2

In [None]:
import random
import os
import nibabel as nib
import numpy as np
import json
import pandas as pd
from onekey_algo import get_param_in_cwd
from onekey_algo.custom.components.Radiology import diagnose_3d_image_mask_settings

root_dir = 'Y' + get_param_in_cwd('radio_dir')[1:]
data_dir= os.path.join(root_dir, 'data')
val_data_dir= os.path.join(root_dir, 'test')
inference_dir= os.path.join(root_dir, 'test_results')
model_root = os.path.join(root_dir, 'models')
os.makedirs(model_root, exist_ok=True)
# os.makedirs(inference_dir, exist_ok=True)
roi_size = (96, 96, 48)

# 这里选择任务类型
sel_modal = 'MR-T2'
train_files = []
data = pd.read_csv(os.path.join(root_dir, 'label.csv'))
# for i in data['ID']:
for i in data[data['group'] == 'train']['ID']:
    if os.path.exists(os.path.join(root_dir, sel_modal,'images', i)):
        train_files.append({'image': os.path.join(root_dir, sel_modal, 'images', i), 
                            'label': os.path.join(root_dir, sel_modal, 'masks', i)})
train_files

In [None]:
val_files = []
for i in data[data['group'] != 'train']['ID']:
    if os.path.exists(os.path.join(root_dir, sel_modal, 'images', i)):
        val_files.append({'image': os.path.join(root_dir, sel_modal, 'images', i), 
                          'label': os.path.join(root_dir, sel_modal, 'masks', i)})
# val_files = val_files[:20]
val_files

In [None]:
print(f"训练集：{len(train_files)}，测试集：{len(val_files)}")

## Setup transforms for training and validation

Here we use several transforms to augment the dataset:
1. `LoadImaged` loads the spleen CT images and labels from NIfTI format files.
1. `AddChanneld` as the original data doesn't have channel dim, add 1 dim to construct "channel first" shape.
1. `Orientationd` unifies the data orientation based on the affine matrix.
1. `Spacingd` adjusts the spacing by `pixdim=(1.5, 1.5, 2.)` based on the affine matrix.
1. `ScaleIntensityRanged` extracts intensity range [-57, 164] and scales to [0, 1].
1. `CropForegroundd` removes all zero borders to focus on the valid body area of the images and labels.
1. `RandCropByPosNegLabeld` randomly crop patch samples from big image based on pos / neg ratio.  
The image centers of negative samples must be in valid body area.
1. `RandAffined` efficiently performs `rotate`, `scale`, `shear`, `translate`, etc. together based on PyTorch affine transform.
1. `EnsureTyped` converts the numpy array to PyTorch Tensor for further steps.

In [None]:
set_determinism(seed=0)
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        EnsureTyped(keys=["image", "label"]),
#         Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"], pixdim=(1, 1, 1), mode=("bilinear", "nearest")),
        ScaleIntensityRanged(
            keys=["image"], a_min=0, a_max=1600,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=roi_size,
            pos=1,
            neg=2,
            num_samples=4,
#             image_key="image",
#             image_threshold=0,
        ),
        EnsureTyped(keys=["image", "label"]),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]), 
        EnsureTyped(keys=["image", "label"]),
#         Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"], pixdim=(1, 1, 1), mode=("bilinear", "nearest")),
        ScaleIntensityRanged(
            keys=["image"], a_min=0, a_max=1600,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        EnsureTyped(keys=["image", "label"]),
    ]
)

### Dataloader

检查Transform以及相应的Dataloader。

Here we use CacheDataset to accelerate training and validation process, it's 10x faster than the regular Dataset.  To achieve best performance, set `cache_rate=1.0` to cache all the data, if memory is not enough, set lower value.  Users can also set `cache_num` instead of `cache_rate`, will use the minimum value of the 2 settings.  And set `num_workers` to enable multi-threads during caching.  If want to to try the regular Dataset, just change to use the commented code below.

In [None]:
train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1, num_workers=6)
# train_ds = Dataset(data=train_files, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=12)

# use batch_size=2 to load images and use RandCropByPosNegLabeld
# to generate 2 x 4 images for network training
val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1, num_workers=6)
# val_ds = Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=12)

#### 数据可视化

In [None]:
# pick one image from DecathlonDataset to visualize and check the 4 channels
val_data_example = val_ds[2]
print(f"image shape: {val_data_example['image'].shape}")
plt.figure("image", (12, 6))
for i in range(1):
    plt.subplot(1, 2, i + 1)
    plt.title(f"image channel {i}")
    plt.imshow(val_data_example["image"][i, :, :, 65].detach().cpu(), cmap="gray")
plt.show()
# also visualize the 3 channels label corresponding to this image
print(f"label shape: {val_data_example['label'].shape}")
plt.figure("label", (6, 6))
for i in range(1):
    plt.subplot(1, 1, i + 1)
    plt.title(f"label channel {i}")
    plt.imshow(val_data_example["label"][i, :, :, 65].detach().cpu())
plt.show()

In [None]:
np.unique(val_data_example['label'])

## 生成 Model, Loss, Optimizer

In [None]:
from monai.metrics import DiceMetric
from monai.losses import DiceLoss, DiceFocalLoss, DiceCELoss

device = torch.device(f"cuda:0")
mtype = 'unet'
num_classes = 2
in_channels = 1
if mtype.lower() == 'unet':
    #Unet
    model = UNet(
        spatial_dims=3,
        in_channels=in_channels,
        out_channels=num_classes,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
        norm=Norm.BATCH,
    ).to(device)
elif mtype.lower() == 'segresnet':
    #SegResNet
    model = SegResNet(
        blocks_down=[1, 2, 2, 4],
        blocks_up=[1, 1, 1],
        init_filters=16,
        in_channels=in_channels,
        out_channels=num_classes,
        dropout_prob=0.2,
    ).to(device)
elif mtype.lower() == 'unetr':
    # UNETR
    model = UNETR(
        in_channels=in_channels,
        out_channels=num_classes,
        img_size=roi_size,
        feature_size=16,
        hidden_size=768,
        mlp_dim=3072,
        num_heads=12,
        pos_embed="perceptron",
        norm_name="instance",
        res_block=True,
        dropout_rate=0.0,
    ).to(device)
elif mtype.lower() == 'vnet':
    model = VNet(spatial_dims=3, 
                 in_channels=in_channels, 
                 out_channels=num_classes,
                 dropout_prob=0.2, 
                 dropout_dim=3, 
                 bias=False).to(device)
else:
    raise ValueError(f'{mtype} not found!')

print(f"使用{mtype.upper()}进行训练！")
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-3)
dice_metric = DiceMetric(include_background=False, reduction="mean")
if os.path.exists(os.path.join(model_root, f"{mtype}_{sel_modal}.pth")):
    print('加载预训练模型...')
    model.load_state_dict(torch.load(os.path.join(model_root, f"{mtype}_{sel_modal}.pth"), map_location=device))

### 模型训练

`max_epochs`最大迭代次数，int类型，默认： 600

`val_interval` 多少次训练进行一次validation，默认： 2

In [None]:
max_epochs = 600
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=num_classes)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=num_classes)])
early_stopping_epoch = 128
for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device),
        )
        optimizer.zero_grad()
#         print(inputs.size())
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        if step % 2 == 0:
            print(f"{step}/{len(train_ds) // train_loader.batch_size}, train_loss: {loss.item():.4f}")
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device),
                )
                sw_batch_size = 4
                val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model)
                val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                # compute metric for current iteration
                dice_metric(y_pred=val_outputs, y=val_labels)

            # aggregate the final mean dice result
            metric = dice_metric.aggregate().item()
            # reset the status for next validation round
            dice_metric.reset()

            metric_values.append(metric)
            torch.save(model.state_dict(), os.path.join(model_root, f"{mtype}_{sel_modal}-Epoch{epoch+1}.pth"))
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), 
                           os.path.join(model_root, f"{mtype}_{sel_modal}.pth"))
                print("saved new best metric model")
            if epoch - best_metric_epoch > early_stopping_epoch:
                print(f'Early Stop @{epoch+1}')
                break
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f"\nbest mean dice: {best_metric:.4f} "
                f"at epoch: {best_metric_epoch}"
            )
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")

### 打印训练过程

In [None]:
plt.figure("train", (12, 6))
plt.subplot(1, 2, 1)
plt.title("Epoch Average Loss")
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("epoch")
plt.plot(x, y)
plt.subplot(1, 2, 2)
plt.title("Val Mean Dice")
x = [val_interval * (i + 1) for i in range(len(metric_values))]
y = metric_values
plt.xlabel("epoch")
plt.plot(x, y)
plt.savefig(f'img/{mtype}_{sel_modal}_train_process.svg', bbox_inch='tight')
plt.show()

In [None]:
from monai.transforms import KeepLargestConnectedComponentd, RemoveSmallObjectsd
import SimpleITK as sitk
import numpy as np
from monai.config import KeysCollection
from monai.transforms import MapTransform


class RemoveSmallObjectsPerLabel(MapTransform):
    """
    Convert labels to multi channels based on brats classes:
    label 1 is the peritumoral edema
    label 2 is the GD-enhancing tumor
    label 3 is the necrotic and non-enhancing tumor core
    The possible classes are TC (Tumor core), WT (Whole tumor)
    and ET (Enhancing tumor).

    """

    def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False,
                 min_size=64, verbose: bool = False, force2: int = 0):
        MapTransform.__init__(self, keys, allow_missing_keys)
        self.min_size = min_size
        self.verbose = verbose
        self.force2 = force2

    def __call__(self, data):
        for k in self.keys:
            batch_data = []
            for idx in range(data[k].shape[0]):
                sel_data = np.array(data[k][idx]).astype(int)
                print(sel_data.shape)
                sel_data = sitk.GetImageFromArray(sel_data)
                sitk.WriteImage(sel_data, f"{idx}.nii.gz")
                cc_filter = sitk.ConnectedComponentImageFilter()
                cc_filter.SetFullyConnected(True)
                omask_array = sitk.GetArrayFromImage(cc_filter.Execute(sel_data))
                unique_labels = np.unique(omask_array)
                mask_label_voxels = {}
                for ul in unique_labels:
                    mask_label_voxels[ul] = np.sum(omask_array == ul)
                mask_label_voxels = sorted(mask_label_voxels.items(), key=lambda x: x[1], reverse=True)
                mask_postprocess = np.ones_like(omask_array)
                for idx, (ul, cnt) in enumerate(mask_label_voxels):
                    if cnt < self.min_size:
                        mask_postprocess[omask_array == ul] = self.force2
                if self.verbose:
                    print(unique_labels, mask_label_voxels)
                batch_data.append(mask_postprocess * data[k])
            data[k] = np.array(batch_data)
        return data

val_t = Compose(
    [
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys=["image"]),
        EnsureTyped(keys=["image"]),
#         Orientationd(keys=["image"], axcodes="RAS"),
        Spacingd(keys=["image"], pixdim=(1, 1, 1), mode=("bilinear")),
        ScaleIntensityRanged(
            keys=["image"], a_min=0, a_max=6000,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["image"], source_key="image"),
        EnsureTyped(keys=["image"]),
    ]
)

# val_transforms = Compose(
#     [
#         LoadImaged(keys=["image"], allow_missing_keys=True),
#         EnsureChannelFirstd(keys=["image",], allow_missing_keys=True),
# #         Orientationd(keys=["image", "label"], axcodes="RAS"),
#         Spacingd(keys=["image",], pixdim=(0.5, 0.5, 3), mode=("bilinear"), allow_missing_keys=True),
#         ScaleIntensityRanged(
#             keys=["image"], a_min=0, a_max=2500,
#             b_min=0.0, b_max=1.0, clip=True,
#         ),
#         CropForegroundd(keys=["image",], source_key="image", allow_missing_keys=True),
#         EnsureTyped(keys=["image",], allow_missing_keys=True),
#     ]
# )
post_ori_t = Compose(
    [
        Invertd(
            keys="pred",
            transform=val_t,
            orig_keys="image",
            meta_keys="pred_meta_dict",
            orig_meta_keys="image_meta_dict",
            meta_key_postfix="meta_dict",
            nearest_interp=False,
            to_tensor=True,
            device="cpu",
        ),
        AsDiscreted(keys="pred", argmax=True),
        KeepLargestConnectedComponentd(keys='pred', num_components=4)
#         RemoveSmallObjectsPerLabel(keys='pred', min_size=5000, verbose=True)
    ]
)


### 预测

In [None]:
import glob
import os
from onekey_algo import OnekeyDS
from onekey_algo.segmentation3D.modelzoo.eval_3dsegmentation import init as init3d
from onekey_algo.segmentation3D.modelzoo.eval_3dsegmentation import inference as inference3d

root_dir = r'D:\20240510-ChangBoWen'
save_dir= os.path.join(root_dir, f'{mtype}_infer')
model_root = os.path.join(root_dir, 'models')
sel_modal = 'CLS2'

mtype = 'segresnet'
model_path = os.path.join(model_root, f'{mtype}_{sel_modal}.pth')
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

# for i in prefetch:
#     mask_img = i.replace('.nii.gz', '.mask.nii.gz')
#     if not os.path.exists(mask_img):
#         data.append(i)
num_classes = 4
m, t, d = init3d('SegResNet', model_path=model_path, num_classes=num_classes, roi_size=roi_size)
d = 'cuda:0'
m = m.to(d)
data = glob.glob(os.path.join(root_dir, r'images', '*.nii.gz'))

inference3d(data, model, (val_t, post_ori_t), d, roi_size=roi_size, save_dir=save_dir)

# for data_ in data:
#     inference3d([data_], m, (val_t, post_ori_t), d, 
#                 roi_size=roi_size, save_dir=os.path.dirname(data_), save_name=data_.replace('.nii.gz', '.infer.nii.gz'))

In [None]:
import numpy as np
from monai.metrics import compute_average_surface_distance, compute_hausdorff_distance

def calc_dice(p_cls, l_cls):
    # cal the inter & conv
    s = p_cls + l_cls
    inter = len(np.where(s >= 2)[0])
    conv = len(np.where(s >= 1)[0]) + inter
    try:
        dice = 2.0 * inter / conv
    except:
        print("conv is zeros when dice = 2.0 * inter / conv")
        dice = None
    return dice

def calc_iou(p_cls, l_cls):
    # cal the inter & conv
    s = p_cls + l_cls
    inter = len(np.where(s >= 2)[0])
    conv = len(np.where(s >= 1)[0])
    try:
        iou = inter / conv
    except:
        print("conv is zeros when dice = 2.0 * inter / conv")
        iou = None
    return iou
    
def calc_sa(p_cls, l_cls):
    # cal the inter & conv
    error = np.bitwise_xor(p_cls, l_cls) & l_cls
    try:
        sa = 1 - np.sum(error) / np.sum(l_cls)
    except:
        print("SA segmentation is error!")
        sa = None
    return sa

def calc_os(p_cls, l_cls):
    # cal the inter & conv
    error = np.bitwise_xor(p_cls, l_cls) & p_cls
    try:
        over_s = np.sum(error) / (np.sum(l_cls) + np.sum(p_cls))
    except:
        print("Over segmentation is error!")
        over_s = None
    return over_s

def calc_us(p_cls, l_cls):
    # cal the inter & conv
    error = np.bitwise_xor(p_cls & l_cls, l_cls)
    try:
        us = np.sum(error) / (np.sum(l_cls) + np.sum(np.bitwise_xor(p_cls, l_cls) & p_cls))
    except:
        print("Under segmentation is error!")
        us = None
    return us

def calc_asd(p_cls, l_cls):
    asd = compute_average_surface_distance(p_cls[np.newaxis, np.newaxis, :], l_cls[np.newaxis, np.newaxis, :])
    return float(asd)

def calc_hausdorff_distance(p_cls, l_cls): 
    hd = compute_hausdorff_distance(p_cls[np.newaxis, np.newaxis, :], l_cls[np.newaxis, np.newaxis, :])
    return float(hd)

def seg_eval(pred, label, clss=[0, 1]):
    """
    calculate the dice between prediction and ground truth
    input:
        pred: predicted mask
        label: groud truth
        clss: eg. [0, 1] for binary class
    """
    Ncls = len(clss)
    eval_matric = [None] * Ncls
    [depth, height, width] = pred.shape
    for idx, cls in enumerate(clss):
        # binary map
        pred_cls = np.zeros([depth, height, width], dtype=np.uint8)
        pred_cls[np.where(pred == cls)] = 1
        label_cls = np.zeros([depth, height, width], dtype=np.uint8)
        label_cls[np.where(label == cls)] = 1

        metric = [calc_dice(pred_cls, label_cls), calc_iou(pred_cls, label_cls), 
                  calc_sa(pred_cls, label_cls), calc_os(pred_cls, label_cls), calc_us(pred_cls, label_cls), 
#                   calc_asd(pred_cls, label_cls), calc_hausdorff_distance(pred_cls, label_cls)
                 ]
        eval_matric[idx] = metric

    return eval_matric

### 后处理

In [None]:
from glob import glob
import SimpleITK as sitk
import os
import numpy as np
import pandas as pd

models = ['unet', 'unetr', 'vnet']
cohort_metric = []
metric_spec = []
tn = None
for model in models:
    root = os.path.join(get_param_in_cwd('radio_dir'), sel_modal, f'{model}_infer')
    metric_names = ['Dice', 'mIOU', 'SA', 'OS', 'US']

    for fs in [train_files[:tn], val_files[:tn]]:
        all_metrics = []
        for gt in fs:
            gt_mask = gt['label']
            pred_mask = os.path.join(root, os.path.basename(gt_mask))
            all_metrics.append(seg_eval(sitk.GetArrayFromImage(sitk.ReadImage(pred_mask)),
                                        sitk.GetArrayFromImage(sitk.ReadImage(gt_mask)), clss=[0, 1]))
        metric = pd.DataFrame(np.mean(np.array(all_metrics), axis=1), columns=metric_names)
        cohort_metric.append(pd.DataFrame(metric.mean(axis=0)).T)
        metric['model'] = model
        info = pd.concat([pd.DataFrame([os.path.basename(f['image']) for f in fs[:tn]], columns=['ID']), metric], axis=1)
        metric_spec.append(info)

In [None]:
info = pd.concat(metric_spec, axis=0)
os.makedirs('data', exist_ok=True)
info.to_csv(f'data/{sel_modal}_infer.csv', index=False)
info

In [None]:
pd.merge(info, pd.read_csv('group.csv'), on='ID', how='left').groupby(['group', 'model']).agg('mean').reset_index()

In [None]:
len(val_files)