In [7]:
import torch
from torch import nn
from torchinfo import summary
from torchvision import transforms, datasets
from torchvision.ops import StochasticDepth
import math

In [8]:
class SEBlock(nn.Module):
    def __init__(self, in_channels, squeeze_channels):
        super().__init__()

        self.squeeze = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Sequential(
            nn.Linear(in_channels, squeeze_channels),
            nn.SiLU(),
            nn.Linear(squeeze_channels, in_channels),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        SE = self.squeeze(x)
        SE = SE.reshape(x.shape[0], x.shape[1])
        
        SE = self.fc(SE)
        SE = SE.unsqueeze(dim=2).unsqueeze(dim=3)

        x = x*SE
        return x

In [10]:
class DepSepConv(nn.Module):
    def __init__(self, in_channels, squeeze_channels, out_channels, kernel_size, stride):
        super().__init__()

        self.depthwise = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size, stride=stride, padding=(kernel_size-1)//2, groups=in_channels, bias=False),
            nn.BatchNorm2d(in_channels),    
            nn.SiLU()
        )

        self.SE_blk = SEBlock(in_channels, squeeze_channels)

        self.seperate = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        x = self.depthwise(x)
        x = self.SE_blk(x)
        x = self.seperate(x)
        return x

In [20]:
class MBConv(nn.Module):
    def __init__(self, in_channels, exp_channels, out_channels, kernel_size, stride, sd_prob):
        super().__init__()

        self.use_skip = True if stride==1 and in_channels==out_channels else False

        # residual을 확률적으로 건너뛰게 하는 것 (drop_out과 유사)
        # sd_prob가 죽일 확률 # "row"는 data 마다 다른 depth를 가지게 함!
        self.stochastic_depth = StochasticDepth(sd_prob, "row") 
        
        layers = []

        if in_channels!=exp_channels:
            layers += [nn.Sequential(
                nn.Conv2d(in_channels, exp_channels, 1, bias=False),
                nn.BatchNorm2d(exp_channels),
                nn.SiLU()
            )]
        
        layers += [DepSepConv(exp_channels, in_channels//4, out_channels, kernel_size, stride=stride)]

        self.residual = nn.Sequential(*layers)
    
    def forward(self, x):
        if self.use_skip==True:
            residual = self.residual(x)
            residual = self.stochastic_depth(residual)
            return x + residual
        else:
            return self.residual(x)

In [21]:
# 가까운 8의 배수를 찾아줌
def _make_divisible(v, divisor, min_value=None):
    if min_value is None:
        min_value = divisor

    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 
    if new_v < 0.9 * v: 
        new_v += divisor

    return new_v

class EfficientNet(nn.Module):
    def __init__(self, num_classes, depth_mult, width_mult, resize_size, crop_size, stochastic_depth_p=0.2):
        super().__init__()

        cfgs = [#k,  t,   c,  n,  s
                [3,  1,  16,  1,  1],
                [3,  6,  24,  2,  2],
                [5,  6,  40,  2,  2],
                [3,  6,  80,  3,  2],
                [5,  6,  112, 3,  1],
                [5,  6,  192, 4,  2],
                [3,  6,  320, 1,  1]]

        in_channels = _make_divisible(32 * width_mult, 8) # width 조절!

        self.transforms = transforms.Compose([transforms.Resize(resize_size, interpolation=transforms.InterpolationMode.BICUBIC),
                                              transforms.CenterCrop(crop_size),
                                              transforms.ToTensor(),
                                              transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) # resolution 조절!

        self.stem_conv = nn.Sequential(nn.Conv2d(3, in_channels, 3, padding=1, stride=2, bias=False),
                                       nn.BatchNorm2d(in_channels, momentum=0.99, eps=1e-3),
                                       nn.SiLU(inplace=True))

        layers = []
        num_block = 0
        N = sum([math.ceil(cfg[-2] * depth_mult) for cfg in cfgs]) # 총 깊이

        for k, t, c, n, s in cfgs:
            n = math.ceil(n * depth_mult) # depth 조절
            for i in range(n):
                stride = s if i == 0 else 1
                exp_channels = _make_divisible(in_channels * t, 8)
                out_channels = _make_divisible(c * width_mult, 8) # width 조절
                sd_prob = stochastic_depth_p * num_block / (N-1) # 뒤로 갈수록 건너 뛸 확률을 크게
                
                layers += [MBConv(in_channels, exp_channels, out_channels, k, stride, sd_prob)]
                in_channels = out_channels
                num_block += 1

        self.layers = nn.Sequential(*layers)

        # building last several layers
        last_channels = _make_divisible(1280 * width_mult, 8) # width 조절
        self.last_conv = nn.Sequential(nn.Conv2d(in_channels, last_channels, 1, bias=False),
                                       nn.BatchNorm2d(last_channels, momentum=0.99, eps=1e-3),
                                       nn.SiLU(inplace=True))

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        self.classifier = nn.Linear(last_channels, num_classes)
    
    def forward(self, x):
        x = self.stem_conv(x)
        x = self.layers(x)
        x = self.last_conv(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [22]:

def efficientnet_b0(num_classes=1000):
    return EfficientNet(num_classes=num_classes, depth_mult=1.0, width_mult=1.0, resize_size=256, crop_size=224)

def efficientnet_b1(num_classes=1000):
    return EfficientNet(num_classes=num_classes, depth_mult=1.1, width_mult=1.0, resize_size=256, crop_size=240)

def efficientnet_b2(num_classes=1000):
    return EfficientNet(num_classes=num_classes, depth_mult=1.2, width_mult=1.1, resize_size=288, crop_size=288)

def efficientnet_b3(num_classes=1000):
    return EfficientNet(num_classes=num_classes, depth_mult=1.4, width_mult=1.2, resize_size=320, crop_size=3003)

def efficientnet_b4(num_classes=1000):
    return EfficientNet(num_classes=num_classes, depth_mult=1.8, width_mult=1.4, resize_size=384, crop_size=380)

def efficientnet_b5(num_classes=1000):
    return EfficientNet(num_classes=num_classes, depth_mult=2.2, width_mult=1.6, resize_size=456, crop_size=456)

def efficientnet_b6(num_classes=1000):
    return EfficientNet(num_classes=num_classes, depth_mult=2.6, width_mult=1.8, resize_size=528, crop_size=528)

def efficientnet_b7(num_classes=1000):
    return EfficientNet(num_classes=num_classes, depth_mult=3.1, width_mult=2.0, resize_size=600, crop_size=600)

In [23]:
model = efficientnet_b7()
summary(model, input_size=(2, 3, 456,456), device='cpu')

Layer (type:depth-idx)                                  Output Shape              Param #
EfficientNet                                            [2, 1000]                 --
├─Sequential: 1-1                                       [2, 64, 228, 228]         --
│    └─Conv2d: 2-1                                      [2, 64, 228, 228]         1,728
│    └─BatchNorm2d: 2-2                                 [2, 64, 228, 228]         128
│    └─SiLU: 2-3                                        [2, 64, 228, 228]         --
├─Sequential: 1-2                                       [2, 640, 15, 15]          --
│    └─MBConv: 2-4                                      [2, 32, 228, 228]         --
│    │    └─Sequential: 3-1                             [2, 32, 228, 228]         4,944
│    └─MBConv: 2-5                                      [2, 32, 228, 228]         --
│    │    └─Sequential: 3-2                             [2, 32, 228, 228]         1,992
│    │    └─StochasticDepth: 3-3                  