In [None]:
import matplotlib.pyplot as plt
import glob
import os
import torch
import sys
import numpy as np
from numpy import *
import SimpleITK as sitk

from monai.transforms import (
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    LoadImaged,
    Orientationd,
    ScaleIntensityRanged,
    Spacingd,
    CropForegroundd,
    EnsureTyped,
    EnsureType
)
from monai.data import CacheDataset, DataLoader, decollate_batch
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from dataset import CHD2CHD
from networks.cardiacseg import CardiacSeg

In [2]:
transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        Spacingd(keys=["image", "label"], pixdim=(1, 1, 1), mode=("bilinear", "nearest")),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        ScaleIntensityRanged(
            keys=["image"], a_min=500, a_max=2000,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        CHD2CHD(),
	    EnsureTyped(keys=["image", "label"]),
    ])

chd_dir = '/ImageCHD'
val_images = sorted(glob.glob(f'{chd_dir}/test/images/*image.nii.gz'))
val_labels = sorted(glob.glob(f'{chd_dir}/test/labels/*label.nii.gz'))
test_files = [
    {"image": image_name, "label": label_name}
    for image_name, label_name in zip(val_images, val_labels)
]
print(len(test_files))

# test_files
test_ds = CacheDataset(data=test_files, transform=transforms, cache_rate=1.0, num_workers=4,)
test_loader = DataLoader(test_ds, batch_size=1, num_workers=4,)

22


Loading dataset:   0%|          | 0/22 [00:00<?, ?it/s]Modifying image pixdim from [0.75 0.75 1.25 1.  ] to [ 0.75        0.75        1.25       97.59146351]
Modifying image pixdim from [0.75 0.75 1.25 1.  ] to [ 0.75        0.75        1.25       95.50278137]
Modifying image pixdim from [0.75 0.75 1.25 1.  ] to [ 0.75        0.75        1.25       93.13716364]
Modifying image pixdim from [0.75 0.75 1.25 1.  ] to [ 0.75        0.75        1.25       94.99958881]
Loading dataset:  27%|██▋       | 6/22 [00:53<01:56,  7.31s/it]Modifying image pixdim from [0.75 0.75 1.25 1.  ] to [ 0.75        0.75        1.25       94.06962182]
Loading dataset: 100%|██████████| 22/22 [03:20<00:00,  9.10s/it]


In [None]:
feature_size = (128,128,128)
num_classes = 1+7
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Args:
    def __init__(self):
        self.arch = 'vit_base'
        self.finetune = False
args = Args()

def vit_params(args):
    if args.arch == 'vit_tiny':
        hidden_size = 192
        num_heads = 3
    elif args.arch == 'vit_base':
        hidden_size = 768
        num_heads = 12
    elif args.arch == 'vit_large':
        hidden_size = 1152
        num_heads = 16
    elif args.arch == 'vit_huge':
        hidden_size = 1344
        num_heads = 16
    return hidden_size, num_heads

hidden_size, num_heads = vit_params(args)

cardiacseg = CardiacSeg(
                    in_channels = 1,
                    out_channels = num_classes,
                    img_size = feature_size,
                    feature_size = 16,
                    hidden_size = hidden_size,
                    mlp_dim = 3072,
                    num_heads = num_heads,
                    norm_name = "batch",
                    res_block = True,
                    dropout_rate = 0.0,
                    lora = False,
                    res_adpter = False,
                    adapterformer = False,
                    args = args
                )

In [4]:
def view(val_data, val_outputs, step=8):
    img1, lab1 = (val_data["image"][0][0], val_data["label"][0][0])
    lab2 = torch.argmax(val_outputs, dim=1).detach().cpu()[0]
    for i in range(0, lab1.shape[-1], step):
        a = img1[..., i]
        b = lab1[..., i]
        d = lab2[..., i]
        if b.sum()+d.sum() > 0:
            fig, ax = plt.subplots(1, 3, figsize=(9,3))
            ax[0].set_title('raw image')
            ax[0].imshow(a, cmap='gray')
            ax[1].set_title('ground truth')
            ax[1].imshow(a, cmap='gray')
            ax[1].imshow(b, alpha=0.5, cmap='hot')
            ax[2].set_title('pred label')
            ax[2].imshow(a, cmap='gray')
            ax[2].imshow(d, alpha=0.5, cmap='hot')
            plt.show()
    return

In [6]:
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=num_classes)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=num_classes)])
dice_metric = DiceMetric(include_background=False, reduction="mean")
def test_step(model, test_loader):        
    with torch.no_grad():
        dice_list = np.empty((len(test_loader), num_classes))
        i = 0
        for test_data in test_loader:
            i += 1
            test_inputs, test_labels = (
                test_data["image"].to(device),
                test_data["label"].to(device),
            )

            roi_size = feature_size
            sw_batch_size = 4
            test_outputs = sliding_window_inference(test_inputs, roi_size, sw_batch_size, model)
            view(test_data, test_outputs, step=20)

            outputs = [post_pred(i) for i in decollate_batch(test_outputs)]
            labels = [post_label(i) for i in decollate_batch(test_labels)]
            dice = dice_metric(y_pred=outputs, y=labels).squeeze().cpu().numpy()
            dice_avg = np.nanmean(dice)
            dice_list[i-1] = np.append(dice, dice_avg)

        dice_metric.reset()
        print(f'average dice {np.mean(dice_list[:,-1])}, standard variation {np.std(dice_list[:,-1])}')
        for i in range(num_classes-1):
            print(f'class {i+1}, average dice {np.nanmean(dice_list[:,i])}, standard variation {np.nanstd(dice_list[:,i])}')
    return

# Inference

In [7]:
weight_path = ''
model = cardiacseg.to(device)  ####
checkpoint = torch.load(weight_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
test_step(model, test_loader)

average dice 0.8918075344779275, standard variation 0.03304845323981812
class 1, average dice 0.9259832934899763, standard variation 0.030582116635415952
class 2, average dice 0.876049055294557, standard variation 0.05941473922174186
class 3, average dice 0.916997644034299, standard variation 0.021506057010733975
class 4, average dice 0.9121511280536652, standard variation 0.023933803539267665
class 5, average dice 0.8861239463090896, standard variation 0.05455964147421084
class 6, average dice 0.8876700618050315, standard variation 0.058084500745306474
class 7, average dice 0.8392996219071475, standard variation 0.09825686648075699
