In [None]:
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import torchvision.transforms.v2 as v2
from monai.metrics import HausdorffDistanceMetric, get_confusion_matrix, compute_confusion_matrix_metric, compute_iou
import torch.nn.functional as F

In [None]:
from improved_diffusion.ss_unet import UNetModel_WithSSF
from improved_diffusion.script_util import create_gaussian_diffusion

In [None]:
def square_pad(img, image_size=224):
    '''
    对图像进行Square Pad，输入图像size为3*H*W
    返回图像size为1*3*image_size*image_size-Net
    '''
    h, w = img.shape[-2], img.shape[-1]
    max_ = max(h,w)
    pad_t = pad_d = h_pad = (max_-h)/2
    if h_pad % 1 > 0:
        pad_t = int(h_pad - .5)
        pad_d = int(h_pad + .5)
    pad_l = pad_r = w_pad = (max_-w)/2
    if w_pad % 1 > 0:
        pad_l = int(w_pad - .5)
        pad_r = int(w_pad + .5)
    pad_list = [int(pad_l), int(pad_t), int(pad_r), int(pad_d)]
    trans = v2.Compose([
        v2.Pad(pad_list, 0, 'constant'),
        v2.Resize((image_size,image_size),antialias=True)
    ])
    return trans(img) #unsqueeze for batch dimension


In [None]:
def inv_square_pad(img, net_ge_img_mask):
    '''
    对图像进行Square Pad的逆操作，输入原始图像来判断各个方向需要减少多少行零填充，size为3*H*W
    输入的为FAT-Net生成的图像掩码，size为1*1*224*224
    返回图像size为1*3*224*224，用于输入FAT-Net
    '''
    h, w = img.shape[-2], img.shape[-1]
    max_ = max(h,w)
    h_pad = (max_-h)/2
    if h_pad % 1 > 0:
        pad_t = int(h_pad - .5)
        pad_d = int(h_pad + .5)
    else:
        pad_t = int(h_pad)
        pad_d = int(h_pad)
    w_pad = (max_-w)/2
    if w_pad % 1 > 0:
        pad_l = int(w_pad - .5)
        pad_r = int(w_pad + .5)
    else:
        pad_l = int(w_pad)
        pad_r = int(w_pad)
    trans = v2.Compose([
        v2.Resize((max_,max_), antialias=True)
    ])
    return trans(net_ge_img_mask[0:])[..., pad_t:max_-pad_d, pad_l:max_-pad_r]

In [None]:
def edge_map(mask):
    assert torch.all((mask == 0) | (mask == 1)) # 必须为0,1掩码

    device = mask.device
    conv_kernal = torch.ones((1,1,3,3), device=device)
    conv_mask = torch.nn.functional.conv2d(mask, conv_kernal, padding=1) # pad=1防止图像大小改变
    conv_mask = torch.where((conv_mask > 0) & (conv_mask < 9), torch.ones_like(conv_mask), torch.zeros_like(conv_mask)) # 落在0-9之间表明3*3的区域不全为0，且不全为1
    edge = torch.mul(conv_mask, mask) # 只有mask上不为0，且附近3*3的区域不全为0，且不全为1，edge的位置才为1，否则为0
    edge = edge.squeeze() # squeeze是因为之后叠加边缘图需要保证必须是二维张量

    return edge

In [None]:
image_size = 256
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# unet hyper parameterts
model_channnels = 128
in_channels = 4
out_channels = 1
num_res_blocks = 1
attn_resolutions = [] # if use, default is [16]
dropout = 0.0
channel_mult = (1, 1, 2, 2, 4, 4) if image_size == 256 else None
dims = 2
num_classes = None
num_heads = 4 # not used in model
num_heads_upsample = -1 # not used in model
use_checkpoint = False
use_scale_shift_norm = False

# diffusion hyper parameters
steps = 1000
learn_sigma = False
predict_xstart = False

In [None]:
diff_unet_root = "./"
diff_unet_path = os.path.join(diff_unet_root, "./final_result/diff_unet_v1_withgan_withss.pt")
HD95 = HausdorffDistanceMetric(include_background=True, percentile=95.)

In [None]:
# Diff_UNet model load

DIFF_UNET = UNetModel_WithSSF(model_channels=model_channnels, in_channels=in_channels, out_channels=out_channels, channel_mult=channel_mult, num_res_blocks=num_res_blocks, attention_resolutions=attn_resolutions, dropout = dropout, dims=dims, num_classes=num_classes, num_heads=num_heads, num_heads_upsample=num_heads_upsample, use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm);
DIFF_UNET.load_resunet(if_pre=False, in_channels=3);
state_dict = torch.load(diff_unet_path);
DIFF_UNET.load_state_dict(state_dict=state_dict);
DIFF_UNET.to(DEVICE)


# diffusion = create_gaussian_diffusion(steps=steps, learn_sigma=learn_sigma, predict_xstart=predict_xstart)
diffusion = create_gaussian_diffusion(steps=steps, timestep_respacing="10", learn_sigma=learn_sigma, predict_xstart=predict_xstart)

In [None]:
from dataset import ISIC_ori_test_Dataset
test_data_path = "d:\DATA\ISIC2016"
image_size = 256
testdata = ISIC_ori_test_Dataset(test_data_path)
ori_test_loader = DataLoader(testdata, batch_size=1, shuffle=False)
img_save_path = "./final_result/images"
if not os.path.exists(img_save_path):
    os.makedirs(img_save_path)

In [None]:
def thres_bi(x):
    x = torch.where(x > 0.5, 1., 0.)
    return x

In [None]:
# 对增强后数据集的测试集的每一张图进行测试，并保存（保存的是同一张图像，因此会覆盖）
# 通过调试时，在断点暂停来查看最后的分割结果，大概查看模型的性能
# 下面两个二选一

SegNet = DIFF_UNET
totaldice=0
totalsens=0
totalacc=0
totalhd95=0
totaliou=0
stepValidcnt=0
num_ensemble = 5

with torch.no_grad():
    SegNet.eval()
    for (img,real_mask, id) in tqdm(ori_test_loader):
        id = id[0]
        (img,real_mask)=(img.to(DEVICE),real_mask.to(DEVICE))
        img_pad = square_pad(img, image_size)
        img_pad = img_pad.repeat(num_ensemble, 1, 1, 1)
        mask_shape = (num_ensemble, 1, img_pad.shape[-2], img_pad.shape[-1])
        fake_mask = diffusion.ddim_sample_loop(model=SegNet, shape =mask_shape, denoised_fn=thres_bi, clip_denoised=True, model_kwargs={'img': img_pad}, progress=False)
        fake_mask = inv_square_pad(img, fake_mask)
        fake_mask = torch.mean(fake_mask, 0, keepdim=True)

        # resize mask to 256*256
        scale = image_size / max(fake_mask.shape[-1], fake_mask.shape[-2])
        fake_mask_256 = F.interpolate(fake_mask, scale_factor=scale, mode='bilinear')
        real_mask_256 = F.interpolate(real_mask, scale_factor=scale, mode='nearest')
        fake_mask_256 = torch.where(fake_mask_256 > 0.5, torch.ones_like(fake_mask_256), torch.zeros_like(fake_mask_256))
        fake_mask = torch.where(fake_mask > 0.5, torch.ones_like(fake_mask), torch.zeros_like(fake_mask))

        # metric
        conf_mat=get_confusion_matrix(fake_mask,real_mask.int())
        batch_dice=compute_confusion_matrix_metric('f1 score', conf_mat)
        batch_sens=compute_confusion_matrix_metric('sensitivity', conf_mat)
        batch_acc=compute_confusion_matrix_metric('accuracy', conf_mat)
        batch_hd95=HD95(fake_mask_256, real_mask_256.int())
        batch_iou=compute_iou(fake_mask, real_mask.int())

        totaldice+=batch_dice.mean()
        totalsens+=batch_sens.mean()
        totalacc+=batch_acc.mean()
        totalhd95+=batch_hd95.mean()
        totaliou+=batch_iou.mean()

        stepValidcnt+=1
        img = torch.clip(img, 0., 1.) # 不加的话，img可能会有大于1的值
        saved_img = torch.cat((img[0], real_mask.squeeze(0).repeat(3,1,1), fake_mask.squeeze(0).repeat(3,1,1)),dim=1).permute(1,2,0).cpu().numpy()
        plt.imsave("test.png", np.clip(saved_img,0.,1.))
        # saved_img = fake_mask.squeeze(0).repeat(3,1,1).permute(1,2,0).cpu().numpy()
        # plt.imsave(os.path.join(img_save_path, id+"_gen.png"), np.clip(saved_img,0.,1.))

print("{} Model Average Dice is {}".format(SegNet.type.__name__, totaldice / stepValidcnt))
print("{} Model Average Sensitivity is {}".format(SegNet.type.__name__, totalsens / stepValidcnt))
print("{} Model Average Accuracy is {}".format(SegNet.type.__name__, totalacc / stepValidcnt))
print("{} Model Average HD95 is {}".format(SegNet.type.__name__, totalhd95 / stepValidcnt))
print("{} Model Average Iou is {}".format(SegNet.type.__name__, totaliou / stepValidcnt))