In [1]:
import os
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn

from nets.deeplabv3_plus import DeepLab
from utils.utils import cvtColor, preprocess_input, resize_image

## generate validation mask

In [2]:
class Detection(object):
    _defaults = {"model_path" : 'logs/best_model.pth',"num_classes" : 2,
                 "input_shape" : [512, 512],"cuda" : True,
    }
    def __init__(self, **kwargs):
        self.__dict__.update(self._defaults)
        for name, value in kwargs.items():
            setattr(self, name, value)
        #   画框的颜色
        self.colors = [ (0, 0, 0), (128, 0, 0)]
        self.generate()
    #   获得所有的分类
    def generate(self):
        #   载入模型与权值
        self.net = DeepLab(num_classes=self.num_classes)

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.net.load_state_dict(torch.load(self.model_path, map_location=device))
        self.net = self.net.eval()

        if self.cuda:
            self.net = nn.DataParallel(self.net)
            self.net = self.net.cuda()

    def get_detection(self, image):
        #   代码仅仅支持RGB图像的预测，所有其它类型的图像都会转化成RGB
        image = cvtColor(image)
        orininal_h = np.array(image).shape[0]
        orininal_w = np.array(image).shape[1]

        image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0]))
        #   添加上batch_size维度
        image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, np.float32)), (2, 0, 1)), 0)

        with torch.no_grad():
            images = torch.from_numpy(image_data)
            if self.cuda:
                images = images.cuda()

            pr = self.net(images)[0]
            #   取出每一个像素点的种类
            pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy()
            #   将灰条部分截取掉
            pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
                    int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
            #   进行图片的resize
            pr = cv2.resize(pr, (orininal_w, orininal_h), interpolation = cv2.INTER_LINEAR)
            #   取出每一个像素点的种类
            pr = pr.argmax(axis=-1)

        image = Image.fromarray(np.uint8(pr))
        return image

## from hist to miou

In [3]:
# 设标签宽W，长H
def fast_hist(a, b, n):
    # a是转化成一维数组的标签，形状(H×W,)；b是转化成一维数组的预测结果，形状(H×W,)
    k = (a >= 0) & (a < n)
    # np.bincount计算了从0到n**2-1这n**2个数中每个数出现的次数，返回值形状(n, n)
    # 返回中，写对角线上的为分类正确的像素点
    return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)

def per_class_iu(hist):
    return np.diag(hist) / np.maximum((hist.sum(1) + hist.sum(0) - np.diag(hist)), 1)

def cal_miou(gt_dir, pred_dir, png_name_list, num_classes):
    # 创建一个全是0的矩阵，是一个混淆矩阵
    hist = np.zeros((num_classes, num_classes))
    # 获得验证集标签路径列表，获得验证集图像分割结果路径列表，方便直接读取
    gt_imgs = [os.path.join(gt_dir, x + ".png") for x in png_name_list]
    pred_imgs = [os.path.join(pred_dir, x + ".png") for x in png_name_list]
    # 读取每一个（图片-标签）对
    for ind in range(len(gt_imgs)):
        # 读取一张图像分割结果，转化成numpy数组
        pred = np.array(Image.open(pred_imgs[ind]))
        # 读取一张对应的标签，转化成numpy数组
        label = np.array(Image.open(gt_imgs[ind]))
        # 如果图像分割结果与标签的大小不一样，这张图片就不计算
        if len(label.flatten()) != len(pred.flatten()):  
            print(
                'Skipping: len(gt) = {:d}, len(pred) = {:d}, {:s}, {:s}'.format(
                    len(label.flatten()), len(pred.flatten()), gt_imgs[ind],
                    pred_imgs[ind]))
            continue
        # 对一张图片计算2×2的hist矩阵，并累加
        hist += fast_hist(label.flatten(), pred.flatten(), num_classes)
    # 计算所有验证集图片的逐类别mIoU值
    miou = per_class_iu(hist)
    # 在所有验证集图像上求所有类别平均的mIoU值，计算时忽略NaN值
    print('===> miou: ' + str(round(np.nanmean(miou) * 100, 2)))

## from mask to boundary iou

In [4]:
#mask--->boundary
def mask_to_boundary(mask, dilation_ratio=0.02):
    h, w = mask.shape
    img_diag = np.sqrt(h ** 2 + w ** 2) # 计算图像对角线长度
    dilation = int(round(dilation_ratio * img_diag))
    if dilation < 1:
        dilation = 1
    # Pad image so mask truncated by the image border is also considered as boundary.
    new_mask = cv2.copyMakeBorder(mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, value=0)
    kernel = np.ones((3, 3), dtype=np.uint8)
    new_mask_erode = cv2.erode(new_mask, kernel, iterations=dilation)   
    # 因为之前向四周填充了0, 故而这里不再需要四周
    mask_erode = new_mask_erode[1 : h + 1, 1 : w + 1]   
    # G_d intersects G in the paper.
    return mask - mask_erode

#计算boundary iou
def boundary_iou(gt, dt, dilation_ratio=0.02):
    gt_boundary = mask_to_boundary(gt, dilation_ratio)
    dt_boundary = mask_to_boundary(dt, dilation_ratio)
    intersection = ((gt_boundary * dt_boundary) > 0).sum()
    union = ((gt_boundary + dt_boundary) > 0).sum()
    if union < 1:
    	return 0
    boundary_iou = intersection / union
    return boundary_iou

## calculate m_iou, boundary iou

In [5]:
num_classes = 2
#   指向数据集所在的文件夹
dataset_path = 'weizmann_horse_db'

image_ids = open(os.path.join(dataset_path, "datasets/val.txt"),'r').read().splitlines()
gt_dir = os.path.join(dataset_path, "mask/")
pred_dir = "detection"

boundary_iou_list = []
if not os.path.exists(pred_dir):
    os.makedirs(pred_dir)

Detect = Detection()

for image_id in image_ids:
    image_path = os.path.join(dataset_path, "horse/"+image_id+".jpg")
    image = Image.open(image_path)
    image = Detect.get_detection(image)
    image.save(os.path.join(pred_dir, image_id + ".png"))
    
    gt_path = os.path.join(dataset_path, "mask/"+image_id+".png")
    img_gt = Image.open(gt_path)
    gt = np.array(img_gt)
    
    gt_path = os.path.join(pred_dir, image_id+".png")
    img_dt = Image.open(gt_path)
    dt = np.array(img_dt)
    
    b_iou = boundary_iou(gt, dt, dilation_ratio=0.02)
    boundary_iou_list.append(b_iou)
print("----------------------------")
cal_miou(gt_dir, pred_dir, image_ids, num_classes)  # 执行计算mIoU的函数
print('===> boundary iou: ' + str(round(np.nanmean(boundary_iou_list) * 100, 2)))
print("----------------------------")

----------------------------
===> miou: 94.35
===> boundary iou: 74.62
----------------------------
