In [12]:
import cv2
import os
import warnings
import tensorflow as tf
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as viz_utils
from object_detection.builders import model_builder
from object_detection.utils import config_util
import platform
import numpy as np
from utils import plot_image

warnings.filterwarnings("ignore")
%load_ext autoreload
%autoreload 2

%matplotlib inline
%reload_ext nb_black
%config IPCompleter.greedy=True

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [13]:
CUSTOM_MODEL_NAME = "my_ssd_mobnet"
PRETRAINED_MODEL_NAME = "ssd_mobilenet_v2_fpnlite_320x320_coco17_tpu-8"
PRETRAINED_MODEL = "http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_mobilenet_v2_fpnlite_320x320_coco17_tpu-8.tar.gz"
TF_RECORD_SCRIPT_NAME = "generate_tfrecord.py"
LABEL_MAP_NAME = "label_map.pbtxt"
current_os = platform.system()

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [14]:
parent_directory = os.path.dirname(os.getcwd())
paths = {
    "SRC_PATH": os.path.join(parent_directory, "src"),
    "DATA_PATH": os.path.join(parent_directory, "data"),
    "APIMODEL_PATH": os.path.join(parent_directory, "api_models"),
    "MODEL_PATH": os.path.join(parent_directory, "my_models"),
    "PRETRAINED_MODEL_PATH": os.path.join(parent_directory, "pre-trained-models"),
    "CHECKPOINT_PATH": os.path.join(parent_directory, "my_models", CUSTOM_MODEL_NAME),
    "OUTPUT_PATH": os.path.join(
        parent_directory, "my_models", CUSTOM_MODEL_NAME, "export"
    ),
    "TFJS_PATH": os.path.join(
        parent_directory, "my_models", CUSTOM_MODEL_NAME, "tfjsexport"
    ),
    "TFLITE_PATH": os.path.join(
        parent_directory, "my_models", CUSTOM_MODEL_NAME, "tfliteexport"
    ),
    "PROTOC_PATH": os.path.join(parent_directory, "protoc"),
    "VIDEOS_PATH": os.path.join(parent_directory, "videos"),
}
mice = [
    "PrL-2",
    "PrL-3",
    "PrL-4",
    "PrL-5",
    "PrL-8",
    "PrL-9",
    "PrL-10",
    "PrL-16",
    "PrL-19",
]

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [15]:
files = {
    "PIPELINE_CONFIG": os.path.join(
        parent_directory, "my_models", CUSTOM_MODEL_NAME, "pipeline.config"
    ),
    "TF_RECORD_SCRIPT": os.path.join(paths["DATA_PATH"], TF_RECORD_SCRIPT_NAME),
    "LABELMAP": os.path.join(paths["DATA_PATH"], LABEL_MAP_NAME),
}

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [16]:
# Load pipeline config and build a detection model
configs = config_util.get_configs_from_pipeline_file(files["PIPELINE_CONFIG"])
detection_model = model_builder.build(model_config=configs["model"], is_training=False)

# Restore checkpoint
ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
ckpt.restore(os.path.join(paths["CHECKPOINT_PATH"], "ckpt-3")).expect_partial()


@tf.function
def detect_fn(image):
    image, shapes = detection_model.preprocess(image)
    prediction_dict = detection_model.predict(image, shapes)
    detections = detection_model.postprocess(prediction_dict, shapes)
    return detections

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [17]:
category_index = label_map_util.create_category_index_from_labelmap(files["LABELMAP"])
category_index

{1: {'id': 1, 'name': 'parent-mice'}, 2: {'id': 2, 'name': 'child-mice'}}

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [18]:
IMAGE_PATH = (
    f'{paths["DATA_PATH"]}/train/parent_PrL-4-0f93312b-0d14-4741-baa5-a82fe96d695f.jpg'
)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [19]:
def sharpen_image(img):
    kernel = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]])
    im = cv2.filter2D(img, -1, kernel)
    return im

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [20]:
0.5 not in range(0, 1)

True

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [21]:
def detect_pup_mouse(image_np: np.ndarray, min_thresh=0.3):
    if min_thresh < 0 or min_thresh > 1:
        raise ValueError("min_thresh score should be between 0 and 1 inclusive")
    input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32)
    detections = detect_fn(input_tensor)

    num_detections = int(detections.pop("num_detections"))
    detections = {
        key: np.asarray(value[0, :num_detections]) for key, value in detections.items()
    }
    detections["num_detections"] = num_detections

    # detection_classes should be ints.
    detections["detection_classes"] = detections["detection_classes"].astype(np.int64)

    label_id_offset = 1
    image_np_with_detections = image_np.copy()

    viz_utils.visualize_boxes_and_labels_on_image_array(
        image_np_with_detections,
        detections["detection_boxes"],
        detections["detection_classes"] + label_id_offset,
        detections["detection_scores"],
        category_index,
        use_normalized_coordinates=True,
        max_boxes_to_draw=2,
        min_score_thresh=min_thresh,
        agnostic_mode=False,
    )
    return detections, image_np_with_detections

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [22]:
watch_video = True
cap = cv2.VideoCapture(f'{paths["VIDEOS_PATH"]}/PrL-9.mp4')
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

while watch_video and cap.isOpened():
    ret, frame = cap.read()

    if not ret:
        break
    #     image_np = np.array(frame)

    #     detections, image_np_with_detections = detect_pup_mouse(image_np)

    #     cv2.imshow("object detection", cv2.resize(image_np_with_detections, (800, 600)))

    client_key_press = cv2.waitKey(1) & 0xFF

    cv2.imshow("object detection", cv2.resize(frame, (800, 600)))
    # end video stream by escape key
    if client_key_press == 27:
        break
cap.release()  # Important: release current active webcam or stream in order for other instance of webcam
cv2.destroyAllWindows()
if platform == "darwin":
    cv2.waitKey(1)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>