In [1]:
import torch
import psroi_align_cuda

In [2]:
from psroi_align import PSRoIAlignFunction
from torch.autograd import gradcheck

In [3]:
from psroi_align import PSRoIAlign

In [4]:
roi_size = 3
pooled_dim = 2
sampling_ratio = 2

In [5]:
crop_layer = PSRoIAlign(1/16., roi_size, sampling_ratio, pooled_dim)

In [6]:
bottom_data = torch.rand([1,pooled_dim * roi_size * roi_size, 20, 20]).cuda()
bottom_data.requires_grad = True
bottom_rois = torch.tensor([[20., 20., 180., 180.], [30., 30., 250., 210.]]).cuda()
top_data = torch.zeros(2,pooled_dim,roi_size,roi_size).cuda()
argmax_data = torch.zeros([2,pooled_dim,roi_size,roi_size], dtype=torch.int32).cuda()

In [7]:
top_data = crop_layer(bottom_data, bottom_rois)

In [8]:
top_data.shape

torch.Size([2, 2, 3, 3])

In [9]:
top_data = top_data.view(4,9)

In [10]:
top_data.shape

torch.Size([4, 9])

In [11]:
l = torch.nn.Linear(9,2,False).cuda()

In [12]:
l1 = l(top_data)

In [13]:
loss = torch.nn.functional.cross_entropy(l1, torch.randint(2, (4,), dtype=torch.int64).cuda())

In [14]:
loss *= 1000

In [15]:
bottom_data.grad.sum()

AttributeError: 'NoneType' object has no attribute 'sum'

In [16]:
loss.backward()

In [17]:
bottom_data.grad.max()

tensor(39.3232, device='cuda:0')

In [None]:
bottom_data.grad.sum()

In [None]:
psroialign = PSRoIAlignFunction(0.0625, roi_size, sampling_ratio, pooled_dim)

In [None]:
torch.abs(bottom_data).max().item()/500

In [None]:
gradcheck(PSRoIAlignFunction.apply, (bottom_data, bottom_rois, 1/16., roi_size, sampling_ratio, pooled_dim), eps=torch.abs(bottom_data).max().item()/500)

In [None]:
gradcheck(PSRoIAlignFunction.apply, (bottom_data, bottom_rois, 1/16., roi_size, sampling_ratio, pooled_dim), atol=1e-3, eps=1e-3)

In [18]:
gradcheck(crop_layer.forward, (bottom_data, bottom_rois), atol=1e-3, eps=1e-3)

  'At least one of the inputs that requires gradient '


True

In [None]:
top_data_ = psroialign(bottom_data, bottom_rois)

In [None]:
torch.sum(torch.abs(top_data_))

In [None]:
out = torch.sum(top_data_)

In [None]:
out.backward(top_diff)

In [None]:
bottom_data.grad.sum()

In [None]:
torch.sum(torch.abs(bottom_data.grad))

In [None]:
argmax_data.device

In [None]:
psroi_align_cuda.forward(bottom_data, bottom_rois, top_data, argmax_data, 0.0625, roi_size, sampling_ratio)

In [None]:
top_clone = top_data.clone()

In [None]:
r = psroi_align_cuda.forward(bottom_data.clone(), bottom_rois.clone(), top_clone, argmax_data.clone(), 0.0625, roi_size, sampling_ratio)

In [None]:
torch.sum(torch.abs(top_clone))

In [None]:
r[1].shape

In [None]:
top_data.shape

In [None]:
argmax_data

In [None]:
top_data[0]

In [None]:
top_diff = torch.rand_like(top_data).cuda()
bottom_diff = torch.zeros_like(bottom_data).cuda()

In [None]:
top_diff.shape

In [None]:
bottom_diff.shape

In [None]:
argmax_data.shape

In [None]:
top_diff.shape

In [None]:
psroi_align_cuda.backward(top_diff, argmax_data, bottom_rois, bottom_diff, 0.0625, 3, 2)

In [None]:
bottom_diff[0][2]

In [None]:
bottom_diff.sum()

In [None]:
torch.mean(torch.abs(grad[0][0,0])) 

In [None]:
torch.sum(torch.abs(bottom_diff[0,0])) 

In [None]:
grad[0].shape

In [None]:
bottom_diff.shape

In [None]:
torch.sum(torch.abs(bottom_diff))

In [None]:
torch.sum(torch.abs(grad[0]))

In [None]:
a.shape

In [None]:
b.shape

In [None]:
argmax_data.shape

In [None]:
argmax_data[0]

In [None]:
argmax_data[1]

In [None]:
top_data_cpu, argmax_data_cpu = forward_cpu(bottom_data.cpu().numpy(), bottom_rois.cpu().numpy(), 10, 3, 3, 3, 0.0625, 2) 

In [None]:
top_data_cpu

In [None]:
np.sum(np.abs(top_data.cpu().numpy() - top_data_cpu))

In [None]:
np.sum(np.abs(argmax_data.cpu().numpy() - argmax_data_cpu))

In [None]:
import numpy as np
import six
def forward_cpu(bottom_data, bottom_rois, pooled_dim, pooled_width, pooled_height, group_size, spatial_scale, sampling_ratio):
    _bottom_data_shape = bottom_data.shape
    print('bottom_data.shape : {}'.format(bottom_data.shape))
    print('bottom_rois.shape : {}'.format(bottom_rois.shape))
    channels, height, width = bottom_data.shape[1:]
    print('channels : {}, height : {}, width : {}'.format(channels, height, width))
    n_roi = bottom_rois.shape[0]
    print('n_roi : {}'.format(sampling_ratio))
    top_data = np.empty((n_roi, pooled_dim, pooled_height, pooled_width), dtype=np.float32)
    argmax_data = np.empty(top_data.shape, dtype=np.int32)

    for i in six.moves.range(top_data.size):
        pw = i % pooled_width
        ph = int(i / pooled_width) % pooled_height
        ctop = int(i / pooled_width / pooled_height) % pooled_dim
        n = int(i / pooled_width / pooled_height / pooled_dim)

        roi_start_h = bottom_rois[n, 0] * spatial_scale
        roi_start_w = bottom_rois[n, 1] * spatial_scale
        roi_end_h = bottom_rois[n, 2] * spatial_scale
        roi_end_w = bottom_rois[n, 3] * spatial_scale

        roi_height = max(roi_end_h - roi_start_h, 1.)
        roi_width = max(roi_end_w - roi_start_w, 1.)
        bin_size_h = 1. * roi_height / pooled_height
        bin_size_w = 1. * roi_width / pooled_width

        gh = np.floor(float(ph) * group_size / pooled_height)
        gw = np.floor(float(pw) * group_size / pooled_width)
        gh = int(min(max(gh, 0), group_size - 1))
        gw = int(min(max(gw, 0), group_size - 1))
        c = (ctop * group_size + gh) * group_size + gw

        if sampling_ratio > 0:
            roi_bin_grid_h = sampling_ratio
            roi_bin_grid_w = sampling_ratio
        else:
            roi_bin_grid_h = np.ceil(roi_height / pooled_height)
            roi_bin_grid_w = np.ceil(roi_width / pooled_width)

        maxval = -1e20
        maxidx = -1
        iy = 0
        while iy < roi_bin_grid_h:
            y = roi_start_h + ph * bin_size_h + (iy + .5) * bin_size_h / roi_bin_grid_h
            ix = 0
            while ix < roi_bin_grid_w:
                x = roi_start_w + pw * bin_size_w + (ix + .5) * bin_size_w / roi_bin_grid_w
                # bilinear interpolation {{
                if y < -1 or y > height or x < -1 or x > width:
                # empty
                    continue

                if y <= 0:
                    y = 0
                if x <= 0:
                    x = 0

                y_low = int(y)
                x_low = int(x)

                if y_low >= height - 1:
                    y_high = y_low = height - 1
                    y = float(y_low)
                else:
                    y_high = y_low + 1

                if x_low >= width - 1:
                    x_high = x_low = width - 1
                    x = float(x_low)
                else:
                    x_high = x_low + 1

                ly = y - y_low
                lx = x - x_low
                hy = 1. - ly
                hx = 1. - lx

                v1 = bottom_data[0, c, y_low, x_low]
                v2 = bottom_data[0, c, y_low, x_high]
                v3 = bottom_data[0, c, y_high, x_low]
                v4 = bottom_data[0, c, y_high, x_high]

                w1 = hy * hx
                w2 = hy * lx
                w3 = ly * hx
                w4 = ly * lx

                tmpval = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4
                bottom_index = iy * roi_bin_grid_w + ix
                if (tmpval > maxval):
                    maxval = tmpval
                    maxidx = bottom_index

                ix += 1
            iy += 1

        top_data[n, ctop, ph, pw] = maxval
        argmax_data[n, ctop, ph, pw] = maxidx

    return top_data, argmax_data

In [None]:
_bottom_data_shape = bottom_data.cpu().numpy().shape

In [None]:
top_diff.shape

In [None]:
bottom_diff_cpu = backward_cpu(bottom_rois.cpu().numpy(), argmax_data.cpu().numpy(), _bottom_data_shape, top_diff.cpu().numpy(), 0.0625, sampling_ratio, pooled_dim, roi_size, roi_size, roi_size)

In [None]:
bottom_diff_cpu[0].shape

In [None]:
np.sum(np.abs(bottom_diff_cpu[0]))

In [None]:
def backward_cpu(bottom_rois, argmax_data, _bottom_data_shape, gy, spatial_scale, sampling_ratio, out_c, out_h, out_w, group_size):
    channels, height, width = _bottom_data_shape[1:]
    bottom_diff = np.zeros(_bottom_data_shape, np.float32)

    spatial_scale = spatial_scale
    pooled_dim = out_c
    pooled_height = out_h
    pooled_width = out_w
    group_size = group_size
    top_diff = gy

    for i in six.moves.range(top_diff.size):
        pw = i % pooled_width
        ph = int(i / pooled_width) % pooled_height
        ctop = int(i / pooled_width / pooled_height) % pooled_dim
        n = int(i / pooled_width / pooled_height / pooled_dim)

        roi_start_h = bottom_rois[n, 0] * spatial_scale
        roi_start_w = bottom_rois[n, 1] * spatial_scale
        roi_end_h = bottom_rois[n, 2] * spatial_scale
        roi_end_w = bottom_rois[n, 3] * spatial_scale

        roi_width = max(roi_end_w - roi_start_w, 1.)
        roi_height = max(roi_end_h - roi_start_h, 1.)
        bin_size_h = 1. * roi_height / pooled_height
        bin_size_w = 1. * roi_width / pooled_width

        gh = np.floor(float(ph) * group_size / pooled_height)
        gw = np.floor(float(pw) * group_size / pooled_width)
        gh = int(min(max(gh, 0), group_size - 1))
        gw = int(min(max(gw, 0), group_size - 1))
        c = (ctop * group_size + gh) * group_size + gw

        top_diff_this_bin = top_diff[n, ctop, ph, pw]

        if sampling_ratio > 0:
            roi_bin_grid_h = sampling_ratio
            roi_bin_grid_w = sampling_ratio
        else:
            roi_bin_grid_h = np.ceil(roi_height / pooled_height)
            roi_bin_grid_w = np.ceil(roi_width / pooled_width)

        maxidx = argmax_data[n, ctop, ph, pw]
        iy = int(maxidx / roi_bin_grid_w)
        ix = maxidx % roi_bin_grid_w

        y = roi_start_h + ph * bin_size_h + \
        (iy + .5) * bin_size_h / roi_bin_grid_h
        x = roi_start_w + pw * bin_size_w + \
        (ix + .5) * bin_size_w / roi_bin_grid_w

        # bilinear_interpolation_gradient {{
        if y < -1 or y > height or x < -1 or x > width:
            # empty
            continue

        if y <= 0:
            y = 0
        if x <= 0:
            x = 0

        y_low = int(y)
        x_low = int(x)

        if y_low >= height - 1:
            y_high = y_low = height - 1
            y = float(y_low)
        else:
            y_high = y_low + 1

        if x_low >= width - 1:
            x_high = x_low = width - 1
            x = float(x_low)
        else:
            x_high = x_low + 1

        ly = y - y_low
        lx = x - x_low
        hy = 1. - ly
        hx = 1. - lx

        w1 = hy * hx
        w2 = hy * lx
        w3 = ly * hx
        w4 = ly * lx
        # }}

        g1 = top_diff_this_bin * w1
        g2 = top_diff_this_bin * w2
        g3 = top_diff_this_bin * w3
        g4 = top_diff_this_bin * w4

        if (x_low >= 0 and x_high >= 0 and y_low >= 0 and y_high >= 0):
            bottom_diff[0, c, y_low, x_low] += g1
            bottom_diff[0, c, y_low, x_high] += g2
            bottom_diff[0, c, y_high, x_low] += g3
            bottom_diff[0, c, y_high, x_high] += g4

    return bottom_diff, None, None