<a href="https://colab.research.google.com/github/Abdelrahman188/Web-development/blob/master/Channel_prunning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Pruning**

# Setup

First, install the required packages and download the datasets and pretrained model. Here we use CIFAR10 dataset and VGG network which is the same as what we used in the Lab 0 tutorial.

In [1]:
print('Installing torchprofile...')
!pip install torchprofile 1>/dev/null
print('All required packages have been successfully installed!')

Installing torchprofile...
All required packages have been successfully installed!


In [2]:
import copy
import math
import random
import time
from collections import OrderedDict, defaultdict
from typing import Union, List

import numpy as np
import torch
from matplotlib import pyplot as plt
from torch import nn
from torch.optim import *
from torch.optim.lr_scheduler import *
from torch.utils.data import DataLoader
from torchprofile import profile_macs
from torchvision.datasets import *
from torchvision.transforms import *
from tqdm.auto import tqdm

from torchprofile import profile_macs

assert torch.cuda.is_available(), \
"The current runtime does not have CUDA support." \
"Please go to menu bar (Runtime - Change runtime type) and select GPU"

In [3]:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x7d9b36294cf0>

In [4]:
def download_url(url, model_dir='.', overwrite=False):
    import os, sys, ssl
    from urllib.request import urlretrieve
    ssl._create_default_https_context = ssl._create_unverified_context
    target_dir = url.split('/')[-1]
    model_dir = os.path.expanduser(model_dir)
    try:
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        model_dir = os.path.join(model_dir, target_dir)
        cached_file = model_dir
        if not os.path.exists(cached_file) or overwrite:
            sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
            urlretrieve(url, cached_file)
        return cached_file
    except Exception as e:
        # remove lock file so download can be executed next time.
        os.remove(os.path.join(model_dir, 'download.lock'))
        sys.stderr.write('Failed to download from url %s' % url + '\n' + str(e) + '\n')
        return None

In [39]:
class VGG(nn.Module):
  ARCH = [64, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']

  def __init__(self) -> None:
    super().__init__()

    layers = []
    counts = defaultdict(int)

    def add(name: str, layer: nn.Module) -> None:
      layers.append((f"{name}{counts[name]}", layer))
      counts[name] += 1

    in_channels = 3
    for x in self.ARCH:
      if x != 'M':
        # conv-bn-relu
        add("conv", nn.Conv2d(in_channels, x, 3, padding=1, bias=False))
        add("bn", nn.BatchNorm2d(x))
        add("relu", nn.ReLU(True))
        in_channels = x
      else:
        # maxpool
        add("pool", nn.MaxPool2d(2))

    self.backbone = nn.Sequential(OrderedDict(layers))
    self.classifier = nn.Linear(512, 10)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    # backbone: [N, 3, 32, 32] => [N, 512, 2, 2]
    x = self.backbone(x)

    # avgpool: [N, 512, 2, 2] => [N, 512]
    x = x.mean([2, 3])

    # classifier: [N, 512] => [N, 10]
    x = self.classifier(x)
    return x

In [46]:
def train(
  model: nn.Module,
  dataloader: DataLoader,
  criterion: nn.Module,
  optimizer: Optimizer,
  scheduler: LambdaLR,
  callbacks = None
) -> None:
  model.train()

  for inputs, targets in tqdm(dataloader, desc='train', leave=False):
    # Move the data from CPU to GPU
    inputs = inputs.cuda()
    targets = targets.cuda()

    # Reset the gradients (from the last iteration)
    optimizer.zero_grad()

    # Forward inference
    outputs = model(inputs)
    loss = criterion(outputs, targets)

    # Backward propagation
    loss.backward()
    enforce_sparsity(pruned_model)

    # Update optimizer and LR scheduler
    optimizer.step()
    scheduler.step()

    if callbacks is not None:
        for callback in callbacks:
            callback()

In [47]:
@torch.inference_mode()
def evaluate(
  model: nn.Module,
  dataloader: DataLoader,
  verbose=True,
) -> float:
  model.eval()

  num_samples = 0
  num_correct = 0

  for inputs, targets in tqdm(dataloader, desc="eval", leave=False,
                              disable=not verbose):
    # Move the data from CPU to GPU
    inputs = inputs.cuda()
    targets = targets.cuda()

    # Inference
    outputs = model(inputs)

    # Convert logits to class indices
    outputs = outputs.argmax(dim=1)

    # Update metrics
    num_samples += targets.size(0)
    num_correct += (outputs == targets).sum()

  return (num_correct / num_samples * 100).item()

Helper Functions (Flops, Model Size calculation, etc.)

In [48]:
def get_model_macs(model, inputs) -> int:
    return profile_macs(model, inputs)


def get_sparsity(tensor: torch.Tensor) -> float:
    """
    calculate the sparsity of the given tensor
        sparsity = #zeros / #elements = 1 - #nonzeros / #elements
    """
    return 1 - float(tensor.count_nonzero()) / tensor.numel()


def get_model_sparsity(model: nn.Module) -> float:
    """
    calculate the sparsity of the given model
        sparsity = #zeros / #elements = 1 - #nonzeros / #elements
    """
    num_nonzeros, num_elements = 0, 0
    for param in model.parameters():
        num_nonzeros += param.count_nonzero()
        num_elements += param.numel()
    return 1 - float(num_nonzeros) / num_elements

def get_num_parameters(model: nn.Module, count_nonzero_only=False) -> int:
    """
    calculate the total number of parameters of model
    :param count_nonzero_only: only count nonzero weights
    """
    num_counted_elements = 0
    for param in model.parameters():
        if count_nonzero_only:
            num_counted_elements += param.count_nonzero()
        else:
            num_counted_elements += param.numel()
    return num_counted_elements


def get_model_size(model: nn.Module, data_width=32, count_nonzero_only=False) -> int:
    """
    calculate the model size in bits
    :param data_width: #bits per element
    :param count_nonzero_only: only count nonzero weights
    """
    return get_num_parameters(model, count_nonzero_only) * data_width

Byte = 8
KiB = 1024 * Byte
MiB = 1024 * KiB
GiB = 1024 * MiB

Define misc functions for verification.

Load the pretrained model and the CIFAR-10 dataset.

In [12]:
def model_to_drive(url):
  from google.colab import drive
  import torch
  import requests
  drive.mount('/content/drive')
  checkpoint_url = url
# Download the checkpoint file
  response = requests.get(checkpoint_url)
  checkpoint_path = '/content/drive/MyDrive/vgg.cifar.pretrained.pth'  # Path in Google Drive
  with open(checkpoint_path, 'wb') as f:
      f.write(response.content)
model_to_drive("https://hanlab18.mit.edu/files/course/labs/vgg.cifar.pretrained.pth")

Mounted at /content/drive


In [13]:
checkpoint_path = '/content/drive/MyDrive/vgg.cifar.pretrained.pth'
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model = VGG().cuda()
print(f"=> loading checkpoint '{checkpoint_path}'")
model.load_state_dict(checkpoint['state_dict'])
recover_model = lambda: model.load_state_dict(checkpoint['state_dict'])

  checkpoint = torch.load(checkpoint_path, map_location="cpu")


=> loading checkpoint '/content/drive/MyDrive/vgg.cifar.pretrained.pth'


In [14]:
image_size = 32
transforms = {
    "train": Compose([
        RandomCrop(image_size, padding=4),
        RandomHorizontalFlip(),
        ToTensor(),
    ]),
    "test": ToTensor(),
}
dataset = {}
for split in ["train", "test"]:
  dataset[split] = CIFAR10(
    root="data/cifar10",
    train=(split == "train"),
    download=True,
    transform=transforms[split],
  )
dataloader = {}
for split in ['train', 'test']:
  dataloader[split] = DataLoader(
    dataset[split],
    batch_size=512,
    shuffle=(split == 'train'),
    num_workers=0,
    pin_memory=True,
  )

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar10/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:13<00:00, 13.0MB/s]


Extracting data/cifar10/cifar-10-python.tar.gz to data/cifar10
Files already downloaded and verified


In [92]:
pip install tensorboard




In [96]:
from torch.utils.tensorboard import SummaryWriter


In [97]:
writer = SummaryWriter()

# Let's First Evaluate the Accuracy and Model Size of Dense Model

Neural networks have become ubiquitous in many applications. Here we have loaded a pretrained VGG model for classifying images in CIFAR10 dataset.

Let's first evaluate the accuracy and model size of this model.

In [153]:
dense_model_accuracy = evaluate(model, dataloader['test'])
dense_model_size = get_model_size(model)
print(f"dense model has accuracy={dense_model_accuracy:.2f}%")
print(f"dense model has size={dense_model_size/MiB:.2f} MiB")

eval:   0%|          | 0/20 [00:00<?, ?it/s]

dense model has accuracy=92.95%
dense model has size=35.20 MiB


# Let's see the distribution of weight values

In [127]:
recover_model()
dense_model_accuracy = evaluate(model, dataloader['test'])
print(f"dense model has accuracy={dense_model_accuracy:.2f}%")

eval:   0%|          | 0/20 [00:00<?, ?it/s]

dense model has accuracy=92.95%


In [128]:
from torch.nn.utils import prune


In [129]:
def check_sparsity(model):
    total_params = 0
    zero_params = 0

    for name, module in model.named_modules():
        if hasattr(module, 'weight'):
            total_params += module.weight.numel()
            zero_params += torch.sum(module.weight == 0).item()

    sparsity = 100.0 * zero_params / total_params if total_params > 0 else 0
    print(f"Model Sparsity: {sparsity:.2f}%")

    return sparsity


In [130]:
pruned_model=copy.deepcopy(model)

In [131]:
def apply_channel_pruning(model, sparsity, n):
  for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
      prune.ln_structured(module, name="weight", amount=sparsity, n=n, dim=0)

In [132]:
apply_channel_pruning(pruned_model,0.6,1)

In [133]:
for name, module in pruned_model.named_modules():
    if isinstance(module, nn.Conv2d):
        print(f"{name} pruned: {torch.sum(module.weight == 0).item()} channels set to zero")

backbone.conv0 pruned: 1026 channels set to zero
backbone.conv1 pruned: 44352 channels set to zero
backbone.conv2 pruned: 177408 channels set to zero
backbone.conv3 pruned: 354816 channels set to zero
backbone.conv4 pruned: 707328 channels set to zero
backbone.conv5 pruned: 1414656 channels set to zero
backbone.conv6 pruned: 1414656 channels set to zero
backbone.conv7 pruned: 1414656 channels set to zero


In [134]:
channel_model_accuracy = evaluate(pruned_model, dataloader['test'])
print(f"dense model has accuracy={channel_model_accuracy:.2f}%")

eval:   0%|          | 0/20 [00:00<?, ?it/s]

dense model has accuracy=10.00%


In [135]:
pruning_masks = {name: module.weight_mask.clone() for name, module in pruned_model.named_modules() if isinstance(module, nn.Conv2d)}

In [136]:
check_sparsity(pruned_model)

Model Sparsity: 59.93%


59.929955775234134

In [137]:
def enforce_sparsity(model):
    """Ensure pruned weights remain zero after optimizer updates."""
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)) and hasattr(module, 'weight_mask'):
            module.weight.data *= module.weight_mask  # Zero-out pruned weights


In [138]:
num_finetune_epochs = 5
optimizer = torch.optim.SGD(pruned_model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_finetune_epochs)
criterion = nn.CrossEntropyLoss()
best_accuracy = 0
for epoch in range(num_finetune_epochs):
    train(pruned_model, dataloader['train'], criterion, optimizer, scheduler)
    accuracy = evaluate(pruned_model, dataloader['test'])
    is_best = accuracy > best_accuracy
    if is_best:
        best_accuracy = accuracy
    print(f'Epoch {epoch+1} Accuracy {accuracy:.2f}% / Best Accuracy: {best_accuracy:.2f}%')
    writer.add_scalar('Test Accuracy', accuracy, epoch)

train:   0%|          | 0/98 [00:00<?, ?it/s]

eval:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 1 Accuracy 87.62% / Best Accuracy: 87.62%


train:   0%|          | 0/98 [00:00<?, ?it/s]

eval:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 2 Accuracy 88.95% / Best Accuracy: 88.95%


train:   0%|          | 0/98 [00:00<?, ?it/s]

eval:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 3 Accuracy 89.57% / Best Accuracy: 89.57%


train:   0%|          | 0/98 [00:00<?, ?it/s]

eval:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 4 Accuracy 89.63% / Best Accuracy: 89.63%


train:   0%|          | 0/98 [00:00<?, ?it/s]

eval:   0%|          | 0/20 [00:00<?, ?it/s]

Epoch 5 Accuracy 90.10% / Best Accuracy: 90.10%


In [139]:
evaluate(pruned_model, dataloader['test'])

eval:   0%|          | 0/20 [00:00<?, ?it/s]

90.0999984741211

In [140]:
for name, module in pruned_model.named_modules():
    if isinstance(module, nn.Conv2d):
        print(f"{name} pruned: {torch.sum(module.weight == 0).item()} channels set to zero")

backbone.conv0 pruned: 1026 channels set to zero
backbone.conv1 pruned: 44352 channels set to zero
backbone.conv2 pruned: 177408 channels set to zero
backbone.conv3 pruned: 354816 channels set to zero
backbone.conv4 pruned: 707328 channels set to zero
backbone.conv5 pruned: 1414656 channels set to zero
backbone.conv6 pruned: 1414656 channels set to zero
backbone.conv7 pruned: 1414656 channels set to zero


In [141]:
check_sparsity(pruned_model)

Model Sparsity: 59.93%


59.929955775234134

In [150]:
torch.save(pruned_model.state_dict(), "model_after.pth")

In [151]:

# Get size in MB
size_before = os.path.getsize("model_after.pth") / (1024 * 1024)
print(f"Model size after pruning: {size_before:.2f} MB")

Model size after pruning: 70.41 MB


In [152]:
get_model_size(model)

295307584

In [158]:
get_model_size(pruned_model,count_nonzero_only=True)/MiB

tensor(35.2034, device='cuda:0')

In [168]:
# Remove pruning masks and make changes permanent
for name, module in pruned_model.named_modules():
    if isinstance(module, nn.Conv2d) and hasattr(module, "weight_mask"):
        prune.remove(module, "weight")  # Removes the pruning mask

In [169]:
get_model_sparsity(pruned_model)

0.5991201905603616

In [170]:
check_sparsity(pruned_model)

Model Sparsity: 59.93%


59.929955775234134