In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import json
import tensorflow as tf
import matplotlib.pyplot as plt
%matplotlib inline

from src import SSD, AnchorGenerator, FeatureExtractor
from src.backbones import mobilenet_v1_base
# from src.anchor_generator import tile_anchors

# Generate anchors

In [None]:
tf.reset_default_graph()

params = json.load(open('config.json'))
input_pipeline_params = params['input_pipeline_params']
params = params['model_params']

width, height = input_pipeline_params['image_size']
images = tf.placeholder(tf.float32, [None, 3, height, width])

is_training = False
def backbone(images, is_training):
    return mobilenet_v1_base(images, is_training, min_depth=8, depth_multiplier=1.0)
    
feature_extractor = FeatureExtractor(backbone, is_training)
anchor_generator = AnchorGenerator(
        min_scale=params['min_scale'], max_scale=params['max_scale'],
        aspect_ratios=params['aspect_ratios'],
        interpolated_scale_aspect_ratio=params['interpolated_scale_aspect_ratio'],
        reduce_boxes_in_lowest_layer=params['reduce_boxes_in_lowest_layer']
    )

feature_maps = feature_extractor(images)
anchors = anchor_generator(feature_maps, images)
anchor_grid_list = anchor_generator.anchor_grid_list


# more_anchors = tile_anchors(
#     width/height,
#     grid_height=12, grid_width=20, scales=(0.2,) * 5,
#     aspect_ratios=(1.0, 0.6, 0.4, 0.3333, 0.2), 
#     anchor_stride=(1/12, 1/20), 
#     anchor_offset=(1/24, 1/40)
# )

In [None]:
feature_maps

In [None]:
anchors

In [None]:
anchor_generator.num_basis_anchors

In [None]:
num_anchors_per_feature_map = anchor_generator.num_anchors_per_feature_map
num_anchors_per_feature_map

In [None]:
anchor_generator.feature_map_shape_list

In [None]:
anchor_grid_list

In [None]:
with tf.Session() as sess:
    anchor_boxes = sess.run(anchor_grid_list)

In [None]:
more_anchor_boxes = anchor_boxes[3]

In [None]:
# anchor_boxes = anchor_boxes.reshape((6, 10, 6, 4))
more_anchor_boxes = more_anchor_boxes.reshape((6, 10, 6, 4))

# Show non clipped anchors

In [None]:
ymin, xmin, ymax, xmax = [more_anchor_boxes[:, :, :, i] for i in range(4)]

h, w = height*(ymax - ymin), width*(xmax - xmin)
cy, cx = height*ymin + 0.5*h, width*xmin + 0.5*w

centers = np.stack([cy, cx], axis=3)
anchor_sizes = np.stack([h, w], axis=3)

In [None]:
fig, ax = plt.subplots(1, dpi=100, figsize=(int(5*width/height), 5))
unique_centers = centers[:, :, 0, :].reshape(-1, 2)
unique_sizes = anchor_sizes[0, 0, :, :]

i = 1
for j, point in enumerate(unique_centers):
    cy, cx = point
    color = 'g' if j == i else 'r' 
    ax.plot([cx], [cy], marker='o', markersize=3, color=color)

cy, cx = unique_centers[i] 
for box in unique_sizes:
    h, w = box
    xmin, ymin = cx - 0.5*w, cy - 0.5*h
    rect = plt.Rectangle(
        (xmin, ymin), w, h,
        linewidth=1.0, edgecolor='k', facecolor='none'
    )
    ax.add_patch(rect)

plt.xlim([0, width]);
plt.ylim([0, height]);