In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AveragePoolingAsConv(nn.Module):
    def __init__(self, kernel_size, stride=None, padding=0):
        super(AveragePoolingAsConv, self).__init__()
        if stride is None:
            stride = kernel_size
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.avg_pool_conv = None
        
    def forward(self, x):
        c_i = x.size(1)
        kernel_size = self.kernel_size
        stride = self.stride
        padding = self.padding

        if self.avg_pool_conv is None:
            avg_pool_kernel = torch.ones(c_i, 1, kernel_size, kernel_size, device=x.device) / (kernel_size * kernel_size)
            self.avg_pool_conv = nn.Conv2d(c_i, c_i, kernel_size, stride=stride, padding=padding, groups=c_i, bias=False)
            self.avg_pool_conv.weight.data = avg_pool_kernel

        return self.avg_pool_conv(x)


input_tensor = torch.randn(1, 3, 32, 32)  
kernel_size = 2
stride = 2
padding = 0

avg_pool_as_conv = AveragePoolingAsConv(kernel_size, stride, padding)
output_tensor = avg_pool_as_conv(input_tensor)
print(output_tensor.size())  

torch.Size([1, 3, 16, 16])
