In [1]:
import torch
import torch.nn as nn


## Load Model Model

In [3]:
from torchvision import models
model = models.mobilenet_v2(weights='MobileNet_V2_Weights.IMAGENET1K_V1')
model.to(DEVICE)
model.eval()

MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=

## Prediction

In [21]:
from torchvision.models import MobileNet_V2_Weights
from torchvision import transforms
import urllib
from PIL import Image

url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
try: urllib.URLopener().retrieve(url, filename)
except: urllib.request.urlretrieve(url, filename)

# Load the weights and associated transforms
input_image = Image.open(filename)
# input_image.show()
# preprocess = transforms.Compose([
#     transforms.Resize(232),
#     transforms.CenterCrop(224),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ])
preprocess = MobileNet_V2_Weights.IMAGENET1K_V1.transforms()
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

with torch.no_grad():
    output = model(input_batch.to(DEVICE))
# Tensor of shape 1000, with confidence scores over ImageNet's 1000 classes
print(output[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
# probabilities = torch.nn.functional.softmax(output[0], dim=0)
# print(probabilities)

tensor([-7.4507e-01, -1.9598e+00, -1.2297e+00, -2.4998e+00,  6.1957e-01,
        -1.5875e+00, -6.6037e-01,  8.7215e-01,  8.0851e-01, -4.4422e+00,
         2.1560e+00,  2.2114e+00,  2.2380e+00,  2.4647e+00,  8.6744e-01,
         2.8453e+00,  2.2687e+00,  3.6068e+00,  3.5936e+00, -6.2327e-01,
        -1.0283e-01, -1.1037e+00, -8.8632e-01, -1.6670e+00, -1.5433e-01,
        -2.1496e+00, -2.7370e+00, -2.1549e+00, -2.4847e+00, -1.6090e+00,
        -1.7374e+00, -3.4063e+00, -2.0511e+00, -6.3207e-01, -5.3296e-01,
         2.0916e+00,  1.6979e+00,  1.8955e+00, -2.8600e+00,  2.1813e+00,
         5.0884e-01,  7.7504e-01,  1.1569e+00,  2.4549e+00,  1.9504e-01,
         1.1840e+00,  3.3721e+00, -2.3278e+00, -8.8066e-02, -2.3595e+00,
         4.1852e-01, -2.5609e+00,  5.4258e-01, -5.0072e-01, -2.3071e-01,
        -5.4639e-01, -1.9046e+00, -1.0521e+00,  4.8525e-01, -3.4383e-01,
        -1.1849e+00, -2.1489e+00, -2.6166e+00, -2.4700e+00, -4.4042e-01,
        -2.4792e+00, -2.6473e+00, -1.2092e+00, -2.6

In [8]:
output[0].max()

tensor(14.3573, device='cuda:0')

In [24]:
# Read the categories
with open("./imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]
# Show top categories per image
top5_prob, top5_catid = torch.topk(probabilities, 5)
for i in range(top5_prob.size(0)):
    print(categories[top5_catid[i]], top5_prob[i].item())

Samoyed 0.8303830623626709
Pomeranian 0.06986550986766815
keeshond 0.012912546284496784
collie 0.01080403383821249
Great Pyrenees 0.009873803704977036


## Evaluate Model

In [22]:
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torchvision.models import MobileNet_V2_Weights

# test_transforms = transforms.Compose([
#     transforms.Resize(256),
#     transforms.CenterCrop(224),
#     transforms.ToTensor(),
#     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# ])



test_dataset = datasets.ImageNet(root='./data',
                                 split='val',
                                transform=preprocess)

In [22]:
test_dataset

Dataset ImageNet
    Number of datapoints: 50000
    Root location: ./data
    Split: val
    StandardTransform
Transform: ImageClassification(
               crop_size=[224]
               resize_size=[256]
               mean=[0.485, 0.456, 0.406]
               std=[0.229, 0.224, 0.225]
               interpolation=InterpolationMode.BILINEAR
           )

In [23]:
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [33]:
def evaluate(model, device, test_loader):
    model.eval()
    
    losses = 0.0
    total_predictions = 0
    true_predictions = 0
    criterion = torch.nn.CrossEntropyLoss()
    with torch.no_grad():
    
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets) / inputs.size(0)
            losses += loss.item()
            _, predicted = torch.max(outputs, 1)
            
            batch_total_predictions = outputs.size(0)
            batch_true_predictions = (predicted == targets).sum().item()
            total_predictions += batch_total_predictions
            true_predictions += batch_true_predictions

            print(f'Batch {batch_idx}, Loss: {loss:.4f}, Accuracy: {batch_true_predictions/batch_total_predictions*100:.2f}%')
            
    return true_predictions / total_predictions, losses

In [55]:
accuracy_top1, accuracy_top5, losses = evaluate(model, DEVICE, test_dataloader)
print(f"acc@1: {accuracy_top1*100}%, acc@5: {accuracy_top5*100}%, loss: {losses}")


Batch 0, Loss: 0.0097, Accuracy@1: 84.38%, Accuracy@5: 95.31%
Batch 1, Loss: 0.0076, Accuracy@1: 85.94%, Accuracy@5: 95.31%
Batch 2, Loss: 0.0102, Accuracy@1: 79.69%, Accuracy@5: 96.88%
Batch 3, Loss: 0.0096, Accuracy@1: 79.69%, Accuracy@5: 98.44%
Batch 4, Loss: 0.0164, Accuracy@1: 71.88%, Accuracy@5: 90.62%
Batch 5, Loss: 0.0160, Accuracy@1: 67.19%, Accuracy@5: 90.62%
Batch 6, Loss: 0.0092, Accuracy@1: 85.94%, Accuracy@5: 95.31%
Batch 7, Loss: 0.0036, Accuracy@1: 93.75%, Accuracy@5: 98.44%
Batch 8, Loss: 0.0048, Accuracy@1: 92.19%, Accuracy@5: 98.44%
Batch 9, Loss: 0.0060, Accuracy@1: 90.62%, Accuracy@5: 95.31%
Batch 10, Loss: 0.0023, Accuracy@1: 95.31%, Accuracy@5: 98.44%
Batch 11, Loss: 0.0058, Accuracy@1: 92.19%, Accuracy@5: 96.88%
Batch 12, Loss: 0.0084, Accuracy@1: 82.81%, Accuracy@5: 93.75%
Batch 13, Loss: 0.0086, Accuracy@1: 85.94%, Accuracy@5: 96.88%
Batch 14, Loss: 0.0116, Accuracy@1: 85.94%, Accuracy@5: 95.31%
Batch 15, Loss: 0.0043, Accuracy@1: 95.31%, Accuracy@5: 96.88%
Ba

In [48]:
def evaluate(model, device, test_loader):
    model.eval()
    
    losses = 0.0
    total_predictions = 0
    true_predictions_top1 = 0
    true_predictions_top5 = 0
    criterion = torch.nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            
            # Compute loss
            loss = criterion(outputs, targets) / inputs.size(0)
            losses += loss.item()
            
            # Top-1 predictions
            _, predicted_top1 = torch.max(outputs, 1)
            batch_true_predictions_top1 = (predicted_top1 == targets).sum().item()
            true_predictions_top1 += batch_true_predictions_top1
            
            # Top-5 predictions
            _, predicted_top5 = torch.topk(outputs, 5, dim=1)
            batch_true_predictions_top5 = sum(
                [targets[i].item() in predicted_top5[i].tolist() for i in range(targets.size(0))]
            )
            true_predictions_top5 += batch_true_predictions_top5
            
            # Update total predictions
            batch_total_predictions = outputs.size(0)
            total_predictions += batch_total_predictions
            
            # Print batch metrics
            print(
                f'Batch {batch_idx}, Loss: {loss:.4f}, '
                f'Accuracy@1: {batch_true_predictions_top1/batch_total_predictions*100:.2f}%, '
                f'Accuracy@5: {batch_true_predictions_top5/batch_total_predictions*100:.2f}%'
            )
    
    # Compute overall accuracies
    accuracy_top1 = true_predictions_top1 / total_predictions
    accuracy_top5 = true_predictions_top5 / total_predictions
    
    return accuracy_top1, accuracy_top5, losses


In [52]:
accuracy_top1, accuracy_top5, losses = evaluate(model, DEVICE, test_dataloader)
print(f"acc@1: {accuracy_top1*100}%, acc@5: {accuracy_top5*100}%, loss: {losses}")


Batch 0, Loss: 0.0177, Accuracy@1: 84.38%, Accuracy@5: 93.75%
Batch 1, Loss: 0.0200, Accuracy@1: 79.69%, Accuracy@5: 93.75%
Batch 2, Loss: 0.0219, Accuracy@1: 78.12%, Accuracy@5: 98.44%
Batch 3, Loss: 0.0237, Accuracy@1: 73.44%, Accuracy@5: 95.31%
Batch 4, Loss: 0.0297, Accuracy@1: 78.12%, Accuracy@5: 92.19%
Batch 5, Loss: 0.0288, Accuracy@1: 64.06%, Accuracy@5: 90.62%
Batch 6, Loss: 0.0358, Accuracy@1: 60.94%, Accuracy@5: 84.38%
Batch 7, Loss: 0.0215, Accuracy@1: 82.81%, Accuracy@5: 93.75%
Batch 8, Loss: 0.0172, Accuracy@1: 85.94%, Accuracy@5: 98.44%
Batch 9, Loss: 0.0209, Accuracy@1: 76.56%, Accuracy@5: 89.06%
Batch 10, Loss: 0.0098, Accuracy@1: 90.62%, Accuracy@5: 100.00%
Batch 11, Loss: 0.0140, Accuracy@1: 89.06%, Accuracy@5: 93.75%
Batch 12, Loss: 0.0168, Accuracy@1: 78.12%, Accuracy@5: 92.19%
Batch 13, Loss: 0.0193, Accuracy@1: 79.69%, Accuracy@5: 92.19%
Batch 14, Loss: 0.0188, Accuracy@1: 79.69%, Accuracy@5: 93.75%
Batch 15, Loss: 0.0186, Accuracy@1: 82.81%, Accuracy@5: 95.31%
B

NameError: name 'acc' is not defined

In [53]:
print(f"acc@1: {accuracy_top1*100}%, acc@5: {accuracy_top5*100}%, loss: {losses}")

acc@1: 59.16400000000001%, acc@5: 82.32199999999999%, loss: 26.93688239902258


In [22]:
import json
test = {"model": {}, "hardware": {}}
with open("./metrics.json", "w") as f:
    json.dump(test, f, indent=2)

## Scaling-based Pruning (Channel-Level Pruning)

In [4]:
for parameter in model.named_parameters():
    # print(parameter)
    # print(parameter.size())
    print(parameter[0], parameter[1].shape)

features.0.0.weight torch.Size([32, 3, 3, 3])
features.0.1.weight torch.Size([32])
features.0.1.bias torch.Size([32])
features.1.conv.0.0.weight torch.Size([32, 1, 3, 3])
features.1.conv.0.1.weight torch.Size([32])
features.1.conv.0.1.bias torch.Size([32])
features.1.conv.1.weight torch.Size([16, 32, 1, 1])
features.1.conv.2.weight torch.Size([16])
features.1.conv.2.bias torch.Size([16])
features.2.conv.0.0.weight torch.Size([96, 16, 1, 1])
features.2.conv.0.1.weight torch.Size([96])
features.2.conv.0.1.bias torch.Size([96])
features.2.conv.1.0.weight torch.Size([96, 1, 3, 3])
features.2.conv.1.1.weight torch.Size([96])
features.2.conv.1.1.bias torch.Size([96])
features.2.conv.2.weight torch.Size([24, 96, 1, 1])
features.2.conv.3.weight torch.Size([24])
features.2.conv.3.bias torch.Size([24])
features.3.conv.0.0.weight torch.Size([144, 24, 1, 1])
features.3.conv.0.1.weight torch.Size([144])
features.3.conv.0.1.bias torch.Size([144])
features.3.conv.1.0.weight torch.Size([144, 1, 3, 3])

In [9]:
for parameter in model.parameters():
    print(parameter[0], parameter[1].shape)
    break

tensor([[[ 0.0132, -0.0043,  0.0148],
         [ 0.0328, -0.0254,  0.0069],
         [ 0.0105, -0.0373, -0.0147]],

        [[ 0.0080, -0.0059,  0.0151],
         [ 0.0200, -0.0329, -0.0021],
         [ 0.0114, -0.0330, -0.0079]],

        [[-0.0252, -0.0202, -0.0100],
         [-0.0112, -0.0293, -0.0152],
         [-0.0265, -0.0334, -0.0242]]], device='cuda:0',
       grad_fn=<SelectBackward0>) torch.Size([3, 3, 3])


In [32]:
import torchvision.models as models

# Iterate through all layers of the model
batchnorm_params = []

for name, module in model.named_modules():
    if isinstance(module, nn.BatchNorm2d):
        batchnorm_params.append((name, module))

# Print the names and parameters of BatchNorm layers
count = 0
for name, module in batchnorm_params:
    print(f"Layer: {name}")
    print(f"Weight: {module.weight.data}")
    print(f"Bias: {module.bias.data}")
    print()
    
    count += 1
    if count==1:
        break   


Layer: features.0.1
Weight: tensor([0.0381, 0.1872, 0.1975, 0.2451, 0.1313, 0.1590, 0.0881, 0.2552, 0.0870,
        0.0119, 0.4129, 0.1137, 0.2245, 0.3014, 0.0114, 0.0104, 0.0128, 0.0043,
        0.0668, 0.4108, 0.3578, 0.1278, 0.5250, 0.0039, 0.4444, 0.1298, 0.3284,
        0.2453, 0.3565, 0.3066, 0.4146, 0.1419], device='cuda:0')
Bias: tensor([-8.4354e-02,  5.6023e-01,  3.5002e-01,  2.8363e-01,  9.7327e-01,
         6.4774e-01,  4.9481e-01,  5.5817e-01,  6.1756e-01, -4.2980e-04,
        -3.0858e-01,  9.5334e-01,  4.4609e-01, -3.8414e-01, -9.3045e-04,
         5.7470e-03, -4.2064e-02, -1.7965e-02,  3.3821e-01,  1.1017e-01,
        -2.5284e-01,  5.0251e-01,  3.7990e-01, -1.5532e-02, -4.6869e-01,
         5.1056e-01, -2.8880e-01,  6.4006e-01, -1.0935e-01, -5.9483e-02,
         3.7479e-01,  2.6511e-01], device='cuda:0')

Layer: features.1.conv.0.1
Weight: tensor([0.0258, 0.5501, 0.7154, 0.4909, 1.1534, 0.4905, 0.5998, 0.8480, 0.5139,
        0.0061, 0.2110, 1.0502, 0.6857, 0.2376, 0.0059

In [85]:
import numpy as np
pruning_ratio = 0.3
# Print the names and parameters of BatchNorm layers
count = 0
for name, module in batchnorm_params:
    print(f"Layer: {name}")
    print(f"Weight: {module.weight.data}")
    print(f"Bias: {module.bias.data}")
    print()
    batchnorm_scale =  module.weight.data
    num_channels = batchnorm_scale.size(0)
    num_pruned_channels = int(num_channels * pruning_ratio)
    
    weights = {k: v for k, v in enumerate(batchnorm_scale.cpu().numpy())}
    sorted_weights = {k: v for k, v in sorted(weights.items(), key=lambda item: item[1])}
    
    keep_idx = list(sorted_weights.keys())[num_pruned_channels:]
    pruned_weights = batchnorm_scale[keep_idx]
    pruned_bias = module.bias.data[keep_idx]
    
    count += 1
    if count==1:
        break   

Layer: features.0.1
Weight: tensor([0.0381, 0.1872, 0.1975, 0.2451, 0.1313, 0.1590, 0.0881, 0.2552, 0.0870,
        0.0119, 0.4129, 0.1137, 0.2245, 0.3014, 0.0114, 0.0104, 0.0128, 0.0043,
        0.0668, 0.4108, 0.3578, 0.1278, 0.5250, 0.0039, 0.4444, 0.1298, 0.3284,
        0.2453, 0.3565, 0.3066, 0.4146, 0.1419], device='cuda:0')
Bias: tensor([-8.4354e-02,  5.6023e-01,  3.5002e-01,  2.8363e-01,  9.7327e-01,
         6.4774e-01,  4.9481e-01,  5.5817e-01,  6.1756e-01, -4.2980e-04,
        -3.0858e-01,  9.5334e-01,  4.4609e-01, -3.8414e-01, -9.3045e-04,
         5.7470e-03, -4.2064e-02, -1.7965e-02,  3.3821e-01,  1.1017e-01,
        -2.5284e-01,  5.0251e-01,  3.7990e-01, -1.5532e-02, -4.6869e-01,
         5.1056e-01, -2.8880e-01,  6.4006e-01, -1.0935e-01, -5.9483e-02,
         3.7479e-01,  2.6511e-01], device='cuda:0')



tensor([0.0881, 0.1137, 0.1278, 0.1298, 0.1313, 0.1419, 0.1590, 0.1872, 0.1975,
        0.2245, 0.2451, 0.2453, 0.2552, 0.3014, 0.3066, 0.3284, 0.3565, 0.3578,
        0.4108, 0.4129, 0.4146, 0.4444, 0.5250], device='cuda:0')

In [5]:
pruning_ratio = 0.3
pruned_channels = {}  # To store pruned indices for each layer
    
# Step 1: Analyze BatchNorm scaling factors
for name, module in model.named_modules():
    if isinstance(module, nn.BatchNorm2d):
        # Get the scale (\gamma) values
        gamma = module.weight.detach().cpu().numpy()
        num_channels = len(gamma)
        
        # Determine the number of channels to prune
        num_prune = int(pruning_ratio * num_channels)
        
        # Identify the indices of the smallest \gamma values
        prune_indices = gamma.argsort()[:num_prune]
        pruned_channels[name] = prune_indices

In [6]:
pruned_channels

{'features.0.1': array([23, 17, 15, 14,  9, 16,  0, 18,  8], dtype=int64),
 'features.1.conv.0.1': array([15, 17, 14,  9, 23, 16,  0, 24, 28], dtype=int64),
 'features.1.conv.2': array([10,  3,  8,  0], dtype=int64),
 'features.2.conv.0.1': array([72, 12, 49, 62, 31, 29, 90, 13, 38, 23, 43, 86, 22, 48,  7, 55, 28,
         2, 41, 66, 65, 78, 74, 17, 21,  9, 94, 63], dtype=int64),
 'features.2.conv.1.1': array([54, 50, 73, 77, 10, 27, 81, 93,  6, 36, 88, 67, 66, 70, 33, 17, 53,
        34, 94, 28, 51, 92, 75, 31, 16, 41, 57, 63], dtype=int64),
 'features.2.conv.3': array([ 6,  4, 23,  3, 21, 20,  1], dtype=int64),
 'features.3.conv.0.1': array([110,  44, 115,  87,  90, 120,  79,  97, 126, 114,   1, 142,  64,
         46, 138,  29,  30, 139,  73, 128, 140,  88,  83,  17,  95,  70,
        118, 106, 130,  45,  42, 125,  12,  54,  35,   8,  81,  41, 137,
        135,  43, 107, 116], dtype=int64),
 'features.3.conv.1.1': array([ 75,  24, 111,  13,  82,  99,  15, 116,   0,  14,  85,  55, 131

In [7]:
def prune_layer(layer, prune_indices):
    if isinstance(layer, nn.Conv2d):
        # Prune output channels
        weight = layer.weight.detach().cpu()
        new_weight = weight[~torch.tensor(prune_indices)].clone()
        
        # Update Conv2d weights
        layer.out_channels = new_weight.size(0)
        layer.weight = nn.Parameter(new_weight)
        
    elif isinstance(layer, nn.BatchNorm2d):
        # Prune BatchNorm parameters
        layer.weight = nn.Parameter(layer.weight.detach()[~torch.tensor(prune_indices)].clone())
        layer.bias = nn.Parameter(layer.bias.detach()[~torch.tensor(prune_indices)].clone())
        layer.running_mean = layer.running_mean.detach()[~torch.tensor(prune_indices)].clone()
        layer.running_var = layer.running_var.detach()[~torch.tensor(prune_indices)].clone()

        layer.num_features = layer.weight.size(0)
    
for name, layer in model.named_modules():
    if name in pruned_channels:
        print(name)
        prune_layer(layer, pruned_channels[name])

features.0.1
features.1.conv.0.1
features.1.conv.2
features.2.conv.0.1
features.2.conv.1.1
features.2.conv.3
features.3.conv.0.1
features.3.conv.1.1
features.3.conv.3
features.4.conv.0.1
features.4.conv.1.1
features.4.conv.3
features.5.conv.0.1
features.5.conv.1.1
features.5.conv.3
features.6.conv.0.1
features.6.conv.1.1
features.6.conv.3
features.7.conv.0.1
features.7.conv.1.1
features.7.conv.3
features.8.conv.0.1
features.8.conv.1.1
features.8.conv.3
features.9.conv.0.1
features.9.conv.1.1
features.9.conv.3
features.10.conv.0.1
features.10.conv.1.1
features.10.conv.3
features.11.conv.0.1
features.11.conv.1.1
features.11.conv.3
features.12.conv.0.1
features.12.conv.1.1
features.12.conv.3
features.13.conv.0.1
features.13.conv.1.1
features.13.conv.3
features.14.conv.0.1
features.14.conv.1.1
features.14.conv.3
features.15.conv.0.1
features.15.conv.1.1
features.15.conv.3
features.16.conv.0.1
features.16.conv.1.1
features.16.conv.3
features.17.conv.0.1
features.17.conv.1.1
features.17.conv

In [13]:
model

MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(28, eps=1e-

In [44]:
from torchvision import models
model = models.mobilenet_v2(weights='MobileNet_V2_Weights.IMAGENET1K_V1')
model.to(DEVICE)
model.eval()

MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=

In [45]:
total_weights_ori = 0
for i in model._modules['features']._modules['16']._modules['conv']._modules['0']:
    # if isinstance(i, nn.BatchNorm2d):
    #     print(i.weight.flatten().numel())
    #     total_weights_pruned += i.weight.flatten().numel()
    if isinstance(i, nn.Conv2d):
        print(i.weight.flatten().numel())
        total_weights_ori += i.weight.flatten().numel()
    # print(i)
    # print(i.weight.shape)
for i in model._modules['features']._modules['16']._modules['conv']._modules['1']:
    # if isinstance(i, nn.BatchNorm2d):
    #     print(i.weight.flatten().numel())
    #     total_weights_pruned += i.weight.flatten().numel()
    if isinstance(i, nn.Conv2d):
        print(i.weight.flatten().numel())
        total_weights_ori += i.weight.flatten().numel()

153600
8640


In [23]:
batch_norms = []
for i in range(1, 18):
    if i == 1:
        batch_norms.append(f'features.{i}.conv.0.1')
        continue    
    batch_norms.append(f'features.{i}.conv.1.1')

In [24]:
batch_norms

['features.1.conv.0.1',
 'features.2.conv.1.1',
 'features.3.conv.1.1',
 'features.4.conv.1.1',
 'features.5.conv.1.1',
 'features.6.conv.1.1',
 'features.7.conv.1.1',
 'features.8.conv.1.1',
 'features.9.conv.1.1',
 'features.10.conv.1.1',
 'features.11.conv.1.1',
 'features.12.conv.1.1',
 'features.13.conv.1.1',
 'features.14.conv.1.1',
 'features.15.conv.1.1',
 'features.16.conv.1.1',
 'features.17.conv.1.1']

In [17]:
# Only prune 1 layer
batch_norms = [batch_norms[15]]
batch_norms

['features.16.conv.1.1']

In [18]:
prune_percentage = 0.1
pruned_channels = {}  # To store pruned indices for each layer
layer_list = list(model.named_modules())  # Flatten the model's layers for traversal

# Step 1: Analyze BatchNorm scaling factors and determine channels to prune
for name, module in model.named_modules():
    if isinstance(module, nn.BatchNorm2d) and name in batch_norms:
        # Get the scale (\gamma) values
        gamma = module.weight.detach().cpu().numpy()
        num_channels = len(gamma)
        
        # Determine the number of channels to prune
        num_prune = int(prune_percentage * num_channels)
        
        # Identify the indices of the smallest \gamma values
        keep_indices = gamma.argsort()[num_prune:]
        pruned_channels[name] = keep_indices

# Step 2: Prune layers
def prune_layer(layer, prune_indices, is_input=False):
    if isinstance(layer, nn.Conv2d):
        if is_input:
            # Prune input channels
            weight = layer.weight.detach().cpu()
            new_weight = weight[:, torch.tensor(prune_indices)].clone().to(device=DEVICE)
            
            layer.in_channels = new_weight.size(1)
            layer.weight = nn.Parameter(new_weight).to(device=DEVICE)
        else:
            # Prune output channels
            weight = layer.weight.detach().cpu()
            new_weight = weight[torch.tensor(prune_indices)].clone().to(device=DEVICE)
            
            layer.out_channels = new_weight.size(0)
            layer.weight = nn.Parameter(new_weight).to(device=DEVICE)
            
            # Adjust the 'groups' parameter if it's a depthwise convolution
            if layer.groups == layer.in_channels:
                layer.groups = new_weight.size(0)
                layer.in_channels = new_weight.size(0)
                
    
    elif isinstance(layer, nn.BatchNorm2d):
        # Prune BatchNorm parameters
        layer.weight = nn.Parameter(layer.weight.detach()[torch.tensor(prune_indices)].clone()).to(device=DEVICE)
        layer.bias = nn.Parameter(layer.bias.detach()[torch.tensor(prune_indices)].clone()).to(device=DEVICE)
        layer.running_mean = layer.running_mean.detach()[torch.tensor(prune_indices)].clone().to(device=DEVICE)
        layer.running_var = layer.running_var.detach()[torch.tensor(prune_indices)].clone().to(device=DEVICE)
        
        layer.num_features = layer.weight.size(0)

# Traverse the model and prune connected layers
for i, (name, module) in enumerate(layer_list):
    if name in pruned_channels:
        prune_indices = pruned_channels[name]
        
        # Prune the current BatchNorm layer
        prune_layer(module, prune_indices)
        
        # Prune the preceding Conv2d (output channels)
        if i > 0:
            prev_name, prev_module = layer_list[i - 1]
            if isinstance(prev_module, nn.Conv2d):
                prune_layer(prev_module, prune_indices, is_input=False)
                if prev_module.groups == prev_module.in_channels:
                    j = i - 2
                    while j > 0:
                        prev_name, prev_module = layer_list[j]
                        if isinstance(prev_module, nn.BatchNorm2d):
                            prune_layer(prev_module, prune_indices)
                            prev_name, prev_module = layer_list[j-1]
                            if isinstance(prev_module, nn.Conv2d):
                                prune_layer(prev_module, prune_indices, is_input=False)
                                break
                        j -= 1
                        

        # Prune the following Conv2d (input channels)
        j = i
        # if i < len(layer_list) - 1:
        #     next_name, next_module = layer_list[i + 2]      # Next conv2d comes after ReLU6 activation layer
        #     if isinstance(next_module, nn.Conv2d):
        #         prune_layer(next_module, prune_indices, is_input=True)
        while j < len(layer_list) - 1:
            next_name, next_module = layer_list[j + 1]
            if isinstance(next_module, nn.Conv2d):
                prune_layer(next_module, prune_indices, is_input=True)
                break
            j += 1


In [21]:
model

MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=

torchvision.models.mobilenetv2.MobileNetV2

In [3]:
input = torch.randn((1,3,224,224)).to(device="cuda")
net(input)

tensor([[-5.0653e-02,  1.3092e+00,  2.5752e+00,  1.4306e+00,  2.3660e+00,
          1.6254e+00,  2.0543e+00,  1.8195e-01, -2.5556e-01, -1.5043e+00,
         -7.2094e-01,  4.9831e-01,  6.4751e-02,  7.8800e-02, -1.1369e+00,
         -1.2337e-01, -2.3843e-01, -4.3499e-01,  1.3113e+00, -8.5748e-01,
         -5.2804e-01,  7.2479e-01,  2.0583e-02, -3.9940e-02, -3.0934e-01,
         -6.5113e-01, -6.9464e-01, -2.8221e-01, -3.9663e-01,  8.4183e-02,
          3.7517e-01,  1.2465e-01, -1.4552e+00,  1.7520e+00,  2.2929e+00,
         -4.3805e-01, -3.3361e-02,  3.6526e-01, -1.3277e+00,  1.0576e-01,
          2.6324e-01, -1.8242e+00, -7.4608e-01, -3.9004e-01, -1.3254e+00,
         -9.3254e-01,  7.4371e-01, -9.4542e-01, -8.2086e-01, -9.7253e-01,
          8.0435e-01, -8.6156e-01, -3.3841e-01, -5.0917e-01,  1.3228e-02,
          8.0233e-01, -1.0509e+00, -1.2024e+00,  4.5280e-01,  1.5692e-01,
          6.1297e-01, -7.4407e-01, -1.4191e+00, -7.2231e-01,  6.0577e-01,
          1.2504e+00, -1.1276e+00,  1.

In [114]:
module_name = 'features.0'
module = model
for name in module_name.split('.'):
    print(name)
    module = module._modules[name]
    print(module)

features
Sequential(
  (0): Conv2dNormActivation(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU6(inplace=True)
  )
  (1): InvertedResidual(
    (conv): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU6(inplace=True)
      )
      (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (2): InvertedResidual(
    (conv): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)


In [69]:
for name, module in model.named_modules():
    if isinstance(module, nn.BatchNorm2d):
        # Get the scale values
        gamma = module.weight.detach().cpu().numpy()
        print(gamma.shape)
        num_channels = len(gamma)
        print(num_channels)

(32,)
32
(32,)
32
(16,)
16
(96,)
96
(96,)
96
(24,)
24
(144,)
144
(144,)
144
(24,)
24
(144,)
144
(144,)
144
(32,)
32
(192,)
192
(192,)
192
(32,)
32
(192,)
192
(192,)
192
(32,)
32
(192,)
192
(192,)
192
(64,)
64
(384,)
384
(384,)
384
(64,)
64
(384,)
384
(384,)
384
(64,)
64
(384,)
384
(384,)
384
(64,)
64
(384,)
384
(384,)
384
(96,)
96
(576,)
576
(576,)
576
(96,)
96
(576,)
576
(576,)
576
(96,)
96
(576,)
576
(576,)
576
(160,)
160
(960,)
960
(960,)
960
(160,)
160
(960,)
960
(960,)
960
(160,)
160
(960,)
960
(960,)
960
(320,)
320
(1280,)
1280


In [61]:
for i in model._modules['features']._modules['5']._modules['conv']._modules['1']:
    print(i)
    print(i.weight.shape)
    print(type(i))

Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False)
torch.Size([192, 1, 3, 3])
<class 'torch.nn.modules.conv.Conv2d'>
BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
torch.Size([192])
<class 'torch.nn.modules.batchnorm.BatchNorm2d'>
ReLU6(inplace=True)


AttributeError: 'ReLU6' object has no attribute 'weight'

In [43]:
total_weights_pruned = 0
for i in model._modules['features']._modules['16']._modules['conv']._modules['0']:
    # if isinstance(i, nn.BatchNorm2d):
    #     print(i.weight.flatten().numel())
    #     total_weights_pruned += i.weight.flatten().numel()
    if isinstance(i, nn.Conv2d):
        print(i.weight.flatten().numel())
        total_weights_pruned += i.weight.flatten().numel()
    # print(i)
    # print(i.weight.shape)
for i in model._modules['features']._modules['16']._modules['conv']._modules['1']:
    # if isinstance(i, nn.BatchNorm2d):
    #     print(i.weight.flatten().numel())
    #     total_weights_pruned += i.weight.flatten().numel()
    if isinstance(i, nn.Conv2d):
        print(i.weight.flatten().numel())
        total_weights_pruned += i.weight.flatten().numel()

138240
7776


In [46]:
total_weights_ori

162240

In [49]:
total_weights_pruned / total_weights_ori 

0.9

In [41]:
for i in model._modules['features']._modules['5']._modules['conv']._modules['0']:
    print(i)
    print(i.weight.shape)

Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
torch.Size([192, 32, 1, 1])
BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
torch.Size([192])
ReLU6(inplace=True)


AttributeError: 'ReLU6' object has no attribute 'weight'

In [53]:
prune_percentage = 0.9
for name, module in model.named_modules():
    if isinstance(module, nn.BatchNorm2d):
        # Get the scale (\gamma) values
        gamma = module.weight.detach().cpu().numpy()
        num_channels = len(gamma)
        
        # Determine the number of channels to prune
        num_prune = int(prune_percentage * num_channels)
        
        # Identify the indices of the smallest \gamma values
        keep_indices = gamma.argsort()[num_prune:]
        pruned_channels[name] = keep_indices

In [59]:
gamma

array([1.4253117, 1.4295772, 1.4305606, ..., 1.6586254, 1.7704152,
       1.7989348], dtype=float32)

In [1]:
import json
with open("sensivity_analysis.json", "r") as f:
    data = json.load(f)
data

{'model': {'yolov7': {'inference_time_gpu': [1, 2, 3],
   'inference_time_cpu': 0.8197850699994887,
   'num_parameter': 6033930,
   'memory_usage': 12.06786}}}

In [12]:
data = {}
for bn in batch_norms:
    data[bn] = []

In [16]:
data[bn]

[1]

## Global Pruning

In [8]:
from torchvision import models
import torch
import torch.nn as nn
import numpy as np

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.mobilenet_v2(weights='MobileNet_V2_Weights.IMAGENET1K_V1')
model.to(DEVICE)
model.eval()

MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=

In [9]:
batch_norms = []
for i in range(1, 18):
    if i == 1:
        batch_norms.append(f'features.{i}.conv.0.1')
        continue    
    batch_norms.append(f'features.{i}.conv.1.1')

In [22]:
pruning_ratio = 0.1

all_gammas = torch.cat([module.weight.flatten() for name, module in model.named_modules() if isinstance(module, nn.BatchNorm2d) and name in batch_norms])
# for name, module in model.named_modules():
#     if isinstance(module, nn.BatchNorm2d):
#         # Get the scale values
#         gamma = module.weight
        # batch_norms[name] = list(gamma)
        
        # total_channels += len(gamma)
 
prune_target = int(all_gammas.size(0) * pruning_ratio)  
threshold = torch.topk(all_gammas, prune_target, largest=False).values[-1]

In [23]:
threshold

tensor(0.1308, device='cuda:0', grad_fn=<SelectBackward0>)

In [10]:
std_dev = np.array([np.std(module.weight.detach().cpu().numpy()) for name, module in model.named_modules() if isinstance(module, nn.BatchNorm2d) and name in batch_norms])
std_dev

array([0.3246427 , 0.11976993, 0.06118749, 0.07926096, 0.0611471 ,
       0.05634503, 0.04216883, 0.0448107 , 0.04583936, 0.04520043,
       0.05498647, 0.04713114, 0.04968606, 0.05315563, 0.05127232,
       0.05739376, 0.04744095], dtype=float32)

In [11]:
import json
with open('../../notebooks/sensivity_analysis.json', 'r') as f:
    data = json.load(f)
data

{'features.1.conv.0.1': {'top1': [0.71096, 0.5757, 0.03474, 0.0022, 0.00114],
  'top5': [0.89948, 0.813, 0.09788, 0.00736, 0.00558],
  'loss': [14.403162048663944,
   22.223612068686634,
   103.64576007798314,
   131.74684715643525,
   138.54842003434896]},
 'features.2.conv.1.1': {'top1': [0.57006, 0.02824, 0.00234, 0.00102, 0.00162],
  'top5': [0.80558, 0.09398, 0.0128, 0.00766, 0.00534],
  'loss': [22.71068325964734,
   109.75838568434119,
   157.76405614987016,
   167.2277132384479,
   113.75423611328006]},
 'features.3.conv.1.1': {'top1': [0.3765, 0.00972, 0.00236, 0.00208, 0.00204],
  'top5': [0.61622, 0.03884, 0.01028, 0.0079, 0.0078],
  'loss': [38.63741221651435,
   106.57205509021878,
   131.6942581795156,
   130.46310277096927,
   131.80611641332507]},
 'features.4.conv.1.1': {'top1': [0.38188, 0.0451, 0.00462, 0.00156, 0.00108],
  'top5': [0.61692, 0.11048, 0.01654, 0.00704, 0.00516],
  'loss': [38.688165878877044,
   90.7589758131653,
   114.93547889590263,
   107.02099776

In [16]:
# Calculate layer sensivity
top1 = 71.1
top5 = 89.956
scales = {}
for i, bn in enumerate(data.keys()):
    layer_sensivity = data[bn]["top5"][0] - data[bn]["top5"][-1] + (top5 / 100 - data[bn]["top5"][0])
    scales[bn] = np.array(layer_sensivity) * std_dev[i]

In [18]:
scales

{'features.1.conv.0.1': 0.290224070250392,
 'features.2.conv.1.1': 0.10710066755533218,
 'features.3.conv.1.1': 0.05456456013649702,
 'features.4.conv.1.1': 0.0708910028219223,
 'features.5.conv.1.1': 0.0320092864985764,
 'features.6.conv.1.1': 0.020639183368161326,
 'features.7.conv.1.1': 0.037714954476952556,
 'features.8.conv.1.1': 0.03280501537919045,
 'features.9.conv.1.1': 0.0028062857322394873,
 'features.10.conv.1.1': 0.005678077998608354,
 'features.11.conv.1.1': 0.04911940982975065,
 'features.12.conv.1.1': 0.01889392982363701,
 'features.13.conv.1.1': 0.03535361602328718,
 'features.14.conv.1.1': 0.047462662765383724,
 'features.15.conv.1.1': 0.027049223961234093,
 'features.16.conv.1.1': 0.02664218142554164,
 'features.17.conv.1.1': 0.0418163540995121}

In [25]:
all_gammas_scaled = torch.cat([module.weight.flatten() / scales[name] for name, module in model.named_modules() if
                                isinstance(module, nn.BatchNorm2d) and name in batch_norms])

prune_target = int(all_gammas.size(0) * pruning_ratio)  
threshold_scaled = torch.topk(all_gammas_scaled, prune_target, largest=False).values[-1]
threshold_scaled

tensor(3.3956, device='cuda:0', grad_fn=<SelectBackward0>)

In [21]:
print(all_gammas_scaled)

tensor([0.0075, 0.1597, 0.2076,  ..., 0.0059, 0.0096, 0.0089], device='cuda:0',
       grad_fn=<CatBackward0>)

In [None]:
thresholds = {}
for i, bn in enumerate(batch_norms):
    thresholds[bn] = threshold * scale[i]

In [15]:
import numpy as np
for name, module in model.named_modules():
    if isinstance(module, nn.BatchNorm2d) and name in batch_norms:
        print(torch.std(module.weight))
        print(np.std(module.weight.detach().cpu().numpy()))
        

tensor(0.3298, device='cuda:0', grad_fn=<StdBackward0>)
0.3246427
tensor(0.1204, device='cuda:0', grad_fn=<StdBackward0>)
0.11976993
tensor(0.0614, device='cuda:0', grad_fn=<StdBackward0>)
0.061187495
tensor(0.0795, device='cuda:0', grad_fn=<StdBackward0>)
0.07926096
tensor(0.0613, device='cuda:0', grad_fn=<StdBackward0>)
0.061147105
tensor(0.0565, device='cuda:0', grad_fn=<StdBackward0>)
0.056345027
tensor(0.0423, device='cuda:0', grad_fn=<StdBackward0>)
0.042168826
tensor(0.0449, device='cuda:0', grad_fn=<StdBackward0>)
0.044810697
tensor(0.0459, device='cuda:0', grad_fn=<StdBackward0>)
0.04583936
tensor(0.0453, device='cuda:0', grad_fn=<StdBackward0>)
0.04520043
tensor(0.0551, device='cuda:0', grad_fn=<StdBackward0>)
0.054986466
tensor(0.0472, device='cuda:0', grad_fn=<StdBackward0>)
0.047131136
tensor(0.0497, device='cuda:0', grad_fn=<StdBackward0>)
0.049686056
tensor(0.0532, device='cuda:0', grad_fn=<StdBackward0>)
0.05315563
tensor(0.0513, device='cuda:0', grad_fn=<StdBackward0>)

In [359]:
prune_target

713

In [202]:
with open( "./test.txt", "w") as f:
    for v in torch.topk(all_gammas, prune_target, largest=False).values:
        f.writelines(f"{str(v.item())}\n")

In [358]:
threshold_.item()

0.1308184266090393

In [252]:
batch_norms = []
for i in range(1, 18):
    if i == 1:
        batch_norms.append(f'features.{i}.conv.0.1')
        continue    
    batch_norms.append(f'features.{i}.conv.1.1')

In [176]:
batch_norms = [batch_norms[16]]
batch_norms

['features.17.conv.1.1']

In [31]:

layer_list = list(model.named_modules())  # Flatten the model's layers for traversal
num_gamma = {}
num_gamma_pruned = {}
pruned_channels = {}

# scale = 10**7  # Scaling factor for 7 decimals
# print(threshold.item())
# threshold = np.ceil(threshold.detach().cpu().numpy() * scale) / scale
# print(threshold.item())
# threshold = threshold_.detach().cpu().numpy()
# print(threshold + 1e-5)

# Step 1: Analyze BatchNorm scaling factors and determine channels to prune
for name, module in model.named_modules():
    if isinstance(module, nn.BatchNorm2d) and name in batch_norms:
        # Get the scale (\gamma) values
        # gamma = module.weight.detach().cpu().numpy()
        gamma = module.weight.data
        
        # print(module.weight.detach().cpu().numpy().dtype)
        # print(threshold.dtype)

    
        # keep_indices = np.array(np.where(gamma > threshold + 1e-2)[0])
        # keep_indices = torch.where(gamma > threshold_)[0]
        keep_indices = torch.where(gamma / scales[name] > threshold_scaled)[0]
        pruned_channels[name] = keep_indices
        
        num_gamma[name] = len(gamma)
        num_gamma_pruned[name] = len(keep_indices)
        


In [347]:
 x = torch.randn(3).to(device="cuda")
x

tensor([ 1.6179, -1.2895, -0.0481], device='cuda:0')

In [350]:
torch.where(x > -1)

(tensor([0, 2], device='cuda:0'),)

In [302]:
gamma.argsort()

tensor([ 33, 605, 862, 874, 878, 395,  95, 261,  70, 850, 125,  11, 730, 910,
        325, 379, 471, 867,  69, 724, 679, 865,  21, 295, 113, 509,  32, 407,
        404, 436, 313, 799, 489, 921, 663, 362, 219, 123,  56, 660, 369, 189,
        691, 410, 378, 954, 637, 228, 882, 831,  66, 580, 280,  34, 417, 312,
          0, 665, 197, 524, 263, 202, 796, 365, 492, 411, 773, 816, 186, 221,
        277, 196, 815, 421, 842, 291, 833, 560, 735, 547, 134, 669, 614, 621,
        888, 401, 783, 683, 616, 641, 324, 806, 578, 629, 301, 931, 334, 940,
        685, 615, 155, 220, 484, 454, 533, 595, 264, 562, 430, 628, 824, 511,
        104,  10, 287, 772, 686, 351, 439, 464, 895, 327, 883, 565, 106, 870,
        306, 444, 754, 384, 363, 892, 881, 250, 915, 482, 793, 249, 266, 809,
        357, 167, 262, 434, 415, 244, 341, 548, 703, 494, 788,  94, 420,  68,
         50, 294, 323, 160, 105, 210, 690, 141, 610, 800,  28, 292, 936,  87,
         38, 452, 391, 869, 938, 635, 222, 721,  91, 748, 449, 2

In [313]:
gamma[509]

tensor(0.1065, device='cuda:0')

In [None]:
0.10553279519081116
0.10554493218660355
0.10554537177085876

In [351]:
threshold_.item()

0.10554537177085876

In [352]:
test = torch.tensor([0.10554493218660355], device="cuda:0")


In [354]:
test > threshold_

tensor([False], device='cuda:0')

In [298]:
print(len(gamma))
print(len(keep_indices))

960
936


In [30]:
scales.values()

dict_values([0.290224070250392, 0.10710066755533218, 0.05456456013649702, 0.0708910028219223, 0.0320092864985764, 0.020639183368161326, 0.037714954476952556, 0.03280501537919045, 0.0028062857322394873, 0.005678077998608354, 0.04911940982975065, 0.01889392982363701, 0.03535361602328718, 0.047462662765383724, 0.027049223961234093, 0.02664218142554164, 0.0418163540995121])

In [33]:
print(sum(list(num_gamma.values())))
print(sum(list(num_gamma_pruned.values())))


7136
6423


In [32]:
print(list(num_gamma.values()))
print(list(num_gamma_pruned.values()))


[32, 96, 144, 144, 192, 192, 192, 384, 384, 384, 384, 576, 576, 576, 960, 960, 960]
[3, 27, 124, 104, 187, 191, 191, 368, 384, 383, 314, 569, 530, 385, 960, 956, 747]


In [364]:
print(sum(list(num_gamma.values())))
print(sum(list(num_gamma_pruned.values())))


7136
6423


In [365]:
print(list(num_gamma.values()))
print(list(num_gamma_pruned.values()))


[32, 96, 144, 144, 192, 192, 192, 384, 384, 384, 384, 576, 576, 576, 960, 960, 960]
[23, 94, 137, 144, 181, 178, 191, 344, 317, 313, 378, 514, 501, 519, 896, 875, 818]


In [363]:
100 - (6423/7136 * 100 )

9.991591928251125

In [394]:
pruned = []
for i in range(len(list(num_gamma.values()))):
    gamma_pruned = list(num_gamma_pruned.values())[i]
    gamma = list(num_gamma.values())[i]
    perc = 100 - (gamma_pruned/ gamma * 100)
    pruned.append(f"{perc:.2f}")


In [395]:
print(pruned)

['28.12', '2.08', '4.86', '0.00', '5.73', '7.29', '0.52', '10.42', '17.45', '18.49', '1.56', '10.76', '13.02', '9.90', '6.67', '8.85', '14.79']


In [122]:
100 - (950/960 * 100 )

1.0416666666666572

In [80]:

# Step 2: Prune layers
def prune_layer(layer, prune_indices, is_input=False):
    if isinstance(layer, nn.Conv2d):
        if is_input:
            # Prune input channels
            weight = layer.weight.detach().cpu()
            new_weight = weight[:, torch.tensor(prune_indices)].clone().to(device=DEVICE)
            
            layer.in_channels = new_weight.size(1)
            layer.weight = nn.Parameter(new_weight).to(device=DEVICE)
        else:
            # Prune output channels
            weight = layer.weight.detach().cpu()
            new_weight = weight[torch.tensor(prune_indices)].clone().to(device=DEVICE)
            
            layer.out_channels = new_weight.size(0)
            layer.weight = nn.Parameter(new_weight).to(device=DEVICE)
            
            # Adjust the 'groups' parameter if it's a depthwise convolution
            if layer.groups == layer.in_channels:
                layer.groups = new_weight.size(0)
                layer.in_channels = new_weight.size(0)
                
    
    elif isinstance(layer, nn.BatchNorm2d):
        # Prune BatchNorm parameters
        layer.weight = nn.Parameter(layer.weight.detach()[torch.tensor(prune_indices)].clone()).to(device=DEVICE)
        layer.bias = nn.Parameter(layer.bias.detach()[torch.tensor(prune_indices)].clone()).to(device=DEVICE)
        layer.running_mean = layer.running_mean.detach()[torch.tensor(prune_indices)].clone().to(device=DEVICE)
        layer.running_var = layer.running_var.detach()[torch.tensor(prune_indices)].clone().to(device=DEVICE)
        
        layer.num_features = layer.weight.size(0)

# Traverse the model and prune connected layers
for i, (name, module) in enumerate(layer_list):
    if name in pruned_channels:
        prune_indices = pruned_channels[name]
        
        # Prune the current BatchNorm layer
        prune_layer(module, prune_indices)
        
        # Prune the preceding Conv2d (output channels)
        if i > 0:
            prev_name, prev_module = layer_list[i - 1]
            if isinstance(prev_module, nn.Conv2d):
                prune_layer(prev_module, prune_indices, is_input=False)
                if prev_module.groups == prev_module.in_channels:
                    j = i - 2
                    while j > 0:
                        prev_name, prev_module = layer_list[j]
                        if isinstance(prev_module, nn.BatchNorm2d):
                            prune_layer(prev_module, prune_indices)
                            prev_name, prev_module = layer_list[j-1]
                            if isinstance(prev_module, nn.Conv2d):
                                prune_layer(prev_module, prune_indices, is_input=False)
                                break
                        j -= 1
                        

        # Prune the following Conv2d (input channels)
        j = i
        # if i < len(layer_list) - 1:
        #     next_name, next_module = layer_list[i + 2]      # Next conv2d comes after ReLU6 activation layer
        #     if isinstance(next_module, nn.Conv2d):
        #         prune_layer(next_module, prune_indices, is_input=True)
        while j < len(layer_list) - 1:
            next_name, next_module = layer_list[j + 1]
            if isinstance(next_module, nn.Conv2d):
                prune_layer(next_module, prune_indices, is_input=True)
                break
            j += 1


[array([0.0381208 , 0.18724878, 0.19752091, 0.24511185, 0.13127756,
        0.15904109, 0.08813592, 0.25522655, 0.08698447, 0.01191824,
        0.4129336 , 0.11366259, 0.22449361, 0.30144963, 0.01136814,
        0.01036511, 0.01283585, 0.00430115, 0.06677727, 0.41075766,
        0.3577698 , 0.12777077, 0.52496415, 0.00393305, 0.44444874,
        0.12979506, 0.32841888, 0.24531788, 0.35647354, 0.30660468,
        0.4146144 , 0.14192693], dtype=float32),
 array([0.02575562, 0.5501319 , 0.71535563, 0.4909047 , 1.1533738 ,
        0.49047863, 0.5998267 , 0.8479905 , 0.5139333 , 0.00609509,
        0.210984  , 1.0502005 , 0.6857423 , 0.2375618 , 0.00594302,
        0.00494451, 0.00846775, 0.00503939, 0.573383  , 0.19520941,
        0.2612262 , 0.28118408, 0.2269232 , 0.0084268 , 0.08874259,
        1.0205832 , 0.27261943, 0.26077464, 0.12849563, 0.17686303,
        0.20431545, 0.40291643], dtype=float32),
 array([0.5125795 , 0.5345392 , 0.5720613 , 0.4572209 , 0.56945914,
        0.61156434