In [1]:
import torch
from torchvision import transforms as T

import torch.optim as optim

from dataset import build_cifar
from utils import Trainer, plot_sampledata, plot_misclassified

In [11]:
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary


class BaseNet(nn.Module):
    def summarize(self, device: torch.device, input_size: tuple = (1, 3, 32, 32)):
        print(summary(self.to(device), input_size=input_size))


# depthwise separable convolution
class DepthwiseSeparable(nn.Module):
    def __init__(self, nin, nout, kernel_size=3, padding=1, bias=False):
        super(DepthwiseSeparable, self).__init__()

        self.depthwise = nn.Conv2d(
            nin, nin, kernel_size=kernel_size, padding=padding, groups=nin, bias=bias
        )
        self.pointwise = nn.Conv2d(nin, nout, kernel_size=1, bias=bias)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out


class Net(BaseNet):
    def __init__(self, drop: float = 0):
        super(Net, self).__init__()

        # Block 1
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Dropout2d(drop),
            DepthwiseSeparable(16, 32, 3, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout2d(drop),
            nn.Conv2d(32, 64, 3, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout2d(drop),
        ) # 16

        self.transition1 = nn.Sequential(
            nn.Conv2d(64, 8, 1, bias=False),
            nn.ReLU()
        )
        # Block 2
        self.layer2 = nn.Sequential(
            nn.Conv2d(8, 16, 3, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Dropout2d(drop),
            DepthwiseSeparable(16, 32, 3, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout2d(drop),
            nn.Conv2d(32, 64, 3, dilation=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout2d(drop),
        ) # 8
        self.transition2 = nn.Sequential(
            nn.Conv2d(64, 16, 1, bias=False),
            nn.ReLU()
        )
        
        # Block 3
        self.layer3 = nn.Sequential(
            nn.Conv2d(16, 32, 3, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout2d(drop),
            DepthwiseSeparable(32, 64, 3, padding=0, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout2d(drop),
            DepthwiseSeparable(64, 128, 3, padding=0, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout2d(drop),
        )

        # Fully connected layer
        self.out = nn.Sequential( 
            nn.AvgPool2d(4),
            nn.Conv2d(
                in_channels=128, out_channels=10, kernel_size=(1, 1), bias=False
            ),  # output  RF: 28
        )

    def forward(self, x):
        x = self.layer1(x)
        x = self.transition1(x)
        x = self.layer2(x)
        x = self.transition2(x)
        x = self.layer3(x)

        x = self.out(x)
        x = x.view(-1, 10)

        return F.log_softmax(x, dim=1)


In [12]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
model = Net().to(device)

In [13]:
model.summarize(device, (1, 3, 32, 32))


Layer (type:depth-idx)                   Output Shape              Param #
Net                                      [36, 10]                  --
├─Sequential: 1-1                        [1, 64, 28, 28]           --
│    └─Conv2d: 2-1                       [1, 16, 32, 32]           432
│    └─BatchNorm2d: 2-2                  [1, 16, 32, 32]           32
│    └─ReLU: 2-3                         [1, 16, 32, 32]           --
│    └─Dropout2d: 2-4                    [1, 16, 32, 32]           --
│    └─DepthwiseSeparable: 2-5           [1, 32, 32, 32]           --
│    │    └─Conv2d: 3-1                  [1, 16, 32, 32]           144
│    │    └─Conv2d: 3-2                  [1, 32, 32, 32]           512
│    └─BatchNorm2d: 2-6                  [1, 32, 32, 32]           64
│    └─ReLU: 2-7                         [1, 32, 32, 32]           --
│    └─Dropout2d: 2-8                    [1, 32, 32, 32]           --
│    └─Conv2d: 2-9                       [1, 64, 28, 28]           18,432
│    └─B