# Model compression

Deep neural networks are the state-of-the-art models for many tasks, especially with unstructured data like images, text etc. One issue with deep learning models is that they are highly overparamterized, making it difficult to deploy/use it in devices with low memory requirements. Compress the neural networks model without sacrifying the model performance is an active reasearch topic in machine learning. Parameter pruning,quantization,knowledge distillation etc  are the widely used approaches nowadays to make the neural networks compact. In this notebook, we will demonstrate a parameter pruning technique on a fine-tuned Resnet50 model. 

## Model Pruning

Pruning is a technique used for reducing redundant parameters which are not sensitive to the performance. One way to reduce parameters while training is by using L-1 regularisation. But when using L1 regularisation, there is not much control over the number of parameters that needs to be pruned.  There were earlier works in deep learning that tried out various pruning techniques, for example by Yann Le Cun et. al, in the paper [**Optimal Brain Damage**](https://proceedings.neurips.cc/paper/1989/hash/6c9882bbac1c7093bd25041881277658-Abstract.html). The work tried to reduce the number of connections based on the Hessian of the loss function. Nowadays there exist different methods for pruning a neural networks. There are approaches that take care of sparsity while training the model like [**winning lottery ticket hypothesis**](https://arxiv.org/abs/1803.03635) and more hardware-efficient pruning methods like [**N:M Fine-grained Structured Sparse Neural Networks**](https://arxiv.org/abs/2102.04010). In his notebook, we will demonstrate a much simpler pruning method called **magnitude weight pruning.**

### Magnitude Weight pruning

Ref : https://arxiv.org/pdf/1506.02626.pdf

In Magnitude weight pruning technique, weights based on desired sparsity level is pruned from a neural network after training. The assumption is that weights with larger absolute values are important. In magnitude pruning, if we want 'X%' sparsity, we prune the smaller weights based on the absolute value to attain the desired sparsity. This can be done layer-wise (local), or on entire-model (global). While doing global pruning, one must ensure that not all the weights in a single layer is pruned, which affects the layerwise communication in neural network. To improve the model performance, the model is often fine-tuned with existing weights. If this process is carried iteratively, it is called "**iterative magnitude pruning**". There are many different ways to prune the weights like pattern-based pruning, vector-based pruning, vector-level pruning, kernel-level pruning, channel-level pruning etc but they are beyond the scope of this notebook. To make the pruning effective and hardware efficient, one may have to use some of these pruning techniques than the simple magnitude-based pruning.





# Problem statement

This is a multiclass image classification problem. There data contains images from 6 categories 'buildings','forest','glacier','mountain','sea','street'. The aim is to develop a machine learning model that correctly classifies an input image into one of the categories.

In this notebook, we try to find a model with a fixed % of sparsity, without lossing much expressivity of the original dense model.

## Modeling Approach : Transfer learning

Ref : https://www.kaggle.com/code/darraghcaffrey/transfer-learning-with-resnet50-91-5-test-acc <br>
      https://arxiv.org/abs/1512.03385

We use pretrained Resnet-50 model and finetune on the trainset. 

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import OrderedDict
import cv2
import os
from PIL import Image


import torch
from torch import optim
from torch.autograd import Variable
from torch.utils.data import random_split, DataLoader
from torch import nn
import torch.nn.functional as F
from torchvision.utils import make_grid
from torchvision import transforms, models, datasets
import torch.nn.utils.prune as prune

In [None]:
train_dir = "../input/intel-image-classification/seg_train/seg_train/"
test_dir = "../input/intel-image-classification/seg_test/seg_test/"
pred_dir ="../input/intel-image-classification/seg_pred/seg_pred/"


pred_files = [os.path.join(pred_dir, f) for f in os.listdir(pred_dir)]

In [None]:
cat_counts = {}
for cat in os.listdir(train_dir):
    counts = len(os.listdir(os.path.join(train_dir, cat)))
    cat_counts[cat] =counts
print(cat_counts)

In [None]:
# using mean and std for which Resnet was trained on
mean = [0.485, 0.456, 0.406] 
std = [0.229, 0.224, 0.225]

In [None]:
# Torchvision train and test transforms 
train_transforms = transforms.Compose([transforms.Resize((150, 150)), # Resize all images 
                                       transforms.RandomResizedCrop(150),# Crop
                                       transforms.RandomRotation(30), # Rotate 
                                       transforms.RandomHorizontalFlip(), # Flip
                                       transforms.ToTensor(), # Convert
                                       transforms.Normalize(torch.Tensor(mean), torch.Tensor(std)) # Normalize
                                       ])



test_transforms = transforms.Compose([transforms.Resize((150, 150)),
                                     transforms.CenterCrop(150),
                                     transforms.ToTensor(),
                                     transforms.Normalize(torch.Tensor(mean),torch.Tensor(std))
                                     ])

# Tmp torchvision datasets.Image folder to split into train and validation sets
tmp_data = datasets.ImageFolder(train_dir, transform=train_transforms)
# len(tmp_data): 14034

# Randomsplit tmp data based on length of dataset and set seed for reproducable split
train_data, val_data = random_split(tmp_data, [10000, 4034], generator=torch.Generator().manual_seed(42))
# Test set with with test transforms 
test_data = datasets.ImageFolder(test_dir, transform=test_transforms)


# Set Pytorch dataloaders, batch_size, training set shuffle
train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=64)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=64)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device


In [None]:
# # Capture installation text
# %%capture
# Import resnet50
model = models.resnet50(pretrained=True)
# Freeze model params 
for param in model.parameters():
    param.required_grad = False
# Pull final fc layer feature dimensions
features = model.fc.in_features


# Build custom classifier which reduces Resnets 1000 out_features to 6
classifier = nn.Sequential(OrderedDict([('fc1', nn.Linear(features, 512)),
                                        ('relu', nn.ReLU()),
                                        ('drop', nn.Dropout(0.05)),
                                        ('fc2', nn.Linear(512, 6)),
                                        ]))

# ('output', nn.LogSoftmax(dim=1)) - NLLLoss
# Appending classifier layer to Resnet
model.classifier = classifier
# Pushing the model to cuda
model.to(device)

In [None]:
# Define criterion and optimizer
criterion = nn.CrossEntropyLoss()
# Pass the optimizer to the appended classifier layer
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
# Set scheduler
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 15], gamma=0.05)
# scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.1, steps_per_epoch=10, epochs=20)

In [None]:
epochs = 20


tr_losses = []
avg_epoch_tr_loss = []
tr_accuracy = []


val_losses = []
avg_epoch_val_loss = []
val_accuracy = []
val_loss_min = np.Inf

for epoch in range(epochs):
    model.train()
    for i, batch in enumerate(train_loader):
        # Pull the data and labels from the batch
        data, label = batch
        # If available push data and label to GPU
        #    if train_on_gpu:
        data, label = data.to(device), label.to(device)
        # Clear the gradient
        optimizer.zero_grad()
        # Compute the logit
        logit = model(data)
        # Compte loss
        loss = criterion(logit, label)
        # Backpropagate the gradients (accumulte the partial derivatives of loss)
        loss.backward()
        # Apply the updates to the optimizer step in the opposite direction to the gradient
        optimizer.step()
        # Store the losses of each batch
        # loss.item() seperates the loss from comp graph
        tr_losses.append(loss.item())
        # Detach and store the average accuracy of each batch
        tr_accuracy.append(label.eq(logit.argmax(dim=1)).float().mean())
        # Print the rolling batch training loss every 20 batches
    #     if i % 40 == 0:
    #       print(f'Batch No: {i} \tAverage Training Batch Loss: {torch.tensor(tr_losses).mean():.2f}')
    # Print the average loss for each epoch
    print(
        f"\nEpoch No: {epoch + 1},Training Loss: {torch.tensor(tr_losses).mean():.2f}"
    )
    # Print the average accuracy for each epoch
    print(
        f"Epoch No: {epoch + 1}, Training Accuracy: {torch.tensor(tr_accuracy).mean():.2f}\n"
    )
    # Store the avg epoch loss for plotting
    avg_epoch_tr_loss.append(torch.tensor(tr_losses).mean())

    model.eval()
    for i, batch in enumerate(val_loader):
        # Pull the data and labels from the batch
        data, label = batch
        # If available push data and label to GPU
        data, label = data.to(device), label.to(device)
        # Compute the logits without computing the gradients
        with torch.no_grad():
            logit = model(data)
        # Compte loss
        loss = criterion(logit, label)
        # Store test loss
        val_losses.append(loss.item())
        # Store the accuracy for each batch
        val_accuracy.append(label.eq(logit.argmax(dim=1)).float().mean())
        # if i % 40 == 0:
        # print(f'Batch No: {i} \tAverage Val Batch Loss: {torch.tensor(val_losses).mean():.2f}')
    # Print the average loss for each epoch
    print(
        f"\nEpoch No: {epoch + 1}, Epoch Val Loss: {torch.tensor(val_losses).mean():.2f}"
    )
    # Print the average accuracy for each epoch
    print(
        f"Epoch No: {epoch + 1}, Epoch Val Accuracy: {torch.tensor(val_accuracy).mean():.2f}\n"
    )
    # Store the avg epoch loss for plotting
    avg_epoch_val_loss.append(torch.tensor(val_losses).mean())

    # Checpoininting the model using val loss threshold
    if torch.tensor(val_losses).float().mean() <= val_loss_min:
        print("Val Loss Decreased... Saving model")
        # save current model
        torch.save(model.state_dict(), "./model_state.pt")
        val_loss_min = torch.tensor(val_losses).mean()
    # Step the scheduler for the next epoch
    scheduler.step()
    # Print the updated learning rate
    print(
        "Learning Rate Set To: {:.10f}".format(
            optimizer.state_dict()["param_groups"][0]["lr"]
        ),
        "\n",
    )

# Magnitude weight Pruning in Pytorch

Pytorch provides support for implementing different pruning techniques. We will use the pytorch module for implementing model pruning.

Ref : https://pytorch.org/tutorials/intermediate/pruning_tutorial.html

Firstly, we will calculate the sparsity existing in resnet50 model. For this task, we won't be considering any bias parameters.

## Global pruning

In [None]:
def calc_global_sparsity_resnet50(model):
    layers = [
        model.conv1,
        *[l.conv1 for l in model.layer1], *[l.conv2 for l in model.layer1], *[l.conv3 for l in model.layer1], *[l.downsample[0] for l in model.layer1 if l.downsample is not None],
        *[l.conv1 for l in model.layer2], *[l.conv2 for l in model.layer2], *[l.conv3 for l in model.layer2], *[l.downsample[0] for l in model.layer2 if l.downsample is not None],
        *[l.conv1 for l in model.layer3], *[l.conv2 for l in model.layer3], *[l.conv3 for l in model.layer3], *[l.downsample[0] for l in model.layer3 if l.downsample is not None],
        *[l.conv1 for l in model.layer4], *[l.conv2 for l in model.layer4], *[l.conv3 for l in model.layer4], *[l.downsample[0] for l in model.layer4 if l.downsample is not None],
        model.fc,
        model.classifier.fc1,
        model.classifier.fc2
    ]
    
    num_zero_weights = sum((layer.weight == 0).sum().item() for layer in layers)
    total_weights = sum(layer.weight.nelement() for layer in layers)
    
    return 100. * num_zero_weights / total_weights


In [None]:
print("sparsity of the model before pruning: {:.2f}%".format(calc_global_sparsity_resnet50(model)))

Now, we will prune overall 20% of weights from the finetuned model.

In [None]:
parameters_to_prune = [
        (model.conv1, 'weight'),
        *[(l.conv1, 'weight') for l in model.layer1], *[(l.conv2, 'weight') for l in model.layer1], *[(l.conv3, 'weight') for l in model.layer1], *[(l.downsample[0], 'weight') for l in model.layer1 if l.downsample is not None],
        *[(l.conv1, 'weight') for l in model.layer2], *[(l.conv2, 'weight') for l in model.layer2], *[(l.conv3, 'weight') for l in model.layer2], *[(l.downsample[0], 'weight') for l in model.layer2 if l.downsample is not None],
        *[(l.conv1, 'weight') for l in model.layer3], *[(l.conv2, 'weight') for l in model.layer3], *[(l.conv3, 'weight') for l in model.layer3], *[(l.downsample[0], 'weight') for l in model.layer3 if l.downsample is not None],
        *[(l.conv1, 'weight') for l in model.layer4], *[(l.conv2, 'weight') for l in model.layer4], *[(l.conv3, 'weight') for l in model.layer4], *[(l.downsample[0], 'weight') for l in model.layer4 if l.downsample is not None],
        (model.fc, 'weight'),
        (model.classifier.fc1, 'weight'),
        (model.classifier.fc2, 'weight')
    ]

# Apply pruning
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

### Sparsity after pruning

Let's analyse how much percentage of parameters are pruned away layerwise

In [None]:
print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print("=="*80)
print(
    "Sparsity in layer1 first bottle neck conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer1[0].conv1.weight == 0))
        /float(model.layer1[0].conv1.weight.nelement())
    )
)

print(
    "Sparsity in layer1 first bottle neck conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer1[0].conv2.weight == 0))
        /float(model.layer1[0].conv2.weight.nelement())
    )
)

print(
    "Sparsity in layer1 first bottle neck conv3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer1[0].conv3.weight == 0))
        /float(model.layer1[0].conv3.weight.nelement())
    )
)
print(
    "Sparsity in layer1 first bottle neck downsample Conv2d.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer1[0].downsample[0].weight == 0))
        /float(model.layer1[0].downsample[0].weight.nelement())
    )
)
print(" "*100)
print(
    "Sparsity in layer1 second bottle neck conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer1[1].conv1.weight == 0))
        /float(model.layer1[1].conv1.weight.nelement())
    )
)

print(
    "Sparsity in layer1 second bottle neck conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer1[1].conv2.weight == 0))
        /float(model.layer1[1].conv2.weight.nelement())
    )
)

print(
    "Sparsity in layer1 second bottle neck conv3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer1[1].conv3.weight == 0))
        /float(model.layer1[1].conv3.weight.nelement())
    )
)

print(" "*100)
print(
    "Sparsity in layer1 third bottle neck conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer1[2].conv1.weight == 0))
        /float(model.layer1[2].conv1.weight.nelement())
    )
)

print(
    "Sparsity in layer1 third bottle neck conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer1[2].conv2.weight == 0))
        /float(model.layer1[2].conv2.weight.nelement())
    )
)

print(
    "Sparsity in layer1 third bottle neck conv3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer1[2].conv3.weight == 0))
        /float(model.layer1[2].conv3.weight.nelement())
    )
)
print("=="*80)
print(
    "Sparsity in layer2 first bottle neck conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer2[0].conv1.weight == 0))
        /float(model.layer2[0].conv1.weight.nelement())
    )
)

print(
    "Sparsity in layer2 first bottle neck conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer2[0].conv2.weight == 0))
        /float(model.layer2[0].conv2.weight.nelement())
    )
)

print(
    "Sparsity in layer2 first bottle neck conv3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer2[0].conv3.weight == 0))
        /float(model.layer2[0].conv3.weight.nelement())
    )
)
print(
    "Sparsity in layer2 first bottle neck downsample Conv2d.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer2[0].downsample[0].weight == 0))
        /float(model.layer2[0].downsample[0].weight.nelement())
    )
)
print(" "*100)
print(
    "Sparsity in layer2 second bottle neck conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer2[1].conv1.weight == 0))
        /float(model.layer2[1].conv1.weight.nelement())
    )
)

print(
    "Sparsity in layer2 second bottle neck conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer2[1].conv2.weight == 0))
        /float(model.layer2[1].conv2.weight.nelement())
    )
)

print(
    "Sparsity in layer2 second bottle neck conv3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer2[1].conv3.weight == 0))
        /float(model.layer2[1].conv3.weight.nelement())
    )
)

print(" "*100)
print(
    "Sparsity in layer2 third bottle neck conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer2[2].conv1.weight == 0))
        /float(model.layer2[2].conv1.weight.nelement())
    )
)

print(
    "Sparsity in layer2 third bottle neck conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer2[2].conv2.weight == 0))
        /float(model.layer2[2].conv2.weight.nelement())
    )
)

print(
    "Sparsity in layer2 third bottle neck conv3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer2[2].conv3.weight == 0))
        /float(model.layer2[2].conv3.weight.nelement())
    )
)

print(" "*100)
print(
    "Sparsity in layer2 fourth bottle neck conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer2[3].conv1.weight == 0))
        /float(model.layer2[3].conv1.weight.nelement())
    )
)

print(
    "Sparsity in layer2 fourth bottle neck conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer2[3].conv2.weight == 0))
        /float(model.layer2[3].conv2.weight.nelement())
    )
)

print(
    "Sparsity in layer2 fourth bottle neck conv3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer2[3].conv3.weight == 0))
        /float(model.layer2[3].conv3.weight.nelement())
    )
)

print("=="*80)
print(
    "Sparsity in layer3 first bottle neck conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer3[0].conv1.weight == 0))
        /float(model.layer3[0].conv1.weight.nelement())
    )
)

print(
    "Sparsity in layer3 first bottle neck conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer3[0].conv2.weight == 0))
        /float(model.layer3[0].conv2.weight.nelement())
    )
)

print(
    "Sparsity in layer3 first bottle neck conv3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer3[0].conv3.weight == 0))
        /float(model.layer3[0].conv3.weight.nelement())
    )
)
print(
    "Sparsity in layer3 first bottle neck downsample Conv2d.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer3[0].downsample[0].weight == 0))
        /float(model.layer3[0].downsample[0].weight.nelement())
    )
)
print(" "*100)
print(
    "Sparsity in layer3 second bottle neck conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer3[1].conv1.weight == 0))
        /float(model.layer3[1].conv1.weight.nelement())
    )
)

print(
    "Sparsity in layer3 second bottle neck conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer3[1].conv2.weight == 0))
        /float(model.layer3[1].conv2.weight.nelement())
    )
)

print(
    "Sparsity in layer3 second bottle neck conv3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer3[1].conv3.weight == 0))
        /float(model.layer3[1].conv3.weight.nelement())
    )
)

print(" "*100)
print(
    "Sparsity in layer3 third bottle neck conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer3[2].conv1.weight == 0))
        /float(model.layer3[2].conv1.weight.nelement())
    )
)

print(
    "Sparsity in layer3 third bottle neck conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer3[2].conv2.weight == 0))
        /float(model.layer3[2].conv2.weight.nelement())
    )
)

print(
    "Sparsity in layer3 third bottle neck conv3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer3[2].conv3.weight == 0))
        /float(model.layer3[2].conv3.weight.nelement())
    )
)

print(" "*100)
print(
    "Sparsity in layer3 fourth bottle neck conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer3[3].conv1.weight == 0))
        /float(model.layer3[3].conv1.weight.nelement())
    )
)

print(
    "Sparsity in layer3 fourth bottle neck conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer3[3].conv2.weight == 0))
        /float(model.layer3[3].conv2.weight.nelement())
    )
)

print(
    "Sparsity in layer3 fourth bottle neck conv3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer3[3].conv3.weight == 0))
        /float(model.layer3[3].conv3.weight.nelement())
    )
)

print(" "*100)
print(
    "Sparsity in layer3 fifth bottle neck conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer3[4].conv1.weight == 0))
        /float(model.layer3[4].conv1.weight.nelement())
    )
)

print(
    "Sparsity in layer3 fifth bottle neck conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer3[4].conv2.weight == 0))
        /float(model.layer3[4].conv2.weight.nelement())
    )
)

print(
    "Sparsity in layer3 fifth bottle neck conv3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer3[4].conv3.weight == 0))
        /float(model.layer3[4].conv3.weight.nelement())
    )
)

print(" "*100)
print(
    "Sparsity in layer3 sixth bottle neck conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer3[5].conv1.weight == 0))
        /float(model.layer3[5].conv1.weight.nelement())
    )
)

print(
    "Sparsity in layer3 sixth bottle neck conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer3[5].conv2.weight == 0))
        /float(model.layer3[5].conv2.weight.nelement())
    )
)

print(
    "Sparsity in layer3 sixth bottle neck conv3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer3[5].conv3.weight == 0))
        /float(model.layer3[5].conv3.weight.nelement())
    )
)

print("=="*80)
print(
    "Sparsity in layer4 first bottle neck conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer4[0].conv1.weight == 0))
        /float(model.layer4[0].conv1.weight.nelement())
    )
)

print(
    "Sparsity in layer4 first bottle neck conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer4[0].conv2.weight == 0))
        /float(model.layer4[0].conv2.weight.nelement())
    )
)

print(
    "Sparsity in layer4 first bottle neck conv3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer4[0].conv3.weight == 0))
        /float(model.layer4[0].conv3.weight.nelement())
    )
)
print(
    "Sparsity in layer4 first bottle neck downsample Conv2d.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer4[0].downsample[0].weight == 0))
        /float(model.layer4[0].downsample[0].weight.nelement())
    )
)
print(" "*100)
print(
    "Sparsity in layer4 second bottle neck conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer4[1].conv1.weight == 0))
        /float(model.layer4[1].conv1.weight.nelement())
    )
)

print(
    "Sparsity in layer4 second bottle neck conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer4[1].conv2.weight == 0))
        /float(model.layer4[1].conv2.weight.nelement())
    )
)

print(
    "Sparsity in layer4 second bottle neck conv3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer4[1].conv3.weight == 0))
        /float(model.layer4[1].conv3.weight.nelement())
    )
)

print(" "*100)
print(
    "Sparsity in layer4 third bottle neck conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer4[2].conv1.weight == 0))
        /float(model.layer4[2].conv1.weight.nelement())
    )
)

print(
    "Sparsity in layer4 third bottle neck conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.layer4[2].conv2.weight == 0))
        /float(model.layer4[2].conv2.weight.nelement())
    )
)

print("=="*80)
print(
    "Sparsity in first linear weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc.weight == 0))
        /float(model.fc.weight.nelement())
    )
)
print(
    "Sparsity in second linear weight: {:.2f}%".format(
        100. * float(torch.sum(model.classifier.fc1.weight == 0))
        /float(model.classifier.fc1.weight.nelement())
    )
)
print(
    "Sparsity in third linear weight: {:.2f}%".format(
        100. * float(torch.sum(model.classifier.fc2.weight == 0))
        /float(model.classifier.fc2.weight.nelement())
    )
)

In [None]:
print("Sparsity of the model after pruning: {:.2f}%".format(calc_global_sparsity_resnet50(model)))

## Fine tuning using the pruned model

We will fine-tune the model with the pruned weights.

In [None]:
epochs = 20


tr_losses = []
avg_epoch_tr_loss = []
tr_accuracy = []


val_losses = []
avg_epoch_val_loss = []
val_accuracy = []
val_loss_min = np.Inf


for epoch in range(epochs):
    model.train()
    for i, batch in enumerate(train_loader):
        # Pull the data and labels from the batch
        data, label = batch
        # If available push data and label to GPU
        #    if train_on_gpu:
        data, label = data.to(device), label.to(device)
        # Clear the gradient
        optimizer.zero_grad()
        # Compute the logit
        logit = model(data)
        # Compte loss
        loss = criterion(logit, label)
        # Backpropagate the gradients (accumulte the partial derivatives of loss)
        loss.backward()
        # Apply the updates to the optimizer step in the opposite direction to the gradient
        optimizer.step()
        # Store the losses of each batch
        # loss.item() seperates the loss from comp graph
        tr_losses.append(loss.item())
        # Detach and store the average accuracy of each batch
        tr_accuracy.append(label.eq(logit.argmax(dim=1)).float().mean())
        # Print the rolling batch training loss every 20 batches
    #     if i % 40 == 0:
    #       print(f'Batch No: {i} \tAverage Training Batch Loss: {torch.tensor(tr_losses).mean():.2f}')
    # Print the average loss for each epoch
    print(
        f"\nEpoch No: {epoch + 1},Training Loss: {torch.tensor(tr_losses).mean():.2f}"
    )
    # Print the average accuracy for each epoch
    print(
        f"Epoch No: {epoch + 1}, Training Accuracy: {torch.tensor(tr_accuracy).mean():.2f}\n"
    )
    # Store the avg epoch loss for plotting
    avg_epoch_tr_loss.append(torch.tensor(tr_losses).mean())

    model.eval()
    for i, batch in enumerate(val_loader):
        # Pull the data and labels from the batch
        data, label = batch
        # If available push data and label to GPU
        data, label = data.to(device), label.to(device)
        # Compute the logits without computing the gradients
        with torch.no_grad():
            logit = model(data)
        # Compte loss
        loss = criterion(logit, label)
        # Store test loss
        val_losses.append(loss.item())
        # Store the accuracy for each batch
        val_accuracy.append(label.eq(logit.argmax(dim=1)).float().mean())
    #     if i % 40 == 0:
    #       print(f'Batch No: {i} \tAverage Val Batch Loss: {torch.tensor(val_losses).mean():.2f}')
    # Print the average loss for each epoch
    print(
        f"\nEpoch No: {epoch + 1}, Epoch Val Loss: {torch.tensor(val_losses).mean():.2f}"
    )
    # Print the average accuracy for each epoch
    print(
        f"Epoch No: {epoch + 1}, Epoch Val Accuracy: {torch.tensor(val_accuracy).mean():.2f}\n"
    )
    # Store the avg epoch loss for plotting
    avg_epoch_val_loss.append(torch.tensor(val_losses).mean())

    # Checpoininting the model using val loss threshold
    if torch.tensor(val_losses).float().mean() <= val_loss_min:
        print("Val Loss Decreased... Saving model")
        # save current model
        torch.save(model.state_dict(), "./model_state.pt")
        val_loss_min = torch.tensor(val_losses).mean()
    # Step the scheduler for the next epoch
    scheduler.step()
    # Print the updated learning rate
    print(
        "Learning Rate Set To: {:.10f}".format(
            optimizer.state_dict()["param_groups"][0]["lr"]
        ),
        "\n",
    )

We  can observe that the **fine-tuned + pruned+ again fine-tuned model** performs almost as good as the **fine-tuned dense model**. If the desired performance is not obtained or if there is a need for more sparsity, one can try this process iteratively, till it acheives the desired result.

Now, we will make the pruning effect permanent using 'prune.remove' function.

In [None]:
modules = [
        model.conv1,
        *[l.conv1 for l in model.layer1], *[l.conv2 for l in model.layer1], *[l.conv3 for l in model.layer1], *[l.downsample[0] for l in model.layer1 if l.downsample is not None],
        *[l.conv1 for l in model.layer2], *[l.conv2 for l in model.layer2], *[l.conv3 for l in model.layer2], *[l.downsample[0] for l in model.layer2 if l.downsample is not None],
        *[l.conv1 for l in model.layer3], *[l.conv2 for l in model.layer3], *[l.conv3 for l in model.layer3], *[l.downsample[0] for l in model.layer3 if l.downsample is not None],
        *[l.conv1 for l in model.layer4], *[l.conv2 for l in model.layer4], *[l.conv3 for l in model.layer4], *[l.downsample[0] for l in model.layer4 if l.downsample is not None],
        model.fc,
        model.classifier.fc1,
        model.classifier.fc2
    ]

for module in modules:
    prune.remove(module, 'weight')

To confirm,let's check the sparsity of the pruned+fine-tuned model:

In [None]:
print("Sparsity of the model after pruning and fine tuning: {:.2f}%".format(calc_global_sparsity_resnet50(model)))

# End Notes

* This script demonstrates maginitude weight pruning on a finetuned resnet-50 model 
* We were able to obtain 20% sparsity in one iteration without much performance drop
