In [2]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune

# Define a simple neural network with one layer
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc = nn.Linear(10, 5)

    def forward(self, x):
        return self.fc(x)

# Instantiate the network
model = SimpleNet()

# Print the original weights
print("Original weights:\n", model.fc.weight)

# Apply random unstructured pruning to the weights of the 'fc' layer
prune.random_unstructured(model.fc, name='weight', amount=0.2)

# Print the pruned weights
print("Pruned weights:\n", model.fc.weight)


Original weights:
 Parameter containing:
tensor([[-0.2156,  0.1311, -0.1006, -0.1321,  0.1374,  0.1605,  0.1152,  0.1794,
         -0.2753,  0.2601],
        [-0.1734, -0.2749,  0.0317, -0.0125, -0.2882,  0.2254, -0.1127, -0.2972,
         -0.2354,  0.2357],
        [-0.0007, -0.0961,  0.0214, -0.2732, -0.1601, -0.1590,  0.2090, -0.0512,
         -0.3037, -0.2496],
        [ 0.1794,  0.1156,  0.0565,  0.1318,  0.0684,  0.1886,  0.0735,  0.2157,
         -0.2980, -0.0148],
        [ 0.0771, -0.1499,  0.0901,  0.2020,  0.2026, -0.1158, -0.0097, -0.1923,
          0.0941,  0.1082]], requires_grad=True)
Pruned weights:
 tensor([[-0.2156,  0.1311, -0.1006, -0.1321,  0.1374,  0.1605,  0.1152,  0.0000,
         -0.2753,  0.2601],
        [-0.1734, -0.2749,  0.0317, -0.0125, -0.2882,  0.2254, -0.1127, -0.2972,
         -0.2354,  0.0000],
        [-0.0007, -0.0961,  0.0214, -0.0000, -0.1601, -0.1590,  0.2090, -0.0512,
         -0.3037, -0.2496],
        [ 0.1794,  0.1156,  0.0000,  0.0000,  0.0

In [4]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import numpy as np

# Define a simple convolutional neural network
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)  # 3 input channels, 6 output channels, 5x5 filter

    def forward(self, x):
        return self.conv1(x)

# Instantiate the network
model = SimpleCNN()

# Print the original weight shape
print("Original weight shape:", model.conv1.weight.shape)


print("Original weights:\n", model.conv1.weight)

# Custom structured random pruning function for channels
def random_structured_pruning(module, name, amount, dim):
    weight = getattr(module, name)
    num_structures = weight.shape[dim]
    num_pruned = int(np.round(amount * num_structures))
    
    # Randomly select structures to prune
    indices = np.random.choice(num_structures, num_pruned, replace=False)
    
    # Create the mask
    mask = torch.ones_like(weight)
    for index in indices:
        if dim == 0:
            mask[index, :, :, :] = 0
        elif dim == 1:
            mask[:, index, :, :] = 0
    
    # Apply the mask
    prune.custom_from_mask(module, name, mask)


# Apply structured random pruning to the conv1 layer
random_structured_pruning(model.conv1, name='weight', amount=0.5, dim=0)  # Prune 50% of output channels

# Print the pruned weights
print("Pruned weights:\n", model.conv1.weight)


Original weight shape: torch.Size([6, 3, 5, 5])
Original weights:
 Parameter containing:
tensor([[[[ 0.0516,  0.0660,  0.0485, -0.0785, -0.0972],
          [-0.0106, -0.0661,  0.1062,  0.1049,  0.0392],
          [ 0.0761, -0.0666, -0.0689, -0.0445, -0.0396],
          [-0.0919, -0.0005, -0.0352, -0.0799, -0.0045],
          [-0.0766,  0.0686,  0.0975,  0.0509,  0.0282]],

         [[ 0.0864,  0.0999, -0.0444, -0.0434,  0.1066],
          [ 0.0083,  0.1107, -0.1127, -0.0742,  0.1089],
          [ 0.0366,  0.0949,  0.0476,  0.0064,  0.0116],
          [ 0.0337, -0.0729,  0.0813, -0.0812, -0.0375],
          [ 0.0723, -0.0875,  0.0646,  0.0665,  0.0863]],

         [[ 0.0896,  0.0253, -0.0151, -0.0530, -0.1117],
          [ 0.0396,  0.0275,  0.0121,  0.0532, -0.0459],
          [ 0.0766,  0.0739,  0.0944, -0.0212,  0.0952],
          [ 0.0074,  0.0677,  0.0134,  0.0938, -0.1080],
          [-0.0464,  0.0449,  0.0432,  0.1038, -0.1011]]],


        [[[ 0.0757, -0.0037, -0.0457, -0.1104,  