In [1]:
import os
import numpy
import cv2
from openpyxl import Workbook 
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns


In [2]:
def cofficent_calculate(pred,gt,threshold=0.5):
    eps = 1e-5

    if not pred.shape == gt.shape:
        pred = cv2.resize(pred, (gt.shape[1], gt.shape[0]))
    gt = gt / gt.max()
    pred = pred / 255
    if pred.max() != pred.min():
        pred = (pred - pred.min()) / (pred.max() - pred.min())

    preds = pred > threshold
    intersection = (preds * gt).sum()
    union =(preds + gt).sum()
    dice = 2 * intersection  / (union + eps)
    iou = intersection/(union - intersection + eps)
    return dice

In [5]:
split = 'TestEasyDataset/Unseen'

gt_path = '/memory/yizhenyu/dataset/SUN/data/SUN-SEG/{}/GT/'.format(split)
pred_LSINet = '/data/yizhenyu/project/video-polyp-seg/SLT-Net_align_Memory/res/cas_long2short/Net_epoch_4_best/SUN-SEG/'+split
pred_MAST = '/data/yizhenyu/datasets/pvt_MAST/'+split
pred_PNSplus = '/data/yizhenyu/datasets/Benchmark/2022-MIR-PNS+/'+split
pred_PolypPVT = '/data/yizhenyu/project/video-polyp-seg/Comparison/Polyp-PVT/results_final/Polyp-PVT/'+split
pred_SLTNet = '/data/yizhenyu/project/video-polyp-seg/SLT-Net/res/longterm/SLT-Net/'+split
pred_PraNet = '/memory/yizhenyu/results_map/VPS/PraNet/res_epoch_12/PraNet/'+split
pred_UNet = '/data/yizhenyu/project/video-polyp-seg/Comparison/pytorch-nested-unet/outputs/unet/'+split

model_names = ['name','LSINet','MAST','SLTNet','PNSplus','PolypPVT','PraNet','UNet']
pred_models = [pred_LSINet, pred_MAST, pred_SLTNet, pred_PNSplus, pred_PolypPVT,pred_PraNet,pred_UNet]

In [6]:
three_d_data = []
case_name = []
for case in tqdm(os.listdir(gt_path)):
    if 'DS_Store' in case:
        continue
    
    case_name.append(case)
    two_d_data = []
    for f in os.listdir(gt_path + case):

        gt = cv2.imread(os.path.join(gt_path,case,f),cv2.IMREAD_GRAYSCALE)
        dices = [f]
        for i, model in enumerate(pred_models):
            model_name = model_names[i]
            pred = cv2.imread(os.path.join(model,case,f),cv2.IMREAD_GRAYSCALE)
            dice = cofficent_calculate(pred,gt)
            dices.append(dice)

        two_d_data.append(dices)
    three_d_data.append(two_d_data)
        

100%|██████████| 88/88 [31:22<00:00, 21.39s/it]


In [6]:
from openpyxl.styles import PatternFill
import openpyxl
from openpyxl.styles import Color
from openpyxl.formatting.rule import ColorScale, FormatObject, Rule
import numpy as np

# 创建格式对象
first = FormatObject(type='min')
mid = FormatObject(type='num', val=40)
last = FormatObject(type='max')

# 创建颜色列表
colors = [Color('AA0000'), Color('00AA00')]

# 创建色阶规则
cs = ColorScale(cfvo=[first, mid, last], color=colors)

# 创建规则对象
rule = Rule(type='colorScale', colorScale=cs)

wb = Workbook()
for idx, two_d_data in enumerate(three_d_data):
    # 创建一个新的 sheet，名字为 Sheet1、Sheet2、Sheet3...
    sheet_name = case_name[idx]
    # min_value = min(min(np.exp(row)) for row in two_d_data)
    # max_value = max(max(np.exp(row)) for row in two_d_data)
    ws = wb.create_sheet(title=sheet_name)
    for col_idx, cell_value in enumerate(model_names):
        ws.cell(row=1, column=col_idx+1, value=cell_value)

    # 将二维数据写入 sheet
    for row_idx, row_data in enumerate(two_d_data):
        for col_idx, cell_value in enumerate(row_data):
            # 在 Excel 中，行和列索引从 1 开始
            # 计算蓝色和红色通道的值
            # red_value = int((max_value - np.exp(cell_value)) / (max_value - min_value) * 255)
            # green_value = int((np.exp(cell_value) - min_value) / (max_value - min_value) * 255)
            # 设置单元格背景色
            # fill = PatternFill(start_color=f"FF{red_value:02X}{green_value:02X}00", end_color=f"FF{red_value:02X}{green_value:02X}00", fill_type="solid")
            # ws.cell(row=row_idx+2, column=col_idx+1, value=cell_value).fill = fill
            ws.cell(row=row_idx+2, column=col_idx+1, value=cell_value)
    # color_scale_rule = ColorScaleRule(start_type='min', start_color='0000FF', end_type='max', end_color='FF0000')
    # ws.conditional_formatting.add(f"A1:{openpyxl.utils.get_column_letter(len(two_d_data[0]))}{len(two_d_data)}", rule)
# 保存 Excel 文件
wb.save("Sample_Dice_easy_unseen.xlsx")