# Import Library

In [2]:
!pip install ipython-autotime

Collecting ipython-autotime
  Downloading ipython_autotime-0.3.1-py2.py3-none-any.whl (6.8 kB)
Installing collected packages: ipython-autotime
Successfully installed ipython-autotime-0.3.1


In [3]:
!pip install torch_pruning

Collecting torch_pruning
  Downloading torch_pruning-0.2.7-py3-none-any.whl (14 kB)
Installing collected packages: torch-pruning
Successfully installed torch-pruning-0.2.7


In [4]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import models,transforms,datasets
import torch.nn.utils.prune as prune
from copy import deepcopy
import torch_pruning as tp

%load_ext autotime

time: 88.6 µs (started: 2021-09-20 20:18:05 +00:00)


In [5]:
num_classes = 100
batch_size = 64

time: 827 µs (started: 2021-09-20 20:18:05 +00:00)


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda
time: 83.6 ms (started: 2021-09-20 20:18:05 +00:00)


# Important Functions

In [7]:
def get_total_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def get_pruned_parameters(pruned_model):
    params = 0
    for param in pruned_model.parameters():
        if param is not None:
            params += torch.nonzero(param).size(0)
    return params

time: 7.81 ms (started: 2021-09-20 20:18:05 +00:00)


In [8]:
def measure_module_sparsity(module, weight=True, bias=False, use_mask=False):
    num_zeros = 0
    num_elements = 0
    if use_mask == True:
        for buffer_name, buffer in module.named_buffers():
            if "weight_mask" in buffer_name and weight == True:
                num_zeros += torch.sum(buffer == 0).item()
                num_elements += buffer.nelement()
            if "bias_mask" in buffer_name and bias == True:
                num_zeros += torch.sum(buffer == 0).item()
                num_elements += buffer.nelement()
    else:
        for param_name, param in module.named_parameters():
            if "weight" in param_name and weight == True:
                num_zeros += torch.sum(param == 0).item()
                num_elements += param.nelement()
            if "bias" in param_name and bias == True:
                num_zeros += torch.sum(param == 0).item()
                num_elements += param.nelement()

    sparsity = num_zeros / num_elements

    return num_zeros, num_elements, sparsity

time: 16.8 ms (started: 2021-09-20 20:18:05 +00:00)


In [9]:
def measure_global_sparsity(model,
                            weight=True,
                            bias=False,
                            conv2d_use_mask=False,
                            linear_use_mask=False):

    num_zeros = 0
    num_elements = 0

    for module_name, module in model.named_modules():

        if isinstance(module, torch.nn.Conv2d):

            module_num_zeros, module_num_elements, _ = measure_module_sparsity(
                module, weight=weight, bias=bias, use_mask=conv2d_use_mask)
            num_zeros += module_num_zeros
            num_elements += module_num_elements

        elif isinstance(module, torch.nn.Linear):

            module_num_zeros, module_num_elements, _ = measure_module_sparsity(
                module, weight=weight, bias=bias, use_mask=linear_use_mask)
            num_zeros += module_num_zeros
            num_elements += module_num_elements

    sparsity = num_zeros / num_elements

    return num_zeros, num_elements, sparsity

time: 16.3 ms (started: 2021-09-20 20:18:05 +00:00)


In [10]:
def remove_parameters(model):
    for module_name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            try:
                prune.remove(module, "weight")
            except:
                pass
            try:
                prune.remove(module, "bias")
            except:
                pass
        elif isinstance(module, torch.nn.Linear):
            try:
                prune.remove(module, "weight")
            except:
                pass
            try:
                prune.remove(module, "bias")
            except:
                pass
    return model

time: 17.4 ms (started: 2021-09-20 20:18:05 +00:00)


In [11]:
def fit(model, dataset, optimizer, criterion):
  #Set model to training mode
  model.train()
  #Iterate over data
  for data, targets in dataset:
    data = data.to(device)
    targets = targets.to(device)
    #Reset the gradients
    optimizer.zero_grad()
    # Generate predictions
    out = model(data)
    # Calculate loss
    loss = criterion(out, targets)
    # Backpropagation
    loss.backward()
    # Update model parameters
    optimizer.step()

time: 6.7 ms (started: 2021-09-20 20:18:05 +00:00)


In [12]:
def evaluate(model, dataloader):
  model.eval()
  acc = 0
  loss = []
  for data, targets in dataloader:
    data = data.to(device)
    targets = targets.to(device)
    out = model(data)
    #Get loss
    l = criterion(out, targets)
    loss.append(l.item())
    #Get index of class label
    _,preds = torch.max(out.data,1)
    #Get accuracy
    acc += torch.sum(preds == targets).item()

  return 100*acc/len(dataloader.dataset), np.mean(np.array(loss))

time: 25.1 ms (started: 2021-09-20 20:18:05 +00:00)


In [13]:
def unstructured_global_pruning(model, amount):
  pruning_model = deepcopy(model)
  pruning_model = pruning_model.to(device)

  parameters_to_prune = []
  for module_name, module in pruning_model.named_modules():
      if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
          parameters_to_prune.append((module, "weight"))
  parameters_to_prune = parameters_to_prune[:-1]

  prune.global_unstructured(
      parameters = parameters_to_prune,
      pruning_method = prune.L1Unstructured,
      amount = amount)

  pruning_model = remove_parameters(model = pruning_model)

  return pruning_model

time: 11.8 ms (started: 2021-09-20 20:18:05 +00:00)


In [14]:
def structured_pruning(model, sparsity_layer):
  prune_model = deepcopy(model)
  prune_model = prune_model.to(device)

  prunable_module_type = (nn.Conv2d, nn.Linear)
  prunable_modules = [ m for m in prune_model.modules() if isinstance(m, prunable_module_type) ]
  prunable_modules = prunable_modules[:-1]

  strategy = tp.strategy.L1Strategy()
  DG = tp.DependencyGraph().build_dependency(prune_model, example_inputs = torch.randn(1,3,64,64))

  i = 0
  for layer_to_prune in prunable_modules:
    if isinstance( layer_to_prune, nn.Conv2d ):
        prune_fn = tp.prune_conv
    elif isinstance(layer_to_prune, nn.Linear):
        prune_fn = tp.prune_linear
    
    pruning_idxs = strategy(layer_to_prune.weight, amount = sparsity_layer[i])
    plan = DG.get_pruning_plan( layer_to_prune, prune_fn, pruning_idxs)

    i += 1

    plan.exec()
  
  return prune_model

time: 15.2 ms (started: 2021-09-20 20:18:05 +00:00)


# Import Data

In [15]:
train_transform = transforms.Compose([transforms.Resize(64),
                                   transforms.RandomHorizontalFlip(),
                                   transforms.RandomRotation(15),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.48,0.4593,0.4155),(0.2774,0.2794,0.2794))])

train_set = datasets.CIFAR100(root = "CIFAR100", train = True, download = True, transform = train_transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to CIFAR100/cifar-100-python.tar.gz


  0%|          | 0/169001437 [00:00<?, ?it/s]

Extracting CIFAR100/cifar-100-python.tar.gz to CIFAR100
time: 6.59 s (started: 2021-09-20 20:18:05 +00:00)


In [16]:
test_transform = transforms.Compose([transforms.Resize(64),
                                  transforms.ToTensor(),
                                  transforms.Normalize((0.48,0.4593,0.4155),(0.2774,0.2794,0.2794))])

test_set = datasets.CIFAR100(root = "CIFAR100", train = False, download = True, transform = test_transform)

Files already downloaded and verified
time: 804 ms (started: 2021-09-20 20:18:12 +00:00)


In [17]:
train_dataloader = DataLoader(train_set, batch_size = batch_size, shuffle = True)
test_dataloader = DataLoader(test_set, batch_size = batch_size)

time: 3.68 ms (started: 2021-09-20 20:18:13 +00:00)


In [18]:
criterion = nn.CrossEntropyLoss()

time: 3.32 ms (started: 2021-09-20 20:18:13 +00:00)


# Import Model to be Pruned

In [None]:
model = torch.load('/content/drive/MyDrive/CV Systems/Project1/Pruning/VGG16/PrunedModel.pt')

time: 9.64 s (started: 2021-09-15 18:50:30 +00:00)


In [None]:
print(model)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 63, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(63, 55, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(55, 108, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(108, 104, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(104, 200, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(200, 187, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(187, 190, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

# Unstructured Pruning

In [None]:
ustp_model = unstructured_global_pruning(model, 0.5)

time: 756 ms (started: 2021-09-15 19:04:47 +00:00)


In [None]:
acc, loss = evaluate(ustp_model, test_dataloader)
print('Accuracy : ',acc , '\tLoss : ', loss)

Accuracy :  73.65 	Loss :  1.0088575296341233
time: 14.7 s (started: 2021-09-15 18:50:52 +00:00)


In [None]:
num_zeros, num_elements, sparsity = measure_global_sparsity(
            ustp_model,
            weight=True,
            bias=False,
            conv2d_use_mask=False,
            linear_use_mask=False)

print("Global Sparsity:")
print("{:.2f}".format(sparsity))

Global Sparsity:
0.49
time: 13.6 ms (started: 2021-09-15 19:04:48 +00:00)


In [None]:
sparsity_layer = []
i = 0
for module_name, module in ustp_model.named_modules():
  if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
    num_zeros, num_elements, sparsity = measure_module_sparsity(module, weight=True, bias=False, use_mask=False)
    print('Layer', i , num_zeros, num_elements, sparsity)
    sparsity_layer.append(sparsity)
    i+= 1

Layer 0 30 1701 0.01763668430335097
Layer 1 2945 31185 0.0944364277697611
Layer 2 4243 53460 0.0793677515899738
Layer 3 9858 101088 0.09751899335232668
Layer 4 23455 187200 0.12529380341880342
Layer 5 53990 336600 0.16039809863339274
Layer 6 46109 319770 0.14419426462770116
Layer 7 104703 605340 0.17296560610565964
Layer 8 231774 1013148 0.22876618223596157
Layer 9 203450 895806 0.22711390635918938
Layer 10 234236 909891 0.25743303318749167
Layer 11 273627 938961 0.29141465939479916
Layer 12 278287 880821 0.31594046917591656
Layer 13 4854214 6087270 0.7974369462829807
Layer 14 196195 671990 0.2919611898986592
Layer 15 0 163900 0.0
time: 51.8 ms (started: 2021-09-15 19:04:50 +00:00)


In [None]:
sparsity_layer = [ max(round(round(i,2)-0.01,2),0) for i in sparsity_layer]
sparsity_layer

[0.01,
 0.08,
 0.07,
 0.09,
 0.12,
 0.15,
 0.13,
 0.16,
 0.22,
 0.22,
 0.25,
 0.28,
 0.31,
 0.79,
 0.28,
 0]

time: 9.05 ms (started: 2021-09-15 19:04:52 +00:00)


# Structured Pruning

In [None]:
stp_model = structured_pruning(ustp_model, sparsity_layer)

time: 217 ms (started: 2021-09-15 19:04:54 +00:00)


In [None]:
tp.utils.count_params(stp_model)

5070779

time: 13.6 ms (started: 2021-09-15 19:04:56 +00:00)


In [None]:
print(stp_model)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 63, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(63, 51, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(51, 101, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(101, 95, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(95, 176, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(176, 159, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(159, 166, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, 

In [None]:
stp_model = stp_model.to(device)
acc, loss = evaluate(stp_model, test_dataloader)
print('Accuracy : ',acc , '\tLoss : ', loss)

Accuracy :  22.55 	Loss :  3.691333208873773
time: 11.8 s (started: 2021-09-15 18:51:25 +00:00)


# Fine Tuning

In [None]:
sgd_optimizer = torch.optim.SGD(stp_model.parameters(), lr = 0.001, momentum = 0.9)
epochs = 5
stp_model = stp_model.to(device)

for epoch in range(epochs):
  #Fit model
  fit(stp_model, train_dataloader, sgd_optimizer, criterion)

  #Train validation
  trn_acc, trn_lss = evaluate(stp_model, train_dataloader)

  #Test validation
  tst_acc, tst_lss = evaluate(stp_model, test_dataloader)

  print(f'Epoch:{epoch+1:2.0f}\t Train_Loss:{trn_lss:.4f}\t Train_Acc:{trn_acc:.4f}\t Test_Loss:{tst_lss:.4f}\t Test_Acc:{tst_acc:.4f}')

Epoch: 1	 Train_Loss:0.7550	 Train_Acc:77.7100	 Test_Loss:1.2360	 Test_Acc:66.9100
Epoch: 2	 Train_Loss:0.6192	 Train_Acc:81.6400	 Test_Loss:1.1803	 Test_Acc:68.8000
Epoch: 3	 Train_Loss:0.5873	 Train_Acc:82.6380	 Test_Loss:1.1991	 Test_Acc:69.2300
Epoch: 4	 Train_Loss:0.5924	 Train_Acc:82.4600	 Test_Loss:1.2134	 Test_Acc:69.1100
Epoch: 5	 Train_Loss:0.5021	 Train_Acc:85.1320	 Test_Loss:1.1732	 Test_Acc:70.2100
time: 15min 37s (started: 2021-09-15 19:05:02 +00:00)


In [None]:
sgd_optimizer = torch.optim.SGD(stp_model.parameters(), lr = 0.0001, momentum = 0.9)
epochs = 5

for epoch in range(epochs):
  #Fit model
  fit(stp_model, train_dataloader, sgd_optimizer, criterion)

  #Train validation
  trn_acc, trn_lss = evaluate(stp_model, train_dataloader)

  #Test validation
  tst_acc, tst_lss = evaluate(stp_model, test_dataloader)

  print(f'Epoch:{epoch+6:2.0f}\t Train_Loss:{trn_lss:.4f}\t Train_Acc:{trn_acc:.4f}\t Test_Loss:{tst_lss:.4f}\t Test_Acc:{tst_acc:.4f}')

Epoch: 6	 Train_Loss:0.3602	 Train_Acc:89.0500	 Test_Loss:1.1275	 Test_Acc:72.7000
Epoch: 7	 Train_Loss:0.3511	 Train_Acc:89.2980	 Test_Loss:1.1235	 Test_Acc:72.9000
Epoch: 8	 Train_Loss:0.3340	 Train_Acc:89.6360	 Test_Loss:1.1497	 Test_Acc:72.9500
Epoch: 9	 Train_Loss:0.3223	 Train_Acc:89.9640	 Test_Loss:1.1454	 Test_Acc:72.7900
Epoch:10	 Train_Loss:0.3139	 Train_Acc:90.3260	 Test_Loss:1.1333	 Test_Acc:73.0900
time: 15min 39s (started: 2021-09-15 19:22:54 +00:00)


In [20]:
sgd_optimizer = torch.optim.SGD(stp_model.parameters(), lr = 0.00001, momentum = 0.9)
epochs = 5

for epoch in range(epochs):
  #Fit model
  fit(stp_model, train_dataloader, sgd_optimizer, criterion)

  #Train validation
  trn_acc, trn_lss = evaluate(stp_model, train_dataloader)

  #Test validation
  tst_acc, tst_lss = evaluate(stp_model, test_dataloader)

  print(f'Epoch:{epoch+11:2.0f}\t Train_Loss:{trn_lss:.4f}\t Train_Acc:{trn_acc:.4f}\t Test_Loss:{tst_lss:.4f}\t Test_Acc:{tst_acc:.4f}')

Epoch:11	 Train_Loss:0.2683	 Train_Acc:91.6420	 Test_Loss:1.1819	 Test_Acc:73.2200
Epoch:12	 Train_Loss:0.2665	 Train_Acc:91.7580	 Test_Loss:1.1863	 Test_Acc:73.2500
Epoch:13	 Train_Loss:0.2649	 Train_Acc:91.7240	 Test_Loss:1.1833	 Test_Acc:73.1100
Epoch:14	 Train_Loss:0.2628	 Train_Acc:91.7820	 Test_Loss:1.1886	 Test_Acc:73.4000
Epoch:15	 Train_Loss:0.2626	 Train_Acc:91.8920	 Test_Loss:1.1903	 Test_Acc:73.3800
time: 15min 53s (started: 2021-09-20 20:18:23 +00:00)


# Testing

In [21]:
acc, loss = evaluate(stp_model, test_dataloader)
print('Accuracy : ',acc , '\tLoss : ', loss)

Accuracy :  73.38 	Loss :  1.1903043569652898
time: 12.3 s (started: 2021-09-20 20:34:38 +00:00)


# Save Pruned Model

In [22]:
torch.save(stp_model, 'PrunedModel_V2.pt')

time: 53.9 ms (started: 2021-09-20 20:34:50 +00:00)
