In [1]:
import random
import torch

n = random.randint(1, 3)
C = random.randint(10, 20)
H = random.randint(5, 10)
W = random.randint(5, 10)
oH = random.randint(2, 4)
oW = random.randint(2, 4)
L = random.randint(2, 6)
input = torch.rand(n, C, H, W)
boxes = [torch.zeros(L, 4) for _ in range(n)]
for i in range(n):
  boxes[i][:, 0] = torch.rand(L) * (H-oH)       # y
  boxes[i][:, 1] = torch.rand(L) * (W-oW)       # x
  boxes[i][:, 2] = oH + torch.rand(L) * (H-oH)  # w
  boxes[i][:, 3] = oW + torch.rand(L) * (W-oW)  # h

  boxes[i][:,2:] += boxes[i][:,:2]
  boxes[i][:,2] = torch.clamp(boxes[i][:,2], max=H-1)
  boxes[i][:,3] = torch.clamp(boxes[i][:,3], max=W-1)
output_size = (oH, oW)

In [2]:
input.shape

torch.Size([1, 14, 8, 8])

In [3]:
import torch.nn as nn
import torch.nn.functional as F

In [4]:
n, C, H, W = input.shape
oH, oW = output_size
L, _ = boxes[0].shape
out = torch.randn((n, L, C, oH, oW), dtype=torch.float32)
out.shape

torch.Size([1, 2, 14, 2, 4])

In [5]:
for box in boxes:
    print(box)

tensor([[3.0269, 1.1751, 5.3938, 5.6796],
        [3.5404, 1.2701, 7.0000, 7.0000]])


In [8]:
# For all channels
for i in range(n):
    feature_map = input[i]

    # For all the boxes
    for box in boxes[i]:
        y1, x1, y2, x2 = box
        
        # Rounding at the nearest integer
        y1 = torch.round(y1)
        x1 = torch.round(x1)
        y2 = torch.round(y2)
        x2 = torch.round(x2)
        
        # Compute roi
        for i_grid in range(oH):
            for j_grid in range(oW):
                # x axis
                y_range_start = torch.floor(y1 + i_grid * (y2 - y1 + 1) / oH).int()
                y_range_end = torch.ceil(y1 + (i_grid + 1) * (y2 - y1 + 1) / oH).int()
                
                # y axis
                x_range_start = torch.floor(x1 + j_grid * (x2 - x1 + 1) / oW).int()
                x_range_end = torch.ceil(x1 + (j_grid + 1) * (x2 - x1 + 1) / oW).int()
                
                y_range_start = torch.clamp(y_range_start, min=0, max=H)
                y_range_end = torch.clamp(y_range_end, min=0, max=H)
                x_range_start = torch.clamp(x_range_start, min=0, max=W)
                x_range_end = torch.clamp(x_range_end, min=0, max=W)

                region = feature_map[:, y_range_start:y_range_end, x_range_start:x_range_end]
                
                max_pooled = region.amax(dim=(1, 2))

                out[i, i, :, i_grid, j_grid] = max_pooled
        
out.shape

torch.Size([1, 2, 14, 2, 4])