# 驗證 `model_complexity` 的運算問題

In [1]:
from model_complexity import compute_model_complexity
from torch import nn
from torch.nn import functional as F

class build_model(nn.Module):
    def __init__(self):
        super(build_model, self).__init__()
        
        self.conv = nn.Conv2d(in_channels=1024, out_channels=128, 
                              kernel_size=1, stride=1, padding=0,
                              bias=True)
        self.bn = nn.BatchNorm2d(128)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x_1 = x[0, :, :, :]
        x_2 = x[1, :, :, :]
        inter_x = torch.matmul(x_1, x_2.permute(0, 2, 1)).unsqueeze(0)
        print(inter_x.size())
        x = self.conv(inter_x)
        x = self.bn(x)
        return self.relu(x)

In [2]:
import torch

model = build_model()
x = torch.randn(2, 1024, 16, 8)
num_params, flops = compute_model_complexity(model, x.size(), verbose=True)
s = model(x)

torch.Size([1, 1024, 16, 16])
  -------------------------------------------------------
  Model complexity with input size torch.Size([2, 1024, 16, 8])
  -------------------------------------------------------
  Conv2d (params=131,200, flops=33,587,200)
  BatchNorm2d (params=256, flops=0)
  ReLU (params=0, flops=0)
  -------------------------------------------------------
  Total (params=131,456, flops=33,587,200)
  -------------------------------------------------------
torch.Size([1, 1024, 16, 16])


In [10]:
from models.AATM import AATM

model = AATM(inplanes=1024, mid_planes=256, num='0')
params, flops = compute_model_complexity(model, (1, 2, 1024, 16, 8), verbose=True)


Build 0 layer mutual spatial attention!
Build 0 layer mutual channel attention!
Build 0 layer appearance spatial attention!
Build 0 layer appearacne channel attention!
  -------------------------------------------------------
  Model complexity with input size (1, 2, 1024, 16, 8)
  -------------------------------------------------------
  AdaptiveAvgPool2d (params=0, flops=0)
  Conv2d (params=2,130,560, flops=528,580,608)
  BatchNorm2d (params=6,406, flops=0)
  ReLU (params=0, flops=0)
  Conv1d (params=131,072, flops=262,144)
  Linear (params=856,576, flops=856,576)
  Sigmoid (params=0, flops=0)
  -------------------------------------------------------
  Total (params=3,124,614, flops=529,699,328)
  -------------------------------------------------------


In [15]:
conv = nn.Conv2d(in_channels=10, out_channels=12, kernel_size=1, stride=1, padding=0)
input = torch.randn(1, 10, 3, 3)
from thop import profile

In [19]:
profile(conv, inputs=(input,))


[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.


(0, 0)

In [108]:
class build_model(nn.Module):
    def __init__(self):
        super(build_model, self).__init__()
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU(True)
        
#         self.conv = nn.Sequential(nn.Conv1d(in_channels=1024, out_channels=128, 
#                                           kernel_size=1, stride=1, padding=0,
#                                          bias=False),
#                                 nn.Conv2d(in_channels=512, out_channels=128,
#                                          kernel_size=1, stride=1, padding=0,
#                                          bias=False)
#                              )
        
        self.L1 = nn.Linear(in_features=1024, out_features=128)
#                 nn.ReLU(),
        self.L2 = nn.Linear(in_features=128, out_features=1024)
#                 nn.Sigmoid()            
#         self.bn = nn.BatchNorm2d(128)
#         self.relu = nn.ReLU(inplace=
        
    def forward(self, x):
        x = self.L1(x)
        x = self.L2(x)
        return x

In [110]:
input = torch.randn(1,1024)
conv = build_model()
num_params, flops = compute_model_complexity(conv, input.size(), verbose=True)

  -------------------------------------------------------
  Model complexity with input size torch.Size([1, 1024])
  -------------------------------------------------------
  Linear (params=263,296, flops=263,296)
  -------------------------------------------------------
  Total (params=263,296, flops=263,296)
  -------------------------------------------------------


In [91]:
num_params, flops = compute_model_complexity(conv, input.size(), verbose=True)

  -------------------------------------------------------
  Model complexity with input size torch.Size([1, 1024, 16, 8])
  -------------------------------------------------------
  Conv2d (params=3,407,872, flops=436,207,616)
  BatchNorm2d (params=4,096, flops=0)
  ReLU (params=0, flops=0)
  -------------------------------------------------------
  Total (params=3,411,968, flops=436,207,616)
  -------------------------------------------------------
