In [3]:
import torch.nn as nn
import torch
from torchinfo import summary
import math
from thop import profile

In [1]:
from backbone.efficientnet import efficientnet_b0

In [2]:
model = efficientnet_b0(num_classes=2)

In [4]:
input = torch.randn(1, 3, 224, 224)
output = model(input)
print(output.shape) 

torch.Size([1, 2])


In [5]:
macs, params = profile(model, inputs=(input, ))

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


In [6]:
print(f"FLOPs: {macs / 1e9} G")  # 打印计算量（以十亿次浮点运算为单位）  
print(f"Params: {params / 1e6} M")  # 打印参数量（以百万为单位）

FLOPs: 0.411552096 G
Params: 4.01011 M


In [7]:
summary(model, (1, 3, 224, 224))

Layer (type:depth-idx)                        Output Shape              Param #
EfficientNet                                  [1, 2]                    --
├─Sequential: 1-1                             [1, 1280, 7, 7]           --
│    └─ConvBNActivation: 2-1                  [1, 32, 112, 112]         --
│    │    └─Conv2d: 3-1                       [1, 32, 112, 112]         864
│    │    └─BatchNorm2d: 3-2                  [1, 32, 112, 112]         64
│    │    └─ReLU: 3-3                         [1, 32, 112, 112]         --
│    └─InvertedResidual: 2-2                  [1, 16, 112, 112]         --
│    │    └─Sequential: 3-4                   [1, 16, 112, 112]         1,448
│    │    └─Identity: 3-5                     [1, 16, 112, 112]         --
│    └─InvertedResidual: 2-3                  [1, 24, 56, 56]           --
│    │    └─Sequential: 3-6                   [1, 24, 56, 56]           6,004
│    │    └─Identity: 3-7                     [1, 24, 56, 56]           --
│    └─Invert