In [1]:
import os, sys
import torch
import torch.nn.functional as F

from torchvision.datasets import ImageFolder
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
import torch.nn as nn
import timm
import torch_pruning as tp

import pbench
pbench.forward_patch.patch_timm_forward() # patch timm.forward() to support pruning

from tqdm import tqdm
import argparse

import torchvision as tv

In [2]:
from torch.utils.data import Subset
import random

def prepare_imagenet(imagenet_root, train_batch_size=64, val_batch_size=128, num_workers=4, use_imagenet_mean_std=True, interpolation='bicubic', val_resize=256):
    """The imagenet_root should contain train and val folders.
    """
    interpolation = getattr(T.InterpolationMode, interpolation.upper())

    print('Parsing dataset...')
    train_dst = ImageFolder(os.path.join(imagenet_root, 'train'), 
                            transform=pbench.data.presets.ClassificationPresetEval(
                                mean=[0.485, 0.456, 0.406] if use_imagenet_mean_std else [0.5, 0.5, 0.5],
                                std=[0.229, 0.224, 0.225] if use_imagenet_mean_std else [0.5, 0.5, 0.5],
                                crop_size=224,
                                resize_size=val_resize,
                                interpolation=interpolation,
                            )
    )
    val_dst = ImageFolder(os.path.join(imagenet_root, 'val'), 
                          transform=pbench.data.presets.ClassificationPresetEval(
                                mean=[0.485, 0.456, 0.406] if use_imagenet_mean_std else [0.5, 0.5, 0.5],
                                std=[0.229, 0.224, 0.225] if use_imagenet_mean_std else [0.5, 0.5, 0.5],
                                crop_size=224,
                                resize_size=val_resize,
                                interpolation=interpolation,
                            )
    )
    train_loader = torch.utils.data.DataLoader(train_dst, batch_size=train_batch_size, shuffle=True, num_workers=num_workers)
    val_loader = torch.utils.data.DataLoader(val_dst, batch_size=val_batch_size, shuffle=False, num_workers=num_workers)

    cal_subset_size = 1024
    cal_indices = random.sample(range(len(val_loader)*val_batch_size), cal_subset_size)
    cal_dst = Subset(val_dst, cal_indices)
    cal_loader = torch.utils.data.DataLoader(cal_dst, batch_size=val_batch_size, shuffle=False, num_workers=num_workers)
    
    return train_loader, val_loader, cal_loader

In [3]:
import gc

def validate_model(model, val_loader, device):
    model.eval()
    model.to(device)
    correct = 0
    loss = 0
    with torch.no_grad():
        for k, (images, labels) in enumerate(tqdm(val_loader)):  
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss += torch.nn.functional.cross_entropy(outputs, labels, reduction='sum').item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
    
    del images, outputs, predicted
    gc.collect()
    return correct / len(val_loader.dataset), loss / len(val_loader.dataset)

In [4]:
data_path ='data/imagenet'
train_batch_size = 64
val_batch_size = 1
no_imagenet_mean_std = False
val_resize = 256
interpolation = 'bilinear' #'bicubic'

model = 'resnet101.tv_in1k' #'mobilenet_v2' #'convnext_base.fb_in1k'
is_torchvision = False
drop = 0.0
drop_path = 0.0
ckpt = None

taylor_batchs = 1024

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
example_inputs = torch.randn(1,3,224,224)
train_loader, val_loader, cal_loader = prepare_imagenet(data_path, train_batch_size=train_batch_size, val_batch_size=val_batch_size, use_imagenet_mean_std= not no_imagenet_mean_std, val_resize=val_resize, interpolation=interpolation)

if is_torchvision:
        import torchvision
        print(f"Loading torchvision model {model}...")
        model = torchvision.models.__dict__[model](pretrained=True).eval()
else:
    print(f"Loading timm model {model}...")
    model = timm.create_model(model, pretrained=True, drop_rate=drop, drop_path_rate=drop_path).eval()

if ckpt is not None:
    print(f"Loading checkpoint from {ckpt}...")
    ckpt = torch.load(ckpt, map_location='cpu')
    model.load_state_dict(ckpt['model'])
    

Parsing dataset...
Loading timm model resnet101.tv_in1k...


In [6]:
def get_resnet_blocks(model):

    model_blocks, ignored_blocks = [], []

    for child in model.children():
        if isinstance(child, nn.Sequential):
            model_blocks.extend([sub_child for sub_child in child.children()])
        else:
            ignored_blocks.append(child)

    return model_blocks, ignored_blocks


In [7]:
def get_mobilenet_blocks(model):

    model_blocks, ignored_blocks = [], []

    for feat in model.features:
        if not isinstance(feat, nn.Sequential):
            model_blocks.append(feat)
        else:
            ignored_blocks.append(feat)

    ignored_blocks.append(model.classifier)

    return model_blocks, ignored_blocks


In [8]:
def get_convnext_blocks(model):

      model_blocks, ignored_blocks = [], []

      for stage in model.stages:
            for block in stage.blocks:
                  if isinstance(block, nn.Module):
                        model_blocks.append(block)
            model_blocks.append(stage.downsample)

      ignored_blocks.extend([model.stem, model.norm_pre, model.head])

      return model_blocks, ignored_blocks

In [13]:
def get_model_blocks(model):
    if model.__class__.__name__ == "MobileNetV2":
        return get_mobilenet_blocks(model)
    elif model.__class__.__name__ == "ResNet":
        return get_resnet_blocks(model)
    elif model.__class__.__name__ == "ConvNeXt":
        return get_convnext_blocks(model)


In [11]:
import copy
import warnings

def selective_block_pruning(trained_model, prune_method, pruning_ratios, data_loader, device):
    

    model = copy.deepcopy(trained_model)
    model_blocks, ignored_layers = get_model_blocks(model)
    model.to(device)

    pruning_info = {i: {"block": model_blocks[i], "pruning_ratio": ratio} for i, ratio in enumerate(pruning_ratios)}

    if prune_method == "Taylor":
        imp = tp.importance.TaylorImportance()

        if isinstance(imp, (tp.importance.GroupTaylorImportance, tp.importance.GroupHessianImportance)):
            model.zero_grad()
            if isinstance(imp, tp.importance.GroupHessianImportance):
                imp.zero_grad() # clear the accumulated gradients
            print("Accumulating gradients for pruning...")
            for k, (imgs, lbls) in enumerate(tqdm(data_loader)):
                if k>=taylor_batchs: break
                imgs = imgs.to(device)
                lbls = lbls.to(device)
                output = model(imgs)
                if isinstance(imp, tp.importance.GroupHessianImportance): # per-sample gradients for hessian
                    loss = torch.nn.functional.cross_entropy(output, lbls, reduction='none')
                    for l in loss:
                        model.zero_grad()
                        l.backward(retain_graph=True)
                        imp.accumulate_grad(model) # accumulate gradients
                elif isinstance(imp, tp.importance.GroupTaylorImportance): # batch gradients for first-order taylor
                    loss = torch.nn.functional.cross_entropy(output, lbls)
                    loss.backward()

    elif prune_method == "Magnitude":
        imp = tp.importance.MagnitudeImportance()

    else:
        warnings.warn(f"Invalid pruning method: '{prune_method}'. Expected 'Taylor' or 'Magnitude'.", UserWarning)
        raise ValueError("Pruning method must be either 'Taylor' or 'Magnitude'.")

    _, original_nparams = tp.utils.count_ops_and_params(model, imgs)

    for i, info in pruning_info.items():
        _, pruning_ratio = info["block"], info["pruning_ratio"]
        if pruning_ratio == 0.0:
            continue

        ignored_layers_block = [pruning_info[j]["block"] for j in range(len(pruning_info)) if j != i]
        combined_ignored_layers = ignored_layers + ignored_layers_block

        count = 0

        while True:

            pruner_group = tp.pruner.MagnitudePruner(
                model,
                example_inputs=imgs,
                importance=imp,
                pruning_ratio=pruning_ratio,
                ignored_layers=combined_ignored_layers,
            )
            pruner_group.step()

            macs, nparams = tp.utils.count_ops_and_params(model, imgs)
            if original_nparams - nparams == 0:
                count += 1
                if count > 1:
                    break
                pruning_ratio = 0.5

            original_nparams = nparams

    del imgs, lbls, output
    torch.cuda.empty_cache()


    return model, macs, nparams


In [12]:
import logging
import numpy as np


def perplexity_analysis_with_contributions(original_model, prune_method, data_loader, device):

    
    model_blocks, ignored_layers = get_model_blocks(original_model)
    blocks_number = len(model_blocks)

    total_block_accuracy = [0.0 for _ in range(blocks_number)]
    params_reduction = []
    macs_reduction = []

    # logging.info("\n=== Computing baseline accuracy without block replacement ===")
    print("\n=== Computing baseline accuracy without block replacement ===")
    baseline_accuracy, baseline_loss = validate_model(original_model, data_loader, device='cpu')
    # logging.info(f"Baseline accuracy: {baseline_accuracy*100:.2f}%")
    print(f"Baseline accuracy: {baseline_accuracy*100:.2f}%")
    
    input_size = [3, 224, 224]
    example_inputs = torch.randn(1, *input_size)
    original_macs, original_nparams = tp.utils.count_ops_and_params(original_model, example_inputs)


    for block_idx in range(blocks_number):

        # logging.info("\n=== Replacing Blocks and Tracking Reductions ===")
        print("\n=== Replacing Blocks and Tracking Reductions ===")

        pruning_ratios = (np.eye(blocks_number) * 0.8)[block_idx]

        pruned_model, macs, nparams = selective_block_pruning(
        original_model, prune_method, pruning_ratios, data_loader, device
        )

        # logging.info(f"\nReplacing Block {block_idx}:")
        # logging.info(f"  - MACs Reduction: {((original_macs - macs) / original_macs * 100):.2f}%")
        # logging.info(f"  - Parameters Reduction: {((original_nparams - nparams)/original_nparams*100):.2f}%")
        print(f"\nReplacing Block {block_idx}:")
        print(f"  - MACs Reduction: {((original_macs - macs) / original_macs * 100):.2f}%")
        print(f"  - Parameters Reduction: {((original_nparams - nparams)/original_nparams*100):.2f}%")

        params_reduction.append(original_nparams - nparams)
        macs_reduction.append(original_macs - macs)

        pruned_model.to(device)
        pruned_model.eval()
        
        pruned_accuracy,_ = validate_model(pruned_model, data_loader, device='cpu')

        del pruned_model
        torch.cuda.empty_cache()

        logging.info(f"Accuracy After Pruning This Block: {pruned_accuracy*100:.2f}%\n")
        total_block_accuracy[block_idx] += pruned_accuracy

    total_accuracy_reduction = 0.0
    block_reductions = []
    total_params_reduction = 0.0
    total_macs_reduction = 0.0

    for block_idx in range(blocks_number):
        final_average_accuracy = total_block_accuracy[block_idx]
        accuracy_reduction = baseline_accuracy - final_average_accuracy
        block_reductions.append(accuracy_reduction)
        total_accuracy_reduction += accuracy_reduction
        total_params_reduction += params_reduction[block_idx]
        total_macs_reduction += macs_reduction[block_idx] 

    relative_contributions = []
    weighted_importance_scores = []

    # logging.info("\n=== Relative Contribution of Each Block ===")
    print("\n=== Relative Contribution of Each Block ===")
    for block_idx in range(blocks_number):
        relative_contribution_accuracy = (block_reductions[block_idx] / total_accuracy_reduction) * 100
        relative_contribution_params = (1 - (params_reduction[block_idx] / total_params_reduction)) * 100
        relative_contribution_macs = (1 - (macs_reduction[block_idx] / total_macs_reduction)) * 100

        weight_accuracy = 0.5
        weight_params = 0.3
        weight_macs = 0.2
        weighted_importance = (weight_accuracy * relative_contribution_accuracy) 
        + (weight_params * relative_contribution_params) 
        + (weight_macs * relative_contribution_macs)

        # logging.info(
        # f"Block {block_idx}: Accuracy Decrease Contribution = {relative_contribution_accuracy:.2f}%, "
        # f"Parameter Reduction = {100 - relative_contribution_params:.2f}%, "
        # f"MACs Reduction = {100 - relative_contribution_macs:.2f}%, "
        # f"Weighted Importance Score = {weighted_importance:.2f}"
        # )  
        print(f"Block {block_idx}: Accuracy Decrease Contribution = {relative_contribution_accuracy:.2f}%, ")
        print(f"Parameter Reduction = {100 - relative_contribution_params:.2f}%, ")
        print(f"MACs Reduction = {100 - relative_contribution_macs:.2f}%, ")
        print(f"Weighted Importance Score = {weighted_importance:.2f}") 

        relative_contributions.append(relative_contribution_accuracy)
        weighted_importance_scores.append(weighted_importance)

    return weighted_importance_scores

In [12]:
# import logging
# import numpy as np


# def perplexity_analysis_with_contributions(original_model, prune_method, data_loader, device):

    
#     model_blocks, ignored_layers = get_model_blocks(original_model)
#     blocks_number = len(model_blocks)

#     total_block_accuracy = [0.0 for _ in range(blocks_number)]
#     params_reduction = []
#     macs_reduction = []

#     # logging.info("\n=== Computing baseline accuracy without block replacement ===")
#     print("\n=== Computing baseline accuracy without block replacement ===")
#     baseline_accuracy, baseline_loss = validate_model(original_model, data_loader, device='cpu')
#     # logging.info(f"Baseline accuracy: {baseline_accuracy*100:.2f}%")
#     print(f"Baseline accuracy: {baseline_accuracy*100:.2f}%")
    
#     input_size = [3, 224, 224]
#     example_inputs = torch.randn(1, *input_size).to(device)
#     original_macs, original_nparams = tp.utils.count_ops_and_params(original_model, example_inputs)


#     # logging.info("\n=== Replacing Blocks and Tracking Reductions ===")
#     print("\n=== Replacing Blocks and Tracking Reductions ===")

#     # pruning_ratios = (np.eye(blocks_number) * 0.8)[block_idx]

#     pruned_acc, pruned_macs, pruned_params = selective_block_pruning(
#     original_model, prune_method, data_loader, device
#     )

#     # logging.info(f"\nReplacing Block {block_idx}:")
#     # logging.info(f"  - MACs Reduction: {((original_macs - macs) / original_macs * 100):.2f}%")
#     # logging.info(f"  - Parameters Reduction: {((original_nparams - nparams)/original_nparams*100):.2f}%")

#     params_reduction = [original_nparams - param for param in pruned_params]
#     macs_reduction = [original_macs - macs for macs in pruned_macs]

#     # pruned_model.to(device)
#     # pruned_model.eval()
    
#     # pruned_accuracy,_ = validate_model(pruned_model, data_loader, device='cpu')

#     # del pruned_model
#     # torch.cuda.empty_cache()

#     # logging.info(f"Accuracy After Pruning This Block: {nacc*100:.2f}%\n")
#     total_block_accuracy = pruned_acc

#     total_accuracy_reduction = 0.0
#     block_reductions = []
#     total_params_reduction = 0.0
#     total_macs_reduction = 0.0

#     for block_idx in range(blocks_number):
#         final_average_accuracy = total_block_accuracy[block_idx]
#         accuracy_reduction = baseline_accuracy - final_average_accuracy
#         block_reductions.append(accuracy_reduction)
#         total_accuracy_reduction += accuracy_reduction
#         total_params_reduction += params_reduction[block_idx]
#         total_macs_reduction += macs_reduction[block_idx] 

#     relative_contributions = []
#     weighted_importance_scores = []

#     # logging.info("\n=== Relative Contribution of Each Block ===")
#     print("\n=== Relative Contribution of Each Block ===")
#     for block_idx in range(blocks_number):
#         relative_contribution_accuracy = (block_reductions[block_idx] / total_accuracy_reduction) * 100
#         relative_contribution_params = (1 - (params_reduction[block_idx] / total_params_reduction)) * 100
#         relative_contribution_macs = (1 - (macs_reduction[block_idx] / total_macs_reduction)) * 100

#         weight_accuracy = 0.5
#         weight_params = 0.3
#         weight_macs = 0.2
#         weighted_importance = (weight_accuracy * relative_contribution_accuracy) 
#         + (weight_params * relative_contribution_params) 
#         + (weight_macs * relative_contribution_macs)

#         logging.info(
#         f"Block {block_idx}: Accuracy Decrease Contribution = {relative_contribution_accuracy:.2f}%, "
#         f"Parameter Reduction = {100 - relative_contribution_params:.2f}%, "
#         f"MACs Reduction = {100 - relative_contribution_macs:.2f}%, "
#         f"Weighted Importance Score = {weighted_importance:.2f}"
#         )  
#         print(f"Block {block_idx}: Accuracy Decrease Contribution = {relative_contribution_accuracy:.2f}%, ")
#         print(f"Parameter Reduction = {100 - relative_contribution_params:.2f}%, ")
#         print(f"MACs Reduction = {100 - relative_contribution_macs:.2f}%, ")
#         print(f"Weighted Importance Score = {weighted_importance:.2f}")

#         relative_contributions.append(relative_contribution_accuracy)
#         weighted_importance_scores.append(weighted_importance)

#     return weighted_importance_scores

In [22]:
def prune_model(trained_model, prune_method, pruning_ratios, data_loader, device):
   
    # Make a copy of the trained model
    model = copy.deepcopy(trained_model)
    model.to(device)

    model_blocks, ignored_layers = get_model_blocks(model)


    pruning_info = {i: {"block": model_blocks[i], "pruning_ratio": ratio} 
                    for i, ratio in enumerate(pruning_ratios)}

    if prune_method == 'Taylor':
        imp = tp.importance.TaylorImportance() 
        
        if isinstance(imp, (tp.importance.GroupTaylorImportance, tp.importance.GroupHessianImportance)):
            model.zero_grad()
            if isinstance(imp, tp.importance.GroupHessianImportance):
                imp.zero_grad() # clear the accumulated gradients
            print("Accumulating gradients for pruning...")
            for k, (imgs, lbls) in enumerate(tqdm(data_loader)):
                if k>=taylor_batchs: break
                imgs = imgs.to(device)
                lbls = lbls.to(device)
                output = model(imgs)
                if isinstance(imp, tp.importance.GroupHessianImportance): # per-sample gradients for hessian
                    loss = torch.nn.functional.cross_entropy(output, lbls, reduction='none')
                    for l in loss:
                        model.zero_grad()
                        l.backward(retain_graph=True)
                        imp.accumulate_grad(model) # accumulate gradients
                elif isinstance(imp, tp.importance.GroupTaylorImportance): # batch gradients for first-order taylor
                    loss = torch.nn.functional.cross_entropy(output, lbls)
                    loss.backward()

        original_macs, original_params = tp.utils.count_ops_and_params(model, imgs)

        # Prune each block while ignoring other layers
        for i, info in pruning_info.items():
            pruning_ratio = info["pruning_ratio"]
            
            # Add all blocks to the ignored layers except the block being pruned
            ignored_layers_block = [pruning_info[j]["block"] for j in range(len(pruning_info)) if j != i]

            # Combine fixed ignored layers (conv_stem, bn1, classifier) with the ignored blocks
            combined_ignored_layers = ignored_layers + ignored_layers_block

            # Apply pruning using the combined ignored layers
            pruner_group = tp.pruner.MagnitudePruner( 
                model,
                example_inputs=imgs,
                importance=imp,
                pruning_ratio=pruning_ratio,
                ignored_layers=combined_ignored_layers,
                iterative_steps=1,
            )

            # Step through pruning
            pruner_group.step()

    # Counting MACs and Params after pruning
    macs, nparams = tp.utils.count_ops_and_params(model, imgs)

    # logging.info(f"MACs of the Pruned Model: {macs/ 1e9} G")
    print(f"MACs of the Original Model: {original_macs/ 1e9} G  -->  MACs of the Pruned Model: {macs/ 1e9} G")
    # logging.info(f"# Parameters of the Pruned Model: {nparams/ 1e3} K")
    print(f"# Parameters of the Original Model: {original_params/ 1e3} K  --> # Parameters of the Pruned Model: {nparams/ 1e3} K")

    # Free up GPU memory
    del imgs, lbls, output
    torch.cuda.empty_cache()

    return model

In [14]:
def calculate_pruning_ratios_intense(contributions, max_pruning_ratio=0.9, k=5):
    # Normalize the contributions to get values between 0 and 1
    total_contribution = sum(contributions)
    normalized_contributions = [contribution / total_contribution for contribution in contributions]

    # Apply exponential decay to magnify the effect for less important blocks
    pruning_factors = [np.exp(-k * nc) for nc in normalized_contributions]

    # Normalize the pruning factors so they stay within the max pruning ratio
    max_factor = max(pruning_factors)
    normalized_factors = [pf / max_factor for pf in pruning_factors]

    # Scale by the maximum pruning ratio
    pruning_ratios = [max_pruning_ratio * nf for nf in normalized_factors]

    pruning_ratios = [round(num, 2) for num in pruning_ratios]

    return pruning_ratios

In [None]:
warnings.filterwarnings("ignore")
relative_contribution = perplexity_analysis_with_contributions(model, 'Taylor', cal_loader, device)


=== Computing baseline accuracy without block replacement ===


100%|██████████| 128/128 [00:46<00:00,  2.74it/s]


Baseline accuracy: 84.47%

=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.44it/s]



Replacing Block 0:
  - MACs Reduction: 2.67%
  - Parameters Reduction: 0.15%


100%|██████████| 128/128 [00:43<00:00,  2.92it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.73it/s]



Replacing Block 1:
  - MACs Reduction: 2.67%
  - Parameters Reduction: 0.15%


100%|██████████| 128/128 [00:44<00:00,  2.90it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.64it/s]



Replacing Block 2:
  - MACs Reduction: 2.67%
  - Parameters Reduction: 0.15%


100%|██████████| 128/128 [00:43<00:00,  2.91it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.80it/s]



Replacing Block 3:
  - MACs Reduction: 0.00%
  - Parameters Reduction: 0.00%


100%|██████████| 128/128 [00:46<00:00,  2.75it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:09<00:00, 12.80it/s]



Replacing Block 4:
  - MACs Reduction: 2.67%
  - Parameters Reduction: 0.59%


100%|██████████| 128/128 [00:45<00:00,  2.84it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.71it/s]



Replacing Block 5:
  - MACs Reduction: 2.67%
  - Parameters Reduction: 0.59%


100%|██████████| 128/128 [00:44<00:00,  2.85it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.77it/s]



Replacing Block 6:
  - MACs Reduction: 2.67%
  - Parameters Reduction: 0.59%


100%|██████████| 128/128 [00:44<00:00,  2.87it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.73it/s]



Replacing Block 7:
  - MACs Reduction: 0.00%
  - Parameters Reduction: 0.00%


100%|██████████| 128/128 [00:46<00:00,  2.75it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:09<00:00, 12.87it/s]



Replacing Block 8:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 2.37%


100%|██████████| 128/128 [00:45<00:00,  2.83it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.77it/s]



Replacing Block 9:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 2.37%


100%|██████████| 128/128 [00:45<00:00,  2.82it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.74it/s]



Replacing Block 10:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 2.37%


100%|██████████| 128/128 [00:45<00:00,  2.82it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:09<00:00, 12.84it/s]



Replacing Block 11:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 2.37%


100%|██████████| 128/128 [00:45<00:00,  2.81it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:09<00:00, 12.83it/s]



Replacing Block 12:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 2.37%


100%|██████████| 128/128 [00:45<00:00,  2.82it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.72it/s]



Replacing Block 13:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 2.37%


100%|██████████| 128/128 [00:45<00:00,  2.79it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.77it/s]



Replacing Block 14:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 2.37%


100%|██████████| 128/128 [00:45<00:00,  2.82it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:09<00:00, 12.86it/s]



Replacing Block 15:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 2.37%


100%|██████████| 128/128 [00:45<00:00,  2.82it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.67it/s]



Replacing Block 16:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 2.37%


100%|██████████| 128/128 [00:45<00:00,  2.81it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:09<00:00, 12.81it/s]



Replacing Block 17:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 2.37%


100%|██████████| 128/128 [00:45<00:00,  2.83it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.62it/s]



Replacing Block 18:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 2.37%


100%|██████████| 128/128 [00:45<00:00,  2.84it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.64it/s]



Replacing Block 19:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 2.37%


100%|██████████| 128/128 [00:45<00:00,  2.82it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.79it/s]



Replacing Block 20:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 2.37%


100%|██████████| 128/128 [00:45<00:00,  2.84it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.76it/s]



Replacing Block 21:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 2.37%


100%|██████████| 128/128 [00:44<00:00,  2.84it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.75it/s]



Replacing Block 22:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 2.37%


100%|██████████| 128/128 [00:45<00:00,  2.80it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:09<00:00, 12.80it/s]



Replacing Block 23:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 2.37%


100%|██████████| 128/128 [00:45<00:00,  2.82it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.78it/s]



Replacing Block 24:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 2.37%


100%|██████████| 128/128 [00:45<00:00,  2.82it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:09<00:00, 12.85it/s]



Replacing Block 25:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 2.37%


100%|██████████| 128/128 [00:45<00:00,  2.81it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.71it/s]



Replacing Block 26:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 2.37%


100%|██████████| 128/128 [00:45<00:00,  2.82it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.64it/s]



Replacing Block 27:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 2.37%


100%|██████████| 128/128 [00:45<00:00,  2.82it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.67it/s]



Replacing Block 28:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 2.37%


100%|██████████| 128/128 [00:45<00:00,  2.84it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.66it/s]



Replacing Block 29:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 2.37%


100%|██████████| 128/128 [00:45<00:00,  2.84it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.69it/s]



Replacing Block 30:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 2.37%


100%|██████████| 128/128 [00:45<00:00,  2.82it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.78it/s]



Replacing Block 31:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 2.37%


100%|██████████| 128/128 [00:45<00:00,  2.81it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:09<00:00, 12.81it/s]



Replacing Block 32:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 2.37%


100%|██████████| 128/128 [00:45<00:00,  2.82it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.70it/s]



Replacing Block 33:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 2.37%


100%|██████████| 128/128 [00:45<00:00,  2.81it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.70it/s]



Replacing Block 34:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 2.37%


100%|██████████| 128/128 [00:45<00:00,  2.80it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.76it/s]



Replacing Block 35:
  - MACs Reduction: 0.00%
  - Parameters Reduction: 0.00%


100%|██████████| 128/128 [00:46<00:00,  2.75it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.62it/s]



Replacing Block 36:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 9.47%


100%|██████████| 128/128 [00:45<00:00,  2.80it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.70it/s]



Replacing Block 37:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 9.47%


100%|██████████| 128/128 [00:45<00:00,  2.81it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.72it/s]



Replacing Block 38:
  - MACs Reduction: 2.68%
  - Parameters Reduction: 9.47%


100%|██████████| 128/128 [00:45<00:00,  2.81it/s]



=== Replacing Blocks and Tracking Reductions ===
Accumulating gradients for pruning...


100%|██████████| 128/128 [00:10<00:00, 12.73it/s]



Replacing Block 39:
  - MACs Reduction: 0.00%
  - Parameters Reduction: 0.00%


100%|██████████| 128/128 [00:46<00:00,  2.74it/s]



=== Relative Contribution of Each Block ===
Block 0: Accuracy Decrease Contribution = 59.03%, 
Parameter Reduction = 0.16%, 
MACs Reduction = 2.77%, 
Weighted Importance Score = 29.51
Block 1: Accuracy Decrease Contribution = 3.50%, 
Parameter Reduction = 0.16%, 
MACs Reduction = 2.77%, 
Weighted Importance Score = 1.75
Block 2: Accuracy Decrease Contribution = 9.43%, 
Parameter Reduction = 0.16%, 
MACs Reduction = 2.77%, 
Weighted Importance Score = 4.72
Block 3: Accuracy Decrease Contribution = 0.00%, 
Parameter Reduction = 0.00%, 
MACs Reduction = 0.00%, 
Weighted Importance Score = 0.00
Block 4: Accuracy Decrease Contribution = 1.62%, 
Parameter Reduction = 0.63%, 
MACs Reduction = 2.78%, 
Weighted Importance Score = 0.81
Block 5: Accuracy Decrease Contribution = 1.89%, 
Parameter Reduction = 0.63%, 
MACs Reduction = 2.78%, 
Weighted Importance Score = 0.94
Block 6: Accuracy Decrease Contribution = 1.62%, 
Parameter Reduction = 0.63%, 
MACs Reduction = 2.78%, 
Weighted Importance 

In [None]:
pruning_ratios = calculate_pruning_ratios_intense(relative_contribution, max_pruning_ratio=0.5, k=5)

In [16]:
pruning_ratios = [np.float64(0.02),
 np.float64(0.4),
 np.float64(0.3),
 np.float64(0.47),
 np.float64(0.44),
 np.float64(0.43),
 np.float64(0.44),
 np.float64(0.47),
 np.float64(0.49),
 np.float64(0.44),
 np.float64(0.45),
 np.float64(0.43),
 np.float64(0.44),
 np.float64(0.49),
 np.float64(0.44),
 np.float64(0.5),
 np.float64(0.48),
 np.float64(0.43),
 np.float64(0.46),
 np.float64(0.44),
 np.float64(0.44),
 np.float64(0.44),
 np.float64(0.45),
 np.float64(0.47),
 np.float64(0.44),
 np.float64(0.46),
 np.float64(0.44),
 np.float64(0.44),
 np.float64(0.44),
 np.float64(0.44),
 np.float64(0.47),
 np.float64(0.49),
 np.float64(0.47),
 np.float64(0.47),
 np.float64(0.45),
 np.float64(0.47),
 np.float64(0.46),
 np.float64(0.46),
 np.float64(0.47),
 np.float64(0.47)]

In [23]:
warnings.filterwarnings("ignore")

pruned_model = prune_model(model, 'Taylor', pruning_ratios, cal_loader, device)

Accumulating gradients for pruning...


100%|██████████| 1024/1024 [00:16<00:00, 63.36it/s]


MACs of the Original Model: 15.360289896 G  -->  MACs of the Pruned Model: 8.893172991 G
# Parameters of the Original Model: 88591.464 K  --> # Parameters of the Pruned Model: 50298.111 K


In [24]:
model

ConvNeXt(
  (stem): Sequential(
    (0): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
    (1): LayerNorm2d((128,), eps=1e-06, elementwise_affine=True)
  )
  (stages): Sequential(
    (0): ConvNeXtStage(
      (downsample): Identity()
      (blocks): Sequential(
        (0): ConvNeXtBlock(
          (conv_dw): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128)
          (norm): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=128, out_features=512, bias=True)
            (act): GELU()
            (drop1): Dropout(p=0.0, inplace=False)
            (norm): Identity()
            (fc2): Linear(in_features=512, out_features=128, bias=True)
            (drop2): Dropout(p=0.0, inplace=False)
          )
          (shortcut): Identity()
          (drop_path): Identity()
        )
        (1): ConvNeXtBlock(
          (conv_dw): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), g

In [26]:
pruned_model

ConvNeXt(
  (stem): Sequential(
    (0): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
    (1): LayerNorm2d((128,), eps=1e-06, elementwise_affine=True)
  )
  (stages): Sequential(
    (0): ConvNeXtStage(
      (downsample): Identity()
      (blocks): Sequential(
        (0): ConvNeXtBlock(
          (conv_dw): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128)
          (norm): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=128, out_features=501, bias=True)
            (act): GELU()
            (drop1): Dropout(p=0.0, inplace=False)
            (norm): Identity()
            (fc2): Linear(in_features=501, out_features=128, bias=True)
            (drop2): Dropout(p=0.0, inplace=False)
          )
          (shortcut): Identity()
          (drop_path): Identity()
        )
        (1): ConvNeXtBlock(
          (conv_dw): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), g

In [30]:
pruned_accuracy,_ = validate_model(pruned_model, val_loader, device='cuda')

100%|██████████| 50000/50000 [04:02<00:00, 205.84it/s]


In [32]:
torch.save(pruned_model, '/home/ict317-3/Mohammad/Isomorphic-Pruning/our_pruned_models/ConvNext4.2/pruned_model_43%.pth')