In [1]:
import matplotlib.pyplot as plt
import glob
import os
import torch
import sys
import numpy as np
from numpy import *
import SimpleITK as sitk
import scipy.ndimage as ndimage

from monai.transforms import (
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    LoadImaged,
    Orientationd,
    ScaleIntensityRanged,
    Spacingd,
    CropForegroundd,
    EnsureTyped,
    EnsureType
)
from monai.data import CacheDataset, DataLoader, decollate_batch
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from dataset import CHD2CHD
from networks.cardiacseg import CardiacSeg

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Spacingd(keys=["image", "label"], pixdim=(1, 1, 1), mode=("bilinear", "nearest")),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        ScaleIntensityRanged(
            keys=["image"], a_min=500, a_max=2000,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        # CropForegroundd(keys=["image", "label"], source_key="image"),
        CHD2CHD(),
	    EnsureTyped(keys=["image", "label"]),
    ])

chd_dir = '/home/jianglei/VCL-Project/data/2022Jianglei/dataset/ImageCHD_split_sdf'
val_images = sorted(glob.glob(f'{chd_dir}/test/images/*image.nii.gz'))
val_labels = sorted(glob.glob(f'{chd_dir}/test/labels/*label.nii.gz'))
test_files = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(val_images, val_labels)
]
print(len(test_files))

# test_files
test_ds = CacheDataset(data=test_files, transform=transforms, cache_rate=1.0, num_workers=4,)
test_loader = DataLoader(test_ds, batch_size=1, num_workers=4,)

34


Loading dataset: 100%|██████████| 34/34 [00:33<00:00,  1.00it/s]


In [3]:
feature_size = (128,128,128)
num_classes = 1+7
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Args:
    def __init__(self):
        self.arch = 'vit_base'
        self.finetune = False
args = Args()

def vit_params(args):
    if args.arch == 'vit_tiny':
        hidden_size = 192
        num_heads = 3
    elif args.arch == 'vit_base':
        hidden_size = 768
        num_heads = 12
    elif args.arch == 'vit_large':
        hidden_size = 1152
        num_heads = 16
    elif args.arch == 'vit_huge':
        hidden_size = 1344
        num_heads = 16
    return hidden_size, num_heads

hidden_size, num_heads = vit_params(args)

cardiacseg = CardiacSeg(in_channels = 1,
                            out_channels = num_classes,
                            img_size = feature_size,
                            feature_size = 16,    ####
                            hidden_size = hidden_size,
                            mlp_dim = 3072,
                            num_heads = num_heads,
                            pos_embed = "conv",
                            norm_name = 'instance',
                            conv_block = True,
                            res_block = True,
                            dropout_rate = 0.0,
                            spatial_dims = 3,
                            args = args)

In [4]:
def view(val_data, val_outputs, step=8):
    img1, lab1 = (val_data["image"][0][0], val_data["label"][0][0])
    lab2 = torch.argmax(val_outputs, dim=1).detach().cpu()[0]
    for i in range(0, lab1.shape[-1], step):
        a = img1[..., i]
        b = lab1[..., i]
        d = lab2[..., i]
        if b.sum()+d.sum() > 0:
            fig, ax = plt.subplots(1, 3, figsize=(9,3))
            ax[0].set_title('raw image')
            ax[0].imshow(a, cmap='gray')
            ax[1].set_title('ground truth')
            ax[1].imshow(a, cmap='gray')
            ax[1].imshow(b, alpha=0.5, cmap='hot')
            ax[2].set_title('pred label')
            ax[2].imshow(a, cmap='gray')
            ax[2].imshow(d, alpha=0.5, cmap='hot')
            plt.show()
    return

In [5]:
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=num_classes)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=num_classes)])
dice_metric = DiceMetric(include_background=False, reduction="mean")
def test_step(model, test_loader, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    with torch.no_grad():
        dice_list = np.empty((len(test_loader), num_classes))
        i = 0
        for test_data in test_loader:
            test_inputs, test_labels = (
                test_data["image"].to(device),
                test_data["label"].to(device),
            )

            roi_size = feature_size
            sw_batch_size = 4
            test_outputs = sliding_window_inference(test_inputs, roi_size, sw_batch_size, model)
            # view(test_data, test_outputs, step=20)
            save_result(i, test_files, test_outputs, output_dir)
            i += 1

            outputs = [post_pred(i) for i in decollate_batch(test_outputs)]
            labels = [post_label(i) for i in decollate_batch(test_labels)]
            dice = dice_metric(y_pred=outputs, y=labels).squeeze().cpu().numpy()
            dice_avg = np.nanmean(dice)
            dice_list[i-1] = np.append(dice, dice_avg)

        dice_metric.reset()
        print(f'average dice {np.mean(dice_list[:,-1])}, standard variation {np.std(dice_list[:,-1])}')
        for i in range(num_classes-1):
            print(f'class {i+1}, average dice {np.nanmean(dice_list[:,i])}, standard variation {np.nanstd(dice_list[:,i])}')
    return

In [6]:
def save_result(n, files, val_outputs, output_dir):
    val_outputs = torch.softmax(val_outputs, 1).cpu().numpy()
    val_outputs = np.argmax(val_outputs, axis=1).astype(np.uint8)[0]

    raw_img_p = files[n]['image']
    raw_name = raw_img_p.split('/')[-1]
    save_name = raw_name.replace('image.nii.gz', 'predlabel.nii.gz')
    
    raw_img = sitk.ReadImage(raw_img_p)
    raw_img_arr = sitk.GetArrayFromImage(raw_img)
    raw_size = raw_img_arr.shape
    pred_shape = val_outputs.shape
    
    zoom = (raw_size[2]/pred_shape[0], raw_size[1]/pred_shape[1], raw_size[0]/pred_shape[2])
    pred_arr = ndimage.zoom(val_outputs, zoom, output=np.uint8, order=0, mode='nearest', prefilter=False)
    pred_arr = pred_arr.transpose(2,1,0)
#     pred_arr = np.flip(pred_arr, 2)

    out = sitk.GetImageFromArray(pred_arr)
    out.SetDirection(raw_img.GetDirection())
    out.SetOrigin(raw_img.GetOrigin())
    out.SetSpacing(raw_img.GetSpacing())
    save_pred = f'{output_dir}/{save_name}'
    sitk.WriteImage(out, save_pred)
    print('Done: {}'.format(raw_name))
            
    return

# Inference

In [7]:
weight_path = '/home/jianglei/VCL-Project/data/2022Jianglei/CardiacSeg/output-sdf-b1in-cardiacseg/metric_model-epoch490-dice0.8008242249488831.pth'
output_dir = "/home/jianglei/VCL-Project/data/2022Jianglei/CardiacSeg/output-sdf-b1in-cardiacseg/output"
model = cardiacseg.to(device)  ####
model = torch.nn.DataParallel(model)
checkpoint = torch.load(weight_path)
model.load_state_dict(checkpoint['model_state_dict'], strict=True)
model.eval()
test_step(model, test_loader, output_dir)

Done: ct_1002_image.nii.gz
Done: ct_1003_image.nii.gz
Done: ct_1004_image.nii.gz
Done: ct_1005_image.nii.gz
Done: ct_1010_image.nii.gz
Done: ct_1011_image.nii.gz
Done: ct_1014_image.nii.gz
Done: ct_1016_image.nii.gz
Done: ct_1023_image.nii.gz
Done: ct_1028_image.nii.gz
Done: ct_1030_image.nii.gz
Done: ct_1033_image.nii.gz
Done: ct_1035_image.nii.gz
Done: ct_1036_image.nii.gz
Done: ct_1043_image.nii.gz
Done: ct_1044_image.nii.gz
Done: ct_1046_image.nii.gz
Done: ct_1048_image.nii.gz
Done: ct_1050_image.nii.gz
Done: ct_1054_image.nii.gz
Done: ct_1059_image.nii.gz
Done: ct_1060_image.nii.gz
Done: ct_1063_image.nii.gz
Done: ct_1064_image.nii.gz
Done: ct_1070_image.nii.gz
Done: ct_1083_image.nii.gz
Done: ct_1092_image.nii.gz
Done: ct_1105_image.nii.gz
Done: ct_1112_image.nii.gz
Done: ct_1114_image.nii.gz
Done: ct_1119_image.nii.gz
Done: ct_1135_image.nii.gz
Done: ct_1138_image.nii.gz
Done: ct_1150_image.nii.gz
average dice 0.8014564803735245, standard variation 0.1590832348871658
class 1, av