<a href="https://colab.research.google.com/github/IlvaX/ProjectUnet/blob/main/Global_pruning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Referece: https://pytorch.org/tutorials/intermediate/pruning_tutorial.html

In [None]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

In [None]:
from tensorflow.keras.utils import normalize
import os
import cv2
from PIL import Image
import numpy as np
from matplotlib import pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from keras.optimizers import Adam
import glob

import tensorflow as tf
from tensorflow import keras
%matplotlib inline


In [None]:
import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        # Encoder
        self.encoder_conv1 = nn.Conv2d(1, 64, 3, padding=1)
        self.encoder_bn1 = nn.BatchNorm2d(64)
        self.encoder_relu1 = nn.ReLU(inplace=True)
        self.encoder_pool1 = nn.MaxPool2d(2, 2)

        self.encoder_conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.encoder_bn2 = nn.BatchNorm2d(128)
        self.encoder_relu2 = nn.ReLU(inplace=True)
        self.encoder_pool2 = nn.MaxPool2d(2, 2)

        self.encoder_conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.encoder_bn3 = nn.BatchNorm2d(256)
        self.encoder_relu3 = nn.ReLU(inplace=True)
        self.encoder_pool3 = nn.MaxPool2d(2, 2)

        # Bridge
        self.bridge_conv = nn.Conv2d(256, 512, 3, padding=1)
        self.bridge_bn = nn.BatchNorm2d(512)
        self.bridge_relu = nn.ReLU(inplace=True)

        # Decoder
        self.decoder_upsample1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.decoder_conv1 = nn.Conv2d(512, 256, 3, padding=1)
        self.decoder_bn1 = nn.BatchNorm2d(256)
        self.decoder_relu1 = nn.ReLU(inplace=True)

        self.decoder_upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.decoder_conv2 = nn.Conv2d(256, 128, 3, padding=1)
        self.decoder_bn2 = nn.BatchNorm2d(128)
        self.decoder_relu2 = nn.ReLU(inplace=True)

        self.decoder_upsample3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.decoder_conv3 = nn.Conv2d(128, 64, 3, padding=1)
        self.decoder_bn3 = nn.BatchNorm2d(64)
        self.decoder_relu3 = nn.ReLU(inplace=True)

        self.output_conv = nn.Conv2d(64, 10, 1)
        self.output_softmax = nn.Softmax(dim=1)

    def forward(self, x):
        # Encoder
        enc1 = self.encoder_relu1(self.encoder_bn1(self.encoder_conv1(x)))
        enc1_pool = self.encoder_pool1(enc1)

        enc2 = self.encoder_relu2(self.encoder_bn2(self.encoder_conv2(enc1_pool)))
        enc2_pool = self.encoder_pool2(enc2)

        enc3 = self.encoder_relu3(self.encoder_bn3(self.encoder_conv3(enc2_pool)))
        enc3_pool = self.encoder_pool3(enc3)

        # Bridge
        bridge = self.bridge_relu(self.bridge_bn(self.bridge_conv(enc3_pool)))

        # Decoder
        dec1 = self.decoder_relu1(self.decoder_bn1(self.decoder_conv1(self.decoder_upsample1(bridge))))
        dec1_concat = torch.cat((dec1, enc3), dim=1)

        dec2 = self.decoder_relu2(self.decoder_bn2(self.decoder_conv2(self.decoder_upsample2(dec1_concat))))
        dec2_concat = torch.cat((dec2, enc2), dim=1)

        dec3 = self.decoder_relu3(self.decoder_bn3(self.decoder_conv3(self.decoder_upsample3(dec2_concat))))
        dec3_concat = torch.cat((dec3, enc1), dim=1)

        # Output
        output = self.output_conv(dec3_concat)
        output = self.output_softmax(output)

        return output

model = UNet().to(device=device)


In [None]:
prune.random_unstructured(module, name="weight", amount=0.3)

Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

In [None]:
model = UNet()

# Define the parameters to prune
parameters_to_prune = (
    (model.encoder_conv1, 'weight'),
    (model.encoder_conv2, 'weight'),
    (model.encoder_conv3, 'weight'),
    (model.bridge_conv, 'weight'),
    (model.decoder_conv1, 'weight'),
    (model.decoder_conv2, 'weight'),
    (model.decoder_conv3, 'weight'),
    (model.output_conv, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

In [None]:
print(
    "Sparsity in encoder_conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.encoder_conv1.weight == 0))
        / float(model.encoder_conv1.weight.nelement())
    )
)
print(
    "Sparsity in encoder_conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.encoder_conv2.weight == 0))
        / float(model.encoder_conv2.weight.nelement())
    )
)
print(
    "Sparsity in encoder_conv3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.encoder_conv3.weight == 0))
        / float(model.encoder_conv3.weight.nelement())
    )
)
print(
    "Sparsity in bridge_conv.weight: {:.2f}%".format(
        100. * float(torch.sum(model.bridge_conv.weight == 0))
        / float(model.bridge_conv.weight.nelement())
    )
)
print(
    "Sparsity in decoder_conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.decoder_conv1.weight == 0))
        / float(model.decoder_conv1.weight.nelement())
    )
)
print(
    "Sparsity in decoder_conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.decoder_conv2.weight == 0))
        / float(model.decoder_conv2.weight.nelement())
    )
)
print(
    "Sparsity in decoder_conv3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.decoder_conv3.weight == 0))
        / float(model.decoder_conv3.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.output_conv.weight == 0))
        / float(model.output_conv.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.encoder_conv1.weight == 0)
            + torch.sum(model.encoder_conv2.weight == 0)
            + torch.sum(model.encoder_conv3.weight == 0)
            + torch.sum(model.bridge_conv.weight == 0)
            + torch.sum(model.decoder_conv1.weight == 0)
            + torch.sum(model.decoder_conv2.weight == 0)
            + torch.sum(model.decoder_conv3.weight == 0)
            + torch.sum(model.output_conv.weight == 0)
        )
        / float(
            model.encoder_conv1.weight.nelement()
            + model.encoder_conv2.weight.nelement()
            + model.encoder_conv3.weight.nelement()
            + model.bridge_conv.weight.nelement()
            + model.decoder_conv1.weight.nelement()
            + model.decoder_conv2.weight.nelement()
            + model.decoder_conv3.weight.nelement()
            + model.output_conv.weight.nelement()
        )
    )
)

Sparsity in conv1.weight: 1.39%
Sparsity in conv2.weight: 9.07%
Sparsity in fc1.weight: 12.72%
Sparsity in fc2.weight: 17.99%
Sparsity in fc3.weight: 25.47%
Sparsity in fc3.weight: 18.06%
Sparsity in fc3.weight: 12.87%
Sparsity in fc3.weight: 2.50%
Global sparsity: 20.00%


Sparsity refers to the percentage of pruned weights (parameters) in a neural network. When pruning is applied to a model, a certain fraction of the weights are set to zero, effectively removing them from the network. The sparsity induced in pruned parameters is a measure of how much of the model's weights have been pruned or set to zero.