Demonstration code for faster r-cnn object detection. With additional testing functions, to enable testing using PyTest.

In [1]:
# from https://www.learnopencv.com/faster-r-cnn-object-detection-with-pytorch/

Next line is only for this testing environment. wget is not available in Google Golab by default, so need to run shell command to install

In [2]:
!pip install wget



In [3]:
# import necessary libraries
from PIL import Image
import torch
import torchvision
import torchvision.transforms as T
import numpy as np
'''
# following imports are for object_detection_api only
import matplotlib.pyplot as plt
import cv2
'''

'\n# following imports are for object_detection_api only\nimport matplotlib.pyplot as plt\nimport cv2\n'

In [4]:
# following imports are for the test modules
import wget
import json
import pytest

In [5]:
# get the pretrained model from torchvision.models
# Note: pretrained=True will get the pretrained weights for the model.
# model.eval() to use the model for inference
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()

FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256)
          (relu): ReLU(inplace=True)
          (downsample)

In [6]:
# Class labels from official PyTorch documentation for the pretrained model
# Note that there are some N/A's
# for complete list check https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
# we will use the same list for this notebook
COCO_INSTANCE_CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

In [7]:
def get_prediction(img_path, threshold):
    """
    get_prediction
      parameters:
        - img_path - path of the input image
        - threshold - threshold value for prediction score
      method:
        - Image is obtained from the image path
        - the image is converted to image tensor using PyTorch's Transforms
        - image is passed through the model to get the predictions
        - class, box coordinates are obtained, but only prediction score > threshold
          are chosen.

    """
    img = Image.open(img_path)
    transform = T.Compose([T.ToTensor()])
    img = transform(img)
    pred = model([img])
    pred_class = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].numpy())]
    pred_boxes = [[(i[0], i[1]), (i[2], i[3])] for i in list(pred[0]['boxes'].detach().numpy())]
    pred_score = list(pred[0]['scores'].detach().numpy())
    pred_t = [pred_score.index(x) for x in pred_score if x > threshold][-1]
    pred_boxes = pred_boxes[:pred_t+1]
    pred_class = pred_class[:pred_t+1]
    return pred_boxes, pred_class

API below gives visual representation. Not used.

In [8]:
'''
def object_detection_api(img_path, threshold=0.5, rect_th=3, text_size=3, text_th=3):
    """
    object_detection_api
      parameters:
        - img_path - path of the input image
        - threshold - threshold value for prediction score
        - rect_th - thickness of bounding box
        - text_size - size of the class label text
        - text_th - thichness of the text
      method:
        - prediction is obtained from get_prediction method
        - for each prediction, bounding box is drawn and text is written
          with opencv
        - the final image is displayed
    """
    boxes, pred_cls = get_prediction(img_path, threshold)
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    for i in range(len(boxes)):
        cv2.rectangle(img, boxes[i][0], boxes[i][1], color=(0, 255, 0), thickness=rect_th)
        cv2.putText(img, pred_cls[i], boxes[i][0], cv2.FONT_HERSHEY_SIMPLEX,
                    text_size, (0, 255, 0), thickness=text_th)
    plt.figure(figsize=(20, 30))
    plt.imshow(img)
    plt.xticks([])
    plt.yticks([])
    plt.show()
'''

'\ndef object_detection_api(img_path, threshold=0.5, rect_th=3, text_size=3, text_th=3):\n    """\n    object_detection_api\n      parameters:\n        - img_path - path of the input image\n        - threshold - threshold value for prediction score\n        - rect_th - thickness of bounding box\n        - text_size - size of the class label text\n        - text_th - thichness of the text\n      method:\n        - prediction is obtained from get_prediction method\n        - for each prediction, bounding box is drawn and text is written\n          with opencv\n        - the final image is displayed\n    """\n    boxes, pred_cls = get_prediction(img_path, threshold)\n    img = cv2.imread(img_path)\n    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n    for i in range(len(boxes)):\n        cv2.rectangle(img, boxes[i][0], boxes[i][1], color=(0, 255, 0), thickness=rect_th)\n        cv2.putText(img, pred_cls[i], boxes[i][0], cv2.FONT_HERSHEY_SIMPLEX,\n                    text_size, (0, 255, 0), 

Next section defines namespace variables to locate the test data. These locations will have to be altered depending upon the environment.

The image file has strict syntax rules:
- expect all images to be .jpg on a different line
- location to be either: 1) relative to TEST_FOLDER;
- or 2) full url
- TEST_FOLDER = ./tests/ in most environments, except Colab
- comments have a hash at beginning of line

In [9]:
TEST_FOLDER = '/content/drive/My Drive/Colab Notebooks/'
IMAGE_FILE = 'images.txt'
DETECTION_THRESHOLD = 0.8

Helper function to take contents of test data file and parse it into a Python list for processing. Includes downloading in required.

Note: the original input text file is unaltered, and the downloaded file is not deleted after use. This means, if this function is run twice, two copies of the file will be downloaded. The OS will handle renaming the downloaded file to avoid overwriting.

In [10]:
def parse_images(test_images=TEST_FOLDER+IMAGE_FILE):
    # parse images list
    try:
        with open(test_images) as file:
            lines = [line.rstrip() for line in file]
        images = [line for line in lines if (line != '' and not(line.startswith('#')))]
    except:
        error_status = 'invalid image file'
        return error_status

    # download into tests folder if not already downloaded
    for i in range(len(images)):
        filename = images[i].rsplit('/', 1)[-1]
        if images[i].startswith('http'):
            try:
                wget.download(images[i], TEST_FOLDER + filename)
                images[i] = TEST_FOLDER + filename
            except:
                pass
        else:
            images[i] = TEST_FOLDER + filename
    
    # if not downloadable, remove from list
    downloaded_images = [image for image in images if not(image.startswith('http'))]
    return downloaded_images

Helper function to run image detection list of images, then return results as a JSON file if successful.

In [11]:
def detect_images(images):
    # run object detection on each image
    detections = []
    for image in images:
        _, labels = get_prediction(image, DETECTION_THRESHOLD)
        detections.append({
            'image_filename': image,
            'detections': labels
        })

    detections_json = json.dumps(detections)
    return detections_json

For Colab only, need to install helper library to run PyTest within a notebook

In [12]:
!pip install ipytest



In [13]:
import ipytest
ipytest.autoconfig()

Test functions to enable PyTest in build pipeline. Fixture (decorator) function inialises as follows:
- loads modules
- creates images list
- runs detection on all images

Test 1: 
- Is the output from detections valid JSON?

Test 2:
- Did all valid images process? (Did all parsed images yield corresponding result?)

Test 3:
- Did every image yield at least one detection?

In [37]:
%%run_pytest[clean]

@pytest.fixture
def images_detected(scope="session"):
    images = parse_images()
    result_json = detect_images(images)
    return images, result_json


def test_valid_json(images_detected):
    images, result_json = images_detected
    error = False
    try:
        r = json.loads(result_json)
        error = False
    except:
        error = True
    assert error == False


def test_confirm_all_processed(images_detected):
    images, result_json = images_detected
    error = False
    if len(images) == len(json.loads(result_json)):
        error = False
    else:
        error = True
    assert error == False


def test_at_least_one_result(images_detected):
    images, result_json = images_detected
    error = False
    print(result_json)
    for result in json.loads(result_json):
        if result['detections'] == []:
            error = True
        else:
            error = error or False
    assert error == False


....                                                                     [100%]
4 passed in 54.48s


These functions will be separated out in the final package:

src:
- object_detection.py: training model, category names, get_prediction, object_detection_api;

tests:
- detect_images.py: parse_images, detect_images;
- test_detections.py: test_prerequisites, test_valid_json, test_confirm_all_processed, test_at_least_one_result;
- images.txt