# Experiments comparing the performance of traditional pooling operations and entropy pooling within a shallow neural network and Lenet. The experiments use MNIST and FashionMNIST.

In [None]:
%matplotlib inline

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                        download=True, transform=transform)
# trainset = torchvision.datasets.FashionMNIST(root='./data', train=True,
#                                         download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=8)

testset = torchvision.datasets.MNIST(root='./data', train=False,
                                       download=True, transform=transform)
# testset = torchvision.datasets.FashionMNIST(root='./data', train=False,
#                                        download=True, transform=transform)

testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=8)


In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair, _quadruple
import time
from skimage.measure import shannon_entropy
from scipy import stats

from torch.nn.modules.utils import _pair, _quadruple
import time
from skimage.measure import shannon_entropy
from scipy import stats
import numpy as np

class EntropyPool2d(nn.Module):

    def __init__(self, kernel_size=3, stride=1, padding=0, same=False, entr='high'):
        super(EntropyPool2d, self).__init__()
        self.k = _pair(kernel_size)
        self.stride = _pair(stride)
        self.padding = _quadruple(padding)  # convert to l, r, t, b
        self.same = same
        self.entr = entr

    def _padding(self, x):
        if self.same:
            ih, iw = x.size()[2:]
            if ih % self.stride[0] == 0:
                ph = max(self.k[0] - self.stride[0], 0)
            else:
                ph = max(self.k[0] - (ih % self.stride[0]), 0)
            if iw % self.stride[1] == 0:
                pw = max(self.k[1] - self.stride[1], 0)
            else:
                pw = max(self.k[1] - (iw % self.stride[1]), 0)
            pl = pw // 2
            pr = pw - pl
            pt = ph // 2
            pb = ph - pt
            padding = (pl, pr, pt, pb)
        else:
            padding = self.padding
        return padding
    
    def forward(self, x):
        # using existing pytorch functions and tensor ops so that we get autograd, 
        # would likely be more efficient to implement from scratch at C/Cuda level
        start = time.time()
        x = F.pad(x, self._padding(x), mode='reflect')
        x_detached = x.cpu().detach()
        x_unique, x_indices, x_inverse, x_counts = np.unique(x_detached,
                                                             return_index=True, 
                                                             return_inverse=True, 
                                                             return_counts=True)        
        freq = torch.FloatTensor([x_counts[i] / len(x_inverse) for i in x_inverse]).cuda()
        x_probs = freq.view(x.shape)       
        x_probs = x_probs.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1])
        x_probs = x_probs.contiguous().view(x_probs.size()[:4] + (-1,))
        if self.entr is 'high':
            x_probs, indices = torch.min(x_probs.cuda(), dim=-1)
        elif self.entr is 'low':
            x_probs, indices = torch.max(x_probs.cuda(), dim=-1)
        else:
            raise Exception('Unknown entropy mode: {}'.format(self.entr))
            
        x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1])
        x = x.contiguous().view(x.size()[:4] + (-1,))
        indices = indices.view(indices.size() + (-1,))
        pool = torch.gather(input=x, dim=-1, index=indices)
        
        return pool.squeeze(-1)
    

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
from sklearn.metrics import f1_score

MAX = 'max'
AVG = 'avg'
HIGH_ENTROPY = 'high_entr'
LOW_ENTROPY = 'low_entr'

class Net1Pool(nn.Module):
    def __init__(self, num_classes=10, pooling=MAX):
        super(Net1Pool, self).__init__()
        self.conv1 = nn.Conv2d(1, 30, 5)
        
        if pooling is MAX:
            self.pool = nn.MaxPool2d(2, 2)
        elif pooling is AVG:
            self.pool = nn.AvgPool2d(2, 2)
        elif pooling is HIGH_ENTROPY:
            self.pool = EntropyPool2d(2, 2, entr='high')
        elif pooling is LOW_ENTROPY:
            self.pool = EntropyPool2d(2, 2, entr='low')
              
        self.fc0 = nn.Linear(30 * 12 * 12, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = x.view(-1, 30 * 12 * 12)
        x = F.relu(self.fc0(x))
        return x


class Net2Pool(nn.Module):
    def __init__(self, num_classes=10, pooling=MAX):
        super(Net2Pool, self).__init__()
        self.conv1 = nn.Conv2d(1, 50, 5, 1)
        self.conv2 = nn.Conv2d(50, 50, 5, 1)
        
        if pooling is MAX:
            self.pool = nn.MaxPool2d(2, 2)
        elif pooling is AVG:
            self.pool = nn.AvgPool2d(2, 2)
        elif pooling is HIGH_ENTROPY:
            self.pool = EntropyPool2d(2, 2, entr='high')
        elif pooling is LOW_ENTROPY:
            self.pool = EntropyPool2d(2, 2, entr='low')
              
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)

        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

def configure_net(net, device):
    net.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    return net, optimizer, criterion

def train(net, optimizer, criterion, trainloader, device, epochs=10, logging=2000):
    for epoch in range(epochs):  
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            start = time.time()
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
        
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % logging == logging - 1:    
                print('[%d, %5d] loss: %.3f duration: %.5f' %
                      (epoch + 1, i + 1, running_loss / logging, time.time() - start))
                running_loss = 0.0

    print('Finished Training')
    
def test(net, testloader, device):
    correct = 0
    total = 0
    predictions = []
    l = []
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)

            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            predictions.extend(predicted.cpu().numpy())
            l.extend(labels.cpu().numpy())

    print('Accuracy: {}'.format(100 * correct / total))


In [None]:
epochs = 10
logging = 15000
num_classes = 10

In [None]:
print('- - - - - - -  - -- - - - 2 pool - -  - - - - - - - - - - - - - -')
print('- - - - - - -  - -- - - - MAX - -  - - - - - - - - - - - - - -')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net, optimizer, criterion = configure_net(Net2Pool(num_classes=num_classes, pooling=MAX), device)
train(net, optimizer, criterion, trainloader, device, epochs=epochs, logging=logging)
test(net, testloader, device)

print('- - - - - - -  - -- - - - AVG - -  - - - - - - - - - - - - - -')

net, optimizer, criterion = configure_net(Net2Pool(num_classes=num_classes, pooling=AVG), device)
train(net, optimizer, criterion, trainloader, device, epochs=epochs, logging=logging)
test(net, testloader, device)

print('- - - - - - -  - -- - - - HIGH - -  - - - - - - - - - - - - - -')

net, optimizer, criterion = configure_net(Net2Pool(num_classes=num_classes, pooling=HIGH_ENTROPY), device)
train(net, optimizer, criterion, trainloader, device, epochs=epochs, logging=logging)
test(net, testloader, device)

print('- - - - - - -  - -- - - - LOW - -  - - - - - - - - - - - - - -')

net, optimizer, criterion = configure_net(Net2Pool(num_classes=num_classes, pooling=LOW_ENTROPY), device)
train(net, optimizer, criterion, trainloader, device, epochs=epochs, logging=logging)
test(net, testloader, device)

In [None]:
print('- - - - - - -  - -- - - -1 pool - -  - - - - - - - - - - - - - -')
print('- - - - - - -  - -- - - - MAX - -  - - - - - - - - - - - - - -')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net, optimizer, criterion = configure_net(Net1Pool(num_classes=num_classes, pooling=MAX), device)
train(net, optimizer, criterion, trainloader, device, epochs=epochs, logging=logging)
test(net, testloader, device)

print('- - - - - - -  - -- - - - AVG - -  - - - - - - - - - - - - - -')

net, optimizer, criterion = configure_net(Net1Pool(num_classes=num_classes, pooling=AVG), device)
train(net, optimizer, criterion, trainloader, device, epochs=epochs, logging=logging)
test(net, testloader, device)

print('- - - - - - -  - -- - - - HIGH - -  - - - - - - - - - - - - - -')

net, optimizer, criterion = configure_net(Net1Pool(num_classes=num_classes, pooling=MAX_ENTROPY), device)
train(net, optimizer, criterion, trainloader, device, epochs=epochs, logging=logging)
test(net, testloader, device)

print('- - - - - - -  - -- - - - LOW - -  - - - - - - - - - - - - - -')

net, optimizer, criterion = configure_net(Net1Pool(num_classes=num_classes, pooling=MIN_ENTROPY), device)
train(net, optimizer, criterion, trainloader, device, epochs=epochs, logging=logging)
test(net, testloader, device)