In [1]:
# https://github.com/pytorch/pytorch/issues/47990

In [56]:
import torch
import torch.nn as nn
import copy

In [57]:
x = torch.randn(2,3,256,256)
# x = torch.ones(2,3,256,256)
x.shape

torch.Size([2, 3, 256, 256])

In [58]:
# using convoluiton
conv1 = nn.Conv2d(3,64,(3,3),padding=1)
conv_out = conv1(x)
print(conv_out.shape)
# conv_out

torch.Size([2, 64, 256, 256])


In [96]:
class CustomConv(nn.Module):

    def __init__(self, in_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=True):
        super(CustomConv, self).__init__()
        self.kernel_size = kernel_size
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.stride = stride
        self.padding = padding
        self.weight = nn.Parameter(torch.ones(self.out_channel, self.in_channel, self.kernel_size, self.kernel_size))
        self.dweight = nn.Parameter(torch.ones(1, self.in_channel, self.kernel_size, self.kernel_size))
        self.bias = nn.Parameter(torch.ones(self.out_channel))
        self.isbias = bias
        
        self.Unfold = nn.Unfold(kernel_size=(self.kernel_size, self.kernel_size), 
                                stride=self.stride, padding=self.padding) 
        
    def forward(self, x):
        
        batch_size, c, h, w = x.shape
        x = self.Unfold(x)
        x1 = self.dweight.view(1, -1) @ torch.sigmoid(x)
        x = self.weight.view(self.out_channel, -1) @ (x * x1)
        if self.isbias :
            x = x + self.bias[None,:,None]        
        x = x.view(batch_size, self.out_channel, h, w)        

        # batch_size, c, h, w = x.shape
        # x = self.Unfold(x)
        # x = self.weight.view(self.out_channel, -1) @ x
        # if self.isbias :
        #     x = x + self.bias[None,:,None]
        # x = x.view(batch_size, self.out_channel, h, w)

        return x

In [97]:
cus_conv = CustomConv(3, 64)
cus_conv.weight = copy.deepcopy(conv1.weight)
cus_conv.bias = copy.deepcopy(conv1.bias)

In [98]:
cus_conv_out = cus_conv(x)
# print(cus_conv_out.shape)
(torch.abs(cus_conv_out - conv_out) > 0.0001).sum()

torch.Size([2, 1, 65536])


tensor(8388520)

In [9]:
x = torch.ones(2,3,4,4)
x.shape

torch.Size([2, 3, 4, 4])

In [10]:

conv1 = nn.Conv2d(3,4,(3,3),padding=1)
conv1.weight = nn.Parameter(torch.ones(conv1.weight.shape).float())
conv1.bias = nn.Parameter(torch.ones(conv1.bias.shape).float())

In [11]:
conv_out = conv1(x)
print(conv_out.shape)
# conv_out[0,:,0,0]

torch.Size([2, 4, 4, 4])


In [12]:
conv_out

tensor([[[[13., 19., 19., 13.],
          [19., 28., 28., 19.],
          [19., 28., 28., 19.],
          [13., 19., 19., 13.]],

         [[13., 19., 19., 13.],
          [19., 28., 28., 19.],
          [19., 28., 28., 19.],
          [13., 19., 19., 13.]],

         [[13., 19., 19., 13.],
          [19., 28., 28., 19.],
          [19., 28., 28., 19.],
          [13., 19., 19., 13.]],

         [[13., 19., 19., 13.],
          [19., 28., 28., 19.],
          [19., 28., 28., 19.],
          [13., 19., 19., 13.]]],


        [[[13., 19., 19., 13.],
          [19., 28., 28., 19.],
          [19., 28., 28., 19.],
          [13., 19., 19., 13.]],

         [[13., 19., 19., 13.],
          [19., 28., 28., 19.],
          [19., 28., 28., 19.],
          [13., 19., 19., 13.]],

         [[13., 19., 19., 13.],
          [19., 28., 28., 19.],
          [19., 28., 28., 19.],
          [13., 19., 19., 13.]],

         [[13., 19., 19., 13.],
          [19., 28., 28., 19.],
          [19., 28., 28.

In [13]:
cus_conv = CustomConv(3, 4)
cus_conv.weight = conv1.weight
cus_conv.bais = conv1.bias

In [14]:
cus_conv_out = cus_conv(x)
# print(cus_conv_out.shape)
cus_conv_out[0,:,0,0]

tensor([13., 13., 13., 13.], grad_fn=<SelectBackward0>)

In [15]:
cus_conv_out

tensor([[[[13., 19., 19., 13.],
          [19., 28., 28., 19.],
          [19., 28., 28., 19.],
          [13., 19., 19., 13.]],

         [[13., 19., 19., 13.],
          [19., 28., 28., 19.],
          [19., 28., 28., 19.],
          [13., 19., 19., 13.]],

         [[13., 19., 19., 13.],
          [19., 28., 28., 19.],
          [19., 28., 28., 19.],
          [13., 19., 19., 13.]],

         [[13., 19., 19., 13.],
          [19., 28., 28., 19.],
          [19., 28., 28., 19.],
          [13., 19., 19., 13.]]],


        [[[13., 19., 19., 13.],
          [19., 28., 28., 19.],
          [19., 28., 28., 19.],
          [13., 19., 19., 13.]],

         [[13., 19., 19., 13.],
          [19., 28., 28., 19.],
          [19., 28., 28., 19.],
          [13., 19., 19., 13.]],

         [[13., 19., 19., 13.],
          [19., 28., 28., 19.],
          [19., 28., 28., 19.],
          [13., 19., 19., 13.]],

         [[13., 19., 19., 13.],
          [19., 28., 28., 19.],
          [19., 28., 28.

In [16]:
(conv_out == cus_conv_out).sum()

tensor(128)

In [17]:
unfold = nn.Unfold(kernel_size=(3, 3), padding=1)
x1 = unfold(x)
x1 = conv1.weight.view(4,-1) @ x1
print(type(x1))
# print(type(conv1.bias[None,:,None]))
x1 += conv1.bias[None,:,None]
x1 = x1.view(2,4,4,4)
x1

<class 'torch.Tensor'>


tensor([[[[13., 19., 19., 13.],
          [19., 28., 28., 19.],
          [19., 28., 28., 19.],
          [13., 19., 19., 13.]],

         [[13., 19., 19., 13.],
          [19., 28., 28., 19.],
          [19., 28., 28., 19.],
          [13., 19., 19., 13.]],

         [[13., 19., 19., 13.],
          [19., 28., 28., 19.],
          [19., 28., 28., 19.],
          [13., 19., 19., 13.]],

         [[13., 19., 19., 13.],
          [19., 28., 28., 19.],
          [19., 28., 28., 19.],
          [13., 19., 19., 13.]]],


        [[[13., 19., 19., 13.],
          [19., 28., 28., 19.],
          [19., 28., 28., 19.],
          [13., 19., 19., 13.]],

         [[13., 19., 19., 13.],
          [19., 28., 28., 19.],
          [19., 28., 28., 19.],
          [13., 19., 19., 13.]],

         [[13., 19., 19., 13.],
          [19., 28., 28., 19.],
          [19., 28., 28., 19.],
          [13., 19., 19., 13.]],

         [[13., 19., 19., 13.],
          [19., 28., 28., 19.],
          [19., 28., 28.

In [65]:
a = torch.randn(2,1,10)
b = torch.randn(2,3,10)
a*b

tensor([[[ 0.3379, -1.5741,  0.3580,  2.6894, -0.3282,  0.8057, -0.7230,
           0.4657, -2.6248,  1.7149],
         [ 0.3394, -0.0491, -1.4133, -3.0067,  0.2556,  0.4303,  0.3314,
           0.3267, -0.3379, -1.4812],
         [ 0.2844, -4.3359, -0.2769, -3.2841, -0.0730,  0.0537, -0.9952,
           0.0076,  2.1398, -4.4930]],

        [[-0.3693, -1.8441,  0.0707,  0.5148,  0.5557,  0.4578,  0.0133,
          -0.0719, -0.0821, -0.0231],
         [ 0.5681, -1.2018,  0.0698,  1.0384, -0.0440,  0.9311,  0.0981,
          -0.0240,  0.0908,  0.6570],
         [-1.3673, -2.6551, -0.0150,  0.6079, -0.2202, -0.0183, -0.0511,
          -0.0457, -0.1287,  0.8067]]])