# Imports

In [164]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset, random_split
from torchvision import models
from torchvision.models import mobilenet_v2
from torchvision.models.mobilenetv2 import InvertedResidual
from tqdm import tqdm
from torchsummary import summary
import time
import os
import torch_pruning as tp

# Loading the dataset (CIFAR-10)

In [6]:
# Fix the seed to ensure reproducibility
torch.manual_seed(42)

# Data augmentation for training (applied only to the train dataset)
train_transform = transforms.Compose([
    transforms.Resize((224, 224)), # Resize to MobileNet input size
    transforms.RandomHorizontalFlip(), # Flip horizontally
    transforms.RandomRotation(15), # Random rotation 15 degress
    transforms.ToTensor(), # Convert to tensor
    transforms.Normalize((0.5,), (0.5,)) # Normalize (mean=0.5, std=0.5)
])

# No augmentation for validation/test (only resizing and normalization)
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
# Download CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_data = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)

# Split the test_data into test (6k) and dev (4k)
test_size = 6000
dev_size = 4000
test_dataset, dev_dataset = random_split(test_data, [test_size, dev_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
dev_loader = DataLoader(dev_dataset, batch_size=64, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

# Print dataset sizes
print(f"Train set size: {len(train_dataset)}")
print(f"Dev set size: {len(dev_dataset)}")
print(f"Test set size: {len(test_dataset)}")

Train set size: 50000
Dev set size: 4000
Test set size: 6000


# Loading the model

In [197]:
# Load the pre-trained MobileNet model
base_model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)

# Unfreeze some of the top layers
for param in base_model.features[:-10].parameters():
    param.requires_grad = False

# Modify the classifier for CIFAR-10 (10 classes)
class MobileNetV2(nn.Module):
    def __init__(self, base_model):
        super(MobileNetV2, self).__init__()
        self.features = base_model.features
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Linear(1280, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.5),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Dropout(0.5),
            nn.Linear(64, 10)  # CIFAR-10 has 10 classes
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

model = MobileNetV2(base_model)

In [198]:
model = MobileNetV2(base_model)
model_weights_path = '../pytorch_models/mobilenet_cifar10_fullyTrained.pth'
# load the full model
model.load_state_dict(torch.load(model_weights_path))
# Set the model to evaluation mode
model.eval()
print("Original Model weights loaded successfully!")

Original Model weights loaded successfully!


In [199]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=

In [200]:
# print model summary
print("Original Model Summary:")
summary(model, (3, 224, 224), device=device.type)

Original Model Summary:
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 112, 112]             864
       BatchNorm2d-2         [-1, 32, 112, 112]              64
             ReLU6-3         [-1, 32, 112, 112]               0
            Conv2d-4         [-1, 32, 112, 112]             288
       BatchNorm2d-5         [-1, 32, 112, 112]              64
             ReLU6-6         [-1, 32, 112, 112]               0
            Conv2d-7         [-1, 16, 112, 112]             512
       BatchNorm2d-8         [-1, 16, 112, 112]              32
  InvertedResidual-9         [-1, 16, 112, 112]               0
           Conv2d-10         [-1, 96, 112, 112]           1,536
      BatchNorm2d-11         [-1, 96, 112, 112]             192
            ReLU6-12         [-1, 96, 112, 112]               0
           Conv2d-13           [-1, 96, 56, 56]             864
      BatchNorm

# Required functions for training and inference

In [71]:
# Training loop
def train_model(model, train_loader, dev_loader, criterion, optimizer, num_epochs=10, mask_enforcer=None):
    for epoch in range(num_epochs):
        ### Training Phase ###
        model.train()  # Set model to training mode
        running_loss = 0.0
        correct, total = 0, 0

        loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=True)
        for images, labels in loop:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Enforce pruning mask if provided
            if mask_enforcer:
                mask_enforcer.enforce()

            # Compute training metrics
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            # Update progress bar
            loop.set_postfix(train_loss=running_loss / total, train_acc=100. * correct / total)

        train_loss = running_loss / len(train_loader)
        train_acc = 100. * correct / total

        ### Validation (Dev) Phase ###
        model.eval()  # Set model to evaluation mode
        dev_loss, dev_correct, dev_total = 0.0, 0, 0

        with torch.no_grad():
            for images, labels in dev_loader:
                images, labels = images.to(device), labels.to(device)

                outputs = model(images)
                loss = criterion(outputs, labels)

                dev_loss += loss.item()
                _, predicted = outputs.max(1)
                dev_total += labels.size(0)
                dev_correct += predicted.eq(labels).sum().item()

        dev_loss /= len(dev_loader)
        dev_acc = 100. * dev_correct / dev_total

        print(f"Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, Dev Loss: {dev_loss:.4f}, Dev Acc: {dev_acc:.2f}%")

    print("Training complete!")

In [72]:
def evaluate_model(model, test_loader, criterion, device):
    model.eval()  # Set the model to evaluation mode
    test_loss = 0.0
    correct = 0
    total = 0
    start_time = time.time()
    with torch.no_grad():  # No gradients for validation (Disable gradient calculations for efficiency)
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Update loss
            test_loss += loss.item()

            # Get predictions
            _, predicted = outputs.max(1)

            # Update metrics
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    end_time = time.time()
    inference_time = end_time - start_time
    avg_loss = test_loss / len(test_loader)
    accuracy = 100. * correct / total

    #print(f"Test Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.2f}%")
    return avg_loss, accuracy, inference_time

# Test model before pruning

In [201]:
criterion = nn.CrossEntropyLoss()
# Results before pruning
bp_loss, bp_acc, bp_inf_time = evaluate_model(model, test_loader, criterion, device)
print("Results before pruning:")
print(f"Average test Loss: {bp_loss:.4f}, Test Accuracy: {bp_acc:.2f}%, Inference Time for {len(test_dataset)} images: {bp_inf_time:.2f} seconds")
example_inputs = torch.randn(1, 3, 224, 224).to(device)
bp_flops, bp_params = tp.utils.count_ops_and_params(model, example_inputs)
print(f"Original Model FLOPs: {bp_flops / 1e6:.2f} MFLOPs, Parameters: {bp_params / 1e6:.2f} M")

Results before pruning:
Average test Loss: 0.2305, Test Accuracy: 93.33%, Inference Time for 6000 images: 12.88 seconds
Original Model FLOPs: 319.39 MFLOPs, Parameters: 2.59 M


# Apply pruning using torch pruning library
https://pypi.org/project/torch-pruning/#why-torch-pruning

To install use: `pip install torch-pruning`


In [202]:
# List to store layers we want to ignore during pruning
ignored_layers = []

# Loop through all modules in model.features
for m in model.features.modules():
    # Check if the module is an InvertedResidual block with residual connection
    if isinstance(m, InvertedResidual) and m.use_res_connect:
        # Add all layers in this residual block to ignored_layers
        for name, layer in m.named_modules():
            if isinstance(layer, nn.Conv2d):
                ignored_layers.append(layer)

# Also ignore the classifier
ignored_layers.append(model.classifier)

# Now create pruner with ignored_layers
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs=example_inputs,
    importance=tp.importance.MagnitudeImportance(p=1),
    global_pruning=True,
    pruning_ratio=0.8,
    ignored_layers=ignored_layers,
)

# Perform pruning
pruner.step()

In [203]:
print("Results before pruning:")
print(f"Average test Loss: {bp_loss:.4f}, Test Accuracy: {bp_acc:.2f}%, Inference Time for {len(test_dataset)} images: {bp_inf_time:.2f} seconds")
print(f"Original Model FLOPs: {bp_flops / 1e6:.2f} MFLOPs, Parameters: {bp_params / 1e6:.2f} M")
criterion = nn.CrossEntropyLoss()
ap_loss, ap_acc, ap_inf = evaluate_model(model, test_loader, criterion, device)
print("Results after pruning:")
print(f"Average test Loss: {ap_loss:.4f}, Test Accuracy: {ap_acc:.2f}%, Inference Time for {len(test_dataset)} images: {ap_inf:.2f} seconds")
# Calculate FLOPs and parameters after pruning
ap_flops, ap_params = tp.utils.count_ops_and_params(model, example_inputs)
print(f"Pruned Model FLOPs: {ap_flops / 1e6:.2f} MFLOPs, Parameters: {ap_params / 1e6:.2f} M")


Results before pruning:
Average test Loss: 0.2305, Test Accuracy: 93.33%, Inference Time for 6000 images: 12.88 seconds
Original Model FLOPs: 319.39 MFLOPs, Parameters: 2.59 M
Results after pruning:
Average test Loss: 4.2113, Test Accuracy: 10.00%, Inference Time for 6000 images: 8.57 seconds
Pruned Model FLOPs: 256.15 MFLOPs, Parameters: 1.30 M


In [205]:
# fine tune the pruned model
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
# Train the pruned model
train_model(model, train_loader, dev_loader, criterion, optimizer, num_epochs=5)

Epoch 1/5: 100%|██████████| 782/782 [02:12<00:00,  5.90it/s, train_acc=87.2, train_loss=0.00654]


Epoch 1: Train Loss: 0.4183, Train Acc: 87.19%, Dev Loss: 0.3276, Dev Acc: 89.85%


Epoch 2/5: 100%|██████████| 782/782 [02:13<00:00,  5.85it/s, train_acc=90.1, train_loss=0.00511]


Epoch 2: Train Loss: 0.3265, Train Acc: 90.08%, Dev Loss: 0.3057, Dev Acc: 90.15%


Epoch 3/5: 100%|██████████| 782/782 [02:17<00:00,  5.70it/s, train_acc=91.2, train_loss=0.0045] 


Epoch 3: Train Loss: 0.2877, Train Acc: 91.22%, Dev Loss: 0.2780, Dev Acc: 91.03%


Epoch 4/5: 100%|██████████| 782/782 [02:13<00:00,  5.84it/s, train_acc=91.6, train_loss=0.00432]


Epoch 4: Train Loss: 0.2765, Train Acc: 91.57%, Dev Loss: 0.2862, Dev Acc: 91.35%


Epoch 5/5: 100%|██████████| 782/782 [02:11<00:00,  5.94it/s, train_acc=91.9, train_loss=0.00417]


Epoch 5: Train Loss: 0.2663, Train Acc: 91.95%, Dev Loss: 0.2549, Dev Acc: 91.70%
Training complete!


In [206]:
print("Results before pruning:")
print(f"Average test Loss: {bp_loss:.4f}, Test Accuracy: {bp_acc:.2f}%, Inference Time for {len(test_dataset)} images: {bp_inf_time:.2f} seconds")
print(f"Original Model FLOPs: {bp_flops / 1e6:.2f} MFLOPs, Parameters: {bp_params / 1e6:.2f} M")
print("Results after pruning:")
print(f"Average test Loss: {ap_loss:.4f}, Test Accuracy: {ap_acc:.2f}%, Inference Time for {len(test_dataset)} images: {ap_inf:.2f} seconds")
print(f"Pruned Model FLOPs: {ap_flops / 1e6:.2f} MFLOPs, Parameters: {ap_params / 1e6:.2f} M")
print("Results after fine-tuning:")
ft_loss, ft_acc, ft_inf = evaluate_model(model, test_loader, criterion, device)
print(f"Average test Loss: {ft_loss:.4f}, Test Accuracy: {ft_acc:.2f}%, Inference Time for {len(test_dataset)} images: {ft_inf:.2f} seconds")
# calculate FLOPs and parameters after fine-tuning
ft_flops, ft_params = tp.utils.count_ops_and_params(model, example_inputs)
print(f"Fine-tuned Model FLOPs: {ft_flops / 1e6:.2f} MFLOPs, Parameters: {ft_params / 1e6:.2f} M")

Results before pruning:
Average test Loss: 0.2305, Test Accuracy: 93.33%, Inference Time for 6000 images: 12.88 seconds
Original Model FLOPs: 319.39 MFLOPs, Parameters: 2.59 M
Results after pruning:
Average test Loss: 4.2113, Test Accuracy: 10.00%, Inference Time for 6000 images: 8.57 seconds
Pruned Model FLOPs: 256.15 MFLOPs, Parameters: 1.30 M
Results after fine-tuning:
Average test Loss: 0.2807, Test Accuracy: 91.53%, Inference Time for 6000 images: 12.36 seconds
Fine-tuned Model FLOPs: 256.15 MFLOPs, Parameters: 1.30 M


In [207]:
onnx_path = '../onnx_models/MobileNetV2_torch_pruning_80.onnx'
# Export the model to ONNX
torch.onnx.export(
    model=model,
    args=torch.randn(1, 3, 224, 224).to(device),
    f=onnx_path,
    input_names=['input'],
    output_names=['output'],
    export_params=True,
    opset_version=11
)
print(f"Model successfully exported to: {onnx_path}")

Model successfully exported to: ../onnx_models/MobileNetV2_torch_pruning_80.onnx


In [208]:
original_model_path = '../onnx_models/mobilenet_cifar10.onnx'
pruned_30_path = '../onnx_models/MobileNetV2_torch_pruning_30.onnx'
pruned_50_path = '../onnx_models/MobileNetV2_torch_pruning_50.onnx'
pruned_80_path = '../onnx_models/MobileNetV2_torch_pruning_80.onnx'

In [209]:
# compare sizes of the models
def get_model_size(model_path):
    return os.path.getsize(model_path) / (1024 * 1024)  # Size in MB

print(f"Original Model Size: {get_model_size(original_model_path):.2f} MB")
print(f"Pruned Model (30%) Size: {get_model_size(pruned_30_path):.2f} MB")
print(f"Pruned Model (50%) Size: {get_model_size(pruned_50_path):.2f} MB")
print(f"Pruned Model (80%) Size: {get_model_size(pruned_80_path):.2f} MB")

Original Model Size: 9.88 MB
Pruned Model (30%) Size: 7.57 MB
Pruned Model (50%) Size: 6.31 MB
Pruned Model (80%) Size: 4.97 MB
