In [2]:
import torch
import timm
import torch.nn as nn
import torch.nn.functional as F

In [6]:
NUM_FINETUNE_CLASSES = 1

In [7]:
model = timm.create_model('inception_resnet_v2', pretrained=True, num_classes=NUM_FINETUNE_CLASSES)

In [9]:
class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM,self).__init__()
        self.p = nn.Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        return self.gem(x, p=self.p, eps=self.eps)
        
    def gem(self, x, p=3, eps=1e-6):
        return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
        
    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'

In [7]:
img = torch.ones((1, 3, 224, 224))

In [10]:
model.global_pool = GeM()

In [8]:
feature_output = model.forward_features(img)

In [23]:
model.classif.in_features

1536

In [9]:
feature_output.shape

torch.Size([1, 1536, 5, 5])

In [14]:
gem_pooling = GeM()

In [17]:
out = gem_pooling(feature_output)

In [18]:
out.shape

torch.Size([1, 1536, 1, 1])