In [32]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

In [33]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet,self).__init__()
        self.conv1=nn.Conv2d(1,6,5)
        self.conv2=nn.Conv2d(6,16,5)
        self.fc1=nn.Linear(16*5*5,128)
        self.fc2=nn.Linear(128,84)
        self.fc3=nn.Linear(84,10)
    def forward(self,x):
        x=F.max_pool2d(F.relu(self.conv1(x)),(2,2))
        x=F.max_pool2d(F.relu(self.conv2(x)),2)
        x=x.view(-1,int(x.nelement()/x.shape[0]))
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=self.fc3(x)
        return x
model=LeNet().to(device=device)

In [42]:
# module=model.conv1
# print(list(module.named_parameters()))\
# prune.random_unstructured(module,name="weight",amount=.5)

In [40]:
def prune_filters(model,k):
    for m in model.modules():
        if isinstance(m,nn.Conv2d) or isinstance(m,nn.Linear):
            prune.l1_unstructured(m,"weight",amount=k)
            prune.remove(m,"weight")

In [41]:
prune_filters(model,0.5)

In [44]:
module=model.conv1
print(list(module.named_parameters()))
print(module.weight)


[('bias', Parameter containing:
tensor([ 0.1471,  0.0196,  0.1756, -0.1621, -0.1106, -0.0398],
       requires_grad=True)), ('weight', Parameter containing:
tensor([[[[ 0.1956,  0.1691,  0.1322, -0.0000,  0.1843],
          [ 0.1074, -0.0000,  0.0000,  0.1912,  0.0000],
          [ 0.1410,  0.0000,  0.0000, -0.1634, -0.1208],
          [-0.1964,  0.1793, -0.1848, -0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000,  0.1180,  0.1335]]],


        [[[-0.1178,  0.1615,  0.0000,  0.1979,  0.1908],
          [-0.0000, -0.0000,  0.0000, -0.1295, -0.0000],
          [ 0.1419, -0.0000, -0.0000, -0.1334,  0.1887],
          [-0.0000, -0.0000,  0.0000, -0.0000, -0.1444],
          [-0.1625,  0.1633, -0.0000, -0.1501,  0.0000]]],


        [[[ 0.1043, -0.0000, -0.1836, -0.0000, -0.1555],
          [-0.1690,  0.1108, -0.1972, -0.0000,  0.0000],
          [ 0.1410, -0.0000,  0.0000, -0.0000,  0.1178],
          [ 0.1566, -0.1585, -0.1040, -0.1428, -0.1070],
          [ 0.0000,  0.1982,  0.0000,