# Instructions

Run this notebook to:
* Load a vgg16 model pretrained on the cifar10 dataset, from the "pretrainedmodel" folder.
* Use this pretrained model to perform "Filter Pruning via Geometric Median". The pruned model is fine-tuned for 40 epochs. The pruning is done iteratively. So far the parameters are only zeroed out. The pruned model at this stage is saved as "vgg_cifar10_pruned_net.pth" in the present working directory.
* Finally, architecture modifications are performed and the final pruned model is saved as "vgg_cifar10_arch_pruned_net.pth" in the present working directory.

# Selecting device

In [1]:
import torch 
import torch.nn as nn

device = "cpu"
if torch.cuda.is_available():
    device = "cuda"

In [2]:
! CUDA_VISIBLE_DEVICES=0
#! python ./fpgmdata/testing/pruning_cifar_vgg.py  ./fpgmdata/testing/data/cifar.python --dataset cifar10 --arch vgg --save_path ./logs/vgg_prune_precfg_varience4 --rate_norm 1 --rate_dist 0.2
! python  ./fpgmdata/testing/pruning_cifar_vgg.py  ./fpgmdata/testing/data/cifar.python --dataset cifar10 --arch vgg --save_path ./logs/vgg_pretrain/prune_precfg_epoch40_varience1 --rate_norm 1 --rate_dist 0.2 --use_pretrain --pretrain_path ./pretrainedmodel/checkpoint.pth.tar --use_state_dict --lr 0.001 --epochs 40 --use_precfg

save path : ./logs/vgg_pretrain/prune_precfg_epoch40_varience1
{'data_path': '/kaggle/input/fpgmdata/testing/data/cifar.python', 'dataset': 'cifar10', 'batch_size': 64, 'test_batch_size': 256, 'epochs': 40, 'start_epoch': 0, 'lr': 0.001, 'momentum': 0.9, 'weight_decay': 0.0001, 'resume': '', 'no_cuda': False, 'seed': 1, 'log_interval': 100, 'save_path': './logs/vgg_pretrain/prune_precfg_epoch40_varience1', 'arch': 'vgg', 'depth': 16, 'rate_norm': 1.0, 'rate_dist': 0.2, 'layer_begin': 1, 'layer_end': 1, 'layer_inter': 1, 'epoch_prune': 1, 'dist_type': 'l2', 'use_state_dict': True, 'use_pretrain': True, 'pretrain_path': '/kaggle/input/pretrainedmodel/checkpoint.pth.tar', 'use_precfg': True, 'evaluate': False, 'cuda': True}
Random Seed: 1
python version : 3.10.10 | packaged by conda-forge | (main, Mar 24 2023, 20:08:06) [GCC 11.3.0]
torch  version : 2.0.0
cudnn  version : 8700
Norm Pruning Rate: 1.0
Distance Pruning Rate: 0.2
Layer Begin: 1
Layer End: 1
Layer Inter: 1
Epoch pru

In [3]:
import torch
import sys
sys.path.append("./fpgmdata/testing")
import models

unpruned_model = models.__dict__['vgg'](dataset="cifar10", depth=16)

checkpoint = torch.load("./pretrainedmodel/checkpoint.pth.tar", map_location=device)

unpruned_model.load_state_dict(checkpoint['state_dict'])
unpruned_model.to(device)

total = 0
print('Trainable parameters:')

for n, module in unpruned_model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        for name, param in module.named_parameters():
            if param.requires_grad:
                print(name, '\t', param.numel())
                total += param.numel()
print()
print('Total', '\t', total)

Trainable parameters:
weight 	 1728
weight 	 36864
weight 	 73728
weight 	 147456
weight 	 294912
weight 	 589824
weight 	 589824
weight 	 1179648
weight 	 2359296
weight 	 2359296
weight 	 2359296
weight 	 2359296
weight 	 2359296

Total 	 14710464


# General function to test a model

In [4]:
import numpy as np

def test_model(model):
    model.eval()
    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    timings = []
    #GPU-WARM-UP
    i=0
    for data in testloader:
        if(i>1000):
            break
        images, labels = data
        images = images.to(device)
        _ = model(images)
        i += 1
    
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            
            starter.record()
            outputs = model(images)
            ender.record()
            
            # WAIT FOR GPU SYNC
            torch.cuda.synchronize()
            curr_time = starter.elapsed_time(ender)
            timings.append(curr_time)
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: '+str(100 * correct / total))
    
    tot = np.sum(timings)
    mean_syn_per_batch = np.sum(timings) / len(timings)
    std_syn_per_batch = np.std(timings)
    print("Total inference time for test data: "+str(tot))
    print("Mean inference time per test batch: "+str(mean_syn_per_batch))
    print("Standard deviation of inference times per batch: "+str(std_syn_per_batch))
    model.train()

# Loading and normalizing images using TorchVision


In [5]:
import torchvision
import torchvision.transforms as transforms

In [6]:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([
                                 transforms.Pad(4),
                                 transforms.RandomCrop(32),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                             ]),
                                        download=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
            ]),
                                       download=True)
testloader = torch.utils.data.DataLoader(testset,
                                         batch_size=32,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:36<00:00, 4692265.14it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [7]:
import math

defaultcfg = {
    11 : [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
    13 : [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512],
    16 : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512],
    19 : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512],
}

class vgg(nn.Module):
    def __init__(self, dataset='cifar10', depth=16, init_weights=True, cfg=None):
        super(vgg, self).__init__()
        if cfg is None:
            cfg = defaultcfg[depth]

        self.cfg = cfg

        self.feature = self.make_layers(cfg, True)

        if dataset == 'cifar10':
            num_classes = 10
        elif dataset == 'cifar100':
            num_classes = 100
        self.classifier = nn.Sequential(
            nn.Linear(cfg[-1], 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Linear(512, num_classes)
        )
        if init_weights:
            self._initialize_weights()

    def make_layers(self, cfg, batch_norm=False):
        layers = []
        in_channels = 3
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1, bias=False)
                if batch_norm:
                    layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                else:
                    layers += [conv2d, nn.ReLU(inplace=True)]
                in_channels = v
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.feature(x)
        x = nn.AvgPool2d(2)(x)
        x = x.view(x.size(0), -1)
        y = self.classifier(x)
        return y

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(0.5)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

# Testing the accuracy of the unpruned model

In [8]:
test_model(unpruned_model)

Accuracy of the network on the 10000 test images: 93.59
Total inference time for test data: 1548.8746535778046
Mean inference time per test batch: 4.948481321334839
Standard deviation of inference times per batch: 1.0180026892832421


# Loading the pruned (only zeroed out) model

In [9]:
pruned_model = vgg().to(device)

pruned_model.load_state_dict(torch.load("./logs/vgg_pretrain/prune_precfg_epoch40_varience1/checkpoint.pth.tar")['state_dict'])

<All keys matched successfully>

# Saving the pruned (only zeroed out) model

In [10]:
torch.save(pruned_model, './vgg_cifar10_pruned_net.pth') # without .state_dict

# Let's test the accuracy of the pruned (only zeroed out) model

In [11]:
test_model(pruned_model)

Accuracy of the network on the 10000 test images: 93.51
Total inference time for test data: 1762.172775030136
Mean inference time per test batch: 5.6299449681474
Standard deviation of inference times per batch: 2.086842221757349


# Changing the architecture

In [12]:
!pip install torch-pruning
import torch_pruning as tp
    
for name, module in pruned_model.named_modules():
    if isinstance(module, torch.nn.Conv2d): #Iterating over all the conv2d layers of the model
        channel_indices = [] #Stores indices of the channels to prune within this conv layer
        t = module.weight.clone().detach()
        t = t.reshape(t.shape[0], -1)
        z = torch.all(t == 0, dim=1)
        z = z.tolist()
        
        for i, flag in enumerate(z):
            if(flag):
                channel_indices.append(i)

        if(channel_indices == []):
            continue
        
        # 1. build dependency graph for vgg
        DG = tp.DependencyGraph().build_dependency(pruned_model, example_inputs=torch.randn(1,3,32,32).to(device))

        # 2. Specify the to-be-pruned channels. Here we prune those channels indexed by idxs.
        group = DG.get_pruning_group(module, tp.prune_conv_out_channels, idxs=channel_indices)
        #print(group)

        # 3. prune all grouped layers that are coupled with the conv layer (included).
        if DG.check_pruning_group(group): # avoid full pruning, i.e., channels=0.
            group.prune()
    
# 4. Save & Load
pruned_model.zero_grad() # We don't want to store gradient information
torch.save(pruned_model, './vgg_cifar10_arch_pruned_net.pth') # without .state_dict

Collecting torch-pruning
  Downloading torch_pruning-1.1.9-py3-none-any.whl (39 kB)
Installing collected packages: torch-pruning
Successfully installed torch-pruning-1.1.9
[0m

# Let's test the accuracy of the pruned model after the architecture modifications

In [13]:
test_model(pruned_model)

Accuracy of the network on the 10000 test images: 93.11
Total inference time for test data: 1473.8917746543884
Mean inference time per test batch: 4.708919407841496
Standard deviation of inference times per batch: 1.8891718471936976


# Arch pruned model reload check

In [14]:
reloaded_model = torch.load('./vgg_cifar10_arch_pruned_net.pth')

In [15]:
test_model(reloaded_model)

Accuracy of the network on the 10000 test images: 93.11
Total inference time for test data: 1664.7044486999512
Mean inference time per test batch: 5.318544564536586
Standard deviation of inference times per batch: 2.201446803791996
