# PyTorch Pruning: _LeNet-5_ Conv Net

Trained on MNIST dataset using __global unstructured absolute magnitude__ based pruning.

[PyTorch Pruning Tutorial](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html#serializing-a-pruned-model)

In [1]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import numpy as np
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import os

In [2]:
%matplotlib inline

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Available device: {device}")

Available device: cpu


In [4]:
# Hyper-parameters-
# input_size = 784    # 28 x 28, flattened to be 1-D tensor
# hidden_size = 100
num_classes = 10
num_epochs = 20
batch_size = 32
learning_rate = 0.0012

In [5]:
# MNIST dataset statistics:
# mean = tensor([0.1307]) & std dev = tensor([0.3081])
mean = np.array([0.1307])
std_dev = np.array([0.3081])

transforms_apply = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = mean, std = std_dev)
    ])

In [6]:
os.chdir("/home/arjun/Documents/Programs/Python_Codes/PyTorch_Resources/Good_Codes/")

In [7]:
# MNIST dataset-
train_dataset = torchvision.datasets.MNIST(
        root = './data', train = True,
        transform = transforms_apply, download = True
        )

test_dataset = torchvision.datasets.MNIST(
        root = './data', train = False,
        transform = transforms_apply
        )


print(f"len(train_dataset): {len(train_dataset)} & len(test_dataset): {len(test_dataset)}")

len(train_dataset): 60000 & len(test_dataset): 10000


In [8]:
# Create dataloader-
train_loader = torch.utils.data.DataLoader(
        dataset = train_dataset, batch_size = batch_size,
        shuffle = True
        )

test_loader = torch.utils.data.DataLoader(
        dataset = test_dataset, batch_size = batch_size,
        shuffle = False
        )

print(f"len(train_loader) = {len(train_loader)} & len(test_loader) = {len(test_loader)}")
len(train_loader), len(test_loader)

len(train_loader) = 1875 & len(test_loader) = 313


(1875, 313)

### Create a model:

In this tutorial, we use the [LeNet CNN](http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf) architecture from LeCun et al., 1998.

In [9]:
class LeNet5(nn.Module):
    '''
    Implements a variation of LeNet-5 CNN. It is LeNet-4.
    '''
    def __init__(self):
        super(LeNet5, self).__init__()
        
        self.conv1 = nn.Conv2d(
            in_channels = 1, out_channels = 6,
            kernel_size = 3, padding = 1,
            stride = 1
        )
        
        self.conv2 = nn.Conv2d(
            in_channels = 6, out_channels = 16,
            kernel_size = 3, padding = 1,
            stride = 1
        )
        
        self.conv3 = nn.Conv2d(
            in_channels = 16, out_channels = 120,
            kernel_size = 3, padding = 1,
            stride = 1
        )
        
        self.pool = nn.MaxPool2d(
            kernel_size = 2, stride = 2
        )
        
        self.flatten = nn.Flatten()
        # self.fc1 = nn.Linear(in_features = 512, out_features = 256)
        # self.fc2 = nn.Linear(in_features = 120, out_features = 84)
        # self.op = nn.Linear(in_features = 84, out_features = 10)
        self.op = nn.Linear(in_features = 1080, out_features = 10)
        
        self.weights_initialization()
        
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.flatten(x)
        return self.op(x)
    
    
    def shape_computation(self, x):
        x = self.conv1(x)
        print(f"conv1.shape = {x.shape}")
        
        x = self.pool(x)
        print(f"pool.shape = {x.shape}")
        
        x = self.conv2(x)
        print(f"conv2.shape = {x.shape}")
        
        x = self.pool(x)
        print(f"pool.shape = {x.shape}")
        
        x = self.conv3(x)
        print(f"conv3.shape = {x.shape}")
        
        x = self.pool(x)
        print(f"pool.shape = {x.shape}")
        
        x = self.flatten(x)
        print(f"flatten.shape = {x.shape}")
        
        x = self.op(x)
        print(f"output.shape = {x.shape}")
        
    
    def weights_initialization(self):
        '''
        When we define all the modules such as the layers in '__init__()'
        method above, these are all stored in 'self.modules()'.
        We go through each module one by one. This is the entire network,
        basically.
        '''
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

        

In [10]:
# model = LeNet5().to(device = device)

In [None]:
# Define loss and optimizer-
# loss = nn.CrossEntropyLoss()    # applies softmax for us
# optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

In [10]:
def count_params(model):
    
    tot_params = 0
    for layer_name, param in model.named_parameters():
        # print(f"{layer_name}.shape = {param.shape} has {torch.count_nonzero(param.data)} non-zero params")
        tot_params += torch.count_nonzero(param.data)
    
    return tot_params


In [None]:
orig_params = count_params(model)
print(f"Unpruned LeNet-4 model has {orig_params} trainable parameters")

In [19]:
for layer, param in model.named_parameters():
    print(f"layer.name: {layer} & param.shape = {param.shape}")

layer.name: conv1.weight & param.shape = torch.Size([6, 1, 3, 3])
layer.name: conv1.bias & param.shape = torch.Size([6])
layer.name: conv2.weight & param.shape = torch.Size([16, 6, 3, 3])
layer.name: conv2.bias & param.shape = torch.Size([16])
layer.name: conv3.weight & param.shape = torch.Size([120, 16, 3, 3])
layer.name: conv3.bias & param.shape = torch.Size([120])
layer.name: op.weight & param.shape = torch.Size([10, 1080])
layer.name: op.bias & param.shape = torch.Size([10])


### Train model:

In [12]:
def train_model(model, train_loader):
    '''
    Function to perform one epoch of training by using 'train_loader'.
    Returns loss and number of correct predictions for this epoch.
    '''
    running_loss = 0.0
    running_corrects = 0.0

    for batch, (images, labels) in enumerate(train_loader):
        # Reshape image and place it on GPU-
        # images = images.reshape(-1, input_size).to(device)
        images = images.to(device)
        labels = labels.to(device) 
        outputs = model(images)   # forward pass
        J = loss(outputs, labels) # compute loss
        optimizer.zero_grad()     # empty accumulated gradients
        J.backward()              # perform backpropagation
        optimizer.step()          # update parameters

        # Compute model's performance statistics-
        running_loss += J.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        running_corrects += torch.sum(predicted == labels.data)

        '''
        # Print information every 100 steps-
        if (batch + 1) % 100 == 0:
            print(f"epoch {epoch + 1}/{num_epochs}, step {batch + 1}/{num_steps}, loss = {J.item():.4f}")
        '''

    return running_loss, running_corrects


In [13]:
def test_model(model, test_loader):
    total = 0.0
    correct = 0.0
    running_loss_val = 0.0

    with torch.no_grad():
        for images, labels in test_loader:

            # Place features (images) and targets (labels) to GPU-
            # images = images.reshape(-1, input_size).to(device)
            images = images.to(device)
            labels = labels.to(device)
            # print(f"images.shape = {images.shape}, labels.shape = {labels.shape}")

            # Set model to evaluation mode-
            model.eval()
    
            # Make predictions using trained model-
            outputs = model(images)
            _, y_pred = torch.max(outputs, 1)

            # Compute validation loss-
            J_val = loss(outputs, labels)

            running_loss_val += J_val.item() * labels.size(0)
    
            # Total number of labels-
            total += labels.size(0)

            # Total number of correct predictions-
            correct += (y_pred == labels).sum()

    return (running_loss_val, correct, total)


In [14]:
# User input parameters for Early Stopping in manual implementation-
minimum_delta = 0.001
patience = 5


# Initialize parameters for Early Stopping manual implementation-
best_val_loss = 100
loc_patience = 0

In [23]:
# Python3 lists to store model training metrics-
training_acc = []
validation_acc = []
training_loss = []
validation_loss = []

In [None]:
# Training loop-
for epoch in range(num_epochs):
    running_loss = 0.0
    running_corrects = 0.0
    
    if loc_patience >= patience:
        print("\n'EarlyStopping' called!\n")
        break

    running_loss, running_corrects = train_model(model, train_loader)
  
    # Compute training loss and accuracy for one epoch-
    epoch_loss = running_loss / len(train_dataset)
    epoch_acc = running_corrects.double() / len(train_dataset)
    # epoch_acc = 100 * running_corrects / len(trainset)
    # print(f"\nepoch: {epoch + 1} training loss = {epoch_loss:.4f}, training accuracy = {epoch_acc * 100:.2f}%\n")

    running_loss_val, correct, total = test_model(model, test_loader)

    # Compute validation loss and accuracy-
    epoch_val_loss = running_loss_val / len(test_dataset)
    val_acc = 100 * (correct / total)
    # print(f"\nepoch: {epoch + 1} training loss = {epoch_loss:.4f}, training accuracy = {epoch_acc * 100:.2f}%, val_loss = {epoch_val_loss:.4f} & val_accuracy = {val_acc:.2f}%\n")

    print(f"\nepoch: {epoch + 1} training loss = {epoch_loss:.4f}, training accuracy = {epoch_acc * 100:.2f}%, val_loss = {epoch_val_loss:.4f} & val_accuracy = {val_acc:.2f}%\n")

    # Code for manual Early Stopping:
    # if np.abs(epoch_val_loss < best_val_loss) >= minimum_delta:
    if (epoch_val_loss < best_val_loss) and np.abs(epoch_val_loss - best_val_loss) >= minimum_delta:
        # print(f"epoch_val_loss = {epoch_val_loss:.4f}, best_val_loss = {best_val_loss:.4f}")
        
        # update 'best_val_loss' variable to lowest loss encountered so far-
        best_val_loss = epoch_val_loss
        
        # reset 'loc_patience' variable-
        loc_patience = 0
        
        print(f"Saving model with lowest val_loss = {epoch_val_loss:.4f}\n")
        
        # Save trained model with validation accuracy-
        # torch.save(model.state_dict, f"LeNet-300-100_Trained_{val_acc}.pth")
        torch.save(model.state_dict(), "LeNet-4_Trained.pth")
        
    else:  # there is no improvement in monitored metric 'val_loss'
        loc_patience += 1  # number of epochs without any improvement


    training_acc.append(epoch_acc * 100)
    validation_acc.append(val_acc)
    training_loss.append(epoch_loss)
    validation_loss.append(epoch_val_loss)


In [15]:
# Initialize and load best weights-
best_model = LeNet5().to(device = device)
best_model.load_state_dict(torch.load("/home/arjun/Documents/Programs/Python_Codes/PyTorch_Resources/Good_Codes/Pruning_codes_and_resources/LeNet-4_Trained.pth"))

<All keys matched successfully>

In [16]:
# Define loss and optimizer-
loss = nn.CrossEntropyLoss()    # applies softmax for us
optimizer = torch.optim.Adam(best_model.parameters(), lr = learning_rate)

In [17]:
print(f"number of non-zero parameter in unpruned LeNet-4 CNN = {count_params(best_model)}")

number of non-zero parameter in unpruned LeNet-4 CNN = 29150


In [16]:
# Compute 'best weights' metrics on validation dataset-
running_loss_val, correct, total = test_model(best_model, test_loader)

# Compute validation loss and accuracy-
val_loss = running_loss_val / len(test_dataset)
val_acc = 100 * (correct / total)

print("Best trained LeNet-4 metrics on validation dataset:")
print(f"val_loss = {val_loss:.4f} & val_acc = {val_acc:.2f}%")

Best trained LeNet-4 metrics on validation dataset:
val_loss = 0.0247 & val_acc = 99.21%


In [17]:
for layer_name in best_model.state_dict().keys():
    print(layer_name, best_model.state_dict()[layer_name].shape)

conv1.weight torch.Size([6, 1, 3, 3])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 3, 3])
conv2.bias torch.Size([16])
conv3.weight torch.Size([120, 16, 3, 3])
conv3.bias torch.Size([120])
op.weight torch.Size([10, 1080])
op.bias torch.Size([10])


In [18]:
best_model.state_dict().keys()

odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'conv3.weight', 'conv3.bias', 'op.weight', 'op.bias'])

## Global pruning:

A common and powerful technique is to prune the model all at once, by removing (for example) the lowest 20% of connections across the whole model, instead of removing the lowest 20% of connections in each layer. This is likely to result in different pruning percentages per layer. Let’s see how to do that using global_unstructured from ``torch.nn.utils.prune``.

In [18]:
# Define the parameters/layers to be pruned-
parameters_to_prune = (
    (best_model.conv1, 'weight'),
    (best_model.conv2, 'weight'),
    (best_model.conv3, 'weight'),
    (best_model.op, 'weight')
)

'''
prune.global_unstructured(
    parameters_to_prune,
    pruning_method = prune.L1Unstructured,
    amount = 0.2,
)
'''

'\nprune.global_unstructured(\n    parameters_to_prune,\n    pruning_method = prune.L1Unstructured,\n    amount = 0.2,\n)\n'

In [19]:
def compute_sparsity(best_model):
    conv1_sparsity = (torch.sum(best_model.conv1.weight == 0) / best_model.conv1.weight.nelement()) * 100
    conv2_sparsity = (torch.sum(best_model.conv2.weight == 0) / best_model.conv2.weight.nelement()) * 100
    conv3_sparsity = (torch.sum(best_model.conv3.weight == 0) / best_model.conv3.weight.nelement()) * 100
    op_sparsity = (torch.sum(best_model.op.weight == 0) / best_model.op.weight.nelement()) * 100

    # Compute global sparsity-
    num = torch.sum(best_model.conv1.weight == 0) + torch.sum(best_model.conv2.weight == 0) + torch.sum(best_model.conv3.weight == 0) \
            + torch.sum(best_model.op.weight == 0)

    denom = best_model.conv1.weight.nelement() + best_model.conv2.weight.nelement() + best_model.conv3.weight.nelement() \
            + best_model.op.weight.nelement()

    global_sparsity = num / denom * 100

    '''
    print(f"conv1.weight has {conv1_sparsity:.2f}% sparsity")
    print(f"conv2.weight has {conv2_sparsity:.2f}% sparsity")
    print(f"conv3.weight has {conv3_sparsity:.2f}% sparsity")
    print(f"op.weight has {op_sparsity:.2f}% sparsity")
    '''
    print(f"LeNet-4 Global Sparsity = {global_sparsity:.2f}%")
    
    return None


In [20]:
compute_sparsity(best_model)

LeNet-4 Global Sparsity = 0.00%


In [59]:
prune.global_unstructured(
    parameters_to_prune,
    pruning_method = prune.L1Unstructured,
    amount = 0.4,
)

In [61]:
compute_sparsity(best_model)

LeNet-4 Global Sparsity = 52.00%


In [60]:
count_params(best_model)

tensor(29150)

### Iterative Global Pruning algorithm:

Take an already trained model and repear the following steps _x_ times-

- prune p% of smallest magnitude weights in an __unstructured global manner__

- fine-tune pruned model to recover performance

In [21]:
# User input parameters for Early Stopping in manual implementation-
minimum_delta = 0.001
patience = 3


# Initialize parameters for Early Stopping manual implementation-
best_val_loss = 100
loc_patience = 0

In [51]:
'''
# Python3 lists to store model training metrics-
training_acc = []
validation_acc = []
training_loss = []
validation_loss = []
'''

In [22]:
# Define global pruning rates-
prune_rates_global = [0.2, 0.3, 0.4, 0.5, 0.6]

In [None]:
'''
for iter_prune_round in range(5):
    print(prune_rates_global[iter_prune_round])
'''

In [23]:
for iter_prune_round in range(5):
    print(f"\n\nIterative Global pruning round = {iter_prune_round + 1} with global sparsity = {prune_rates_global[iter_prune_round] * 100}%")
    
    # Prune globally-
    prune.global_unstructured(
        parameters_to_prune,
        pruning_method = prune.L1Unstructured,
        amount = prune_rates_global[iter_prune_round]
    )
        
    # Print current global sparsity level-
    compute_sparsity(best_model)

    
    # Fine-training loop-
    print("\nFine-tuning pruned model to recover model's performance\n")
    
    # Initialize parameters for Early Stopping manual implementation-
    best_val_loss = 100
    loc_patience = 0
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        running_corrects = 0.0
    
        if loc_patience >= patience:
            print("\n'EarlyStopping' called!\n")
            break

        running_loss, running_corrects = train_model(best_model, train_loader)
  
        # Compute training loss and accuracy for one epoch-
        epoch_loss = running_loss / len(train_dataset)
        epoch_acc = running_corrects.double() / len(train_dataset)
    
        running_loss_val, correct, total = test_model(best_model, test_loader)

        # Compute validation loss and accuracy-
        epoch_val_loss = running_loss_val / len(test_dataset)
        val_acc = 100 * (correct / total)
    
        print(f"\nepoch: {epoch + 1} training loss = {epoch_loss:.4f}, training accuracy = {epoch_acc * 100:.2f}%, val_loss = {epoch_val_loss:.4f} & val_accuracy = {val_acc:.2f}%\n")

        # Code for manual Early Stopping:
        if (epoch_val_loss < best_val_loss) and np.abs(epoch_val_loss - best_val_loss) >= minimum_delta:
        
            # update 'best_val_loss' variable to lowest loss encountered so far-
            best_val_loss = epoch_val_loss
        
            # reset 'loc_patience' variable-
            loc_patience = 0
        
            print(f"Saving model with lowest val_loss = {epoch_val_loss:.4f} for iterative pruning round = {iter_prune_round + 1}\n")
        
            # Save trained model with validation accuracy-
            # torch.save(model.state_dict, f"LeNet-300-100_Trained_{val_acc}.pth")
            torch.save(best_model.state_dict(), "LeNet-4_Trained.pth")
        
        else:  # there is no improvement in monitored metric 'val_loss'
            loc_patience += 1  # number of epochs without any improvement

        



Iterative Global pruning round = 1 with global sparsity = 20.0%
LeNet-4 Global Sparsity = 20.00%

Fine-tuning pruned model to recover model's performance


epoch: 1 training loss = 0.0088, training accuracy = 99.72%, val_loss = 0.0315 & val_accuracy = 99.09%

Saving model with lowest val_loss = 0.0315 for iterative pruning round = 1


epoch: 2 training loss = 0.0065, training accuracy = 99.77%, val_loss = 0.0350 & val_accuracy = 99.13%


epoch: 3 training loss = 0.0058, training accuracy = 99.79%, val_loss = 0.0326 & val_accuracy = 99.26%


epoch: 4 training loss = 0.0065, training accuracy = 99.78%, val_loss = 0.0366 & val_accuracy = 99.13%


'EarlyStopping' called!



Iterative Global pruning round = 2 with global sparsity = 30.0%
LeNet-4 Global Sparsity = 44.00%

Fine-tuning pruned model to recover model's performance


epoch: 1 training loss = 0.0046, training accuracy = 99.83%, val_loss = 0.0326 & val_accuracy = 99.30%

Saving model with lowest val_loss = 0.0326 for iterative pr

In [28]:
# Original unpruned weights-
best_model.state_dict()['conv1.weight_orig'][0, 0, :, :]

tensor([[-0.2519,  0.6400,  0.0337],
        [ 0.3763,  0.5416, -0.1623],
        [-0.0104, -0.2920, -0.4409]])

In [31]:
# Pruned mask-
best_model.state_dict()['conv1.weight_mask'][0, 0, :, :]

tensor([[0., 1., 0.],
        [0., 1., 0.],
        [0., 0., 0.]])

In [32]:
# Element-wise multiplication between weights & masks-
best_model.state_dict()['conv1.weight_orig'][0, 0, :, :] * best_model.state_dict()['conv1.weight_mask'][0, 0, :, :]

tensor([[-0.0000, 0.6400, 0.0000],
        [0.0000, 0.5416, -0.0000],
        [-0.0000, -0.0000, -0.0000]])

In [30]:
# Pruned weights-
best_model.conv1.weight[0, 0, :, :]

tensor([[-0.0000, 0.6400, 0.0000],
        [0.0000, 0.5416, -0.0000],
        [-0.0000, -0.0000, -0.0000]])

In [24]:
best_model.state_dict().keys()

odict_keys(['conv1.bias', 'conv1.weight_orig', 'conv1.weight_mask', 'conv2.bias', 'conv2.weight_orig', 'conv2.weight_mask', 'conv3.bias', 'conv3.weight_orig', 'conv3.weight_mask', 'op.bias', 'op.weight_orig', 'op.weight_mask'])

### Remove pruning re-parametrization:

In [33]:
prune.remove(best_model.conv1, 'weight')
prune.remove(best_model.conv2, 'weight')
prune.remove(best_model.conv3, 'weight')
prune.remove(best_model.op, 'weight')

Linear(in_features=1080, out_features=10, bias=True)

In [36]:
# list(best_model.conv1.named_parameters())

In [47]:
for layer_name in best_model.state_dict().keys():
    if 'bias' in layer_name:
        continue
    else:
        loc_sparsity = (torch.sum(best_model.state_dict()[layer_name] == 0) / best_model.state_dict()[layer_name].nelement()) * 100
        # print(f"layer = {layer_name} has {best_model.state_dict()[layer_name].shape}")
        print(f"layer = {layer_name} has {loc_sparsity:.2f}% sparsity")

layer = conv1.weight has 75.93% sparsity
layer = conv2.weight has 90.39% sparsity
layer = conv3.weight has 94.13% sparsity
layer = op.weight has 92.25% sparsity


In [49]:
for layer_name, param in best_model.named_parameters():
    print(f"layer: {layer_name} has shape: {param.shape}")

layer: conv1.bias has shape: torch.Size([6])
layer: conv1.weight has shape: torch.Size([6, 1, 3, 3])
layer: conv2.bias has shape: torch.Size([16])
layer: conv2.weight has shape: torch.Size([16, 6, 3, 3])
layer: conv3.bias has shape: torch.Size([120])
layer: conv3.weight has shape: torch.Size([120, 16, 3, 3])
layer: op.bias has shape: torch.Size([10])
layer: op.weight has shape: torch.Size([10, 1080])
