In [1]:
# To avoid 'ImportError: libc10.so: cannot open shared object file: No such file or directory'
# we need to import torch
import torch
from mmcv_ops.utils import load_ext
from mmcv_ops.roi_align import RoIAlign

ext_module = load_ext('_ext', [
    'roi_align_forward_cpu', 'roi_align_backward_cpu',
    'roi_align_forward_cuda', 'roi_align_backward_cuda'
])

In [2]:
# Model Creation
# CPU Model
roi_align_layer_cpu = RoIAlign((3, 3), 0.5, 0)
# generate features randomly (bs, C, H, W)
feats_cpu = torch.randn((1, 2, 10, 10), dtype=torch.float32, requires_grad=True)
# create rois
# 0 is batch idx
# each roi is organised by (batch_idx, x1, y1, x2, y2)
rois_cpu = torch.tensor([[0, 2, 2, 10, 10]], dtype=torch.float32, requires_grad=True)

In [3]:
# print feats
print('feats:\n')
print(feats_cpu)

feats:

tensor([[[[ 0.1257, -1.9619, -0.1382,  1.3187,  0.7176, -0.3532, -0.7202,
            2.2736, -0.0474, -1.0741],
          [-0.6901, -1.2496, -1.0259, -0.1270, -0.6580,  0.6004,  0.4073,
           -1.3295,  1.0427,  1.6906],
          [ 1.8739,  0.2285, -0.6754,  1.5749, -0.6862, -0.3564,  0.4297,
            3.6265, -1.4205,  1.5400],
          [ 0.5917,  0.0906,  0.0173,  1.8198, -0.4953, -1.2511,  0.2089,
            0.7291, -0.7137, -0.4959],
          [ 0.4790, -0.9269,  1.4868, -1.5470,  0.7609, -1.0193,  0.3684,
           -1.2842,  1.3593, -0.5944],
          [-0.4893,  1.0714,  2.5024, -1.5468, -0.0141,  0.3112, -1.8697,
            0.4981,  0.0472, -0.7115],
          [-1.1766, -0.8875, -0.4125,  0.3101, -1.6586, -0.2818, -0.2802,
            0.1481, -0.8391,  0.3456],
          [-0.4128,  2.3628, -0.3821, -0.1831,  0.6396, -1.0878, -0.8211,
           -0.7375,  1.0809, -0.0933],
          [-0.6507, -0.0699, -0.2249,  0.8217, -0.3815, -0.8674, -0.9957,
            0.

### 1. `forward` Function in RoIAlign
The following is the forward function in RoIAlign Module.
```python
    def forward(self, input: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
        """
        Args:
            input: NCHW images
            rois: Bx5 boxes. First column is the index into N.\
                The other 4 columns are xyxy.
        """
        return roi_align(input, rois, self.output_size, self.spatial_scale,
                         self.sampling_ratio, self.pool_mode, self.aligned)
```
We can find that it calls `roi_align` function.

### 2. `forward` Function in RoIAlignFunction
Actually, `roi_align` is the `forward` function in `RoIAlignFunction`.

```python
    def forward(ctx: Any,
                input: torch.Tensor,
                rois: torch.Tensor,
                output_size: int,
                spatial_scale: float = 1.0,
                sampling_ratio: int = 0,
                pool_mode: str = 'avg',
                aligned: bool = True) -> torch.Tensor:
        device = input.device
        # In our settings, ctx.output_size = (3, 3)
        ctx.output_size = _pair(output_size)
        # spatial_scale = 0.5
        ctx.spatial_scale = spatial_scale
        # If sampling_ratio > 0, we sample sampling_ratio * sampling_ratio points in each block
        # else, the number of points we sample is based some rules
        ctx.sampling_ratio = sampling_ratio
        assert pool_mode in ('max', 'avg')
        # Here pool mode is 1
        ctx.pool_mode = 0 if pool_mode == 'max' else 1
        # aligned = True
        ctx.aligned = aligned
        # input_shape = (1, 2, 10, 10)
        ctx.input_shape = input.size()

        assert rois.size(1) == 5, 'RoI must be (idx, x1, y1, x2, y2)!'
        # output_shape = (1, 2, 3, 3)
        output_shape = (rois.size(0), input.size(1), ctx.output_size[0],
                        ctx.output_size[1])
        # output
        output = input.new_zeros(output_shape)
        if ctx.pool_mode == 0:
            argmax_y = input.new_zeros(output_shape)
            argmax_x = input.new_zeros(output_shape)
        else:
            argmax_y = input.new_zeros(0)
            argmax_x = input.new_zeros(0)
        if device == 'cpu':
            roi_align_forward = ext_module.roi_align_forward_cpu
        else:
            roi_align_forward = ext_module.roi_align_forward_cuda
        roi_align_forward(input,
                          rois,
                          output,
                          argmax_y,
                          argmax_x,
                          aligned_height=ctx.output_size[0],
                          aligned_width=ctx.output_size[1],
                          spatial_scale=ctx.spatial_scale,
                          sampling_ratio=ctx.sampling_ratio,
                          pool_mode=ctx.pool_mode,
                          aligned=ctx.aligned)

        ctx.save_for_backward(rois, argmax_y, argmax_x)
        return output
```

### 3. `roi_align_forward_cpu`
Let's see `roi_align_forward_cpu` firstly.

```cpp
void roi_align_forward_cpu(Tensor input, Tensor rois, Tensor output,
                                Tensor argmax_y, Tensor argmax_x,
                                int aligned_height, int aligned_width,
                                float spatial_scale, int sampling_ratio,
                                int pool_mode, bool aligned) {
  // output_size = 1 * 2 * 3 * 3 = 18
  int output_size = output.numel();
  // channels = 2
  int channels = input.size(1);
  // height = 10
  int height = input.size(2);
  // width = 10
  int width = input.size(3);

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      input.scalar_type(), "ROIAlign_forward", [&] {
        ROIAlignForward<scalar_t>(
            output_size, input.data_ptr<scalar_t>(), rois.data_ptr<scalar_t>(),
            output.data_ptr<scalar_t>(), argmax_y.data_ptr<scalar_t>(),
            argmax_x.data_ptr<scalar_t>(), aligned_height, aligned_width,
            static_cast<scalar_t>(spatial_scale), sampling_ratio, pool_mode,
            aligned, channels, height, width);
      });
}
```

### 4.ROIAlignForward
In `roi_align_forward_cpu`, it calls `ROIAlignForward` function.

```cpp
template <typename T>
void ROIAlignForward(const int nthreads, const T* input, const T* rois,
                     T* output, T* argmax_y, T* argmax_x,
                     const int pooled_height, const int pooled_width,
                     const T spatial_scale, const int sampling_ratio,
                     const int pool_mode,  // 0 - max pool, 1 - avg pool
                     const bool aligned, const int channels, const int height,
                     const int width) {
  // nthreads are the output size, which is equal to 
  // nrois * channels * pooled_width * pooled_height
  int n_rois = nthreads / channels / pooled_width / pooled_height;
  // (n, c, ph, pw) is an element in the pooled output
  // can be parallelized using omp
  // #pragma omp parallel for num_threads(32)
  // we iterate each roi
  for (int n = 0; n < n_rois; n++) {
    // Note that in C++, the Tensor type is organised by 1D array
    // Therefore, we need to calculate the start index of output for n-th roi
    // output is (n_rois, channels, pooled_width, pooled_height)
    // For n-th roi, the start index should be `n * channels * pooled_width * pooled_height`
    int index_n = n * channels * pooled_width * pooled_height;
    // offset of n-th roi
    const T* offset_rois = rois + n * 5;
    // get batch idx
    int roi_batch_ind = offset_rois[0];

    // Do not use rounding; this implementation detail is critical
    T offset = aligned ? (T)0.5 : (T)0.0;
    // offset is just a trick, whether to align
    // Remember that after convolution, the feature map is a downsample image
    // And spatial_scale is the downsample rate
    // roi_start_w = 0.5, roi_start_h = 0.5, roi_end_w = 4.5, roi_end_h = 4.5
    T roi_start_w = offset_rois[1] * spatial_scale - offset;
    T roi_start_h = offset_rois[2] * spatial_scale - offset;
    T roi_end_w = offset_rois[3] * spatial_scale - offset;
    T roi_end_h = offset_rois[4] * spatial_scale - offset;
    // the width and height of roi in feature map
    // roi_width = 4, roi_height = 4
    T roi_width = roi_end_w - roi_start_w;
    T roi_height = roi_end_h - roi_start_h;
    if (aligned) {
      AT_ASSERTM(roi_width >= 0 && roi_height >= 0,
                 "ROIs in ROIAlign cannot have non-negative size!");
    } else {  // for backward-compatibility only
      roi_width = std::max(roi_width, (T)1.);
      roi_height = std::max(roi_height, (T)1.);
    }
    // Notice that T is float
    // Each output unit length corresponds to the length of the roi
    // In our case, bin_size_h = 1.33, bin_size_w = 1.33
    T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
    T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

    // We use roi_bin_grid to sample the grid and mimic integral
    // roi_bin_grid_h = 2
    // roi_bin_grid_w = 2
    int roi_bin_grid_h = (sampling_ratio > 0)
                             ? sampling_ratio
                             : ceilf(roi_height / pooled_height);  // e.g., = 2
    int roi_bin_grid_w =
        (sampling_ratio > 0) ? sampling_ratio : ceilf(roi_width / pooled_width);

    // When the grid is empty, output zeros == 0/1, instead of NaN.
    // count = 4
    const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1);  // e.g. = 4

    // we want to precalculate indices and weights shared by all channels,
    // this is the key point of optimization
    std::vector<PreCalc<T>> pre_calc(roi_bin_grid_h * roi_bin_grid_w *
                                     pooled_width * pooled_height);
    pre_calc_for_bilinear_interpolate(
        height, width, pooled_height, pooled_width, roi_bin_grid_h,
        roi_bin_grid_w, roi_start_h, roi_start_w, bin_size_h, bin_size_w,
        roi_bin_grid_h, roi_bin_grid_w, pre_calc);
    // before diving into more details, you should read section 5 & 6
    // iteration for each channel
    for (int c = 0; c < channels; c++) {
      // index_n is the start index of n-th roi
      // index_n_c is the start index of c-th channel of n-th roi
      int index_n_c = index_n + c * pooled_width * pooled_height;
      // ptr of feats
      const T* offset_input =
          input + (roi_batch_ind * channels + c) * height * width;
      int pre_calc_index = 0;
      // iteration for pooled_height
      for (int ph = 0; ph < pooled_height; ph++) {
        // iteration for pooled_width
        for (int pw = 0; pw < pooled_width; pw++) {
          // index of pooled pixel
          int index = index_n_c + ph * pooled_width + pw;

          T output_val = 0.;
          T maxval = -10000;
          T maxidx_y = -1.f, maxidx_x = -1.f;
          for (int iy = 0; iy < roi_bin_grid_h; iy++) {
            // y coordinate
            const T y = roi_start_h + ph * bin_size_h +
                        static_cast<T>(iy + .5f) * bin_size_h /
                            static_cast<T>(roi_bin_grid_h);
            for (int ix = 0; ix < roi_bin_grid_w; ix++) {
              // x coordinate
              const T x = roi_start_w + pw * bin_size_w +
                          static_cast<T>(ix + .5f) * bin_size_w /
                              static_cast<T>(roi_bin_grid_w);
              PreCalc<T> pc = pre_calc[pre_calc_index];
              T val = pc.w1 * offset_input[pc.pos1] +
                      pc.w2 * offset_input[pc.pos2] +
                      pc.w3 * offset_input[pc.pos3] +
                      pc.w4 * offset_input[pc.pos4];
              if (val > maxval) {
                maxval = val;
                maxidx_y = y;
                maxidx_x = x;
              }
              output_val += val;
              pre_calc_index += 1;
            }
          }
          if (pool_mode == 0) {
            // We do max pooling inside a bin
            output[index] = maxval;
            argmax_y[index] = maxidx_y;
            argmax_x[index] = maxidx_x;
          } else if (pool_mode == 1) {
            // We do average (integral) pooling inside a bin
            output[index] = output_val / count;
          }  // if
        }    // for pw
      }      // for ph
    }        // for c
  }          // for n
}
```

### 5. PreCalc
One import data structure in `roi_align_forward_cpu` is `PreCalc`.
```cpp
template <typename T>
struct PreCalc {
  // positions and weights
  int pos1;
  int pos2;
  int pos3;
  int pos4;
  T w1;
  T w2;
  T w3;
  T w4;
};
```

### 6. `pre_calc_for_bilinear_interpolate` Function

```cpp
template <typename T>
void pre_calc_for_bilinear_interpolate(
    const int height, const int width, const int pooled_height,
    const int pooled_width, const int iy_upper, const int ix_upper,
    T roi_start_h, T roi_start_w, T bin_size_h, T bin_size_w,
    int roi_bin_grid_h, int roi_bin_grid_w, std::vector<PreCalc<T>>& pre_calc) {
  // height = 10, width = 10, pooled_height = 3, pooled_width = 3
  // iy_upper = 2, ix_upper = 2, roi_start_h = 0.5, roi_start_w = 0.5
  // bin_size_h = 1.33, bin_size_w = 1.33, roi_bin_grid_h = 2, roi_bin_grid_w = 2

  // index counter
  int pre_calc_index = 0;
  // iteration for pooled_height
  for (int ph = 0; ph < pooled_height; ph++) {
    // iteration for pooled_width
    for (int pw = 0; pw < pooled_width; pw++) {
      // iteration for sampling points
      // we will sample iy_upper point per column
      for (int iy = 0; iy < iy_upper; iy++) {
        // yy is the y coordinate of sampling point 
        // roi_start_h is y coordinate of top left corner of roi
        // bin_size_h is gap in roi for per pooling pixel
        // roi_start_h + ph * bin_size_h is the y coordinate of top left corner of roi for ph-th pooling block
        // static_cast<T>(iy + .5f), move to center
        // static_cast<T>(iy + .5f) / static_cast<T>(roi_bin_grid_h) percentage for iy-th sampling point in pooling area
        // static_cast<T>(iy + .5f) * bin_size_h / static_cast<T>(roi_bin_grid_h) is the gap for top left corner of roi for ph-th pooling block
        const T yy = roi_start_h + ph * bin_size_h +
                     static_cast<T>(iy + .5f) * bin_size_h /
                         static_cast<T>(roi_bin_grid_h);  // e.g., 0.5, 1.5
        // iteration for sampling points
        // we will sample ix_upper point per row
        for (int ix = 0; ix < ix_upper; ix++) {
          // xx is the x coordinate of sampling point
          // the same with yy
          const T xx = roi_start_w + pw * bin_size_w +
                       static_cast<T>(ix + .5f) * bin_size_w /
                           static_cast<T>(roi_bin_grid_w);

          T x = xx;
          T y = yy;
          // deal with: inverse elements are out of feature map boundary
          if (y < -1.0 || y > height || x < -1.0 || x > width) {
            // empty
            PreCalc<T> pc;
            pc.pos1 = 0;
            pc.pos2 = 0;
            pc.pos3 = 0;
            pc.pos4 = 0;
            pc.w1 = 0;
            pc.w2 = 0;
            pc.w3 = 0;
            pc.w4 = 0;
            pre_calc[pre_calc_index] = pc;
            pre_calc_index += 1;
            continue;
          }

          if (y <= 0) {
            y = 0;
          }
          if (x <= 0) {
            x = 0;
          }
          // calculate for around points for bilinear interplotation
          int y_low = (int)y;
          int x_low = (int)x;
          int y_high;
          int x_high;

          if (y_low >= height - 1) {
            y_high = y_low = height - 1;
            y = (T)y_low;
          } else {
            y_high = y_low + 1;
          }

          if (x_low >= width - 1) {
            x_high = x_low = width - 1;
            x = (T)x_low;
          } else {
            x_high = x_low + 1;
          }
          // y_low-y distance
          T ly = y - y_low;
          // x_low-x distance
          T lx = x - x_low;
          // y_high-y distance & x_high-x distance
          T hy = 1. - ly, hx = 1. - lx;
          // the weights of bilinear interplotation
          // Remember in pixels, the denominator is 1
          // w1, w2, w3, w4 should be (x2-x)(y2-y), (y2-y)(x-x1), (y-y1)(x2-x), (y-y1)(x-x1), respectively
          T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;

          // save weights and indices
          PreCalc<T> pc;
          pc.pos1 = y_low * width + x_low;
          pc.pos2 = y_low * width + x_high;
          pc.pos3 = y_high * width + x_low;
          pc.pos4 = y_high * width + x_high;
          pc.w1 = w1;
          pc.w2 = w2;
          pc.w3 = w3;
          pc.w4 = w4;
          pre_calc[pre_calc_index] = pc;

          pre_calc_index += 1;
        }
      }
    }
  }
}
```

In [4]:
align_roi_cpu = roi_align_layer_cpu(feats_cpu, rois_cpu)
print('align roi cpu:\n')
print(align_roi_cpu)

align roi cpu:

tensor([[[[-8.4081e-01, -2.2266e-01, -2.3874e-01],
          [ 1.2684e-01,  6.8415e-01, -3.6468e-02],
          [-7.8885e-04,  2.4938e-01, -2.0245e-03]],

         [[-5.0896e-01,  6.5235e-01,  1.5244e-01],
          [-9.3811e-01, -1.1204e+00,  5.1612e-01],
          [-2.4151e-01, -1.0674e+00, -5.7716e-02]]]],
       grad_fn=<RoIAlignFunctionBackward>)


In [5]:
# Model Creation
# GPU Model
roi_align_layer_cuda = RoIAlign((3, 3), 0.5, 0).cuda()
# generate features randomly (bs, C, H, W)
feats_cuda = torch.randn((1, 2, 10, 10), dtype=torch.float32, requires_grad=True, device='cuda:0')
# create rois
# 0 is batch idx
# each roi is organised by (batch_idx, x1, y1, x2, y2)
rois_cuda = torch.tensor([[0, 2, 2, 10, 10]], dtype=torch.float32, requires_grad=True, device='cuda:0')

In [6]:
# print feats
print('feats:\n')
print(feats_cuda)

feats:

tensor([[[[-0.0386, -0.6785,  0.7538, -1.1464,  0.3887,  0.4440, -0.2841,
           -0.9236, -0.0969,  0.6881],
          [ 0.3273, -0.3086,  1.2009, -0.0527,  0.2667, -0.0526,  1.8422,
            0.2237, -0.6699, -1.1457],
          [-0.7323,  0.3871, -1.2246,  0.5657, -0.5494,  0.6029,  1.1535,
           -0.4653, -0.6445, -0.2874],
          [-1.0605, -2.9157, -0.3628, -0.3066, -1.9257,  1.0506,  1.5880,
           -0.1601,  0.5609, -0.5376],
          [-1.6469, -1.0870, -1.1721, -0.8651, -0.8287, -0.2735, -0.3042,
            0.3658, -0.7599,  0.4345],
          [ 0.0796, -0.1492,  0.0307,  0.3946, -0.5278,  0.7277,  0.7529,
            0.2595,  1.0894, -1.7413],
          [-0.3076,  0.2513, -0.7556, -0.9918,  0.9225,  1.3222, -1.6966,
           -1.6236, -0.3125, -1.1513],
          [-1.2115,  1.6396,  0.3104,  0.9622, -0.7934, -0.1700,  0.1579,
           -1.1424, -1.6679,  1.1257],
          [ 0.4306, -1.0570, -0.0604, -0.6088,  1.6677,  0.2019, -0.0521,
           -0.

In [7]:
align_roi_cuda = roi_align_layer_cuda(feats_cuda, rois_cuda)
print('align roi cuda:\n')
print(align_roi_cuda)

align roi cuda:

tensor([[[[ 0.0317,  0.2840,  0.0640],
          [-1.1160, -0.3321, -0.7238],
          [-1.3078, -0.7450, -0.8620]],

         [[ 0.4854, -0.0372, -0.1647],
          [ 0.7459, -0.6621, -0.4197],
          [ 0.8964,  0.5971,  0.2348]]]], device='cuda:0',
       grad_fn=<RoIAlignFunctionBackward>)


### 7. roi_align_forward_cuda

Let's see `roi_align_forward_cuda` then.

```cpp
void roi_align_forward_cuda(Tensor input, Tensor rois, Tensor output,
                                       Tensor argmax_y, Tensor argmax_x,
                                       int aligned_height, int aligned_width,
                                       float spatial_scale, int sampling_ratio,
                                       int pool_mode, bool aligned) {
  int output_size = output.numel();
  int channels = input.size(1);
  int height = input.size(2);
  int width = input.size(3);

  at::cuda::CUDAGuard device_guard(input.device());
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      input.scalar_type(), "roi_align_forward_cuda_kernel", [&] {
        roi_align_forward_cuda_kernel<scalar_t>
            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
                output_size, input.data_ptr<scalar_t>(),
                rois.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
                argmax_y.data_ptr<scalar_t>(), argmax_x.data_ptr<scalar_t>(),
                aligned_height, aligned_width,
                static_cast<scalar_t>(spatial_scale), sampling_ratio, pool_mode,
                aligned, channels, height, width);
      });

  AT_CUDA_CHECK(cudaGetLastError());
}
```

### 8. roi_align_forward_cuda_kernel
Actually, we can find that `roi_align_forward_cuda_kernel` is almost the same with `roi_align_forward_cpu_kernel`.
This implementation is multi-thread and is based on the number of RoIs.

```cpp
template <typename T>
__global__ void roi_align_forward_cuda_kernel(
    const int nthreads, const T* input, const T* rois, T* output, T* argmax_y,
    T* argmax_x, const int pooled_height, const int pooled_width,
    const T spatial_scale, const int sampling_ratio,
    const int pool_mode,  // 0 - max pool, 1 - avg pool
    const bool aligned, const int channels, const int height, const int width) {
  CUDA_1D_KERNEL_LOOP(index, nthreads) {
    // (n, c, ph, pw) is an element in the pooled output
    int pw = index % pooled_width;
    int ph = (index / pooled_width) % pooled_height;
    int c = (index / pooled_width / pooled_height) % channels;
    int n = index / pooled_width / pooled_height / channels;

    const T* offset_rois = rois + n * 5;
    int roi_batch_ind = offset_rois[0];

    // Do not using rounding; this implementation detail is critical
    T offset = aligned ? (T)0.5 : (T)0.0;
    T roi_start_w = offset_rois[1] * spatial_scale - offset;
    T roi_start_h = offset_rois[2] * spatial_scale - offset;
    T roi_end_w = offset_rois[3] * spatial_scale - offset;
    T roi_end_h = offset_rois[4] * spatial_scale - offset;

    T roi_width = roi_end_w - roi_start_w;
    T roi_height = roi_end_h - roi_start_h;
    if (!aligned) {  // for backward-compatibility only
      roi_width = max(roi_width, (T)1.);
      roi_height = max(roi_height, (T)1.);
    }

    T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
    T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

    const T* offset_input =
        input + (roi_batch_ind * channels + c) * height * width;

    // We use roi_bin_grid to sample the grid and mimic integral
    int roi_bin_grid_h =
        (sampling_ratio > 0)
            ? sampling_ratio
            : static_cast<int>(ceilf(roi_height / pooled_height));
    int roi_bin_grid_w =
        (sampling_ratio > 0)
            ? sampling_ratio
            : static_cast<int>(ceilf(roi_width / pooled_width));

    if (pool_mode == 0) {
      // We do max pooling inside a bin
      T maxval = -FLT_MAX;
      T maxidx_y = -1.f, maxidx_x = -1.f;
      for (int iy = 0; iy < roi_bin_grid_h; iy++) {
        const T y = roi_start_h + ph * bin_size_h +
                    static_cast<T>(iy + .5f) * bin_size_h /
                        static_cast<T>(roi_bin_grid_h);
        for (int ix = 0; ix < roi_bin_grid_w; ix++) {
          const T x = roi_start_w + pw * bin_size_w +
                      static_cast<T>(ix + .5f) * bin_size_w /
                          static_cast<T>(roi_bin_grid_w);
          T val =
              bilinear_interpolate(offset_input, height, width, y, x, index);
          if (val > maxval) {
            maxval = val;
            maxidx_y = y;
            maxidx_x = x;
          }
        }
      }
      output[index] = maxval;
      argmax_y[index] = maxidx_y;
      argmax_x[index] = maxidx_x;
    } else if (pool_mode == 1) {
      // We do average pooling inside a bin
      const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1);
      T output_val = 0.;
      for (int iy = 0; iy < roi_bin_grid_h; iy++) {
        const T y = roi_start_h + ph * bin_size_h +
                    static_cast<T>(iy + .5f) * bin_size_h /
                        static_cast<T>(roi_bin_grid_h);
        for (int ix = 0; ix < roi_bin_grid_w; ix++) {
          const T x = roi_start_w + pw * bin_size_w +
                      static_cast<T>(ix + .5f) * bin_size_w /
                          static_cast<T>(roi_bin_grid_w);
          T val =
              bilinear_interpolate(offset_input, height, width, y, x, index);
          output_val += val;
        }
      }
      output[index] = output_val / count;
    }
  }
}
```

### 9. roi_align_backward_cpu

```cpp
void roi_align_backward_cpu(Tensor grad_output, Tensor rois,
                                 Tensor argmax_y, Tensor argmax_x,
                                 Tensor grad_input, int aligned_height,
                                 int aligned_width, float spatial_scale,
                                 int sampling_ratio, int pool_mode,
                                 bool aligned) {
  int output_size = grad_output.numel();
  int channels = grad_input.size(1);
  int height = grad_input.size(2);
  int width = grad_input.size(3);

  // get stride values to ensure indexing into gradients is correct.
  int n_stride = grad_output.stride(0);
  int c_stride = grad_output.stride(1);
  int h_stride = grad_output.stride(2);
  int w_stride = grad_output.stride(3);

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      grad_output.scalar_type(), "ROIAlign_backward", [&] {
        ROIAlignBackward<scalar_t>(
            output_size, grad_output.data_ptr<scalar_t>(),
            rois.data_ptr<scalar_t>(), argmax_y.data_ptr<scalar_t>(),
            argmax_x.data_ptr<scalar_t>(), grad_input.data_ptr<scalar_t>(),
            aligned_height, aligned_width, static_cast<scalar_t>(spatial_scale),
            sampling_ratio, pool_mode, aligned, channels, height, width,
            n_stride, c_stride, h_stride, w_stride);
      });
}
```

### 10. `backward` Function in `RoIAlignFunction`
We also should know the backward process. Note that the `grad_output` should be a tensor whose dimension is (n_rois, C, output_size, output_size). In our setting, the dimension is (1, 2, 3, 3).

```python
    def backward(ctx: Any, grad_output: torch.Tensor) -> tuple:
        rois, argmax_y, argmax_x = ctx.saved_tensors
        device = rois.device
        grad_input = grad_output.new_zeros(ctx.input_shape)
        # complex head architecture may cause grad_output uncontiguous.
        grad_output = grad_output.contiguous()
        if device == 'cpu':
            roi_align_backward = ext_module.roi_align_backward_cpu
        else:
            roi_align_backward = ext_module.roi_align_backward_cuda
        roi_align_backward(grad_output,
                           rois,
                           argmax_y,
                           argmax_x,
                           grad_input,
                           aligned_height=ctx.output_size[0],
                           aligned_width=ctx.output_size[1],
                           spatial_scale=ctx.spatial_scale,
                           sampling_ratio=ctx.sampling_ratio,
                           pool_mode=ctx.pool_mode,
                           aligned=ctx.aligned)
        return grad_input, None, None, None, None, None, None
```

### 11. roi_align_backward_cpu
This is the entrance to back propagation.

```cpp
void roi_align_backward_cpu(Tensor grad_output, Tensor rois,
                                 Tensor argmax_y, Tensor argmax_x,
                                 Tensor grad_input, int aligned_height,
                                 int aligned_width, float spatial_scale,
                                 int sampling_ratio, int pool_mode,
                                 bool aligned) {
  int output_size = grad_output.numel();
  int channels = grad_input.size(1);
  int height = grad_input.size(2);
  int width = grad_input.size(3);

  // get stride values to ensure indexing into gradients is correct.
  int n_stride = grad_output.stride(0);
  int c_stride = grad_output.stride(1);
  int h_stride = grad_output.stride(2);
  int w_stride = grad_output.stride(3);

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      grad_output.scalar_type(), "ROIAlign_backward", [&] {
        ROIAlignBackward<scalar_t>(
            output_size, grad_output.data_ptr<scalar_t>(),
            rois.data_ptr<scalar_t>(), argmax_y.data_ptr<scalar_t>(),
            argmax_x.data_ptr<scalar_t>(), grad_input.data_ptr<scalar_t>(),
            aligned_height, aligned_width, static_cast<scalar_t>(spatial_scale),
            sampling_ratio, pool_mode, aligned, channels, height, width,
            n_stride, c_stride, h_stride, w_stride);
      });
}
```

### 12. ROIAlignBackward

It is easy for you to derive the Back Propagation formula.

```cpp
template <typename T>
void ROIAlignBackward(const int nthreads, const T* grad_output, const T* rois,
                      const T* argmax_y, const T* argmax_x, T* grad_input,
                      const int pooled_height, const int pooled_width,
                      const T spatial_scale, const int sampling_ratio,
                      const int pool_mode,  // 0 - max pool, 1 - avg pool
                      const bool aligned, const int channels, const int height,
                      const int width, const int n_stride, const int c_stride,
                      const int h_stride, const int w_stride) {
  // nthreads = 1 * 2 * 10 * 10 = 200
  // For now, grad_out is a grad matrix whose dimension is (1, 2, 3, 3)
  // the dimension of rois is (1, 5)
  // argmax_y, argmax_x is memorized in forward step
  // grad_input is a zero matrix whose dimension is (1, 2, 10, 10)
  // pooled_height = 3
  // pooled_width = 3
  // spatial_scale = 0.5
  // sampling_ratio = 0
  // pool_mode = 0
  // aligned = True
  // channels = 2
  // height = 10
  // width = 10
  // n_stride = 18
  // c_stride = 9
  // h_stride = 3
  // w_stride = 1
  for (int index = 0; index < nthreads; index++) {
    // (n, c, ph, pw) is an element in the pooled output
    int pw = index % pooled_width;
    int ph = (index / pooled_width) % pooled_height;
    int c = (index / pooled_width / pooled_height) % channels;
    int n = index / pooled_width / pooled_height / channels;
    // for n-th roi
    const T* offset_rois = rois + n * 5;
    // get batch idx
    int roi_batch_ind = offset_rois[0];

    // Do not use rounding; this implementation detail is critical
    T offset = aligned ? (T)0.5 : (T)0.0;
    // calculate the roi coordinate, width and height
    T roi_start_w = offset_rois[1] * spatial_scale - offset;
    T roi_start_h = offset_rois[2] * spatial_scale - offset;
    T roi_end_w = offset_rois[3] * spatial_scale - offset;
    T roi_end_h = offset_rois[4] * spatial_scale - offset;

    T roi_width = roi_end_w - roi_start_w;
    T roi_height = roi_end_h - roi_start_h;
    if (aligned) {
      AT_ASSERTM(roi_width >= 0 && roi_height >= 0,
                 "ROIs in ROIAlign do not have non-negative size!");
    } else {  // for backward-compatibility only
      roi_width = std::max(roi_width, (T)1.);
      roi_height = std::max(roi_height, (T)1.);
    }
    // Notice that T is float
    // Each output unit length corresponds to the length of the roi
    // In our case, bin_size_h = 1.33, bin_size_w = 1.33
    T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
    T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

    // get n-th feature map's grad
    T* offset_grad_input =
        grad_input + ((roi_batch_ind * channels + c) * height * width);
    // get (n, c, ph, pw) element in grad_output
    int output_offset = n * n_stride + c * c_stride;
    const T* offset_grad_output = grad_output + output_offset;
    const T grad_output_this_bin =
        offset_grad_output[ph * h_stride + pw * w_stride];

    if (pool_mode == 0) {
      // We do max pooling inside a bin
      // get max point
      T y = argmax_y[index], x = argmax_x[index];
      if (y != -1.f) {
        T w1, w2, w3, w4;
        int x_low, x_high, y_low, y_high;
        // Before diving into more details, you should read Sec. 13,
        bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4,
                                      x_low, x_high, y_low, y_high, index);
        // w * grad_output
        T g1 = grad_output_this_bin * w1;
        T g2 = grad_output_this_bin * w2;
        T g3 = grad_output_this_bin * w3;
        T g4 = grad_output_this_bin * w4;

        if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
          // atomic add is not needed for now since it is single threaded
          add(offset_grad_input + y_low * width + x_low, static_cast<T>(g1));
          add(offset_grad_input + y_low * width + x_high, static_cast<T>(g2));
          add(offset_grad_input + y_high * width + x_low, static_cast<T>(g3));
          add(offset_grad_input + y_high * width + x_high, static_cast<T>(g4));
        }  // if
      }    // mode
    } else if (pool_mode == 1) {
      // We do average (integral) pooling inside a bin
      // We use roi_bin_grid to sample the grid and mimic integral
      int roi_bin_grid_h =
          (sampling_ratio > 0)
              ? sampling_ratio
              : ceilf(roi_height / pooled_height);  // e.g., = 2
      int roi_bin_grid_w = (sampling_ratio > 0)
                               ? sampling_ratio
                               : ceilf(roi_width / pooled_width);

      const T count = roi_bin_grid_h * roi_bin_grid_w;  // e.g. = 4
      for (int iy = 0; iy < roi_bin_grid_h; iy++) {
        const T y = roi_start_h + ph * bin_size_h +
                    static_cast<T>(iy + .5f) * bin_size_h /
                        static_cast<T>(roi_bin_grid_h);  // e.g., 0.5, 1.5
        for (int ix = 0; ix < roi_bin_grid_w; ix++) {
          const T x = roi_start_w + pw * bin_size_w +
                      static_cast<T>(ix + .5f) * bin_size_w /
                          static_cast<T>(roi_bin_grid_w);

          T w1, w2, w3, w4;
          int x_low, x_high, y_low, y_high;

          bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4,
                                        x_low, x_high, y_low, y_high, index);

          T g1 = grad_output_this_bin * w1 / count;
          T g2 = grad_output_this_bin * w2 / count;
          T g3 = grad_output_this_bin * w3 / count;
          T g4 = grad_output_this_bin * w4 / count;

          if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
            // atomic add is not needed for now since it is single threaded
            add(offset_grad_input + y_low * width + x_low, static_cast<T>(g1));
            add(offset_grad_input + y_low * width + x_high, static_cast<T>(g2));
            add(offset_grad_input + y_high * width + x_low, static_cast<T>(g3));
            add(offset_grad_input + y_high * width + x_high,
                static_cast<T>(g4));
          }  // if
        }    // ix
      }      // iy
    }        // mode
  }          // for
}  // ROIAlignBackward
```


### 13. bilinear_interpolate_gradient
This function helps us to calculate the gradient of bilinear interpolatation.

```cpp
template <typename T>
void bilinear_interpolate_gradient(const int height, const int width, T y, T x,
                                   T& w1, T& w2, T& w3, T& w4, int& x_low,
                                   int& x_high, int& y_low, int& y_high,
                                   const int index /* index for debug only*/) {
  // deal with cases that inverse elements are out of feature map boundary
  if (y < -1.0 || y > height || x < -1.0 || x > width) {
    // empty
    w1 = w2 = w3 = w4 = 0.;
    x_low = x_high = y_low = y_high = -1;
    return;
  }

  if (y <= 0) y = 0;
  if (x <= 0) x = 0;

  y_low = (int)y;
  x_low = (int)x;

  if (y_low >= height - 1) {
    y_high = y_low = height - 1;
    y = (T)y_low;
  } else {
    y_high = y_low + 1;
  }

  if (x_low >= width - 1) {
    x_high = x_low = width - 1;
    x = (T)x_low;
  } else {
    x_high = x_low + 1;
  }

  T ly = y - y_low;
  T lx = x - x_low;
  T hy = 1. - ly, hx = 1. - lx;

  // reference in forward
  // T v1 = input[y_low * width + x_low];
  // T v2 = input[y_low * width + x_high];
  // T v3 = input[y_high * width + x_low];
  // T v4 = input[y_high * width + x_high];
  // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);

  w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx;

  return;
}
```

### 14. add
Add function.
```cpp
template <class T>
inline void add(T* address, const T& val) {
  *address += val;
}
```


In [8]:
align_roi_sum_cpu = align_roi_cpu.sum()
align_roi_sum_cpu.backward()
print(feats_cpu.grad)

tensor([[[[0.0069, 0.0556, 0.0625, 0.0625, 0.0556, 0.0069, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.0556, 0.4444, 0.5000, 0.5000, 0.4444, 0.0556, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.0625, 0.5000, 0.5625, 0.5625, 0.5000, 0.0625, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.0625, 0.5000, 0.5625, 0.5625, 0.5000, 0.0625, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.0556, 0.4444, 0.5000, 0.5000, 0.4444, 0.0556, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.0069, 0.0556, 0.0625, 0.0625, 0.0556, 0.0069, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000

### 14. roi_align_backward_cuda

```cpp
void roi_align_backward_cuda(Tensor grad_output, Tensor rois,
                                        Tensor argmax_y, Tensor argmax_x,
                                        Tensor grad_input, int aligned_height,
                                        int aligned_width, float spatial_scale,
                                        int sampling_ratio, int pool_mode,
                                        bool aligned) {
  int output_size = grad_output.numel();
  int channels = grad_input.size(1);
  int height = grad_input.size(2);
  int width = grad_input.size(3);

  at::cuda::CUDAGuard device_guard(grad_output.device());
  cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  AT_DISPATCH_FLOATING_TYPES_AND_HALF(
      grad_output.scalar_type(), "roi_align_backward_cuda_kernel", [&] {
        roi_align_backward_cuda_kernel<scalar_t>
            <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
                output_size, grad_output.data_ptr<scalar_t>(),
                rois.data_ptr<scalar_t>(), argmax_y.data_ptr<scalar_t>(),
                argmax_x.data_ptr<scalar_t>(), grad_input.data_ptr<scalar_t>(),
                aligned_height, aligned_width,
                static_cast<scalar_t>(spatial_scale), sampling_ratio, pool_mode,
                aligned, channels, height, width);
      });

  AT_CUDA_CHECK(cudaGetLastError());
}
```

### 15. roi_align_backward_cuda_kernel
It is almost the same with `roi_align_backward_cpu_kernel`. Note that in cuda kernel we should use `atomicAdd`.

```cpp
template <typename T>
__global__ void roi_align_backward_cuda_kernel(
    const int nthreads, const T* grad_output, const T* rois, const T* argmax_y,
    const T* argmax_x, T* grad_input, const int pooled_height,
    const int pooled_width, const T spatial_scale, const int sampling_ratio,
    const int pool_mode,  // 0 - max pool, 1 - avg pool
    const bool aligned, const int channels, const int height, const int width) {
  CUDA_1D_KERNEL_LOOP(index, nthreads) {
    // (n, c, ph, pw) is an element in the pooled output
    int pw = index % pooled_width;
    int ph = (index / pooled_width) % pooled_height;
    int c = (index / pooled_width / pooled_height) % channels;
    int n = index / pooled_width / pooled_height / channels;

    const T grad_output_this_bin = grad_output[index];

    const T* offset_rois = rois + n * 5;
    int roi_batch_ind = offset_rois[0];
    T* offset_grad_input =
        grad_input + ((roi_batch_ind * channels + c) * height * width);

    if (pool_mode == 0) {
      T y = argmax_y[index], x = argmax_x[index];
      if (y != -1.f) {
        T w1, w2, w3, w4;
        int x_low, x_high, y_low, y_high;
        bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4,
                                      x_low, x_high, y_low, y_high, index);

        if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
          atomicAdd(offset_grad_input + y_low * width + x_low,
                    grad_output_this_bin * w1);
          atomicAdd(offset_grad_input + y_low * width + x_high,
                    grad_output_this_bin * w2);
          atomicAdd(offset_grad_input + y_high * width + x_low,
                    grad_output_this_bin * w3);
          atomicAdd(offset_grad_input + y_high * width + x_high,
                    grad_output_this_bin * w4);
        }
      }
    } else if (pool_mode == 1) {
      // Do not using rounding; this implementation detail is critical
      T offset = aligned ? (T)0.5 : (T)0.0;
      T roi_start_w = offset_rois[1] * spatial_scale - offset;
      T roi_start_h = offset_rois[2] * spatial_scale - offset;
      T roi_end_w = offset_rois[3] * spatial_scale - offset;
      T roi_end_h = offset_rois[4] * spatial_scale - offset;

      T roi_width = roi_end_w - roi_start_w;
      T roi_height = roi_end_h - roi_start_h;
      if (!aligned) {  // for backward-compatibility only
        roi_width = max(roi_width, (T)1.);
        roi_height = max(roi_height, (T)1.);
      }

      T bin_size_h = static_cast<T>(roi_height) / static_cast<T>(pooled_height);
      T bin_size_w = static_cast<T>(roi_width) / static_cast<T>(pooled_width);

      // We use roi_bin_grid to sample the grid and mimic integral
      int roi_bin_grid_h =
          (sampling_ratio > 0)
              ? sampling_ratio
              : static_cast<int>(ceilf(roi_height / pooled_height));
      int roi_bin_grid_w =
          (sampling_ratio > 0)
              ? sampling_ratio
              : static_cast<int>(ceilf(roi_width / pooled_width));

      // We do average (integral) pooling inside a bin
      const T count = roi_bin_grid_h * roi_bin_grid_w;  // e.g. = 4

      for (int iy = 0; iy < roi_bin_grid_h; iy++) {
        const T y = roi_start_h + ph * bin_size_h +
                    static_cast<T>(iy + .5f) * bin_size_h /
                        static_cast<T>(roi_bin_grid_h);
        for (int ix = 0; ix < roi_bin_grid_w; ix++) {
          const T x = roi_start_w + pw * bin_size_w +
                      static_cast<T>(ix + .5f) * bin_size_w /
                          static_cast<T>(roi_bin_grid_w);

          T w1, w2, w3, w4;
          int x_low, x_high, y_low, y_high;
          bilinear_interpolate_gradient(height, width, y, x, w1, w2, w3, w4,
                                        x_low, x_high, y_low, y_high, index);

          if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) {
            atomicAdd(offset_grad_input + y_low * width + x_low,
                      grad_output_this_bin * w1 / count);
            atomicAdd(offset_grad_input + y_low * width + x_high,
                      grad_output_this_bin * w2 / count);
            atomicAdd(offset_grad_input + y_high * width + x_low,
                      grad_output_this_bin * w3 / count);
            atomicAdd(offset_grad_input + y_high * width + x_high,
                      grad_output_this_bin * w4 / count);
          }
        }
      }
    }
  }
}
```

In [9]:
align_roi_sum_cuda = align_roi_cuda.sum()
align_roi_sum_cuda.backward()
print(feats_cuda.grad)

tensor([[[[0.0069, 0.0556, 0.0625, 0.0625, 0.0556, 0.0069, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.0556, 0.4444, 0.5000, 0.5000, 0.4444, 0.0556, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.0625, 0.5000, 0.5625, 0.5625, 0.5000, 0.0625, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.0625, 0.5000, 0.5625, 0.5625, 0.5000, 0.0625, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.0556, 0.4444, 0.5000, 0.5000, 0.4444, 0.0556, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.0069, 0.0556, 0.0625, 0.0625, 0.0556, 0.0069, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000