# Detect and crop images
- using model(resnet50) pre-trained on COCO

In [1]:
import torchvision
import numpy
from PIL import Image
import cv2
import torchvision.transforms as T
import matplotlib.pyplot as plt
import os 
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()

FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256)
          (relu): ReLU(inplace=True)
          (downsample)

In [2]:
# label list => COCO data list : we are using 'dog' or 'cat'
COCO_INSTANCE_CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]


In [3]:
# prediction + bounding box 좌표 찾기
# 검출이 안됐을 시 error = 1 로 설정.
def get_prediction(img_path, threshold):
    '''
    get prediction values and prediction bbox coordinate
    :param img_path: target image 경로
    :param threshold: prediction value의 threshold 값 설정
    '''
    #####한국어 경로 지원#####
    stream = open( img_path.encode("utf-8") , "rb")
    bytes = bytearray(stream.read())
    numpyArray = numpy.asarray(bytes, dtype=numpy.uint8)
    img = cv2.imdecode(numpyArray , cv2.IMREAD_UNCHANGED)
  
    
    src = img.copy() 
    error = 0 # error 발생시 return하는 값(정상 상태일 때 error = 0)

    transform = T.Compose([T.ToTensor()]) # Defing PyTorch Transform
    img = transform(img) # Apply the transform to the image
    pred = model([img]) # Pass the image to the model
    pred_class = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].numpy())] # Get the Prediction Score
    pred_boxes = [[(i[0], i[1]), (i[2], i[3])] for i in list(pred[0]['boxes'].detach().numpy())] # Bounding boxes
    pred_score = list(pred[0]['scores'].detach().numpy())

    #####예외처리#####
    try:
        pred_t = [pred_score.index(x) for x in pred_score if x > threshold][-1] # Get list of index with score greater than threshold.
    except IndexError as e:
        print(e, ' 검출 x')
        error = 1
        pred_t = 0
        
    pred_boxes = pred_boxes[:pred_t+1]
    pred_class = pred_class[:pred_t+1]
    return pred_boxes, pred_class, src, error


In [4]:
# 강아지 검출 시 bbox 좌표를 더 크게 잡아서 crop 함수로 전달 
def object_detection_api(img_path, threshold, rect_th=1, text_size=1, text_th=1):
    '''
     get 'dog' bbox coordinate
    :param img_path: target image 경로
    :param threshold: prediction value의 threshold 값 설정
    :param  rect_th, text_size, text_th: Debug용 파라미터, default=1, bbox와 class이름을 image위에 그려줌
    '''
    boxes, pred_cls, img, error = get_prediction(img_path, threshold) # Get predictions
    
    #####get_prediction에서 Dog를 못찾았을 때#####
    if error == 1:
        return (0,0),(0,0),0,error
    
    img2 = img
    for i in range(len(boxes)):
        if pred_cls[i] == 'dog':
            #cv2.rectangle(img, boxes[i][0], boxes[i][1],color=(0, 255, 0), thickness=rect_th) # Draw Rectangle with the coordinates
            print('pred_class = ',pred_cls)
            #print(boxes[i][0],boxes[i][1])
            lt = boxes[i][0]
            ld = (boxes[i][0][0], boxes[i][1][1])
            rt = (boxes[i][1][0], boxes[i][0][1])
            rd = boxes[i][1]
#           cv2.putText(img,pred_cls[i], boxes[i][0],  cv2.FONT_HERSHEY_SIMPLEX, text_size, (0,255,0),thickness=text_th) # Write the prediction class
            
    
            #####bbox 좌표를 더 크게 잡아준다.#####
            if int(lt[0] - 20) < 0:
                broad_lt_0 = 0
            else:
                broad_lt_0 = int(lt[0] - 20)
            if int(lt[1] - 20) < 0:
                broad_lt_1 = 0
            else:
                broad_lt_1 = int(lt[1] - 20)
            boxes_broad_lt = (broad_lt_0 , broad_lt_1)
            boxes_broad_rd = (int(rd[0] + 20) , int(rd[1] + 20))
            
    return boxes_broad_lt, boxes_broad_rd, img2, error


In [5]:
# 한국어 경로 지원되는 image write 함수
def imwrite(filename, img, params=None):
    try:
        ext = os.path.splitext(filename)[1]
        result, n = cv2.imencode(ext, img, params)
        if result:
            with open(filename, mode='w+b') as f:
                n.tofile(f) 
            return True 
        else:
            return False
    except Exception as e:
        print(e)
        return False

In [6]:
#  이미지를 size에 맞게 crop후 저장하는 함수 
def crop_save(img_path, dst_name, size= 240):
    '''
     crop and save dog image
    :param img_path: target image 경로
    :param dst_name: crop된 이미지가 저장될 파일 이름
    :param size: crop 된 이미지의 size를 결정, default = (240,240)
    '''
    
    lt, rd, src, error = object_detection_api(img_path, threshold=0.8)
    if error == 1: # error 발생시 조기종료
        return 
    
    #dst_name = img_path.split('/')[-1].split('.jpg')[0] +'_crop'+'.jpg'
    #print(lt,rd)
    roi = src[lt[1]:rd[1], lt[0]:rd[0]]
#    plt.imshow(roi)
    print('dst_name = ',dst_name)
    dst = cv2.resize(roi, dsize=(size, size), interpolation=cv2.INTER_AREA)
    imwrite(dst_name,dst)

In [7]:
# main code
# 대상 이미지 class가 모여있는 폴더를 path_file에 입력. 

path_file = './' + 'sample'

for count, foldername in enumerate(os.listdir(path_file)): 
    print(count, foldername)
    for count, filename in enumerate(os.listdir(path_file+ '/' +foldername)): 
        print(count, filename)
        path_src_file = path_file+'/'+foldername + '/'+filename
        #dst_name = filename.split('/')[-1].split('.jpg')[0] +'_crop'+'.jpg'
        #path_dst_file =  path_file +'/'+foldername + '/'+ dst_name
        print(path_src_file)
        crop_save(path_src_file, path_src_file)

0 example
0 믹스견_전남-광양-2020-00276.jpg
./sample/example/믹스견_전남-광양-2020-00276.jpg


	nonzero()
Consider using one of the following signatures instead:
	nonzero(*, bool as_tuple) (Triggered internally at  ..\torch\csrc\utils\python_arg_parser.cpp:766.)
  keep = keep.nonzero().squeeze(1)


pred_class =  ['dog', 'bench']
dst_name =  ./sample/example/믹스견_전남-광양-2020-00276.jpg
1 믹스견_전남-광양-2020-00277.jpg
./sample/example/믹스견_전남-광양-2020-00277.jpg
pred_class =  ['dog', 'bench']
dst_name =  ./sample/example/믹스견_전남-광양-2020-00277.jpg
2 믹스견_전남-광양-2020-00278.jpg
./sample/example/믹스견_전남-광양-2020-00278.jpg
pred_class =  ['dog']
dst_name =  ./sample/example/믹스견_전남-광양-2020-00278.jpg
3 믹스견_전남-광양-2020-00279.jpg
./sample/example/믹스견_전남-광양-2020-00279.jpg
pred_class =  ['dog']
dst_name =  ./sample/example/믹스견_전남-광양-2020-00279.jpg
4 믹스견_전남-구례-2020-00143.jpg
./sample/example/믹스견_전남-구례-2020-00143.jpg
pred_class =  ['bear', 'dog']
dst_name =  ./sample/example/믹스견_전남-구례-2020-00143.jpg
