In [11]:
import os
import sys
import SimpleITK as sitk
import numpy as np
from PIL import Image
import cv2
from matplotlib import pyplot as plot
import pandas as pd
import xlrd

In [14]:
path = "/Users/WangHao/Desktop/TODO/Data/3dunet/prediction/_20210128075939_0818060.nii.gz"
data = sitk.GetArrayFromImage(sitk.ReadImage(path))
image = data[0]
Image.fromarray(image)

In [13]:
path = "/Users/WangHao/Desktop/TODO/Data/rvm/prediction/_20210128075939_0818310.npy"
data = np.load(path, allow_pickle=True)
image = np.uint8(data[0])
Image.fromarray(image)

In [114]:
# 计算狭窄率，对外膜的mask进行腐蚀操作，去除正常血管壁的厚度
model_name = "3dunet"
result_dir = f'/Users/WangHao/Desktop/TODO/Data/{model_name}/label'
file_list = sorted(os.listdir(result_dir))
try:
    file_list.remove(".DS_Store")
except ValueError:
    pass

pn_all_3 = {}
pne_all_3 = {}
for idx, file_name in enumerate(file_list):
    if file_name.endswith(".nii.gz"):
        data = sitk.GetArrayFromImage(sitk.ReadImage(os.path.join(result_dir, file_name)))
    elif file_name.endswith(".npy"):
        data = np.load(os.path.join(result_dir, file_name), allow_pickle=True)
    else:
        print(f"data format {os.path.splitext(file_name)[-1]} is error")
        sys.exit()

    pn = []
    pne = []
    for k, img in enumerate(data):
        # pred 腐蚀
        out_mask = np.zeros_like(img)
        out_mask[img == 1] = 1  # 环形mask
        in_mask = np.zeros_like(img)
        in_mask[img == 2] = 1  # 内膜mask
        both_mask = out_mask + in_mask  # 外膜mask
        both_mask_erode = cv2.erode(both_mask,
                                    kernel=(3, 3),
                                    iterations=3)

        pred_narrow = 1 - (np.sum(in_mask) / np.sum(both_mask))
        pred_erode_narrow = 1 - (np.sum(in_mask) / np.sum(both_mask_erode))

        pn.append(pred_narrow)
        pne.append(pred_erode_narrow)

        if False:
            print(f'pred_narrow:{pred_narrow:.2f} | pred_narrow:{pred_erode_narrow:.2f}')

    pn_all_3[file_name] = pd.Series(pn)
    pne_all_3[file_name] = pd.Series(pne)

In [None]:
if False:
    with pd.ExcelWriter(f'{model_name}.xlsx') as writer:
        df_pn_all = pd.DataFrame(data=pn_all)
        df_pn_all.to_excel(writer,
                           sheet_name="diameter stenosis",
                           index=True,
                           header=True,
                           startrow=0,
                           startcol=0)

    with pd.ExcelWriter(f'{model_name}_erode.xlsx') as writer:
        df_pne_all = pd.DataFrame(data=pne_all)
        df_pne_all.to_excel(writer,
                            sheet_name="diameter stenosis",
                            index=True,
                            header=True,
                            startrow=0,
                            startcol=0)

In [None]:
save_dir = f"/Users/WangHao/Desktop/TODO/Data/{model_name}/figs"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

for idx in range(len(pn_all)):
    fig_title = list(pn_all.keys())[idx]+model_name
    fig_x = list(pn_all[list(pn_all.keys())[idx]].index)
    fig_y = list(pn_all[list(pn_all.keys())[idx]].values)
    fig = plot.figure()
    ax = fig.add_subplot(111)
    ax.set(title=fig_title, xlabel='frame', ylabel='diameter stenosis')
    ax.plot(fig_x, fig_y)
    ax.text(fig_y.index(max(fig_y)), max(fig_y),
            (fig_y.index(max(fig_y)), round(max(fig_y), 2)),
            color='r')
    ax.grid(True)
    fig.savefig(f'{save_dir}/{fig_title}.png')

for idx in range(len(pne_all)):
    fig_title = list(pne_all.keys())[idx]+model_name+'erode'
    fig_x = list(pne_all[list(pne_all.keys())[idx]].index)
    fig_y = list(pne_all[list(pne_all.keys())[idx]].values)
    fig = plot.figure()
    ax = fig.add_subplot(111)
    ax.set(title=fig_title, xlabel='frame', ylabel='diameter stenosis')
    ax.plot(fig_x, fig_y)
    ax.text(fig_y.index(max(fig_y)),
            max(fig_y), (fig_y.index(max(fig_y)), round(max(fig_y), 2)),
            color='r')
    ax.grid(True)
    fig.savefig(f'{save_dir}/{fig_title}.png')

print("运行完成")

In [None]:
save_dir = f"/Users/WangHao/Desktop/TODO/Data/figs"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

for idx in range(len(pn_all)):
    fig_title = list(pn_all.keys())[idx].split(".")[0]
    fig_x = list(pn_all[list(pn_all.keys())[idx]].index)
    fig_y = list(pn_all[list(pn_all.keys())[idx]].values)

    fig_title_1 = list(pn_all.keys())[idx] + 'rvm'
    fig_x_1 = list(pn_all_1[list(pn_all_1.keys())[idx]].index)
    fig_y_1 = list(pn_all_1[list(pn_all_1.keys())[idx]].values)

    fig_title_2 = list(pn_all.keys())[idx] + '3dunet'
    fig_x_2 = list(pn_all_2[list(pn_all_2.keys())[idx]].index)
    fig_y_2 = list(pn_all_2[list(pn_all_2.keys())[idx]].values)

    fig_title_3 = list(pn_all.keys())[idx] + 'label'
    fig_x_3 = list(pn_all_3[list(pn_all_3.keys())[idx]].index)
    fig_y_3 = list(pn_all_3[list(pn_all_3.keys())[idx]].values)

    fig = plot.figure(idx, figsize=(18, 4), dpi=300)
    plot.rcParams.update({"font.size": 8})
    ax = fig.add_subplot(141)
    ax.set(title="transbts", xlabel='frame', ylabel='diameter stenosis')
    ax.plot(fig_x, fig_y)
    ax.text(fig_y.index(max(fig_y)),
            max(fig_y), (fig_y.index(max(fig_y)), round(max(fig_y), 2)),
            color='r')
    ax.grid(True)

    ax1 = fig.add_subplot(142)
    ax1.set(title="rvm", xlabel='frame', ylabel='diameter stenosis')
    ax1.plot(fig_x_1, fig_y_1)
    ax1.text(fig_y_1.index(max(fig_y_1)),
            max(fig_y_1), (fig_y_1.index(max(fig_y_1)), round(max(fig_y_1), 2)),
            color='r')
    ax1.grid(True)

    ax2 = fig.add_subplot(143)
    ax2.set(title="3dunet", xlabel='frame', ylabel='diameter stenosis')
    ax2.plot(fig_x_2, fig_y_2)
    ax2.text(fig_y_2.index(max(fig_y_2)),
            max(fig_y_2),
            (fig_y_2.index(max(fig_y_2)), round(max(fig_y_2), 2)),
            color='r')
    ax2.grid(True)

    ax3 = fig.add_subplot(144)
    ax3.set(title="label", xlabel='frame', ylabel='diameter stenosis')
    ax3.plot(fig_x_3, fig_y_3)
    ax3.text(fig_y_3.index(max(fig_y_3)),
             max(fig_y_3),
             (fig_y_3.index(max(fig_y_3)), round(max(fig_y_3), 2)),
             color='r')
    ax3.grid(True)

    fig.savefig(f'{save_dir}/{fig_title}.png')