In [1]:
import cv2
import time
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_V2_Weights

from PIL import Image

In [2]:
np.random.seed(42)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

weights = FasterRCNN_ResNet50_FPN_V2_Weights.COCO_V1

In [3]:
# Define the torchvision image transforms.
transform = transforms.Compose([
    transforms.ToTensor(),
])

In [4]:
 weights.meta["categories"]

['__background__',
 'person',
 'bicycle',
 'car',
 'motorcycle',
 'airplane',
 'bus',
 'train',
 'truck',
 'boat',
 'traffic light',
 'fire hydrant',
 'N/A',
 'stop sign',
 'parking meter',
 'bench',
 'bird',
 'cat',
 'dog',
 'horse',
 'sheep',
 'cow',
 'elephant',
 'bear',
 'zebra',
 'giraffe',
 'N/A',
 'backpack',
 'umbrella',
 'N/A',
 'N/A',
 'handbag',
 'tie',
 'suitcase',
 'frisbee',
 'skis',
 'snowboard',
 'sports ball',
 'kite',
 'baseball bat',
 'baseball glove',
 'skateboard',
 'surfboard',
 'tennis racket',
 'bottle',
 'N/A',
 'wine glass',
 'cup',
 'fork',
 'knife',
 'spoon',
 'bowl',
 'banana',
 'apple',
 'sandwich',
 'orange',
 'broccoli',
 'carrot',
 'hot dog',
 'pizza',
 'donut',
 'cake',
 'chair',
 'couch',
 'potted plant',
 'bed',
 'N/A',
 'dining table',
 'N/A',
 'N/A',
 'toilet',
 'N/A',
 'tv',
 'laptop',
 'mouse',
 'remote',
 'keyboard',
 'cell phone',
 'microwave',
 'oven',
 'toaster',
 'sink',
 'refrigerator',
 'N/A',
 'book',
 'clock',
 'vase',
 'scissors',
 'ted

In [5]:
COCO_INSTANCE_CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

In [6]:
# Create different colors for each class.
COLORS = np.random.uniform(0, 255, size=(len(COCO_INSTANCE_CATEGORY_NAMES), 3))

In [7]:
def get_model(device='cpu'):
    # Load the model.
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(weights=weights)
    
    # Load the model onto the computation device.
    model = model.eval().to(device)
    
    return model

In [8]:
def predict(image, model, device, detection_threshold):
    """
    Predict the output of an image after forward pass through
    the model and return the bounding boxes, class names, and 
    class labels. 
    """
    # Transform the image to tensor.
    image = transform(image).to(device)
    # Add a batch dimension.
    image = image.unsqueeze(0) 
    # Get the predictions on the image.
    with torch.no_grad():
        outputs = model(image) 

    # Get score for all the predicted objects.
    pred_scores = outputs[0]['scores'].detach().cpu().numpy()

    # Get all the predicted bounding boxes.
    pred_bboxes = outputs[0]['boxes'].detach().cpu().numpy()
    # Get boxes above the threshold score.
    boxes = pred_bboxes[pred_scores >= detection_threshold].astype(np.int32)
    labels = outputs[0]['labels'][:len(boxes)]
    # Get all the predicited class names.
    pred_classes = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in labels.cpu().numpy()]

    return boxes, pred_classes, labels

In [9]:
def draw_boxes(boxes, classes, labels, image):
    """
    Draws the bounding box around a detected object.
    """
    lw = max(round(sum(image.shape) / 2 * 0.003), 2)  # Line width.
    tf = max(lw - 1, 1) # Font thickness.

    for i, box in enumerate(boxes):
        color = COLORS[labels[i]]
        cv2.rectangle(
            img=image,
            pt1=(int(box[0]), int(box[1])),
            pt2=(int(box[2]), int(box[3])),
            color=color[::-1], 
            thickness=lw
        )
        cv2.putText(
            img=image, 
            text=classes[i], 
            org=(int(box[0]), int(box[1]-5)),
            fontFace=cv2.FONT_HERSHEY_SIMPLEX, 
            fontScale=lw / 3, 
            color=color[::-1], 
            thickness=tf, 
            lineType=cv2.LINE_AA
        )

    return image


In [10]:
def detectImage(input, threshold=0.5):
    # Define the computation device.
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = get_model(device)
    
    # Read the image.
    image = Image.open(input).convert('RGB')
    # Create a BGR copy of the image for annotation.
    image_bgr = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    # Detect outputs.
    with torch.no_grad():
        boxes, classes, labels = predict(image, model, device, threshold)
    # Draw bounding boxes.
    image = draw_boxes(boxes, classes, labels, image_bgr)
    save_name = f"{input.split('/')[-1].split('.')[0]}_t{''.join(str(threshold).split('.'))}"
    cv2.imshow('Image', image)
    cv2.imwrite(f"outputs/{save_name}.jpg", image)
    cv2.waitKey(0)


In [11]:
def detectVideo(input, threshold=0.5):
    # Define the computation device.
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = get_model(device)
    
    cap = cv2.VideoCapture(input)
    
    if (cap.isOpened() == False):
        print('Error while trying to read video. Please check path again')
    
    # Get the frame width and height.
    frame_width = int(cap.get(3))
    frame_height = int(cap.get(4))
    
    save_name = f"{input.split('/')[-1].split('.')[0]}_t{''.join(str(threshold).split('.'))}"
    # Define codec and create VideoWriter object .
    out = cv2.VideoWriter(f"outputs/{save_name}.mp4", 
                          cv2.VideoWriter_fourcc(*'mp4v'), 30, 
                          (frame_width, frame_height))
    
    frame_count = 0 # To count total frames.
    total_fps = 0 # To get the final frames per second.
    
    # Read until end of video.
    while(cap.isOpened):
        # Capture each frame of the video.
        ret, frame = cap.read()
        if ret:
            frame_copy = frame.copy()
            frame_copy = cv2.cvtColor(frame_copy, cv2.COLOR_BGR2RGB)
            # Get the start time.
            start_time = time.time()
            with torch.no_grad():
                # Get predictions for the current frame.
                boxes, classes, labels = predict(
                    frame, model, 
                    device, threshold
                )
            
            # Draw boxes and show current frame on screen.
            image = draw_boxes(boxes, classes, labels, frame)
    
            # Get the end time.
            end_time = time.time()
            # Get the fps.
            fps = 1 / (end_time - start_time)
            # Add fps to total fps.
            total_fps += fps
            # Increment frame count.
            frame_count += 1
            # Write the FPS on the current frame.
            cv2.putText(
                img=image, 
                text=f"{fps:.3f} FPS", 
                org=(15, 30), 
                fontFace=cv2.FONT_HERSHEY_SIMPLEX,
                fontScale=1, 
                color=(0, 255, 0), 
                thickness=2,
                lineType=cv2.LINE_AA
    
            )
            # Convert from BGR to RGB color format.
            cv2.imshow('image', image)
            out.write(image)
            # Press `q` to exit.
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
    
        else:
            break
    
    # Release VideoCapture().
    cap.release()
    # Close all frames and video windows.
    cv2.destroyAllWindows()
    
    # Calculate and print the average FPS.
    avg_fps = total_fps / frame_count
    print(f"Average FPS: {avg_fps:.3f}")
    

In [12]:
# model = get_model(device)

In [13]:
# torch.save(model, "fasterrcnn_resnet50_model.pt")
# torch.save(model.state_dict(), "fasterrcnn_resnet50_model_state.pt")

In [14]:
# model_load = torch.load("conv_ant_bee_model2.pt")

In [15]:
detectImage("input/image_2.jpg")

In [16]:
detectVideo("input/video_2.mp4")

Average FPS: 4.907


In [1]:
# !python detect_image.py --input input/image_1.jpg

In [2]:
# !python detect_image.py --input input/image_2.jpg

In [3]:
# !python detect_video.py --input input/video_1.mp4

Average FPS: 4.785


In [4]:
# !python detect_image.py --input data/coco/val2017/000000057782.jpg

In [6]:
# !python detect_image.py --input data/coco/val2017/000000045084.jpg