## Mask R-CNN - Inspect Trained Model

Code and visualizations to test, debug, and evaluate the Mask R-CNN model.

In [None]:
import os
import sys
import random
import math
import re
import time
import colorsys
import numpy as np
import cv2
import tensorflow as tf

from skimage.measure import find_contours
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.lines as lines
from matplotlib.patches import Polygon
import IPython.display


import utils
import visualize
from visualize import display_images
import model_res18_ose as modellib
from model import log
%matplotlib inline 

os.environ['CUDA_VISIBLE_DEVICES'] = '3'

# Root directory of the project
ROOT_DIR = os.getcwd()

# Directory to save logs and trained model
MODEL_DIR = os.path.join(ROOT_DIR, "logs")

#Directory to save results
RESULT_DIR = os.path.join(ROOT_DIR, "results")


# Path to trained weights file
# Download this file and place in the root of your 
# project (See README file for details)
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")

CXR_MODEL_PATH = os.path.join(MODEL_DIR, "res18_kmeans1.h5")

## Configurations

In [None]:
import cxr
config = cxr.CxrConfig()


In [None]:
# Override the training configurations with a few
# changes for inferencing.
class InferenceConfig(config.__class__):
    # Run detection on one image at a time
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1

config = InferenceConfig()
config.display()

## Notebook Preferences

In [None]:
# Device to load the neural network on.
# Useful if you're training a model on the same 
# machine, in which case use CPU and leave the
# GPU for training.
DEVICE = "/gpu:0"  # /cpu:0 or /gpu:0

# Inspect the model in training or inference modes
# values: 'inference' or 'training'
# TODO: code for 'training' test mode not ready yet
TEST_MODE = "inference"

In [None]:
def get_ax(rows=1, cols=1, size=16):
    """Return a Matplotlib Axes array to be used in
    all visualizations in the notebook. Provide a
    central point to control graph sizes.
    
    Adjust the size attribute to control how big to render images
    """
    _, ax = plt.subplots(rows, cols, figsize=(size*cols, size*rows))
    return ax

## Load Validation Dataset

In [None]:
# Build validation dataset
if config.NAME == 'shapes':
    dataset = shapes.ShapesDataset()
    dataset.load_shapes(500, config.IMAGE_SHAPE[0], config.IMAGE_SHAPE[1])
elif config.NAME == "coco":
    dataset = coco.CocoDataset()
    dataset.load_coco(COCO_DIR, "minival")
elif config.NAME == "cxr":
    dataset = cxr.CxrDataset()
    dataset.load_cxr(txt='/media/Disk/wangfuyu/Mask_RCNN/data/cxr/val_id.txt')

    
# Must call before using the dataset
dataset.prepare()

print("Images: {}\nClasses: {}".format(len(dataset.image_ids), dataset.class_names))

## Load Model

In [None]:
# Create model in inference mode
with tf.device(DEVICE):
    model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR,
                              config=config)

# Set weights file path
if config.NAME == "shapes":
    weights_path = SHAPES_MODEL_PATH
elif config.NAME == "coco":
    weights_path = COCO_MODEL_PATH
elif config.NAME == "cxr":
    weights_path = CXR_MODEL_PATH
    
# Or, uncomment to load the last model you trained
# weights_path = model.find_last()[1]

# Load weights
print("Loading weights ", weights_path)
model.load_weights(weights_path, by_name=True)

## Run Detection

In [None]:
# Compute IoU

visual_dir = './visual/cxr'
maskrcnn_results_dir = '/media/Disk/wangfuyu/Mask_RCNN/results_255/cxr'

isExists=os.path.exists(visual_dir)
if not isExists:
    os.makedirs(visual_dir) 

isExists=os.path.exists(maskrcnn_results_dir)
if not isExists:
    os.makedirs(maskrcnn_results_dir) 


def random_colors(N, bright=True):
    """
    Generate random colors.
    To get visually distinct colors, generate them in HSV space then
    convert to RGB.
    """
    brightness = 1.0 if bright else 0.7
    hsv = [(i / N, 1, brightness) for i in range(N)]
    colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
    random.shuffle(colors)
    
    return colors


def apply_mask(image, mask, color, alpha=0.5):
    """Apply the given mask to the image.
    """
    for c in range(3):
        image[:, :, c] = np.where(mask == 1,
                                  image[:, :, c] *
                                  (1 - alpha) + alpha * color[c] * 255,
                                  image[:, :, c])
    return image



def display_instances(image, boxes, masks, class_ids, class_names,
                      save_dir, scores=None, title="",
                      figsize=(16, 16), ax=None):
    """
    boxes: [num_instance, (y1, x1, y2, x2, class_id)] in image coordinates.
    masks: [height, width, num_instances]
    class_ids: [num_instances]
    class_names: list of class names of the dataset
    scores: (optional) confidence scores for each box
    figsize: (optional) the size of the image.
    """
    # Number of instances
    N = boxes.shape[0]
    if not N:
        print("\n*** No instances to display *** \n")
    else:
        assert boxes.shape[0] == masks.shape[-1] == class_ids.shape[0]

    if not ax:
        fig = fig = plt.figure(figsize=figsize)

        ax = fig.add_subplot(1,1,1)

    # Generate random colors
    colors = random_colors(N)

    # Show area outside image boundaries.
    height, width = image.shape[:2]
    ax.set_ylim(height + 10, -10)
    ax.set_xlim(-10, width + 10)
    ax.axis('off')
    ax.set_title(title)

    masked_image = image.astype(np.uint32).copy()
    for i in range(N):
        color = colors[i]

        # Bounding box
        if not np.any(boxes[i]):
            # Skip this instance. Has no bbox. Likely lost in image cropping.
            continue
        y1, x1, y2, x2 = boxes[i]
        p = patches.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2,
                              alpha=0.7, linestyle="dashed",
                              edgecolor=color, facecolor='none')
        ax.add_patch(p)

        # Label
        class_id = class_ids[i]
        score = scores[i] if scores is not None else None
        label = class_names[class_id]
        x = random.randint(x1, (x1 + x2) // 2)
        caption = "{} {:.3f}".format(label, score) if score else label
        ax.text(x1, y1 + 8, caption,
                color='w', size=11, backgroundcolor="none")

        # Mask
        mask = masks[:, :, i]
        masked_image = apply_mask(masked_image, mask, color)

        # Mask Polygon
        # Pad to ensure proper polygons for masks that touch image edges.
        padded_mask = np.zeros(
            (mask.shape[0] + 2, mask.shape[1] + 2), dtype=np.uint8)
        padded_mask[1:-1, 1:-1] = mask
        contours = find_contours(padded_mask, 0.5)
        for verts in contours:
            # Subtract the padding and flip (y, x) to (x, y)
            verts = np.fliplr(verts) - 1
            p = Polygon(verts, facecolor="none", edgecolor=color)
            ax.add_patch(p)

#     fig.savefig(save_dir)
    return masked_image.astype(np.uint8)

def _dice(y_pred, y_true):
    smooth = 1.
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    intersection = (y_true_f * y_pred_f).sum()
    score = (2. * intersection + smooth) / (y_true_f.sum() + y_pred_f.sum() + smooth)
    return score

def _fast_hist(label_pred, label_true, num_classes):
    mask = (label_true >= 0) & (label_true < num_classes)
    hist = np.bincount(
        num_classes * label_true[mask].astype(int) +
        label_pred[mask], minlength=num_classes ** 2).reshape(num_classes, num_classes)
    return hist

def compute_IOU(lp, lt, num_classes=2):
    hist = np.zeros((num_classes, num_classes))
    hist += _fast_hist(lp.flatten(), lt.flatten(), num_classes)
    dice = _dice(y_pred=lp, y_true=lt)
    # axis 0: gt, axis 1: prediction
    acc = np.diag(hist).sum() / hist.sum()
    acc_cls = np.diag(hist) / hist.sum(axis=1)
    acc_cls = np.nanmean(acc_cls)
    iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
    # mean_iu = np.nanmean(iu)
    freq = hist.sum(axis=1) / hist.sum()
    fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()

    return acc, acc_cls, iu, fwavacc, dice


# Compute VOC-style Average Precision
def compute_batch_ap(image_ids, iou_threshold=0.5):
    APs = []
    txt = open('res18_kmeans2.txt', 'w')
    
    txt.write('filename' + ' ' + 'image_id' + ' ' + 'iu' + ' ' + 'dice' + '\n')
    
#     images_dir = '/media/Disk/wangfuyu/Mask_RCNN/data/cxr/1024/images'
#     masks255_dir = '/media/Disk/wangfuyu/Mask_RCNN/data/cxr/1024/masks255'
#     isExists=os.path.exists(images_dir)
#     if not isExists:
#         os.makedirs(images_dir)  
    
#     isExists=os.path.exists(masks255_dir)
#     if not isExists:
#         os.makedirs(masks255_dir)

    
    iu, dice = np.zeros(len(image_ids)), np.zeros(len(image_ids))
    iu_normal, iu_abnormal = [], []
    dice_normal, dice_abnormal = [], []
    for index, image_id in enumerate(image_ids):
        # Load image
        image, image_meta, row_wave, col_wave, gt_class_id, gt_bbox, gt_mask =\
            modellib.load_image_gt(dataset, config,
                                   image_id, use_mini_mask=False)
        
#         info = dataset.image_info[image_id]
#         filename = info['filename']
#         cv2.imwrite(os.path.join(images_dir, filename + '.jpg'), image, 
#                     [int(cv2.IMWRITE_PNG_COMPRESSION), 9])
#         gt_mask = np.sum(gt_mask, axis=2) 
#         gt_mask *= 255
#         cv2.imwrite(os.path.join(masks255_dir, filename + '.png'), gt_mask, 
#                     [int(cv2.IMWRITE_PNG_COMPRESSION), 9])
#         print(filename)
    
        
        # Run object detection
        results = model.detect([image], np.stack([row_wave]), np.stack([col_wave]), verbose=0)
#         results = model.detect([image], verbose=0)

        
        # Compute AP
        r = results[0]
        AP, precisions, recalls, overlaps =\
            utils.compute_ap(gt_bbox, gt_class_id,
                              r['rois'], r['class_ids'], r['scores'], iou_threshold)
            
        APs.append(AP)
        
        info = dataset.image_info[image_id]
        filename = info['filename']
        

#         display_instances(image, r['rois'], r['masks'], r['class_ids'], dataset.class_names, 
#                           os.path.join(visual_dir, filename + '.png'), 
#                           r['scores'], title="Predictions")

#         masked_image = display_instances(image, r['rois'], r['masks'], r['class_ids'], 
#                                          dataset.class_names, 
#                                          os.path.join(visual_dir, filename + '.png'), 
#                                          r['scores'], title="Predictions")
#         cv2.imwrite(os.path.join(visual_dir, filename + '.png'), masked_image, [int(cv2.IMWRITE_PNG_COMPRESSION), 9])

            
        shape = r['masks'].shape
        pred_mask = np.zeros((1024, 1024))
        
        pred_mask = pred_mask.astype(np.int64)
        r['masks'] = r['masks'].astype(np.int64)
        
        
        if shape[0] == 1024:    
            for i in range(shape[2]):
                pred_mask = np.bitwise_or(pred_mask, r['masks'][:,:,i])


        gt_mask = np.sum(gt_mask, axis=2)        
#         print (pred_mask.shape, gt_mask.shape, shape)
        
        
#         plt.figure(figsize=(20,20))
#         plt.subplot(2, 1, 1),
#         plt.imshow(gt_mask)
#         plt.subplot(2, 1, 2)
#         plt.imshow(pred_mask)
#         plt.show()
        
        evals = compute_IOU(pred_mask, gt_mask)
        iu[index] = evals[2][1]
        dice[index] = evals[4]
        if (filename[0:2] == 'CA'):
            iu_abnormal.append(iu[index])
            dice_abnormal.append(dice[index])
        else:
            iu_normal.append(iu[index])
            dice_normal.append(dice[index])

        
        print(filename, image_id, iu[index])
        txt.write(filename + ' ' + str(image_id) + ' ' + str(iu[index])[0:8] 
                  + ' ' + str(dice[index])[0:8] + '\n')
        filename = filename + '_' + str(iu[index])[0:8]
#         cv2.imwrite(os.path.join(maskrcnn_results_dir, filename + '.png'), (pred_mask*255).astype(np.uint8), [int(cv2.IMWRITE_PNG_COMPRESSION), 9])
    
    iu_abnormal = np.stack(iu_abnormal)
    iu_normal = np.stack(iu_normal)
    dice_abnormal = np.stack(dice_abnormal)
    dice_normal = np.stack(dice_normal)
    print(iu.mean(), iu.std(), dice.mean(), dice.std())
#     print(iu_abnormal.shape, iu_abnormal.mean(), iu_normal.shape, iu_normal.mean())    
    txt.write('iu.mean ' + str(iu.mean())[0:8] + 
              ' iu_abnormal ' + str(iu_abnormal.mean())[0:8] + 
              ' iu_normal ' + str(iu_normal.mean())[0:8] + '\n')
    txt.write('dice.mean ' + str(dice.mean())[0:8] + 
              ' dice_abnormal ' + str(dice_abnormal.mean())[0:8] + 
              ' dice_normal ' + str(dice_normal.mean())[0:8] + '\n')
    txt.write('iu.mean ' + str(iu.mean())[0:8] + 'iu.std' + str(iu.std())[0:8] + '\n')
    txt.write('dice.mean ' + str(dice.mean())[0:8] + 'dice.std' + str(dice.std())[0:8] + '\n')
    return APs

# Pick a set of random images
image_ids = np.random.choice(dataset.image_ids, 10)
APs = compute_batch_ap(dataset.image_ids, iou_threshold=0.5)
print("mAP @ IoU=80: ", np.mean(APs))

In [None]:
images_dir = '/media/Disk/wangfuyu/Mask_RCNN/crop_results/cxr/res18_ose_kmeans1/512_320/images/'
mrcnn_masks_dir = '/media/Disk/wangfuyu/Mask_RCNN/crop_results/cxr/res18_ose_kmeans1/512_320/mrcnn_masks/'
masks_dir = '/media/Disk/wangfuyu/Mask_RCNN/crop_results/cxr/res18_ose_kmeans1/512_320/masks/'


isExists=os.path.exists(images_dir)
if not isExists:
    os.makedirs(images_dir) 

isExists=os.path.exists(mrcnn_masks_dir)
if not isExists:
    os.makedirs(mrcnn_masks_dir) 
    
isExists=os.path.exists(masks_dir)
if not isExists:
    os.makedirs(masks_dir) 

# train_txt = open('/media/Disk/wangfuyu/Mask_RCNN/crop_results/cxr/res18_ose_kmeans1/512_320/train_id.txt', 'w')
# val_txt = open('/media/Disk/wangfuyu/Mask_RCNN/crop_results/cxr/res18_ose_kmeans1/512_320/val_id.txt', 'w')
# val_box_info = open('/media/Disk/wangfuyu/Mask_RCNN/crop_results/cxr/res18_ose_kmeans1/512_320/box_info.txt', 'w')    


def _dice(y_pred, y_true):
    smooth = 1.
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    intersection = (y_true_f * y_pred_f).sum()
    score = (2. * intersection + smooth) / (y_true_f.sum() + y_pred_f.sum() + smooth)
    return score

def _fast_hist(label_pred, label_true, num_classes):
    mask = (label_true >= 0) & (label_true < num_classes)
    hist = np.bincount(
        num_classes * label_true[mask].astype(int) +
        label_pred[mask], minlength=num_classes ** 2).reshape(num_classes, num_classes)
    return hist

def compute_IOU(lp, lt, num_classes=2):
    hist = np.zeros((num_classes, num_classes))
    hist += _fast_hist(lp.flatten(), lt.flatten(), num_classes)
    dice = _dice(y_pred=lp, y_true=lt)
    # axis 0: gt, axis 1: prediction
    acc = np.diag(hist).sum() / hist.sum()
    acc_cls = np.diag(hist) / hist.sum(axis=1)
    acc_cls = np.nanmean(acc_cls)
    iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
    # mean_iu = np.nanmean(iu)
    freq = hist.sum(axis=1) / hist.sum()
    fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()

    return acc, acc_cls, iu, fwavacc, dice

    

# Compute VOC-style Average Precision
def compute_batch_ap(image_ids, iou_threshold=0.5):
    APs = []
    IMAGE_MIN_DIM = 320
    IMAGE_MAX_DIM = 512
    iu = np.zeros(len(image_ids))
    for index, image_id in enumerate(image_ids):
        # Load image
        image, image_meta, row_wave, col_wave, gt_class_id, gt_bbox, gt_mask =\
            modellib.load_image_gt(dataset, config,
                                   image_id, use_mini_mask=False)
        # Run object detection
        mrcnn_results = model.detect([image], np.stack([row_wave]), np.stack([col_wave]), verbose=0)
            
        
        info = dataset.image_info[image_id]
        filename = info['filename']
#         filename = info['filename'][0:-4]    ##JSRT
        
        
        mrcnn_result = mrcnn_results[0]


        for i in range(gt_mask.shape[2]):
            max_iou = 0.
            final_index = 0    ## left or right lung
#             print(gt_bbox.shapey)
            y1, x1, y2, x2 = 0, 0, 0, 0
            for j in range(mrcnn_result['masks'].shape[2]):
                iou = compute_IOU(mrcnn_result['masks'][:,:,j], gt_mask[:,:,i])
                if max_iou < iou[2][1]:
                    max_iou = iou[2][1]
                    final_index = j
                    y1, x1, y2, x2 = mrcnn_result['rois'][j]

                        
            
            y1 = max(y1 - int((y2 - y1) * 0.25), 0)
            x1 = max(x1 - int((x2 - x1) * 0.25), 0)
            y2 = min(y2 + int((y2 - y1) * 0.25), 1024)
            x2 = min(x2 + int((x2 - x1) * 0.25), 1024)
            
            
#             print (gt_mask.shape)

            crop_image = image[y1:y2, x1:x2, :]
            crop_mask = gt_mask[y1:y2, x1:x2, i] 
            crop_mrcnn_mask = mrcnn_result['masks'][y1:y2, x1:x2, final_index]


#             crop_image = cv2.resize(crop_image.astype(np.uint8), (IMAGE_MIN_DIM, IMAGE_MAX_DIM),
#                                     interpolation=cv2.INTER_LINEAR)
#             crop_mask = cv2.resize(crop_mask.astype(np.uint8), (IMAGE_MIN_DIM, IMAGE_MAX_DIM),
#                                    interpolation=cv2.INTER_LINEAR)
#             crop_mrcnn_mask = cv2.resize(crop_mrcnn_mask.astype(np.uint8), (IMAGE_MIN_DIM, IMAGE_MAX_DIM),
#                                          interpolation=cv2.INTER_LINEAR)
            
            
#             plt.figure(figsize=(20,20))
#             plt.subplot(3,1,1)
#             plt.imshow(crop_image)
#             plt.subplot(3,1,2)
#             plt.imshow(crop_mask)
#             plt.subplot(3,1,3)
#             plt.imshow(crop_mrcnn_mask)
#             plt.show()
            
            savename = filename + '_' + str(i)
            print(savename + ' ' + str(y1) + ' ' + str(x1) + ' ' + str(y2) + ' ' + str(x2))

#             val_box_info.write(savename + ' ' + str(y1) + ' ' + str(x1) + 
#                                ' ' + str(y2) + ' ' + str(x2) + '\n')
#             cv2.imwrite(images_dir + savename + '.jpg', crop_image, 
#                         [int(cv2.IMWRITE_PNG_COMPRESSION), 9])
            
#             cv2.imwrite(masks_dir + savename + '.png', crop_mask, 
#                         [int(cv2.IMWRITE_PNG_COMPRESSION), 9])
            
#             cv2.imwrite(mrcnn_masks_dir + savename + '.png', crop_mrcnn_mask, 
#                         [int(cv2.IMWRITE_PNG_COMPRESSION), 9])
            
#             val_txt.write(savename + '\n')
#             print(savename)
#             train_txt.write(savename + '\n')

            

# Pick a set of random images

compute_batch_ap(dataset.image_ids, iou_threshold=0.5)
# for image_id in dataset.image_ids:
#     print(image_id, dataset.image_info[image_id]['filename'])





In [None]:
# Compute IoU
import os
import sys
import random
import math
import re
import time
import colorsys
import numpy as np
import cv2
import tensorflow as tf




gt_mask_dir = '/media/Disk/wangfuyu/Mask_RCNN/data/cxr/binary_masks/'
back_mask_dir = '/media/Disk/wangfuyu/Mask_RCNN/refine/HED/cxr/res18_ose_kmeans1/512_320/back/'
iou_dir = '/media/Disk/wangfuyu/Mask_RCNN/refine/HED/cxr/res18_ose_kmeans1/512_320/iou/'
txt = open('/media/Disk/wangfuyu/Mask_RCNN/refine/HED/cxr/res18_ose_kmeans1/512_320/record.txt', 'w')

# gt_mask_dir = '/media/Disk/wangfuyu/Mask_RCNN/data/cxr/800/JSRT/binary_masks/'
# back_mask_dir = '/media/Disk/wangfuyu/Mask_RCNN/unet/results/crop_preserve/JSRT/800/renet_C5_wave/512_320/back/'
# iou_dir = '/media/Disk/wangfuyu/Mask_RCNN/unet/results/crop_preserve/JSRT/800/renet_C5_wave/512_320/iou/'
# txt = open('/media/Disk/wangfuyu/Mask_RCNN/unet/results/crop_preserve/JSRT/800/renet_C5_wave/512_320/record.txt', 'w')

isExists=os.path.exists(iou_dir)
if not isExists:
    os.makedirs(iou_dir) 

def _dice(y_pred, y_true):
    smooth = 1.
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    intersection = (y_true_f * y_pred_f).sum()
    score = (2. * intersection + smooth) / (y_true_f.sum() + y_pred_f.sum() + smooth)
    return score

def _fast_hist(label_pred, label_true, num_classes):
    mask = (label_true >= 0) & (label_true < num_classes)
    hist = np.bincount(
        num_classes * label_true[mask].astype(int) +
        label_pred[mask], minlength=num_classes ** 2).reshape(num_classes, num_classes)
    return hist

def compute_IOU(lp, lt, num_classes=2):
    hist = np.zeros((num_classes, num_classes))
    hist += _fast_hist(lp.flatten(), lt.flatten(), num_classes)
    dice = _dice(y_pred=lp, y_true=lt)
    # axis 0: gt, axis 1: prediction
    acc = np.diag(hist).sum() / hist.sum()
    acc_cls = np.diag(hist) / hist.sum(axis=1)
    acc_cls = np.nanmean(acc_cls)
    iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
    # mean_iu = np.nanmean(iu)
    freq = hist.sum(axis=1) / hist.sum()
    fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()

    return acc, acc_cls, iu, fwavacc, dice


# Compute VOC-style Average Precision
def compute_batch_ap(image_ids):    
    txt.write('filename' + ' ' + 'image_id' + ' ' + 'iu' + ' ' + 'dice' + '\n')
    
    iu, dice = np.zeros(len(image_ids)), np.zeros(len(image_ids))
    iu_normal, iu_abnormal = [], []
    dice_normal, dice_abnormal = [], []

    for index, image_id in enumerate(image_ids):
        # Load image
        image, image_meta, _, _, gt_class_id, gt_bbox, gt_mask =\
            modellib.load_image_gt(dataset, config,
                                   image_id, use_mini_mask=False)

        gt_mask = np.sum(gt_mask, axis=2)
        
        info = dataset.image_info[image_id]
        filename = info['filename']
#         filename = info['filename'][0:-4]   ## JSRT

#         print(filename)
        back = cv2.imread(back_mask_dir + filename + '.png', cv2.COLOR_BGR2GRAY)
        _, back_mask = cv2.threshold(back, 127, 1, cv2.THRESH_BINARY)
        

        
#         plt.figure(figsize=(20,20))
#         plt.subplot(2, 1, 1),
#         plt.imshow(gt_mask)
#         plt.subplot(2, 1, 2)
#         plt.imshow(back_mask)
#         plt.show()
#             print (compute_IOU(predictions, gts)[2][1])
        print(back_mask.shape, gt_mask.shape)
        evals = compute_IOU(back_mask, gt_mask)
        iu[index] = evals[2][1]
        dice[index] = evals[4]
        if (filename[0:2] == 'CA'):
            iu_abnormal.append(iu[index])
            dice_abnormal.append(dice[index])
        else:
            iu_normal.append(iu[index])
            dice_normal.append(dice[index])
        print(filename, iu[index])
        
        savename = filename + '_' + str(iu[index])[0:8]
        txt.write(filename + ' ' + str(image_id) + ' ' + str(iu[index])[0:8] 
                  + ' ' + str(dice[index])[0:8] + '\n')
        cv2.imwrite(os.path.join(iou_dir, savename + '.png'), back, 
                    [int(cv2.IMWRITE_PNG_COMPRESSION), 9])
            
#     iu_abnormal = np.stack(iu_abnormal)
#     iu_normal = np.stack(iu_normal)
#     dice_abnormal = np.stack(dice_abnormal)
#     dice_normal = np.stack(dice_normal)
    print(iu.mean(), iu.std(), dice.mean(), dice.std())
#     print(iu_abnormal.shape, iu_abnormal.mean(), iu_normal.shape, iu_normal.mean())
    #     print(iu_abnormal.shape, iu_abnormal.mean(), iu_normal.shape, iu_normal.mean()) 
#     txt.write('iu.mean ' + str(iu.mean())[0:8] + 
#               ' iu_abnormal ' + str(iu_abnormal.mean())[0:8] + 
#               ' iu_normal ' + str(iu_normal.mean())[0:8] + '\n')
#     txt.write('dice.mean ' + str(dice.mean())[0:8] + 
#               ' dice_abnormal ' + str(dice_abnormal.mean())[0:8] + 
#               ' dice_normal ' + str(dice_normal.mean())[0:8] + '\n')
    txt.write('iu.mean ' + str(iu.mean())[0:8] + ' ' + str(iu.std())[0:8] + '\n')
    txt.write('dice.mean ' + str(dice.mean())[0:8] + ' ' + str(dice.mean())[0:8] + '\n')

compute_batch_ap(dataset.image_ids)

In [None]:
# Compute IoU
import os
import sys
import random
import math
import re
import time
import colorsys
import numpy as np
import cv2
import tensorflow as tf




gt_mask_dir = '/media/Disk/wangfuyu/Mask_RCNN/crop_results/cxr/res18_ose_kmeans1/512_320/masks/'
box_mask_dir = '/media/Disk/wangfuyu/Mask_RCNN/refine/HED/cxr/res18_ose_kmeans1/512_320/Segresult/'


def _dice(y_pred, y_true):
    smooth = 1.
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    intersection = (y_true_f * y_pred_f).sum()
    score = (2. * intersection + smooth) / (y_true_f.sum() + y_pred_f.sum() + smooth)
    return score

def _fast_hist(label_pred, label_true, num_classes):
    mask = (label_true >= 0) & (label_true < num_classes)
    hist = np.bincount(
        num_classes * label_true[mask].astype(int) +
        label_pred[mask], minlength=num_classes ** 2).reshape(num_classes, num_classes)
    return hist

def compute_IOU(lp, lt, num_classes=2):
    hist = np.zeros((num_classes, num_classes))
    hist += _fast_hist(lp.flatten(), lt.flatten(), num_classes)
    dice = _dice(y_pred=lp, y_true=lt)
    # axis 0: gt, axis 1: prediction
    acc = np.diag(hist).sum() / hist.sum()
    acc_cls = np.diag(hist) / hist.sum(axis=1)
    acc_cls = np.nanmean(acc_cls)
    iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
    # mean_iu = np.nanmean(iu)
    freq = hist.sum(axis=1) / hist.sum()
    fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()

    return acc, acc_cls, iu, fwavacc, dice


# Compute VOC-style Average Precision
def compute_batch_ap():    
    
    iu, dice = [], []
    
    for filename in os.listdir(box_mask_dir):
#         filename.sort()
        print (filename)
        gt_mask = cv2.imread(gt_mask_dir + filename, cv2.COLOR_BGR2GRAY)
        box = cv2.imread(box_mask_dir + filename, cv2.COLOR_BGR2GRAY)
    
        print(gt_mask.shape, box.shape)
        _, box_mask = cv2.threshold(box, 127, 1, cv2.THRESH_BINARY)
            
            
        evals = compute_IOU(box_mask, gt_mask)
        iu.append(evals[2][1])
        dice.append(evals[4])
            
        print (filename, evals[2][1])   
        
    
    iu = np.stack(iu)
    print(iu.mean())
        

compute_batch_ap()