In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Subset

In [2]:
class stochasticPool(nn.Module):
    def __init__(self, kernel_size, stride, padding = 0):
        super(stochasticPool, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
    
    def forward(self, x):
        x = F.pad(x, (self.padding, self.padding, self.padding, self.padding))
        n_imgs = x.shape[0]
        n_channels = x.shape[1]
        n_height = x.shape[2]
        n_width = x.shape[3]
        
        x = F.relu(x)
        x_unfold = F.unfold(x, kernel_size=self.kernel_size, stride=self.stride)
        
        n_regions = x_unfold.shape[-1]
        n_regions_side = int(np.sqrt(n_regions))
        
        x_unfold = x_unfold.view(n_imgs, n_channels, self.kernel_size, self.kernel_size, n_regions).permute(0, 4, 1, 2, 3)
        
        norm = torch.sum(x_unfold, dim=(-1, -2)).view(n_imgs, n_regions, n_channels, 1, 1)
        
        x_normed = torch.nan_to_num(x_unfold / norm).view(n_imgs, n_regions, n_channels, self.kernel_size ** 2)

        output = torch.zeros((n_imgs, n_channels, n_height // self.stride, n_width // self.stride))

        for idx_i, img in enumerate(x_normed):
            for idx_r, region in enumerate(img):
                for idx_c, channel in enumerate(region):
                    if torch.sum(channel) == 0:
                        output[idx_i][idx_c][idx_r // n_regions_side][idx_r % n_regions_side] = 0
                        continue
                    idx = torch.multinomial(channel, 1)
                    val = channel[idx] * norm[idx_i][idx_r][idx_c][0][0]
                    output[idx_i][idx_c][idx_r // n_regions_side][idx_r % n_regions_side] = val

        return output
    