In [1]:
from reader.MedSegReader import MedSegSimpleReader
from train_tools.EarlyStopper import EarlyStopper
from train_tools.BestSaver import BestSaver
from reader.ctimageio import *
from metrics.multilabel import *
from net.Unet2D import Unet2D

from module.display import *
import torch
import cv2
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt

In [2]:
path = "MedSeg/Liver/"
ms_reader = MedSegSimpleReader(path, isFlip=True)

In [3]:
def resizeCTMask(source_mask, resize_shape):
    labels = [i+1 for i in range(9)]

    segments = [np.array(np.where(source_mask==label, 1, 0), dtype=np.uint8) for label in labels] # 分別取得各segment的mask
    segments = [cv2.resize(segment, resize_shape, interpolation=cv2.INTER_LINEAR) 
                        for segment in segments]
    
    check_mask = np.zeros(resize_shape) # 紀錄resize後，label的重疊情形
    resized_mask = np.zeros(resize_shape)
    for label, segment in zip(labels, segments):
        check_mask += segment
        resized_mask += segment*label

    if len(np.unique(check_mask)) > 2: # label有重疊到 (以確認過只會重疊一次 overlap_mask==2的部分)
        overlap_ys, overlap_xs = np.where(check_mask==2)
        resized_mask[overlap_ys, overlap_xs] = 0 # 把重疊的部分改為0

        # overlap_mask[overlap_ys, overlap_xs] = 0 # 檢查用
        # print(np.unique(overlap_mask))
    return resized_mask

class MSDataset(torch.utils.data.Dataset):
    def __init__(self, reader, pt_indices, resize_shape=(256, 256), targetonly=False):
        self.imgs = []
        self.masks = []

        for index in pt_indices:
            ct_slices, ct_masks = reader[index] # patient images
            if resize_shape != None:
                ct_slices, ct_masks = self.__resize(ct_slices, ct_masks, resize_shape)
                
            for ct_slice, ct_mask in zip(ct_slices, ct_masks):                
                if targetonly: # 只保留有Target Segments的部分
                    if np.sum(ct_mask) == 0:
                        continue

                self.imgs.append(ct_slice)
                self.masks.append(ct_mask)
        self.imgs = np.array(self.imgs)
        self.masks = np.array(self.masks)

    def __resize(self, ct_slices, ct_masks, resize_shape):
        resized_slices, resized_masks = [], []
        for ct_slice, ct_mask in zip(ct_slices, ct_masks):                
            ct_slice = cv2.resize(ct_slice, resize_shape, interpolation=cv2.INTER_LINEAR)
            ct_mask = resizeCTMask(ct_mask, resize_shape)
            ct_slice = np.clip(ct_slice, -160, 240)

            resized_slices.append(ct_slice)
            resized_masks.append(ct_mask)
        return np.array(resized_slices), np.array(resized_masks)

    def __getitem__(self, idx):
        img = self.imgs[idx]
        mask = self.masks[idx]

        img = torch.from_numpy(img)
        img = torch.unsqueeze(img, 0).type(torch.float32)
        mask = torch.from_numpy(mask).type(torch.int64)
        mask = torch.nn.functional.one_hot(mask, num_classes=10).permute(2, 0, 1).float()

        return img, mask

    def __len__(self):
        return len(self.imgs)

In [4]:
use_cuda = torch.cuda.is_available()
if use_cuda:
    device = torch.device("cuda")
else:
    raise PermissionError("Not detect GPU devices")

In [5]:
batch_size = 16
criterion = multilabel_dice_loss

def evaluate(model, data_loader, o_dataset, device):
    model.eval()
    targets = [mask for _, mask in o_dataset]
    targets = torch.stack(targets)

    # Get model outputs
    outputs = []
    for images, _ in data_loader:
        X = images.to(device)

        with torch.no_grad():
            outs = model(X)

        outs = outs.cpu()
        outputs.append(outs)
    outputs = torch.vstack(outputs)
    outputs = outputs.permute(0,2,3,1)
    outputs = np.argmax(outputs.numpy(), axis=-1)

    # Output resize
    prediction = []
    for mask in outputs:
        resized_mask = resizeCTMask(mask, (512, 512))
        resized_mask = torch.from_numpy(resized_mask).type(torch.int64)
        resized_mask = torch.nn.functional.one_hot(resized_mask, num_classes=10).permute(2, 0, 1).float()
        prediction.append(resized_mask)
    prediction = torch.stack(prediction)

    dice_global = multilabel_dice(prediction, targets) # Calculation of dice global
    no_bg_dice = dice_coeff_no_bg(prediction, targets)
    label_dice = dice_all_labels(prediction, targets)

    return dice_global.item(), no_bg_dice.item(), label_dice

In [6]:
import os
from pathlib import Path
import openpyxl

result_filename = "2DResult.xlsx"

# 檢查檔案是否存在
if Path(result_filename).is_file():
    # 如果存在，刪除該檔案
    os.remove(result_filename)

# 建立一個新的Excel檔案
workbook = openpyxl.Workbook()
workbook.save(result_filename)

In [7]:
sample_indices = [i for i in range(len(ms_reader))]
kf = KFold(n_splits=5)
# 創建一個空的DataFrame
df = pd.DataFrame(columns=["fold_num", "bg", "S1", "S2", "S3", "S4a", "S4b", "S5", "S6", "S7", "S8", "dice_global", "No_bg"])
df.to_excel(result_filename, index=False)

for fold_num, (train_indices, test_indices) in enumerate(kf.split(sample_indices)):
    print(f"Fold {fold_num+1}")
    model_save_name = f"save_model/Unet2D/Unet2D_fold{fold_num+1}.pth"
            
    model = Unet2D(n_classes=10).to(device)

    #################
    #  Test Dataset #
    #################
    dataset_test = MSDataset(ms_reader, test_indices)
    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=batch_size, shuffle=False)
    o_dataset_test = MSDataset(ms_reader, test_indices, None)

    model.load_state_dict(torch.load(model_save_name))
    dice_global, no_bg_dice, label_dice = evaluate(model, data_loader_test, o_dataset_test, device)
    print(f"Test dice global: {dice_global}")
    print(f"No Background dice: {no_bg_dice}")
    print(f"Each label dice: {label_dice}")
    print()
    df = pd.read_excel(result_filename)
    fold_result = [fold_num+1] + label_dice + [dice_global, no_bg_dice]
    df.loc[-1] = fold_result
    df.index = df.index + 1
    df.to_excel(result_filename, index=False)

Fold 1
Test dice global: 0.7506187558174133
No Background dice: 0.7546030282974243
Each label dice: [0.9976194500923157, 0.3883272111415863, 0.7677257657051086, 0.8020117878913879, 0.6446415781974792, 0.7794450521469116, 0.7919529676437378, 0.7889350056648254, 0.7706326842308044, 0.774895429611206]

Fold 2
Test dice global: 0.751479983329773
No Background dice: 0.7265217900276184
Each label dice: [0.9971669316291809, 0.7165870666503906, 0.8325670957565308, 0.7559942603111267, 0.6395816206932068, 0.6679778099060059, 0.7434919476509094, 0.7181137800216675, 0.7400827407836914, 0.7032361626625061]

Fold 3
Test dice global: 0.7467681765556335
No Background dice: 0.7101786136627197
Each label dice: [0.997839093208313, 0.6633767485618591, 0.8101839423179626, 0.8221350908279419, 0.6798381805419922, 0.7739604711532593, 0.7231922745704651, 0.7156186699867249, 0.6066797971725464, 0.6748578548431396]

Fold 4
Test dice global: 0.7192608714103699
No Background dice: 0.7095493078231812
Each label dic