In [0]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import torchvision

In [3]:
vggmodel = torchvision.models.vgg11(pretrained=True)

Downloading: "https://download.pytorch.org/models/vgg11-bbd30ac9.pth" to /root/.cache/torch/checkpoints/vgg11-bbd30ac9.pth


HBox(children=(IntProgress(value=0, max=531456000), HTML(value='')))




**Applying l1 unstructured pruning on convolution and dense layer**

In [0]:
for name,module in vggmodel.named_modules():
  if(isinstance(module,torch.nn.Conv2d)):
    prune.l1_unstructured(module, name='weight', amount=0.2)
  elif(isinstance(module, torch.nn.Linear)):
    prune.l1_unstructured(module, name='weight', amount=0.4)

**Model weight for the first convolution layer before applying pruning**

In [6]:
module.weight

tensor([[[[ 0.2882,  0.0000, -0.3850],
          [ 0.1795,  0.3668, -0.5012],
          [-0.0974,  0.3648, -0.2296]],

         [[ 0.4015, -0.0000, -0.6842],
          [ 0.4442,  0.4478, -0.7949],
          [ 0.1129,  0.4917, -0.3705]],

         [[ 0.2162, -0.0000, -0.3949],
          [ 0.1490,  0.2967, -0.4294],
          [-0.0000,  0.3479, -0.1558]]],


        [[[-0.3528, -0.2549,  0.6734],
          [-0.6027, -0.3453,  0.8054],
          [-0.4197, -0.1347,  0.6622]],

         [[-0.5740, -0.3998,  0.7708],
          [-0.8758, -0.3668,  1.1098],
          [-0.5186, -0.0801,  0.9228]],

         [[-0.0000, -0.2954,  0.2159],
          [-0.1868, -0.2904,  0.3808],
          [-0.0926, -0.0766,  0.3727]]],


        [[[ 0.0000, -0.2592, -0.2389],
          [ 0.3209,  0.2152, -0.2047],
          [-0.0000,  0.2227, -0.0000]],

         [[ 0.0000, -0.4797, -0.5490],
          [ 0.5363,  0.2973, -0.3728],
          [ 0.1776,  0.4145, -0.0000]],

         [[ 0.0000, -0.3427, -0.3938],
     

In [0]:
for name,module in vggmodel.named_modules():
  module._forward_pre_hooks

In [0]:
module = vggmodel.features[0]

**Model weights for first convolutional layer after applying pruning**

We can find more number of zeros in the weights meaning those weights were pruned.

In [0]:
module.weight

tensor([[[[ 0.2882,  0.0000, -0.3850],
          [ 0.1795,  0.3668, -0.5012],
          [-0.0974,  0.3648, -0.2296]],

         [[ 0.4015, -0.0000, -0.6842],
          [ 0.4442,  0.4478, -0.7949],
          [ 0.1129,  0.4917, -0.3705]],

         [[ 0.2162, -0.0000, -0.3949],
          [ 0.1490,  0.2967, -0.4294],
          [-0.0000,  0.3479, -0.1558]]],


        [[[-0.3528, -0.2549,  0.6734],
          [-0.6027, -0.3453,  0.8054],
          [-0.4197, -0.1347,  0.6622]],

         [[-0.5740, -0.3998,  0.7708],
          [-0.8758, -0.3668,  1.1098],
          [-0.5186, -0.0801,  0.9228]],

         [[-0.0000, -0.2954,  0.2159],
          [-0.1868, -0.2904,  0.3808],
          [-0.0926, -0.0766,  0.3727]]],


        [[[ 0.0000, -0.2592, -0.2389],
          [ 0.3209,  0.2152, -0.2047],
          [-0.0000,  0.2227, -0.0000]],

         [[ 0.0000, -0.4797, -0.5490],
          [ 0.5363,  0.2973, -0.3728],
          [ 0.1776,  0.4145, -0.0000]],

         [[ 0.0000, -0.3427, -0.3938],
     

In [8]:
list(module.named_parameters())

[('bias', Parameter containing:
  tensor([ 0.1939,  0.3042,  0.1825, -1.1122,  0.0442, -0.0678,  0.1324, -0.5846,
           0.2210, -0.0130, -0.3794,  0.1256,  0.2415,  0.2491, -0.7849, -1.0575,
           0.2637, -0.1838,  0.1533, -1.1485,  0.0095, -0.8640, -0.3903, -0.4040,
           0.5775,  0.2528,  0.0911,  0.1554, -0.1833, -0.4766, -0.2352, -1.1038,
          -0.5924,  0.2940,  0.1892,  0.4142, -0.0220,  0.1800, -0.5532,  0.2365,
           0.1548, -1.7165, -0.0318,  0.0494,  0.1286,  0.1860,  0.1789,  0.4679,
          -0.0066, -0.0948, -1.4731, -0.7618, -1.2010,  0.1765,  0.2015, -0.0822,
           0.1453,  0.0289, -1.2024,  0.1595, -0.8845, -0.0075,  0.2292, -1.3837],
         requires_grad=True)), ('weight_orig', Parameter containing:
  tensor([[[[ 0.2882,  0.0358, -0.3850],
            [ 0.1795,  0.3668, -0.5012],
            [-0.0974,  0.3648, -0.2296]],
  
           [[ 0.4015, -0.0461, -0.6842],
            [ 0.4442,  0.4478, -0.7949],
            [ 0.1129,  0.4917, -0

In [9]:
list(module.named_buffers())

[('weight_mask', tensor([[[[1., 0., 1.],
            [1., 1., 1.],
            [1., 1., 1.]],
  
           [[1., 0., 1.],
            [1., 1., 1.],
            [1., 1., 1.]],
  
           [[1., 0., 1.],
            [1., 1., 1.],
            [0., 1., 1.]]],
  
  
          [[[1., 1., 1.],
            [1., 1., 1.],
            [1., 1., 1.]],
  
           [[1., 1., 1.],
            [1., 1., 1.],
            [1., 1., 1.]],
  
           [[0., 1., 1.],
            [1., 1., 1.],
            [1., 1., 1.]]],
  
  
          [[[0., 1., 1.],
            [1., 1., 1.],
            [0., 1., 0.]],
  
           [[0., 1., 1.],
            [1., 1., 1.],
            [1., 1., 0.]],
  
           [[0., 1., 1.],
            [1., 1., 1.],
            [1., 1., 0.]]],
  
  
          ...,
  
  
          [[[1., 1., 1.],
            [1., 1., 1.],
            [1., 1., 1.]],
  
           [[1., 1., 1.],
            [1., 1., 1.],
            [1., 1., 1.]],
  
           [[1., 1., 0.],
            [1., 1., 1.]

In [10]:
print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(vggmodel.features[0].weight == 0))
        / float(vggmodel.features[0].weight.nelement())
    )
)

Sparsity in conv1.weight: 20.02%


In [27]:
batch_size_train = 64
batch_size_test = 1000
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.CIFAR10('/content/CIFARtrain/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.CIFAR10('/content/CIFARtest/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /content/CIFARtrain/cifar-10-python.tar.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting /content/CIFARtrain/cifar-10-python.tar.gz to /content/CIFARtrain/
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /content/CIFARtest/cifar-10-python.tar.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting /content/CIFARtest/cifar-10-python.tar.gz to /content/CIFARtest/
