In [1]:
####

In [2]:
import torch
from torch import nn

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [22]:
class Conv(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 kernel_size = (3 , 3) , 
                 stride = (1 , 1) , 
                 padding = 1 , 
                 use_norm = True , 
                 use_activation = True , 
                 groups = False , 
                 width_multiplier = 1):
        super(Conv , self).__init__()

        
        self.use_norm = use_norm
        self.use_activation = use_activation

        
        out_channels = out_channels if groups else int(out_channels * width_multiplier)
        grps = in_channels if groups else 1
        self.conv1 = nn.Conv2d(in_channels , out_channels , kernel_size , stride , padding , groups=grps)

        if self.use_norm:
            self.norm = nn.BatchNorm2d(out_channels)

        if self.use_activation:
            self.activation = nn.LeakyReLU(0.2)

    def forward(self , x):
        x = self.conv1(x)
        if self.use_norm:
            x = self.norm(x)
        if self.use_activation:
            x = self.activation(x)
        return x

In [23]:
class MobileNet_Block(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels , 
                 stride=(1 , 1) , 
                 width_multiplier = 1):
        super(MobileNet_Block , self).__init__()

        self.conv1 = Conv(in_channels , 
                          in_channels , 
                          groups=True , 
                          width_multiplier=width_multiplier)
        self.conv2 = Conv(in_channels , 
                          out_channels , 
                          kernel_size=(1 , 1) , 
                          stride=stride , 
                          padding = 0 , 
                          width_multiplier=width_multiplier)
        
    def forward(self , x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x

In [None]:
x = torch.randn(2 , 3 , 224 , 224).to(device)
eb = MobileNet_Block(3 , 32 , stride=2 , width_multiplier=1).to(device)
z = eb(x)
z.shape

In [36]:
dw = 0.7
config = [
          # [out_channels ,  stride ,  width_multiplier]
          # Tuple for repeated block
          [32 , (1 , 1) , dw ] , 
          [32 , (1 , 1) , 1] , 
          [64 , (2 , 2) , dw] , 
          [128 , (1 , 1) , 1] , 
          [128 , (1 , 1) , dw] , 
          [128 , (1 , 1) , 1] , 
          [128 , (2 , 2) , dw] , 
          [256 , (1 , 1) , 1] , 
          [256 , (1 , 1) , dw] , 
          [256 , (1 , 1) , 1] , 
          [256 , (2 , 2) , dw] ,
          (5 , 512 , (1 , 1) , dw) ,  
          [512 , (2 , 2) , 1] , 
          [1024 , (1 , 1) , dw] ,
          [1024 , (1 , 1) , 1]
]

In [43]:
class Mobile_Net(nn.Module):
    def __init__(self , 
                 in_channels , 
                 out_channels_model , 
                 config=config):
        super(Mobile_Net , self).__init__()
        
        self.layers = nn.ModuleList()
        self.layers.append(MobileNet_Block(in_channels , 32 , stride=(2 , 2)))
        in_channels = 32
        for layer in config:
            if isinstance(layer , list):
                out_channels , stride , width_multiplier = layer
                self.layers.append(MobileNet_Block(
                    in_channels ,
                    out_channels , 
                    stride , 
                    width_multiplier
                ))
                in_channels = int(out_channels * width_multiplier)

            elif isinstance(layer , tuple):
                repeats , out_channels , stride , width_multiplier = layer
                in_channels_ = in_channels
                for _ in range(repeats):
                    self.layers.append(
                        MobileNet_Block(in_channels_ , out_channels , width_multiplier=width_multiplier)
                    )
                    in_channels_ = int(out_channels * width_multiplier)
                    self.layers.append(
                        MobileNet_Block(in_channels_ , out_channels)
                    )
                    in_channels_ = out_channels
                in_channels = out_channels

        self.adaptive_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(in_channels , out_channels_model)
    def forward(self , x):
        for layer in self.layers:
            x = layer(x)
        x = self.adaptive_avg_pool(x)
        x = self.fc(x.squeeze(-1).squeeze(-1))
        return x                

In [47]:
def test():
    x = torch.randn(2 , 3 , 224 , 224).to(device)
    mobile_net = Mobile_Net(3 , 1000).to(device)
    z = mobile_net(x)
    print(z.shape)

In [48]:
test()

torch.Size([2, 1000])
