Skip to content

Commit

Permalink
Generating bounding boxes on image with labels!
Browse files Browse the repository at this point in the history
  • Loading branch information
Brad Neuberg committed Oct 28, 2015
1 parent 8fd4a54 commit 029ae35
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Expand Up @@ -7,4 +7,4 @@ data/leveldb
data/landsat/images
data/planetlab/images
data/planetlab/metadata
src/cloudless/inference/regions
src/cloudless/inference/bbox-regions
1 change: 1 addition & 0 deletions requirements.txt
Expand Up @@ -2,3 +2,4 @@ numpy>=1.8.0
plyvel>=0.9
scikit-learn>=0.15.2
matplotlib>=1.3.1
simplejson>=3.8.1
9 changes: 3 additions & 6 deletions src/cloudless/inference/README.md
Expand Up @@ -14,19 +14,16 @@ There are two scripts:

## Steps
- Set env vars CAFFE_HOME and SELECTIVE_SEARCH
- Remove argmax layer from prototxt
```
cd src/cloudless/inference
rm -fr regions/
./localization.py -i cloud_test.jpg --classes cloud-classes.txt --config ../../caffe_model/bvlc_alexnet/bounding_box.prototxt --weights ../../caffe_model/bvlc_alexnet/bvlc_alexnet_finetuned.caffemodel
```

Testing against imagenet (not cloudless) for debugging:
```
cd src/cloudless/inference
rm -fr regions/
./localization.py -i cat.jpg --classes imagenet-classes.txt --config ../../caffe_model/bvlc_alexnet/bounding_box_imagenet.prototxt --weights ../../caffe_model/bvlc_alexnet/bvlc_alexnet.caffemodel
./localization.py -i cat.jpg --classes imagenet-classes.txt --config ../../caffe_model/bvlc_alexnet/bounding_box_imagenet.prototxt --weights ../../caffe_model/bvlc_alexnet/bvlc_alexnet.caffemodel --ks 125 --max_regions 4
open cat-regions.png
```

## TODO
- Actually use generated results to determine segmentation or bounding boxes
This will write out the image with bounding boxes drawn on it, including a JSON file with machine readable info on the top bounding boxes.
112 changes: 90 additions & 22 deletions src/cloudless/inference/localization.py
@@ -1,10 +1,14 @@
#!/usr/bin/env python
import os
import shutil
import sys
import argparse
import time
import re
import random
from decimal import Decimal
from operator import itemgetter
from PIL import Image, ImageDraw, ImageFont

sys.path.append(os.environ.get('SELECTIVE_SEARCH'))

Expand All @@ -15,9 +19,11 @@
os.environ['GLOG_minloglevel'] = '1'

from selective_search import *
import features
from skimage.transform import resize
import caffe
import numpy as np
import simplejson as json

def parse_command_line():
parser = argparse.ArgumentParser(
Expand All @@ -28,12 +34,6 @@ def parse_command_line():
help="input image",
default='cat.jpg'
)
parser.add_argument(
"-o",
"--output",
help="output image with bounding boxes",
default='cat-regions.jpg'
)
parser.add_argument(
"-m",
"--dimension",
Expand Down Expand Up @@ -73,7 +73,7 @@ def parse_command_line():
)
parser.add_argument(
"-r",
"--regions",
"--max_regions",
help="(optional) maximum number of bounding box regions to choose",
type=int,
default=3
Expand All @@ -83,7 +83,7 @@ def parse_command_line():
"--threshold",
help="(optional) percentage threshold of confidence necessary for a bounding box to be included",
type=float,
default=13.0
default=10.0
)
parser.add_argument(
"-D",
Expand All @@ -92,6 +92,20 @@ def parse_command_line():
action="store_true",
default=True
)
parser.add_argument(
"-k",
"--ks",
help="value for the ks argument controlling selective search region formation",
type=int,
default=100
)
parser.add_argument(
"--only_for_class",
help="""only draw bounding boxes for regions that match some class; draws bounding boxes
for any class found if not given""",
type=int,
default=None
)

args = parser.parse_args()

Expand All @@ -107,21 +121,22 @@ def parse_command_line():

return args

# Choose X number of regions that match some threshold.
# Take the original image and draw bounding boxes on them; add labels to the bounding boxes

def gen_regions(image, dims, pad):
def gen_regions(image, dims, pad, ks):
"""
Generates candidate regions for object detection using selective search.
"""

print "Generating cropped regions..."
assert(len(dims) == 3)
img = skimage.io.imread(image)
regions = selective_search(img, ks=[300])
regions = selective_search(image, ks=[ks], feature_masks=[features.SimilarityMask(
size=1,
color=1,
texture=1,
fill=1,
)])

crops = []
for conf, (x0, y0, x1, y1) in regions:
for conf, (y0, x0, y1, x1) in regions:
if x0 - pad >= 0:
x0 = x0 - pad
if y0 - pad >= 0:
Expand All @@ -130,7 +145,8 @@ def gen_regions(image, dims, pad):
x1 = x1 + pad
if y1 + pad <= dims[0]:
y1 = y1 + pad
region = img[x0:x1, y0:y1, :]
# Images are rows, then columns, then channels.
region = image[y0:y1, x0:x1, :]
candidate = resize(region, dims)
crops.append((conf, candidate, region, (x0, y0, x1, y1)))

Expand All @@ -140,18 +156,18 @@ def gen_regions(image, dims, pad):

def get_region_filename(idx):
""" Generates a region filename. """
return "regions/%s.jpg" % idx
return "bbox-regions/%s.jpg" % idx

def dump_regions(crops):
""" Writes out region proposals to the disk in regions/ for debugging. """
if not os.path.exists("regions"):
os.makedirs("regions")
shutil.rmtree("bbox-regions", ignore_errors=True)
os.makedirs("bbox-regions")

for idx, img in enumerate(crops):
fname = get_region_filename(idx)
skimage.io.imsave(fname, img[2])

print "Wrote regions out to disk in regions/"
print "Wrote regions out to disk in bbox-regions/"

def classify(images, config, weights):
""" Classifies our region proposals. """
Expand Down Expand Up @@ -190,19 +206,28 @@ def sort_predictions(classes, predictions, bboxes):
results = []
for idx, pred in enumerate(predictions):
results.append({
"class_idx": np.argmax(pred),
"class": classes[np.argmax(pred)],
"prob": pred[np.argmax(pred)],
"fname": get_region_filename(idx),
"coords": bboxes[idx],
})
results.sort(key=itemgetter("prob"), reverse=True)

print_predictions(classes, results)
return results

def filter_predictions(predictions, max_regions, threshold):
"""
Filters predictions down to just those that are above or equal to a certain threshold, with
a max number of results controlled by 'max_regions'.
"""
results = [entry for entry in predictions if entry["prob"] >= threshold]
results = results[0:max_regions]
return results

def print_predictions(classes, predictions):
""" Prints out the predictions for debugging. """
print "Top predictions:"
for idx, pred in enumerate(predictions):
print("prob: {}, class: {}, file: {}, coords: {}".format(
predictions[idx]["prob"],
Expand All @@ -211,9 +236,48 @@ def print_predictions(classes, predictions):
predictions[idx]["coords"],
))

def draw_bounding_boxes(image_path, image, classes, predictions, only_for_class=None):
image = Image.fromarray(numpy.uint8(image))
dr = ImageDraw.Draw(image)

colors = {}
for idx, pred in enumerate(predictions):
x0, y0, x1, y1 = pred["coords"]

color = "red"
# If we want to display multiple classes, randomly generate a color for it.
if not only_for_class:
class_idx = pred["class_idx"]
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
if class_idx in colors:
color = colors[class_idx]
colors[class_idx] = color

dr.rectangle(((x0, y0), (x1, y1)), fill=None, outline=color)
dr.text((x0, y0), pred["class"], fill=color)

filename = os.path.splitext(image_path)[0] + "-regions.png"
image.save(filename)

print "Image with drawn bounding boxes saved to %s" % filename

def dump_bounding_box_info(image_path, predictions):
""" Writes out our top predictions to a JSON file for other tools to work with. """
filename = os.path.splitext(image_path)[0] + "-regions.json"
# Make sure we can serialize our Python float values.
for entry in predictions:
entry["prob"] = Decimal("%.7g" % entry["prob"])

with open(filename, "w") as f:
f.write(json.dumps(predictions, use_decimal=True, indent=4, separators=(',', ': ')))

print "Bounding box info saved as JSON to %s" % filename

def main(argv):
args = parse_command_line()
crops = gen_regions(args.image, args.dimension, args.pad)
image = skimage.io.imread(args.image)

crops = gen_regions(image, args.dimension, args.pad, args.ks)

if args.dump_regions:
dump_regions(crops)
Expand All @@ -224,7 +288,11 @@ def main(argv):

bboxes = [entry[3] for entry in crops]
predictions = sort_predictions(classes, predictions, bboxes)
predictions = filter_predictions(predictions, args.max_regions, args.threshold)
print_predictions(classes, predictions)

draw_bounding_boxes(args.image, image, classes, predictions, args.only_for_class)
dump_bounding_box_info(args.image, predictions)

if __name__ == '__main__':
main(sys.argv)

0 comments on commit 029ae35

Please sign in to comment.