In [1]:
import torch
from torch import nn 

# Depth wise seperable convolutions

1. Depthwise convs
2. pointwise convs

In [2]:
class Depthwise_convs(nn.Module):
    def __init__(self,in_channels,out_channels,stride):
        super(Depthwise_convs,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels,in_channels,
                        kernel_size=(3),
                        stride=stride,
                        padding=1,
                        groups=in_channels,
                        bias=False
                        ),
                        
            nn.BatchNorm2d(in_channels),
            nn.ReLU() )
        
    def forward(self, x: torch.Tensor):
        return self.conv(x)

In [3]:
class pointwise_convs(nn.Module):
    def __init__(self,in_channels,out_channels):
        super(pointwise_convs,self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size=(1),stride=1,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        ) 

    def forward(self, x : torch.Tensor):
        return self.conv(x)

In [4]:
class depthwise_seperable_cons(nn.Module):
    def __init__(self, in_channels,out_channels,stride):
        super(depthwise_seperable_cons,self).__init__()

        self.depthwise = Depthwise_convs(in_channels,out_channels,stride)
        self.pointwise = pointwise_convs(in_channels,out_channels)

    def forward(self, x : torch.Tensor):
        x1 = self.depthwise(x)
        x2 = self.pointwise(x1)
        return x2

In [8]:
class MobileNet(nn.Module):
    def __init__(self, num_classes = 1000):
        super(MobileNet,self).__init__()

        self.model = nn.Sequential(

            nn.Conv2d(in_channels=3,
                        out_channels=32,
                        kernel_size=3,
                        stride=2,
                        bias=False),

            nn.ReLU(inplace=True),

            depthwise_seperable_cons(32,64,1),
            depthwise_seperable_cons(64,128,2),
            depthwise_seperable_cons(128,128,1),
            depthwise_seperable_cons(128,256,2),
            depthwise_seperable_cons(256,256,1),
            depthwise_seperable_cons(256,512,2),

            depthwise_seperable_cons(512,512,1),
            depthwise_seperable_cons(512,512,1),
            depthwise_seperable_cons(512,512,1),
            depthwise_seperable_cons(512,512,1),
            depthwise_seperable_cons(512,512,1),
            
            depthwise_seperable_cons(512,1024,2),
            depthwise_seperable_cons(1024,1024,1),

            nn.AdaptiveAvgPool2d(1),
        )

        self.fc = nn.Linear(1024,num_classes)

        self.softmax = nn.Softmax(dim=1)

    def forward(self, x : torch.Tensor):
        x1 = self.model(x)
        print(f'this is the shape after model construction : {x1.shape}')
        x2 = x1.reshape(x1.size(0),-1)
        x3 = self.fc(x2)
        x4 = self.softmax(x3)
        
        return x4
    
mobilenet_instance = MobileNet(num_classes=1000)

In [9]:
img_tensor = torch.randn(1,3,224,224)

output = mobilenet_instance(img_tensor)
output.shape

this is the shape after model construction : torch.Size([1, 1024, 1, 1])


torch.Size([1, 1000])

In [10]:
num_parms = sum(p.numel() for p in mobilenet_instance.parameters() if p.requires_grad)
print(f'the no of trainable params : {num_parms}')

the no of trainable params : 4231912


In [None]:
4,231,912