In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.nn.utils.prune as prune
import torchvision.models as models
import torchvision.datasets as datasets
import time
import test
import train
import torch_pruning as tp

  warn(f"Failed to load image Python extension: {e}")


In [2]:
##Dataset - CIFAR10 Dataset for demonstration

trn_batch_size = 64
val_batch_size = 64


train_tfms = transforms.Compose([
transforms.RandomCrop(32, padding = 4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])

valid_tfms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])


fullset = datasets.CIFAR10(root='./data10', train=True, download=True, transform=train_tfms)
testset = datasets.CIFAR10(root='./data10', train=False, download=True, transform=valid_tfms)

trainloader = torch.utils.data.DataLoader(fullset, batch_size=trn_batch_size,
                                            shuffle=False, pin_memory=True, num_workers=1)

valloader = torch.utils.data.DataLoader(testset, batch_size=val_batch_size,
                                        shuffle=False, pin_memory=True, num_workers=1)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = models.resnet50().to(device).eval()
model.load_state_dict(torch.load('resnet50.pth'))

<All keys matched successfully>

In [4]:
criterion = nn.CrossEntropyLoss()
# Evaluate the model using the imported module
start = time.time()
test.evaluate_model(
    model=model,
    test_loader=trainloader,
    criterion=criterion,
    device=device
)
print(f"Inference time Before Pruning: {time.time() - start}s")

Test Loss: 25.2504, Test Accuracy: 0.01%
Inference time Before Pruning: 17.52191662788391s


In [5]:

# 1. Build dependency graph for a resnet18. This requires a dummy input for forwarding
DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1,3,224,224).to(device))

# 2. Get the group for pruning model.conv1 with the specified channel idxs
group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] )

# 3. Do the pruning
if DG.check_pruning_group(group): # avoid over-pruning, i.e., channels=0.
    group.prune()
    
# 4. Save & Load
model.zero_grad() # clear gradients to avoid a large file size
torch.save(model, 'model.pth') # !! no .state_dict for saving
model = torch.load('model.pth') # load the pruned model

In [6]:
model

ResNet(
  (conv1): Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(61, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(61, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(61, 256, kernel_size=(1, 1), stride=(1, 

In [7]:
# Calculate Inference time after pruning
criterion = nn.CrossEntropyLoss()

start = time.time()
test.evaluate_model(
    model=model,
    test_loader=trainloader,
    criterion=criterion,
    device=device
)
print(f"Inference time After Pruning: {time.time() - start}s")

Test Loss: 24.6318, Test Accuracy: 0.03%
Inference time After Pruning: 14.433615446090698s
