In [58]:
import keras
from keras.models import Model
from keras.preprocessing import image
import numpy as np
import tensorflow as tf
import os
import cv2

from ssd import SSD300
from ssd_utils import BBoxUtility

In [2]:
classes = ['upper', 'lower', 'full']
NUM_CLASSES = len(classes) + 1

input_shape=(300, 300, 3)
model = SSD300(input_shape, num_classes=NUM_CLASSES)
model.load_weights('models/ssd_2.28_20180530.h5', by_name=True)
bbox_util = BBoxUtility(NUM_CLASSES)

In [80]:
def crop(crop_type, image, result):
    """crop the image.

    # Arguments
    crop_type: upper/lower/full
    image: ndarray
    result: array: label,conf,xmin,ymin,xmax,ymax

    # Returns
    image: ndarray, region image
    """
    img = image.copy()
    label,conf,xmin,ymin,xmax,ymax = result
    
    if conf >= 0.6:
        xmin = int(round(xmin * img.shape[1]))
        ymin = int(round(ymin * img.shape[0]))
        xmax = int(round(xmax * img.shape[1]))
        ymax = int(round(ymax * img.shape[0]))
    
        if crop_type == "upper" and result[0] == 1.0:
            img = img[ymin:ymax,xmin:xmax]
        elif crop_type == "lower" and result[0] == 2.0:
            img = img[ymin:ymax,xmin:xmax]
        elif crop_type == "full" and result[0] == 3.0:
            img = img[ymin:ymax,xmin:xmax]
        else:
            img = None
    else:
        img = None
    
    return img

In [81]:
m = 0
crop_type = "upper" # upper/lower/full 上半身、下半身、全身 三种剪裁类型

for (root, dirs, files) in os.walk('img/test'):
    if files:
        for f in files:
            path = os.path.join(root, f)
            image = cv2.resize(cv2.imread(path),(300,300))
            inputs = np.expand_dims(image, axis=0)
            preds = model.predict(inputs, batch_size=1, verbose=1)
            result = bbox_util.detection_out(preds)[0][0] # 只考虑最大可能性，不考虑一张图片多次检测的情况
            image = crop(crop_type, image, result)
            if image is not None:
                cv2.imwrite('img/crop/' + "crop_" + f, image)
                m += 1
                if m%2 == 0:
                    print('{0} images have {1} part cropped .'.format(m, crop_type))

2 images have upper part cropped .
4 images have upper part cropped .
6 images have upper part cropped .
8 images have upper part cropped .
