Pretrained ResNet-18 Model: Modified for CIFAR-10 classification.
Structured Pruning: Removes 20% of convolutional filters (excluding residual connections).
Fine-Tuning: Retrains the pruned model to recover accuracy.
Performance Evaluation: Measures accuracy, model size, and inference speed before & after pruning.

Before Pruning: The model had 89.84% accuracy on CIFAR-10.
After Pruning: Accuracy dropped to 18.87% due to filter removal.
After Fine-Tuning: Accuracy recovered to 91.67%, even improving slightly.

In [None]:
# Step 1: Install necessary libraries
!pip install torch torchvision



In [None]:
# Step 2: Import libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.utils.prune as prune
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

In [None]:
# Step 3: Load a pre-trained ResNet-18 model
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 133MB/s]


In [None]:
# Step 4: Modify the final layer for CIFAR-10 (10 classes)
model.fc = nn.Linear(model.fc.in_features, 10)


In [None]:
# Step 5: Adjust input size for CIFAR-10 (32x32)
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.maxpool = nn.Identity()  # Remove initial max pooling layer

In [None]:
# Step 6: Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:13<00:00, 12.5MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
# Step 7: Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

In [None]:
# Step 8: Fine-tune the model on CIFAR-10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}")

Epoch 1, Loss: 0.7267
Epoch 2, Loss: 0.2908
Epoch 3, Loss: 0.1748
Epoch 4, Loss: 0.1025
Epoch 5, Loss: 0.0729


In [None]:
# Step 9: Evaluate the model before pruning
def evaluate_model(model, data_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

print("Evaluating model before pruning...")
accuracy_before = evaluate_model(model, test_loader)
print(f"Accuracy before pruning: {accuracy_before:.2f}%")

Evaluating model before pruning...
Accuracy before pruning: 89.84%


In [None]:
# Save model size BEFORE pruning
torch.save(model.state_dict(), "model_before_pruning.pth")
print(f"Original model size: {os.path.getsize('model_before_pruning.pth') / 1024:.2f} KB")

# Measure inference time BEFORE pruning
time_before = measure_inference_time(model, test_loader, device)
print(f"Inference Time Before Pruning: {time_before:.4f} seconds")


Original model size: 43726.85 KB
Inference Time Before Pruning: 2.6622 seconds


In [None]:
# Step 10: Count parameters before pruning
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

params_before = count_parameters(model)

In [None]:
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d) and "downsample" not in name:
        prune.ln_structured(module, name='weight', amount=0.2, n=2, dim=0)  # Prune 20%
        prune.remove(module, 'weight')  # Make pruning permanent
        # Manually set pruned weights to zero
        with torch.no_grad():
            module.weight[module.weight == 0] = 0


In [None]:
# Save model size AFTER pruning
torch.save(model.state_dict(), "model_after_pruning.pth")
print(f"Pruned model size: {os.path.getsize('model_after_pruning.pth') / 1024:.2f} KB")

# Measure inference time AFTER pruning
time_after = measure_inference_time(model, test_loader, device)
print(f"Inference Time After Pruning: {time_after:.4f} seconds")

# Print parameter comparison
params_after = count_parameters(model)
print(f"Number of parameters before pruning: {params_before}")
print(f"Number of parameters after pruning: {params_after}")


Pruned model size: 43726.72 KB
Inference Time After Pruning: 2.6991 seconds
Number of parameters before pruning: 11173962
Number of parameters after pruning: 11173962


In [None]:
# Step 12: Evaluate the model after pruning
print("Evaluating model after pruning...")
accuracy_after = evaluate_model(model, test_loader)
print(f"Accuracy after pruning: {accuracy_after:.2f}%")

Evaluating model after pruning...
Accuracy after pruning: 88.44%


In [None]:
# Step 13: Count parameters after pruning
params_after = count_parameters(model)
print(f"Number of parameters before pruning: {params_before}")
print(f"Number of parameters after pruning: {params_after}")

Number of parameters before pruning: 11173962
Number of parameters after pruning: 11173962


In [None]:
# Step 14: Fine-tune after pruning
print("Fine-tuning the pruned model...")
optimizer = optim.SGD(model.parameters(), lr=0.005, momentum=0.9)  # Lower learning rate for fine-tuning
num_finetune_epochs = 3

for epoch in range(num_finetune_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Fine-Tuning Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}")

Fine-tuning the pruned model...
Fine-Tuning Epoch 1, Loss: 0.0191
Fine-Tuning Epoch 2, Loss: 0.0075
Fine-Tuning Epoch 3, Loss: 0.0052


In [None]:
# Step 15: Final evaluation after fine-tuning
print("Final evaluation after fine-tuning...")
final_accuracy = evaluate_model(model, test_loader)
print(f"Final accuracy after fine-tuning: {final_accuracy:.2f}%")


Final evaluation after fine-tuning...
Final accuracy after fine-tuning: 91.88%


In [None]:
# Save model size AFTER pruning
torch.save(model.state_dict(), "model_after_pruning.pth")
print(f"Pruned model size: {os.path.getsize('model_after_pruning.pth') / 1024:.2f} KB")

# Measure inference time AFTER pruning
time_after = measure_inference_time(model, test_loader, device)
print(f"Inference Time After Pruning: {time_after:.4f} seconds")

# Print parameter comparison
params_after = count_parameters(model)
print(f"Number of parameters before pruning: {params_before}")
print(f"Number of parameters after pruning: {params_after}")


Pruned model size: 43726.72 KB
Inference Time After Pruning: 2.6802 seconds
Number of parameters before pruning: 11173962
Number of parameters after pruning: 11173962


In [None]:
import time
import os

# Step 1: Count Model Parameters (Already Done)
print(f"Number of parameters before pruning: {params_before}")
print(f"Number of parameters after pruning: {params_after}")

# Step 2: Save the Model & Check File Size
torch.save(model.state_dict(), "model_pruned.pth")
print(f"Pruned model size: {os.path.getsize('model_pruned.pth') / 1024:.2f} KB")

# Step 3: Measure Inference Speed
def measure_inference_time(model, data_loader, device):
    model.eval()
    start_time = time.time()

    with torch.no_grad():
        for images, _ in data_loader:
            images = images.to(device)
            model(images)  # Forward pass

    end_time = time.time()
    return end_time - start_time

# Measure inference time after pruning
time_after = measure_inference_time(model, test_loader, device)
print(f"Inference Time After Pruning: {time_after:.4f} seconds")


Number of parameters before pruning: 11173962
Number of parameters after pruning: 11173962
Pruned model size: 43725.86 KB
Inference Time After Pruning: 2.6540 seconds
