In [None]:
## Comparing Pytorch Torchvision MultiScale ROI_Align layer intrinsics with
## native pytorch and tensorflow versions using pickled IO data

## https://github.com/tensorflow/models/blob/master/research/...
## .../object_detection/utils/spatial_transform_ops.py

## Tensorflow Model Garden Research Utils
# def multilevel_roi_align(features, boxes, box_levels, output_size,
#                         num_samples_per_cell_y=1, num_samples_per_cell_x=1,
#                         align_corners=False, extrapolation_value=0.0,
#                         scope=None):

## PyTorch TorchVision intrinsics code
## https://github.com/pytorch/vision/blob/main/torchvision/ops/roi_align.py
## -> https://github.com/pytorch/vision/blob/main/torchvision/csrc/ops/roi_align.cpp
## -> https://github.com/pytorch/vision/blob/main/torchvision/csrc/ops/cpu/roi_align_kernel.cpp
## -> https://github.com/pytorch/vision/blob/main/torchvision/csrc/ops/cpu/roi_align_common.h

In [None]:
import os
## Turn off warnings about missing NVinfer libs in TF
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

## try to stay off the GPU
os.environ['CUDA_VISIBLE_DEVICES'] = ""

import torch
import torchvision
import tensorflow as tf
import pickle
import numpy as np

## features data from collections class, but this import isn't needed.
# from collections import OrderedDict

In [None]:
## Pickle data captured pre-post torchvision layer on live wheathead model inference.
with open('frcnn_msroia_inputs_orig.pickle', 'rb') as handle:
    features, boxes, image_shapes, scales, map_levels, output_size, sampling_ratio, canonical_scale, canonical_level = pickle.load(handle)

with open('frcnn_msroia_outputs_orig.pickle', 'rb') as handle:
    x_filtered, re_scales, re_map_levels, roi_out = pickle.load(handle)

In [None]:
print("features:", type(features), len(features), features.keys(), type(features['0']), features['0'].shape)
print("boxes:", type(boxes), len(boxes), type(boxes[0]), len(boxes[0]),
    type(boxes[0][0]), boxes[0][0].shape, boxes[0][0][0].dtype)
print("image_shapes:", type(image_shapes), len(image_shapes), type(image_shapes[0]), len(image_shapes[0]))
print("scales:", type(scales), len(scales), type(scales[0]), scales[0])
print("map_levels:", type(map_levels))
print("output_size:", type(output_size), len(output_size),
    type(output_size[0]), type(output_size[1]),
    output_size[0], output_size[1])
print("sampling_ratio:", type(sampling_ratio), sampling_ratio)
print("canonical_scale:", type(canonical_scale), canonical_scale)
print("canonical_level:", type(canonical_level), canonical_level)
print("x_filtered:", type(x_filtered), len(x_filtered), type(x_filtered[0]), x_filtered[0].shape)
print("re_scales:", type(re_scales), len(re_scales), type(re_scales[0]), re_scales[0])
print("re_map_levels:", type(re_map_levels))
print("roi_out:", type(roi_out), roi_out.shape)

# TorchVision with Intrinsics version (original)

In [None]:
## This pulls up the torchvision library version!
roi_pooler = torchvision.ops.MultiScaleRoIAlign(
        featmap_names=['0'],
        output_size=7,
        sampling_ratio=2
    )
roi_output = roi_pooler(features, boxes, image_shapes)
print("roi_output:", type(roi_output), roi_output.shape)

## compare CPU standalone version to CUDA from embedded model layer.
if torch.all(roi_output.eq(roi_out)):
    print("Output matches exactly!")

# TensorFlow Model Garden MultiScale ROI Align version

In [None]:
## https://github.com/tensorflow/models/blob/master/research/object_detection/utils/shape_utils.py
def tf_combined_static_and_dynamic_shape(tensor):
  static_tensor_shape = tensor.shape.as_list()
  dynamic_tensor_shape = tf.shape(tensor)
  combined_shape = []
  for index, dim in enumerate(static_tensor_shape):
    if dim is not None:
      combined_shape.append(dim)
    else:
      combined_shape.append(dynamic_tensor_shape[index])
  return combined_shape

In [None]:
def tf_gather_valid_indices(tensor, indices, padding_value=0.0):
  padded_tensor = tf.concat(
      [
          padding_value *
          tf.ones([1, tf.shape(tensor)[-1]], dtype=tensor.dtype), tensor
      ],
      axis=0,
  )
  padded_tensor *= 1.0
  return tf.gather(padded_tensor, indices + 1)

In [None]:
def tf_valid_indicator(feature_grid_y, feature_grid_x, true_feature_shapes):
  height = tf.cast(true_feature_shapes[:, :, 0:1], dtype=feature_grid_y.dtype)
  width = tf.cast(true_feature_shapes[:, :, 1:2], dtype=feature_grid_x.dtype)
  valid_indicator = tf.logical_and(
      tf.expand_dims(
          tf.logical_and(feature_grid_y >= 0, tf.less(feature_grid_y, height)),
          3),
      tf.expand_dims(
          tf.logical_and(feature_grid_x >= 0, tf.less(feature_grid_x, width)),
          2))
  return tf.reshape(valid_indicator, [-1])

In [None]:
def tf_ravel_indices(feature_grid_y, feature_grid_x, num_levels, height, width,
                  box_levels):
  num_boxes = tf.shape(feature_grid_y)[1]
  batch_size = tf.shape(feature_grid_y)[0]
  size_y = tf.shape(feature_grid_y)[2]
  size_x = tf.shape(feature_grid_x)[2]
  height_dim_offset = width
  level_dim_offset = height * height_dim_offset
  batch_dim_offset = num_levels * level_dim_offset

  batch_dim_indices = (
      tf.reshape(
          tf.range(batch_size) * batch_dim_offset, [batch_size, 1, 1, 1]) *
      tf.ones([1, num_boxes, size_y, size_x], dtype=tf.int32))
  box_level_indices = (
      tf.reshape(box_levels * level_dim_offset, [batch_size, num_boxes, 1, 1]) *
      tf.ones([1, 1, size_y, size_x], dtype=tf.int32))
  height_indices = (
      tf.reshape(feature_grid_y * height_dim_offset,
                 [batch_size, num_boxes, size_y, 1]) *
      tf.ones([1, 1, 1, size_x], dtype=tf.int32))
  width_indices = (
      tf.reshape(feature_grid_x, [batch_size, num_boxes, 1, size_x])
      * tf.ones([1, 1, size_y, 1], dtype=tf.int32))
  indices = (
      batch_dim_indices + box_level_indices + height_indices + width_indices)
  flattened_indices = tf.reshape(indices, [-1])
  return flattened_indices

In [None]:
def tf_feature_grid_coordinate_vectors(box_grid_y, box_grid_x):
  feature_grid_y0 = tf.floor(box_grid_y)
  feature_grid_x0 = tf.floor(box_grid_x)
  feature_grid_y1 = tf.floor(box_grid_y + 1)
  feature_grid_x1 = tf.floor(box_grid_x + 1)
  feature_grid_y0 = tf.cast(feature_grid_y0, dtype=tf.int32)
  feature_grid_y1 = tf.cast(feature_grid_y1, dtype=tf.int32)
  feature_grid_x0 = tf.cast(feature_grid_x0, dtype=tf.int32)
  feature_grid_x1 = tf.cast(feature_grid_x1, dtype=tf.int32)
  return (feature_grid_y0, feature_grid_x0, feature_grid_y1, feature_grid_x1)

In [None]:
def tf_coordinate_vector_1d(start, end, size, align_endpoints):
  start = tf.expand_dims(start, -1)
  end = tf.expand_dims(end, -1)
  length = end - start
  if align_endpoints:
    relative_grid_spacing = tf.linspace(0.0, 1.0, size)
    offset = 0 if size > 1 else length / 2
  else:
    relative_grid_spacing = tf.linspace(0.0, 1.0, size + 1)[:-1]
    offset = length / (2 * size)
  relative_grid_spacing = tf.reshape(relative_grid_spacing, [1, 1, size])
  relative_grid_spacing = tf.cast(relative_grid_spacing, dtype=start.dtype)
  absolute_grid = start + offset + relative_grid_spacing * length
  return absolute_grid

In [None]:
def tf_box_grid_coordinate_vectors(boxes, size_y, size_x, align_corners=False):
  ymin, xmin, ymax, xmax = tf.unstack(boxes, axis=-1)
  box_grid_y = tf_coordinate_vector_1d(ymin, ymax, size_y, align_corners)
  box_grid_x = tf_coordinate_vector_1d(xmin, xmax, size_x, align_corners)
  return box_grid_y, box_grid_x

In [None]:
def tf_pad_to_max_size(features):
  if len(features) == 1:
    return tf.expand_dims(features[0],
                          1), tf.expand_dims(tf.shape(features[0])[1:3], 0)

  if all([feature.shape.is_fully_defined() for feature in features]):
    heights = [feature.shape[1] for feature in features]
    widths = [feature.shape[2] for feature in features]
    max_height = max(heights)
    max_width = max(widths)
  else:
    heights = [tf.shape(feature)[1] for feature in features]
    widths = [tf.shape(feature)[2] for feature in features]
    max_height = tf.reduce_max(heights)
    max_width = tf.reduce_max(widths)
  features_all = [
      tf.image.pad_to_bounding_box(feature, 0, 0, max_height,
                                   max_width) for feature in features
  ]
  features_all = tf.stack(features_all, axis=1)
  true_feature_shapes = tf.stack([tf.shape(feature)[1:3]
                                  for feature in features])
  return features_all, true_feature_shapes

In [None]:
## let's build the TF version
## https://github.com/tensorflow/models/blob/master/research/object_detection/...
## .../utils/spatial_transform_ops.py

def tf_multilevel_roi_align(features, boxes, box_levels, output_size,
                         num_samples_per_cell_y=1, num_samples_per_cell_x=1,
                         align_corners=False, extrapolation_value=0.0,
                         scope=None):
    with tf.name_scope(scope if scope is not None else 'MultiLevelRoIAlign'):
        features, true_feature_shapes = tf_pad_to_max_size(features)
        batch_size = tf_combined_static_and_dynamic_shape(features)[0]
        num_levels = features.get_shape().as_list()[1]
        max_feature_height = tf.shape(features)[2]
        max_feature_width = tf.shape(features)[3]
        num_filters = features.get_shape().as_list()[4]
        num_boxes = tf.shape(boxes)[1]
        
        print("num_levels:", num_levels)
        print("num_filters:", num_filters)
        print("num_boxes:", num_boxes)
        ## note: num_boxes is tensor, .numpy() converts, but no behavior change

        # Convert boxes to absolute co-ordinates.
        true_feature_shapes = tf.cast(true_feature_shapes, dtype=boxes.dtype)
        true_feature_shapes = tf.gather(true_feature_shapes, box_levels)
        boxes *= tf.concat([true_feature_shapes - 1] * 2, axis=-1)

        size_y = output_size[0] * num_samples_per_cell_y
        size_x = output_size[1] * num_samples_per_cell_x
        print("size_x, size_y:", size_x, size_y)
        box_grid_y, box_grid_x = tf_box_grid_coordinate_vectors(
            boxes, size_y=size_y, size_x=size_x, align_corners=align_corners)
        
        # print("box grids:", box_grid_y, box_grid_x) ## looks ok
        
        (feature_grid_y0, feature_grid_x0, feature_grid_y1,
         feature_grid_x1) = tf_feature_grid_coordinate_vectors(box_grid_y, box_grid_x)
        
        feature_grid_y = tf.reshape(
            tf.stack([feature_grid_y0, feature_grid_y1], axis=3),
            [batch_size, num_boxes, -1])
        feature_grid_x = tf.reshape(
            tf.stack([feature_grid_x0, feature_grid_x1], axis=3),
            [batch_size, num_boxes, -1])
        
        feature_coordinates = tf_ravel_indices(feature_grid_y, feature_grid_x,
                                            num_levels, max_feature_height,
                                            max_feature_width, box_levels)
        
        #print("feature_coordinates:", feature_coordinates)
        ## Note: 90356 int32.
        valid_indices = tf_valid_indicator(feature_grid_y, feature_grid_x,
                                         true_feature_shapes)
        feature_coordinates = tf.where(valid_indices, feature_coordinates,
                                       -1 * tf.ones_like(feature_coordinates))
        #print("feature_coordinates:", feature_coordinates)
        
        flattened_features = tf.reshape(features, [-1, num_filters])
        #print("flattened_features:", flattened_features)
        ## FF is sparse array, mostly valid values, shape=(1024, 512), dtype=float32
        
        flattened_feature_values = tf_gather_valid_indices(flattened_features,
                                                         feature_coordinates,
                                                         extrapolation_value)
        #print("flattened_feature_values:", flattened_feature_values)
        
        
        features_per_box = tf.reshape(
            flattened_feature_values,
            [batch_size, num_boxes, size_y * 2, size_x * 2, num_filters])

        # Cast tensors into dtype of features.
        box_grid_y = tf.cast(box_grid_y, dtype=features_per_box.dtype)
        box_grid_x = tf.cast(box_grid_x, dtype=features_per_box.dtype)
        feature_grid_y0 = tf.cast(feature_grid_y0, dtype=features_per_box.dtype)
        feature_grid_x0 = tf.cast(feature_grid_x0, dtype=features_per_box.dtype)

        ly = box_grid_y - feature_grid_y0
        lx = box_grid_x - feature_grid_x0
        hy = 1.0 - ly
        hx = 1.0 - lx

        kernel_y = tf.reshape(
            tf.stack([hy, ly], axis=3), [batch_size, num_boxes, size_y * 2, 1])

        kernel_x = tf.reshape(
            tf.stack([hx, lx], axis=3), [batch_size, num_boxes, 1, size_x * 2])

        # Multiplier 4 is to make tf.nn.avg_pool behave like sum_pool.
        interpolation_kernel = kernel_y * kernel_x * 4

        # Interpolate the gathered features with computed interpolation kernels.
        features_per_box *= tf.expand_dims(interpolation_kernel, axis=4),
        features_per_box = tf.reshape(
            features_per_box,
            [batch_size * num_boxes, size_y * 2, size_x * 2, num_filters])

        # This combines the two pooling operations - sum_pool to perform bilinear
        # interpolation and avg_pool to pool the values in each bin.
        features_per_box = tf.nn.avg_pool(
            features_per_box,
            [1, num_samples_per_cell_y * 2, num_samples_per_cell_x * 2, 1],
            [1, num_samples_per_cell_y * 2, num_samples_per_cell_x * 2, 1], 'VALID')
        features_per_box = tf.reshape(
            features_per_box,
            [batch_size, num_boxes, output_size[0], output_size[1], num_filters])

        return features_per_box

In [None]:
BoxSize = 1024.0

np_xt = features['0'].numpy()
tf_x = [tf.transpose(tf.convert_to_tensor(np_xt), perm=[0, 2, 3, 1])]
print("tf_x[0]:", type(tf_x[0]), tf_x[0].shape)

## TF ROI-Align want's normalized.
np_boxes = boxes[0].numpy().reshape(1, -1, 4)
tf_boxes = tf.convert_to_tensor(np_boxes / BoxSize, dtype=tf.float32)
print("tf_boxes:", type(tf_boxes), tf_boxes.shape)

tf_box_levels = tf.zeros(shape=(1, tf_boxes.shape[1]), dtype=tf.int32)
tf_output_size = [7, 7]

tf_output = tf_multilevel_roi_align(tf_x, tf_boxes, tf_box_levels, tf_output_size)

print("tf_output:", type(tf_output), tf_output.shape)
tptf_output = tf.transpose(tf.squeeze(tf_output, axis=0), perm=[0, 3, 1, 2]).numpy()
## transposed tensorflow numpy

## pytorch torchvision output
print("roi_output:", type(roi_output), roi_output.shape, np.sqrt(np.mean(roi_output.numpy()**2)))

print("tptf_output:", type(tptf_output), tptf_output.shape, np.sqrt(np.mean(tptf_output**2)))

print("x['0']:", type(features['0']), features['0'].shape, np.sqrt(np.mean(features['0'].numpy()**2)))
print("tf_x[0]:", type(tf_x[0]), tf_x[0].shape, np.sqrt(np.mean(tf_x[0].numpy()**2)))


In [None]:
#print(np.min(np_boxes), np.max(np_boxes))
#print(tptf_output - roi_output.numpy())
print(np.sqrt(np.mean(roi_out.numpy()**2)), np.sqrt(np.mean(tptf_output**2)))
A = np.count_nonzero(tptf_output - roi_out.numpy())
B = roi_out.numel()
print(A, "/", B, "=", str(round((100 * (B - A) / B), 3)) + "% Matching!")
## Looks like TF version is very different.

# Native Python / Torch version

In [None]:
def pt_initLevelMapper(
    k_min,
    k_max,
    canonical_scale=224,
    canonical_level=4,
    eps=1e-6
):
    return pt_LevelMapper(k_min, k_max, canonical_scale, canonical_level, eps)


class pt_LevelMapper:
    
    def __init__(
        self,
        k_min,
        k_max,
        canonical_scale= 224,
        canonical_level= 4,
        eps= 1e-6
    ):
        self.k_min = k_min
        self.k_max = k_max
        self.s0 = canonical_scale
        self.lvl0 = canonical_level
        self.eps = eps

    def __call__(self, boxlists):
        # Compute level ids
        s = torch.sqrt(torch.cat([box_area(boxlist) for boxlist in boxlists]))

        # Eqn.(1) in FPN paper
        target_lvls = torch.floor(self.lvl0 + torch.log2(s / self.s0) + torch.tensor(self.eps, dtype=s.dtype))
        target_lvls = torch.clamp(target_lvls, min=self.k_min, max=self.k_max)
        return (target_lvls.to(torch.int64) - self.k_min).to(torch.int64)


In [None]:
def pt_infer_scale(feature, original_size):
    # assumption: the scale is of the form 2 ** (-k), with k integer
    size = feature.shape[-2:]
    possible_scales: List[float] = []
    for s1, s2 in zip(size, original_size):
        approx_scale = float(s1) / float(s2)
        scale = 2 ** float(torch.tensor(approx_scale).log2().round())
        possible_scales.append(scale)
    return possible_scales[0]

In [None]:
def pt_setup_scales(features, image_shapes, canonical_scale, canonical_level):
    if not image_shapes:
        raise ValueError("images list should not be empty")
    max_x = 0
    max_y = 0
    for shape in image_shapes:
        max_x = max(shape[0], max_x)
        max_y = max(shape[1], max_y)
    original_input_shape = (max_x, max_y)

    scales = [pt_infer_scale(feat, original_input_shape) for feat in features]
    # get the levels in the feature map by leveraging the fact that the network always
    # downsamples by a factor of 2 at each level.
    lvl_min = -torch.log2(torch.tensor(scales[0], dtype=torch.float32)).item()
    lvl_max = -torch.log2(torch.tensor(scales[-1], dtype=torch.float32)).item()

    map_levels = pt_initLevelMapper(
        int(lvl_min),
        int(lvl_max),
        canonical_scale=canonical_scale,
        canonical_level=canonical_level,
    )
    return scales, map_levels

In [None]:
def pt_filter_input(x, featmap_names):
    x_filtered = []
    for k, v in x.items():
        if k in featmap_names:
            x_filtered.append(v)
    return x_filtered

In [None]:
## from torchvision/ops/_utils.py
def pt_check_roi_boxes_shape(boxes):
    if isinstance(boxes, (list, tuple)):
        for _tensor in boxes:
            torch._assert(
                _tensor.size(1) == 4, "The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]"
            )
    elif isinstance(boxes, torch.Tensor):
        torch._assert(boxes.size(1) == 5, "The boxes tensor shape is not correct as Tensor[K, 5]")
    else:
        torch._assert(False, "boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]")
    return

In [None]:
def pt_convert_to_roi_format(boxes):
    concat_boxes = torch.cat(boxes, dim=0)
    device, dtype = concat_boxes.device, concat_boxes.dtype
    ids = torch.cat(
        [torch.full_like(b[:, :1],
                         i,
                         dtype=dtype,
                         layout=torch.strided,
                         device=device) for i, b in enumerate(boxes)],
        dim=0,
    )
    rois = torch.cat([ids, concat_boxes], dim=1)
    return rois

In [None]:
def ppt_bilinear_interpolate_v1(
    input,  # [N, C, H, W]
    roi_batch_ind,  # [K]
    y,  # [K, PH, IY]
    x,  # [K, PW, IX]
    ymask,  # [K, IY]
    xmask,  # [K, IX]
):
    _, channels, height, width = input.size()

    # deal with inverse element out of feature map boundary
    y = y.clamp(min=0)
    x = x.clamp(min=0)
    y_low = y.int()
    x_low = x.int()
    y_high = torch.where(y_low >= height - 1, height - 1, y_low + 1)
    y_low = torch.where(y_low >= height - 1, height - 1, y_low)
    y = torch.where(y_low >= height - 1, y.to(input.dtype), y)

    x_high = torch.where(x_low >= width - 1, width - 1, x_low + 1)
    x_low = torch.where(x_low >= width - 1, width - 1, x_low)
    x = torch.where(x_low >= width - 1, x.to(input.dtype), x)

    ly = y - y_low
    lx = x - x_low
    hy = 1.0 - ly
    hx = 1.0 - lx

    # do bilinear interpolation, but respect the masking!
    # TODO: It's possible the masking here is unnecessary if y and
    # x were clamped appropriately; hard to tell
    def masked_index(
        y,  # [K, PH, IY]
        x,  # [K, PW, IX]
    ):
        #print("NOTE: MI:", type(x), type(y), x.dtype, y.dtype, x.shape, y.shape)
        #print("roi_batch_ind:", roi_batch_ind.dtype)
        #print("channels:", type(channels))
        
        if ymask is not None:
            assert xmask is not None
            y = torch.where(ymask[:, None, :], y, 0)
            x = torch.where(xmask[:, None, :], x, 0)
        return input[
            roi_batch_ind[:, None, None, None, None, None].long(),
            torch.arange(channels, device=input.device)[None, :, None, None, None, None].long(),
            y[:, None, :, None, :, None].long(),  # prev [K, PH, IY]
            x[:, None, None, :, None, :].long(),  # prev [K, PW, IX]
        ]  # [K, C, PH, PW, IY, IX]

    v1 = masked_index(y_low, x_low)
    v2 = masked_index(y_low, x_high)
    v3 = masked_index(y_high, x_low)
    v4 = masked_index(y_high, x_high)

    # all ws preemptively [K, C, PH, PW, IY, IX]
    def outer_prod(y, x):
        return y[:, None, :, None, :, None] * x[:, None, None, :, None, :]

    w1 = outer_prod(hy, hx)
    w2 = outer_prod(hy, lx)
    w3 = outer_prod(ly, hx)
    w4 = outer_prod(ly, lx)

    val = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
    return val

In [None]:
def ppt_maybe_cast(tensor):
    if torch.is_autocast_enabled() and tensor.is_cuda and tensor.dtype != torch.double:
        return tensor.float()
    else:
        return tensor

In [None]:
## Pure Pytorch version 1.0
# https://github.com/pytorch/vision/blob/main/torchvision/ops/roi_align.py

def ppt_roi_align_v1(
    input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
    orig_dtype = input.dtype

    input = ppt_maybe_cast(input)
    rois = ppt_maybe_cast(rois)

    _, _, height, width = input.size()

    ph = torch.arange(pooled_height, device=input.device)  # [PH]
    pw = torch.arange(pooled_width, device=input.device)  # [PW]

    # input: [N, C, H, W]
    # rois: [K, 5]

    roi_batch_ind = rois[:, 0].int()  # [K]
    offset = 0.5 if aligned else 0.0
    roi_start_w = rois[:, 1] * spatial_scale - offset  # [K]
    roi_start_h = rois[:, 2] * spatial_scale - offset  # [K]
    roi_end_w = rois[:, 3] * spatial_scale - offset  # [K]
    roi_end_h = rois[:, 4] * spatial_scale - offset  # [K]

    roi_width = roi_end_w - roi_start_w  # [K]
    roi_height = roi_end_h - roi_start_h  # [K]
    if not aligned:
        roi_width = torch.clamp(roi_width, min=1.0)  # [K]
        roi_height = torch.clamp(roi_height, min=1.0)  # [K]

    bin_size_h = roi_height / pooled_height  # [K]
    bin_size_w = roi_width / pooled_width  # [K]

    exact_sampling = sampling_ratio > 0

    roi_bin_grid_h = sampling_ratio if exact_sampling else torch.ceil(roi_height / pooled_height)  # scalar or [K]
    roi_bin_grid_w = sampling_ratio if exact_sampling else torch.ceil(roi_width / pooled_width)  # scalar or [K]

    """
    iy, ix = dims(2)
    """

    if exact_sampling:
        count = max(roi_bin_grid_h * roi_bin_grid_w, 1)  # scalar
        iy = torch.arange(roi_bin_grid_h, device=input.device)  # [IY]
        ix = torch.arange(roi_bin_grid_w, device=input.device)  # [IX]
        ymask = None
        xmask = None
    else:
        count = torch.clamp(roi_bin_grid_h * roi_bin_grid_w, min=1)  # [K]
        # When doing adaptive sampling, the number of samples we need to do
        # is data-dependent based on how big the ROIs are.  This is a bit
        # awkward because first-class dims can't actually handle this.
        # So instead, we inefficiently suppose that we needed to sample ALL
        # the points and mask out things that turned out to be unnecessary
        iy = torch.arange(height, device=input.device)  # [IY]
        ix = torch.arange(width, device=input.device)  # [IX]
        ymask = iy[None, :] < roi_bin_grid_h[:, None]  # [K, IY]
        xmask = ix[None, :] < roi_bin_grid_w[:, None]  # [K, IX]

    def from_K(t):
        return t[:, None, None]

    y = (
        from_K(roi_start_h)
        + ph[None, :, None] * from_K(bin_size_h)
        + (iy[None, None, :] + 0.5) * from_K(bin_size_h / roi_bin_grid_h)
    )  # [K, PH, IY]
    x = (
        from_K(roi_start_w)
        + pw[None, :, None] * from_K(bin_size_w)
        + (ix[None, None, :] + 0.5) * from_K(bin_size_w / roi_bin_grid_w)
    )  # [K, PW, IX]
    ## V1
    val = ppt_bilinear_interpolate_v1(input, roi_batch_ind, y, x, ymask, xmask)  # [K, C, PH, PW, IY, IX]
    ## V2
    #n, c, ph, pw = dims(4)
    #offset_rois = rois[n]
    #roi_batch_ind = offset_rois[0].int()
    #offset_input = input[roi_batch_ind.long()][c]
    #val = ppt_bilinear_interpolate(offset_input, height, width, y, x, ymask, xmask)
    
    
        # Mask out samples that weren't actually adaptively needed
    if not exact_sampling:
        val = torch.where(ymask[:, None, None, None, :, None], val, 0)
        val = torch.where(xmask[:, None, None, None, None, :], val, 0)

    output = val.sum((-1, -2))  # remove IY, IX ~> [K, C, PH, PW]
    if isinstance(count, torch.Tensor):
        output /= count[:, None, None, None]
    else:
        output /= count

    output = output.to(orig_dtype)

    return output

In [None]:
#from torch.nn.modules.utils import _pair
UseOpsVer = False

def pt_roi_align(
    input,
    boxes,
    output_size,
    spatial_scale= 1.0,
    sampling_ratio= -1,
    aligned=False,
):
    
    pt_check_roi_boxes_shape(boxes)
    rois = boxes
    output_size = torch.nn.modules.utils._pair(output_size)
    if not isinstance(rois, torch.Tensor):
        rois = pt_convert_boxes_to_roi_format(rois)
    
    ## This gets exact matching results.
    if UseOpsVer:
        return torch.ops.torchvision.roi_align(
            input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned
        )
    else:
    ## New python / torch only version, 98% matching
    # https://github.com/pytorch/vision/blob/main/torchvision/ops/roi_align.py
        return ppt_roi_align_v1(
            input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned
        )

In [None]:
def pt_multiscale_roi_align(
    x_filtered,
    boxes,
    output_size,
    sampling_ratio,
    scales,
    mapper):
    
    if scales is None or mapper is None:
        raise ValueError("scales and mapper should not be None")

    num_levels = len(x_filtered)
    rois = pt_convert_to_roi_format(boxes)

    if num_levels == 1:
        return pt_roi_align(
            x_filtered[0],
            rois,
            output_size=output_size,
            spatial_scale=scales[0],
            sampling_ratio=sampling_ratio,
        )

    levels = mapper(boxes)

    num_rois = len(rois)
    num_channels = x_filtered[0].shape[1]

    dtype, device = x_filtered[0].dtype, x_filtered[0].device
    result = torch.zeros(
        (
            num_rois,
            num_channels,
        )
        + output_size,
        dtype=dtype,
        device=device,
    )

    tracing_results = []
    for level, (per_level_feature, scale) in enumerate(zip(x_filtered, scales)):
        idx_in_level = torch.where(levels == level)[0]
        rois_per_level = rois[idx_in_level]

        result_idx_in_level = roi_align(
            per_level_feature,
            rois_per_level,
            output_size=output_size,
            spatial_scale=scale,
            sampling_ratio=sampling_ratio,
        )
    
        result[idx_in_level] = result_idx_in_level.to(result.dtype)
    
    return result

In [None]:
## From https://pytorch.org/vision/0.12/_modules/torchvision/ops/poolers.html
## in_tools/models/vision/torchvision/ops/poolers.py 

class pt_MultiScaleRoIAlign(torch.nn.Module):
    
    def __init__(
        self,
        featmap_names,
        output_size,
        sampling_ratio,
        canonical_scale=224,
        canonical_level=4,
    ):
        super().__init__()
        if isinstance(output_size, int):
            output_size = (output_size, output_size)
        self.featmap_names = featmap_names
        self.sampling_ratio = sampling_ratio
        self.output_size = tuple(output_size)
        self.scales = None
        self.map_levels = None
        self.canonical_scale = canonical_scale
        self.canonical_level = canonical_level
 
    def forward(
        self,
        x,
        boxes,
        image_shapes,
    ):
        
        x_filtered = pt_filter_input(x, self.featmap_names)
        if self.scales is None or self.map_levels is None:
            self.scales, self.map_levels = pt_setup_scales(
                x_filtered, image_shapes, self.canonical_scale, self.canonical_level
            )

        return pt_multiscale_roi_align(
            x_filtered,
            boxes,
            self.output_size,
            self.sampling_ratio,
            self.scales,
            self.map_levels,
        )

In [None]:
## This brings up local pytorch version!
pt_roi_pooler = pt_MultiScaleRoIAlign(
        featmap_names=['0'],
        output_size=7,
        sampling_ratio=2
    )

pt_roi_output = pt_roi_pooler(features, boxes, image_shapes)
print("pt_roi_output:", type(pt_roi_output), pt_roi_output.shape)

if torch.all(pt_roi_output.eq(roi_out)):
    print("Output matches exactly!")
else:
    print("Output is NOT the same!")

In [None]:
#print(pt_roi_output - roi_out)
print(np.sqrt(np.mean(roi_output.numpy()**2)), np.sqrt(np.mean(pt_roi_output.numpy()**2)))
A = torch.count_nonzero(pt_roi_output - roi_out).numpy()
B = roi_out.numel()
print(A, "/", B, "=", str(round((100 * (B - A) / B), 3)) + "% Matching!")