In [1]:
import torch.nn.functional as F
import torch.nn as nn
import torch

from torch.nn.parameter import Parameter
def gem(x, p=3, eps=1e-6):
    return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)

class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM,self).__init__()
        self.p = Parameter(torch.ones(1)*p)
        self.eps = eps
        
    def forward(self, x):
        return gem(x, p=self.p, eps=self.eps)   
    
    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'

In [34]:
x = torch.ones([8, 2048, 4, 4]) 
pool = torch.nn.AdaptiveAvgPool2d(1)
x = pool(x).view(-1, 16384)
lin = nn.Linear(16384, 1)
x = lin(x)
print(x.shape)

torch.Size([1, 1])


In [41]:
%%time
x = torch.ones([8, 2048, 4, 4]) 
pool = torch.nn.AdaptiveAvgPool2d(1)
x = pool(x)
print(x.shape)

torch.Size([8, 2048, 1, 1])
CPU times: user 1.82 ms, sys: 0 ns, total: 1.82 ms
Wall time: 1.79 ms


In [42]:
%%time
x = torch.ones([8, 2048, 4, 4]) 
pool = GeM()
x = pool(x)
print(x.shape)

torch.Size([8, 2048, 1, 1])
CPU times: user 2.34 ms, sys: 0 ns, total: 2.34 ms
Wall time: 1.8 ms


In [None]:
conv
batch
mish 
pool

In [69]:
%%time
fc1 = nn.Conv2d(2048, 2048, 4)
bn = nn.BatchNorm2d(2048)
relu = nn.ReLU()
pool = GeM()


x = torch.ones([8, 2048, 4, 4]) 
x = fc1(x)
x = relu(x)
x = bn(x)
x = pool(x)
x = x.view(8, -1)
lin = nn.Linear(2048, 168)
print(x.shape)

torch.Size([8, 2048])
CPU times: user 580 ms, sys: 62.5 ms, total: 643 ms
Wall time: 428 ms


In [4]:
%%time
fc1 = nn.Conv2d(2048, 1024, 4)
bn = nn.BatchNorm2d(1024)
relu = nn.ReLU()
pool = GeM()


x = torch.ones([8, 2048, 4, 4]) 
x = fc1(x)
x = bn(x)
x = relu(x)
x = pool(x)
x = x.view(8, -1)
lin = nn.Linear(2048, 168)
print(x.shape)

torch.Size([8, 1024])
CPU times: user 291 ms, sys: 29.4 ms, total: 321 ms
Wall time: 215 ms
