# Pleas read:

before your start running this jupyter notebook, please click Edit > Notebook Settings and choose any of the available GPUs.

## 0. Import Python Packages

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import thop
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

## 1. Define ResNet18

In [2]:
class Block(nn.Module):
    """Initial implementation courtesy of GeeksForGeeks."""
    def __init__(self, in_channels, out_channels, stride=1):
        super(Block, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += self.shortcut(x)
        out = self.relu(out)
        return out

class DepthwiseBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(DepthwiseBlock, self).__init__()
        self.dw1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride, padding=1, groups=in_channels, bias=False)
        self.pw1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.dw2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, groups=out_channels, bias=False)
        self.pw2 = nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.dw1(x)
        out = self.pw1(out)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.dw2(out)
        out = self.pw2(out)
        out = self.bn2(out)
        out += self.shortcut(x)
        out = self.relu(out)
        return out

class InvertedResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, expandion=6):
        super(InvertedResidualBlock, self).__init__()
        self.use_shortcut = stride == 1 and in_channels == out_channels
        exp_channels = in_channels * expandion
        layers = []
        if expandion != 1:
            layers.extend([
                nn.Conv2d(in_channels, exp_channels, kernel_size=1, bias=False),
                nn.BatchNorm2d(exp_channels),
                nn.ReLU6(inplace=True)
            ])
        dw_channels = exp_channels if expandion != 1 else in_channels
        layers.extend([
            nn.Conv2d(dw_channels, dw_channels, kernel_size=3, stride=stride, 
                     padding=1, groups=dw_channels, bias=False),
            nn.BatchNorm2d(dw_channels),
            nn.ReLU6(inplace=True)
        ])
        layers.extend([
            nn.Conv2d(dw_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels)
        ])        
        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        if self.use_shortcut:
            return x + self.conv(x)
        else:
            return self.conv(x)

In [3]:
BLOCK = {
    "basic": Block,
    "depthwise": DepthwiseBlock,
    "inverted": InvertedResidualBlock,
}

In [4]:
class ResNet18(nn.Module):
    """Initial implementation courtesy of GeeksForGeeks"""
    def __init__(self, num_classes=10, in_channels=3, block_type='basic'):
        super(ResNet18, self).__init__()
        self.in_channels = 64
        
        self.block = BLOCK.get(block_type)
        assert self.block, "Invalid block_type given. Supported types: 'basic', 'depthwise', 'inverted'."
        
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.l1 = self._make_layer(self.block, 64, 2, stride=1)
        self.l2 = self._make_layer(self.block, 128, 2, stride=2)
        self.l3 = self._make_layer(self.block, 256, 2, stride=2)
        self.l4 = self._make_layer(self.block, 512, 2, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.maxpool(out)

        out = self.l1(out)
        out = self.l2(out)
        out = self.l3(out)
        out = self.l4(out)

        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

## 2. Implement training loop and test function

In [5]:
def train(model, device, train_loader, optimizer, epoch, scheduler):
    losses = []
    model.train()
    for batch, (X, y) in enumerate(train_loader):
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        y_hat = model(X)
        loss = F.nll_loss(F.log_softmax(y_hat, dim=1), y.long())
        loss.backward()
        optimizer.step()
        if batch % 100 == 0:
            losses.append(loss.item())
    scheduler.step()
    return losses

def test(model, device, test_loader):
    model.eval()
    loss = 0
    accuracy = 0
    with torch.no_grad():
        for X, y in test_loader:
            X, y = X.to(device), y.to(device)
            output = model(X)
            loss += F.nll_loss(F.log_softmax(output, dim=1), y.long(), reduction='sum').item()
            predicted = F.log_softmax(output, dim=1).argmax(dim=1, keepdim=True)
            accuracy += predicted.eq(y.view_as(predicted)).sum().item()

    loss /= len(test_loader.dataset)
    accuracy /= len(test_loader.dataset)

    return loss, accuracy

In [6]:
def run(
    epochs: int,
    device: torch.device,
    train_dl: torch.utils.data.DataLoader,
    val_dl: torch.utils.data.DataLoader,
    test_dl: torch.utils.data.DataLoader,
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler.StepLR,
):
    train_losses = []
    val_losses = []
    best = float('inf')
    patience = 10
    stop = 0
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        loss = train(model, device, train_dl, optimizer, epoch, scheduler)
        train_losses.append(sum(loss) / len(loss))
        val_loss, val_acc = test(model, device, val_dl)
        val_losses.append(val_loss)
        print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}")
        if val_loss < best:
            best = val_loss
            stop = 0
        else:
            stop += 1
        if stop >= patience:
            print("Early stopping triggered.")
            break
    test_loss, test_accuracy = test(model, device, test_dl)
    return train_losses, val_losses, test_loss, test_accuracy

# 3. CIFAR10

In [7]:
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD = (0.2470, 0.2435, 0.2616)
TRAIN_TRANSFORM = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),    
    transforms.ToTensor(),
    transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
])
TEST_TRANSFORM = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD),
])

In [8]:
CIFAR10_TRAIN = datasets.CIFAR10("./", train=True, transform=TRAIN_TRANSFORM, download=True)
CIFAR10_TEST = datasets.CIFAR10("./", train=False, transform=TEST_TRANSFORM, download=True)

train_size = int(0.8 * len(CIFAR10_TRAIN))
val_size = len(CIFAR10_TRAIN) - train_size

CIFAR10_TRAIN, CIFAR10_VALIDATION = torch.utils.data.random_split(CIFAR10_TRAIN, [train_size, val_size])

In [9]:
batch = 256
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dl = torch.utils.data.DataLoader(CIFAR10_TRAIN, batch_size=batch, shuffle=True)
val_dl = torch.utils.data.DataLoader(CIFAR10_VALIDATION, batch_size=batch, shuffle=True)
test_dl = torch.utils.data.DataLoader(CIFAR10_TEST, batch_size=batch, shuffle=True)

# 4. Depthwise Convolution

In [23]:
cifar10_epochs = 30
dw_model = ResNet18(in_channels=3, block_type="depthwise").to(device)
dw_optimizer = torch.optim.Adam(dw_model.parameters(), lr=5e-3, weight_decay=1e-4)
dw_scheduler = torch.optim.lr_scheduler.MultiStepLR(dw_optimizer, milestones=[15, 20, 25], gamma=0.1)

dw_train_losses, dw_val_losses, dw_test_loss, dw_test_accuracy = run(cifar10_epochs, device, train_dl, val_dl, test_dl, dw_model, dw_optimizer, dw_scheduler)

Epoch 1/30
Validation Loss: 1.4029, Validation Accuracy: 0.4972
Epoch 2/30
Validation Loss: 1.3013, Validation Accuracy: 0.5429
Epoch 3/30
Validation Loss: 1.1181, Validation Accuracy: 0.5973
Epoch 4/30
Validation Loss: 1.0611, Validation Accuracy: 0.6247
Epoch 5/30
Validation Loss: 1.0055, Validation Accuracy: 0.6500
Epoch 6/30
Validation Loss: 0.9299, Validation Accuracy: 0.6731
Epoch 7/30
Validation Loss: 0.9154, Validation Accuracy: 0.6868
Epoch 8/30
Validation Loss: 0.9805, Validation Accuracy: 0.6633
Epoch 9/30
Validation Loss: 0.8730, Validation Accuracy: 0.6991
Epoch 10/30
Validation Loss: 0.9097, Validation Accuracy: 0.6894
Epoch 11/30
Validation Loss: 0.8729, Validation Accuracy: 0.6986
Epoch 12/30
Validation Loss: 0.8160, Validation Accuracy: 0.7166
Epoch 13/30
Validation Loss: 0.9176, Validation Accuracy: 0.6880
Epoch 14/30
Validation Loss: 0.8012, Validation Accuracy: 0.7205
Epoch 15/30
Validation Loss: 0.7812, Validation Accuracy: 0.7273
Epoch 16/30
Validation Loss: 0.627

In [26]:
print(f"Test Loss: {dw_test_loss:.4f}, Test Accuracy: {dw_test_accuracy:.4f}")

Test Loss: 0.4825, Test Accuracy: 0.8377


# 5. Inverted Residual

In [21]:
cifar10_epochs = 30
inv_model = ResNet18(in_channels=3, block_type="inverted").to(device)
inv_optimizer = torch.optim.Adam(inv_model.parameters(), lr=5e-3, weight_decay=1e-4)
inv_scheduler = torch.optim.lr_scheduler.MultiStepLR(inv_optimizer, milestones=[15, 20, 25], gamma=0.1)

inv_train_losses, inv_val_losses, inv_test_loss, inv_test_accuracy = run(cifar10_epochs, device, train_dl, val_dl, test_dl, inv_model, inv_optimizer, inv_scheduler)

Epoch 1/30
Validation Loss: 1.4009, Validation Accuracy: 0.4927
Epoch 2/30
Validation Loss: 1.4153, Validation Accuracy: 0.5089
Epoch 3/30
Validation Loss: 1.0666, Validation Accuracy: 0.6221
Epoch 4/30
Validation Loss: 1.0816, Validation Accuracy: 0.6214
Epoch 5/30
Validation Loss: 1.0167, Validation Accuracy: 0.6502
Epoch 6/30
Validation Loss: 1.0604, Validation Accuracy: 0.6366
Epoch 7/30
Validation Loss: 0.9309, Validation Accuracy: 0.6655
Epoch 8/30
Validation Loss: 0.8762, Validation Accuracy: 0.6967
Epoch 9/30
Validation Loss: 0.8699, Validation Accuracy: 0.6989
Epoch 10/30
Validation Loss: 0.8577, Validation Accuracy: 0.7003
Epoch 11/30
Validation Loss: 0.8922, Validation Accuracy: 0.6916
Epoch 12/30
Validation Loss: 0.8163, Validation Accuracy: 0.7192
Epoch 13/30
Validation Loss: 0.8520, Validation Accuracy: 0.6964
Epoch 14/30
Validation Loss: 0.7828, Validation Accuracy: 0.7343
Epoch 15/30
Validation Loss: 0.7764, Validation Accuracy: 0.7345
Epoch 16/30
Validation Loss: 0.598

In [27]:
print(f"Test Loss: {inv_test_loss:.4f}, Test Accuracy: {inv_test_accuracy:.4f}")

Test Loss: 0.4199, Test Accuracy: 0.8567


# 6. Standard

In [24]:
cifar10_epochs = 30
basic_model = ResNet18(in_channels=3, block_type="basic").to(device)
basic_optimizer = torch.optim.Adam(basic_model.parameters(), lr=5e-3, weight_decay=1e-4)
basic_scheduler = torch.optim.lr_scheduler.MultiStepLR(basic_optimizer, milestones=[15, 20, 25], gamma=0.5)

basic_train_losses, basic_val_losses, basic_test_loss, basic_test_accuracy = run(cifar10_epochs, device, train_dl, val_dl, test_dl, basic_model, basic_optimizer, basic_scheduler)

Epoch 1/30
Validation Loss: 1.6995, Validation Accuracy: 0.3753
Epoch 2/30
Validation Loss: 1.3043, Validation Accuracy: 0.5352
Epoch 3/30
Validation Loss: 1.2936, Validation Accuracy: 0.5447
Epoch 4/30
Validation Loss: 1.1958, Validation Accuracy: 0.5759
Epoch 5/30
Validation Loss: 1.2257, Validation Accuracy: 0.5751
Epoch 6/30
Validation Loss: 1.0868, Validation Accuracy: 0.6233
Epoch 7/30
Validation Loss: 1.0952, Validation Accuracy: 0.6191
Epoch 8/30
Validation Loss: 0.9393, Validation Accuracy: 0.6706
Epoch 9/30
Validation Loss: 0.9378, Validation Accuracy: 0.6718
Epoch 10/30
Validation Loss: 0.9241, Validation Accuracy: 0.6776
Epoch 11/30
Validation Loss: 0.8648, Validation Accuracy: 0.7019
Epoch 12/30
Validation Loss: 0.8359, Validation Accuracy: 0.7138
Epoch 13/30
Validation Loss: 0.8867, Validation Accuracy: 0.6894
Epoch 14/30
Validation Loss: 0.7985, Validation Accuracy: 0.7246
Epoch 15/30
Validation Loss: 0.9639, Validation Accuracy: 0.6722
Epoch 16/30
Validation Loss: 0.670

In [28]:
print(f"Test Loss: {basic_test_loss:.4f}, Test Accuracy: {basic_test_accuracy:.4f}")

Test Loss: 0.4796, Test Accuracy: 0.8527


# Checkpoint

In [32]:
torch.save(basic_model.state_dict(), "basic_model.pth")
torch.save(dw_model.state_dict(), "dw_model.pth")
torch.save(inv_model.state_dict(), "inv_model.pth")

In [20]:
basic_model = ResNet18(in_channels=3, block_type="basic").to(device)
dw_model = ResNet18(in_channels=3, block_type="depthwise").to(device)
inv_model = ResNet18(in_channels=3, block_type="inverted").to(device)

In [21]:
basic_model.load_state_dict(torch.load("basic_model.pth"))
dw_model.load_state_dict(torch.load("dw_model.pth"))
inv_model.load_state_dict(torch.load("inv_model.pth"))

<All keys matched successfully>

## 7. Analysis


In [22]:
def flops(model):
    model.eval()
    macs, params = thop.profile(model, inputs=(torch.randn(1, 3, 32, 32).to(device),))
    return macs, params

In [23]:
basic_macs, basic_params = flops(basic_model)
dw_macs, dw_params = flops(dw_model)
inv_macs, inv_params = flops(inv_model)


[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_adap_avg

In [None]:
print("Basic Model:")
print(f"MACs: {basic_macs}, Params: {basic_params}")
print("Depthwise Model:")
print(f"MACs: {dw_macs}, Params: {dw_params}")
print("Inverted Model:")
print(f"MACs: {inv_macs}, Params: {inv_params}")

Basic Model:
MACs: 141000192.0, Params: 11173962.0
Depthwise Model:
MACs: 20406784.0, Params: 1439626.0
Inverted Model:
MACs: 99563008.0, Params: 5901002.0


In [39]:
def latency(model: torch.nn.Module, device: torch.device, input_shape=(1, 3, 32, 32),
                    warmup: int = 50, runs: int = 100) -> dict:
    model.eval()
    
    dummy = torch.randn(*input_shape, device=device)
    assert device.type == "cuda", "Latency measurement requires CUDA device"
    
    torch.backends.cudnn.benchmark = True
    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    
    with torch.inference_mode():
        for _ in range(warmup):
            _ = model(dummy)
    torch.cuda.synchronize()
    
    times = []
    with torch.inference_mode():
        for _ in range(runs):
            starter.record()
            _ = model(dummy)
            ender.record()
            torch.cuda.synchronize()
            times.append(starter.elapsed_time(ender))
    
    avg = sum(times) / len(times)
    p50 = sorted(times)[len(times)//2]
    p90 = sorted(times)[int(len(times)*0.9)]

    return {"avg_ms": avg, "p50_ms": p50, "p90_ms": p90, "runs": runs}


In [40]:
input_shape = (1, 3, 32, 32)

basic_latency_single = latency(basic_model, device, input_shape=input_shape)
dw_latency_single = latency(dw_model, device, input_shape=input_shape)
inv_latency_single = latency(inv_model, device, input_shape=input_shape)

basic_latency_batched = latency(basic_model, device, input_shape=(batch, 3, 32, 32))
dw_latency_batched = latency(dw_model, device, input_shape=(batch, 3, 32, 32))
inv_latency_batched = latency(inv_model, device, input_shape=(batch, 3, 32, 32))

print("Latency (per forward pass):")
print("Basic ResNet-18:", basic_latency_single)
print("Depthwise ResNet-18:", dw_latency_single)
print("Inverted ResNet-18:", inv_latency_single)

print("\nLatency (batch of images):")
print("Basic ResNet-18:", basic_latency_batched)
print("Depthwise ResNet-18:", dw_latency_batched)
print("Inverted ResNet-18:", inv_latency_batched)

Latency (per forward pass):
Basic ResNet-18: {'avg_ms': 1.441298553943634, 'p50_ms': 1.4417599439620972, 'p90_ms': 1.4489599466323853, 'runs': 100}
Depthwise ResNet-18: {'avg_ms': 1.567883838415146, 'p50_ms': 1.5667519569396973, 'p90_ms': 1.5841920375823975, 'runs': 100}
Inverted ResNet-18: {'avg_ms': 1.5076278448104858, 'p50_ms': 1.5062719583511353, 'p90_ms': 1.516543984413147, 'runs': 100}

Latency (batch of images):
Basic ResNet-18: {'avg_ms': 5.021609635353088, 'p50_ms': 5.0575361251831055, 'p90_ms': 5.238783836364746, 'runs': 100}
Depthwise ResNet-18: {'avg_ms': 3.460281584262848, 'p50_ms': 3.4611198902130127, 'p90_ms': 3.4662399291992188, 'runs': 100}
Inverted ResNet-18: {'avg_ms': 11.047739572525025, 'p50_ms': 11.040767669677734, 'p90_ms': 11.084799766540527, 'runs': 100}
