# Introduction to Model Efficiency

### Why Model Efficiency Matters

In the context of AI, efficiency is not just about speed; it's about making AI accessible and practical for everyday applications. The size and complexity of deep learning models have grown tremendously. While this leads to improved performance, it creates challenges for deployment, especially on mobile devices, embedded systems, or any setting where computational resources or power are limited. Model efficiency techniques aim to address these challenges without sacrificing accuracy.

Growing model complexity poses challenges:
- Storage: Large models consume more storage space.
- Inference Speed: Complex models take longer to process inputs.
- Energy Consumption: Computationally demanding models drain batteries quickly.
- Deployment: Resource-constrained devices (smartphones, IoT) struggle with large models.

### Common Model Definition and Data Loading
Before diving into each model effieciency technique, let's establish a common setup, especially for loading the CIFAR-10 dataset and defining a simple CNN model. This setup will be used across for the 1. Pruning and 2. Quantization examples

In [12]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import torchvision.models as models
import numpy as np
import os

# Simple CNN model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.fc1 = nn.Linear(32 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 32 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# CIFAR-10 data loading
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)


Files already downloaded and verified


In [13]:
model = SimpleCNN()
torch.save(model.state_dict(), "model_before_pruning.pth")
original_size = os.path.getsize("model_before_pruning.pth")

print(f"Original model size: {original_size} bytes")

Original model size: 488494 bytes


## 1. Pruning
Pruning identifies and removes the less important weights of a neural network that have minimal impact on its output. This creates a sparser and more streamlined model. 
Unstructured pruning zero-out the least important weights, while structured pruning can be more aggressive by removing full sections of the network.

Benefits:
- Reduces model size
- Can improve inference speed

<img src="./imgs/pruning.webp" alt="drawing" width="450"/>

PyTorch includes functionality for both structured and unstructured pruning. Below, we'll show an example of unstructured pruning, which removes individual weights in the model. This example demonstrates how to randomly prune 30% of the connections in the first linear layer of the network by setting their weights to zero.

In [25]:
import torch.nn.utils.prune as prune

# Apply pruning to the model's first convolutional layer
prune.l1_unstructured(model.conv1, name="weight", amount=0.3)
prune.remove(model.conv1, 'weight')  # Make pruning permanent

torch.save(model.state_dict(), "model_after_pruning.pth")
pruned_size = os.path.getsize("model_after_pruning.pth")

print(f"Model size after pruning: {pruned_size} bytes")

Model size after pruning: 488480 bytes


To achieve a more significant difference in model size through pruning, especially in a small model like `SimpleCNN`, we might need to apply more aggressive pruning or use structured pruning which removes entire channels or filters, not just individual weights. This is illustrated below.

In [26]:
# Apply structured pruning (removing entire channels) to the convolutional layers
prune.ln_structured(model.conv1, name="weight", amount=0.5, n=2, dim=0)
prune.remove(model.conv1, 'weight')  # Make pruning permanent

prune.ln_structured(model.conv2, name="weight", amount=0.5, n=2, dim=0)
prune.remove(model.conv2, 'weight')  # Make pruning permanent

# Save the pruned model to disk
torch.save(model.state_dict(), "pruned_model.pth")
pruned_size = os.path.getsize("pruned_model.pth")

print(f"Model size after pruning: {pruned_size} bytes")

Model size after pruning: 488382 bytes


#### Tradeoffs
- **Risk of Losing Important Information:** Aggressive pruning might remove weights that are important for the model’s accuracy, leading to a decrease in performance.
- **Need for Retraining:** Often, a pruned model will require fine-tuning or full retraining to restore or improve its accuracy post-pruning.

## 2. Quantization
Quantization simplifies the model's mathematical operations, converting those high-precision calculations into something more manageable and, crucially, faster. It  reduces the storage footprint of models by using less precise data types such as representing model weights and activations using lower-precision numbers (e.g., 8-bit integers instead of 32-bit floating-point)

Benefits:
- Reduces model size
- Accelerates computation

<img src="./imgs/quantization.jpeg" alt="drawing" width="600"/>

Below is an example of quantization on the same simple model as before. This code dynamically quantizes the linear layers of the model to int8 precision, which is particularly useful for reducing model size and speeding up inference for AI applications.

In [16]:
torch.backends.quantized.engine = 'qnnpack'  # For ARM architectures

model.eval()  # Ensure the model is in evaluation mode for quantization
quantized_model = torch.quantization.quantize_dynamic(model, {nn.Linear, nn.Conv2d}, dtype=torch.qint8)

torch.save(quantized_model.state_dict(), "model_after_quantization.pth")
quantized_size = os.path.getsize("model_after_quantization.pth")

print(f"Model size after quantization: {quantized_size} bytes")


Model size after quantization: 169850 bytes


#### Tradeoffs:
- **Potential Accuracy Loss:** Reducing precision can lead to a loss of model fidelity, particularly if not managed carefully with techniques like Quantization-Aware Training (QAT).
- **Hardware Dependencies:** The benefits of quantization may depend on specific hardware capabilities, limiting its effectiveness across diverse deployment scenarios.

## 3. Knowledge Distillation
Knowledge distillation involves training a smaller (student) model to replicate the behavior of a larger (teacher) model. The teacher model produces "soft labels" (probabilistic outputs). Student model is trained to match the soft labels, not just the original dataset's hard labels.

Benefits:
- Compresses knowledge into a smaller, more efficient model
- Potential for higher accuracy than training the student directly on the dataset

<img src="./imgs/kd.jpeg" alt="drawing" width="600"/>

Below is a simplified example of how to set this up in PyTorch. In this example, we use a pretrained ResNet18 model from torchvision as the teacher model, demonstrating knowledge distillation to a simpler student model on the CIFAR-10 dataset. The ResNet18 model is pretrained on ImageNet, so we'll adapt it to work with CIFAR-10.

The teacher model is set to evaluation mode to ensure it does not update its weights during training.
The loss function (CrossEntropyLoss) and optimizer (SGD with learning rate 0.001 and momentum 0.9) are defined for training the student model.

In [27]:
# Data Preparation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainset = Subset(trainset, np.random.choice(len(trainset), 2000, replace=False))
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testset = Subset(testset, np.random.choice(len(testset), 500, replace=False))
testloader = DataLoader(testset, batch_size=64, shuffle=False)

# Define the Student Model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 5, padding=2)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 5, padding=2)
        self.fc1 = nn.Linear(32 * 8 * 8, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 32 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

student_model = SimpleCNN()

# Define the Teacher Model (Pretrained ResNet18)
teacher_model = models.resnet18(pretrained=True)
teacher_model.fc = nn.Linear(teacher_model.fc.in_features, 10)  # Adapt for CIFAR-10

# Function to Train the Model (Without Distillation)
def train_model(model, dataloader, epochs=1):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    model.train()
    for epoch in range(epochs):
        for inputs, labels in dataloader:
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

# Function for Knowledge Distillation
def distillation_loss(outputs, teacher_outputs, labels, T, alpha):
    soft_loss = F.kl_div(F.log_softmax(outputs/T, dim=1), F.softmax(teacher_outputs/T, dim=1), reduction='batchmean') * (T * T * alpha)
    hard_loss = F.cross_entropy(outputs, labels) * (1. - alpha)
    return soft_loss + hard_loss

def train_with_distillation(student_model, teacher_model, dataloader, T=5.0, alpha=0.5, epochs=1):
    optimizer = optim.SGD(student_model.parameters(), lr=0.001, momentum=0.9)
    teacher_model.eval()
    student_model.train()
    for epoch in range(epochs):
        for inputs, labels in dataloader:
            student_outputs = student_model(inputs)
            with torch.no_grad():
                teacher_outputs = teacher_model(inputs)
            loss = distillation_loss(student_outputs, teacher_outputs, labels, T, alpha)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

# Train and Evaluate Student Model Without Distillation
train_model(student_model, trainloader)

# Re-initialize the student model for distillation
student_model = SimpleCNN()

# Train and Evaluate Student Model With Distillation
train_with_distillation(student_model, teacher_model, trainloader)

# Save models and compare sizes
torch.save(student_model.state_dict(), "student_model.pth")
student_model_size = os.path.getsize("student_model.pth")

torch.save(teacher_model.state_dict(), "teacher_model.pth")
teacher_model_size = os.path.getsize("teacher_model.pth")

print(f'Student model size: {student_model_size} bytes')
print(f'Teacher model size: {teacher_model_size} bytes')

Files already downloaded and verified
Files already downloaded and verified
Student model size: 1087436 bytes
Teacher model size: 44804402 bytes


#### Tradeoffs:
- **Dependence on Teacher Quality:** The success of distillation heavily relies on the quality and performance of the teacher model.
- **Complex Training Process:** Distillation can add complexity to the training process, requiring careful tuning of parameters like temperature and loss balancing.

## Conclusion

Implementing these model efficiency techniques involves a balance between improving computational efficiency and managing potential impacts on model accuracy. Each technique has its place and benefits, but also drawbacks that need careful consideration. Depending on the specific requirements of your application—whether it's speed, size, cost, or accuracy—these techniques can be powerful tools in the machine learning practitioner's toolkit.