In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.nn.utils.prune as prune
import torchvision.models as models
import torchvision.datasets as datasets
import time
import test
import train


  warn(f"Failed to load image Python extension: {e}")


In [2]:
##Dataset - CIFAR10 Dataset for demonstration

num_classes = 10

trn_batch_size = 64
val_batch_size = 64


train_tfms = transforms.Compose([
transforms.RandomCrop(32, padding = 4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])

valid_tfms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])


fullset = datasets.CIFAR10(root='./data10', train=True, download=True, transform=train_tfms)
testset = datasets.CIFAR10(root='./data10', train=False, download=True, transform=valid_tfms)

trainloader = torch.utils.data.DataLoader(fullset, batch_size=trn_batch_size,
                                            shuffle=False, pin_memory=True, num_workers=1)

valloader = torch.utils.data.DataLoader(testset, batch_size=val_batch_size,
                                        shuffle=False, pin_memory=True, num_workers=1)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# Model ResNet50 for Demonstration purpose, Trained model weights are resnet50_acc70.pth and the weights after pruning and fine-tuning is pmodel.pth, 
# Feel free to download and test them.

device = "cuda" if torch.cuda.is_available() else "cpu"

model = models.resnet50().to(device)
model.fc = nn.Linear(in_features=2048, out_features=num_classes, bias=True).to(device)
model.conv1 = nn.Conv2d(3, 64, 3, 3, bias=False).to(device)
model.load_state_dict(torch.load('resnet50_acc70.pth'))

<All keys matched successfully>

In [4]:

criterion = nn.CrossEntropyLoss()
# Evaluate the model using the imported module
start = time.time()
test.evaluate_model(
    model=model,
    test_loader=valloader,
    criterion=criterion,
    device=device
)
print(f"Inference time Before Pruning: {time.time() - start}s")

Test Loss: 0.7726, Test Accuracy: 72.71%
Inference time Before Pruning: 5.090752363204956s


In [5]:
#Single layer Pruning

prune.l1_unstructured(model.conv1, name="weight", amount=0.25)

# Check the pruned weights and the mask
print("Pruned weights for conv1:\n", model.conv1.weight)
print("Mask applied to conv1 weights:\n", model.conv1.weight_mask)

Pruned weights for conv1:
 tensor([[[[ 0.0000,  0.0543,  0.0509],
          [-0.1251, -0.0647, -0.1438],
          [ 0.0904, -0.0438, -0.1082]],

         [[-0.1177, -0.0740, -0.0850],
          [-0.1907, -0.0997, -0.2078],
          [ 0.0581, -0.0436, -0.0854]],

         [[ 0.0000,  0.0400,  0.0000],
          [-0.0318,  0.0402, -0.0359],
          [ 0.0765, -0.0000, -0.0000]]],


        [[[-0.1408, -0.0958, -0.1608],
          [ 0.1024,  0.1083,  0.0813],
          [ 0.0000, -0.0000,  0.0000]],

         [[-0.1013, -0.0433, -0.1189],
          [ 0.1391,  0.1489,  0.1136],
          [-0.0000, -0.0000,  0.0000]],

         [[-0.0817, -0.0369, -0.0685],
          [ 0.1072,  0.0965,  0.1041],
          [-0.0000, -0.0569,  0.0000]]],


        [[[-0.0295, -0.0000,  0.0785],
          [-0.0274, -0.0000,  0.0481],
          [-0.0444, -0.0000,  0.0802]],

         [[-0.0800, -0.0281,  0.1298],
          [-0.0996, -0.0000,  0.1037],
          [-0.1091, -0.0000,  0.1395]],

         [[-0.133

In [6]:
# Pruning Multiple layers
layers_to_prune = ["conv1", "layer1.0.conv1"]

for layer_name in layers_to_prune:
   
    modules = layer_name.split('.')
    layer = model
    for module in modules:
        if module.isdigit():
            layer = layer[int(module)] 
        else:
            layer = getattr(layer, module)

    prune.l1_unstructured(layer, name="weight", amount=0.25)


In [7]:
# Calculate Inference time after pruning
criterion = nn.CrossEntropyLoss()

start = time.time()
test.evaluate_model(
    model=model,
    test_loader=valloader,
    criterion=criterion,
    device=device
)
print(f"Inference time After Pruning: {time.time() - start}s")

Test Loss: 0.9610, Test Accuracy: 66.47%
Inference time After Pruning: 3.9290804862976074s


In [None]:
# Accuracy drops by almost 6% while less than 2 sec reduction in inference time, So Finetune for 10 Epochs with default masking, 
# in case the pruned weights are still being updated, create your own mask and use the mask as the parameter 'mask_dict'. 
# Example: to create a mask from scratch 
 
# def l1_norm_mask(weight, amount):
#     num_prune = int(amount * x.flatten().numel())
#     threshold = torch.topk(torch.abs(x).flatten().numel(), num_prune, largest=False).values.max()
#     mask = (torch.abs(weight) > threshold).int() # Create binary mask (1 = keep, 0 = prune)
#     return mask

# # Prune 50% of the weights

# linear_layer = model.fc
# weight_matrix = model.fc.weight.detach()
# mask = l1_norm_mask(weight_matrix, 0.5)
# print("L1 norm-based mask:", mask)

batch_size = 64
num_epochs = 10
optimizer_name = 'Adam'
scheduler_name = 'StepLR'
lr = 0.001
step_size = 10
gamma = 0.1
model_path = 'pmodel.pth'

# Train and save the model using the imported module
saved_model_path = train.train_and_save_model(
    model=model,
    train_loader=trainloader,
    num_epochs=num_epochs,
    optimizer_name=optimizer_name,
    scheduler_name=scheduler_name,
    lr=lr,
    step_size=step_size,
    gamma=gamma,
    model_path=model_path,
    mask_dict=None
)

print(f"Trained model saved to {saved_model_path}")


Epoch [1/10], Loss: 0.7952, Accuracy: 72.29%
Epoch [2/10], Loss: 0.7699, Accuracy: 73.31%
Epoch [3/10], Loss: 0.7478, Accuracy: 73.86%
Epoch [4/10], Loss: 0.7287, Accuracy: 74.75%
Epoch [5/10], Loss: 0.7127, Accuracy: 75.32%
Epoch [6/10], Loss: 0.7026, Accuracy: 75.71%
Epoch [7/10], Loss: 0.6894, Accuracy: 76.04%


In [4]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 