In [None]:
import torch 
import numpy as np
import torch.nn as nn

In [None]:
def sort_anchors(anchors): #basically we get all widths and heights ,compute ratios and areas,apply lexsort on acnhors[sort,:]
  w=anchors[:,2]-anchors[:,0]
  h=anchors[:,3]-anchors[:,1]
  ratios=np.round(h/w,1)
  areas=w*h
  return anchors[np.lexsort((areas,ratios)),:]

def generate_anchors_reference(base_size,scales,ratios):
  scales_grid,ratios_grid=np.meshgrid(scales,ratios) #1. meshgrid of scales,ratios reshaped to a vector
  base_scales=scales_grid.reshape(-1) #2. compute square of ratios
  base_ratios=ratios_grid.reshape(-1)
  ratio_sqrt=np.sqrt(base_ratios)    
  height=base_scales*ratio_sqrt*base_size  #3. compute width and height with square of ratios
  width=base_scales / ratio_sqrt * base_size 
  center_xy=0  
  anchors=np.column_stack([center_xy-width/2, #4.compute four coordinates and stack together
                    center_xy-height/2,
                    center_xy+width/2,
                    center_xy+height/2])
  return sort_anchors(anchors)  #5. return with sorting

In [None]:
anchors_ref = generate_anchors_reference(
    256,  # Base size.
    [0.5, 1, 2],  # Aspect ratios.
    [0.125, 0.25, 0.5, 1, 2],  # Scales.
    )

In [None]:
ANCHOR_BASE_SIZE = 16
ANCHOR_RATIOS = [0.5, 1, 2]
ANCHOR_SCALES = [0.125, 0.25, 0.5, 1, 2]
feature_map_shape=(3,16,16,256)
def generate_anchors(feature_map_shape): 
  anchor_reference = generate_anchors_reference(ANCHOR_BASE_SIZE, ANCHOR_RATIOS, ANCHOR_SCALES) #1. Obtain references to anchors
  num_anchors_per_points=anchor_reference.shape[0] 
  width=feature_map_shape[2]  #2. Obtain width/height of feature map
  height=feature_map_shape[1]

  shift_x=torch.arange(0,width) * ANCHOR_BASE_SIZE #3. Create vector of shifts in x and y direction
  shift_y=torch.arange(0,height) * ANCHOR_BASE_SIZE

  shift_x,shift_y=torch.meshgrid(shift_x,shift_y) #4. Meshgrid of shifts + reshape to vector
  shift_x,shift_y=shift_x.reshape(-1),shift_y.reshape(-1)

  shifts_xy=torch.stack([shift_x,shift_y,shift_x,shift_y]) #5. stack shifts
  num_shifts = shifts_xy.shape[1]

  all_anchors = torch.tensor(anchor_reference.reshape((1, num_anchors_per_points, 4))) + shifts_xy.reshape((num_shifts, 1, 4))
  all_anchors=all_anchors.reshape(height,width,15,4) #6.reshape acnhor_reference and shift it by shifts 

  return all_anchors

In [None]:
def get_width_upright(bboxes): #bboxes: (num_bboxes,4)
  x1, y1, x2, y2 = bboxes[:,0],bboxes[:,1],bboxes[:,2],bboxes[:,3]
  width = x2 - x1 + 1
  height = y2 - y1 + 1.

  ctx = x1 + .5 * width
  cty = y1 + .5 * height

  return width, height, ctx, cty

In [None]:
# Encoding `bbox` with respect to an anchor having the same center
# should keep the first two deltas at zero.

def encode(anchors, bboxes):
  w,h,ctx,cty=get_width_upright(anchors)
  tw,th,tctx,tcty=get_width_upright(bboxes)
  
  tg_dx=(tctx-ctx) / w
  tg_dy=(tcty-cty) / h

  tg_dw=torch.log(tw/w)
  tg_dh=torch.log(th/h)

  deltas=torch.stack([tg_dx,tg_dy,tg_dw,tg_dh],dim=1)
  return deltas

In [None]:
a= torch.tensor([[0, 0, 100, 100]], dtype=torch.float32)
b =torch.tensor([[25, 25, 75, 75]], dtype=torch.float32)
print('With same center, first two deltas should be zero:\n', encode(a, b))
print(encode(a,b).shape)

With same center, first two deltas should be zero:
 tensor([[ 0.0000,  0.0000, -0.6833, -0.6833]])
torch.Size([1, 4])


In [None]:
def decode(anchors, deltas):
  w,h, ctx,cty = get_width_upright(anchors)

  dx, dy, dw, dh = deltas[:,0],deltas[:,1],deltas[:,2],deltas[:,3]

  pred_ctx = dx * w + ctx
  pred_cty = dy * h + cty
  pred_w = torch.exp(dw) * w
  pred_h = torch.exp(dh) * h

  bbox_x1 = pred_ctx - 0.5 * pred_w
  bbox_y1 = pred_cty - 0.5 * pred_h

  bbox_x2 = pred_ctx + 0.5 * pred_w -1.
  bbox_y2 = pred_cty + 0.5 * pred_h -1.

  bboxes = torch.stack([bbox_x1, bbox_y1, bbox_x2, bbox_y2], dim=1)

  return bboxes

In [None]:
# Test the round-trip: encode `bboxes` w.r.t. the anchors `anchors`
anchor = torch.tensor([[0, 0, 100, 100],], dtype=torch.float32)
bboxes = torch.tensor([
    [25, 25, 75, 75],
    [10, -205, 120, 20],
    [-35, 37, 38, 100],
    [-0.2, -0.2, 0.2, 0.2],
    [-25, -50, -5, -20],], 
    dtype=torch.float32)
print(f"ANCHOR SHAPE: {anchor.shape} === BBOX SHAPE: {bboxes.shape}")
print(
    'Round-trip looks good:',
    torch.sum(torch.abs(
        decode(anchor, encode(anchor, bboxes)) - bboxes
    )) < 1e-3
)

ANCHOR SHAPE: torch.Size([1, 4]) === BBOX SHAPE: torch.Size([5, 4])
Round-trip looks good: tensor(True)


In [None]:
class RPN_conv(nn.Module):
  def __init__(self,num_anchors,out=512,inp=1024):
    #feature_map: Tensor of shape (1, W, H, C), with WxH the spatial shape of the feature map and C the number of channels (1024 in this case)
    super(RPN_conv,self).__init__()
    self.num_anchors=num_anchors
    self.out_ch=out
    self.in_ch=inp

    self.conv=nn.Conv2d(self.in_ch,self.out_ch,3,1,1)
    self.prob=nn.Conv2d(self.out_ch,num_anchors*2,1)
    self.delt=nn.Conv2d(self.out_ch,num_anchors*4,1)
    self.relu=nn.ReLU()

  def forward(self,x):
    x=x.permute(0,3,1,2)
    interm=self.relu(self.conv(x))
    out_prob=self.prob(interm).permute(0, 2, 3, 1).contiguous().view(-1, 2)
    out_delta=self.delt(interm).permute(0, 2, 3, 1).contiguous().view(-1, 4)
    
    return (out_prob,out_delta)

In [None]:
feature_map=torch.rand((1,232,232,1024))
model=RPN_conv(15)
prob,delt=model(feature_map)
prob.shape,delt.shape

(torch.Size([807360, 2]), torch.Size([807360, 4]))

In [None]:
expected_preds = (
    feature_map.shape[1]
    * feature_map.shape[2]
    * len(ANCHOR_RATIOS)
    * len(ANCHOR_SCALES)
)

In [None]:
assert delt.shape[0] == expected_preds , "Numbers don't match"
assert prob.shape[0] == expected_preds , "Numbers don't match"

In [None]:
anchors=torch.rand((807360,4),dtype=torch.float32)
print(f"ANCHOR SHAPE: {anchors.shape} === BBOX SHAPE: {delt.shape}")
proposals = decode(anchors, delt)
scores=prob[:,1].reshape(-1)
proposals.shape , scores.shape

ANCHOR SHAPE: torch.Size([807360, 4]) === BBOX SHAPE: torch.Size([807360, 4])


(torch.Size([807360, 4]), torch.Size([807360]))

In [None]:
def keep_top_n(proposals, scores, topn):
  num_proposals=proposals.shape[0]
  if num_proposals <= topn:
    return proposals, scores
  else:
    indices=torch.argsort(scores,descending=True)
    top_indices = indices[:topn]
    sorted_top_proposals=proposals[top_indices]
    sorted_top_scores=scores[top_indices]
  return sorted_top_proposals, sorted_top_scores
proposals,scores=keep_top_n(proposals,scores,3000)
print(f'PROPOSALS SHAPE: {proposals.shape} \n SCORES SHAPE: {scores.shape}')

PROPOSALS SHAPE: torch.Size([3000, 4]) 
 SCORES SHAPE: torch.Size([3000])


In [None]:
def clip_boxes(bboxes,im_shape): #We usually apply this on proposals after decode
  x1,y1,x2,y2=bboxes[:,0],bboxes[:,1],bboxes[:,2],bboxes[:,3]
  width,height=im_shape.shape[1],im_shape.shape[2]

  x1=torch.maximum(torch.minimum(x1,width-1),0.0)
  x2=torch.maximum(torch.minimum(x2,width-1),0.0)
  y1=torch.maximum(torch.minimum(y1,height-1),0.0)
  y2=torch.maximum(torch.minimum(y2,height-1),0.0)

  bboxes=torch.stack([x1,y1,x2,y2],dim=1)
  
  return bboxes

In [None]:
#def filter_proposals(bbox_preds,class_preds): clip_boxes, apply area and prob filters + nms

In [None]:
def normalize_boxes(proposals,im_shape):
  x1,y1,x2,y2=proposals[:,0],proposals[:,1],proposals[:,2],proposals[:,3]
  x1 = x1 / im_shape[1]
  y1 = y1 / im_shape[0]
  x2 = x2 / im_shape[1]
  y2 = y2 / im_shape[0]

  return torch.stack([x1,y1,x2,y2],dim=1)

def roi_crop(proposals,ft_map,im_shape,pooled_width,pooled_height):
  bboxes=normalize_boxes(proposals,im_shape)
  bboxes_shape=bboxes.shape
  batch_ids=torch.zeros((bboxes_shape[0],),dtype=torch.int32)

  crops=torch.resized_crop(ft_map,bboxes,batch_ids,[pooled_width * 2, pooled_height * 2])
  pool=nn.MaxPool2d(2,2)
  out=pool(crops)
  return out

In [None]:
#def run_rcnn(pooled, num_classes):
  #pooled: Pooled feature map, with shape `(num_proposals,pool_size, pool_size, feature_map_channels)`.
  #Returns: Tuple of Tensors (`(W * H * proposals, 4)`, `(pool_size ^ 2 * proposals, num_classes)`)

  #1.Run pooled through the tail of ResNet + global average pooling
  #2. Run through fully-connected + softmax