# Iterative Pruning: Using _torch.nn.utils.prune_

Experiment includes 'LeNet-300-100' dense neural network using MNIST dataset.

In [61]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import numpy as np
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import os

In [2]:
# GPU device configuration-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Available device: {device}")

Available device: cpu


In [3]:
# Hyper-parameters-
input_size = 784    # 28 x 28, flattened to be 1-D tensor
hidden_size = 100
num_classes = 10
num_epochs = 20
batch_size = 32
learning_rate = 0.0012

In [4]:
# MNIST dataset statistics:
# mean = tensor([0.1307]) & std dev = tensor([0.3081])
mean = np.array([0.1307])
std_dev = np.array([0.3081])

transforms_apply = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = mean, std = std_dev)
    ])

In [5]:
os.chdir("/home/arjun/Documents/Programs/Python_Codes/PyTorch_Resources/Good_Codes/")

In [6]:
# MNIST dataset-
train_dataset = torchvision.datasets.MNIST(
        root = './data', train = True,
        transform = transforms_apply, download = True
        )

test_dataset = torchvision.datasets.MNIST(
        root = './data', train = False,
        transform = transforms_apply
        )


print(f"len(train_dataset): {len(train_dataset)} & len(test_dataset): {len(test_dataset)}")

len(train_dataset): 60000 & len(test_dataset): 10000


In [7]:
# Create dataloader-
train_loader = torch.utils.data.DataLoader(
        dataset = train_dataset, batch_size = batch_size,
        shuffle = True
        )

test_loader = torch.utils.data.DataLoader(
        dataset = test_dataset, batch_size = batch_size,
        shuffle = False
        )

print(f"len(train_loader) = {len(train_loader)} & len(test_loader) = {len(test_loader)}")
len(train_loader), len(test_loader)

len(train_loader) = 1875 & len(test_loader) = 313


(1875, 313)

In [8]:
class LeNet300(nn.Module):
    def __init__(self):
        super(LeNet300, self).__init__()
        
        # Define layers-
        self.fc1 = nn.Linear(in_features = input_size, out_features = 300)
        self.fc2 = nn.Linear(in_features = 300, out_features = 100)
        self.output = nn.Linear(in_features = 100, out_features = 10)
        
        self.weights_initialization()
    
    
    def forward(self, x):
        out = F.relu(self.fc1(x))
        out = F.relu(self.fc2(out))
        return self.output(out)
    
    
    def weights_initialization(self):
        '''
        When we define all the modules such as the layers in '__init__()'
        method above, these are all stored in 'self.modules()'.
        We go through each module one by one. This is the entire network,
        basically.
        '''
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)


In [37]:
# Initialize an instance of LeNet-300-100 dense neural network and load already trained model-
best_model = LeNet300()

In [38]:
# Load trained weights-
best_model.load_state_dict(torch.load("/home/arjun/Deep_Learning_Resources/LTH-Resources/LeNet-300-100_Trained.pth"))

<All keys matched successfully>

In [39]:
# Define loss and optimizer-
loss = nn.CrossEntropyLoss()    # applies softmax for us
optimizer = torch.optim.Adam(best_model.parameters(), lr = learning_rate)

In [12]:
def count_params(model):
    
    tot_params = 0
    for layer_name, param in model.named_parameters():
        # print(f"{layer_name}.shape = {param.shape} has {torch.count_nonzero(param.data)} non-zero params")
        tot_params += torch.count_nonzero(param.data)
    
    return tot_params


In [40]:
orig_params = count_params(best_model)
print(f"Original LeNet-300-100 model has {orig_params} trainable parameters")

Original LeNet-300-100 model has 266610 trainable parameters


In [14]:
for layer, param in best_model.named_parameters():
    print(f"layer.name: {layer} & param.shape = {param.shape}")

layer.name: fc1.weight & param.shape = torch.Size([300, 784])
layer.name: fc1.bias & param.shape = torch.Size([300])
layer.name: fc2.weight & param.shape = torch.Size([100, 300])
layer.name: fc2.bias & param.shape = torch.Size([100])
layer.name: output.weight & param.shape = torch.Size([10, 100])
layer.name: output.bias & param.shape = torch.Size([10])


In [18]:
# To access layer names-
best_model.state_dict().keys()

odict_keys(['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'output.weight', 'output.bias'])

In [25]:
def layer_wise_pruning():
    '''
    Function to compute layer-wise pruning for iterative pruning using LeNet-300-100 model
    and MNIST dataset
    '''
    
    # number of fully-connected dense parameters-
    dense1 = 235200
    dense2 = 30000
    op_layer = 100


    # total number of parameters-
    total_params = dense1 + dense2 + op_layer
    # print(f"Total number of trainable parameters = {total_params}")
    
    # maximum pruning performed is till 0.5% of all parameters-
    max_pruned_params = 0.005 * total_params
        
    loc_tot_params = total_params
    loc_dense1 = dense1
    loc_dense2 = dense2
    loc_op_layer = op_layer

    # variable to count number of pruning rounds-
    n = 0
        
    # Lists to hold percentage of weights pruned in each round for all layers in NN-
    dense1_pruning = []
    dense2_pruning = []
    op_layer_pruning = []
    
    
    while loc_tot_params >= max_pruned_params:
        loc_dense1 *= 0.8   # 20% weights are pruned
        loc_dense2 *= 0.8   # 20% weights are pruned
        loc_op_layer *= 0.9 # 10% weights are pruned

        dense1_pruning.append(((dense1 - loc_dense1) / dense1) * 100)
        dense2_pruning.append(((dense2 - loc_dense2) / dense2) * 100)
        op_layer_pruning.append(((op_layer - loc_op_layer) / op_layer) * 100)

        loc_tot_params = loc_dense1 + loc_dense2 + loc_op_layer

        n += 1

        '''
        print("\nConv1 = {0:.3f}, Conv2 = {1:.3f}, Conv3 = {2:.4f}".format(loc_conv1, loc_conv2, loc_conv3))
        print("Conv4 = {0:.3f}, Conv5 = {1:.3f} & Conv6 = {2:.3f}".format(loc_conv4, loc_conv5, loc_conv6))
        print("Dense1 = {0:.3f}, Dense2 = {1:.3f} & O/p layer = {2:.3f}".format(
            loc_dense1, loc_dense2, loc_op_layer))
        print("Total number of parameters = {0:.3f}\n".format(loc_tot_params))
        '''
    
    # print("\nnumber of pruning rounds = {0}\n\n".format(n))
    
    # Convert from list to np.array-
    dense1_pruning = np.array(dense1_pruning)
    dense2_pruning = np.array(dense2_pruning)
    op_layer_pruning = np.array(op_layer_pruning)

    # Round off numpy arrays to 3 decimal digits-
    dense1_pruning = np.round(dense1_pruning, decimals=3)
    dense2_pruning = np.round(dense2_pruning, decimals=3)
    op_layer_pruning = np.round(op_layer_pruning, decimals=3)

    
    # Python 3 dict to hold layer
    pruning_layers = {}
    
    pruning_layers['dense1_pruning'] = dense1_pruning
    pruning_layers['dense2_pruning'] = dense2_pruning
    pruning_layers['op_layer'] = op_layer_pruning
    
    return pruning_layers, n


In [26]:
pruning_layers, num_pruning_rounds = layer_wise_pruning()

In [28]:
print(f"number of pruning rounds = {num_pruning_rounds} to achieve 0.5% final sparsity")

number of pruning rounds = 24 to achieve 0.5% final sparsity


In [30]:
# Sanity check- percentage of weights to prune for first dense layer in each iterative pruning round-
pruning_layers['dense1_pruning']

array([20.   , 36.   , 48.8  , 59.04 , 67.232, 73.786, 79.028, 83.223,
       86.578, 89.263, 91.41 , 93.128, 94.502, 95.602, 96.482, 97.185,
       97.748, 98.199, 98.559, 98.847, 99.078, 99.262, 99.41 , 99.528])

In [31]:
pruning_layers['op_layer']

array([10.   , 19.   , 27.1  , 34.39 , 40.951, 46.856, 52.17 , 56.953,
       61.258, 65.132, 68.619, 71.757, 74.581, 77.123, 79.411, 81.47 ,
       83.323, 84.991, 86.491, 87.842, 89.058, 90.152, 91.137, 92.023])

In [41]:
def train_model(model, train_loader):
    '''
    Function to perform one epoch of training by using 'train_loader'.
    Returns loss and number of correct predictions for this epoch.
    '''
    running_loss = 0.0
    running_corrects = 0.0

    for batch, (images, labels) in enumerate(train_loader):
        # Reshape image and place it on GPU-
        images = images.reshape(-1, input_size).to(device)
        labels = labels.to(device) 
        outputs = model(images)   # forward pass
        J = loss(outputs, labels) # compute loss
        optimizer.zero_grad()     # empty accumulated gradients
        J.backward()              # perform backpropagation
        optimizer.step()          # update parameters

        # Compute model's performance statistics-
        running_loss += J.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        running_corrects += torch.sum(predicted == labels.data)

        '''
        # Print information every 100 steps-
        if (batch + 1) % 100 == 0:
            print(f"epoch {epoch + 1}/{num_epochs}, step {batch + 1}/{num_steps}, loss = {J.item():.4f}")
        '''

    return running_loss, running_corrects


In [42]:
def test_model(model, test_loader):
    total = 0.0
    correct = 0.0
    running_loss_val = 0.0

    with torch.no_grad():
        for images, labels in test_loader:

            # Place features (images) and targets (labels) to GPU-
            images = images.reshape(-1, input_size).to(device)
            labels = labels.to(device)
            # print(f"images.shape = {images.shape}, labels.shape = {labels.shape}")

            # Set model to evaluation mode-
            model.eval()
    
            # Make predictions using trained model-
            outputs = model(images)
            _, y_pred = torch.max(outputs, 1)

            # Compute validation loss-
            J_val = loss(outputs, labels)

            running_loss_val += J_val.item() * labels.size(0)
    
            # Total number of labels-
            total += labels.size(0)

            # Total number of correct predictions-
            correct += (y_pred == labels).sum()

    return (running_loss_val, correct, total)


In [43]:
# User input parameters for Early Stopping in manual implementation-
minimum_delta = 0.001
patience = 5

In [44]:
# Initialize parameters for Early Stopping manual implementation-
best_val_loss = 100
loc_patience = 0

# running_loss = 0.0
# running_corrects = 0

In [45]:
# Python3 lists to store model training metrics-
training_acc = []
validation_acc = []
training_loss = []
validation_loss = []

In [46]:
# Training loop-
for epoch in range(num_epochs):
    running_loss = 0.0
    running_corrects = 0.0
    
    if loc_patience >= patience:
        print("\n'EarlyStopping' called!\n")
        break

    running_loss, running_corrects = train_model(best_model, train_loader)
  
    epoch_loss = running_loss / len(train_dataset)
    epoch_acc = running_corrects.double() / len(train_dataset)
    # epoch_acc = 100 * running_corrects / len(trainset)
    # print(f"\nepoch: {epoch + 1} training loss = {epoch_loss:.4f}, training accuracy = {epoch_acc * 100:.2f}%\n")

    running_loss_val, correct, total = test_model(best_model, test_loader)

    epoch_val_loss = running_loss_val / len(test_dataset)
    val_acc = 100 * (correct / total)
    # print(f"\nepoch: {epoch + 1} training loss = {epoch_loss:.4f}, training accuracy = {epoch_acc * 100:.2f}%, val_loss = {epoch_val_loss:.4f} & val_accuracy = {val_acc:.2f}%\n")

    print(f"\nepoch: {epoch + 1} training loss = {epoch_loss:.4f}, training accuracy = {epoch_acc * 100:.2f}%, val_loss = {epoch_val_loss:.4f} & val_accuracy = {val_acc:.2f}%\n")

    # Code for manual Early Stopping:
    # if np.abs(epoch_val_loss < best_val_loss) >= minimum_delta:
    if (epoch_val_loss < best_val_loss) and np.abs(epoch_val_loss - best_val_loss) >= minimum_delta:
        # print(f"epoch_val_loss = {epoch_val_loss:.4f}, best_val_loss = {best_val_loss:.4f}")
        
        # update 'best_val_loss' variable to lowest loss encountered so far-
        best_val_loss = epoch_val_loss
        
        # reset 'loc_patience' variable-
        loc_patience = 0
        
        print(f"\nSaving model with lowest val_loss = {epoch_val_loss:.4f}")
        
        # Save trained model with validation accuracy-
        # torch.save(model.state_dict, f"LeNet-300-100_Trained_{val_acc}.pth")
        torch.save(best_model.state_dict(), "LeNet-300-100_Trained.pth")
        
    else:  # there is no improvement in monitored metric 'val_loss'
        loc_patience += 1  # number of epochs without any improvement


    training_acc.append(epoch_acc * 100)
    validation_acc.append(val_acc)
    training_loss.append(epoch_loss)
    validation_loss.append(epoch_val_loss)
    


epoch: 1 training loss = 0.0428, training accuracy = 98.70%, val_loss = 0.1167 & val_accuracy = 97.01%


Saving model with lowest val_loss = 0.1167

epoch: 2 training loss = 0.0351, training accuracy = 98.92%, val_loss = 0.1107 & val_accuracy = 97.54%


Saving model with lowest val_loss = 0.1107

epoch: 3 training loss = 0.0360, training accuracy = 98.92%, val_loss = 0.0979 & val_accuracy = 97.89%


Saving model with lowest val_loss = 0.0979

epoch: 4 training loss = 0.0303, training accuracy = 99.11%, val_loss = 0.1028 & val_accuracy = 97.83%


epoch: 5 training loss = 0.0320, training accuracy = 99.07%, val_loss = 0.1109 & val_accuracy = 97.61%


epoch: 6 training loss = 0.0244, training accuracy = 99.31%, val_loss = 0.1262 & val_accuracy = 97.53%


epoch: 7 training loss = 0.0270, training accuracy = 99.25%, val_loss = 0.1307 & val_accuracy = 97.47%


epoch: 8 training loss = 0.0246, training accuracy = 99.26%, val_loss = 0.1249 & val_accuracy = 97.89%


'EarlyStopping' called!



In [48]:
# Initialize a LeNet-300-100 model to contain best weights from above-
best_model = LeNet300()

In [49]:
# Load weights-
best_model.load_state_dict(torch.load("/home/arjun/Documents/Programs/Python_Codes/PyTorch_Resources/Good_Codes/LeNet-300-100_Trained.pth"))

<All keys matched successfully>

In [50]:
# Compute performance of trained model on validation dataset-
running_loss_val, correct, total = test_model(best_model, test_loader)

In [52]:
val_loss = running_loss_val / len(test_dataset)
val_acc = 100 * (correct / total)

In [55]:
print("Trained LeNet-300-100 dense neural network's performance on validation data:")
print(f"val_loss = {val_loss:.4f}, val_accuracy = {val_acc:.2f}%")

Trained LeNet-300-100 dense neural network's performance on validation data:
val_loss = 0.0979, val_accuracy = 97.89%


## Implement Neural Network Pruning using _'torch.nn.utils.prune'_ module


The available pruning methods:

The following child classes inherit from the _BasePruningMethod_:

- _torch.nn.utils.prune.Identity_: utility pruning method that does not prune any units but generates the pruning parametrization with a mask of ones

- _torch.nn.utils.prune.RandomUnstructured_: prune (currently unpruned) entries in a tensor at random

- _torch.nn.utils.prune.L1Unstructured_: prune (currently unpruned) entries in a tensor by zeroing out the ones with the lowest absolute magnitude

- _torch.nn.utils.prune.RandomStructured_: prune entire (currently unpruned) rows or columns in a tensor at random

- _torch.nn.utils.prune.LnStructured_: prune entire (currently unpruned) rows or columns in a tensor based on their Ln-norm (supported values of n correspond to supported values for argument p in torch.norm())
- _torch.nn.utils.prune.CustomFromMask_: prune a tensor using a user-provided mask.


Their 'functional equivalents' are:

- _torch.nn.utils.prune.identity_

- _torch.nn.utils.prune.random_unstructured_

- _torch.nn.utils.prune.l1_unstructured_

- _torch.nn.utils.prune.random_structured_

- _torch.nn.utils.prune.ln_structured_

- _torch.nn.utils.prune.custom_from_mask_

Global pruning, in which entries are compared across multiple tensors, is enabled through "torch.nn.utils.prune.global unstructured"


Refer to the research paper "Streamlining Tensor and Network Pruning in PyTorch" by Michela Paganini et al. for more details and [PyTorch Pruning Tutorial](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html#iterative-pruning).


### Pruning multiple parameters in a model:

By specifying the desired pruning technique and parameters, we can easily prune multiple tensors in a neural network, perhaps according to their type, as can be seen in this example-

In [62]:
# Prune multiple parameters/layers in a given model-
for name, module in best_model.named_modules():
    '''
    # prune 15% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module = module, name = 'weight', amount = 0.15)
    '''
    # prune 20% of weights/connections in for all hidden layaers-
    if isinstance(module, torch.nn.Linear) and name != 'output':
        prune.l1_unstructured(module = module, name = 'weight', amount = 0.2)
    
    # prune 10% of weights/connections for output layer-
    elif isinstance(module, torch.nn.Linear) and name == 'output':
        prune.l1_unstructured(module = module, name = 'weight', amount = 0.1)


In [63]:
# Sanity check: verify that all of the defined pruning exists as masks-
print(dict(best_model.named_buffers()).keys())

dict_keys(['fc1.weight_mask', 'fc2.weight_mask', 'output.weight_mask'])


In [72]:
curr_params = count_params(best_model)

In [76]:
curr_params.numpy()

array(266610)

In [80]:
# User input parameters for Early Stopping in manual implementation-
minimum_delta = 0.001
patience = 5

In [81]:
# Initialize parameters for Early Stopping manual implementation-
best_val_loss = 100
loc_patience = 0

# running_loss = 0.0
# running_corrects = 0

In [45]:
# Python3 lists to store model training metrics-
training_acc = []
validation_acc = []
training_loss = []
validation_loss = []

In [82]:
# Training loop-
for epoch in range(num_epochs):
    running_loss = 0.0
    running_corrects = 0.0
    
    if loc_patience >= patience:
        print("\n'EarlyStopping' called!\n")
        break

    running_loss, running_corrects = train_model(best_model, train_loader)
  
    epoch_loss = running_loss / len(train_dataset)
    epoch_acc = running_corrects.double() / len(train_dataset)
    # epoch_acc = 100 * running_corrects / len(trainset)
    # print(f"\nepoch: {epoch + 1} training loss = {epoch_loss:.4f}, training accuracy = {epoch_acc * 100:.2f}%\n")

    running_loss_val, correct, total = test_model(best_model, test_loader)

    epoch_val_loss = running_loss_val / len(test_dataset)
    val_acc = 100 * (correct / total)
    # print(f"\nepoch: {epoch + 1} training loss = {epoch_loss:.4f}, training accuracy = {epoch_acc * 100:.2f}%, val_loss = {epoch_val_loss:.4f} & val_accuracy = {val_acc:.2f}%\n")

    print(f"\nepoch: {epoch + 1} training loss = {epoch_loss:.4f}, training accuracy = {epoch_acc * 100:.2f}%, val_loss = {epoch_val_loss:.4f} & val_accuracy = {val_acc:.2f}%")

    curr_params = count_params(best_model)
    print(f"Number of parameters = {curr_params}\n")
    
    percentage_pruned = ((orig_params - curr_params.numpy()) / orig_params * 100).numpy()
    
    # Code for manual Early Stopping:
    # if np.abs(epoch_val_loss < best_val_loss) >= minimum_delta:
    if (epoch_val_loss < best_val_loss) and np.abs(epoch_val_loss - best_val_loss) >= minimum_delta:
        # print(f"epoch_val_loss = {epoch_val_loss:.4f}, best_val_loss = {best_val_loss:.4f}")
        
        # update 'best_val_loss' variable to lowest loss encountered so far-
        best_val_loss = epoch_val_loss
        
        # reset 'loc_patience' variable-
        loc_patience = 0
        
        print(f"\nSaving model with lowest val_loss = {epoch_val_loss:.4f}")
        
        # Save trained model with validation accuracy-
        # torch.save(model.state_dict, f"LeNet-300-100_Trained_{val_acc}.pth")
        torch.save(best_model.state_dict(), f"LeNet-300-100_{percentage_pruned:.2f}.pth")
        
    else:  # there is no improvement in monitored metric 'val_loss'
        loc_patience += 1  # number of epochs without any improvement

    '''
    training_acc.append(epoch_acc * 100)
    validation_acc.append(val_acc)
    training_loss.append(epoch_loss)
    validation_loss.append(epoch_val_loss)
    '''


epoch: 1 training loss = 0.0266, training accuracy = 99.11%, val_loss = 0.0980 & val_accuracy = 97.94%
Number of parameters = 266610


Saving model with lowest val_loss = 0.0980

epoch: 2 training loss = 0.0266, training accuracy = 99.11%, val_loss = 0.0980 & val_accuracy = 97.94%
Number of parameters = 266610


epoch: 3 training loss = 0.0266, training accuracy = 99.11%, val_loss = 0.0980 & val_accuracy = 97.94%
Number of parameters = 266610


epoch: 4 training loss = 0.0266, training accuracy = 99.11%, val_loss = 0.0980 & val_accuracy = 97.94%
Number of parameters = 266610


epoch: 5 training loss = 0.0266, training accuracy = 99.11%, val_loss = 0.0980 & val_accuracy = 97.94%
Number of parameters = 266610


epoch: 6 training loss = 0.0266, training accuracy = 99.11%, val_loss = 0.0980 & val_accuracy = 97.94%
Number of parameters = 266610


'EarlyStopping' called!

