<a href="https://colab.research.google.com/github/anil2k/Prune-Neural-Network/blob/main/pruning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

1. Define and initialize your neural network model.

2. Record the weights of the model before pruning.

3. Perform the pruning process on the model, setting some weights to zero.

4. Fine-tune the pruned model if needed.

5. Record the weights of the model after pruning.

6. Calculate the weight changes by subtracting weights before pruning from weights after pruning.

7. Analyze the weight changes to understand the impact of pruning on the model's weights.

This process will help you observe how pruning has affected the weights in your neural network.

Import the necessary libraries, including PyTorch, to work with your neural network model.

In [44]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.utils.prune as prune

**Create the Neural Network Model:**

Define your neural network model using PyTorch's nn.Module. For this example, let's assume you have a simple feedforward neural network.

In [45]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.fc1 = nn.Linear(64, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

**Initialize the Model:**

Create an instance of your neural network and initialize it.

In [46]:
model = NeuralNetwork()

In [47]:
# Print the weights before pruning
print("Weights before pruning:")
for name, param in model.named_parameters():
    if 'weight' in name:  # Filter out weight tensors
        print(f"{name}: {param.data}")

Weights before pruning:
fc1.weight: tensor([[ 0.0193,  0.0884,  0.0630,  ..., -0.0782,  0.0537,  0.0424],
        [ 0.0483, -0.1067,  0.0960,  ..., -0.0868, -0.0355,  0.0917],
        [ 0.1136, -0.1091,  0.1016,  ...,  0.1048, -0.0786, -0.0734],
        ...,
        [ 0.0657, -0.0409,  0.1167,  ...,  0.1117, -0.0924,  0.0965],
        [ 0.0360,  0.0062,  0.0316,  ..., -0.1010,  0.0425,  0.0933],
        [ 0.0729, -0.0716,  0.1095,  ...,  0.0165,  0.0056, -0.0401]])
fc2.weight: tensor([[ 0.0144,  0.0135, -0.0039,  ...,  0.0365,  0.0299, -0.0451],
        [ 0.0714, -0.0416,  0.0592,  ..., -0.0413, -0.0011, -0.0872],
        [ 0.0005, -0.0438,  0.0728,  ...,  0.0213, -0.0381, -0.0129],
        ...,
        [-0.0072,  0.0538, -0.0630,  ..., -0.0628, -0.0199,  0.0584],
        [ 0.0583,  0.0012, -0.0753,  ...,  0.0472, -0.0534, -0.0863],
        [ 0.0077, -0.0484,  0.0137,  ..., -0.0730,  0.0441,  0.0255]])
fc3.weight: tensor([[-0.0846, -0.0899, -0.0101, -0.0182, -0.1071,  0.1141,  0.0585, 

In [48]:
# Record the weights before pruning
weights_before_pruning = {name: param.data.clone() for name, param in model.named_parameters() if 'weight' in name}


**Define the Pruning Method:**

You can choose from various pruning methods provided by PyTorch, such as L1Unstructured, RandomUnstructured, or custom methods. In this example, we'll use L1Unstructured pruning.

In [49]:
parameters_to_prune = (
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
)
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2  # Specify the pruning rate (e.g., 20%)
)

This code prunes 20% of the weights in the specified layers (fc1 and fc2) using L1Unstructured pruning.

Fine-Tune the Pruned Model:
After pruning, it's a good practice to fine-tune the model to recover some of the lost performance. Train the model with your dataset as you normally would.

Remove Pruning Masks (Optional):
If you want to completely remove the pruning masks and make the model permanently pruned, you can do so using the following code:

In [50]:
for module, _ in parameters_to_prune:
    prune.remove(module, 'weight')

In [51]:
# Print the weights after pruning
print("\nWeights after pruning:")
for name, param in model.named_parameters():
    if 'weight' in name:  # Filter out weight tensors
        print(f"{name}: {param.data}")


Weights after pruning:
fc1.weight: tensor([[ 0.0000,  0.0884,  0.0630,  ..., -0.0782,  0.0537,  0.0424],
        [ 0.0483, -0.1067,  0.0960,  ..., -0.0868, -0.0355,  0.0917],
        [ 0.1136, -0.1091,  0.1016,  ...,  0.1048, -0.0786, -0.0734],
        ...,
        [ 0.0657, -0.0409,  0.1167,  ...,  0.1117, -0.0924,  0.0965],
        [ 0.0360,  0.0000,  0.0316,  ..., -0.1010,  0.0425,  0.0933],
        [ 0.0729, -0.0716,  0.1095,  ...,  0.0000,  0.0000, -0.0401]])
fc2.weight: tensor([[ 0.0000,  0.0000, -0.0000,  ...,  0.0365,  0.0299, -0.0451],
        [ 0.0714, -0.0416,  0.0592,  ..., -0.0413, -0.0000, -0.0872],
        [ 0.0000, -0.0438,  0.0728,  ...,  0.0213, -0.0381, -0.0000],
        ...,
        [-0.0000,  0.0538, -0.0630,  ..., -0.0628, -0.0000,  0.0584],
        [ 0.0583,  0.0000, -0.0753,  ...,  0.0472, -0.0534, -0.0863],
        [ 0.0000, -0.0484,  0.0000,  ..., -0.0730,  0.0441,  0.0255]])
fc3.weight: tensor([[-0.0846, -0.0899, -0.0101, -0.0182, -0.1071,  0.1141,  0.0585, 

In [52]:
# Record the weights after pruning
weights_after_pruning = {name: param.data.clone() for name, param in model.named_parameters() if 'weight' in name}


In [53]:
# Compare the weight changes
print("Weight changes after pruning:")
for name in weights_before_pruning.keys():
    weight_diff = weights_before_pruning[name] - weights_after_pruning[name]
    print(f"{name} change:")
    print(weight_diff)

Weight changes after pruning:
fc1.weight change:
tensor([[0.0193, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0062, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0165, 0.0056, 0.0000]])
fc2.weight change:
tensor([[ 0.0144,  0.0135, -0.0039,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000, -0.0011,  0.0000],
        [ 0.0005,  0.0000,  0.0000,  ...,  0.0000,  0.0000, -0.0129],
        ...,
        [-0.0072,  0.0000,  0.0000,  ...,  0.0000, -0.0199,  0.0000],
        [ 0.0000,  0.0012,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0077,  0.0000,  0.0137,  ...,  0.0000,  0.0000,  0.0000]])
fc3.weight change:
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

**Save the Pruned Model:**

You can save the pruned model to a file for future use or deployment.

In [54]:
torch.save(model.state_dict(), 'pruned_model.pth')
