
# Inertia as a Form of Model Compression in Convolutional Neural Networks


> *In submission for Deep Learning Project, Spring 2025*

**Labiba Shahab**

**Ahmed Wali**


> *What you usually achieve with a standard convolution, you can accomplish efficiently using a smaller convolution + peripheral inertia mechanism (inertial filter)*
~ Group 39


## In a Gist:

Standard dxd convolution layers involve d^2 learnable parameters and d^2 computations per convolution operation. In this project, we propose an inertial convolution mechanism that dynamically decides whether a detailed convolution is necessary. The goal is to reduce both computations and learnable parameters while maintaining performance on vision tasks like MNIST classification.

Instead of learning a full dxd kernel, we use a (d-k)x(d-k) core filter to convolve a central patch. The surrounding d^2-(d-k)^2 pixels act as an inertial periphery—evaluating local divergence or “friction.” If the divergence is high, we re-apply the core filter across the full dxd region in another convolution and stack the outputs. If low, we skip detailed computation.

In the best case, we perform just d-k computation with d-k learnable parameters.
In the worst case, we perform up to d computations, but still only learn d-k parameters, hence effectively pruning the model with estimation compression.

## Background

There has been extensive research in reducing neural network complexity:

Dynamic convolutions and skip-convolutions conditionally skip expensive operations. This project is inspired from [Dynamic Sparse Convolutions](https://arxiv.org/pdf/2102.04906), [Skip Convolutions](https://openaccess.thecvf.com/content/CVPR2021/papers/Habibian_Skip-Convolutions_for_Efficient_Video_Processing_CVPR_2021_paper.pdf), and [Fractional Skipping](https://arxiv.org/abs/2001.00705)

Pruning, quantization, and knowledge distillation reduce parameters, memory, or model depth.

Our approach is inspired by these ideas but focuses on parameter reuse and friction-aware skipping, offering a novel trade-off between learning capacity and computational efficiency.

#### Novelty

Inspired from computational optimization, our project proposes a similar mechanism to do model compression by estimation.

## Baseline Models

#### CIFAR 
In this dataset, we needed models model that are simple enough to modify, high-performing enough to be credible and reproducible, and modular enough to swap out Conv2d layers to test our inertial convolution. We chose ResNet18, VGG16, and SimpleDLA as the baselines models on the CIFAR-10 dataset directly adapted from [kuangliu/pytorch-cifar](https://github.com/kuangliu/pytorch-cifar). 

Reproducibility + Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import random
import numpy as np

Seeds and CuDNN configs for deterministic results on a GPU

In [2]:
seed = 1
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Load CIFAR-10 Dataset

In [3]:
transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10('./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:14<00:00, 12.2MB/s] 


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


Test Function

In [4]:
def test(model, device, test_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    accuracy = correct / len(test_loader.dataset)
    print(f"Accuracy: {accuracy:.10f}")
    return accuracy

## Baseline Architectures

The following blocks will define the 4 baseline models adapted from KuangLiu's repo.
- To integrate inertial filters, we later plan to replace `nn.Conv2d` with `InertialConv2d` (which we would define in future) in these blocks.

ResNet18 Model

In [5]:
#@adapted from https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, 1, stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return F.relu(out)

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super().__init__()
        self.in_planes = 64
        self.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64,  num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for s in strides:
            layers.append(block(self.in_planes, planes, s))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        return self.linear(out)

def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])

VGG-16 Model

In [6]:
#@adapted from https://github.com/kuangliu/pytorch-cifar/blob/master/models/vgg.py
cfg = {
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M',
              512, 512, 512, 'M', 512, 512, 512, 'M'],
}

class VGG(nn.Module):
    def __init__(self, vgg_name='VGG16'):
        super().__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(512, 10)
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)
    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(2, 2)]
            else:
                layers += [
                    nn.Conv2d(in_channels, x, 3, padding=1),
                    nn.BatchNorm2d(x),
                    nn.ReLU(inplace=True)
                ]
                in_channels = x
        layers += [nn.AvgPool2d(1)]
        return nn.Sequential(*layers)

SimpleDLA Model

In [7]:
#@adapted from https://github.com/kuangliu/pytorch-cifar/blob/master/models/dla_simple.py
class Root(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, 1,
                              padding=(kernel_size - 1) // 2, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
    def forward(self, xs):
        x = torch.cat(xs, 1)
        return F.relu(self.bn(self.conv(x)))

class Tree(nn.Module):
    def __init__(self, block, in_channels, out_channels, level=1, stride=1):
        super().__init__()
        self.root = Root(2*out_channels, out_channels)
        if level == 1:
            self.left = block(in_channels, out_channels, stride)
            self.right = block(out_channels, out_channels)
        else:
            self.left = Tree(block, in_channels, out_channels, level-1, stride)
            self.right = Tree(block, out_channels, out_channels, level-1, 1)
    def forward(self, x):
        x1 = self.left(x)
        x2 = self.right(x1)
        return self.root([x1, x2])

class SimpleDLA(nn.Module):
    def __init__(self, block=BasicBlock, num_classes=10):
        super().__init__()
        self.base = nn.Sequential(
            nn.Conv2d(3, 16, 3, 1, 1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)
        )
        self.layer1 = nn.Sequential(
            nn.Conv2d(16, 16, 3, 1, 1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, 3, 1, 1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        self.layer3 = Tree(block, 32, 64, level=1, stride=1)
        self.layer4 = Tree(block, 64, 128, level=2, stride=2)
        self.layer5 = Tree(block, 128, 256, level=2, stride=2)
        self.layer6 = Tree(block, 256, 512, level=1, stride=2)
        self.linear = nn.Linear(512, num_classes)
    def forward(self, x):
        x = self.base(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.layer6(x)
        x = F.avg_pool2d(x, 4)
        x = x.view(x.size(0), -1)
        return self.linear(x)

Training Function

In [8]:
def train_model(model, name):
    model = model.to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[150, 225], gamma=0.1)
    epochs = 200

    for epoch in range(epochs):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
        scheduler.step()
        print(f"Epoch {epoch} complete.")

    print(f"\\n{name} Results:")
    acc = test(model, device, test_loader)
    return acc

So Let's Evaluate MNIST on the Baseline

In [None]:
model_rn18 = ResNet18()
acc_rn18 = train_model(model_rn18, "ResNet18")

model_vgg16 = VGG('VGG16')
acc_vgg16 = train_model(model_vgg16, "VGG16")

model_simpledla = SimpleDLA()
acc_simpledla = train_model(model_simpledla, "SimpleDLA")

Epoch 0 complete.
Epoch 1 complete.
Epoch 2 complete.
Epoch 3 complete.
Epoch 4 complete.
Epoch 5 complete.
Epoch 6 complete.
Epoch 7 complete.
Epoch 8 complete.
Epoch 9 complete.
Epoch 10 complete.
Epoch 11 complete.
Epoch 12 complete.
Epoch 13 complete.
Epoch 14 complete.
Epoch 15 complete.
Epoch 16 complete.
Epoch 17 complete.
Epoch 18 complete.
Epoch 19 complete.
Epoch 20 complete.
Epoch 21 complete.
Epoch 22 complete.
Epoch 23 complete.
Epoch 24 complete.
Epoch 25 complete.
Epoch 26 complete.
Epoch 27 complete.
Epoch 28 complete.
Epoch 29 complete.
Epoch 30 complete.
Epoch 31 complete.
Epoch 32 complete.
Epoch 33 complete.
Epoch 34 complete.
Epoch 35 complete.
Epoch 36 complete.
Epoch 37 complete.
Epoch 38 complete.
Epoch 39 complete.
Epoch 40 complete.
Epoch 41 complete.
Epoch 42 complete.
Epoch 43 complete.
Epoch 44 complete.
Epoch 45 complete.
Epoch 46 complete.
Epoch 47 complete.
Epoch 48 complete.
Epoch 49 complete.
Epoch 50 complete.
Epoch 51 complete.
Epoch 52 complete.
Epo

Accuracy: 0.9248000000
#### Note:

CIFAR-10 consists of 60,000 32x32 color images in 10 classes.
- **Model**: ResNet18 / VGG-like models
- **Expected Accuracy**: ~92–96%
- **Link**: [CIFAR-10 GitHub](https://github.com/kuangliu/pytorch-cifar)

> This repository explores multiple state of the art models including VGG, ResNet18, and many others providing us with a good point to start comparing the results of our implementation with further models.