In [66]:
import torch
import torch.nn as nn

import torch
import torch.nn as nn

class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(DepthwiseSeparableConv, self).__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=in_channels)
        self.bn1 = nn.BatchNorm2d(in_channels)

        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):

        x = self.depthwise(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.pointwise(x)
        x = self.bn2(x)
        x = self.relu(x)
        return x

class CustomDecoder(nn.Module):
    def __init__(self, num_classes=8):
        super(CustomDecoder, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1)
        self.dconv1 = DepthwiseSeparableConv(32, 64, kernel_size=3, stride=1, padding=1)
        self.dconv2 = DepthwiseSeparableConv(64, 128, kernel_size=3, stride=2, padding=1)
        self.dconv3 = DepthwiseSeparableConv(128, 128, kernel_size=3, stride=1, padding=1)
        self.dconv4 = DepthwiseSeparableConv(128, 256, kernel_size=3, stride=2, padding=1)
        self.dconv5 = DepthwiseSeparableConv(256, 256, kernel_size=3, stride=1, padding=1)
        
        self.dconv6 = DepthwiseSeparableConv(256, 512, kernel_size=3, stride=2, padding=1)
        self.fivex_dconv = nn.ModuleList([DepthwiseSeparableConv(512, 512, kernel_size=3, stride=1, padding=1) for _ in range(5)])
        
        self.dconv7 = DepthwiseSeparableConv(512, 1024, kernel_size=3, stride=2, padding=1)
        self.dconv8 = DepthwiseSeparableConv(1024, 1024, kernel_size=3, stride=1, padding=1)
        
        self.dropout_dense = nn.Dropout(p=0.3)
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(1024, num_classes)
        self.softmax = nn.Softmax()

    def forward(self, x):
        x = self.conv1(x)
        x = self.dconv1(x)
        x = self.dconv2(x)
        x = self.dconv3(x)
        x = self.dconv4(x)
        x = self.dconv5(x)
        x = self.dconv6(x)
        for dconv_layer in self.fivex_dconv: x = dconv_layer(x)   
        x = self.dconv7(x)
        x = self.dconv8(x)
        x = self.global_avg_pool(x)
        x = self.dropout_dense(x)
        x = self.fc(x.reshape(32,1024))
        return self.softmax(x)


# Create an instance of the model
model = CustomDecoder(num_classes=8)
inputs = torch.randint(0, 255, (32, 3, 128, 128),dtype=torch.float32)
out = model(inputs)

  return self.softmax(x)


In [67]:
out.shape

torch.Size([32, 8])