In [1]:
import torch
import torch.nn as nn

In [2]:
from ResNet import *

In [3]:
model = resnet18()

In [63]:
class SCAN(nn.Module):
    """
    SCAN
    """
    
    def __init__(self, channels, stride=(1,1,1), final_channels=512, num_classes=100):
        super(SCAN, self).__init__()
        
        # activation func
        self._relu = nn.ReLU(inplace=True)
        
        # attention module
        self._attconv = nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=1, bias=False)
        self._attbn0 = nn.BatchNorm2d(channels)
        self._attdeconv = nn.ConvTranspose2d(channels, channels, kernel_size=3, stride=2, padding=1, bias=False)
        self._attbn1 = nn.BatchNorm2d(channels)
        
        # bottleneck module
        self._bot1x1_0 = nn.Conv2d(channels, channels, kernel_size=1, stride=stride[0], bias=False)
        self._botbn0 = nn.BatchNorm2d(channels)
        self._bot3x3 = nn.Conv2d(channels, channels, kernel_size=3, stride=stride[1], padding=1, bias=False)
        self._botbn1 = nn.BatchNorm2d(channels)
        self._bot1x1_1 = nn.Conv2d(channels, final_channels, kernel_size=1, stride=stride[2], bias=False)
        self._botbn2 = nn.BatchNorm2d(final_channels)
        
        # classifier module
        self._globalavgpool = nn.AdaptiveAvgPool2d(1)
        self._shallow_classifier = nn.Conv2d(final_channels, num_classes, kernel_size=1, stride=1, bias=True)
    
    def forward(self, x):
        # attention
        att = self._relu(self._attbn0(self._attconv(x)))

        att = self._relu(self._attbn1(self._attdeconv(att, x.shape)))
        x = x * torch.sigmoid(att)
        
        # bottleneck
        x = self._relu(self._botbn0(self._bot1x1_0(x)))
        x = self._relu(self._botbn1(self._bot3x3(x)))
        feature = self._relu(self._botbn2(self._bot1x1_1(x)))
        #print(feature.shape)
        # classifier
        x = self._globalavgpool(feature)
        x = self._shallow_classifier(x).squeeze()
        
        return x, feature

In [64]:
sample1 = torch.randn(1, 64, 112, 112)
sample2 = torch.randn(1, 128, 56, 56)
sample3 = torch.randn(1, 256, 28, 28)

In [65]:
scan_module1 = SCAN(channels=64, bottle_stride=(2,2,2))
scan_module2 = SCAN(channels=128, bottle_stride=(2,2,1))
scan_module3 = SCAN(channels=256, bottle_stride=(1,2,1))

In [66]:
result1 = scan_module1(sample1)
result2 = scan_module2(sample2)
result3 = scan_module3(sample3)

In [117]:
from pthflops import count_ops

def counter(model, sample, verbose=True):
    model = model.eval()
    
    M = 1000000
    params_num = 0
    for params in model.parameters():
        params_num += params.view(-1).shape[0]    

    flops = count_ops(model, sample, verbose=verbose)
    print("flops: {:.4f}M, params: {:.4f}M".format(flops/1000000, params_num/1000000))
    
    return params, flops

params, flops = counter(scan_module3, sample3)

Operation                                            OPS         
---------------------------------------------------  ----------  
SCAN/Conv2d[_attconv]/onnx::Conv                     115655680   
SCAN/BatchNorm2d[_attbn0]/onnx::BatchNormalization   100352      
SCAN/ReLU[_relu]/onnx::Relu                          100352      
SCAN/BatchNorm2d[_attbn1]/onnx::BatchNormalization   401408      
SCAN/ReLU[_relu]/onnx::Relu                          401408      
SCAN/Conv2d[_bot1x1_0]/onnx::Conv                    51380224    
SCAN/BatchNorm2d[_botbn0]/onnx::BatchNormalization   401408      
SCAN/ReLU[_relu]/onnx::Relu                          401408      
SCAN/Conv2d[_bot3x3]/onnx::Conv                      115605504   
SCAN/BatchNorm2d[_botbn1]/onnx::BatchNormalization   100352      
SCAN/ReLU[_relu]/onnx::Relu                          100352      
SCAN/Conv2d[_bot1x1_1]/onnx::Conv                    25690112    
SCAN/BatchNorm2d[_botbn2]/onnx::BatchNormalization   200704      
SCAN/ReLU[

In [109]:
class EPE(nn.Module):
    """
    EPE Module
    """
    def __init__(self, channels, stride=(1,1,1), final_channels=512, expansion=2, num_class=100):
        super(EPE, self).__init__()
        
        # activation func
        self._relu = nn.ReLU(inplace=True)
        
        # expansion module
        mid_channels = channels * expansion
        self._expansion_conv = nn.Conv2d(channels, mid_channels, kernel_size=1, stride=stride[0], bias=False)
        self._bn0 = nn.BatchNorm2d(mid_channels)
        
        # conv module
        self._depthwise_conv = nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride[1], padding=1, bias=False, groups=channels)
        self._bn1 = nn.BatchNorm2d(mid_channels)
        
        self._projection_conv = nn.Conv2d(mid_channels, final_channels, kernel_size=1, stride=stride[2], bias=False)
        self._bn2 = nn.BatchNorm2d(final_channels)
        
        # classifier module
        self._globalavgpool = nn.AdaptiveAvgPool2d(1)
        self._shallow_classifier = nn.Conv2d(final_channels, num_class, kernel_size=1, stride=1, bias=True)
    
    def forward(self, x):
        # conv
        x = self._relu(self._bn0(self._expansion_conv(x)))
        x = self._relu(self._bn1(self._depthwise_conv(x)))
        x = self._relu(self._bn2(self._projection_conv(x)))

        # classifier
        x = self._globalavgpool(x)
        x = self._shallow_classifier(x).squeeze()
        return x

In [110]:
epe_module1 = EPE(64, stride=(2,2,2))
epe_module2 = EPE(128, stride=(2,2,1))
epe_module3 = EPE(256, stride=(1,2,1))

In [114]:
params, flops = counter(epe_module3, sample3)

Operation                                        OPS         
-----------------------------------------------  ----------  
EPE/Conv2d[_expansion_conv]/onnx::Conv           103161856   
EPE/BatchNorm2d[_bn0]/onnx::BatchNormalization   802816      
EPE/ReLU[_relu]/onnx::Relu                       802816      
EPE/Conv2d[_depthwise_conv]/onnx::Conv           1806336     
EPE/BatchNorm2d[_bn1]/onnx::BatchNormalization   200704      
EPE/ReLU[_relu]/onnx::Relu                       200704      
EPE/Conv2d[_projection_conv]/onnx::Conv          51380224    
EPE/BatchNorm2d[_bn2]/onnx::BatchNormalization   200704      
EPE/ReLU[_relu]/onnx::Relu                       200704      
EPE/Conv2d[_shallow_classifier]/onnx::Conv       51200       
----------------------------------------------   ---------   
Input size: (1, 256, 28, 28)
158,808,064 FLOPs or approx. 0.16 GFLOPs
flops: 158.8081M, params: 0.4568M


In [88]:
from torchvision.models.resnet import resnet18

In [89]:
model = resnet18()

In [90]:
sample = torch.randn(1,3,32,32)

In [91]:
counter(model,sample)

Operation                                                                                                OPS       
-------------------------------------------------------------------------------------------------------  --------  
ResNet/Conv2d[conv1]/onnx::Conv                                                                          2424832   
ResNet/BatchNorm2d[bn1]/onnx::BatchNormalization                                                         32768     
ResNet/ReLU[relu]/onnx::Relu                                                                             32768     
ResNet/MaxPool2d[maxpool]/onnx::MaxPool                                                                  32768     
ResNet/Sequential[layer1]/BasicBlock[0]/Conv2d[conv1]/onnx::Conv                                         2359296   
ResNet/Sequential[layer1]/BasicBlock[0]/BatchNorm2d[bn1]/onnx::BatchNormalization                        8192      
ResNet/Sequential[layer1]/BasicBlock[0]/ReLU[relu]/onnx::Relu           

(Parameter containing:
 tensor([-2.4426e-02,  2.3967e-02, -2.9111e-02, -2.1427e-02,  4.3984e-02,
          3.0813e-02,  1.2565e-02,  8.0713e-03, -1.9495e-02,  2.4364e-02,
          4.3021e-03,  1.5280e-02, -2.3674e-02,  6.3198e-03,  4.2103e-02,
         -4.1981e-02,  3.6257e-02,  3.5632e-02,  3.1763e-03,  3.1044e-02,
         -1.8174e-02, -1.3629e-02, -2.9928e-02, -2.2692e-02, -2.7269e-03,
          3.6816e-02,  2.3959e-02,  4.3799e-02,  2.9688e-02,  3.7647e-02,
          1.5104e-02, -7.5375e-03,  1.3225e-02,  5.7249e-03, -1.1666e-02,
         -2.9981e-02,  3.2902e-02, -3.2995e-02, -3.0110e-02,  1.1996e-02,
          1.6513e-02,  3.5449e-02,  3.8253e-02, -3.9126e-04,  5.3069e-03,
         -3.0493e-02,  3.5185e-02, -1.2021e-02,  4.3209e-03, -1.9483e-02,
          2.4133e-02, -3.6570e-02,  1.1128e-02,  7.5300e-03, -1.9603e-02,
         -3.9263e-02,  2.1353e-03, -2.6687e-02,  3.1525e-02,  1.4052e-02,
         -7.3651e-03,  3.4925e-02,  9.1222e-03, -2.0324e-02,  1.2228e-02,
          5.383