In [11]:
import torchvision
from PIL import Image
import torch

def get_tp(mask, prd):
    # 都是基于tensor计算
    # tp 取共同的像素，矩阵相乘
    tp = mask * prd
    return tp

def get_fp(mask, prd):
    # fp prd中去除mask的部分
    fp = prd * (1 - mask)
    return fp

def get_fn(mask, prd):
    # FN 取mask去掉prd的部分
    fn = mask * (1 - prd)
    return fn

def get_background(image, tp, fp, fn):
    tp_fp_fn = tp + fp + fn
    background = image * (1 - tp_fp_fn)
    return background

def image_gray2RGB(image):
    image_RGB = torch.zeros(size=(image.shape[0] + 2, image.shape[1], image.shape[2]))
    # 维度上复制三次
    image_RGB[0, :, :] = image
    image_RGB[1, :, :] = image
    image_RGB[2, :, :] = image
    return image_RGB

def image_gray2RGBRed(image):
    image_RGB_RED = torch.zeros(size=(image.shape[0] + 2, image.shape[1], image.shape[2]))
    # 维度上复制三次
    image_RGB_RED[0, :, :] = image
    return image_RGB_RED

def image_gray2RGBwhite(image):
    image_RGB = torch.zeros(size=(image.shape[0] + 2, image.shape[1], image.shape[2]))
    # 维度上复制三次
    image_RGB[0, :, :] = image
    image_RGB[1, :, :] = image
    image_RGB[2, :, :] = image
    return image_RGB



def image_gray2RGBGreen(image):
    image_RGB_GREEN = torch.zeros(size=(image.shape[0] + 2, image.shape[1], image.shape[2]))
    # 维度上复制三次
    image_RGB_GREEN[1, :, :] = image
    return image_RGB_GREEN


def image_gray2RGBlue(image):
    image_RGB_BLUE = torch.zeros(size=(image.shape[0] + 2, image.shape[1], image.shape[2]))
    # 维度上复制三次
    image_RGB_BLUE[2, :, :] = image
    return image_RGB_BLUE


def image_gray2RGBYellow(image):
    image_RGB_BLUE = torch.zeros(size=(image.shape[0] + 2, image.shape[1], image.shape[2]))
    # 维度上复制三次
    image_RGB_BLUE[0, :, :] = image
    image_RGB_BLUE[1, :, :] = image
    return image_RGB_BLUE

def converge_image(image_RGB, tp_RGB, fp_RGB, fn_RGB):
    image = image_RGB + tp_RGB + fp_RGB + fn_RGB
    return image

def save_image(image, dst):
    # image的格式为[H,W,C]
    image = torch.clamp(image * 255, 0, 255).permute(1, 2, 0).byte().cpu().numpy()
    image = Image.fromarray(image)  # PIL.Image接受[H,W,C]这样格式图
    image.save(dst)

# def get_tensor_image_label(image_path, label_path):
#     # 转换格式
#     transform2tensor = val_transform
#     image = rgb_loader(image_path)
#     image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
#     label = binary_loader(label_path)
#     processes = transform2tensor(image=image, mask=label)

#     label = processes['mask']
#     label[label > 0] = 255
#     label = (label / 255).long()
#     return processes['image'], label

# def get_tensor_predict(image_path):
#     # 转换格式
#     image_tensor = binary_loader(image_path)
#     image_tensor = torch.tensor(image_tensor)
#     image_tensor = (image_tensor / 255).long()
#     return image_tensor

transformer = torchvision.transforms.Compose([        
    torchvision.transforms.ToPILImage(),
    torchvision.transforms.Resize((224,224)),
    torchvision.transforms.ToTensor(),
])
temp_transformer = torchvision.transforms.Compose([      
    torchvision.transforms.ToPILImage(),
    torchvision.transforms.Resize((224,224)),
    torchvision.transforms.RandomRotation((-90,-90)),
#     torchvision.transforms.RandomVerticalFlip(p=1),
    torchvision.transforms.RandomHorizontalFlip(p=1),
    torchvision.transforms.ToTensor(),
])

def get_tensor_image(image_path,augment=True):
    # 转换格式
    transform2tensor = transformer
    temp = temp_transformer
    
    image = Image.open(image_path)
    image = image.convert("L")
    if augment:
        image_tensor = transform2tensor(image)
    else:
        image_tensor = temp(image)
    return image_tensor


def test1(origin_image_path, mask_image_path, prd_image_path, save_path):
    origin_image_path = origin_image_path
    mask_image_path = mask_image_path
    prd_image_path = prd_image_path
    save_dst_path = save_path

#     origin_image, mask_image = get_tensor_image_label(image_path=origin_image_path, label_path=mask_image_path)
#     prd_image = get_tensor_predict(image_path=prd_image_path)
    origin_image = get_tensor_image(image_path=origin_image_path,augment=True)
    mask_image = get_tensor_image(image_path=mask_image_path,augment=False)
    prd_image = get_tensor_image(image_path=prd_image_path,augment=False)
    

    # 取TP, FP, FN
    tp = get_tp(mask=mask_image, prd=prd_image)
    fn = get_fn(mask=mask_image, prd=prd_image)
    fp = get_fp(mask=mask_image, prd=prd_image)

    # 获取背景
    background_image = get_background(image=origin_image, tp=tp, fp=fp, fn=fn)

    # 转化为RGB，并取一定的颜色
    background_image_RGB = image_gray2RGB(background_image)
    tp_image_GREEN = image_gray2RGBwhite(tp)
    fp_image_RED = image_gray2RGBRed(fp)
    fn_image_Yellow = image_gray2RGBGreen(fn)

    # 图片融合
    image = converge_image(image_RGB=background_image_RGB, tp_RGB=tp_image_GREEN
                           , fp_RGB=fp_image_RED, fn_RGB=fn_image_Yellow)

    # 保存结果
    save_image(image=image, dst=save_dst_path)

In [5]:
import torch
import os
import cv2

# path = 'XJUN_outs\segmentation'
# img_names = os.listdir(path)
# out_path = 'XJUN_all_samples'

path = 'XJUN_outs\segmentation'
img_names = os.listdir(path)
out_path = 'XJUN_all_samples'

# ISIC['AttU_Net', '_base_WCCE_0.6.png', '_deeplabv3plus.png', '_R2AttU_Net.png', '_U__Net.png', '_UNext.png']
# XJUN[]
# for i in img_names:
#     origin_image_path = os.path.join(path, i, i+'_orignal.png')
#     mask_image_path = os.path.join(path, i, i+'_target.png')
#     prd_image_path = os.path.join(path, i, i+'_R2U_Net.png')
    
    
#     save_path = os.path.join(out_path, i, i+'_R2U_Net.png')
#     test1(origin_image_path, mask_image_path, prd_image_path, save_path)

In [6]:
# 移target
for i in img_names:
    mask_image_path = os.path.join(path, i, i+'_target.png')
    
    image = cv2.imread(mask_image_path)
    image = temp_transformer(image)
    save_path = os.path.join(out_path, i, i+'_target.png')
    save_image(image=image, dst=save_path)

In [19]:
# 移热力图
import torch
import os
import cv2

path = 'ISIC_outs\segmentation'
img_names = os.listdir(path)
out_path = 'ISIC_outs\heat'

orignal_path = 'ISIC_outs'

# MHDM_CE  MHDM_WCCE  Unet_bddice
name = 'MHDM_CE'
for i in img_names:
    mask_image_path = os.path.join(orignal_path, name, i+'_heat.png')
    
    image = cv2.imread(mask_image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    image = transformer(image)
    save_path = os.path.join(out_path, i, i+'_'+name+'.png')
    save_image(image=image, dst=save_path)

In [8]:
path = 'XJUN_outs\segmentation'
img_names = os.listdir(path)

for i in img_names:
    os.mkdir(os.path.join('XJUN_all_samples', i))

FileExistsError: [WinError 183] 当文件已存在时，无法创建该文件。: 'XJUN_all_samples\\SK_0004'