In [None]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [None]:
print('Installing torchprofile...')
!pip install torchprofile 1>/dev/null
print('All required packages have been successfully installed!')

Installing torchprofile...
All required packages have been successfully installed!


In [None]:
import copy
import math
import random
import time
from collections import OrderedDict, defaultdict
from typing import Union, List

import numpy as np
import torch
from matplotlib import pyplot as plt
from torch import nn
from torch.optim import *
from torch.optim.lr_scheduler import *
from torch.utils.data import DataLoader
from torchprofile import profile_macs
from torchvision.datasets import *
from torchvision.transforms import *
from tqdm.auto import tqdm

from torchprofile import profile_macs

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

class ModelPruner:
    def __init__(self, model: nn.Module, dataloader: dict, device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')):
        """
        Initialize the ModelPruner class.

        Parameters:
        - model: The neural network model to prune.
        - dataloader: A dictionary containing 'train' and 'test' DataLoader objects.
        - device: The device to run the model on ('cuda' or 'cpu').
        """
        self.model = model.to(device)
        self.dataloader = dataloader
        self.device = device

    def count_parameters(self, model: nn.Module, count_nonzero_only=False) -> int:
        """Count the parameters in the model."""
        return sum(p.numel() for p in model.parameters() if p.requires_grad and (p.count_nonzero() if count_nonzero_only else True))

    def get_model_size(self, model: nn.Module, count_nonzero_only=False) -> int:
        """Calculate the model size."""
        # Assuming 32 bits (4 bytes) per parameter
        return self.count_parameters(model, count_nonzero_only) * 4

    def evaluate(self, model: nn.Module) -> float:
        """Evaluate the model accuracy on the test dataset."""
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data in self.dataloader['test']:
                inputs, labels = data
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        return 100 * correct / total

    @staticmethod
    def create_pruning_mask(tensor: torch.Tensor, target_sparsity: float) -> torch.Tensor:
        """Create a binary mask for pruning."""
        tensor_flat = tensor.view(-1)
        threshold = torch.quantile(torch.abs(tensor_flat), target_sparsity)
        mask = torch.abs(tensor) > threshold
        return mask.float()

    def fine_grained_prune(self, target_sparsity: float):
        """Apply fine-grained pruning to the model."""
        for param in self.model.parameters():
            if param.requires_grad and param.dim() > 1:  # Typically for weights of conv and linear layers
                mask = self.create_pruning_mask(param.data, target_sparsity)
                param.data.mul_(mask)

    def fine_tune(self, epochs: int, lr: float = 0.01):
        """Fine-tune the pruned model."""
        optimizer = torch.optim.SGD(self.model.parameters(), lr=lr, momentum=0.9)
        criterion = nn.CrossEntropyLoss()
        self.model.train()
        for epoch in range(epochs):
            for data in self.dataloader['train']:
                inputs, labels = data
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                optimizer.zero_grad()
                outputs = self.model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

    def prune_and_fine_tune(self, target_sparsity: float, fine_tune_epochs: int):
        """Prune the model and then fine-tune it."""
        print("Initial Model Evaluation:")
        initial_accuracy = self.evaluate(self.model)
        initial_size = self.get_model_size(self.model)
        initial_params = self.count_parameters(self.model)
        print(f"Accuracy: {initial_accuracy:.2f}%, Size: {initial_size} bytes, Parameters: {initial_params}")

        print("\nPruning Model...")
        self.fine_grained_prune(target_sparsity)

        print("Model Evaluation After Pruning:")
        pruned_accuracy = self.evaluate(self.model)
        pruned_size = self.get_model_size(self.model, count_nonzero_only=True)
        pruned_params = self.count_parameters(self.model, count_nonzero_only=True)
        print(f"Accuracy: {pruned_accuracy:.2f}%, Size: {pruned_size} bytes, Parameters: {pruned_params}")

        print("\nFine-tuning Model...")
        self.fine_tune(fine_tune_epochs)

        print("Final Model Evaluation:")
        final_accuracy = self.evaluate(self.model)
        final_size = self.get_model_size(self.model, count_nonzero_only=True)
        final_params = self.count_parameters(self.model, count_nonzero_only=True)
        print(f"Accuracy: {final_accuracy:.2f}%, Size: {final_size} bytes, Parameters: {final_params}")


In [None]:
def download_url(url, model_dir='.', overwrite=False):
    import os, sys, ssl
    from urllib.request import urlretrieve
    ssl._create_default_https_context = ssl._create_unverified_context
    target_dir = url.split('/')[-1]
    model_dir = os.path.expanduser(model_dir)
    try:
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
        model_dir = os.path.join(model_dir, target_dir)
        cached_file = model_dir
        if not os.path.exists(cached_file) or overwrite:
            sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
            urlretrieve(url, cached_file)
        return cached_file
    except Exception as e:
        # remove lock file so download can be executed next time.
        os.remove(os.path.join(model_dir, 'download.lock'))
        sys.stderr.write('Failed to download from url %s' % url + '\n' + str(e) + '\n')
        return None

In [None]:
class VGG(nn.Module):
  ARCH = [64, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']

  def __init__(self) -> None:
    super().__init__()

    layers = []
    counts = defaultdict(int)

    def add(name: str, layer: nn.Module) -> None:
      layers.append((f"{name}{counts[name]}", layer))
      counts[name] += 1

    in_channels = 3
    for x in self.ARCH:
      if x != 'M':
        # conv-bn-relu
        add("conv", nn.Conv2d(in_channels, x, 3, padding=1, bias=False))
        add("bn", nn.BatchNorm2d(x))
        add("relu", nn.ReLU(True))
        in_channels = x
      else:
        # maxpool
        add("pool", nn.MaxPool2d(2))

    self.backbone = nn.Sequential(OrderedDict(layers))
    self.classifier = nn.Linear(512, 10)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    # backbone: [N, 3, 32, 32] => [N, 512, 2, 2]
    x = self.backbone(x)

    # avgpool: [N, 512, 2, 2] => [N, 512]
    x = x.mean([2, 3])

    # classifier: [N, 512] => [N, 10]
    x = self.classifier(x)
    return x

In [None]:
def train(
  model: nn.Module,
  dataloader: DataLoader,
  criterion: nn.Module,
  optimizer: Optimizer,
  scheduler: LambdaLR,
  callbacks = None
) -> None:
  model.train()

  for inputs, targets in tqdm(dataloader, desc='train', leave=False):
    # Move the data from CPU to GPU
    inputs = inputs.cuda()
    targets = targets.cuda()

    # Reset the gradients (from the last iteration)
    optimizer.zero_grad()

    # Forward inference
    outputs = model(inputs)
    loss = criterion(outputs, targets)

    # Backward propagation
    loss.backward()

    # Update optimizer and LR scheduler
    optimizer.step()
    scheduler.step()

    if callbacks is not None:
        for callback in callbacks:
            callback()

In [None]:
checkpoint_url = "https://hanlab18.mit.edu/files/course/labs/vgg.cifar.pretrained.pth"
checkpoint = torch.load(download_url(checkpoint_url), map_location="cpu")
model = VGG()
print(f"=> loading checkpoint '{checkpoint_url}'")
model.load_state_dict(checkpoint['state_dict'])
recover_model = lambda: model.load_state_dict(checkpoint['state_dict'])

Downloading: "https://hanlab18.mit.edu/files/course/labs/vgg.cifar.pretrained.pth" to ./vgg.cifar.pretrained.pth


=> loading checkpoint 'https://hanlab18.mit.edu/files/course/labs/vgg.cifar.pretrained.pth'


In [None]:
from torchvision import datasets, transforms

# Define transformations
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Normalizing the images to [-1, 1]
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Normalizing the images to [-1, 1]
])

# Load datasets
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

# Define DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2, pin_memory=True)

# Your existing model initialization and training code follows here


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:11<00:00, 15337622.11it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
# Assume model, train_loader, and test_loader are predefined
dataloader = {'train': train_loader, 'test': test_loader}
pruner = ModelPruner(model, dataloader)
pruner.prune_and_fine_tune(target_sparsity=0.95, fine_tune_epochs=5)

Initial Model Evaluation:
Accuracy: 89.48%, Size: 36913448 bytes, Parameters: 9228362

Pruning Model...
Model Evaluation After Pruning:
Accuracy: 10.04%, Size: 36913448 bytes, Parameters: 9228362

Fine-tuning Model...
Final Model Evaluation:
Accuracy: 89.39%, Size: 36913448 bytes, Parameters: 9228362
