In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch import autograd
import copy

In [2]:
def im2col(image, ksize, stride):
    bs, channel, width, height = image.shape
    res = image.unfold(2,ksize,stride).unfold(3,ksize,stride).swapaxes(1,5).swapaxes(1,3)
    size = res.shape
    return res.reshape(size[0], size[1] * size[2], size[3] * size[4] * size[5])

In [3]:
class SConv2d(autograd.Function):
    @staticmethod
    def forward(ctx, input, inputS, weight, weightS, bias=None, stride=1, padding=0, dilation=1, groups=1):
        # col_weights = weight.reshape(weight.shape[0], -1).swapaxes(0,1)
        # input = F.pad(input,tuple(4*[padding]))
        # bs, xc, xw, xh = input.shape
        # oc, _, kw, kh = weight.shape
        # ow, oh = xw - kw + 1, xh - kh + 1

        # col_image = F.unfold(input,(kw,kh)).transpose(1,2)
        # conv_out = col_image.matmul(w.view(w.size(0),-1).t()).transpose(1,2)
        # conv_out = F.fold(conv_out, (ow, oh), (1,1))
        # ctx.save_for_backward(col_image, weight, bias)
        conv_out = F.conv2d(input, weight, bias, stride, padding, dilation, groups)
        padded_input = F.pad(input,tuple(4*[padding]))
        ctx.save_for_backward(padded_input, weight, bias, torch.IntTensor([padding], device=padded_input.device))
        return conv_out, torch.ones_like(conv_out)
    
    @staticmethod
    def backward(ctx, grad_output, grad_outputS):
        input, weight, bias, padding = ctx.saved_tensors
        col_image = F.unfold(input,(3,3)).transpose(1,2)
        bs, channels, ow, oh = grad_output.shape
        oc, ic, kw, kh = weight.shape
        # col_grad_output = grad_output.view(bs, channels, -1)
        grad_w = grad_output.view(bs, channels, -1).bmm(col_image).sum(dim=0).view(weight.shape)
        grad_wS = grad_outputS.view(bs, channels, -1).bmm(col_image**2).sum(dim=0).view(weight.shape) # SSSS

        if bias is None:
            grad_b = None
        else:
            grad_b = grad_output.sum(axis=[0,2,3])

        grad_output_padded = F.pad(grad_output,tuple(4*[kw-1-padding.item()]))
        col_grad = F.unfold(grad_output_padded,(kh,kw)).transpose(1,2)
        grad_outputS_padded = F.pad(grad_outputS,tuple(4*[kw-1-padding.item()])) # SSSS
        col_gradS = F.unfold(grad_outputS_padded,(kh,kw)).transpose(1,2)
        
        flipped_w = weight.flip([2,3]).swapaxes(0,1)
        col_flip = flipped_w.reshape(flipped_w.size(0),-1)
        grad_i = col_grad.matmul(col_flip.t()).transpose(1,2)
        grad_i = F.fold(grad_i, (ow, oh), (1,1))
        grad_iS = col_gradS.matmul(col_flip.t()).transpose(1,2)
        grad_iS = F.fold(grad_iS, (ow, oh), (1,1))

        return grad_i, grad_iS, grad_w, grad_wS, grad_b, None, None, None, None

In [4]:
img = torch.randn(3,5,4,4).requires_grad_()
imgS = torch.ones(3,5,4,4).requires_grad_()
w = torch.randn(2,5,3,3).requires_grad_()
wS = torch.ones(2,5,3,3).requires_grad_()
bias = torch.zeros(2).requires_grad_()
res1,_ = SConv2d.apply(img,imgS,w,wS,bias,1,1)

In [5]:
res1.sum().backward()

In [6]:
T1 = copy.deepcopy(img.grad.data)
T2 = copy.deepcopy(w.grad.data)
T3 = copy.deepcopy(bias.grad.data)

In [7]:
res2 = F.conv2d(img,w,bias,1,1)
img.grad.zero_()
bias.grad.zero_()
w.grad.zero_()
res2.sum().backward()

In [8]:
print((img.grad - T1).abs().max())
print((w.grad - T2).abs().max())
print((bias.grad - T3).abs().max())

tensor(9.5367e-07)
tensor(1.9073e-06)
tensor(0.)


In [41]:
pool = nn.MaxPool2d(2,return_indices=True)
a = torch.randn(2,3,4,4)
r, i = pool(a)
print(a)

tensor([[[[-0.0386, -0.6824, -0.5217, -0.6673],
          [ 0.0108, -1.2389, -0.8433, -0.2332],
          [-2.1736,  0.5494,  0.8251,  1.1906],
          [ 0.2444, -1.0515,  1.6519, -1.5744]],

         [[ 0.4461,  1.6695, -1.1383,  0.2146],
          [ 1.6683, -0.1425, -0.4049,  1.1140],
          [ 0.1828,  2.0047, -0.9005, -0.7964],
          [ 0.5337,  1.0797,  0.0670,  1.3408]],

         [[ 0.5395,  0.9447, -0.7407, -1.6365],
          [ 0.9528,  0.5593, -0.6031,  0.2240],
          [-2.2354, -0.6138, -1.0768,  1.3020],
          [ 0.8636, -1.4690,  0.3154,  0.3527]]],


        [[[ 0.5729,  0.7669,  1.2477, -1.4297],
          [ 1.1332,  0.5756, -0.8515,  0.9711],
          [ 0.6159, -2.3166,  0.8089,  1.0219],
          [ 0.0121, -0.4188,  0.1152, -0.1013]],

         [[ 0.4683, -1.1388,  0.2252,  1.2950],
          [-0.0072,  0.7322, -0.5640, -0.4931],
          [ 0.9604, -0.4985,  0.1336,  0.2823],
          [-1.5844, -1.2147,  1.1050, -1.0086]],

         [[ 1.8520,  0.1835,

In [42]:
print(r)

tensor([[[[ 0.0108, -0.2332],
          [ 0.5494,  1.6519]],

         [[ 1.6695,  1.1140],
          [ 2.0047,  1.3408]],

         [[ 0.9528,  0.2240],
          [ 0.8636,  1.3020]]],


        [[[ 1.1332,  1.2477],
          [ 0.6159,  1.0219]],

         [[ 0.7322,  1.2950],
          [ 0.9604,  1.1050]],

         [[ 1.8520,  0.9110],
          [ 0.2145,  0.8290]]]])


In [43]:
print(i)

tensor([[[[ 4,  7],
          [ 9, 14]],

         [[ 1,  7],
          [ 9, 15]],

         [[ 4,  7],
          [12, 11]]],


        [[[ 4,  2],
          [ 8, 11]],

         [[ 5,  3],
          [ 8, 14]],

         [[ 0,  2],
          [ 9, 11]]]])


In [44]:
print(a.view(2,1,-1)[torch.LongTensor([0,0]),torch.LongTensor([0,0]),torch.LongTensor([5,2])])

tensor([-1.2389, -0.5217])


In [45]:
i.view(2,1,-1)

tensor([[[ 4,  7,  9, 14,  1,  7,  9, 15,  4,  7, 12, 11]],

        [[ 4,  2,  8, 11,  5,  3,  8, 14,  0,  2,  9, 11]]])

In [56]:
def parse_indice(indice):
    bs, ch, w, h = indice.shape
    length = w * h
    BD = torch.LongTensor(list(range(bs))).expand([length,ch,bs]).swapaxes(0,2).reshape(-1)
    CD = torch.LongTensor(list(range(ch))).expand([bs,length,ch]).swapaxes(1,2).reshape(-1)
    return [BD, CD, indice.view(-1)]

T = parse_indice(i)
a.view([2,3,-1])[T].view(i.shape)

tensor([[[[ 0.0108, -0.2332],
          [ 0.5494,  1.6519]],

         [[ 1.6695,  1.1140],
          [ 2.0047,  1.3408]],

         [[ 0.9528,  0.2240],
          [ 0.8636,  1.3020]]],


        [[[ 1.1332,  1.2477],
          [ 0.6159,  1.0219]],

         [[ 0.7322,  1.2950],
          [ 0.9604,  1.1050]],

         [[ 1.8520,  0.9110],
          [ 0.2145,  0.8290]]]])

In [57]:
(1,1) * 2

(1, 1, 1, 1)