In [None]:
import os
import sys
import numpy as np
import tensorflow as tf


from matplotlib import pyplot as plt
from scipy.misc import imread, imresize

if tf.__version__ != '1.4.0':
  raise ImportError('Please upgrade your tensorflow installation to v1.4.0!')

%matplotlib inline

In [None]:
BASE_DIR = '/home/wenfeng/all-files/skin-lesion-seg-v2'

In [None]:
import crf
import inputs
import my_utils
import evaluation

from sklearn.model_selection import KFold
%load_ext autoreload
%autoreload 2

In [None]:
config = my_utils.load_config(os.path.join(BASE_DIR, 'config.json'))
image_config = my_utils.load_config(os.path.join(BASE_DIR, 'image_config.json'))

In [None]:
class RestoredModel:
    def __init__(self, ckpt_file):
        self.graph = tf.Graph()
        with self.graph.as_default() as g:
            with tf.device('/cpu'):
                od_graph_def = tf.GraphDef()
                with tf.gfile.GFile(ckpt_file, 'rb') as fid:
                    sg = fid.read()
                    od_graph_def.ParseFromString(sg)
                    tf.import_graph_def(od_graph_def, name='')

                self.image_ph = g.get_tensor_by_name('image_tensor:0')
                self.bboxes = g.get_tensor_by_name('detection_boxes:0')
                self.scores = g.get_tensor_by_name('detection_scores:0')
                self.n_bboxes = g.get_tensor_by_name('num_detections:0')
    
    def inference_box(self, image):
        sess = tf.get_default_session()
        image = image[None] if len(image.shape) == 3 else image
        return sess.run(self.bboxes, feed_dict={self.image_ph: image})[0, 0]

In [None]:
fold = 0
PATH_TO_CKPT = os.path.join(BASE_DIR, 'training/train/%d/output_inference_graph.pb/frozen_inference_graph.pb' % fold)

In [None]:
PATH_TO_CKPT

In [None]:
mm = RestoredModel(PATH_TO_CKPT)

In [None]:
dermquest = inputs.load_raw_data('dermquest', config)
train_data = inputs.get_kth_fold(dermquest, fold, config['n_folds'], seed=config['split_seed'])
test_data = inputs.get_kth_fold(dermquest, fold, config['n_folds'], seed=config['split_seed'], type_='test')

In [None]:
len(dermquest), len(train_data), len(test_data)

In [None]:
base = train_data.listing[0]
img1, _, _ = inputs.load_one_example(base)
print(img1.shape)
img1, _, _ = inputs.load_one_example(base, smallest_to=400)
print(img1.shape)
img1, _, _ = inputs.load_one_example(base, highest_to=400)
print(img1.shape)
img1, _, _ = inputs.load_one_example(base, size=(400, 400))
print(img1.shape)

In [None]:
path_base = test_data.listing[31]
with mm.graph.as_default() as g:
    with tf.Session(graph=g, config=tf.ConfigProto(device_count={'GPU': 0})):
        image = imread(path_base + '_orig.jpg')
        label = imread(path_base + '_contour.png')
        label[label == 255] = 1
        bbox_gt = my_utils.calc_bbox(label)
        
        plt.figure(figsize=(20, 10))
        plt.subplot(211)
        bbox_pred = mm.inference_box(image)
        plt.imshow(image)
        
        top, left, height, width = my_utils.bbox_xy_to_tlwh(bbox_pred, image.shape[:2])
        plt.gca().add_patch(plt.Rectangle((left, top), width, height, alpha=0.2, color='b'))
        
        top, left, height, width = bbox_gt
        plt.gca().add_patch(plt.Rectangle((left, top), width, height, alpha=0.2, color='r'))
        
        
        plt.subplot(212)
        label[label == 0] = 255
        top, left, height, width = my_utils.bbox_xy_to_tlwh(bbox_pred, image.shape[:2])
        plt.gca().add_patch(plt.Rectangle((left, top), width, height, alpha=0.2, color='b'))

        top, left, height, width = bbox_gt
        plt.gca().add_patch(plt.Rectangle((left, top), width, height, alpha=0.2, color='r'))
        
        plt.imshow(label, cmap='gray')
        
        plt.show()
        print(my_utils.calc_bbox_iou(bbox_gt, my_utils.bbox_xy_to_tlwh(bbox_pred, image.shape[:2])))

In [None]:
path_base = test_data[1]
with mm.graph.as_default() as g:
    with tf.Session(graph=g, config=tf.ConfigProto(device_count={'GPU': 0})):
        for i, base in enumerate(test_data.listing):
            image, label, bbox_gt = inputs.load_one_example(base)

            bbox_pred = mm.inference_box(image)
            bbox_pred = my_utils.bbox_xy_to_tlwh(bbox_pred, image.shape[:2])
            iou_i = my_utils.calc_bbox_iou(bbox_gt, bbox_pred)
            if iou_i < 0.5:
                print(i, iou_i, '----------->')
            else:
                print(i, iou_i)

In [None]:
with mm.graph.as_default() as g:
    result = {
        'TP': 0,
        'TN': 0,
        'FP': 0,
        'FN': 0
    }
    def update_dict(target, to_update):
        for key in to_update:
            target[key] += to_update[key]
    with tf.Session(graph=g, config=tf.ConfigProto(device_count={'GPU': 0})):
        for i, base in enumerate(test_data.listing):
            image, label, bbox_gt = inputs.load_one_example(base, highest_to=600)
            result_i, _ = evaluation.inference_with_restored_model(mm, image, label,
                                                                   bbox_gt=bbox_gt,
                                                                   verbose=False, 
                                                                   times=3,
                                                                   gt_prob=0.51)
            if _['IoU'] < 0.5:
                print('---->')
                # continue
            update_dict(result, result_i)
            result_i.update(my_utils.metric_many_from_counter(result_i))
            # print(i, result_i)
        result.update(my_utils.metric_many_from_counter(result))
        print(result)

In [None]:
def show_one_result(image, label, label_pred, bbox_gt, bbox_pred):
    plt.figure(figsize=(30, 20))
    plt.subplot(311)
    plt.imshow(image)
    top, left, height, width = bbox_gt
    plt.gca().add_patch(plt.Rectangle((left, top), width, height, alpha=0.2, color='b'))

    top, left, height, width = bbox_pred
    plt.gca().add_patch(plt.Rectangle((left, top), width, height, alpha=0.2, color='r'))
    
    plt.subplot(312)
    plt.imshow(label_pred, cmap='gray')
    
    plt.subplot(313)
    plt.imshow(label, cmap='gray')
    plt.show()

In [None]:
sess = tf.Session(graph=mm.graph, config=tf.ConfigProto(device_count={'GPU': 0}))
sess.__enter__()

In [None]:
%%timeit -r 5 -n 1
image, label, bbox_gt = inputs.load_one_example(test_data.listing[7], highest_to=600)
result, prediction = evaluation.inference_with_restored_model(mm, image, label, bbox_gt, times=3, gt_prob=0.51)
# print(result)
label_pred, bbox_pred = prediction['label'], prediction['bbox']
# show_one_result(image, label, label_pred, bbox_gt, bbox_pred)

In [None]:
sess.__exit__(None, None, None)

In [None]:
label_pred.shape