In [1]:
import os, sys
import torch
import torchvision
import torch.nn as nn

### Setting depthwise separable convolution block

In [21]:
def get_num_parameters(model):
    total_num_parameters = 0
    for name, paras in model.named_parameters():
        total_num_parameters += paras.numel()
        
    return total_num_parameters

In [63]:
class depthwise_separable_conv(nn.Module):
    def __init__(self, in_channels, kernels_per_layer, out_channels):
        super(depthwise_separable_conv, self).__init__()
        
        self.depthwise_conv = nn.Conv2d(in_channels=in_channels,
                                        out_channels=kernels_per_layer * in_channels, 
                                        kernel_size=(3,3), 
                                        stride = (1,1),
                                        padding = (1,1),
                                        groups=in_channels)
        self.pointwise_conv = nn.Conv2d(in_channels=kernels_per_layer * in_channels,
                                        out_channels=out_channels,
                                        kernel_size=(1,1),
                                        stride=(1,1),
                                        padding=(0,0))
    def forward(self, x):
        out = self.depthwise_conv(x)
        out = self.pointwise_conv(out)
        
        return out
    
class depthwise_conv(nn.Module):
    def __init__(self, in_channels, kernels_per_layer, out_channels, groups, **kwarg):
        super(depthwise_conv, self).__init__()
        
        assert in_channels%groups == 0, "Groups Error: in_channels should be divisible by groups"
            
        self.depthwise_conv = nn.Conv2d(in_channels=in_channels,
                                        out_channels=kernels_per_layer*in_channels,
                                        groups=in_channels,
                                        **kwarg)
        
    def forward(self, x):
        out = self.depthwise_conv(x)
        
        return out

### Testing Depthwise Separable Convolution on MNIST¶

In [64]:
mnist_trainset = torchvision.datasets.MNIST(root='./data',
                                            train=True,
                                            download=True,
                                            transform=torchvision.transforms.ToTensor())
mnist_loader = torch.utils.data.DataLoader(dataset=mnist_trainset, batch_size=64)
x, y = mnist_loader.__iter__().__next__()

In [68]:
model = depthwise_separable_conv(in_channels=x.size()[1], kernels_per_layer=1, out_channels=64)
model_original_conv = nn.Conv2d(in_channels=x.size()[1], out_channels=64, kernel_size=(3,3), padding=(1,1))

In [69]:
out = model(x)
out_original_conv = model_original_conv(x)
out.size(), out_original_conv.size()

(torch.Size([64, 64, 28, 28]), torch.Size([64, 64, 28, 28]))

In [77]:
""" Comparing parameters"""
paras = get_num_parameters(model) # (3*3*1 + 1) + (1*1*1*64 + 64)
paras_original_conv = get_num_parameters(model_original_conv) # 3*3*1*64 + 64
paras, paras_original_conv

(138, 640)

In [72]:
for name, i in model.named_parameters():
    print(name, i.numel())

depthwise_conv.weight 9
depthwise_conv.bias 1
pointwise_conv.weight 64
pointwise_conv.bias 64


### Testing Depthwise Convolution on MNIST

In [55]:
for names, i in model_depthwise.named_parameters():
    print(names)

depthwise_conv.weight
depthwise_conv.bias


In [76]:
model_depthwise = depthwise_conv(in_channels=x.size()[1],
                                 kernels_per_layer=1,
                                 out_channels=64,
                                 groups = x.size()[1],
                                 kernel_size=(3,3), 
                                 padding=(0,0),
                                 stride=(1,1))

paras_depthwise = get_num_parameters(model_depthwise) # (3*3*1 + 1)
paras_depthwise

10

### Reference:

1. [How to modify Conv2d to Depthwise Separable Convolution](https://discuss.pytorch.org/t/how-to-modify-a-conv2d-to-depthwise-separable-convolution/15843/10)
2. [Paper](https://arxiv.org/pdf/1704.04861.pdf)