# Pruning a PyTorch Model

Note : Do not change the order of execution of the following cells

In [None]:
# Creating a Simple PyTorch Model
import torch
import torch.nn as nn


class ExampleModel(nn.Module):
    def __init__(self):
        super(ExampleModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1)
        self.relu2 = nn.ReLU()
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(6272, 10)  # Example: incorrect input_features

    def forward(self, x):
        x1 = self.conv1(x)
        x1 = self.relu1(x1)
        x2 = self.conv2(x)
        x2 = self.relu2(x2)
        x = torch.cat([x1, x2], dim=1)
        x = self.flatten(x)
        x = self.linear(x)
        return x


model = ExampleModel()

In [2]:
# Test the model with dummy data

input = torch.randn(1, 3, 28, 28)
output = model(input)

print(output.shape)

torch.Size([1, 10])


In [3]:
# Prune output channels of conv1 layer of the model and verfy the Model
from metinor.profiler.handlers.rms import conv_channelwise_rms, linear_weights_rms
from metinor.optimization.pruning.strategies import (
    prune_output_channels_by_rms,
    prune_input_channels_by_rms,
    change_linear_input,
    change_linear_output,
)
from metinor.optimization.pruning import verify_model
from metinor.visualization import draw_graph, calculate_max_depth

# Create the graph of the model
num_channels_prune = 2
input_shape = (1, 3, 28, 28)
depth_model = calculate_max_depth(model)
graph = draw_graph(
    model,
    input_size=input_shape,
    device="cpu",
    expand_nested=True,
    graph_dir="TB",
    roll=True,
    show_shapes=True,
    depth=depth_model,
    criteria=0.0254,
)

# Prune the model

pruned_mod, _, _ = prune_output_channels_by_rms(
    model.conv1,
    conv_channelwise_rms(model.conv1, rms_dim="op_channels"),
    num_channels_prune,
)

# Verify the pruned model with dummy data

failed_node = verify_model(model, torch.randn(input_shape), graph)
print(failed_node)

140127068521520-




In [4]:
# As you can see one node is failed. It is the node of the linear layer.

# Change the input features of the linear layer
pruned_mod, _, _ = change_linear_input(
    model.linear, linear_weights_rms(model.linear, rms_dim="ip"), 392
)  # 392 is the correct input features [6272/(16*2)]*2(pruned channels)


# Verify the pruned model with dummy data
failed_node = verify_model(model, torch.randn(input_shape), graph)
print(failed_node)

None


In [5]:
# Test the model with dummy data

input = torch.randn(1, 3, 28, 28)
output = model(input)

print(output.shape)

torch.Size([1, 10])


In [6]:
# To verify the conv input channel pruning we change the input channels of the input
input = torch.randn(1, 1, 28, 28)

# Verify the pruned model with dummy data
output = model(input)

RuntimeError: Given groups=1, weight of size [14, 3, 3, 3], expected input[1, 1, 28, 28] to have 3 channels, but got 1 channels instead

In [7]:
# To handle this we change the input channels of the conv layer conv1 and conv2
num_channels_prune = 2
pruned_mod, _, _ = prune_input_channels_by_rms(
    model.conv1,
    conv_channelwise_rms(model.conv1, rms_dim="ip_channels"),
    num_channels_prune,
)
pruned_mod, _, _ = prune_input_channels_by_rms(
    model.conv2,
    conv_channelwise_rms(model.conv2, rms_dim="ip_channels"),
    num_channels_prune,
)

In [8]:
# To verify the conv input channel pruning we change the input channels of the input
input = torch.randn(1, 1, 28, 28)

# Verify the pruned model with dummy data
output = model(input)