In [14]:
import numpy as np 
import pandas as pd 

import matplotlib.patches as patches

import os
import random
import xml.etree.ElementTree as ET
import torch
import torchvision
from torchvision import datasets, transforms, models
from torch.utils.data import Dataset,DataLoader
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
import cv2
from PIL import Image

import sys
import torch.optim as optim

from time import time


In [2]:
labels_dict = {1: 'Eurasian_jay',
 2: 'great_spotted_woodpecker',
 3: 'greenfinch',
 4: 'blue_tit',
 5: 'Carduelis',
 6: 'common_redpoll',
 7: 'great_tit',
 8: 'bullfinch',
 9: 'Eurasian_siskin',
 10: 'Eurasian_tree_sparrow',
 11: 'hawfinch',
 12: 'willow_tit',
 13: 'Fieldfare',
 14: 'Common chaffinch'}



In [28]:
class BirdsDetection:
    """
    Class implements detection with Faster R-CNN mobilenet trained for 14 birds species
    """

    def __init__(self, source, out_file = None,
                       weights_path = '/home/costia/faster_rcnn_mobilenet_10feb.pt',
                       conf_lvl = 0.6,labels_dict = labels_dict):
        """
        source - camera url or device
        weigths_path - path to model weights
        conf_lvl - confidence level of predictions
        labels_dict - dictionary of birds dpecies names
        
        """
        self.source = source
        self.weights_path = weights_path
        self.model = self.load_model()
        self.out_file = out_file if out_file is not None else source.split('.')[0]+"_predicted.avi" 
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.conf_lvl = conf_lvl
        self.transform = transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                        ])


    def get_video_from_source(self):
        """
        Creates a new video streaming object to extract video frame by frame to make prediction on.
        :returns: opencv2 video capture object
        """
        return cv2.VideoCapture(self.source)

    def load_model(self):
        """
        load Faster R-CNN without weights pytorch hub and  
        load trained weights.
               
        :returns: Pytorch model with weights.
        """
        #
        model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(pretrained=True)
        in_features = model.roi_heads.box_predictor.cls_score.in_features
        num_classes = 14+1 #(n_classes + background)
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
       #load weights
        
        model.load_state_dict(torch.load(self.weights_path))
        model.eval()
        return model

    def score_frame(self, frame):
        """
        Takes a single frame as input, and scores the frame .
        :param frame: input frame in numpy/list/tuple format.
        :returns: Labels and  boxes of objects detected by model in the frame.
        """
        self.model.to(self.device)
        
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        inputs = self.transform(frame)
        
        inputs = (inputs) #for img in frame]
        #plt.imshow(inputs)
        inputs.unsqueeze_(0)
        #print(inputs.shape)
        inputs = inputs.to(self.device) #for img in inputs]
        with torch.no_grad():
            preds = self.model(inputs)
        torch.cuda.empty_cache()
        
        return preds

    def class_to_label(self, label):
        """
        For a given label value, return bird species name.
        :param x: numeric label
        :return: corresponding bird species
        """
        return labels_dict [label]

    def plot_boxes(self, preds, frame):
        """
        Takes a frame and its results as input, and plots the bounding boxes and label on to the frame.
        :param results: contains labels and coordinates predicted by model on the given frame.
        :param frame: Frame which has been scored.
        :returns: Frame with bounding boxes and labels ploted on it.
        """
        
        
        mask = preds[0]['scores'] >= self.conf_lvl
        #print('Founded:',preds[0]['scores'])
        confidences = preds[0]['scores'][mask]
        boxes = preds[0]['boxes'][mask]
        labels = preds[0]['labels'][mask]
        
        flag = True
        
        for i, box in enumerate(boxes):
            x_min, y_min, x_max, y_max = map(int,box)
            bgr = (0, 255, 0)
            #print((x_min, y_min), (x_max, y_max))
            cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), bgr, 2)
            cv2.putText(frame, self.class_to_label(int(labels[i])), (x_min, y_min), cv2.FONT_HERSHEY_SIMPLEX, 0.9, bgr, 2)
        return frame
      

    def __call__(self):
        """
        This function is called when class is executed, it runs the loop to read the video frame by frame,
        and write the output into a new file.
        :returns: void
        """
        player = self.get_video_from_source()
        assert player.isOpened()
        x_shape = int(player.get(cv2.CAP_PROP_FRAME_WIDTH))
        y_shape = int(player.get(cv2.CAP_PROP_FRAME_HEIGHT))
        #four_cc = cv2.VideoWriter_fourcc(*"MJPG")
        #out = cv2.VideoWriter(self.out_file, four_cc, 30, (x_shape, y_shape))
        i = 0
        while True:
            
            start_time = time()
            try:
                ret, frame = player.read()
            except AssertionError:
                print("Video ended")
                break
            assert ret
            if i%1==0:
                results = self.score_frame(frame)
                frame = self.plot_boxes(results, frame)
                end_time = time()
                fps = 1/np.round(end_time - start_time, 3)
                #print(f"Frames Per Second : {fps}")
            
            i+=1
            cv2.imshow('frame',frame)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                player.release()
                cv2.destroyAllWindows()
                break


In [29]:
# Create a new birds detection object and execute.
detect = BirdsDetection(source = 'rtsp://admin:@192.168.1.123/user=admin_password=_channel=1_stream=1.sdp',
                        conf_lvl=0.2,
                        )
detect()

