In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn

# Adaptive Average Pooling layer
class AdaptiveAVGPool(nn.Module):
    def __init__(self, kernel_size=3, stride=2, padding=0):
        super().__init__()
        self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        x = self.AdaptiveAvgPool(x)
        x = x.flatten(1)
        return F.normalize(x, p=2, dim=1)
    
class GeMPool(nn.Module):
    """Implementation of GeM as in https://github.com/filipradenovic/cnnimageretrieval-pytorch
    we add flatten and norm so that we can use it as one aggregation layer.
    """
    def __init__(self, p=3, eps=1e-6):
        super().__init__()
        self.p = nn.Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        x = F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1./self.p)
        x = x.flatten(1)
        return F.normalize(x, p=2, dim=1)