# Import Library

In [2]:
!pip install ipython-autotime

Collecting ipython-autotime
  Downloading ipython_autotime-0.3.1-py2.py3-none-any.whl (6.8 kB)
Installing collected packages: ipython-autotime
Successfully installed ipython-autotime-0.3.1


In [3]:
!pip install torch_pruning

Collecting torch_pruning
  Downloading torch_pruning-0.2.7-py3-none-any.whl (14 kB)
Installing collected packages: torch-pruning
Successfully installed torch-pruning-0.2.7


In [4]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import models,transforms,datasets
import torch.nn.utils.prune as prune
from copy import deepcopy
import torch_pruning as tp

%load_ext autotime

time: 84.4 µs (started: 2021-09-20 07:59:09 +00:00)


In [5]:
num_classes = 100
batch_size = 64

time: 914 µs (started: 2021-09-20 07:59:09 +00:00)


In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda
time: 82.5 ms (started: 2021-09-20 07:59:09 +00:00)


# Important Functions

In [7]:
def get_total_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def get_pruned_parameters(pruned_model):
    params = 0
    for param in pruned_model.parameters():
        if param is not None:
            params += torch.nonzero(param).size(0)
    return params

time: 5.14 ms (started: 2021-09-20 07:59:09 +00:00)


In [8]:
def measure_module_sparsity(module, weight=True, bias=False, use_mask=False):
    num_zeros = 0
    num_elements = 0
    if use_mask == True:
        for buffer_name, buffer in module.named_buffers():
            if "weight_mask" in buffer_name and weight == True:
                num_zeros += torch.sum(buffer == 0).item()
                num_elements += buffer.nelement()
            if "bias_mask" in buffer_name and bias == True:
                num_zeros += torch.sum(buffer == 0).item()
                num_elements += buffer.nelement()
    else:
        for param_name, param in module.named_parameters():
            if "weight" in param_name and weight == True:
                num_zeros += torch.sum(param == 0).item()
                num_elements += param.nelement()
            if "bias" in param_name and bias == True:
                num_zeros += torch.sum(param == 0).item()
                num_elements += param.nelement()

    sparsity = num_zeros / num_elements

    return num_zeros, num_elements, sparsity

time: 19.2 ms (started: 2021-09-20 07:59:10 +00:00)


In [9]:
def measure_global_sparsity(model,
                            weight=True,
                            bias=False,
                            conv2d_use_mask=False,
                            linear_use_mask=False):

    num_zeros = 0
    num_elements = 0

    for module_name, module in model.named_modules():

        if isinstance(module, torch.nn.Conv2d):

            module_num_zeros, module_num_elements, _ = measure_module_sparsity(
                module, weight=weight, bias=bias, use_mask=conv2d_use_mask)
            num_zeros += module_num_zeros
            num_elements += module_num_elements

        elif isinstance(module, torch.nn.Linear):

            module_num_zeros, module_num_elements, _ = measure_module_sparsity(
                module, weight=weight, bias=bias, use_mask=linear_use_mask)
            num_zeros += module_num_zeros
            num_elements += module_num_elements

    sparsity = num_zeros / num_elements

    return num_zeros, num_elements, sparsity

time: 14.6 ms (started: 2021-09-20 07:59:10 +00:00)


In [10]:
def remove_parameters(model):
    for module_name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            try:
                prune.remove(module, "weight")
            except:
                pass
            try:
                prune.remove(module, "bias")
            except:
                pass
        elif isinstance(module, torch.nn.Linear):
            try:
                prune.remove(module, "weight")
            except:
                pass
            try:
                prune.remove(module, "bias")
            except:
                pass
    return model

time: 10.1 ms (started: 2021-09-20 07:59:10 +00:00)


In [11]:
def fit(model, dataset, optimizer, criterion):
  #Set model to training mode
  model.train()
  #Iterate over data
  for data, targets in dataset:
    data = data.to(device)
    targets = targets.to(device)
    #Reset the gradients
    optimizer.zero_grad()
    # Generate predictions
    out = model(data)
    # Calculate loss
    loss = criterion(out, targets)
    # Backpropagation
    loss.backward()
    # Update model parameters
    optimizer.step()

time: 12 ms (started: 2021-09-20 07:59:10 +00:00)


In [12]:
def evaluate(model, dataloader):
  model.eval()
  acc = 0
  loss = []
  for data, targets in dataloader:
    data = data.to(device)
    targets = targets.to(device)
    out = model(data)
    #Get loss
    l = criterion(out, targets)
    loss.append(l.item())
    #Get index of class label
    _,preds = torch.max(out.data,1)
    #Get accuracy
    acc += torch.sum(preds == targets).item()

  return 100*acc/len(dataloader.dataset), np.mean(np.array(loss))

time: 9.5 ms (started: 2021-09-20 07:59:10 +00:00)


In [13]:
def unstructured_global_pruning(model, amount):
  pruning_model = deepcopy(model)
  pruning_model = pruning_model.to(device)

  parameters_to_prune = []
  for module_name, module in pruning_model.named_modules():
      if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
          parameters_to_prune.append((module, "weight"))
  parameters_to_prune = parameters_to_prune[:-1]

  prune.global_unstructured(
      parameters = parameters_to_prune,
      pruning_method = prune.L1Unstructured,
      amount = amount)

  pruning_model = remove_parameters(model = pruning_model)

  return pruning_model

time: 10.8 ms (started: 2021-09-20 07:59:10 +00:00)


In [14]:
def structured_pruning(model, sparsity_layer):
  prune_model = deepcopy(model)
  prune_model = prune_model.to(device)

  prunable_module_type = (nn.Conv2d, nn.Linear)
  prunable_modules = [ m for m in prune_model.modules() if isinstance(m, prunable_module_type) ]
  prunable_modules = prunable_modules[:-1]

  strategy = tp.strategy.L1Strategy()
  DG = tp.DependencyGraph().build_dependency(prune_model, example_inputs = torch.randn(1,3,64,64))

  i = 0
  for layer_to_prune in prunable_modules:
    if isinstance( layer_to_prune, nn.Conv2d ):
        prune_fn = tp.prune_conv
    elif isinstance(layer_to_prune, nn.Linear):
        prune_fn = tp.prune_linear
    
    pruning_idxs = strategy(layer_to_prune.weight, amount = sparsity_layer[i])
    plan = DG.get_pruning_plan( layer_to_prune, prune_fn, pruning_idxs)

    i += 1

    plan.exec()
  
  return prune_model

time: 24.9 ms (started: 2021-09-20 07:59:10 +00:00)


# Import Data

In [15]:
train_transform = transforms.Compose([transforms.Resize(64),
                                   transforms.RandomHorizontalFlip(),
                                   transforms.RandomRotation(15),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.48,0.4593,0.4155),(0.2774,0.2794,0.2794))])

train_set = datasets.CIFAR100(root = "CIFAR100", train = True, download = True, transform = train_transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to CIFAR100/cifar-100-python.tar.gz


  0%|          | 0/169001437 [00:00<?, ?it/s]

Extracting CIFAR100/cifar-100-python.tar.gz to CIFAR100
time: 6.81 s (started: 2021-09-20 07:59:10 +00:00)


In [16]:
test_transform = transforms.Compose([transforms.Resize(64),
                                  transforms.ToTensor(),
                                  transforms.Normalize((0.48,0.4593,0.4155),(0.2774,0.2794,0.2794))])

test_set = datasets.CIFAR100(root = "CIFAR100", train = False, download = True, transform = test_transform)

Files already downloaded and verified
time: 902 ms (started: 2021-09-20 07:59:17 +00:00)


In [17]:
train_dataloader = DataLoader(train_set, batch_size = batch_size, shuffle = True)
test_dataloader = DataLoader(test_set, batch_size = batch_size)

time: 2.4 ms (started: 2021-09-20 07:59:17 +00:00)


In [18]:
criterion = nn.CrossEntropyLoss()

time: 4.13 ms (started: 2021-09-20 07:59:17 +00:00)


# Import Model to be Pruned

In [None]:
model = torch.load('/content/drive/MyDrive/Sai/PrunedModel_V2.pt')

time: 9.89 s (started: 2021-09-20 03:34:01 +00:00)


In [None]:
print(model)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 63, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(63, 51, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(51, 101, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(101, 95, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(95, 176, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(176, 159, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(159, 166, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, 

# Unstructured Pruning

In [None]:
ustp_model = unstructured_global_pruning(model, 0.3)

time: 422 ms (started: 2021-09-20 03:34:11 +00:00)


In [None]:
acc, loss = evaluate(ustp_model, test_dataloader)
print('Accuracy : ',acc , '\tLoss : ', loss)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Accuracy :  72.95 	Loss :  1.0824515999882085
time: 11.9 s (started: 2021-09-20 03:34:11 +00:00)


In [None]:
num_zeros, num_elements, sparsity = measure_global_sparsity(
            ustp_model,
            weight=True,
            bias=False,
            conv2d_use_mask=False,
            linear_use_mask=False)

print("Global Sparsity:")
print("{:.2f}".format(sparsity))

Global Sparsity:
0.29
time: 7.55 ms (started: 2021-09-20 03:34:23 +00:00)


In [None]:
sparsity_layer = []
i = 0
for module_name, module in ustp_model.named_modules():
  if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
    num_zeros, num_elements, sparsity = measure_module_sparsity(module, weight=True, bias=False, use_mask=False)
    print('Layer', i , num_zeros, num_elements, sparsity)
    sparsity_layer.append(sparsity)
    i+= 1

Layer 0 26 1701 0.015285126396237508
Layer 1 2911 28917 0.10066742746481308
Layer 2 4070 46359 0.08779309303479367
Layer 3 9182 86355 0.10632852758960107
Layer 4 20062 150480 0.13332004253056884
Layer 5 41999 251856 0.16675798869195096
Layer 6 35759 237546 0.1505350542631743
Layer 7 80301 445212 0.18036575833535484
Layer 8 152107 667818 0.2277671461386186
Layer 9 120759 549045 0.21994372046007157
Layer 10 130232 535815 0.24305403917396862
Layer 11 137134 509571 0.2691165706054701
Layer 12 124564 440370 0.2828621386561301
Layer 13 606203 895230 0.6771477720809177
Layer 14 19398 102747 0.18879383339659553
Layer 15 0 118100 0.0
time: 18.1 ms (started: 2021-09-20 03:34:23 +00:00)


In [None]:
sparsity_layer = [ max(round(round(i,2)-0.01,2),0) for i in sparsity_layer]
sparsity_layer

[0.01,
 0.09,
 0.08,
 0.1,
 0.12,
 0.16,
 0.14,
 0.17,
 0.22,
 0.21,
 0.23,
 0.26,
 0.27,
 0.67,
 0.18,
 0]

time: 11.8 ms (started: 2021-09-20 03:34:23 +00:00)


# Structured Pruning

In [None]:
stp_model = structured_pruning(ustp_model, sparsity_layer)

time: 306 ms (started: 2021-09-20 03:34:23 +00:00)


In [None]:
tp.utils.count_params(stp_model)

2921463

time: 7.72 ms (started: 2021-09-20 03:34:24 +00:00)


In [None]:
print(stp_model)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 63, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(63, 47, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(47, 93, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(93, 86, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(86, 155, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(155, 134, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(134, 143, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ce

In [None]:
stp_model = stp_model.to(device)
acc, loss = evaluate(stp_model, test_dataloader)
print('Accuracy : ',acc , '\tLoss : ', loss)

Accuracy :  27.13 	Loss :  3.699010852036203
time: 10.7 s (started: 2021-09-20 03:34:24 +00:00)


# Fine Tuning

In [None]:
sgd_optimizer = torch.optim.SGD(stp_model.parameters(), lr = 0.001, momentum = 0.9)
epochs = 10
stp_model = stp_model.to(device)

for epoch in range(epochs):
  #Fit model
  fit(stp_model, train_dataloader, sgd_optimizer, criterion)

  #Train validation
  trn_acc, trn_lss = evaluate(stp_model, train_dataloader)

  #Test validation
  tst_acc, tst_lss = evaluate(stp_model, test_dataloader)

  print(f'Epoch:{epoch+1:2.0f}\t Train_Loss:{trn_lss:2.4f}\t Train_Acc:{trn_acc:2.4f}\t Test_Loss:{tst_lss:2.4f}\t Test_Acc:{tst_acc:2.4f}')

Epoch: 1	 Train_Loss:1.1268	 Train_Acc:70.8960	 Test_Loss:1.4153	 Test_Acc:61.8900
Epoch: 2	 Train_Loss:0.9984	 Train_Acc:75.0400	 Test_Loss:1.3162	 Test_Acc:64.9500
Epoch: 3	 Train_Loss:0.9259	 Train_Acc:75.7380	 Test_Loss:1.3045	 Test_Acc:64.9200
Epoch: 4	 Train_Loss:0.8981	 Train_Acc:76.2620	 Test_Loss:1.2949	 Test_Acc:66.0100
Epoch: 5	 Train_Loss:0.8581	 Train_Acc:77.6020	 Test_Loss:1.2821	 Test_Acc:66.2300
Epoch: 6	 Train_Loss:0.8341	 Train_Acc:78.3780	 Test_Loss:1.2872	 Test_Acc:66.5500
Epoch: 7	 Train_Loss:0.7925	 Train_Acc:78.7560	 Test_Loss:1.2962	 Test_Acc:66.0800
Epoch: 8	 Train_Loss:0.7616	 Train_Acc:80.2640	 Test_Loss:1.2416	 Test_Acc:67.8000
Epoch: 9	 Train_Loss:0.7793	 Train_Acc:79.5740	 Test_Loss:1.2917	 Test_Acc:66.7600
Epoch:10	 Train_Loss:0.7283	 Train_Acc:81.3200	 Test_Loss:1.2439	 Test_Acc:67.6900
time: 28min 46s (started: 2021-09-20 03:34:46 +00:00)


In [None]:
sgd_optimizer = torch.optim.SGD(stp_model.parameters(), lr = 0.001, momentum = 0.9)
epochs = 10
stp_model = stp_model.to(device)

for epoch in range(epochs):
  #Fit model
  fit(stp_model, train_dataloader, sgd_optimizer, criterion)

  #Train validation
  trn_acc, trn_lss = evaluate(stp_model, train_dataloader)

  #Test validation
  tst_acc, tst_lss = evaluate(stp_model, test_dataloader)

  print(f'Epoch:{epoch+11:2.0f}\t Train_Loss:{trn_lss:2.4f}\t Train_Acc:{trn_acc:2.4f}\t Test_Loss:{tst_lss:2.4f}\t Test_Acc:{tst_acc:2.4f}')

Epoch:11	 Train_Loss:0.6775	 Train_Acc:81.9980	 Test_Loss:1.2298	 Test_Acc:68.1400
Epoch:12	 Train_Loss:0.6716	 Train_Acc:82.2080	 Test_Loss:1.2375	 Test_Acc:68.0900
Epoch:13	 Train_Loss:0.6982	 Train_Acc:82.1140	 Test_Loss:1.2631	 Test_Acc:67.9400
Epoch:14	 Train_Loss:0.6540	 Train_Acc:83.0020	 Test_Loss:1.2878	 Test_Acc:67.4400
Epoch:15	 Train_Loss:0.6475	 Train_Acc:83.0520	 Test_Loss:1.2876	 Test_Acc:67.9400
Epoch:16	 Train_Loss:0.6360	 Train_Acc:83.8620	 Test_Loss:1.2658	 Test_Acc:68.5400
Epoch:17	 Train_Loss:0.5882	 Train_Acc:85.1840	 Test_Loss:1.2471	 Test_Acc:69.2500
Epoch:18	 Train_Loss:0.5807	 Train_Acc:84.9120	 Test_Loss:1.2453	 Test_Acc:68.8000
Epoch:19	 Train_Loss:0.5917	 Train_Acc:85.0360	 Test_Loss:1.3006	 Test_Acc:68.6700
Epoch:20	 Train_Loss:0.5530	 Train_Acc:85.8000	 Test_Loss:1.3027	 Test_Acc:68.2400
time: 28min 47s (started: 2021-09-20 04:04:59 +00:00)


In [None]:
sgd_optimizer = torch.optim.SGD(stp_model.parameters(), lr = 0.0001, momentum = 0.9)
epochs = 10
stp_model = stp_model.to(device)

for epoch in range(epochs):
  #Fit model
  fit(stp_model, train_dataloader, sgd_optimizer, criterion)

  #Train validation
  trn_acc, trn_lss = evaluate(stp_model, train_dataloader)

  #Test validation
  tst_acc, tst_lss = evaluate(stp_model, test_dataloader)

  print(f'Epoch:{epoch+21:2.0f}\t Train_Loss:{trn_lss:2.4f}\t Train_Acc:{trn_acc:2.4f}\t Test_Loss:{tst_lss:2.4f}\t Test_Acc:{tst_acc:2.4f}')

Epoch:21	 Train_Loss:0.4130	 Train_Acc:89.6100	 Test_Loss:1.1706	 Test_Acc:71.1800
Epoch:22	 Train_Loss:0.3907	 Train_Acc:90.0520	 Test_Loss:1.1719	 Test_Acc:71.2400
Epoch:23	 Train_Loss:0.3789	 Train_Acc:90.4460	 Test_Loss:1.1699	 Test_Acc:71.8500
Epoch:24	 Train_Loss:0.3711	 Train_Acc:90.4500	 Test_Loss:1.1764	 Test_Acc:71.8200
Epoch:25	 Train_Loss:0.3660	 Train_Acc:90.5540	 Test_Loss:1.1789	 Test_Acc:71.6600
Epoch:26	 Train_Loss:0.3616	 Train_Acc:90.7740	 Test_Loss:1.1824	 Test_Acc:71.7700
Epoch:27	 Train_Loss:0.3496	 Train_Acc:91.1840	 Test_Loss:1.1786	 Test_Acc:71.8300
Epoch:28	 Train_Loss:0.3436	 Train_Acc:91.0280	 Test_Loss:1.1836	 Test_Acc:71.7100
Epoch:29	 Train_Loss:0.3431	 Train_Acc:91.1860	 Test_Loss:1.1813	 Test_Acc:71.7800
Epoch:30	 Train_Loss:0.3356	 Train_Acc:91.5020	 Test_Loss:1.1885	 Test_Acc:71.8200
time: 28min 51s (started: 2021-09-20 04:34:55 +00:00)


In [22]:
sgd_optimizer = torch.optim.SGD(stp_model.parameters(), lr = 0.00001, momentum = 0.9)
epochs = 5
stp_model = stp_model.to(device)

for epoch in range(epochs):
  #Fit model
  fit(stp_model, train_dataloader, sgd_optimizer, criterion)

  #Train validation
  trn_acc, trn_lss = evaluate(stp_model, train_dataloader)

  #Test validation
  tst_acc, tst_lss = evaluate(stp_model, test_dataloader)

  print(f'Epoch:{epoch+31:2.0f}\t Train_Loss:{trn_lss:2.4f}\t Train_Acc:{trn_acc:2.4f}\t Test_Loss:{tst_lss:2.4f}\t Test_Acc:{tst_acc:2.4f}')

Epoch:31	 Train_Loss:0.3311	 Train_Acc:91.7280	 Test_Loss:1.1821	 Test_Acc:71.9600
Epoch:32	 Train_Loss:0.3286	 Train_Acc:91.6280	 Test_Loss:1.1844	 Test_Acc:71.9800
Epoch:33	 Train_Loss:0.3265	 Train_Acc:91.7120	 Test_Loss:1.1825	 Test_Acc:72.0900
Epoch:34	 Train_Loss:0.3252	 Train_Acc:91.7360	 Test_Loss:1.1808	 Test_Acc:72.1400
Epoch:35	 Train_Loss:0.3250	 Train_Acc:91.6940	 Test_Loss:1.1816	 Test_Acc:72.1400
time: 14min 26s (started: 2021-09-20 08:00:43 +00:00)


# Testing

In [23]:
acc, loss = evaluate(stp_model, test_dataloader)
print('Accuracy : ',acc , '\tLoss : ', loss)

Accuracy :  72.14 	Loss :  1.181633984017524
time: 11.1 s (started: 2021-09-20 08:16:09 +00:00)


# Save Pruned Model

In [24]:
torch.save(stp_model, 'PrunedModel_V3.pt')

time: 37.7 ms (started: 2021-09-20 08:16:21 +00:00)
