In [12]:
import os
# set available cuda
os.environ["CUDA_VISIBLE_DEVICES"] = "1, 3"

# basic package
import logging
import sys
from glob import glob
import numpy as np
import pandas as pd
import torch
from torch.nn.functional import pad
from torch.utils.data import DataLoader, Dataset

# handle .nii.gz files
import nibabel as nib

# advanced function package (medical imaging)
from monai import config
from monai.data import decollate_batch
from monai.handlers import CheckpointLoader, MeanDice, StatsHandler
from monai.networks.nets import UNet
from monai.transforms import (
    Activations,
    AsDiscrete,
    Compose,
    LoadImage,
    CenterSpatialCrop,
    ScaleIntensity,
    EnsureType,
    Transpose,
)
from ignite.handlers import Checkpoint

# subnet
from unet_unit.unet_unit import unit_model


print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))
print(torch.cuda.current_device())

True
2
GeForce GTX 1080
0


## data list, data transform and data loader

In [13]:
# set logging info
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

# set the experiment folder and output dirs
root_dir = "./exp_04"
output_dir = os.path.join(root_dir, 'output')
threshold_dir = os.path.join(output_dir, 'threshold')
singlenet_dir = os.path.join(output_dir, 'single_net')
if not os.path.exists(root_dir):
    os.mkdir(root_dir)
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
if not os.path.exists(threshold_dir):
    os.mkdir(threshold_dir)
if not os.path.exists(singlenet_dir):
    os.mkdir(singlenet_dir)
    
data_dir = "./dataset/working_data"
test_dir = os.path.join(data_dir, 'Test')

images_test = sorted(glob(os.path.join(test_dir, "case*", "image.nii.gz")))
segs_test = sorted(glob(os.path.join(test_dir, "case*", "task01_seg*.nii.gz")))

imtrans_test = Compose(
    [
        # LoadImage(image_only=True),
        ScaleIntensity(),
        Transpose((2, 0, 1)),
        # AddChannel(),
        # CenterSpatialCrop((640, 640)),
        # RandSpatialCrop((96, 96), random_size=False),
        EnsureType(),
    ]
)
segtrans_test = Compose(
    [
        # LoadImage(image_only=True),
        Transpose((2, 0, 1)),
        # AddChannel(),
        # CenterSpatialCrop((640, 640)),
        # RandSpatialCrop((96, 96), random_size=False),
        EnsureType(),
    ]
)


# data set loader, generates 1 img and a list of 6 segs at a time
class evaluator_Dataset(Dataset):
    def __init__(self, img_list, seg_list, img_transform=None, seg_transform=None):
        self.loader = LoadImage()
        self.img_list = img_list
        self.seg_list = seg_list
        self.img_transform = img_transform
        self.seg_transform = seg_transform

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        img, img_meta_data = self.loader(self.img_list[idx])
        img = self.img_transform(img)
        
        seg, seg_meta_data = [], []
        for i in range(6):
            tmp_seg, tmp_seg_meta_data = self.loader(self.seg_list[idx*6 + i])
            tmp_seg = self.seg_transform(tmp_seg)
            seg.append(tmp_seg)
            seg_meta_data.append(tmp_seg_meta_data)
        
        return img, seg, img_meta_data, seg_meta_data

# only batch_size=1 is allowed here
dataset_test = evaluator_Dataset(images_test, segs_test, imtrans_test, segtrans_test)
loader_test = DataLoader(dataset_test, batch_size=1, num_workers=4, pin_memory=torch.cuda.is_available(), shuffle=False)

## load models

In [14]:
# the evaluation device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 6 subnets
model = [UNet(
    dimensions=2,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device) for i in range(6)]

# if we have already trained them (model param saved at the log dirs)
pretrained = True

# (train and) load models
for id in range(1, 7):
    if not pretrained:
        unit_model(id, device)
    check_point_path = glob(os.path.join(root_dir, f"./logs_{id:02d}", "*.pt"))
    if len(check_point_path) > 1:
        raise ValueError
    check_point_path = check_point_path[0]
    Checkpoint.load_objects(to_load={"net": model[id-1]}, checkpoint=torch.load(check_point_path))

## help functions

In [15]:
# functions to compute dice score
def Dice_score(inputs, targets, smooth=0):

    #comment out if your model contains a sigmoid or equivalent activation layer
    #inputs = torch.nn.functional.sigmoid(inputs)       
    #flatten label and prediction tensors
    intersection = (inputs * targets).sum()                            
    dice_score = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
    return dice_score.item()

# compute the dice seperately for each channel (the first dim)
def channel_parallel_Dice_score(inputs, targets, smooth=0):
    #Channel first input needed
    return torch.tensor([Dice_score(i, j) for i, j in zip(torch.split(inputs, 1), torch.split(targets, 1))])


# save nifti file; if file_name is not given, then use the one in the metadata
def save_nifti(tensor_img, meta_data, data_dir, file_name = None):
    if not os.path.exists(data_dir):
        os.mkdir(data_dir)
    
    case_name = meta_data['filename_or_obj'].split('/')[-2]
    if not file_name:
        file_name = meta_data['filename_or_obj'].split('/')[-1]
    case_dir = os.path.join(data_dir, case_name)
    if not os.path.exists(case_dir):
        os.mkdir(case_dir)
    
    full_name = os.path.join(case_dir, file_name)
    
    nib.Nifti1Image(np.array(tensor_img.cpu()), np.array(meta_data['affine'])).to_filename(full_name)

# create dice score record sheets
threshold_dicesheet = pd.DataFrame(columns = [f"{i:.1f}" for i in (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)], index = [i.split('/')[-2] for i in images_test])    
single_net_dicesheet = pd.DataFrame(columns = [f"{i:02d}" for i in range(1, 7)], index = [i.split('/')[-2] for i in images_test])
corr_matrix = np.zeros((6, 6))

## evaluation process

In [None]:
# creat transforms
post_trans_avgd_out = [Compose([EnsureType(), AsDiscrete(threshold_values=True, logit_thresh=i)]) for i in (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)]
post_trans_single_out = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold_values=True, logit_thresh=0.5)])
cut_trans = Compose([CenterSpatialCrop((1, 640, 640)), EnsureType()])


for i in range(6):
    model[i].eval()
with torch.no_grad():

    # initialize mean dice
    mean_dice=torch.zeros((1, len(post_trans_avgd_out)))
    
    for item, case in zip(loader_test, threshold_dicesheet.index):
        # preprocess
        img, segs, img_meta_data, segs_meta_data = item[0], item[1], decollate_batch(item[2])[0], decollate_batch(item[3])[0]
        img = img.to(device)
        segs = [i.to(device) for i in segs]
        
        # check if batchsize is 1
        if img.shape[0] != 1:
            raise ValueError
        
        # aggregate the results
        output = torch.zeros((img.shape[1:])).to(device)  # the output
        ctn_gt = torch.zeros((img.shape[1:])).to(device)  # the continuous ground truth

        # for each 6 subnets, catch their singel output
        for i in range(6):
            # check the size whether to crop to 640, 640
            if img.shape != (1, 1, 640, 640):
                # singel subnet output
                single_output = post_trans_single_out(model[i](cut_trans(img)))
                single_output = pad(single_output, (0, 0, 160, 160), "constant", 0)
            else:
                single_output = post_trans_single_out(model[i](img))
            
            # cross record other experts result
            for j in range(6):
                corr_matrix[j, i] += Dice_score(single_output, segs[j])
            # the matching expert
            single_dice = Dice_score(single_output, segs[i])
        
            # record result and save img
            single_net_dicesheet.loc[case, f'{i+1:02d}'] = single_dice
            save_nifti(single_output.squeeze(dim=0).permute([1, 2, 0]), segs_meta_data[i], singlenet_dir)
            
            # accumulate
            output += torch.squeeze(single_output, dim=0)
            ctn_gt += torch.squeeze(segs[i], dim=0)
        # take the average
        output.div_(6)
        ctn_gt.div_(6)
        
        
        # 9 different thresholds apply to output and continuous gt
        threshed_output = [post_trans_avgd_out[i](output) for i in range(9)]
        threshed_gt = [post_trans_avgd_out[i](ctn_gt) for i in range(9)]
        cated_output = torch.cat(threshed_output)

        # compute the mean dice with 9 thresholds
        cated_gt = torch.cat(threshed_gt)
        dice_row = channel_parallel_Dice_score(cated_output, cated_gt)
        mean_dice += dice_row

        # record the dice
        threshold_dicesheet.loc[case] = dice_row
        
        # the mean dice of this case
        print(f"the {case} mean dice:", f"{dice_row.mean().item():.4f}")

        # save the threshed output images and gts
        for i in range(9):
            save_nifti(threshed_output[i].permute([1, 2, 0]), img_meta_data, threshold_dir, f"output_threshold=point_{i+1}.nii.gz")
            save_nifti(threshed_gt[i].permute([1, 2, 0]), img_meta_data, threshold_dir, f"gt_threshold=point_{i+1}.nii.gz")

# save the recording sheet   
single_net_dicesheet.to_csv(os.path.join(singlenet_dir, "single_net_dicesheet.csv"))
threshold_dicesheet.to_csv(os.path.join(threshold_dir, "threshold_dicesheet.csv"))

# take average and save the cross dice matrix recording
corr_matrix = np.divide(corr_matrix, len(images_test))
np.savetxt(os.path.join(singlenet_dir, "corr_dice_matrix.csv"), corr_matrix, delimiter=',')

# the final score (the whole mean dice score)
mean_dice.div_(len(images_test))
print("overall mean dice:", f"{mean_dice.mean().item():.4f}")

the case47 mean dice: 0.9147
the case48 mean dice: 0.5486
the case49 mean dice: 0.9154
