In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Subset

In [2]:
class rankAvgPool(nn.Module):
    def __init__(self, kernel_size, stride, t=-1, padding = 0):
        super(rankAvgPool, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        if t == -1:
            self.t = kernel_size // 2
        else:
            self.t = t
    
    def forward(self, x):
        x = F.pad(x, (self.padding, self.padding, self.padding, self.padding))
        n_imgs = x.shape[0]
        n_channels = x.shape[1]
        n_height = x.shape[2]
        n_width = x.shape[3]
        
        x_unfold = F.unfold(x, kernel_size=self.kernel_size, stride=self.stride)
        
        n_regions = x_unfold.shape[-1]
        n_regions_side = int(np.sqrt(n_regions))
        
        x_unfold = x_unfold.view(n_imgs, 
                                   n_channels, 
                                   self.kernel_size, 
                                   self.kernel_size, 
                                   n_regions).permute(0, 4, 1, 2, 3).view(n_imgs, 
                                                                          n_regions, 
                                                                          n_channels, 
                                                                          self.kernel_size ** 2)
        
        t_sorted = torch.sort(x_unfold, descending=True, dim=-1).values[:, :, :, : (self.t + 1)]
        
        output = torch.mean(t_sorted, dim=-1).permute(0, 2, 1).view(n_imgs, n_channels, n_regions_side, n_regions_side)
        
        return output