In [2]:
import torch
from torch import nn

In [7]:
class Involution(nn.Module):
    def __init__(self,in_channels,
                out_channels,
                 kernel,
                ratio=16,
                 stride=1,group=1,dilation=1,padding=1):
        super().__init__()
        self.G=group
        self.K=kernel

        self.o = nn.AvgPool2d(stride, stride) if stride > 1 else nn.Identity()
        self.reduce = nn.Conv2d(in_channels,in_channels//ratio,1)
        self.span = nn.Conv2d(in_channels//ratio, kernel*kernel*group, 1)
        self.unfold = nn.Unfold(kernel, dilation, padding, stride)
    def forward(self,x):
        B,C,H,W=x.shape
        x_unfolded = self.unfold(x) # B,CxKxK,HxW
        x_unfolded = x_unfolded.view(B, self.G, C//self.G, self.K*self.K, H, W)
        # kernel generation, Eqn.(6)
        kernel = self.span(self.reduce(self.o(x))) # B,KxKxG,H,W
        kernel = kernel.view(B, self.G, self.K*self.K, H, W).unsqueeze(2)
        # Multiply-Add operation, Eqn.(4)
        out = torch.mul(kernel, x_unfolded).sum(dim=3) # B,G,C/G,H,W
        out = out.view(B, C, H, W)
        return out 

In [10]:
test = Involution(256,512,3)

In [11]:
x=torch.randn(1,256,10,10)
out =test(x)
test.eval()


In [12]:
out.shape

torch.Size([1, 256, 10, 10])