# `GeM Pooling` vs `Average Pooling`

BLOG: https://amaarora.github.io/2020/08/30/gempool.html

In this notebook, we implement `GeM Pooling` and compare the results with `Average Pooling` on PETs dataset using `ResNet-34` model.

In [1]:
from fastai2.vision.all import *
from nbdev.showdoc import *
import glob
import albumentations
from torchvision import models
from albumentations.pytorch.transforms import ToTensorV2
set_seed(2)

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.hub import load_state_dict_from_url

# `Pets` Dataset

Now we use the wonderful [fastai library](https://github.com/fastai/fastai2) to get the `Pets` dataset.

In [3]:
bs = 4

In [4]:
path = untar_data(URLs.PETS); path

Path('/home/ubuntu/.fastai/data/oxford-iiit-pet')

In [5]:
(path/'images').ls()

(#7381) [Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/keeshond_34.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Siamese_178.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/german_shorthaired_94.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Abyssinian_92.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/basset_hound_111.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_194.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/staffordshire_bull_terrier_91.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Persian_69.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/english_setter_33.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_155.jpg')...]

## `Dataset`

The implementation of the `PetsDataset` has been heavily inspired and partially copied (regex part) from `fastai2` repo [here](https://github.com/fastai/fastai2/blob/master/nbs/course/lesson1-pets.ipynb).

In [6]:
class PetsDataset:
    def __init__(self, paths, transforms=None):
        self.image_paths = paths
        self.transforms = transforms
        
    def __len__(self): 
        return len(self.image_paths)
    
    def setup(self, pat=r'(.+)_\d+.jpg$', label2int=None):
        "adds a label dictionary to `self`"
        self.pat = re.compile(pat)
        if label2int is not None:
            self.label2int = label2int
            self.int2label = {v:i for i,v in self.label2int.items()}
        else:
            labels = [os.path.basename(self.pat.search(str(p)).group(1))
                  for p in self.image_paths]
            self.labels = set(labels)
            self.label2int = {label:i for i,label in enumerate(self.labels)}
            self.int2label = {v:i for i,v in self.label2int.items()}

    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        img = Image.open(img_path)
        img = np.array(img)
        
        target = os.path.basename(self.pat.search(str(img_path)).group(1))
        target = self.label2int[target]
        
        if self.transforms:
            img = self.transforms(image=img)['image']      
            
        return img, torch.tensor(target, dtype=torch.long)

In [7]:
image_paths = get_image_files(path/'images')
image_paths

(#7378) [Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/keeshond_34.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Siamese_178.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/german_shorthaired_94.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Abyssinian_92.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/basset_hound_111.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_194.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/staffordshire_bull_terrier_91.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Persian_69.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/english_setter_33.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_155.jpg')...]

In [8]:
# remove those images that are not 3 channel
from tqdm.notebook import tqdm
run_remove = False
def remove(o):
    img = Image.open(o)
    img = np.array(img)
    if img.shape[2] != 3:
        os.remove(o)
if run_remove:
    for o in tqdm(image_paths): remove(o)

In [9]:
image_paths = get_image_files(path/'images')
image_paths

(#7378) [Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/keeshond_34.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Siamese_178.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/german_shorthaired_94.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Abyssinian_92.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/basset_hound_111.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_194.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/staffordshire_bull_terrier_91.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Persian_69.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/english_setter_33.jpg'),Path('/home/ubuntu/.fastai/data/oxford-iiit-pet/images/Russian_Blue_155.jpg')...]

In [10]:
# augmentations using `albumentations` library
sz = 224
tfms = albumentations.Compose([
    albumentations.Resize(sz, sz) if sz else albumentations.NoOp(),
    albumentations.OneOf(
        [albumentations.Cutout(random.randint(1,8), 16, 16),
         albumentations.CoarseDropout(random.randint(1,8), 16, 16)]
    ),
    albumentations.Normalize(always_apply=True),
    ToTensorV2()
])

In [11]:
dataset = PetsDataset(image_paths, tfms)

In [12]:
# to setup the `label2int` dictionary
dataset.setup()

In [13]:
dataset[0]

(tensor([[[ 0.8618,  0.1597,  0.4166,  ..., -0.6452, -0.3198, -0.2171],
          [ 1.1872,  0.3481,  0.4166,  ..., -0.3027,  0.0912,  0.3138],
          [ 0.8104,  0.6049,  0.0227,  ..., -0.3712, -0.1657, -0.1828],
          ...,
          [ 1.2385,  0.4851,  0.0227,  ...,  0.8789,  1.2214,  0.8961],
          [ 0.7077,  0.9474, -0.6965,  ...,  0.1254,  1.5297,  1.6667],
          [ 0.1083, -0.0801,  0.3652,  ...,  0.2111,  0.5193,  0.6734]],
 
         [[ 0.9230,  0.4328,  0.4503,  ..., -0.2850, -0.0224, -0.0399],
          [ 1.3256,  0.7304,  0.4678,  ..., -0.0399,  0.1527,  0.3277],
          [ 0.8354,  0.8354,  0.3102,  ..., -0.2500, -0.1975, -0.3200],
          ...,
          [ 1.3606,  1.3431,  0.6078,  ...,  0.9755,  1.3957,  1.1331],
          [ 0.7654,  1.0455, -0.0574,  ...,  0.7654,  1.6232,  1.7458],
          [ 0.4153,  0.5903,  0.9230,  ...,  0.7654,  0.8529,  1.0980]],
 
         [[ 0.3393, -0.3578, -0.4275,  ..., -0.7936, -0.4624, -0.3578],
          [ 0.6531, -0.2358,

In [14]:
dataset[0][0].shape

torch.Size([3, 224, 224])

## `DataLoaders`

We divide the `image_paths` into train and validation with 20% split.

In [15]:
nval = int(len(image_paths)*0.2)
nval

1475

In [16]:
trn_img_paths = image_paths[:-nval]
val_img_paths = image_paths[-nval:]
assert len(trn_img_paths) + len(val_img_paths) == len(image_paths)
len(trn_img_paths), len(val_img_paths)

(5903, 1475)

In [17]:
trn_dataset = PetsDataset(trn_img_paths, transforms=tfms)
val_dataset = PetsDataset(val_img_paths, transforms=tfms)

In [18]:
# use same `label2int` dictionary as in `dataset` for consistency across train and val
trn_dataset.setup(label2int=dataset.label2int)
val_dataset.setup(label2int=dataset.label2int)

In [19]:
trn_loader = torch.utils.data.DataLoader(trn_dataset, batch_size=bs, num_workers=4, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=bs, num_workers=4, shuffle=False)

In [20]:
# make sure eveyrthing works so far
next(iter(trn_loader))[0].shape, next(iter(val_loader))[0].shape

(torch.Size([4, 3, 224, 224]), torch.Size([4, 3, 224, 224]))

## `Model` with `GeM` pooling

*Implementation of GeM Pooling has been copied and replicated from the official code implementation of the research paper [here](https://github.com/filipradenovic/cnnimageretrieval-pytorch).

Here, we download the pretrained `ResNet-34` model.

In [21]:
# Vanilla resnet with `BatchNorm`
model_resnet34 = models.resnet34(pretrained=True)
in_features = model_resnet34.fc.in_features

In [22]:
class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM,self).__init__()
        self.p = nn.Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        return self.gem(x, p=self.p, eps=self.eps)
        
    def gem(self, x, p=3, eps=1e-6):
        return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
        
    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'

In [23]:
features = list(model_resnet34.children())[:-2]
pool = GeM()

In [24]:
class GeM_ResNet(nn.Module):    
    def __init__(self, features, pool):
        super(GeM_ResNet, self).__init__()
        self.features = nn.Sequential(*features)
        self.fc = nn.Linear(in_features, len(trn_dataset.label2int))
        self.pool = pool
    
    def forward(self, x):
        o = self.features(x)
        o = self.pool(o).squeeze(-1).squeeze(-1)
        o = self.fc(o)
        return o

In [25]:
model_pool = GeM_ResNet(features, pool)
model_pool

GeM_ResNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=T

As can be seen, the `pool` layer in this network is `GeM` instead of the common `Average Pooling` layer used by default. 

## Training using `PytorchLightning`

Finally, we use [PytorchLightning](https://github.com/PyTorchLightning/pytorch-lightning) for training the model. 

In [27]:
from torch.optim.lr_scheduler import LambdaLR
from pytorch_lightning import LightningModule, Trainer

In [28]:
class Model(LightningModule):
    def __init__(self, base):
        super().__init__()
        self.base = base

    def forward(self, x):
        return self.base(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
        scheduler = LambdaLR(optimizer, lambda epoch: 0.95 ** epoch)
        return [optimizer], [scheduler]
        
    def step(self, batch):
        x, y  = batch
        y_hat = self(x)
        loss  = nn.CrossEntropyLoss()(y_hat, y)
        return loss, y, y_hat

    def training_step(self, batch, batch_nb):
        loss, _, _ = self.step(batch)
        return {'loss': loss}

    def validation_step(self, batch, batch_nb):
        loss, y, y_hat = self.step(batch)
        return {'loss': loss, 'y': y.detach(), 'y_hat': y_hat.detach()}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        acc = self.get_accuracy(outputs)
        print(f"Epoch:{self.current_epoch} | Loss:{avg_loss} | Accuracy:{acc}")
        return {'loss': avg_loss}
    
    def get_accuracy(self, outputs):
        from sklearn.metrics import accuracy_score
        y = torch.cat([x['y'] for x in outputs])
        y_hat = torch.cat([x['y_hat'] for x in outputs])
        preds = y_hat.argmax(1)
        return accuracy_score(y.cpu().numpy(), preds.cpu().numpy())

In [29]:
# define PL versions 
model = Model(model_pool)

In [30]:
debug = False
gpus = torch.cuda.device_count()

### `batch_size=64`

In [31]:
trn_loader = torch.utils.data.DataLoader(trn_dataset, batch_size=64, num_workers=6, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, num_workers=6, shuffle=False)

In [32]:
trainer = Trainer(gpus=gpus, max_epochs=5, 
                  num_sanity_val_steps=1 if debug else 0)
trainer.fit(model, train_dataloader=trn_loader, val_dataloaders=val_loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name | Type       | Params
------------------------------------
0 | base | GeM_ResNet | 21 M  


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:0 | Loss:0.38181644678115845 | Accuracy:0.88




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:1 | Loss:0.30204614996910095 | Accuracy:0.9071186440677966


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:2 | Loss:0.290485143661499 | Accuracy:0.9098305084745762


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:3 | Loss:0.27039361000061035 | Accuracy:0.9166101694915254


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:4 | Loss:0.2537878453731537 | Accuracy:0.9233898305084746



1

## `AvgPool`

In [33]:
# define PL versions 
model_avg_pool = models.resnet34(pretrained=True)
features = list(model_avg_pool.children())[:-2]
model_avg_pool.fc = nn.Linear(in_features, len(trn_dataset.label2int))

In [34]:
# define PL versions 
model = Model(model_avg_pool)

In [35]:
trainer = Trainer(gpus=gpus, max_epochs=5, 
                  num_sanity_val_steps=1 if debug else 0)
trainer.fit(model, train_dataloader=trn_loader, val_dataloaders=val_loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name | Type   | Params
--------------------------------
0 | base | ResNet | 21 M  


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:0 | Loss:0.4223055839538574 | Accuracy:0.888135593220339


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:1 | Loss:0.3181498646736145 | Accuracy:0.9057627118644068


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:2 | Loss:0.26608720421791077 | Accuracy:0.92


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:3 | Loss:0.27744489908218384 | Accuracy:0.920677966101695


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Epoch:4 | Loss:0.27841755747795105 | Accuracy:0.9159322033898305



1

# Conclusion

We can see that `GeM Pooling` performance slightly better, however, for me the results vary and both the models trained with `Average Pooling` and `GeM Pooling` have comparable accuracies. 