In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Subset

In [2]:
class gatedPool(nn.Module):
    def __init__(self, in_channel, kernel_size, stride, padding = 0, learn_option='l/c'):
        super(gatedPool, self).__init__()
        
        if learn_option == 'l/c':
            self.mask = nn.Parameter(torch.randn(in_channel,
                                                 in_channel,
                                                 kernel_size,
                                                 kernel_size).float())
        elif learn_option == 'l':
            self.mask = nn.Parameter(torch.randn(1,
                                                 1,
                                                 kernel_size,
                                                 kernel_size).float())
        else:
            raise NameError(learn_option)
            
        self.learn_option = learn_option
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
         
    def forward(self, x):
        if self.learn_option == 'l/c':
            return self.layer_channel(x)
        if self.learn_option == 'l':
            return self.layer(x)
        
    def layer(self, x):
        size = list(x.size())[1]
        channels_gated = []
        
        for ch in range(size):
            a = x[:,ch,:,:]
            a = torch.unsqueeze(a,1)
            a = F.conv2d(a,self.mask,stride = self.stride)
            channels_gated.append(a)
            
        gated = channels_gated[0]
        
        for channel_gated in channels_gated[1:]:
            gated = torch.cat((gated,channel_gated),1)
        
        alpha = F.sigmoid(gated)
    
        x = alpha * F.max_pool2d(x,
                                 self.kernel_size,
                                 self.stride,
                                 self.padding) + (1-alpha)*F.avg_pool2d(x,
                                                                        self.kernel_size,
                                                                        self.stride,
                                                                        self.padding)
        
        return x 
    
    def layer_channel(self, x):
        mask_c = F.conv2d(x,self.mask,stride = self.stride)
        alpha = F.sigmoid(mask_c)
        x = alpha * F.max_pool2d(x,
                                 self.kernel_size,
                                 self.stride,
                                 self.padding) + (1-alpha) * F.avg_pool2d(x,
                                                                          self.kernel_size,
                                                                          self.stride, 
                                                                          self.padding)

        return torch.Tensor(x) 