In [1]:
import os, sys
import torch
import torch.nn.functional as F

from torchvision.datasets import ImageFolder
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
import torch.nn as nn
import timm
import torch_pruning as tp

import pbench
pbench.forward_patch.patch_timm_forward() # patch timm.forward() to support pruning

from tqdm import tqdm
import argparse

import torchvision as tv
import math

In [2]:
from torch.utils.data import Subset
import random

def prepare_imagenet(imagenet_root, train_batch_size=64, val_batch_size=128, num_workers=4, use_imagenet_mean_std=True, interpolation='bicubic', val_resize=256):
    """The imagenet_root should contain train and val folders.
    """
    interpolation = getattr(T.InterpolationMode, interpolation.upper())

    print('Parsing dataset...')
    train_dst = ImageFolder(os.path.join(imagenet_root, 'train'), 
                            transform=pbench.data.presets.ClassificationPresetEval(
                                mean=[0.485, 0.456, 0.406] if use_imagenet_mean_std else [0.5, 0.5, 0.5],
                                std=[0.229, 0.224, 0.225] if use_imagenet_mean_std else [0.5, 0.5, 0.5],
                                crop_size=224,
                                resize_size=val_resize,
                                interpolation=interpolation,
                            )
    )
    val_dst = ImageFolder(os.path.join(imagenet_root, 'val'), 
                          transform=pbench.data.presets.ClassificationPresetEval(
                                mean=[0.485, 0.456, 0.406] if use_imagenet_mean_std else [0.5, 0.5, 0.5],
                                std=[0.229, 0.224, 0.225] if use_imagenet_mean_std else [0.5, 0.5, 0.5],
                                crop_size=224,
                                resize_size=val_resize,
                                interpolation=interpolation,
                            )
    )
    train_loader = torch.utils.data.DataLoader(train_dst, batch_size=train_batch_size, shuffle=True, num_workers=num_workers)
    val_loader = torch.utils.data.DataLoader(val_dst, batch_size=val_batch_size, shuffle=False, num_workers=num_workers)

    cal_subset_size = 1024
    cal_indices = random.sample(range(len(val_loader)*val_batch_size), cal_subset_size)
    cal_dst = Subset(val_dst, cal_indices)
    cal_loader = torch.utils.data.DataLoader(cal_dst, batch_size=val_batch_size, shuffle=False, num_workers=num_workers)
    
    return train_loader, val_loader, cal_loader

In [3]:
import gc

def validate_model(model, val_loader, device):
    model.eval()
    model.to(device)
    correct = 0
    loss = 0
    with torch.no_grad():
        for k, (images, labels) in enumerate(tqdm(val_loader)):  
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss += torch.nn.functional.cross_entropy(outputs, labels, reduction='sum').item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
    
    del images, outputs, predicted
    gc.collect()
    return correct / len(val_loader.dataset), loss / len(val_loader.dataset)

In [4]:
data_path ='data/imagenet'
train_batch_size = 64
val_batch_size = 16
no_imagenet_mean_std = False
val_resize = 256
interpolation = 'bicubic' #'bilinear' 

model = 'convnext_base.fb_in1k' #'resnet101.tv_in1k' #'mobilenet_v2' 
is_torchvision = False
drop = 0.0
drop_path = 0.0
ckpt = None

taylor_batchs = 64

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
example_inputs = torch.randn(1,3,224,224)
train_loader, val_loader, cal_loader = prepare_imagenet(data_path, train_batch_size=train_batch_size, val_batch_size=val_batch_size, use_imagenet_mean_std= not no_imagenet_mean_std, val_resize=val_resize, interpolation=interpolation)

if is_torchvision:
        import torchvision
        print(f"Loading torchvision model {model}...")
        model = torchvision.models.__dict__[model](pretrained=True).eval()
else:
    print(f"Loading timm model {model}...")
    model = timm.create_model(model, pretrained=True, drop_rate=drop, drop_path_rate=drop_path).eval()

if ckpt is not None:
    print(f"Loading checkpoint from {ckpt}...")
    ckpt = torch.load(ckpt, map_location='cpu')
    model.load_state_dict(ckpt['model'])
    

Parsing dataset...
Loading timm model convnext_base.fb_in1k...


In [6]:
def get_resnet_blocks(model):

    model_blocks, ignored_blocks = [], []

    for child in model.children():
        if isinstance(child, nn.Sequential):
            model_blocks.extend([sub_child for sub_child in child.children()])
        else:
            ignored_blocks.append(child)

    return model_blocks, ignored_blocks


In [7]:
def get_mobilenet_blocks(model):

    model_blocks, ignored_blocks = [], []

    for feat in model.features:
        if not isinstance(feat, nn.Sequential):
            model_blocks.append(feat)
        else:
            ignored_blocks.append(feat)

    ignored_blocks.append(model.classifier)

    return model_blocks, ignored_blocks


In [8]:
def get_convnext_blocks(model):

      model_blocks, ignored_blocks = [], []

      for stage in model.stages:
            for block in stage.blocks:
                  if isinstance(block, nn.Module):
                        model_blocks.append(block)
            model_blocks.append(stage.downsample)

      ignored_blocks.extend([model.stem, model.norm_pre, model.head])

      return model_blocks, ignored_blocks

In [9]:
def get_model_blocks(model):
    if model.__class__.__name__ == "MobileNetV2":
        return get_mobilenet_blocks(model)
    elif model.__class__.__name__ == "ResNet":
        return get_resnet_blocks(model)
    elif model.__class__.__name__ == "ConvNeXt":
        return get_convnext_blocks(model)


In [10]:
import copy
import warnings

def selective_block_pruning(trained_model, prune_method, pruning_ratios, data_loader, device):
    

    model = copy.deepcopy(trained_model)
    model_blocks, ignored_layers = get_model_blocks(model)
    model.to(device)

    pruning_info = {i: {"block": model_blocks[i], "pruning_ratio": ratio} for i, ratio in enumerate(pruning_ratios)}

    if prune_method == "Taylor":
        imp = tp.importance.TaylorImportance()

        if isinstance(imp, (tp.importance.GroupTaylorImportance, tp.importance.GroupHessianImportance)):
            model.zero_grad()
            if isinstance(imp, tp.importance.GroupHessianImportance):
                imp.zero_grad() # clear the accumulated gradients
            print("Accumulating gradients for pruning...")
            for k, (imgs, lbls) in enumerate(tqdm(data_loader)):
                if k>=taylor_batchs: break
                imgs = imgs.to(device)
                lbls = lbls.to(device)
                output = model(imgs)
                if isinstance(imp, tp.importance.GroupHessianImportance): # per-sample gradients for hessian
                    loss = torch.nn.functional.cross_entropy(output, lbls, reduction='none')
                    for l in loss:
                        model.zero_grad()
                        l.backward(retain_graph=True)
                        imp.accumulate_grad(model) # accumulate gradients
                elif isinstance(imp, tp.importance.GroupTaylorImportance): # batch gradients for first-order taylor
                    loss = torch.nn.functional.cross_entropy(output, lbls)
                    loss.backward()

    elif prune_method == "Magnitude":
        imp = tp.importance.MagnitudeImportance()

    else:
        warnings.warn(f"Invalid pruning method: '{prune_method}'. Expected 'Taylor' or 'Magnitude'.", UserWarning)
        raise ValueError("Pruning method must be either 'Taylor' or 'Magnitude'.")

    _, original_nparams = tp.utils.count_ops_and_params(model, imgs)

    for i, info in pruning_info.items():
        _, pruning_ratio = info["block"], info["pruning_ratio"]
        if pruning_ratio == 0.0:
            continue

        ignored_layers_block = [pruning_info[j]["block"] for j in range(len(pruning_info)) if j != i]
        combined_ignored_layers = ignored_layers + ignored_layers_block

        count = 0

        while True:

            pruner_group = tp.pruner.MagnitudePruner(
                model,
                example_inputs=imgs,
                importance=imp,
                pruning_ratio=pruning_ratio,
                ignored_layers=combined_ignored_layers,
            )
            pruner_group.step()

            macs, nparams = tp.utils.count_ops_and_params(model, imgs)
            if original_nparams - nparams == 0:
                count += 1
                if count > 1:
                    break
                pruning_ratio = 0.5

            original_nparams = nparams

    del imgs, lbls, output
    torch.cuda.empty_cache()


    return model, macs, nparams


In [11]:
import logging
import numpy as np


def perplexity_analysis_with_contributions(original_model, prune_method, data_loader, device):

    
    model_blocks, ignored_layers = get_model_blocks(original_model)
    blocks_number = len(model_blocks)

    total_block_accuracy = [0.0 for _ in range(blocks_number)]
    params_reduction = []
    macs_reduction = []

    # logging.info("\n=== Computing baseline accuracy without block replacement ===")
    print("\n=== Computing baseline accuracy without block replacement ===")
    baseline_accuracy, baseline_loss = validate_model(original_model, data_loader, device='cpu')
    # logging.info(f"Baseline accuracy: {baseline_accuracy*100:.2f}%")
    print(f"Baseline accuracy: {baseline_accuracy*100:.2f}%")
    
    input_size = [3, 224, 224]
    example_inputs = torch.randn(1, *input_size)
    original_macs, original_nparams = tp.utils.count_ops_and_params(original_model, example_inputs)


    for block_idx in range(blocks_number):

        # logging.info("\n=== Replacing Blocks and Tracking Reductions ===")
        print("\n=== Replacing Blocks and Tracking Reductions ===")

        pruning_ratios = (np.eye(blocks_number) * 0.8)[block_idx]

        pruned_model, macs, nparams = selective_block_pruning(
        original_model, prune_method, pruning_ratios, data_loader, device
        )

        # logging.info(f"\nReplacing Block {block_idx}:")
        # logging.info(f"  - MACs Reduction: {((original_macs - macs) / original_macs * 100):.2f}%")
        # logging.info(f"  - Parameters Reduction: {((original_nparams - nparams)/original_nparams*100):.2f}%")
        print(f"\nReplacing Block {block_idx}:")
        print(f"  - MACs Reduction: {((original_macs - macs) / original_macs * 100):.2f}%")
        print(f"  - Parameters Reduction: {((original_nparams - nparams)/original_nparams*100):.2f}%")
        

        params_reduction.append(original_nparams - nparams)
        macs_reduction.append(original_macs - macs)

        pruned_model.to(device)
        pruned_model.eval()
        
        pruned_accuracy,_ = validate_model(pruned_model, data_loader, device='cpu')

        del pruned_model
        torch.cuda.empty_cache()

        # logging.info(f"Accuracy After Pruning This Block: {pruned_accuracy*100:.2f}%\n")
        print(f"Accuracy After Pruning This Block: {pruned_accuracy*100:.2f}%\n")

        total_block_accuracy[block_idx] += pruned_accuracy

    total_accuracy_reduction = 0.0
    block_reductions = []
    total_params_reduction = 0.0
    total_macs_reduction = 0.0

    for block_idx in range(blocks_number):
        final_average_accuracy = total_block_accuracy[block_idx]
        accuracy_reduction = baseline_accuracy - final_average_accuracy
        if accuracy_reduction < 0:
            print(f"Block {block_idx} improved accuracy by {-accuracy_reduction*100:.2f}% — treated as 0 for importance.")
            accuracy_reduction = 0.0
        block_reductions.append(accuracy_reduction)
        total_accuracy_reduction += accuracy_reduction
        total_params_reduction += params_reduction[block_idx]
        total_macs_reduction += macs_reduction[block_idx] 

    weighted_importance_scores = []

    # logging.info("\n=== Relative Contribution of Each Block ===")
    print("\n=== Relative Contribution of Each Block ===")
    for block_idx in range(blocks_number):
        if total_accuracy_reduction == 0:
            total_accuracy_reduction = 1e-8  # avoid division by zero
        relative_contribution_accuracy = (block_reductions[block_idx] / total_accuracy_reduction) * 100
        relative_contribution_params = (1 - (params_reduction[block_idx] / total_params_reduction)) * 100
        relative_contribution_macs = (1 - (macs_reduction[block_idx] / total_macs_reduction)) * 100

        weight_accuracy = 0.5
        weight_params = 0.3
        weight_macs = 0.2
        weighted_importance = (weight_accuracy * relative_contribution_accuracy) 
        + (weight_params * relative_contribution_params) 
        + (weight_macs * relative_contribution_macs)

        # logging.info(
        # f"Block {block_idx}: Accuracy Decrease Contribution = {relative_contribution_accuracy:.2f}%, "
        # f"Parameter Reduction = {100 - relative_contribution_params:.2f}%, "
        # f"MACs Reduction = {100 - relative_contribution_macs:.2f}%, "
        # f"Weighted Importance Score = {weighted_importance:.2f}"
        # )  
        print(f"Block {block_idx}: Accuracy Decrease Contribution = {relative_contribution_accuracy:.2f}%, ")
        print(f"Parameter Reduction = {100 - relative_contribution_params:.2f}%, ")
        print(f"MACs Reduction = {100 - relative_contribution_macs:.2f}%, ")
        print(f"Weighted Importance Score = {weighted_importance:.2f}") 

        weighted_importance_scores.append(weighted_importance)

    return weighted_importance_scores

In [12]:
# import logging
# import numpy as np


# def perplexity_analysis_with_contributions(original_model, prune_method, data_loader, device):

    
#     model_blocks, ignored_layers = get_model_blocks(original_model)
#     blocks_number = len(model_blocks)

#     total_block_accuracy = [0.0 for _ in range(blocks_number)]
#     params_reduction = []
#     macs_reduction = []

#     # logging.info("\n=== Computing baseline accuracy without block replacement ===")
#     print("\n=== Computing baseline accuracy without block replacement ===")
#     baseline_accuracy, baseline_loss = validate_model(original_model, data_loader, device='cpu')
#     # logging.info(f"Baseline accuracy: {baseline_accuracy*100:.2f}%")
#     print(f"Baseline accuracy: {baseline_accuracy*100:.2f}%")
    
#     input_size = [3, 224, 224]
#     example_inputs = torch.randn(1, *input_size).to(device)
#     original_macs, original_nparams = tp.utils.count_ops_and_params(original_model, example_inputs)


#     # logging.info("\n=== Replacing Blocks and Tracking Reductions ===")
#     print("\n=== Replacing Blocks and Tracking Reductions ===")

#     # pruning_ratios = (np.eye(blocks_number) * 0.8)[block_idx]

#     pruned_acc, pruned_macs, pruned_params = selective_block_pruning(
#     original_model, prune_method, data_loader, device
#     )

#     # logging.info(f"\nReplacing Block {block_idx}:")
#     # logging.info(f"  - MACs Reduction: {((original_macs - macs) / original_macs * 100):.2f}%")
#     # logging.info(f"  - Parameters Reduction: {((original_nparams - nparams)/original_nparams*100):.2f}%")

#     params_reduction = [original_nparams - param for param in pruned_params]
#     macs_reduction = [original_macs - macs for macs in pruned_macs]

#     # pruned_model.to(device)
#     # pruned_model.eval()
    
#     # pruned_accuracy,_ = validate_model(pruned_model, data_loader, device='cpu')

#     # del pruned_model
#     # torch.cuda.empty_cache()

#     # logging.info(f"Accuracy After Pruning This Block: {nacc*100:.2f}%\n")
#     total_block_accuracy = pruned_acc

#     total_accuracy_reduction = 0.0
#     block_reductions = []
#     total_params_reduction = 0.0
#     total_macs_reduction = 0.0

#     for block_idx in range(blocks_number):
#         final_average_accuracy = total_block_accuracy[block_idx]
#         accuracy_reduction = baseline_accuracy - final_average_accuracy
#         block_reductions.append(accuracy_reduction)
#         total_accuracy_reduction += accuracy_reduction
#         total_params_reduction += params_reduction[block_idx]
#         total_macs_reduction += macs_reduction[block_idx] 

#     relative_contributions = []
#     weighted_importance_scores = []

#     # logging.info("\n=== Relative Contribution of Each Block ===")
#     print("\n=== Relative Contribution of Each Block ===")
#     for block_idx in range(blocks_number):
#         relative_contribution_accuracy = (block_reductions[block_idx] / total_accuracy_reduction) * 100
#         relative_contribution_params = (1 - (params_reduction[block_idx] / total_params_reduction)) * 100
#         relative_contribution_macs = (1 - (macs_reduction[block_idx] / total_macs_reduction)) * 100

#         weight_accuracy = 0.5
#         weight_params = 0.3
#         weight_macs = 0.2
#         weighted_importance = (weight_accuracy * relative_contribution_accuracy) 
#         + (weight_params * relative_contribution_params) 
#         + (weight_macs * relative_contribution_macs)

#         logging.info(
#         f"Block {block_idx}: Accuracy Decrease Contribution = {relative_contribution_accuracy:.2f}%, "
#         f"Parameter Reduction = {100 - relative_contribution_params:.2f}%, "
#         f"MACs Reduction = {100 - relative_contribution_macs:.2f}%, "
#         f"Weighted Importance Score = {weighted_importance:.2f}"
#         )  
#         print(f"Block {block_idx}: Accuracy Decrease Contribution = {relative_contribution_accuracy:.2f}%, ")
#         print(f"Parameter Reduction = {100 - relative_contribution_params:.2f}%, ")
#         print(f"MACs Reduction = {100 - relative_contribution_macs:.2f}%, ")
#         print(f"Weighted Importance Score = {weighted_importance:.2f}")

#         relative_contributions.append(relative_contribution_accuracy)
#         weighted_importance_scores.append(weighted_importance)

#     return weighted_importance_scores

In [12]:
def prune_model(trained_model, prune_method, pruning_ratios, data_loader, device):
   
    # Make a copy of the trained model
    model = copy.deepcopy(trained_model)
    model.to(device)

    model_blocks, ignored_layers = get_model_blocks(model)


    pruning_info = {i: {"block": model_blocks[i], "pruning_ratio": ratio} 
                    for i, ratio in enumerate(pruning_ratios)}

    if prune_method == 'Taylor':
        imp = tp.importance.TaylorImportance() 
        
        if isinstance(imp, (tp.importance.GroupTaylorImportance, tp.importance.GroupHessianImportance)):
            model.zero_grad()
            if isinstance(imp, tp.importance.GroupHessianImportance):
                imp.zero_grad() # clear the accumulated gradients
            print("Accumulating gradients for pruning...")
            for k, (imgs, lbls) in enumerate(tqdm(data_loader)):
                if k>=taylor_batchs: break
                imgs = imgs.to(device)
                lbls = lbls.to(device)
                output = model(imgs)
                if isinstance(imp, tp.importance.GroupHessianImportance): # per-sample gradients for hessian
                    loss = torch.nn.functional.cross_entropy(output, lbls, reduction='none')
                    for l in loss:
                        model.zero_grad()
                        l.backward(retain_graph=True)
                        imp.accumulate_grad(model) # accumulate gradients
                elif isinstance(imp, tp.importance.GroupTaylorImportance): # batch gradients for first-order taylor
                    loss = torch.nn.functional.cross_entropy(output, lbls)
                    loss.backward()

        original_macs, original_params = tp.utils.count_ops_and_params(model, imgs)

        # Prune each block while ignoring other layers
        for i, info in pruning_info.items():
            pruning_ratio = info["pruning_ratio"]
            
            # Add all blocks to the ignored layers except the block being pruned
            ignored_layers_block = [pruning_info[j]["block"] for j in range(len(pruning_info)) if j != i]

            # Combine fixed ignored layers (conv_stem, bn1, classifier) with the ignored blocks
            combined_ignored_layers = ignored_layers + ignored_layers_block

            # Apply pruning using the combined ignored layers
            pruner_group = tp.pruner.MagnitudePruner( 
                model,
                example_inputs=imgs,
                importance=imp,
                pruning_ratio=pruning_ratio,
                ignored_layers=combined_ignored_layers,
                iterative_steps=1,
            )

            # Step through pruning
            pruner_group.step()

    # Counting MACs and Params after pruning
    macs, nparams = tp.utils.count_ops_and_params(model, imgs)

    # logging.info(f"MACs of the Pruned Model: {macs/ 1e9} G")
    print(f"MACs of the Original Model: {original_macs/ 1e9} G  -->  MACs of the Pruned Model: {macs/ 1e9} G")
    # logging.info(f"# Parameters of the Pruned Model: {nparams/ 1e3} K")
    print(f"# Parameters of the Original Model: {original_params/ 1e3} K  --> # Parameters of the Pruned Model: {nparams/ 1e3} K")

    param_reduction = ((original_params - nparams) / original_params) * 100
    macs_reduction = ((original_macs - macs) / original_macs) * 100

    # Free up GPU memory
    del imgs, lbls, output
    torch.cuda.empty_cache()

    return model, math.ceil(param_reduction), math.ceil(macs_reduction)

In [13]:
def calculate_pruning_ratios_intense(contributions, max_pruning_ratio=0.9, k=5):
    # Normalize the contributions to get values between 0 and 1
    total_contribution = sum(contributions)
    normalized_contributions = [contribution / total_contribution for contribution in contributions]

    # Apply exponential decay to magnify the effect for less important blocks
    pruning_factors = [np.exp(-k * nc) for nc in normalized_contributions]

    # Normalize the pruning factors so they stay within the max pruning ratio
    max_factor = max(pruning_factors)
    normalized_factors = [pf / max_factor for pf in pruning_factors]

    # Scale by the maximum pruning ratio
    pruning_ratios = [max_pruning_ratio * nf for nf in normalized_factors]

    pruning_ratios = [round(num, 2) for num in pruning_ratios]

    return pruning_ratios

In [None]:
warnings.filterwarnings("ignore")
relative_contribution = perplexity_analysis_with_contributions(model, 'Taylor', cal_loader, device)

In [None]:
# ConvNext Blocks Contributions
# relative_contribution = [29.51, 1.75, 4.72, 0.00, 0.81, 0.94, 0.81, 0.00, -0.40, 0.81, 0.54, 1.08, 0.81, -0.27, 0.81, -0.54, -0.13, 1.08, 0.27, 0.81, 0.67, 0.67, 0.40, 
#                          0.13, 0.81, 0.27, 0.67, 0.81, 0.67, 0.67, 0.00, -0.40, 0.00, 0.13, 0.54, 0.00, 0.27, 0.27, 0.00, 0.00]

In [None]:
# # ResNet Blocks Contributions

# relative_contribution = [6.373626373626373,
#  0.4578754578754579,
#  0.9340659340659341,
#  9.945054945054945,
#  1.0622710622710623,
#  0.5494505494505495,
#  2.9487179487179485,
#  13.754578754578755,
#  0.1282051282051282,
#  0.5311355311355311,
#  0.29304029304029305,
#  0.5677655677655677,
#  0.2380952380952381,
#  0.347985347985348,
#  0.4578754578754579,
#  0.4761904761904762,
#  0.49450549450549447,
#  0.6043956043956045,
#  0.3663003663003663,
#  0.4761904761904762,
#  0.18315018315018314,
#  0.29304029304029305,
#  0.10989010989010989,
#  0.29304029304029305,
#  0.42124542124542125,
#  0.2014652014652015,
#  0.2014652014652015,
#  0.402930402930403,
#  0.3663003663003663,
#  0.6227106227106227,
#  3.296703296703297,
#  1.52014652014652,
#  1.0805860805860805]

[6.199222889438361,
 0.30024726245143063,
 0.9360649947015189,
 9.46661956905687,
 1.0243730130695867,
 0.8300953726598375,
 3.1967502649240553,
 13.157894736842104,
 0.24726245143058992,
 0.33557046979865773,
 0.37089367714588484,
 0.4592016955139526,
 0.4238784881667256,
 0.31790886612504415,
 0.688802543270929,
 0.4238784881667256,
 0.4238784881667256,
 0.7594489579653833,
 0.33557046979865773,
 0.282585658777817,
 0.26492405510420347,
 0.4238784881667256,
 0.08830801836806781,
 0.12363122571529496,
 0.44154009184033916,
 0.17661603673613563,
 0.31790886612504415,
 0.40621688449311194,
 0.3885552808194984,
 0.7064641469445425,
 3.07311903920876,
 2.0840692334864004,
 1.3246202755210172]

In [None]:
# ResNet Blocks Contributions

# relative_contribution = [3.5756327251324307,
#  3.5854424171081027,
#  3.5854424171081027,
#  3.5903472630959388,
#  2.687855601334118,
#  2.104178928781636,
#  3.5854424171081027,
#  3.4628212674122034,
#  0.5199136747106141,
#  1.0251128114577202,
#  3.5756327251324307,
#  2.530900529723367,
#  3.1832450461055526,
#  3.5854424171081027,
#  2.717284677261134,
#  3.0998626643123406,
#  3.5854424171081027]

[3.5756327251324307,
 3.5854424171081027,
 3.5854424171081027,
 3.5903472630959388,
 2.687855601334118,
 2.104178928781636,
 3.5854424171081027,
 3.4628212674122034,
 0.5199136747106141,
 1.0251128114577202,
 3.5756327251324307,
 2.530900529723367,
 3.1832450461055526,
 3.5854424171081027,
 2.717284677261134,
 3.0998626643123406,
 3.5854424171081027]

In [54]:
pruning_ratios = calculate_pruning_ratios_intense(relative_contribution, max_pruning_ratio=0.98, k=5)

In [55]:
warnings.filterwarnings("ignore")
torch.cuda.empty_cache()

pruned_model, param_reduction, macs_reduction = prune_model(model, 'Taylor', pruning_ratios, cal_loader, device)

Accumulating gradients for pruning...


100%|██████████| 1024/1024 [00:15<00:00, 64.37it/s]


MACs of the Original Model: 15.360289896 G  -->  MACs of the Pruned Model: 2.679145091 G
# Parameters of the Original Model: 88591.464 K  --> # Parameters of the Pruned Model: 13470.083 K


In [None]:
pruned_accuracy,_ = validate_model(pruned_model, val_loader, device='cuda')

In [56]:
torch.save(pruned_model, f'/home/ict317-3/Mohammad/Isomorphic-Pruning/our_pruned_models/ConvNext4.2/pruned_model_{param_reduction}%.pth')