<a href="https://colab.research.google.com/github/DEEP-CGPS/P-CNN-P/blob/master/Pruning_CNN_Pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

First, the libraries needed for the execution of the algorithm are imported

In [None]:
from torch import nn
import torch.nn.utils.prune as prune
import torch
from torchsummary import summary

Then the class of the model is defined, which must correspond to a sequential CNN, in this case is the presented in the paper.

In [None]:
class Net(nn.Module):
    def __init__(self, num_classes=13):
        super(Net, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(64, 192, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(512, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = self.classifier(x)
        return x

If you want to load an already trained CNN you can load it with the following code, otherwise do not execute the next line.

In [None]:
new_model=torch.load('NetName.pth')
new_model.eval()

The device the network will be on (GPU or CPU) is set

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

The number of network parameters is visualized without pruning.

In [None]:
new_model=Net().to(device)#in case you have created the network with the class.
#new_model=new_model.to(device) # Comment out the previous line and uncomment it in case a network has been loaded.
print(summary(new_model,(3,224,224)))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 55, 55]          23,296
              ReLU-2           [-1, 64, 55, 55]               0
         MaxPool2d-3           [-1, 64, 27, 27]               0
            Conv2d-4          [-1, 192, 27, 27]         307,392
              ReLU-5          [-1, 192, 27, 27]               0
         MaxPool2d-6          [-1, 192, 13, 13]               0
            Conv2d-7          [-1, 384, 13, 13]         663,936
              ReLU-8          [-1, 384, 13, 13]               0
            Conv2d-9          [-1, 256, 13, 13]         884,992
             ReLU-10          [-1, 256, 13, 13]               0
           Conv2d-11          [-1, 256, 13, 13]         590,080
             ReLU-12          [-1, 256, 13, 13]               0
        MaxPool2d-13            [-1, 256, 6, 6]               0
AdaptiveAvgPool2d-14            [-1, 25

4 functions are established, which allow the restructuring of the network during pruning in the convolution and fully connected layers

In [None]:
def RESconv(module):
  clist=[]
  dw=len(module.weight)
  for i in range(0,len(module.weight)): 
    x=module.weight[i]
    y=module.bias[i]
    value = 0
    x = x[x.sum()!=value]
    a,b,c,d=x.size()
    if a==0:
      clist.append(i)
      dw=dw-1
      module.bias[i]=0
    if i>=1 and a!=0:
      x=torch.cat([xan,x])
      xan=x
    if i==0:
      xan=x
  value=0
  y=module.bias[:]
  y = y[y!=value]
  return xan,y,dw,clist

def CerosConv(module,clist,dw):
  module.in_channels=dw
  for i in range(0,len(module.weight[:])):
      y=module.weight[i][:]
      for j in clist:
          y[j]=0
      for k in range(0,len(y)): 
          x=y[k]
          value = 0
          x = x[x.sum()!=value]
          a,b,c=x.size()
          if k>=1 and a!=0:
            x=torch.cat([xan,x])
            xan=x
          if k==0:
            xan=x
      if i==0:    
          nx=xan.unsqueeze(0)
      else:
          nx=torch.cat([nx,xan.unsqueeze(0)])
  return nx

def RESFC(module):
  clist=[]
  dw=len(module.weight)
  for i in range(0,len(module.weight)): 
    x=module.weight[i]
    y=module.bias[i]
    value = 0
    x = x[x.sum()!=value]
    a,b=x.size()
    if a==0:
      clist.append(i)
      dw=dw-1
      module.bias[i]=0
    if i==0:    
      nx=x
    else:
      nx=torch.cat([nx,x])
  value=0
  y=module.bias[:]
  y = y[y!=value]
  return nx,y,dw,clist

def CerosFC(module,clist,imdfc1,fc1):
  nclist=[]
  if fc1==1:
    dim=imdfc1*imdfc1
    for i in range(0,len(clist)):
      inp=clist[i]*dim
      enp=inp+dim
      listprev=list(range(inp,enp))
      for j in listprev:
        nclist.append(j)
  else:
    nclist=clist
  for i in range(0,len(module.weight[:])):
    y=module.weight[i][:]
    for j in nclist:
      y[j]=0
    x=y
    value = 0
    x = x[x !=value]
    if i==0:    
        nx=x.unsqueeze(0)
    else:
        nx=torch.cat([nx,x.unsqueeze(0)])
  return nx

In this stage, the L2 norm is used to carry out pruning in all the layers of the CNN, and this process is supported by the functions previously created for the restructuring stage of the CNN in each layer.

In [None]:
imdfc1=6 #this value must be modified with the image dimension before applying the flattening
dl=0
dln=0
cont=0
for name, module in new_model.named_modules():
  dl+=1
print(dl)

pp=0.8 #here is established the % of pruning, for example, 0.8 = 80%

for name, module in new_model.named_modules():
  if isinstance(module, torch.nn.Conv2d):
    if cont==1:
      nx=CerosConv(module,clist,dw)
      module.weight=nn.Parameter(nx) 
    clist=[]    
    prune.ln_structured(module, name="weight", amount=pp, n=2, dim=0)
    prune.remove(module, 'weight')
    print(list(module.named_parameters()))
    xan,y,dw,clist=RESconv(module)
    module.out_channels=dw
    module.weight=nn.Parameter(xan)
    module.bias=nn.Parameter(y)
    cont=1
  if isinstance(module, torch.nn.Linear) and cont==1:
    module.in_features=dw*imdfc1*imdfc1
    nx=CerosFC(module,clist,imdfc1,1)
    module.weight=nn.Parameter(nx)
    cont=2

cont=0
for name, module in new_model.named_modules():
  dln+=1
  if isinstance(module, torch.nn.Linear):
    if cont==1:
      module.in_features=dw
      nx=CerosFC(module,clist,imdfc1,0)
      module.weight=nn.Parameter(nx)
    if dln<dl:
      prune.ln_structured(module, name="weight", amount=pp, n=2, dim=0)
      prune.remove(module, 'weight')
      nx,y,dw,clist=RESFC(module)
      module.out_features=dw
      module.weight=nn.Parameter(nx)
      module.bias=nn.Parameter(y)
      cont=1

24
[('bias', Parameter containing:
tensor([ 0.0194, -0.0263, -0.0378,  0.0312,  0.0278, -0.0241, -0.0129,  0.0387,
         0.0156,  0.0376,  0.0302, -0.0297,  0.0035, -0.0343, -0.0197,  0.0506,
        -0.0373,  0.0040, -0.0087,  0.0237,  0.0370, -0.0123,  0.0467,  0.0211,
         0.0160, -0.0337,  0.0206, -0.0303, -0.0074,  0.0032,  0.0125, -0.0362,
        -0.0257,  0.0477, -0.0209,  0.0442, -0.0286, -0.0240, -0.0489,  0.0148,
         0.0194,  0.0452,  0.0272,  0.0400,  0.0378,  0.0268,  0.0515,  0.0293,
        -0.0313, -0.0086, -0.0514,  0.0342,  0.0236, -0.0474, -0.0410, -0.0395,
         0.0471, -0.0504,  0.0299, -0.0002, -0.0156,  0.0262,  0.0152, -0.0430],
       device='cuda:0', requires_grad=True)), ('weight', Parameter containing:
tensor([[[[-0.0459,  0.0170,  0.0393,  ...,  0.0121,  0.0151, -0.0020],
          [ 0.0234,  0.0428,  0.0232,  ...,  0.0101,  0.0012,  0.0511],
          [-0.0393,  0.0132, -0.0159,  ..., -0.0015, -0.0404,  0.0427],
          ...,
          [-0.

The number of network parameters is shown again, identifying that there was indeed a pruning and restructuring of the network.

In [None]:
print(summary(new_model,(3,224,224)))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 13, 55, 55]           4,732
              ReLU-2           [-1, 13, 55, 55]               0
         MaxPool2d-3           [-1, 13, 27, 27]               0
            Conv2d-4           [-1, 38, 27, 27]          12,388
              ReLU-5           [-1, 38, 27, 27]               0
         MaxPool2d-6           [-1, 38, 13, 13]               0
            Conv2d-7           [-1, 77, 13, 13]          26,411
              ReLU-8           [-1, 77, 13, 13]               0
            Conv2d-9           [-1, 51, 13, 13]          35,394
             ReLU-10           [-1, 51, 13, 13]               0
           Conv2d-11           [-1, 51, 13, 13]          23,460
             ReLU-12           [-1, 51, 13, 13]               0
        MaxPool2d-13             [-1, 51, 6, 6]               0
AdaptiveAvgPool2d-14             [-1, 5

The network is saved for use.

In [None]:
torch.save(new_model,'New_Net_pruning.pth')