Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Tensorflow example with SSD-mobile-net
- Loading branch information
Showing
4 changed files
with
375 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
#!/bin/sh | ||
|
||
# Get models from https://github.com/opencv/opencv/wiki/TensorFlow-Object-Detection-API#use-existing-config-file-for-your-model | ||
|
||
wget -c http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v2_coco_2018_03_29.tar.gz -O - | tar -xz | ||
wget https://raw.githubusercontent.com/opencv/opencv_extra/master/testdata/dnn/ssd_mobilenet_v2_coco_2018_03_29.pbtxt | ||
|
||
# Tensorflow object detection API: https://github.com/opencv/opencv/wiki/TensorFlow-Object-Detection-API#use-existing-config-file-for-your-model |
340 changes: 340 additions & 0 deletions
340
tracking-tensorflow-ssd_mobilenet_v2_coco_2018_03_29.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,340 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import cv2 as cv\n", | ||
"from scipy.spatial import distance\n", | ||
"import numpy as np\n", | ||
"from collections import OrderedDict" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"##### Object Tracking Class" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class Tracker:\n", | ||
" def __init__(self, maxLost = 30): # maxLost: maximum object lost counted when the object is being tracked\n", | ||
" self.nextObjectID = 0 # ID of next object\n", | ||
" self.objects = OrderedDict() # stores ID:Locations\n", | ||
" self.lost = OrderedDict() # stores ID:Lost_count\n", | ||
" \n", | ||
" self.maxLost = maxLost # maximum number of frames object was not detected.\n", | ||
" \n", | ||
" def addObject(self, new_object_location):\n", | ||
" self.objects[self.nextObjectID] = new_object_location # store new object location\n", | ||
" self.lost[self.nextObjectID] = 0 # initialize frame_counts for when new object is undetected\n", | ||
" \n", | ||
" self.nextObjectID += 1\n", | ||
" \n", | ||
" def removeObject(self, objectID): # remove tracker data after object is lost\n", | ||
" del self.objects[objectID]\n", | ||
" del self.lost[objectID]\n", | ||
" \n", | ||
" @staticmethod\n", | ||
" def getLocation(bounding_box):\n", | ||
" xlt, ylt, xrb, yrb = bounding_box\n", | ||
" return (int((xlt + xrb) / 2.0), int((ylt + yrb) / 2.0))\n", | ||
" \n", | ||
" def update(self, detections):\n", | ||
" \n", | ||
" if len(detections) == 0: # if no object detected in the frame\n", | ||
" for objectID in self.lost.keys():\n", | ||
" self.lost[objectID] +=1\n", | ||
" if self.lost[objectID] > self.maxLost: self.removeObject(objectID)\n", | ||
" \n", | ||
" return self.objects\n", | ||
" \n", | ||
" new_object_locations = np.zeros((len(detections), 2), dtype=\"int\") # current object locations\n", | ||
" \n", | ||
" for (i, detection) in enumerate(detections): new_object_locations[i] = self.getLocation(detection)\n", | ||
" \n", | ||
" if len(self.objects)==0:\n", | ||
" for i in range(0, len(detections)): self.addObject(new_object_locations[i])\n", | ||
" else:\n", | ||
" objectIDs = list(self.objects.keys())\n", | ||
" previous_object_locations = np.array(list(self.objects.values()))\n", | ||
" \n", | ||
" D = distance.cdist(previous_object_locations, new_object_locations) # pairwise distance between previous and current\n", | ||
" \n", | ||
" row_idx = D.min(axis=1).argsort() # (minimum distance of previous from current).sort_as_per_index\n", | ||
" \n", | ||
" cols_idx = D.argmin(axis=1)[row_idx] # index of minimum distance of previous from current\n", | ||
" \n", | ||
" assignedRows, assignedCols = set(), set()\n", | ||
" \n", | ||
" for (row, col) in zip(row_idx, cols_idx):\n", | ||
" \n", | ||
" if row in assignedRows or col in assignedCols:\n", | ||
" continue\n", | ||
" \n", | ||
" objectID = objectIDs[row]\n", | ||
" self.objects[objectID] = new_object_locations[col]\n", | ||
" self.lost[objectID] = 0\n", | ||
" \n", | ||
" assignedRows.add(row)\n", | ||
" assignedCols.add(col)\n", | ||
" \n", | ||
" unassignedRows = set(range(0, D.shape[0])).difference(assignedRows)\n", | ||
" unassignedCols = set(range(0, D.shape[1])).difference(assignedCols)\n", | ||
" \n", | ||
" \n", | ||
" if D.shape[0]>=D.shape[1]:\n", | ||
" for row in unassignedRows:\n", | ||
" objectID = objectIDs[row]\n", | ||
" self.lost[objectID] += 1\n", | ||
" \n", | ||
" if self.lost[objectID] > self.maxLost:\n", | ||
" self.removeObject(objectID)\n", | ||
" \n", | ||
" else:\n", | ||
" for col in unassignedCols:\n", | ||
" self.addObject(new_object_locations[col])\n", | ||
" \n", | ||
" return self.objects\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### Loading Object Detector Model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"##### Tensorflow model for Object Detection and Tracking\n", | ||
"\n", | ||
"Here, the SSD Object Detection Model is used.\n", | ||
"\n", | ||
"For more details about single shot detection (SSD), refer the following:\n", | ||
" - **Liu, W., Anguelov, D., Erhan, D., Szegedy, C., Reed, S., Fu, C. Y., & Berg, A. C. (2016, October). Ssd: Single shot multibox detector. In European conference on computer vision (pp. 21-37). Springer, Cham.**\n", | ||
" - Research paper link: https://arxiv.org/abs/1512.02325\n", | ||
" - The pretrained model: https://github.com/opencv/opencv/wiki/TensorFlow-Object-Detection-API#use-existing-config-file-for-your-model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model_info = {\"config_path\":\"./tensorflow_model_dir/ssd_mobilenet_v2_coco_2018_03_29.pbtxt\",\n", | ||
" \"model_weights_path\":\"./tensorflow_model_dir/ssd_mobilenet_v2_coco_2018_03_29/frozen_inference_graph.pb\",\n", | ||
" \"object_names\": {0: 'background',\n", | ||
" 1: 'person', 2: 'bicycle', 3: 'car', 4: 'motorcycle', 5: 'airplane', 6: 'bus',\n", | ||
" 7: 'train', 8: 'truck', 9: 'boat', 10: 'traffic light', 11: 'fire hydrant',\n", | ||
" 13: 'stop sign', 14: 'parking meter', 15: 'bench', 16: 'bird', 17: 'cat',\n", | ||
" 18: 'dog', 19: 'horse', 20: 'sheep', 21: 'cow', 22: 'elephant', 23: 'bear',\n", | ||
" 24: 'zebra', 25: 'giraffe', 27: 'backpack', 28: 'umbrella', 31: 'handbag',\n", | ||
" 32: 'tie', 33: 'suitcase', 34: 'frisbee', 35: 'skis', 36: 'snowboard',\n", | ||
" 37: 'sports ball', 38: 'kite', 39: 'baseball bat', 40: 'baseball glove',\n", | ||
" 41: 'skateboard', 42: 'surfboard', 43: 'tennis racket', 44: 'bottle',\n", | ||
" 46: 'wine glass', 47: 'cup', 48: 'fork', 49: 'knife', 50: 'spoon',\n", | ||
" 51: 'bowl', 52: 'banana', 53: 'apple', 54: 'sandwich', 55: 'orange',\n", | ||
" 56: 'broccoli', 57: 'carrot', 58: 'hot dog', 59: 'pizza', 60: 'donut',\n", | ||
" 61: 'cake', 62: 'chair', 63: 'couch', 64: 'potted plant', 65: 'bed',\n", | ||
" 67: 'dining table', 70: 'toilet', 72: 'tv', 73: 'laptop', 74: 'mouse',\n", | ||
" 75: 'remote', 76: 'keyboard', 77: 'cell phone', 78: 'microwave', 79: 'oven',\n", | ||
" 80: 'toaster', 81: 'sink', 82: 'refrigerator', 84: 'book', 85: 'clock',\n", | ||
" 86: 'vase', 87: 'scissors', 88: 'teddy bear', 89: 'hair drier', 90: 'toothbrush'},\n", | ||
" \"confidence_threshold\": 0.5,\n", | ||
" \"threshold\": 0.4\n", | ||
" }\n", | ||
"\n", | ||
"net = cv.dnn.readNetFromTensorflow(model_info[\"model_weights_path\"], model_info[\"config_path\"])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": { | ||
"scrolled": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"np.random.seed(12345)\n", | ||
"\n", | ||
"bbox_colors = {key: np.random.randint(0, 255, size=(3,)).tolist() for key in model_info['object_names'].keys()}" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"##### Instantiate the Tracker Class" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"maxLost = 5 # maximum number of object losts counted when the object is being tracked\n", | ||
"tracker = Tracker(maxLost = maxLost)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"##### Initiate opencv video capture object\n", | ||
"\n", | ||
"The `video_src` can take two values:\n", | ||
"1. If `video_src=0`: OpenCV accesses the camera connected through USB\n", | ||
"2. If `video_src='video_file_path'`: OpenCV will access the video file at the given path (can be MP4, AVI, etc format)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"video_src = \"./data/video_test5.mp4\"#0\n", | ||
"cap = cv.VideoCapture(video_src)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"##### Start object detection and tracking" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": { | ||
"scrolled": false | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Cannot read the video feed.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"(H, W) = (None, None) # input image height and width for the network\n", | ||
"writer = None\n", | ||
"while(True):\n", | ||
" \n", | ||
" ok, image = cap.read()\n", | ||
" \n", | ||
" if not ok:\n", | ||
" print(\"Cannot read the video feed.\")\n", | ||
" break\n", | ||
" \n", | ||
" if W is None or H is None: (H, W) = image.shape[:2]\n", | ||
" \n", | ||
" blob = cv.dnn.blobFromImage(image, size=(300, 300), swapRB=True, crop=False)\n", | ||
" net.setInput(blob)\n", | ||
" detections = net.forward()\n", | ||
" \n", | ||
" detections_bbox = [] # bounding box for detections\n", | ||
" \n", | ||
" boxes, confidences, classIDs = [], [], []\n", | ||
" \n", | ||
" for detection in detections[0, 0, :, :]:\n", | ||
" classID = detection[1]\n", | ||
" confidence = detection[2]\n", | ||
"\n", | ||
" if confidence > model_info['confidence_threshold']:\n", | ||
" box = detection[3:7] * np.array([W, H, W, H])\n", | ||
" \n", | ||
" (left, top, right, bottom) = box.astype(\"int\")\n", | ||
" width = right - left + 1\n", | ||
" height = bottom - top + 1\n", | ||
"\n", | ||
" boxes.append([int(left), int(top), int(width), int(height)])\n", | ||
" confidences.append(float(confidence))\n", | ||
" classIDs.append(int(classID))\n", | ||
" \n", | ||
" indices = cv.dnn.NMSBoxes(boxes, confidences, model_info[\"confidence_threshold\"], model_info[\"threshold\"])\n", | ||
" \n", | ||
" if len(indices)>0:\n", | ||
" for i in indices.flatten():\n", | ||
" x, y, w, h = boxes[i][0], boxes[i][1], boxes[i][2], boxes[i][3]\n", | ||
" \n", | ||
" detections_bbox.append((x, y, x+w, y+h))\n", | ||
" \n", | ||
" clr = [int(c) for c in bbox_colors[classIDs[i]]]\n", | ||
" cv.rectangle(image, (x, y), (x+w, y+h), clr, 2)\n", | ||
" \n", | ||
" label = \"{}:{:.4f}\".format(model_info[\"object_names\"][classIDs[i]], confidences[i])\n", | ||
" (label_width, label_height), baseLine = cv.getTextSize(label, cv.FONT_HERSHEY_SIMPLEX, 0.5, 2)\n", | ||
" y_label = max(y, label_height)\n", | ||
" cv.rectangle(image, (x, y_label-label_height),\n", | ||
" (x+label_width, y_label+baseLine), (255, 255, 255), cv.FILLED)\n", | ||
" cv.putText(image, label, (x, y_label), cv.FONT_HERSHEY_SIMPLEX, 0.5, clr, 2)\n", | ||
" \n", | ||
" objects = tracker.update(detections_bbox) # update tracker based on the newly detected objects\n", | ||
" \n", | ||
" for (objectID, centroid) in objects.items():\n", | ||
" text = \"ID {}\".format(objectID)\n", | ||
" cv.putText(image, text, (centroid[0] - 10, centroid[1] - 10), cv.FONT_HERSHEY_SIMPLEX,\n", | ||
" 0.5, (0, 255, 0), 2)\n", | ||
" cv.circle(image, (centroid[0], centroid[1]), 4, (0, 255, 0), -1)\n", | ||
" \n", | ||
" cv.imshow(\"image\", image)\n", | ||
" \n", | ||
" if cv.waitKey(1) & 0xFF == ord('q'):\n", | ||
" break\n", | ||
" \n", | ||
" if writer is None:\n", | ||
" fourcc = cv.VideoWriter_fourcc(*\"MJPG\")\n", | ||
" writer = cv.VideoWriter(\"output.avi\", fourcc, 30, (W, H), True)\n", | ||
" writer.write(image)\n", | ||
"writer.release()\n", | ||
"cap.release()\n", | ||
"cv.destroyWindow(\"image\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "drlnd", | ||
"language": "python", | ||
"name": "drlnd" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.8" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |