# Import Library

In [2]:
!pip install ipython-autotime

Collecting ipython-autotime
  Downloading ipython_autotime-0.3.1-py2.py3-none-any.whl (6.8 kB)
Installing collected packages: ipython-autotime
Successfully installed ipython-autotime-0.3.1


In [3]:
!pip install torch_pruning

Collecting torch_pruning
  Downloading torch_pruning-0.2.7-py3-none-any.whl (14 kB)
Installing collected packages: torch-pruning
Successfully installed torch-pruning-0.2.7


In [4]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import models,transforms,datasets
import torch.nn.utils.prune as prune
from copy import deepcopy
import torch_pruning as tp

%load_ext autotime

time: 88.8 µs (started: 2021-09-20 20:39:46 +00:00)


In [5]:
num_classes = 100
batch_size = 64

time: 3.16 ms (started: 2021-09-20 20:39:46 +00:00)


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda
time: 69.1 ms (started: 2021-09-20 20:39:46 +00:00)


In [7]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
time: 2.98 ms (started: 2021-09-20 20:39:46 +00:00)


# Important Functions

In [8]:
def get_total_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def get_pruned_parameters(pruned_model):
    params = 0
    for param in pruned_model.parameters():
        if param is not None:
            params += torch.nonzero(param).size(0)
    return params

time: 7.46 ms (started: 2021-09-20 20:39:46 +00:00)


In [9]:
def measure_module_sparsity(module, weight=True, bias=False, use_mask=False):
    num_zeros = 0
    num_elements = 0
    if use_mask == True:
        for buffer_name, buffer in module.named_buffers():
            if "weight_mask" in buffer_name and weight == True:
                num_zeros += torch.sum(buffer == 0).item()
                num_elements += buffer.nelement()
            if "bias_mask" in buffer_name and bias == True:
                num_zeros += torch.sum(buffer == 0).item()
                num_elements += buffer.nelement()
    else:
        for param_name, param in module.named_parameters():
            if "weight" in param_name and weight == True:
                num_zeros += torch.sum(param == 0).item()
                num_elements += param.nelement()
            if "bias" in param_name and bias == True:
                num_zeros += torch.sum(param == 0).item()
                num_elements += param.nelement()

    sparsity = num_zeros / num_elements

    return num_zeros, num_elements, sparsity

time: 19.7 ms (started: 2021-09-20 20:39:46 +00:00)


In [10]:
def measure_global_sparsity(model,
                            weight=True,
                            bias=False,
                            conv2d_use_mask=False,
                            linear_use_mask=False):

    num_zeros = 0
    num_elements = 0

    for module_name, module in model.named_modules():

        if isinstance(module, torch.nn.Conv2d):

            module_num_zeros, module_num_elements, _ = measure_module_sparsity(
                module, weight=weight, bias=bias, use_mask=conv2d_use_mask)
            num_zeros += module_num_zeros
            num_elements += module_num_elements

        elif isinstance(module, torch.nn.Linear):

            module_num_zeros, module_num_elements, _ = measure_module_sparsity(
                module, weight=weight, bias=bias, use_mask=linear_use_mask)
            num_zeros += module_num_zeros
            num_elements += module_num_elements

    sparsity = num_zeros / num_elements

    return num_zeros, num_elements, sparsity

time: 12.4 ms (started: 2021-09-20 20:39:46 +00:00)


In [11]:
def remove_parameters(model):
    for module_name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            try:
                prune.remove(module, "weight")
            except:
                pass
            try:
                prune.remove(module, "bias")
            except:
                pass
        elif isinstance(module, torch.nn.Linear):
            try:
                prune.remove(module, "weight")
            except:
                pass
            try:
                prune.remove(module, "bias")
            except:
                pass
    return model

time: 17.8 ms (started: 2021-09-20 20:39:46 +00:00)


In [12]:
def fit(model, dataset, optimizer, criterion):
  #Set model to training mode
  model.train()
  #Iterate over data
  for data, targets in dataset:
    data = data.to(device)
    targets = targets.to(device)
    #Reset the gradients
    optimizer.zero_grad()
    # Generate predictions
    out = model(data)
    # Calculate loss
    loss = criterion(out, targets)
    # Backpropagation
    loss.backward()
    # Update model parameters
    optimizer.step()

time: 11.9 ms (started: 2021-09-20 20:39:46 +00:00)


In [13]:
def evaluate(model, dataloader):
  model.eval()
  acc = 0
  loss = []
  for data, targets in dataloader:
    data = data.to(device)
    targets = targets.to(device)
    out = model(data)
    #Get loss
    l = criterion(out, targets)
    loss.append(l.item())
    #Get index of class label
    _,preds = torch.max(out.data,1)
    #Get accuracy
    acc += torch.sum(preds == targets).item()

  return 100*acc/len(dataloader.dataset), np.mean(np.array(loss))

time: 12.5 ms (started: 2021-09-20 20:39:46 +00:00)


In [14]:
def unstructured_global_pruning(model, amount):
  pruning_model = deepcopy(model)
  pruning_model = pruning_model.to(device)

  parameters_to_prune = []
  for module_name, module in pruning_model.named_modules():
      if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
          parameters_to_prune.append((module, "weight"))
  parameters_to_prune = parameters_to_prune[:-1]

  prune.global_unstructured(
      parameters = parameters_to_prune,
      pruning_method = prune.L1Unstructured,
      amount = amount)

  pruning_model = remove_parameters(model = pruning_model)

  return pruning_model

time: 13.2 ms (started: 2021-09-20 20:39:46 +00:00)


In [15]:
def structured_pruning(model, sparsity_layer):
  prune_model = deepcopy(model)
  prune_model = prune_model.to(device)

  prunable_module_type = (nn.Conv2d, nn.Linear)
  prunable_modules = [ m for m in prune_model.modules() if isinstance(m, prunable_module_type) ]
  prunable_modules = prunable_modules[:-1]

  strategy = tp.strategy.L1Strategy()
  DG = tp.DependencyGraph().build_dependency(prune_model, example_inputs = torch.randn(1,3,64,64))

  i = 0
  for layer_to_prune in prunable_modules:
    if isinstance( layer_to_prune, nn.Conv2d ):
        prune_fn = tp.prune_conv
    elif isinstance(layer_to_prune, nn.Linear):
        prune_fn = tp.prune_linear
    
    pruning_idxs = strategy(layer_to_prune.weight, amount = sparsity_layer[i])
    plan = DG.get_pruning_plan( layer_to_prune, prune_fn, pruning_idxs)

    i += 1

    plan.exec()
  
  return prune_model

time: 40.8 ms (started: 2021-09-20 20:39:46 +00:00)


# Import Data

In [16]:
train_transform = transforms.Compose([transforms.Resize(64),
                                   transforms.RandomHorizontalFlip(),
                                   transforms.RandomRotation(15),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.48,0.4593,0.4155),(0.2774,0.2794,0.2794))])

train_set = datasets.CIFAR100(root = "CIFAR100", train = True, download = True, transform = train_transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to CIFAR100/cifar-100-python.tar.gz


  0%|          | 0/169001437 [00:00<?, ?it/s]

Extracting CIFAR100/cifar-100-python.tar.gz to CIFAR100
time: 6.98 s (started: 2021-09-20 20:39:46 +00:00)


In [17]:
test_transform = transforms.Compose([transforms.Resize(64),
                                  transforms.ToTensor(),
                                  transforms.Normalize((0.48,0.4593,0.4155),(0.2774,0.2794,0.2794))])

test_set = datasets.CIFAR100(root = "CIFAR100", train = False, download = True, transform = test_transform)

Files already downloaded and verified
time: 942 ms (started: 2021-09-20 20:39:53 +00:00)


In [18]:
train_dataloader = DataLoader(train_set, batch_size = batch_size, shuffle = True)
test_dataloader = DataLoader(test_set, batch_size = batch_size)

time: 2.56 ms (started: 2021-09-20 20:39:54 +00:00)


In [19]:
criterion = nn.CrossEntropyLoss()

time: 1.85 ms (started: 2021-09-20 20:39:54 +00:00)


# Import Model to be Pruned

In [20]:
model = torch.load('/content/drive/MyDrive/CV Systems/Project1/Pruning/V3/PrunedModel_V3.pt')

time: 9.68 s (started: 2021-09-20 20:39:54 +00:00)


In [21]:
print(model)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 63, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(63, 47, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(47, 93, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(93, 86, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(86, 155, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(155, 134, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(134, 143, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ce

# Unstructured Pruning

In [39]:
ustp_model = unstructured_global_pruning(model, 0.2)

time: 157 ms (started: 2021-09-20 22:16:09 +00:00)


In [40]:
acc, loss = evaluate(ustp_model, test_dataloader)
print('Accuracy : ',acc , '\tLoss : ', loss)

Accuracy :  72.02 	Loss :  1.1774694274185569
time: 11.2 s (started: 2021-09-20 22:16:11 +00:00)


In [41]:
num_zeros, num_elements, sparsity = measure_global_sparsity(
            ustp_model,
            weight=True,
            bias=False,
            conv2d_use_mask=False,
            linear_use_mask=False)

print("Global Sparsity:")
print("{:.2f}".format(sparsity))

Global Sparsity:
0.19
time: 12 ms (started: 2021-09-20 22:16:25 +00:00)


In [42]:
sparsity_layer = []
i = 0
for module_name, module in ustp_model.named_modules():
  if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
    num_zeros, num_elements, sparsity = measure_module_sparsity(module, weight=True, bias=False, use_mask=False)
    print('Layer', i , num_zeros, num_elements, sparsity)
    sparsity_layer.append(sparsity)
    i+= 1

Layer 0 31 1701 0.018224573780129337
Layer 1 2813 26649 0.10555743179856655
Layer 2 3495 39339 0.08884313276900785
Layer 3 7529 71982 0.10459559334278014
Layer 4 15212 119970 0.12679836625823124
Layer 5 28666 186930 0.15335152196009202
Layer 6 23506 172458 0.13629985271776315
Layer 7 51347 319176 0.160873624583302
Layer 8 84601 435240 0.19437781453910485
Layer 9 63407 340470 0.18623373571827181
Layer 10 67499 328248 0.20563415466354706
Layer 11 67268 292716 0.22980636521406414
Layer 12 55918 239778 0.23320738349640083
Layer 13 88863 218834 0.40607492437189835
Layer 14 4163 28101 0.1481441941567916
Layer 15 0 96900 0.0
time: 31.9 ms (started: 2021-09-20 22:16:27 +00:00)


In [43]:
sparsity_layer = [ max(round(round(i,2)-0.01,2),0) for i in sparsity_layer]
sparsity_layer

[0.01,
 0.1,
 0.08,
 0.09,
 0.12,
 0.14,
 0.13,
 0.15,
 0.18,
 0.18,
 0.2,
 0.22,
 0.22,
 0.4,
 0.14,
 0]

time: 6.61 ms (started: 2021-09-20 22:16:29 +00:00)


# Structured Pruning

In [44]:
stp_model = structured_pruning(ustp_model, sparsity_layer)

time: 87.9 ms (started: 2021-09-20 22:16:33 +00:00)


In [45]:
tp.utils.count_params(stp_model)

2018653

time: 4.4 ms (started: 2021-09-20 22:16:35 +00:00)


In [46]:
print(stp_model)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 63, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(63, 43, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(43, 86, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(86, 79, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(79, 137, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(137, 116, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(116, 125, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ce

In [47]:
stp_model = stp_model.to(device)
acc, loss = evaluate(stp_model, test_dataloader)
print('Accuracy : ',acc , '\tLoss : ', loss)

Accuracy :  33.81 	Loss :  3.2660441337877018
time: 10.4 s (started: 2021-09-20 22:16:41 +00:00)


# Fine Tuning

In [48]:
sgd_optimizer = torch.optim.SGD(stp_model.parameters(), lr = 0.001, momentum = 0.9)
epochs = 10
stp_model = stp_model.to(device)

for epoch in range(epochs):
  #Fit model
  fit(stp_model, train_dataloader, sgd_optimizer, criterion)

  #Train validation
  trn_acc, trn_lss = evaluate(stp_model, train_dataloader)

  #Test validation
  tst_acc, tst_lss = evaluate(stp_model, test_dataloader)

  print(f'Epoch:{epoch+1:2.0f}\t Train_Loss:{trn_lss:2.4f}\t Train_Acc:{trn_acc:2.4f}\t Test_Loss:{tst_lss:2.4f}\t Test_Acc:{tst_acc:2.4f}')

Epoch: 1	 Train_Loss:1.1246	 Train_Acc:75.8460	 Test_Loss:1.5187	 Test_Acc:63.6300
Epoch: 2	 Train_Loss:1.0815	 Train_Acc:77.5200	 Test_Loss:1.5073	 Test_Acc:64.4900
Epoch: 3	 Train_Loss:1.0240	 Train_Acc:77.4460	 Test_Loss:1.4803	 Test_Acc:64.4600
Epoch: 4	 Train_Loss:0.9688	 Train_Acc:78.8800	 Test_Loss:1.4475	 Test_Acc:65.1400
Epoch: 5	 Train_Loss:0.8795	 Train_Acc:80.9940	 Test_Loss:1.4195	 Test_Acc:66.3600
Epoch: 6	 Train_Loss:0.9608	 Train_Acc:79.7540	 Test_Loss:1.4527	 Test_Acc:65.9900
Epoch: 7	 Train_Loss:0.9325	 Train_Acc:80.1680	 Test_Loss:1.4379	 Test_Acc:65.5000
Epoch: 8	 Train_Loss:0.8624	 Train_Acc:80.5520	 Test_Loss:1.4211	 Test_Acc:65.9200
Epoch: 9	 Train_Loss:0.9322	 Train_Acc:79.3200	 Test_Loss:1.4591	 Test_Acc:65.2200
Epoch:10	 Train_Loss:0.7955	 Train_Acc:82.8620	 Test_Loss:1.3725	 Test_Acc:67.3700
time: 26min 55s (started: 2021-09-20 22:16:55 +00:00)


In [49]:
sgd_optimizer = torch.optim.SGD(stp_model.parameters(), lr = 0.001, momentum = 0.9)
epochs = 10
stp_model = stp_model.to(device)

for epoch in range(epochs):
  #Fit model
  fit(stp_model, train_dataloader, sgd_optimizer, criterion)

  #Train validation
  trn_acc, trn_lss = evaluate(stp_model, train_dataloader)

  #Test validation
  tst_acc, tst_lss = evaluate(stp_model, test_dataloader)

  print(f'Epoch:{epoch+11:2.0f}\t Train_Loss:{trn_lss:2.4f}\t Train_Acc:{trn_acc:2.4f}\t Test_Loss:{tst_lss:2.4f}\t Test_Acc:{tst_acc:2.4f}')

Epoch:11	 Train_Loss:0.8838	 Train_Acc:81.1440	 Test_Loss:1.4257	 Test_Acc:66.3200
Epoch:12	 Train_Loss:0.8733	 Train_Acc:80.6080	 Test_Loss:1.4603	 Test_Acc:65.4700
Epoch:13	 Train_Loss:0.7653	 Train_Acc:83.1300	 Test_Loss:1.3879	 Test_Acc:66.8400
Epoch:14	 Train_Loss:0.8025	 Train_Acc:82.7200	 Test_Loss:1.4141	 Test_Acc:66.7600
Epoch:15	 Train_Loss:0.7645	 Train_Acc:83.6300	 Test_Loss:1.4034	 Test_Acc:67.1100
Epoch:16	 Train_Loss:0.7638	 Train_Acc:82.4880	 Test_Loss:1.4078	 Test_Acc:66.5500
Epoch:17	 Train_Loss:0.7685	 Train_Acc:83.8540	 Test_Loss:1.3992	 Test_Acc:66.5300
Epoch:18	 Train_Loss:0.7446	 Train_Acc:84.1040	 Test_Loss:1.3703	 Test_Acc:67.1700
Epoch:19	 Train_Loss:0.7591	 Train_Acc:83.9820	 Test_Loss:1.4110	 Test_Acc:66.8000
Epoch:20	 Train_Loss:0.7283	 Train_Acc:84.1960	 Test_Loss:1.4195	 Test_Acc:66.3300
time: 26min 34s (started: 2021-09-20 22:44:28 +00:00)


In [50]:
sgd_optimizer = torch.optim.SGD(stp_model.parameters(), lr = 0.0001, momentum = 0.9)
epochs = 10
stp_model = stp_model.to(device)

for epoch in range(epochs):
  #Fit model
  fit(stp_model, train_dataloader, sgd_optimizer, criterion)

  #Train validation
  trn_acc, trn_lss = evaluate(stp_model, train_dataloader)

  #Test validation
  tst_acc, tst_lss = evaluate(stp_model, test_dataloader)

  print(f'Epoch:{epoch+21:2.0f}\t Train_Loss:{trn_lss:2.4f}\t Train_Acc:{trn_acc:2.4f}\t Test_Loss:{tst_lss:2.4f}\t Test_Acc:{tst_acc:2.4f}')

Epoch:21	 Train_Loss:0.5267	 Train_Acc:88.6980	 Test_Loss:1.2794	 Test_Acc:69.4500
Epoch:22	 Train_Loss:0.5088	 Train_Acc:89.0960	 Test_Loss:1.2766	 Test_Acc:69.4900
Epoch:23	 Train_Loss:0.4916	 Train_Acc:89.4540	 Test_Loss:1.2661	 Test_Acc:69.9600
Epoch:24	 Train_Loss:0.4848	 Train_Acc:89.3520	 Test_Loss:1.2718	 Test_Acc:69.7200
Epoch:25	 Train_Loss:0.4720	 Train_Acc:89.6780	 Test_Loss:1.2759	 Test_Acc:69.9700
Epoch:26	 Train_Loss:0.4699	 Train_Acc:89.4980	 Test_Loss:1.2839	 Test_Acc:69.5200
Epoch:27	 Train_Loss:0.4611	 Train_Acc:89.9620	 Test_Loss:1.2763	 Test_Acc:69.9100
Epoch:28	 Train_Loss:0.4521	 Train_Acc:89.9520	 Test_Loss:1.2830	 Test_Acc:69.9900
Epoch:29	 Train_Loss:0.4512	 Train_Acc:90.0140	 Test_Loss:1.2826	 Test_Acc:69.7000
Epoch:30	 Train_Loss:0.4425	 Train_Acc:90.2080	 Test_Loss:1.2875	 Test_Acc:69.8300
time: 26min 57s (started: 2021-09-20 23:11:04 +00:00)


In [54]:
sgd_optimizer = torch.optim.SGD(stp_model.parameters(), lr = 0.00001, momentum = 0.9)
epochs = 5
stp_model = stp_model.to(device)

for epoch in range(epochs):
  #Fit model
  fit(stp_model, train_dataloader, sgd_optimizer, criterion)

  #Train validation
  trn_acc, trn_lss = evaluate(stp_model, train_dataloader)

  #Test validation
  tst_acc, tst_lss = evaluate(stp_model, test_dataloader)

  print(f'Epoch:{epoch+31:2.0f}\t Train_Loss:{trn_lss:2.4f}\t Train_Acc:{trn_acc:2.4f}\t Test_Loss:{tst_lss:2.4f}\t Test_Acc:{tst_acc:2.4f}')

Epoch:31	 Train_Loss:0.4336	 Train_Acc:90.3460	 Test_Loss:1.2791	 Test_Acc:69.9200
Epoch:32	 Train_Loss:0.4293	 Train_Acc:90.5520	 Test_Loss:1.2778	 Test_Acc:69.9800
Epoch:33	 Train_Loss:0.4275	 Train_Acc:90.6640	 Test_Loss:1.2777	 Test_Acc:69.9600
Epoch:34	 Train_Loss:0.4272	 Train_Acc:90.6860	 Test_Loss:1.2763	 Test_Acc:70.0600
Epoch:35	 Train_Loss:0.4266	 Train_Acc:90.5700	 Test_Loss:1.2756	 Test_Acc:70.1500
time: 13min 23s (started: 2021-09-20 23:38:46 +00:00)


# Testing

In [55]:
acc, loss = evaluate(stp_model, test_dataloader)
print('Accuracy : ',acc , '\tLoss : ', loss)

Accuracy :  70.15 	Loss :  1.2756433906448874
time: 10.4 s (started: 2021-09-20 23:52:10 +00:00)


# Save Pruned Model

In [56]:
torch.save(stp_model, 'PrunedModel_V4.pt')

time: 29.6 ms (started: 2021-09-20 23:52:20 +00:00)
