# Import

In [None]:
import os
from functools import partial
from typing import Any, Optional

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Function

import torch.utils.data as data

import torchvision
from torchvision import datasets
import torchvision.transforms as transforms

import numpy as np

import time
import matplotlib.pyplot as plt

torch.manual_seed(37)

<torch._C.Generator at 0x7f950568a9b0>

# Seperable Convolution

The implementation of DepthwiseConv2D, PointwiseConv2D and SeparableConv2D.

In [None]:
class DepthwiseConv2D(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 padding_mode='zeros',
                 device=None,
                 dtype=None) -> None:
        super().__init__()

        self.depthwise_conv = nn.Conv2d(in_channels=in_channels,
                                        out_channels=out_channels,
                                        kernel_size=kernel_size,
                                        stride=stride,
                                        padding=padding,
                                        dilation=dilation,
                                        groups=groups,
                                        bias=bias,
                                        padding_mode=padding_mode,
                                        device=device,
                                        dtype=dtype)
        
        ########################################################################
        # Below we can add a batch norm or group norm layer to DepthwiseConv2D #
        ########################################################################
        # self.bn = nn.BatchNorm2d(out_channels)
        # self.gn = nn.GroupNorm(16, out_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.depthwise_conv(x)
        return x
        
        ########################################################################
        # Below we can add a batch norm or group norm layer to DepthwiseConv2D #
        ########################################################################        
        # return self.bn(x)
        # return self.gn(x)


class PointwiseConv2D(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 bias=True,
                 device=None,
                 dtype=None) -> None:
        super().__init__()

        self.pointwise_conv = nn.Conv2d(in_channels=in_channels,
                                        out_channels=out_channels,
                                        kernel_size=(1, 1),
                                        stride=1,
                                        padding=0,
                                        dilation=1,
                                        groups=1,
                                        bias=bias,
                                        padding_mode='zeros',
                                        device=device,
                                        dtype=dtype)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pointwise_conv(x)
        return x
        

class SeparableConv2D(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 bias=True,
                 padding_mode='zeros',
                 device=None,
                 dtype=None) -> None:
        super().__init__()

        self.layer1 = self.depthwise_conv = DepthwiseConv2D(in_channels=in_channels,
                                              out_channels=in_channels,
                                              kernel_size=kernel_size,
                                              stride=stride,
                                              padding=padding,
                                              dilation=dilation,
                                              groups=in_channels,
                                              bias=bias,
                                              padding_mode=padding_mode,
                                              device=device,
                                              dtype=dtype)

        self.layer2 = self.pointwise_conv = PointwiseConv2D(in_channels=in_channels,
                                              out_channels=out_channels,
                                              bias=bias,
                                              device=device,
                                              dtype=dtype)
        self.seperable = nn.Sequential(self.layer1, self.layer2)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.seperable(x)

# AlexNet

## AlexNet(Normal)


In [None]:
class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4, padding=0)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
        self.conv2 = nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, stride=1, padding=2)
        self.conv3 = nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1)
        self.conv5 = nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(in_features=9216, out_features=4096)
        self.fc2 = nn.Linear(in_features=4096, out_features=4096)
        self.fc3 = nn.Linear(in_features=4096, out_features=10)


    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.maxpool(x)
        x = F.relu(self.conv2(x))
        x = self.maxpool(x)
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = self.maxpool(x)
        x = x.reshape(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
  
        return x

## AlexNet(SeparableConv2D + BN)

In [None]:
class AlexNet_S(nn.Module):
    def __init__(self):
        super(AlexNet_S, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4, padding=0)
        self.bn1 = nn.BatchNorm2d(96, eps=0.001)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
        self.conv2 = nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, stride=1, padding=2)
        self.bn2 = nn.BatchNorm2d(256, eps=0.001)
        self.conv3 = SeparableConv2D(in_channels=256, out_channels=384, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(384, eps=0.001)
        self.conv4 = SeparableConv2D(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1)
        self.bn4 = nn.BatchNorm2d(384, eps=0.001)
        self.conv5 = SeparableConv2D(in_channels=384, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.bn5 = nn.BatchNorm2d(256, eps=0.001)
        self.fc1 = nn.Linear(in_features=9216, out_features=4096)
        self.fc2 = nn.Linear(in_features=4096, out_features=4096)
        self.fc3 = nn.Linear(in_features=4096, out_features=10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.bn1(x)
        x = self.maxpool(x)
        x = F.relu(self.conv2(x))
        x = self.bn2(x)
        x = self.maxpool(x)
        x = F.relu(self.conv3(x))
        x = self.bn3(x)
        x = F.relu(self.conv4(x))
        x = self.bn4(x)
        x = F.relu(self.conv5(x))
        x = self.bn5(x)
        x = self.maxpool(x)
        x = x.reshape(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
   
        return x

## AlexNet(SeperableConv2D + GN)

In [None]:
class AlexNet_L(nn.Module):
    def __init__(self):
        super(AlexNet_L, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4, padding=0)
        self.gn1 = nn.GroupNorm(16, 96, eps=0.001)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
        self.conv2 = nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, stride=1, padding=2)
        self.gn2 = nn.GroupNorm(16, 256, eps=0.001)
        self.conv3 = SeparableConv2D(in_channels=256, out_channels=384, kernel_size=3, stride=1, padding=1)
        self.gn3 = nn.GroupNorm(16, 384, eps=0.001)
        self.conv4 = SeparableConv2D(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1)
        self.gn4 = nn.GroupNorm(16, 384, eps=0.001)
        self.conv5 = SeparableConv2D(in_channels=384, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.gn5 = nn.GroupNorm(16, 256, eps=0.001)
        self.fc1 = nn.Linear(in_features=9216, out_features=4096)
        self.fc2 = nn.Linear(in_features=4096, out_features=4096)
        self.fc3 = nn.Linear(in_features=4096, out_features=10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.gn1(x)
        x = self.maxpool(x)
        x = F.relu(self.conv2(x))
        x = self.gn2(x)
        x = self.maxpool(x)
        x = F.relu(self.conv3(x))
        x = self.gn3(x)
        x = F.relu(self.conv4(x))
        x = self.gn4(x)
        x = F.relu(self.conv5(x))
        x = self.gn5(x)
        x = self.maxpool(x)
        x = x.reshape(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
   
        return x

# VGG16

## VGG-16(Normal)

In [None]:
class VGG16(nn.Module):
    def __init__(self):
        super(VGG16, self).__init__()
        self.conv1_1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
        self.conv1_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)

        self.conv2_1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.conv2_2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1)

        self.conv3_1 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.conv3_2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self.conv3_3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)

        self.conv4_1 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
        self.conv4_2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.conv4_3 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)

        self.conv5_1 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.conv5_2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.conv5_3 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc1 = nn.Linear(25088, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, 10)

    def forward(self, x):
        x = F.relu(self.conv1_1(x))
        x = F.relu(self.conv1_2(x))
        x = self.maxpool(x)
        x = F.relu(self.conv2_1(x))
        x = F.relu(self.conv2_2(x))
        x = self.maxpool(x)
        x = F.relu(self.conv3_1(x))
        x = F.relu(self.conv3_2(x))
        x = F.relu(self.conv3_3(x))
        x = self.maxpool(x)
        x = F.relu(self.conv4_1(x))
        x = F.relu(self.conv4_2(x))
        x = F.relu(self.conv4_3(x))
        x = self.maxpool(x)
        x = F.relu(self.conv5_1(x))
        x = F.relu(self.conv5_2(x))
        x = F.relu(self.conv5_3(x))
        x = self.maxpool(x)
        x = x.reshape(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, 0.5) #dropout was included to combat overfitting
        x = F.relu(self.fc2(x))
        x = F.dropout(x, 0.5)
        x = self.fc3(x)
        return x

## VGG-16(SeperableConv2D + BN)

In [None]:
class VGG16_S(nn.Module):
    def __init__(self):
        super(VGG16_S, self).__init__()
        self.conv1_1 = SeparableConv2D(in_channels=3, out_channels=64, kernel_size=3, padding=1)
        self.bn1_1 = nn.BatchNorm2d(64, eps=0.001)
        self.conv1_2 = SeparableConv2D(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.bn1_2 = nn.BatchNorm2d(64, eps=0.001)

        self.conv2_1 = SeparableConv2D(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.bn2_1 = nn.BatchNorm2d(128, eps=0.001)
        self.conv2_2 = SeparableConv2D(in_channels=128, out_channels=128, kernel_size=3, padding=1)
        self.bn2_2 = nn.BatchNorm2d(128, eps=0.001)

        self.conv3_1 = SeparableConv2D(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.bn3_1 = nn.BatchNorm2d(256, eps=0.001)
        self.conv3_2 = SeparableConv2D(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self.bn3_2 = nn.BatchNorm2d(256, eps=0.001)
        self.conv3_3 = SeparableConv2D(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self.bn3_3 = nn.BatchNorm2d(256, eps=0.001)

        self.conv4_1 = SeparableConv2D(in_channels=256, out_channels=512, kernel_size=3, padding=1)
        self.bn4_1 = nn.BatchNorm2d(512, eps=0.001)
        self.conv4_2 = SeparableConv2D(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.bn4_2 = nn.BatchNorm2d(512, eps=0.001)
        self.conv4_3 = SeparableConv2D(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.bn4_3 = nn.BatchNorm2d(512, eps=0.001)

        self.conv5_1 = SeparableConv2D(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.bn5_1 = nn.BatchNorm2d(512, eps=0.001)
        self.conv5_2 = SeparableConv2D(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.bn5_2 = nn.BatchNorm2d(512, eps=0.001)
        self.conv5_3 = SeparableConv2D(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.bn5_3 = nn.BatchNorm2d(512, eps=0.001)

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc1 = nn.Linear(25088, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, 10)

    def forward(self, x):
        x = F.relu(self.conv1_1(x))
        x = self.bn1_1(x)
        x = F.relu(self.conv1_2(x))
        x = self.bn1_2(x)
        x = self.maxpool(x)
        x = F.relu(self.conv2_1(x))
        x = self.bn2_1(x)
        x = F.relu(self.conv2_2(x))
        x = self.bn2_2(x)
        x = self.maxpool(x)
        x = F.relu(self.conv3_1(x))
        x = self.bn3_1(x)
        x = F.relu(self.conv3_2(x))
        x = self.bn3_2(x)
        x = F.relu(self.conv3_3(x))
        x = self.bn3_3(x)
        x = self.maxpool(x)
        x = F.relu(self.conv4_1(x))
        x = self.bn4_1(x)
        x = F.relu(self.conv4_2(x))
        x = self.bn4_2(x)
        x = F.relu(self.conv4_3(x))
        x = self.bn4_3(x)
        x = self.maxpool(x)
        x = F.relu(self.conv5_1(x))
        x = self.bn5_1(x)
        x = F.relu(self.conv5_2(x))
        x = self.bn5_2(x)
        x = F.relu(self.conv5_3(x))
        x = self.bn5_3(x)
        x = self.maxpool(x)
        x = x.reshape(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, 0.5) #dropout was included to combat overfitting
        x = F.relu(self.fc2(x))
        x = F.dropout(x, 0.5)
        x = self.fc3(x)
        return x

## VGG-16(SeperableConv2D + GN)

In [None]:
class VGG16_L(nn.Module):
    def __init__(self):
        super(VGG16_L, self).__init__()
        self.conv1_1 = SeparableConv2D(in_channels=3, out_channels=64, kernel_size=3, padding=1)
        self.gn1_1 = nn.GroupNorm(16, 64, eps=0.001)
        self.conv1_2 = SeparableConv2D(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.gn1_2 = nn.GroupNorm(16, 64, eps=0.001)

        self.conv2_1 = SeparableConv2D(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.gn2_1 = nn.GroupNorm(16, 128, eps=0.001)
        self.conv2_2 = SeparableConv2D(in_channels=128, out_channels=128, kernel_size=3, padding=1)
        self.gn2_2 = nn.GroupNorm(16, 128, eps=0.001)

        self.conv3_1 = SeparableConv2D(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.gn3_1 = nn.GroupNorm(16, 256, eps=0.001)
        self.conv3_2 = SeparableConv2D(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self.gn3_2 = nn.GroupNorm(16, 256, eps=0.001)
        self.conv3_3 = SeparableConv2D(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self.gn3_3 = nn.GroupNorm(16, 256, eps=0.001)

        self.conv4_1 = SeparableConv2D(in_channels=256, out_channels=512, kernel_size=3, padding=1)
        self.gn4_1 = nn.GroupNorm(16, 512, eps=0.001)
        self.conv4_2 = SeparableConv2D(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.gn4_2 = nn.GroupNorm(16, 512, eps=0.001)
        self.conv4_3 = SeparableConv2D(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.gn4_3 = nn.GroupNorm(16, 512, eps=0.001)

        self.conv5_1 = SeparableConv2D(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.gn5_1 = nn.GroupNorm(16, 512, eps=0.001)
        self.conv5_2 = SeparableConv2D(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.gn5_2 = nn.GroupNorm(16, 512, eps=0.001)
        self.conv5_3 = SeparableConv2D(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.gn5_3 = nn.GroupNorm(16, 512, eps=0.001)

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc1 = nn.Linear(25088, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, 10)

    def forward(self, x):
        x = F.relu(self.conv1_1(x))
        x = self.gn1_1(x)
        x = F.relu(self.conv1_2(x))
        x = self.gn1_2(x)
        x = self.maxpool(x)
        x = F.relu(self.conv2_1(x))
        x = self.gn2_1(x)
        x = F.relu(self.conv2_2(x))
        x = self.gn2_2(x)
        x = self.maxpool(x)
        x = F.relu(self.conv3_1(x))
        x = self.gn3_1(x)
        x = F.relu(self.conv3_2(x))
        x = self.gn3_2(x)
        x = F.relu(self.conv3_3(x))
        x = self.gn3_3(x)
        x = self.maxpool(x)
        x = F.relu(self.conv4_1(x))
        x = self.gn4_1(x)
        x = F.relu(self.conv4_2(x))
        x = self.gn4_2(x)
        x = F.relu(self.conv4_3(x))
        x = self.gn4_3(x)
        x = self.maxpool(x)
        x = F.relu(self.conv5_1(x))
        x = self.gn5_1(x)
        x = F.relu(self.conv5_2(x))
        x = self.gn5_2(x)
        x = F.relu(self.conv5_3(x))
        x = self.gn5_3(x)
        x = self.maxpool(x)
        x = x.reshape(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, 0.5) #dropout was included to combat overfitting
        x = F.relu(self.fc2(x))
        x = F.dropout(x, 0.5)
        x = self.fc3(x)
        return x

# Training


In [None]:
def train(model, train_loader, test_loader, device):
    
    # loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), 
                                 lr= 1e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[5, 15], gamma=0.1)
    
    training_losses = []
    training_acc = []
    training_time = []
    test_losses = []
    test_acc = []
    test_time = []

    for epoch in range(20):

      # train
      model.train()
      losses = 0.0
      
      correct = 0
      total = 0

      start_time = time.time()
      for _, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        # forward + backward + optimizer + scheduler
        outputs = model(inputs)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        losses += loss.item()

        _, predictions = outputs.max(1)
        correct += (predictions == labels).sum()
        total += predictions.size(0)
      
      tr_time = time.time() - start_time
      training_time.append(tr_time)
      avg_train_loss = losses / len(train_loader)
      training_losses.append(avg_train_loss)
      tr_acc = float(correct) / float(total) * 100
      training_acc.append(tr_acc)
      
      print(f"Epoch: {epoch+1}")
      print("Average training loss: %.4f, Accuracy : %.2f%%, Time(s): %.2f"
            % (avg_train_loss, tr_acc, tr_time))

        
      # evaluate on the 10,000 test images
      model.eval()
      losses = 0.0
      
      with torch.no_grad():
        correct = 0
        total = 0

        start_time = time.time()
        for _, (inputs, labels) in enumerate(test_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)

            loss = criterion(outputs, labels)
            losses += loss.item()

            _, predictions = outputs.max(1)
            correct += (predictions == labels).sum()
            total += predictions.size(0)

        te_time = time.time() - start_time
        test_time.append(te_time)
        avg_test_loss = losses / len(test_loader)
        test_losses.append(avg_test_loss)
        te_acc = float(correct) / float(total) * 100
        test_acc.append(te_acc)

        # change learning rate
        scheduler.step()

        print("Average test loss: %.4f, Accuracy : %.2f%%, Time(s): %.2f"
              % (avg_test_loss, te_acc, te_time))

    # plot
    epochs = [i for i in range(1, 21)]
    plt.plot(epochs, training_losses, "mediumseagreen", label="training_losses")
    plt.plot(epochs, test_losses, "mediumpurple", label="test_losses")
    plt.legend(loc="best")
    plt.xlabel("epochs")
    plt.ylabel("loss")
    plt.savefig("losses.png")
    plt.show()

    plt.plot(epochs, training_acc, "mediumseagreen", label="training_accuracy")
    plt.plot(epochs, test_acc, "mediumpurple", label="test_accuracy")
    plt.legend(loc="best")
    plt.xlabel("epochs")
    plt.ylabel("accuracy")
    plt.savefig("accuracy.png")
    plt.show()

    plt.plot(epochs, training_time, "mediumseagreen", label="training_time")
    plt.plot(epochs, test_time, "mediumpurple", label="test_time")
    plt.legend(loc="best")
    plt.xlabel("epochs")
    plt.ylabel("time")
    plt.savefig("time.png")
    plt.show()

## CIFAR-10

Below, we test different models on CIFAR-10.

In [None]:
# load CIFAR-10 dataset
transform_train = transforms.Compose([transforms.Resize((227,227)), 
                                      transforms.RandomHorizontalFlip(p=0.7), 
                                      transforms.ToTensor(), 
                                      transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                                           std=[0.229, 0.224, 0.225])])
transform_test = transforms.Compose([transforms.Resize((227,227)), 
                                     transforms.ToTensor(), 
                                     transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                                          std=[0.229, 0.224, 0.225])])

batch_size = 64

train_data = datasets.CIFAR10(
    root='./data', 
    train=True,
    download=True,
    transform=transform_train)

train_loader = torch.utils.data.DataLoader(
    train_data,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2)

test_data = datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transform_test)

test_loader = torch.utils.data.DataLoader(
    test_data,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2)

# train
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# AlexNet(Normal)
# model1 = AlexNet()
# model1.to(device)
# train(model1, train_loader, test_loader, device)

# AlexNet(SeparableConv2D + BN)
# model2 = AlexNet_S()
# model2.to(device)
# train(model2, train_loader, test_loader, device)

# AlexNet(SeparableConv2D + GN)
# model3 = AlexNet_L()
# model3.to(device)
# train(model3, train_loader, test_loader, device)

# VGG-16(Normal)
# model4 = VGG16()
# model4.to(device)
# train(model4, train_loader, test_loader, device)

# VGG-16(SeparableConv2D + BN)
# model5 = VGG16_S()
# model5.to(device)
# train(model5, train_loader, test_loader, device)

# VGG-16L(SeparableConv2D + GN)
# model6 = VGG16_L()
# model6.to(device)
# train(model6, train_loader, test_loader, device)

## STL10

Below, we test different models on MNIST.

In [None]:
transform = transforms.Compose([transforms.Resize((227,227)), 
                                transforms.ToTensor()])

batch_size = 64

train_data = datasets.STL10(
    root='./data', 
    split='train',
    download=True,
    transform=transform)

train_loader = torch.utils.data.DataLoader(
    train_data,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2)

test_data = datasets.STL10(
    root='./data',
    split='test',
    download=True,
    transform=transform)

test_loader = torch.utils.data.DataLoader(
    test_data,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2)

# train
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# AlexNet(Normal)
# model1 = AlexNet()
# model1.to(device)
# train(model1, train_loader, test_loader, device)

# AlexNet(SeparableConv2D + BN)
# model2 = AlexNet_S()
# model2.to(device)
# train(model2, train_loader, test_loader, device)

# AlexNet(SeparableConv2D + GN)
# model3 = AlexNet_L()
# model3.to(device)
# train(model3, train_loader, test_loader, device)

# VGG-16(Normal)
# model4 = VGG16()
# model4.to(device)
# train(model4, train_loader, test_loader, device)

# VGG-16(SeperableConv2D + BN)
# model5 = VGG16_S()
# model5.to(device)
# train(model5, train_loader, test_loader, device)

# VGG-16(SeperableConv2D + GN)
# model6 = VGG16_L()
# model6.to(device)
# train(model6, train_loader, test_loader, device)