In [6]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [18]:
import sys
sys.path.append(r"..")
import torch
from pysot.core.config import cfg
from pysot.models.model_builder import ModelBuilder
from pysot.tracker.tracker_builder import build_tracker
import cv2
from pathlib import Path
import PIL.Image as PIL_Image
import tensorflow as tf
import numpy as np
from IPython.display import clear_output, display
from birdbox.tools import process_video, Rect, crop_to_square_crop, cv2_putText, video_drop_down
from birdbox.detection import detect_birds, load_detector, Detection
from birdbox.classifiers import Classifier
from object_detection.utils import visualization_utils


In [8]:
# DETECTOR_MODEL_NAME = "centernet_hg104_1024x1024_coco17_tpu-32"
DETECTOR_MODEL_NAME = "ssd_mobilenet_v2_320x320_coco17_tpu-8"

CLASSIFIER_NAME = "EfficientNetB0_120x120_1-3-4-5-6-7-8-9-10-11-12"


In [9]:
classifier = Classifier.load(CLASSIFIER_NAME)
class_names = {value["id"]: value["name"] for key, value in classifier.category_index.items()}

In [10]:
[detect_fn, category_index] = load_detector(model_name=DETECTOR_MODEL_NAME)

Loading model...Done! Took 8.08270812034607s


In [11]:
def draw_box_with_title(frame, box, title=None):

    visualization_utils.visualize_boxes_and_labels_on_image_array(
        frame,
        np.array([(box.top, box.left, box.bottom, box.right)]),
        [1],
        [1],
        {1: {"id": 1, "name": title}},
        skip_scores=True
    )


class DetectionAndTrackingProcessor:
    CONFIG_FILE = Path(r"..\models\pysot\siamrpn_r50_l234_dwxcorr\config.yaml")
    SNAPSHOT_FILE = Path(r"..\models\pysot\siamrpn_r50_l234_dwxcorr\model.pth")
    MIN_SCORE = 0.5
    HISTORY_LENGTH = 100

    def __init__(self):
        self.init_box = None
        cfg.merge_from_file(self.CONFIG_FILE)
        cfg.CUDA = torch.cuda.is_available() and cfg.CUDA
        device = torch.device('cuda' if cfg.CUDA else 'cpu')
        model = ModelBuilder()
        model.load_state_dict(torch.load(self.SNAPSHOT_FILE, map_location=lambda storage, loc: storage.cpu()))
        model.eval().to(device)
        self.tracker = build_tracker(model)
        self.class_history = []

    def __call__(self, frame):

        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        image_pil = PIL_Image.fromarray(frame_rgb)

        # Initialize tracking target
        if not self.init_box:

            detections = detect_birds(frame_rgb, detect_fn, min_score=0.1, max_intersection_over_size=0)
            if len(detections) == 0:
                return

            self.init_box = detections[0].box * image_pil.size
            self.tracker.init(frame, self.init_box.left_top_width_height())
            return

        # Track
        outputs = self.tracker.track(frame)
        tracking_score = outputs["best_score"]

        # Abort for small tracking score
        if tracking_score < self.MIN_SCORE:
            self.init_box = None
            self.class_history = []
            return
        # print(tracking_score)

        # Build classification crop
        box = round(Rect.from_left_top_width_height(*outputs['bbox']))
        classifier_box = crop_to_square_crop(box, image_pil.size)
        bird_crop = image_pil.crop(classifier_box)

        # Classify
        if tracking_score < 0.9999 or len(self.class_history) == 0:
            image_tf = tf.expand_dims(np.array(bird_crop.resize(classifier.image_size)), 0)
            prediction = classifier.model.predict(image_tf)
            current_class = class_names[prediction[0].argmax() + 1]
            if len(self.class_history) == self.HISTORY_LENGTH:
                self.class_history.pop(0)
            self.class_history.append(current_class)

        # Determine best guess
        classification_counts = dict()
        for name in class_names.values():
            count = self.class_history.count(name)
            classification_counts[name] = count

        best_guess = max(classification_counts, key=lambda key: classification_counts[key])
        best_count = classification_counts[best_guess]

#         draw_box_with_title(frame, box, title=f"{tracking_score * 100:.0f} {best_guess} {best_count / self.HISTORY_LENGTH * 100:.0f}%")
        draw_box_with_title(frame, box, title=f"{best_guess} {best_count / self.HISTORY_LENGTH * 100:.0f}%")

In [12]:
drop_down = video_drop_down()
display(drop_down)

Dropdown(description='Video file:', options=('Blue_tit_vs_chaffinch.mp4', 'Verl_red_robin_in_water.mp4', 'Grouâ€¦

In [27]:
# video = Path(r"..\videos") / drop_down.value
# video = Path(r"..\videos\wald\great_tit_drinking.mp4")
# video = Path(r"..\videos\wald\sparrows.mp4")
video = Path(r"..\videos\wald\robin.mp4")
# skips = {"Ground_feeding": 0, "Blue_tit_vs_chaffinch": 1, "Verl_red_robin_in_water": 0, "Ground_feeding2": 0, "great_tit_drinking": 1}
# skip = skips[video.stem]
skip = 0
# output_path = None
output_path = Path(r"..\videos\results") / (video.stem + ".avi")
process_video(video, DetectionAndTrackingProcessor(), skip=skip, start=150, output_path=output_path)

