In [2]:
import os, sys
import torch
import torch.nn.functional as F

from torchvision.datasets import ImageFolder
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
import torch.nn as nn
import timm
import torch_pruning as tp

import pbench
pbench.forward_patch.patch_timm_forward() # patch timm.forward() to support pruning

from tqdm import tqdm
import argparse

import torchvision as tv
import math

In [3]:
from torch.utils.data import Subset
import random

def prepare_imagenet(imagenet_root, train_batch_size=64, val_batch_size=128, num_workers=4, use_imagenet_mean_std=True, interpolation='bicubic', val_resize=256):
    """The imagenet_root should contain train and val folders.
    """
    interpolation = getattr(T.InterpolationMode, interpolation.upper())

    print('Parsing dataset...')
    train_dst = ImageFolder(os.path.join(imagenet_root, 'train'), 
                            transform=pbench.data.presets.ClassificationPresetEval(
                                mean=[0.485, 0.456, 0.406] if use_imagenet_mean_std else [0.5, 0.5, 0.5],
                                std=[0.229, 0.224, 0.225] if use_imagenet_mean_std else [0.5, 0.5, 0.5],
                                crop_size=224,
                                resize_size=val_resize,
                                interpolation=interpolation,
                            )
    )
    val_dst = ImageFolder(os.path.join(imagenet_root, 'val'), 
                          transform=pbench.data.presets.ClassificationPresetEval(
                                mean=[0.485, 0.456, 0.406] if use_imagenet_mean_std else [0.5, 0.5, 0.5],
                                std=[0.229, 0.224, 0.225] if use_imagenet_mean_std else [0.5, 0.5, 0.5],
                                crop_size=224,
                                resize_size=val_resize,
                                interpolation=interpolation,
                            )
    )
    train_loader = torch.utils.data.DataLoader(train_dst, batch_size=train_batch_size, shuffle=True, num_workers=num_workers)
    val_loader = torch.utils.data.DataLoader(val_dst, batch_size=val_batch_size, shuffle=False, num_workers=num_workers)

    cal_subset_size = 2048
    cal_indices = random.sample(range(len(val_loader)*val_batch_size), cal_subset_size)
    cal_dst = Subset(val_dst, cal_indices)
    cal_loader = torch.utils.data.DataLoader(cal_dst, batch_size=val_batch_size, shuffle=False, num_workers=num_workers)
    
    return train_loader, val_loader, cal_loader

In [4]:
import gc

def validate_model(model, val_loader, device):
    model.eval()
    model.to(device)
    correct = 0
    loss = 0
    with torch.no_grad():
        for k, (images, labels) in enumerate(tqdm(val_loader)):  
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss += torch.nn.functional.cross_entropy(outputs, labels, reduction='sum').item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
    
    del images, outputs, predicted
    gc.collect()
    return correct / len(val_loader.dataset), loss / len(val_loader.dataset)

In [5]:
data_path ='data/imagenet'
train_batch_size = 64
val_batch_size = 4
no_imagenet_mean_std = False
val_resize = 256
interpolation = 'bilinear' #'bicubic' 

model = 'mobilenet_v2'  #'resnet101.tv_in1k' #'convnext_base.fb_in1k' 
is_torchvision = True
drop = 0.0
drop_path = 0.0
ckpt = None

taylor_batchs = 512

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
example_inputs = torch.randn(1,3,224,224)
train_loader, val_loader, cal_loader = prepare_imagenet(data_path, train_batch_size=train_batch_size, val_batch_size=val_batch_size, use_imagenet_mean_std= not no_imagenet_mean_std, val_resize=val_resize, interpolation=interpolation)

if is_torchvision:
        import torchvision
        print(f"Loading torchvision model {model}...")
        model = torchvision.models.__dict__[model](pretrained=True).eval()
else:
    print(f"Loading timm model {model}...")
    model = timm.create_model(model, pretrained=True, drop_rate=drop, drop_path_rate=drop_path).eval()

if ckpt is not None:
    print(f"Loading checkpoint from {ckpt}...")
    ckpt = torch.load(ckpt, map_location='cpu')
    model.load_state_dict(ckpt['model'])
    

Parsing dataset...
Loading torchvision model mobilenet_v2...




In [6]:
def get_resnet_blocks(model):

    model_blocks, ignored_blocks = [], []

    for child in model.children():
        if isinstance(child, nn.Sequential):
            model_blocks.extend([sub_child for sub_child in child.children()])
        else:
            ignored_blocks.append(child)

    return model_blocks, ignored_blocks


In [7]:
def get_mobilenet_blocks(model):

    model_blocks, ignored_blocks = [], []

    for feat in model.features:
        if not isinstance(feat, nn.Sequential):
            model_blocks.append(feat)
        else:
            ignored_blocks.append(feat)

    ignored_blocks.append(model.classifier)

    return model_blocks, ignored_blocks


In [8]:
def get_convnext_blocks(model):

      model_blocks, ignored_blocks = [], []

      for stage in model.stages:
            for block in stage.blocks:
                  if isinstance(block, nn.Module):
                        model_blocks.append(block)
            model_blocks.append(stage.downsample)

      ignored_blocks.extend([model.stem, model.norm_pre, model.head])

      return model_blocks, ignored_blocks

In [9]:
def get_model_blocks(model):
    if model.__class__.__name__ == "MobileNetV2":
        return get_mobilenet_blocks(model)
    elif model.__class__.__name__ == "ResNet":
        return get_resnet_blocks(model)
    elif model.__class__.__name__ == "ConvNeXt":
        return get_convnext_blocks(model)


In [10]:
import copy
import warnings

def selective_block_pruning(trained_model, prune_method, pruning_ratios, data_loader, device):
    

    model = copy.deepcopy(trained_model)
    model_blocks, ignored_layers = get_model_blocks(model)
    model.to(device)

    pruning_info = {i: {"block": model_blocks[i], "pruning_ratio": ratio} for i, ratio in enumerate(pruning_ratios)}

    if prune_method == "Taylor":
        imp = tp.importance.TaylorImportance()

        if isinstance(imp, (tp.importance.GroupTaylorImportance, tp.importance.GroupHessianImportance)):
            model.zero_grad()
            if isinstance(imp, tp.importance.GroupHessianImportance):
                imp.zero_grad() # clear the accumulated gradients
            print("Accumulating gradients for pruning...")
            for k, (imgs, lbls) in enumerate(tqdm(data_loader)):
                if k>=taylor_batchs: break
                imgs = imgs.to(device)
                lbls = lbls.to(device)
                output = model(imgs)
                if isinstance(imp, tp.importance.GroupHessianImportance): # per-sample gradients for hessian
                    loss = torch.nn.functional.cross_entropy(output, lbls, reduction='none')
                    for l in loss:
                        model.zero_grad()
                        l.backward(retain_graph=True)
                        imp.accumulate_grad(model) # accumulate gradients
                elif isinstance(imp, tp.importance.GroupTaylorImportance): # batch gradients for first-order taylor
                    loss = torch.nn.functional.cross_entropy(output, lbls)
                    loss.backward()

    elif prune_method == "Magnitude":
        imp = tp.importance.MagnitudeImportance()

    else:
        warnings.warn(f"Invalid pruning method: '{prune_method}'. Expected 'Taylor' or 'Magnitude'.", UserWarning)
        raise ValueError("Pruning method must be either 'Taylor' or 'Magnitude'.")

    _, original_nparams = tp.utils.count_ops_and_params(model, imgs)

    for i, info in pruning_info.items():
        _, pruning_ratio = info["block"], info["pruning_ratio"]
        if pruning_ratio == 0.0:
            continue

        ignored_layers_block = [pruning_info[j]["block"] for j in range(len(pruning_info)) if j != i]
        combined_ignored_layers = ignored_layers + ignored_layers_block

        count = 0

        while True:

            pruner_group = tp.pruner.MagnitudePruner(
                model,
                example_inputs=imgs,
                importance=imp,
                pruning_ratio=pruning_ratio,
                ignored_layers=combined_ignored_layers,
            )
            pruner_group.step()

            macs, nparams = tp.utils.count_ops_and_params(model, imgs)
            if original_nparams - nparams == 0:
                count += 1
                if count > 1:
                    break
                pruning_ratio = 0.5

            original_nparams = nparams

    del imgs, lbls, output
    torch.cuda.empty_cache()


    return model, macs, nparams


In [11]:
import logging
import numpy as np


def perplexity_analysis_with_contributions(original_model, prune_method, data_loader, device):

    
    model_blocks, ignored_layers = get_model_blocks(original_model)
    blocks_number = len(model_blocks)

    total_block_accuracy = [0.0 for _ in range(blocks_number)]
    params_reduction = []
    macs_reduction = []

    # logging.info("\n=== Computing baseline accuracy without block replacement ===")
    print("\n=== Computing baseline accuracy without block replacement ===")
    baseline_accuracy, baseline_loss = validate_model(original_model, data_loader, device='cpu')
    # logging.info(f"Baseline accuracy: {baseline_accuracy*100:.2f}%")
    print(f"Baseline accuracy: {baseline_accuracy*100:.2f}%")
    
    input_size = [3, 224, 224]
    example_inputs = torch.randn(1, *input_size)
    original_macs, original_nparams = tp.utils.count_ops_and_params(original_model, example_inputs)


    for block_idx in range(blocks_number):

        # logging.info("\n=== Replacing Blocks and Tracking Reductions ===")
        print("\n=== Replacing Blocks and Tracking Reductions ===")

        pruning_ratios = (np.eye(blocks_number) * 0.8)[block_idx]

        pruned_model, macs, nparams = selective_block_pruning(
        original_model, prune_method, pruning_ratios, data_loader, device
        )

        # logging.info(f"\nReplacing Block {block_idx}:")
        # logging.info(f"  - MACs Reduction: {((original_macs - macs) / original_macs * 100):.2f}%")
        # logging.info(f"  - Parameters Reduction: {((original_nparams - nparams)/original_nparams*100):.2f}%")
        print(f"\nReplacing Block {block_idx}:")
        print(f"  - MACs Reduction: {((original_macs - macs) / original_macs * 100):.2f}%")
        print(f"  - Parameters Reduction: {((original_nparams - nparams)/original_nparams*100):.2f}%")
        

        params_reduction.append(original_nparams - nparams)
        macs_reduction.append(original_macs - macs)

        pruned_model.to(device)
        pruned_model.eval()
        
        pruned_accuracy,_ = validate_model(pruned_model, data_loader, device='cpu')

        del pruned_model
        torch.cuda.empty_cache()

        # logging.info(f"Accuracy After Pruning This Block: {pruned_accuracy*100:.2f}%\n")
        print(f"Accuracy After Pruning This Block: {pruned_accuracy*100:.2f}%\n")

        total_block_accuracy[block_idx] += pruned_accuracy

    total_accuracy_reduction = 0.0
    block_reductions = []
    total_params_reduction = 0.0
    total_macs_reduction = 0.0

    for block_idx in range(blocks_number):
        final_average_accuracy = total_block_accuracy[block_idx]
        accuracy_reduction = baseline_accuracy - final_average_accuracy
        if accuracy_reduction < 0:
            print(f"Block {block_idx} improved accuracy by {-accuracy_reduction*100:.2f}% — treated as 0 for importance.")
            accuracy_reduction = 0.0
        block_reductions.append(accuracy_reduction)
        total_accuracy_reduction += accuracy_reduction
        total_params_reduction += params_reduction[block_idx]
        total_macs_reduction += macs_reduction[block_idx] 

    weighted_importance_scores = []

    # logging.info("\n=== Relative Contribution of Each Block ===")
    print("\n=== Relative Contribution of Each Block ===")
    for block_idx in range(blocks_number):
        if total_accuracy_reduction == 0:
            total_accuracy_reduction = 1e-8  # avoid division by zero
        relative_contribution_accuracy = (block_reductions[block_idx] / total_accuracy_reduction) * 100
        relative_contribution_params = (1 - (params_reduction[block_idx] / total_params_reduction)) * 100
        relative_contribution_macs = (1 - (macs_reduction[block_idx] / total_macs_reduction)) * 100

        weight_accuracy = 0.5
        weight_params = 0.3
        weight_macs = 0.2
        weighted_importance = (weight_accuracy * relative_contribution_accuracy) 
        + (weight_params * relative_contribution_params) 
        + (weight_macs * relative_contribution_macs)

        # logging.info(
        # f"Block {block_idx}: Accuracy Decrease Contribution = {relative_contribution_accuracy:.2f}%, "
        # f"Parameter Reduction = {100 - relative_contribution_params:.2f}%, "
        # f"MACs Reduction = {100 - relative_contribution_macs:.2f}%, "
        # f"Weighted Importance Score = {weighted_importance:.2f}"
        # )  
        print(f"Block {block_idx}: Accuracy Decrease Contribution = {relative_contribution_accuracy:.2f}%, ")
        print(f"Parameter Reduction = {100 - relative_contribution_params:.2f}%, ")
        print(f"MACs Reduction = {100 - relative_contribution_macs:.2f}%, ")
        print(f"Weighted Importance Score = {weighted_importance:.2f}") 

        epsilon = 1e-1
        if (abs(100 - relative_contribution_params) < epsilon) and (abs(100 - relative_contribution_macs) < epsilon):
            weighted_importance = -1
            print(f"Block {block_idx} is unprunable")

        weighted_importance_scores.append(weighted_importance)

    return weighted_importance_scores

In [14]:
# import logging
# import numpy as np


# def perplexity_analysis_with_contributions(original_model, prune_method, data_loader, device):

    
#     model_blocks, ignored_layers = get_model_blocks(original_model)
#     blocks_number = len(model_blocks)

#     total_block_accuracy = [0.0 for _ in range(blocks_number)]
#     params_reduction = []
#     macs_reduction = []

#     # logging.info("\n=== Computing baseline accuracy without block replacement ===")
#     print("\n=== Computing baseline accuracy without block replacement ===")
#     baseline_accuracy, baseline_loss = validate_model(original_model, data_loader, device='cpu')
#     # logging.info(f"Baseline accuracy: {baseline_accuracy*100:.2f}%")
#     print(f"Baseline accuracy: {baseline_accuracy*100:.2f}%")
    
#     input_size = [3, 224, 224]
#     example_inputs = torch.randn(1, *input_size).to(device)
#     original_macs, original_nparams = tp.utils.count_ops_and_params(original_model, example_inputs)


#     # logging.info("\n=== Replacing Blocks and Tracking Reductions ===")
#     print("\n=== Replacing Blocks and Tracking Reductions ===")

#     # pruning_ratios = (np.eye(blocks_number) * 0.8)[block_idx]

#     pruned_acc, pruned_macs, pruned_params = selective_block_pruning(
#     original_model, prune_method, data_loader, device
#     )

#     # logging.info(f"\nReplacing Block {block_idx}:")
#     # logging.info(f"  - MACs Reduction: {((original_macs - macs) / original_macs * 100):.2f}%")
#     # logging.info(f"  - Parameters Reduction: {((original_nparams - nparams)/original_nparams*100):.2f}%")

#     params_reduction = [original_nparams - param for param in pruned_params]
#     macs_reduction = [original_macs - macs for macs in pruned_macs]

#     # pruned_model.to(device)
#     # pruned_model.eval()
    
#     # pruned_accuracy,_ = validate_model(pruned_model, data_loader, device='cpu')

#     # del pruned_model
#     # torch.cuda.empty_cache()

#     # logging.info(f"Accuracy After Pruning This Block: {nacc*100:.2f}%\n")
#     total_block_accuracy = pruned_acc

#     total_accuracy_reduction = 0.0
#     block_reductions = []
#     total_params_reduction = 0.0
#     total_macs_reduction = 0.0

#     for block_idx in range(blocks_number):
#         final_average_accuracy = total_block_accuracy[block_idx]
#         accuracy_reduction = baseline_accuracy - final_average_accuracy
#         block_reductions.append(accuracy_reduction)
#         total_accuracy_reduction += accuracy_reduction
#         total_params_reduction += params_reduction[block_idx]
#         total_macs_reduction += macs_reduction[block_idx] 

#     relative_contributions = []
#     weighted_importance_scores = []

#     # logging.info("\n=== Relative Contribution of Each Block ===")
#     print("\n=== Relative Contribution of Each Block ===")
#     for block_idx in range(blocks_number):
#         relative_contribution_accuracy = (block_reductions[block_idx] / total_accuracy_reduction) * 100
#         relative_contribution_params = (1 - (params_reduction[block_idx] / total_params_reduction)) * 100
#         relative_contribution_macs = (1 - (macs_reduction[block_idx] / total_macs_reduction)) * 100

#         weight_accuracy = 0.5
#         weight_params = 0.3
#         weight_macs = 0.2
#         weighted_importance = (weight_accuracy * relative_contribution_accuracy) 
#         + (weight_params * relative_contribution_params) 
#         + (weight_macs * relative_contribution_macs)

#         logging.info(
#         f"Block {block_idx}: Accuracy Decrease Contribution = {relative_contribution_accuracy:.2f}%, "
#         f"Parameter Reduction = {100 - relative_contribution_params:.2f}%, "
#         f"MACs Reduction = {100 - relative_contribution_macs:.2f}%, "
#         f"Weighted Importance Score = {weighted_importance:.2f}"
#         )  
#         print(f"Block {block_idx}: Accuracy Decrease Contribution = {relative_contribution_accuracy:.2f}%, ")
#         print(f"Parameter Reduction = {100 - relative_contribution_params:.2f}%, ")
#         print(f"MACs Reduction = {100 - relative_contribution_macs:.2f}%, ")
#         print(f"Weighted Importance Score = {weighted_importance:.2f}")

#         relative_contributions.append(relative_contribution_accuracy)
#         weighted_importance_scores.append(weighted_importance)

#     return weighted_importance_scores

In [12]:
def prune_model(trained_model, prune_method, pruning_ratios, data_loader, device):
   
    # Make a copy of the trained model
    model = copy.deepcopy(trained_model)
    model.to(device)

    model_blocks, ignored_layers = get_model_blocks(model)


    pruning_info = {i: {"block": model_blocks[i], "pruning_ratio": ratio} 
                    for i, ratio in enumerate(pruning_ratios)}

    if prune_method == 'Taylor':
        imp = tp.importance.TaylorImportance() 
        
        if isinstance(imp, (tp.importance.GroupTaylorImportance, tp.importance.GroupHessianImportance)):
            model.zero_grad()
            if isinstance(imp, tp.importance.GroupHessianImportance):
                imp.zero_grad() # clear the accumulated gradients
            print("Accumulating gradients for pruning...")
            for k, (imgs, lbls) in enumerate(tqdm(data_loader)):
                if k>=taylor_batchs: break
                imgs = imgs.to(device)
                lbls = lbls.to(device)
                output = model(imgs)
                if isinstance(imp, tp.importance.GroupHessianImportance): # per-sample gradients for hessian
                    loss = torch.nn.functional.cross_entropy(output, lbls, reduction='none')
                    for l in loss:
                        model.zero_grad()
                        l.backward(retain_graph=True)
                        imp.accumulate_grad(model) # accumulate gradients
                elif isinstance(imp, tp.importance.GroupTaylorImportance): # batch gradients for first-order taylor
                    loss = torch.nn.functional.cross_entropy(output, lbls)
                    loss.backward()

        original_macs, original_params = tp.utils.count_ops_and_params(model, imgs)

        # Prune each block while ignoring other layers
        for i, info in pruning_info.items():
            pruning_ratio = info["pruning_ratio"]
            
            # Add all blocks to the ignored layers except the block being pruned
            ignored_layers_block = [pruning_info[j]["block"] for j in range(len(pruning_info)) if j != i]

            # Combine fixed ignored layers (conv_stem, bn1, classifier) with the ignored blocks
            combined_ignored_layers = ignored_layers + ignored_layers_block

            # Apply pruning using the combined ignored layers
            pruner_group = tp.pruner.MagnitudePruner( 
                model,
                example_inputs=imgs,
                importance=imp,
                pruning_ratio=pruning_ratio,
                ignored_layers=combined_ignored_layers,
                iterative_steps=1,
            )

            # Step through pruning
            pruner_group.step()

    # Counting MACs and Params after pruning
    macs, nparams = tp.utils.count_ops_and_params(model, imgs)

    # logging.info(f"MACs of the Pruned Model: {macs/ 1e9} G")
    print(f"MACs of the Original Model: {original_macs/ 1e9} G  -->  MACs of the Pruned Model: {macs/ 1e9} G")
    # logging.info(f"# Parameters of the Pruned Model: {nparams/ 1e3} K")
    print(f"# Parameters of the Original Model: {original_params/ 1e3} K  --> # Parameters of the Pruned Model: {nparams/ 1e3} K")

    param_reduction = ((original_params - nparams) / original_params) * 100
    macs_reduction = ((original_macs - macs) / original_macs) * 100

    # Free up GPU memory
    del imgs, lbls, output
    torch.cuda.empty_cache()

    return model, math.ceil(param_reduction), math.ceil(macs_reduction)

In [13]:
def calculate_pruning_ratios_intense(contributions, max_pruning_ratio=0.9, k=5):
    # Normalize the contributions to get values between 0 and 1
    # total_contribution = sum(contributions)
    # Ignore the unprunable blocks
    total_contribution = sum([contribution for contribution in contributions if contribution != -1])

    normalized_contributions = [contribution / total_contribution if contribution != -1 else contribution for contribution in contributions]

    # Apply exponential decay to magnify the effect for less important blocks
    # Assign zero pruning factor for unprunable blocks
    pruning_factors = [np.exp(-k * nc) if nc != -1 else 0 for nc in normalized_contributions]

    # Normalize the pruning factors so they stay within the max pruning ratio
    max_factor = max(pruning_factors)
    normalized_factors = [pf / max_factor for pf in pruning_factors]

    # Scale by the maximum pruning ratio
    pruning_ratios = [max_pruning_ratio * nf for nf in normalized_factors]

    pruning_ratios = [round(num, 2) for num in pruning_ratios]

    return pruning_ratios

In [14]:
warnings.filterwarnings("ignore")
relative_contribution = perplexity_analysis_with_contributions(model, 'Taylor', cal_loader, device)


=== Computing baseline accuracy without block replacement ===


100%|██████████| 512/512 [00:11<00:00, 43.66it/s]


Baseline accuracy: 69.97%

=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 512/512 [00:05<00:00, 87.35it/s]



Replacing Block 0:
  - MACs Reduction: 7.64%
  - Parameters Reduction: 0.06%


100%|██████████| 512/512 [00:10<00:00, 47.67it/s]


Accuracy After Pruning This Block: 0.29%


=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 512/512 [00:05<00:00, 90.43it/s]



Replacing Block 1:
  - MACs Reduction: 10.42%
  - Parameters Reduction: 0.14%


100%|██████████| 512/512 [00:08<00:00, 62.77it/s]


Accuracy After Pruning This Block: 0.10%


=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 512/512 [00:05<00:00, 90.87it/s]



Replacing Block 2:
  - MACs Reduction: 8.82%
  - Parameters Reduction: 0.25%


100%|██████████| 512/512 [00:10<00:00, 48.28it/s]


Accuracy After Pruning This Block: 0.24%


=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 512/512 [00:05<00:00, 90.50it/s]



Replacing Block 3:
  - MACs Reduction: 5.32%
  - Parameters Reduction: 0.28%


100%|██████████| 512/512 [00:11<00:00, 46.34it/s]


Accuracy After Pruning This Block: 0.20%


=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 512/512 [00:05<00:00, 88.77it/s]



Replacing Block 4:
  - MACs Reduction: 3.69%
  - Parameters Reduction: 0.42%


100%|██████████| 512/512 [00:11<00:00, 45.37it/s]


Accuracy After Pruning This Block: 15.97%


=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 512/512 [00:05<00:00, 98.55it/s] 



Replacing Block 5:
  - MACs Reduction: 3.69%
  - Parameters Reduction: 0.42%


100%|██████████| 512/512 [00:11<00:00, 46.17it/s]


Accuracy After Pruning This Block: 29.39%


=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 512/512 [00:05<00:00, 91.19it/s] 



Replacing Block 6:
  - MACs Reduction: 2.53%
  - Parameters Reduction: 0.59%


100%|██████████| 512/512 [00:12<00:00, 41.42it/s]


Accuracy After Pruning This Block: 0.10%


=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 512/512 [00:05<00:00, 89.03it/s]



Replacing Block 7:
  - MACs Reduction: 3.35%
  - Parameters Reduction: 1.54%


100%|██████████| 512/512 [00:12<00:00, 40.72it/s]


Accuracy After Pruning This Block: 3.61%


=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 512/512 [00:05<00:00, 96.29it/s] 



Replacing Block 8:
  - MACs Reduction: 3.35%
  - Parameters Reduction: 1.54%


100%|██████████| 512/512 [00:12<00:00, 41.08it/s]


Accuracy After Pruning This Block: 62.74%


=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 512/512 [00:05<00:00, 97.89it/s] 



Replacing Block 9:
  - MACs Reduction: 3.35%
  - Parameters Reduction: 1.54%


100%|██████████| 512/512 [00:12<00:00, 41.06it/s]


Accuracy After Pruning This Block: 50.15%


=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 512/512 [00:05<00:00, 97.42it/s] 



Replacing Block 10:
  - MACs Reduction: 4.10%
  - Parameters Reduction: 1.89%


100%|██████████| 512/512 [00:12<00:00, 41.07it/s]


Accuracy After Pruning This Block: 0.15%


=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 512/512 [00:05<00:00, 89.20it/s]



Replacing Block 11:
  - MACs Reduction: 7.29%
  - Parameters Reduction: 3.36%


100%|██████████| 512/512 [00:12<00:00, 40.70it/s]


Accuracy After Pruning This Block: 22.95%


=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 512/512 [00:05<00:00, 91.06it/s]



Replacing Block 12:
  - MACs Reduction: 7.29%
  - Parameters Reduction: 3.36%


100%|██████████| 512/512 [00:12<00:00, 41.30it/s]


Accuracy After Pruning This Block: 8.94%


=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 512/512 [00:05<00:00, 95.04it/s] 



Replacing Block 13:
  - MACs Reduction: 5.00%
  - Parameters Reduction: 4.41%


100%|██████████| 512/512 [00:12<00:00, 39.69it/s]


Accuracy After Pruning This Block: 0.15%


=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 512/512 [00:05<00:00, 89.26it/s]



Replacing Block 14:
  - MACs Reduction: 4.92%
  - Parameters Reduction: 9.11%


100%|██████████| 512/512 [00:12<00:00, 41.58it/s]


Accuracy After Pruning This Block: 18.70%


=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 512/512 [00:05<00:00, 92.06it/s]



Replacing Block 15:
  - MACs Reduction: 4.92%
  - Parameters Reduction: 9.11%


100%|██████████| 512/512 [00:12<00:00, 41.17it/s]


Accuracy After Pruning This Block: 11.23%


=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 512/512 [00:05<00:00, 88.18it/s]



Replacing Block 16:
  - MACs Reduction: 13.53%
  - Parameters Reduction: 25.17%


100%|██████████| 512/512 [00:12<00:00, 41.19it/s]


Accuracy After Pruning This Block: 0.24%


=== Relative Contribution of Each Block ===
Block 0: Accuracy Decrease Contribution = 7.23%, 
Parameter Reduction = 0.09%, 
MACs Reduction = 7.70%, 
Weighted Importance Score = 3.61
Block 1: Accuracy Decrease Contribution = 7.25%, 
Parameter Reduction = 0.23%, 
MACs Reduction = 10.50%, 
Weighted Importance Score = 3.62
Block 2: Accuracy Decrease Contribution = 7.23%, 
Parameter Reduction = 0.39%, 
MACs Reduction = 8.89%, 
Weighted Importance Score = 3.62
Block 3: Accuracy Decrease Contribution = 7.24%, 
Parameter Reduction = 0.45%, 
MACs Reduction = 5.36%, 
Weighted Importance Score = 3.62
Block 4: Accuracy Decrease Contribution = 5.60%, 
Parameter Reduction = 0.66%, 
MACs Reduction = 3.72%, 
Weighted Importance Score = 2.80
Block 5: Accuracy Decrease Contribution = 4.21%, 
Parameter Reduction = 0.66%, 
MACs Reduction = 3.72%, 
Weighted Importance Score = 2.10
Block 6: Accuracy Decrease Contribution = 7.25%, 
Parameter Reduction = 0.94%, 
MACs

In [None]:
# # ConvNext Blocks Contributions

# relative_contribution = [30.936995153473347,
#  2.827140549273021,
#  5.573505654281099,
#  -1,
#  1.2924071082390953,
#  1.050080775444265,
#  0.48465266558966075,
#  -1,
#  0.0,
#  0.5654281098546041,
#  0.5654281098546041,
#  0.40387722132471726,
#  0.24232633279483037,
#  0.6462035541195477,
#  0.16155088852988692,
#  0.24232633279483037,
#  0.6462035541195477,
#  0.0,
#  0.24232633279483037,
#  0.48465266558966075,
#  0.08077544426494346,
#  0.32310177705977383,
#  0.16155088852988692,
#  0.08077544426494346,
#  0.16155088852988692,
#  0.0,
#  0.40387722132471726,
#  0.40387722132471726,
#  0.24232633279483037,
#  0.08077544426494346,
#  0.5654281098546041,
#  0.0,
#  0.0,
#  0.0,
#  0.24232633279483037,
#  -1,
#  0.0,
#  0.40387722132471726,
#  0.48465266558966075,
#  -1]

In [None]:
# # ResNet Blocks Contributions

# relative_contribution = [6.866310160427807,
#  0.1711229946524064,
#  0.6631016042780749,
#  11.336898395721926,
#  1.0053475935828877,
#  0.4598930481283422,
#  3.3262032085561497,
#  16.06417112299465,
#  0.0,
#  0.09625668449197862,
#  0.0320855614973262,
#  0.2352941176470588,
#  0.0427807486631016,
#  0.1497326203208556,
#  0.27807486631016043,
#  0.2566844919786096,
#  0.19251336898395724,
#  0.35294117647058826,
#  0.2994652406417112,
#  0.1497326203208556,
#  0.0855614973262032,
#  0.1176470588235294,
#  0.1176470588235294,
#  0.0,
#  0.2887700534759358,
#  0.0641711229946524,
#  0.2566844919786096,
#  0.4812834224598931,
#  0.41711229946524064,
#  0.5454545454545455,
#  3.048128342245989,
#  1.6363636363636365,
#  0.9625668449197862]

In [None]:
# # MobileNet Blocks Contributions

# relative_contribution = [3.593584505976698,
#  3.618802642860745,
#  3.5986281333535075,
#  3.62132445654915,
#  2.6327734906945075,
#  2.0628435971150454,
#  3.618802642860745,
#  3.4624501941796537,
#  0.5094063650577496,
#  0.9633328289705957,
#  3.618802642860745,
#  2.4839864830786302,
#  3.1573107378826855,
#  3.6288898976143646,
#  2.6932970192162204,
#  3.1119180914914004,
#  3.6238462702375545]

In [None]:
pruning_ratios = calculate_pruning_ratios_intense(relative_contribution, max_pruning_ratio=0.99, k=5)

In [24]:
pruning_ratios

[np.float64(0.99),
 np.float64(0.99),
 np.float64(0.99),
 np.float64(0.99),
 np.float64(0.99),
 np.float64(0.99),
 np.float64(0.99),
 np.float64(0.99),
 np.float64(0.99),
 np.float64(0.99),
 np.float64(0.99),
 np.float64(0.99),
 np.float64(0.99),
 np.float64(0.99),
 np.float64(0.99),
 np.float64(0.99),
 np.float64(0.99)]

In [25]:
warnings.filterwarnings("ignore")
torch.cuda.empty_cache()

pruned_model, param_reduction, macs_reduction = prune_model(model, 'Taylor', pruning_ratios, cal_loader, device)

Accumulating gradients for pruning...


100%|██████████| 512/512 [00:06<00:00, 82.88it/s]


MACs of the Original Model: 0.320236538 G  -->  MACs of the Pruned Model: 0.061527024 G
# Parameters of the Original Model: 3504.872 K  --> # Parameters of the Pruned Model: 1309.808 K


In [None]:
pruned_accuracy,_ = validate_model(pruned_model, val_loader, device='cuda')
pruned_accuracy

In [26]:
torch.save(pruned_model, f'/home/ict317-3/Mohammad/Isomorphic-Pruning/our_pruned_models/MobileNet/pruned_model_{param_reduction}%.pth')

In [22]:
input_size = [3, 224, 224]
example_inputs = torch.randn(1, *input_size, device=device)

model = torch.load("/home/ict317-3/Mohammad/Isomorphic-Pruning/our_pruned_models/ConvNext4.2/pruned_model_80%.pth", weights_only=False)
model.to(device)

macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
print(f"MACs = {macs/1e9} G, and parameters count = {nparams/1e6} M")

MACs = 3.563852603 G, and parameters count = 18.574907 M
