<a href="https://colab.research.google.com/github/DavoodSZ1993/Dive_into_Deep_Learning/blob/main/14_4_anchor_boxes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
!pip install d2l==1.0.0-alpha1.post0 --quiet

## 14.4 Anchor Boxes

In [6]:
%matplotlib inline
import torch
from d2l import torch as d2l

torch.set_printoptions(2)

### 14.4.1 Generating Multiple Anchor Boxes

In [18]:
def multibox_prior(data, sizes, ratios):
  in_height, in_width = data.shape[-2:]
  device, num_sizes, num_ratios = data.device, len(sizes), len(ratios)
  boxes_per_pixel = (num_sizes + num_ratios - 1)
  size_tensor = torch.tensor(sizes, device=device)
  ratio_tensor = torch.tensor(ratios, device=device)

  offset_h, offset_w = 0.5, 0.5
  steps_h = 1.0 / in_height
  steps_w = 1.0 / in_width

  center_h = (torch.arange(in_height, device=device) + offset_h) * steps_h
  center_w = (torch.arange(in_width, device=device) + offset_w) * steps_w
  shift_y, shift_x = torch.meshgrid(center_h, center_w, indexing='ij')
  shift_y, shift_x = shift_y.reshape(-1), shift_x.reshape(-1)

  w = torch.cat((size_tensor * torch.sqrt(ratio_tensor[0]),
                 sizes[0] * torch.sqrt(ratio_tensor[1:])))\
                 * in_height / in_width
  h = torch.cat((size_tensor / torch.sqrt(ratio_tensor[0]),
                 sizes[0] / torch.sqrt(ratio_tensor[1:])))
  anchor_manipulations = torch.stack((-w, -h, w, h)).T.repeat(
                                     in_height * in_width, 1) / 2
  
  out_grid = torch.stack([shift_x, shift_y, shift_x, shift_y],
                         dim=1).repeat_interleave(boxes_per_pixel, dim=0)
  output = out_grid + anchor_manipulations
  return output.unsqueeze(0)

In [19]:
from PIL import Image
import torchvision.transforms as T


img = Image.open('catdog.jpg')
aug = T.Compose([T.Resize((561, 728)),
                 T.ToTensor()])
img = aug(img)
h, w = img.shape[1:]

In [20]:
print(h, w)
X = torch.rand(size=(1, 3, h, w)) # COnstruct input data
Y = multibox_prior(X, sizes=[0.75, 0.5, 0.25], ratios=[1, 2, 0.5])
Y.shape

561 728


torch.Size([1, 2042040, 4])

In [21]:
boxes = Y.reshape(h, w, 5, 4)
boxes[250, 250, 0, :]

tensor([0.06, 0.07, 0.63, 0.82])

In [22]:
def show_bboxes(axes, bboxes, labels=None, colors=None):

  def make_list(obj, default_values=None):
    if obj is None:
      obj = default_values
    elif not isinstance(obj, (list, tuple)):
      obj = [obj]
    return obj

  labels = make_list(labels)
  colors = make_list(colors, ['b', 'g', 'r', 'm', 'c'])
  for i, bbox in enumerate(bboxes):
    color = colors[i % len(colors)]
    rect = d2l.bbox_to_rect(bbox.detach().numpy(), color)
    axes.add_patch(rect)
    if labels and len(labels) > i:
      text_color = 'k' if color == 'w' else 'w'
      axes.text(rect.xy[0], rect.xy[1], labels[i],
                va='center', ha='center', fontsize=9, color=text_color,
                bbox=dict(facecolor=color, lw=0))

