# Import Library

In [2]:
!pip install ipython-autotime

Collecting ipython-autotime
  Downloading ipython_autotime-0.3.1-py2.py3-none-any.whl (6.8 kB)
Installing collected packages: ipython-autotime
Successfully installed ipython-autotime-0.3.1


In [3]:
!pip install torch_pruning

Collecting torch_pruning
  Downloading torch_pruning-0.2.7-py3-none-any.whl (14 kB)
Installing collected packages: torch-pruning
Successfully installed torch-pruning-0.2.7


In [4]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import models,transforms,datasets
import torch.nn.utils.prune as prune
from copy import deepcopy
import torch_pruning as tp

%load_ext autotime

time: 107 µs (started: 2021-09-20 07:38:15 +00:00)


In [5]:
num_classes = 100
batch_size = 64

time: 984 µs (started: 2021-09-20 07:38:15 +00:00)


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda
time: 71.9 ms (started: 2021-09-20 07:38:15 +00:00)


# Important Functions

In [7]:
def get_total_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def get_pruned_parameters(pruned_model):
    params = 0
    for param in pruned_model.parameters():
        if param is not None:
            params += torch.nonzero(param).size(0)
    return params

time: 4.15 ms (started: 2021-09-20 07:38:16 +00:00)


In [8]:
def measure_module_sparsity(module, weight=True, bias=False, use_mask=False):
    num_zeros = 0
    num_elements = 0
    if use_mask == True:
        for buffer_name, buffer in module.named_buffers():
            if "weight_mask" in buffer_name and weight == True:
                num_zeros += torch.sum(buffer == 0).item()
                num_elements += buffer.nelement()
            if "bias_mask" in buffer_name and bias == True:
                num_zeros += torch.sum(buffer == 0).item()
                num_elements += buffer.nelement()
    else:
        for param_name, param in module.named_parameters():
            if "weight" in param_name and weight == True:
                num_zeros += torch.sum(param == 0).item()
                num_elements += param.nelement()
            if "bias" in param_name and bias == True:
                num_zeros += torch.sum(param == 0).item()
                num_elements += param.nelement()

    sparsity = num_zeros / num_elements

    return num_zeros, num_elements, sparsity

time: 23.6 ms (started: 2021-09-20 07:38:16 +00:00)


In [9]:
def measure_global_sparsity(model,
                            weight=True,
                            bias=False,
                            conv2d_use_mask=False,
                            linear_use_mask=False):

    num_zeros = 0
    num_elements = 0

    for module_name, module in model.named_modules():

        if isinstance(module, torch.nn.Conv2d):

            module_num_zeros, module_num_elements, _ = measure_module_sparsity(
                module, weight=weight, bias=bias, use_mask=conv2d_use_mask)
            num_zeros += module_num_zeros
            num_elements += module_num_elements

        elif isinstance(module, torch.nn.Linear):

            module_num_zeros, module_num_elements, _ = measure_module_sparsity(
                module, weight=weight, bias=bias, use_mask=linear_use_mask)
            num_zeros += module_num_zeros
            num_elements += module_num_elements

    sparsity = num_zeros / num_elements

    return num_zeros, num_elements, sparsity

time: 15.5 ms (started: 2021-09-20 07:38:16 +00:00)


In [10]:
def remove_parameters(model):
    for module_name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            try:
                prune.remove(module, "weight")
            except:
                pass
            try:
                prune.remove(module, "bias")
            except:
                pass
        elif isinstance(module, torch.nn.Linear):
            try:
                prune.remove(module, "weight")
            except:
                pass
            try:
                prune.remove(module, "bias")
            except:
                pass
    return model

time: 17.7 ms (started: 2021-09-20 07:38:16 +00:00)


In [11]:
def fit(model, dataset, optimizer, criterion):
  #Set model to training mode
  model.train()
  #Iterate over data
  for data, targets in dataset:
    data = data.to(device)
    targets = targets.to(device)
    #Reset the gradients
    optimizer.zero_grad()
    # Generate predictions
    out = model(data)
    # Calculate loss
    loss = criterion(out, targets)
    # Backpropagation
    loss.backward()
    # Update model parameters
    optimizer.step()

time: 14.8 ms (started: 2021-09-20 07:38:16 +00:00)


In [12]:
def evaluate(model, dataloader):
  model.eval()
  acc = 0
  loss = []
  for data, targets in dataloader:
    data = data.to(device)
    targets = targets.to(device)
    out = model(data)
    #Get loss
    l = criterion(out, targets)
    loss.append(l.item())
    #Get index of class label
    _,preds = torch.max(out.data,1)
    #Get accuracy
    acc += torch.sum(preds == targets).item()

  return 100*acc/len(dataloader.dataset), np.mean(np.array(loss))

time: 10.3 ms (started: 2021-09-20 07:38:16 +00:00)


In [13]:
def unstructured_global_pruning(model, amount):
  pruning_model = deepcopy(model)
  pruning_model = pruning_model.to(device)

  parameters_to_prune = []
  for module_name, module in pruning_model.named_modules():
      if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
          parameters_to_prune.append((module, "weight"))

  prune.global_unstructured(
      parameters = parameters_to_prune,
      pruning_method = prune.L1Unstructured,
      amount = amount)

  pruning_model = remove_parameters(model = pruning_model)

  return pruning_model

time: 26.7 ms (started: 2021-09-20 07:38:16 +00:00)


In [14]:
def structured_pruning(model, sparsity_layer):
  prune_model = deepcopy(model)
  prune_model = prune_model.to(device)

  prunable_module_type = (nn.Conv2d, nn.Linear)
  prunable_modules = [ m for m in prune_model.modules() if isinstance(m, prunable_module_type) ]

  strategy = tp.strategy.L1Strategy()
  DG = tp.DependencyGraph().build_dependency(prune_model, example_inputs = torch.randn(1,3,64,64))

  i = 0
  for layer_to_prune in prunable_modules:
    if isinstance( layer_to_prune, nn.Conv2d ):
        prune_fn = tp.prune_conv
    elif isinstance(layer_to_prune, nn.Linear):
        prune_fn = tp.prune_linear
    
    pruning_idxs = strategy(layer_to_prune.weight, amount = sparsity_layer[i])
    plan = DG.get_pruning_plan( layer_to_prune, prune_fn, pruning_idxs)

    i += 1

    plan.exec()
  
  return prune_model

time: 25.5 ms (started: 2021-09-20 07:38:16 +00:00)


# Import Data

In [15]:
train_transform = transforms.Compose([transforms.Resize(64),
                                   transforms.RandomHorizontalFlip(),
                                   transforms.RandomRotation(15),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.48,0.4593,0.4155),(0.2774,0.2794,0.2794))])

train_set = datasets.CIFAR100(root = "CIFAR100", train = True, download = True, transform = train_transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to CIFAR100/cifar-100-python.tar.gz


  0%|          | 0/169001437 [00:00<?, ?it/s]

Extracting CIFAR100/cifar-100-python.tar.gz to CIFAR100
time: 6.9 s (started: 2021-09-20 07:38:16 +00:00)


In [16]:
test_transform = transforms.Compose([transforms.Resize(64),
                                  transforms.ToTensor(),
                                  transforms.Normalize((0.48,0.4593,0.4155),(0.2774,0.2794,0.2794))])

test_set = datasets.CIFAR100(root = "CIFAR100", train = False, download = True, transform = test_transform)

Files already downloaded and verified
time: 932 ms (started: 2021-09-20 07:38:23 +00:00)


In [17]:
train_dataloader = DataLoader(train_set, batch_size = batch_size, shuffle = True)
test_dataloader = DataLoader(test_set, batch_size = batch_size)

time: 2.84 ms (started: 2021-09-20 07:38:24 +00:00)


In [18]:
criterion = nn.CrossEntropyLoss()

time: 1.54 ms (started: 2021-09-20 07:38:24 +00:00)


# Import Model to be Pruned

In [None]:
vgg16_model = models.vgg16(pretrained = False)
vgg16_model.classifier[6].out_features = num_classes

time: 2.48 s (started: 2021-09-15 06:50:08 +00:00)


In [None]:
vgg16_model.load_state_dict(torch.load('/content/drive/MyDrive/Sai/VGG16(SGD-0.003)'))

<All keys matched successfully>

time: 16.4 s (started: 2021-09-15 06:50:10 +00:00)


# Unstructured Pruning

In [None]:
ustp_model = unstructured_global_pruning(vgg16_model, 0.8)

time: 9.94 s (started: 2021-09-15 06:50:27 +00:00)


In [None]:
acc, loss = evaluate(ustp_model, test_dataloader)
print('Accuracy : ',acc , '\tLoss : ', loss)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Accuracy :  73.58 	Loss :  1.1833626514027833
time: 17.5 s (started: 2021-09-15 06:50:37 +00:00)


In [None]:
num_zeros, num_elements, sparsity = measure_global_sparsity(
            ustp_model,
            weight=True,
            bias=False,
            conv2d_use_mask=False,
            linear_use_mask=False)

print("Global Sparsity:")
print("{:.2f}".format(sparsity))

Global Sparsity:
0.80
time: 35.5 ms (started: 2021-09-15 06:50:54 +00:00)


In [None]:
sparsity_layer = []
i = 0
for module_name, module in ustp_model.named_modules():
  if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
    num_zeros, num_elements, sparsity = measure_module_sparsity(module, weight=True, bias=False, use_mask=False)
    print('Layer', i , num_zeros, num_elements, sparsity)
    sparsity_layer.append(sparsity)
    i+= 1

Layer 0 66 1728 0.03819444444444445
Layer 1 6032 36864 0.1636284722222222
Layer 2 12471 73728 0.16914876302083334
Layer 3 29062 147456 0.1970893012152778
Layer 4 67727 294912 0.2296515570746528
Layer 5 163749 589824 0.2776234944661458
Layer 6 159220 589824 0.2699449327256944
Layer 7 373172 1179648 0.3163418240017361
Layer 8 928557 2359296 0.3935737609863281
Layer 9 935664 2359296 0.3965861002604167
Layer 10 899004 2359296 0.3810475667317708
Layer 11 907191 2359296 0.3845176696777344
Layer 12 996302 2359296 0.42228783501519096
Layer 13 93584979 102760448 0.9107101109563088
Layer 14 10160493 16777216 0.605612576007843
Layer 15 1451613 4096000 0.354397705078125
time: 55.9 ms (started: 2021-09-15 06:50:54 +00:00)


In [None]:
sparsity_layer = [ round(round(i,2)-0.01,2) for i in sparsity_layer]
sparsity_layer

[0.03,
 0.15,
 0.16,
 0.19,
 0.22,
 0.27,
 0.26,
 0.31,
 0.38,
 0.39,
 0.37,
 0.37,
 0.41,
 0.9,
 0.6,
 0.34]

time: 9.25 ms (started: 2021-09-15 06:50:54 +00:00)


# Structured Pruning

In [None]:
stp_model = structured_pruning(ustp_model, sparsity_layer)

time: 1.49 s (started: 2021-09-15 06:51:13 +00:00)


In [None]:
tp.utils.count_params(stp_model)

13203121

time: 6.67 ms (started: 2021-09-15 06:51:14 +00:00)


In [None]:
stp_model.classifier[6].out_features = 100
print(stp_model)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 63, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(63, 55, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(55, 108, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(108, 104, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(104, 200, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(200, 187, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(187, 190, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [None]:
stp_model = stp_model.to(device)
acc, loss = evaluate(stp_model, test_dataloader)
print('Accuracy : ',acc , '\tLoss : ', loss)

Accuracy :  8.92 	Loss :  4.479058964237286
time: 13 s (started: 2021-09-15 06:51:14 +00:00)


# Fine Tuning

In [None]:
sgd_optimizer = torch.optim.SGD(stp_model.parameters(), lr = 0.001, momentum = 0.9)
epochs = 10

for epoch in range(epochs):
  #Fit model
  fit(stp_model, train_dataloader, sgd_optimizer, criterion)

  #Train validation
  trn_acc, trn_lss = evaluate(stp_model, train_dataloader)

  #Test validation
  tst_acc, tst_lss = evaluate(stp_model, test_dataloader)

  print(f'Epoch:{epoch+1:2.0f}\t Train_Loss:{trn_lss:.4f}\t Train_Acc:{trn_acc:.4f}\t Test_Loss:{tst_lss:.4f}\t Test_Acc:{tst_acc:.4f}')

Epoch: 1	 Train_Loss:1.0196	 Train_Acc:70.4460	 Test_Loss:1.3063	 Test_Acc:62.6700
Epoch: 2	 Train_Loss:0.7246	 Train_Acc:78.4480	 Test_Loss:1.1228	 Test_Acc:68.8400
Epoch: 3	 Train_Loss:0.6333	 Train_Acc:80.7040	 Test_Loss:1.1397	 Test_Acc:69.1800
Epoch: 4	 Train_Loss:0.5470	 Train_Acc:83.6740	 Test_Loss:1.0626	 Test_Acc:70.0100
Epoch: 5	 Train_Loss:0.4999	 Train_Acc:84.7720	 Test_Loss:1.1162	 Test_Acc:70.3500
Epoch: 6	 Train_Loss:0.4324	 Train_Acc:86.7560	 Test_Loss:1.0917	 Test_Acc:71.2100
Epoch: 7	 Train_Loss:0.4252	 Train_Acc:86.8800	 Test_Loss:1.1379	 Test_Acc:71.1900
Epoch: 8	 Train_Loss:0.3794	 Train_Acc:88.1320	 Test_Loss:1.1187	 Test_Acc:71.6400
Epoch: 9	 Train_Loss:0.3693	 Train_Acc:88.5340	 Test_Loss:1.1188	 Test_Acc:71.6200
Epoch:10	 Train_Loss:0.3161	 Train_Acc:90.2100	 Test_Loss:1.1320	 Test_Acc:71.8000
time: 35min 29s (started: 2021-09-15 06:51:37 +00:00)


In [None]:
sgd_optimizer = torch.optim.SGD(stp_model.parameters(), lr = 0.0001, momentum = 0.9)
epochs = 5

for epoch in range(epochs):
  #Fit model
  fit(stp_model, train_dataloader, sgd_optimizer, criterion)

  #Train validation
  trn_acc, trn_lss = evaluate(stp_model, train_dataloader)

  #Test validation
  tst_acc, tst_lss = evaluate(stp_model, test_dataloader)

  print(f'Epoch:{epoch+11:2.0f}\t Train_Loss:{trn_lss:.4f}\t Train_Acc:{trn_acc:.4f}\t Test_Loss:{tst_lss:.4f}\t Test_Acc:{tst_acc:.4f}')

Epoch:11	 Train_Loss:0.2258	 Train_Acc:93.0340	 Test_Loss:1.1073	 Test_Acc:73.6600
Epoch:12	 Train_Loss:0.2108	 Train_Acc:93.3180	 Test_Loss:1.1376	 Test_Acc:73.8000
Epoch:13	 Train_Loss:0.2001	 Train_Acc:93.7160	 Test_Loss:1.1403	 Test_Acc:74.0100
Epoch:14	 Train_Loss:0.1956	 Train_Acc:93.8680	 Test_Loss:1.1514	 Test_Acc:73.7400
Epoch:15	 Train_Loss:0.1886	 Train_Acc:94.0800	 Test_Loss:1.1544	 Test_Acc:74.0200
time: 17min 42s (started: 2021-09-15 07:27:15 +00:00)


In [20]:
sgd_optimizer = torch.optim.SGD(stp_model.parameters(), lr = 0.00001, momentum = 0.9)
epochs = 5

for epoch in range(epochs):
  #Fit model
  fit(stp_model, train_dataloader, sgd_optimizer, criterion)

  #Train validation
  trn_acc, trn_lss = evaluate(stp_model, train_dataloader)

  #Test validation
  tst_acc, tst_lss = evaluate(stp_model, test_dataloader)

  print(f'Epoch:{epoch+16:2.0f}\t Train_Loss:{trn_lss:.4f}\t Train_Acc:{trn_acc:.4f}\t Test_Loss:{tst_lss:.4f}\t Test_Acc:{tst_acc:.4f}')

Epoch:16	 Train_Loss:0.1868	 Train_Acc:94.1060	 Test_Loss:1.1579	 Test_Acc:73.9900
Epoch:17	 Train_Loss:0.1825	 Train_Acc:94.3440	 Test_Loss:1.1608	 Test_Acc:74.0500
Epoch:18	 Train_Loss:0.1833	 Train_Acc:94.3000	 Test_Loss:1.1608	 Test_Acc:74.0200
Epoch:19	 Train_Loss:0.1829	 Train_Acc:94.2280	 Test_Loss:1.1626	 Test_Acc:74.0200
Epoch:20	 Train_Loss:0.1803	 Train_Acc:94.3280	 Test_Loss:1.1633	 Test_Acc:74.2100
time: 17min 24s (started: 2021-09-20 07:38:42 +00:00)


# Testing

In [21]:
acc, loss = evaluate(stp_model, test_dataloader)
print('Accuracy : ',acc , '\tLoss : ', loss)

Accuracy :  74.21 	Loss :  1.1633469540222434
time: 12.9 s (started: 2021-09-20 07:56:47 +00:00)


# Save Pruned Model

In [22]:
torch.save(stp_model, 'PrunedModel_V1.pt')

time: 125 ms (started: 2021-09-20 07:57:00 +00:00)
