In [None]:
from fastai.vision.all import *
from fastai.callback.all import *
from fasterai.core.criteria import *
import torch_pruning as tp
from torch_pruning.pruner import function
import torch_pruning as tp
import timm

import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
def onecycle_scheduler(pruning_ratio_dict, steps, start=0, end=1, α=14, β=6):
    return [
        sched_onecycle(start, end, i / float(steps), α, β) * pruning_ratio_dict
        for i in range(steps + 1)
    ]

def sched_onecycle(start, end, pos, α=14, β=6):
    out = (1 + np.exp(-α + β)) / (1 + np.exp((-α * pos) + β))
    return start + (end - start) * out

In [None]:
def get_dls(size, bs):
    path = URLs.IMAGENETTE_160
    source = untar_data(path)
    blocks=(ImageBlock, CategoryBlock)
    tfms = [RandomResizedCrop(size, min_scale=0.35), FlipItem(0.5)]
    batch_tfms = [Normalize.from_stats(*imagenet_stats)]

    csv_file = 'noisy_imagenette.csv'
    inp = pd.read_csv(source/csv_file)
    dblock = DataBlock(blocks=blocks,
               splitter=ColSplitter(),
               get_x=ColReader('path', pref=source),
               get_y=ColReader(f'noisy_labels_0'),
               item_tfms=tfms,
               batch_tfms=batch_tfms)

    return dblock.dataloaders(inp, path=source, bs=bs)

In [None]:
model = timm.create_model('resnet18', pretrained=False, no_jit=True).eval()

In [None]:
dls = get_dls(model.default_cfg['input_size'][2], 16)

In [None]:
#learn = vision_learner(dls, 'bat_resnext26ts', metrics = [accuracy])
#learn.unfreeze()

model = timm.create_model('beit_base_patch16_224', pretrained=False, no_jit=True).eval()

ignored_layers = []
num_heads = {}
pruning_ratio_dict = {}
#ratios = [0.265625,0.234375,0.265625,0.265625,0.93359375,0.328125,0.2265625,0.58984375,0.54296875,0.701171875,0.919921875,0.04296875,0.796875,0.240966796875,0.07763671875]


#k = 0
for m in model.modules():
    #if hasattr(m, 'head'): #isinstance(m, nn.Linear) and m.out_features == model.num_classes:
    if isinstance(m, nn.Linear) and m.out_features == model.num_classes:
        ignored_layers.append(m)
        print("Ignore classifier layer: ", m)

    # Attention layers
    if hasattr(m, 'num_heads'):
        if hasattr(m, 'qkv'):
            num_heads[m.qkv] = m.num_heads
            print("Attention layer: ", m.qkv, m.num_heads)
        elif hasattr(m, 'qkv_proj'):
            num_heads[m.qkv_proj] = m.num_heads
    
    #elif isinstance(m, nn.Conv2d):
    #    pruning_ratio_dict[m] = ratios[k]
    #    print(k)
    #    k+=1

learn = Learner(dls, model, metrics = [accuracy])

In [None]:
xb, _ = dls.one_batch()

In [None]:
learn.fit_one_cycle(3, 1e-3)

In [None]:
pruner = tp.pruner.MetaPruner(
                        model, 
                        xb.to('cpu'), 
                        global_pruning=False,
                        importance=tp.importance.GroupNormImportance(), 
                        iterative_steps=10000,
                        pruning_ratio=0.5,
                        #pruning_ratio_dict=pruning_ratio_dict,
                        num_heads=num_heads,
                        ignored_layers=ignored_layers,
                    )
#for g in pruner.step(interactive=True):
#    g.prune()
pruner.step()

In [None]:
for m in model.modules():
    # Attention layers
    if hasattr(m, 'num_heads'):
        if hasattr(m, 'qkv'):
            m.num_heads = num_heads[m.qkv]
            m.head_dim = m.qkv.out_features // (3 * m.num_heads)
        elif hasattr(m, 'qkv_proj'):
            m.num_heads = num_heads[m.qqkv_projkv]
            m.head_dim = m.qkv_proj.out_features // (3 * m.num_heads)

In [None]:
model_old = timm.create_model('convnext_xxlarge', pretrained=False, no_jit=True).eval()
base_macs, base_params = tp.utils.count_ops_and_params(model_old, xb.to('cpu'))

In [None]:
pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, xb.to('cpu'))

In [None]:
print("MACs: %.4f G => %.4f G"%(base_macs/1e9, pruned_macs/1e9))
print("Params: %.4f M => %.4f M"%(base_params/1e6, pruned_params/1e6))

In [None]:
class PruneCallback(Callback):
    def __init__(self, pruning_ratio, schedule, criteria, ignored_layers, *args, **kwargs):
        store_attr()
        self.sparsity_levels = []
        self.extra_args = args
        self.extra_kwargs = kwargs

    def before_fit(self):
        n_batches_per_epoch = len(self.learn.dls.train)
        total_training_steps = n_batches_per_epoch * self.learn.n_epoch

        self.total_training_steps = total_training_steps 
        print(self.total_training_steps)
        self.example_inputs, _ = self.learn.dls.one_batch()
        self.sparsity_levels = self.schedule(self.pruning_ratio, total_training_steps)

        self.pruner = tp.pruner.MetaPruner(
        self.learn.model,
        example_inputs= torch.randn(self.example_inputs.shape).to('cuda:0'),
        importance=self.criteria,
        pruning_ratio=self.pruning_ratio, 
        ignored_layers=self.ignored_layers,
        iterative_steps= self.total_training_steps, 
        #iterative_steps= 1, 
        #iterative_pruning_ratio_scheduler=self.schedule,
        #global_pruning=self.context, 
        *self.extra_args, 
        **self.extra_kwargs
        )
        
    def before_step(self):
        if self.training: 
           #self.pruner.step()
            for g in self.pruner.step(interactive=True):
                g.prune()
            
        #for m in self.pruner.model.modules():
        #    # Attention layers
        #    if hasattr(m, 'num_heads'):
        #        if hasattr(m, 'qkv'):
        #            m.num_heads = num_heads[m.qkv]
        #            m.head_dim = m.qkv.out_features // (3 * m.num_heads)
        #        elif hasattr(m, 'qkv_proj'):
        #            m.num_heads = num_heads[m.qqkv_projkv]
        #            m.head_dim = m.qkv_proj.out_features // (3 * m.num_heads)

    def after_epoch(self):
        completed_steps = (self.epoch + 1) * len(self.learn.dls.train)
        current_sparsity = self.sparsity_levels[completed_steps - 1]
        print(f'Sparsity at the end of epoch {self.epoch}: {current_sparsity*100:.2f}%')

In [None]:
timm.list_models()

In [None]:
model = timm.create_model('resnet18', pretrained=True, no_jit=True).eval()

ignored_layers = []
num_heads = {}

#k = 0
for m in model.modules():
    if isinstance(m, nn.Linear) and m.out_features == model.num_classes:
        ignored_layers.append(m)
        print("Ignore classifier layer: ", m)

    # Attention layers
    if hasattr(m, 'num_heads'):
        if hasattr(m, 'qkv'):
            num_heads[m.qkv] = m.num_heads
            print("Attention layer: ", m.qkv, m.num_heads)
        elif hasattr(m, 'qkv_proj'):
            num_heads[m.qkv_proj] = m.num_heads

learn = Learner(dls, model, metrics = [accuracy])

In [None]:
learn.fit_one_cycle(5, 1e-3)

In [None]:
pr_cb = PruneCallback(pruning_ratio=0.25, schedule=onecycle_scheduler, global_pruning=True, criteria=tp.importance.GroupNormImportance(normalizer=None, target_types=[nn.modules.conv._ConvNd, nn.Linear]), num_heads=num_heads, ignored_layers=ignored_layers)
learn.fit_one_cycle(10, 1e-4, cbs=pr_cb)

In [None]:
pr_cb = PruneCallback(pruning_ratio=0.25, schedule=onecycle_scheduler, global_pruning=True, criteria=GroupNormImportance(normalizer=None, target_types=[nn.modules.conv._ConvNd, nn.Linear]), num_heads=num_heads, ignored_layers=ignored_layers)
learn.fit_one_cycle(15, 1e-4, cbs=pr_cb)

In [None]:
model = timm.create_model('tf_efficientnet_b3', pretrained=False, no_jit=True).eval()
base_macs, base_params = tp.utils.count_ops_and_params(model, xb.to('cpu'))

In [None]:
pruned_macs, pruned_params = tp.utils.count_ops_and_params(learn.model, xb.to('cuda:0'))

In [None]:
print("MACs: %.4f G => %.4f G"%(base_macs/1e9, pruned_macs/1e9))
print("Params: %.4f M => %.4f M"%(base_params/1e6, pruned_params/1e6))

In [None]:
import abc
import torch
import torch.nn as nn

import typing

from torch_pruning import function
from torch_pruning.dependency import Group

class Importance(abc.ABC):
    """ Estimate the importance of a tp.Dependency.Group, and return an 1-D per-channel importance score.

        It should accept a group as inputs, and return a 1-D tensor with the same length as the number of channels.
        All groups must be pruned simultaneously and thus their importance should be accumulated across channel groups.

        Example:
            ```python
            DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1,3,224,224)) 
            group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] )    
            scorer = MagnitudeImportance()    
            imp_score = scorer(group)    
            #imp_score is a 1-D tensor with length 3 for channels [2, 6, 9]  
            min_score = imp_score.min() 
            ``` 
    """
    @abc.abstractclassmethod
    def __call__(self, group: Group) -> torch.Tensor: 
        raise NotImplementedError


class GroupNormImportance(Importance):

    def __init__(self, 
                 p: int=2, 
                 group_reduction: str="mean", 
                 normalizer: str='mean', 
                 bias=False,
                 target_types:list=[nn.modules.conv._ConvNd, nn.Linear, nn.modules.batchnorm._BatchNorm, nn.LayerNorm]):
        self.p = p
        self.group_reduction = group_reduction
        self.normalizer = normalizer
        self.target_types = target_types
        self.bias = bias

    def _lamp(self, scores): # Layer-adaptive Sparsity for the Magnitude-based Pruning
        """
        Normalizing scheme for LAMP.
        """
        # sort scores in an ascending order
        sorted_scores,sorted_idx = scores.view(-1).sort(descending=False)
        # compute cumulative sum
        scores_cumsum_temp = sorted_scores.cumsum(dim=0)
        scores_cumsum = torch.zeros(scores_cumsum_temp.shape,device=scores.device)
        scores_cumsum[1:] = scores_cumsum_temp[:len(scores_cumsum_temp)-1]
        # normalize by cumulative sum
        sorted_scores /= (scores.sum() - scores_cumsum)
        # tidy up and output
        new_scores = torch.zeros(scores_cumsum.shape,device=scores.device)
        new_scores[sorted_idx] = sorted_scores
        
        return new_scores.view(scores.shape)
    
    def _normalize(self, group_importance, normalizer):
        if normalizer is None:
            return group_importance
        elif isinstance(normalizer, typing.Callable):
            return normalizer(group_importance)
        elif normalizer == "sum":
            return group_importance / group_importance.sum()
        elif normalizer == "standarization":
            return (group_importance - group_importance.min()) / (group_importance.max() - group_importance.min()+1e-8)
        elif normalizer == "mean":
            return group_importance / group_importance.mean()
        elif normalizer == "max":
            return group_importance / group_importance.max()
        elif normalizer == 'gaussian':
            return (group_importance - group_importance.mean()) / (group_importance.std()+1e-8)
        elif normalizer.startswith('sentinel'): # normalize the score with the k-th smallest element. e.g. sentinel_0.5 means median normalization
            sentinel = float(normalizer.split('_')[1]) * len(group_importance)
            sentinel = torch.argsort(group_importance, dim=0, descending=False)[int(sentinel)]
            return group_importance / (group_importance[sentinel]+1e-8)
        elif normalizer=='lamp':
            return self._lamp(group_importance)
        else:
            raise NotImplementedError

    def _reduce(self, group_imp: typing.List[torch.Tensor], group_idxs: typing.List[typing.List[int]]):
        if len(group_imp) == 0: return group_imp
        if self.group_reduction == 'prod':
            reduced_imp = torch.ones_like(group_imp[0])
        elif self.group_reduction == 'max':
            reduced_imp = torch.ones_like(group_imp[0]) * -99999
        else:
            reduced_imp = torch.zeros_like(group_imp[0])

        for i, (imp, root_idxs) in enumerate(zip(group_imp, group_idxs)):
            imp = imp.to(reduced_imp.device)
            if self.group_reduction == "sum" or self.group_reduction == "mean":
                reduced_imp.scatter_add_(0, torch.tensor(root_idxs, device=imp.device), imp) # accumulated importance
            elif self.group_reduction == "max": # keep the max importance
                selected_imp = torch.index_select(reduced_imp, 0, torch.tensor(root_idxs, device=imp.device))
                selected_imp = torch.maximum(input=selected_imp, other=imp)
                reduced_imp.scatter_(0, torch.tensor(root_idxs, device=imp.device), selected_imp)
            elif self.group_reduction == "prod": # product of importance
                selected_imp = torch.index_select(reduced_imp, 0, torch.tensor(root_idxs, device=imp.device))
                torch.mul(selected_imp, imp, out=selected_imp)
                reduced_imp.scatter_(0, torch.tensor(root_idxs, device=imp.device), selected_imp)
            elif self.group_reduction == 'first':
                if i == 0:
                    reduced_imp.scatter_(0, torch.tensor(root_idxs, device=imp.device), imp)
            elif self.group_reduction == 'gate':
                if i == len(group_imp)-1:
                    reduced_imp.scatter_(0, torch.tensor(root_idxs, device=imp.device), imp)
            elif self.group_reduction is None:
                reduced_imp = torch.stack(group_imp, dim=0) # no reduction
            else:
                raise NotImplementedError
        
        if self.group_reduction == "mean":
            reduced_imp /= len(group_imp)
        return reduced_imp
    
    @torch.no_grad()
    def __call__(self, group: Group):
        group_imp = []
        group_idxs = []
        # Iterate over all groups and estimate group importance
        for i, (dep, idxs) in enumerate(group):
            layer = dep.layer
            prune_fn = dep.pruning_fn
            root_idxs = group[i].root_idxs
            if not isinstance(layer, tuple(self.target_types)):
                continue
            ####################
            # Conv/Linear Output
            ####################
            if prune_fn in [
                function.prune_conv_out_channels,
                function.prune_linear_out_channels,
            ]:
                if hasattr(layer, "transposed") and layer.transposed:
                    w = layer.weight.data.transpose(1, 0)[idxs].flatten(1)
                else:
                    w = layer.weight.data[idxs].flatten(1)
                #local_imp = w.abs().pow(self.p).sum(1)
                local_imp = w.abs().pow(self.p).mean(1)
                group_imp.append(local_imp)
                group_idxs.append(root_idxs)

                if self.bias and layer.bias is not None:
                    local_imp = layer.bias.data[idxs].abs().pow(self.p)
                    group_imp.append(local_imp)
                    group_idxs.append(root_idxs)

            ####################
            # Conv/Linear Input
            ####################
            elif prune_fn in [
                function.prune_conv_in_channels,
                function.prune_linear_in_channels,
            ]:
                if hasattr(layer, "transposed") and layer.transposed:
                    w = (layer.weight.data).flatten(1)
                else:
                    w = (layer.weight.data).transpose(0, 1).flatten(1)
                #local_imp = w.abs().pow(self.p).sum(1)
                local_imp = w.abs().pow(self.p).mean(1)

                # repeat importance for group convolutions
                if prune_fn == function.prune_conv_in_channels and layer.groups != layer.in_channels and layer.groups != 1:
                    local_imp = local_imp.repeat(layer.groups)
                
                local_imp = local_imp[idxs]
                group_imp.append(local_imp)
                group_idxs.append(root_idxs)

            ####################
            # BatchNorm
            ####################
            elif prune_fn == function.prune_batchnorm_out_channels:
                # regularize BN
                if layer.affine:
                    w = layer.weight.data[idxs]
                    local_imp = w.abs().pow(self.p)
                    group_imp.append(local_imp)
                    group_idxs.append(root_idxs)

                    if self.bias and layer.bias is not None:
                        local_imp = layer.bias.data[idxs].abs().pow(self.p)
                        group_imp.append(local_imp)
                        group_idxs.append(root_idxs)
            ####################
            # LayerNorm
            ####################
            elif prune_fn == function.prune_layernorm_out_channels:

                if layer.elementwise_affine:
                    w = layer.weight.data[idxs]
                    local_imp = w.abs().pow(self.p)
                    group_imp.append(local_imp)
                    group_idxs.append(root_idxs)

                    if self.bias and layer.bias is not None:
                        local_imp = layer.bias.data[idxs].abs().pow(self.p)
                        group_imp.append(local_imp)
                        group_idxs.append(root_idxs)

        if len(group_imp) == 0: # skip groups without parameterized layers
            return None

        group_imp = self._reduce(group_imp, group_idxs)
        group_imp = self._normalize(group_imp, self.normalizer)
        return group_imp


In [None]:
timm_models = timm.list_models(module='resnet')

In [None]:
timm.list_models(exclude_filters='*vit*')

In [None]:
m = timm.create_model('seresnet18')

In [None]:
m

In [None]:
from fasterai.prune.all import *

In [None]:
dummy_input = torch.randn(16, 3, 224, 224)

In [None]:
m = timm.create_model('seresnet18')

In [None]:
benchmark(m, dummy_input)

In [None]:
pr = Pruner(m, 'local', large_final, layer_type=[nn.Conv2d])
pr.prune_model(30, round_to=8)

In [None]:
benchmark(m, dummy_input)

In [None]:
import abc
import torch
import torch.nn as nn

import typing

from torch_pruning import function
from torch_pruning.dependency import Group

class Importance(abc.ABC):
    """ Estimate the importance of a tp.Dependency.Group, and return an 1-D per-channel importance score.

        It should accept a group as inputs, and return a 1-D tensor with the same length as the number of channels.
        All groups must be pruned simultaneously and thus their importance should be accumulated across channel groups.

        Example:
            ```python
            DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1,3,224,224)) 
            group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] )    
            scorer = MagnitudeImportance()    
            imp_score = scorer(group)    
            #imp_score is a 1-D tensor with length 3 for channels [2, 6, 9]  
            min_score = imp_score.min() 
            ``` 
    """
    @abc.abstractclassmethod
    def __call__(self, group: Group) -> torch.Tensor: 
        raise NotImplementedError


class GroupNormImportance(Importance):

    def __init__(self, 
                 p: int=2, 
                 group_reduction: str="mean", 
                 normalizer: str='mean', 
                 bias=False,
                 target_types:list=[nn.modules.conv._ConvNd, nn.Linear, nn.modules.batchnorm._BatchNorm, nn.LayerNorm]):
        self.p = p
        self.group_reduction = group_reduction
        self.normalizer = normalizer
        self.target_types = target_types
        self.bias = bias

    def _lamp(self, scores): # Layer-adaptive Sparsity for the Magnitude-based Pruning
        """
        Normalizing scheme for LAMP.
        """
        # sort scores in an ascending order
        sorted_scores,sorted_idx = scores.view(-1).sort(descending=False)
        # compute cumulative sum
        scores_cumsum_temp = sorted_scores.cumsum(dim=0)
        scores_cumsum = torch.zeros(scores_cumsum_temp.shape,device=scores.device)
        scores_cumsum[1:] = scores_cumsum_temp[:len(scores_cumsum_temp)-1]
        # normalize by cumulative sum
        sorted_scores /= (scores.sum() - scores_cumsum)
        # tidy up and output
        new_scores = torch.zeros(scores_cumsum.shape,device=scores.device)
        new_scores[sorted_idx] = sorted_scores
        
        return new_scores.view(scores.shape)
    
    def _normalize(self, group_importance, normalizer):
        if normalizer is None:
            return group_importance
        elif isinstance(normalizer, typing.Callable):
            return normalizer(group_importance)
        elif normalizer == "sum":
            return group_importance / group_importance.sum()
        elif normalizer == "standarization":
            return (group_importance - group_importance.min()) / (group_importance.max() - group_importance.min()+1e-8)
        elif normalizer == "mean":
            return group_importance / group_importance.mean()
        elif normalizer == "max":
            return group_importance / group_importance.max()
        elif normalizer == 'gaussian':
            return (group_importance - group_importance.mean()) / (group_importance.std()+1e-8)
        elif normalizer.startswith('sentinel'): # normalize the score with the k-th smallest element. e.g. sentinel_0.5 means median normalization
            sentinel = float(normalizer.split('_')[1]) * len(group_importance)
            sentinel = torch.argsort(group_importance, dim=0, descending=False)[int(sentinel)]
            return group_importance / (group_importance[sentinel]+1e-8)
        elif normalizer=='lamp':
            return self._lamp(group_importance)
        else:
            raise NotImplementedError

    def _reduce(self, group_imp: typing.List[torch.Tensor], group_idxs: typing.List[typing.List[int]]):
        if len(group_imp) == 0: return group_imp
        if self.group_reduction == 'prod':
            reduced_imp = torch.ones_like(group_imp[0])
        elif self.group_reduction == 'max':
            reduced_imp = torch.ones_like(group_imp[0]) * -99999
        else:
            reduced_imp = torch.zeros_like(group_imp[0])

        for i, (imp, root_idxs) in enumerate(zip(group_imp, group_idxs)):
            imp = imp.to(reduced_imp.device)
            if self.group_reduction == "sum" or self.group_reduction == "mean":
                reduced_imp.scatter_add_(0, torch.tensor(root_idxs, device=imp.device), imp) # accumulated importance
            elif self.group_reduction == "max": # keep the max importance
                selected_imp = torch.index_select(reduced_imp, 0, torch.tensor(root_idxs, device=imp.device))
                selected_imp = torch.maximum(input=selected_imp, other=imp)
                reduced_imp.scatter_(0, torch.tensor(root_idxs, device=imp.device), selected_imp)
            elif self.group_reduction == "prod": # product of importance
                selected_imp = torch.index_select(reduced_imp, 0, torch.tensor(root_idxs, device=imp.device))
                torch.mul(selected_imp, imp, out=selected_imp)
                reduced_imp.scatter_(0, torch.tensor(root_idxs, device=imp.device), selected_imp)
            elif self.group_reduction == 'first':
                if i == 0:
                    reduced_imp.scatter_(0, torch.tensor(root_idxs, device=imp.device), imp)
            elif self.group_reduction == 'gate':
                if i == len(group_imp)-1:
                    reduced_imp.scatter_(0, torch.tensor(root_idxs, device=imp.device), imp)
            elif self.group_reduction is None:
                reduced_imp = torch.stack(group_imp, dim=0) # no reduction
            else:
                raise NotImplementedError
        
        if self.group_reduction == "mean":
            reduced_imp /= len(group_imp)
        return reduced_imp
    
    @torch.no_grad()
    def __call__(self, group: Group):
        group_imp = []
        group_idxs = []
        # Iterate over all groups and estimate group importance
        for i, (dep, idxs) in enumerate(group):
            layer = dep.layer
            prune_fn = dep.pruning_fn
            root_idxs = group[i].root_idxs
            if not isinstance(layer, tuple(self.target_types)):
                continue
            ####################
            # Conv/Linear Output
            ####################
            if prune_fn in [
                function.prune_conv_out_channels,
                function.prune_linear_out_channels,
            ]:
                if hasattr(layer, "transposed") and layer.transposed:
                    w = layer.weight.data.transpose(1, 0)[idxs].flatten(1)
                else:
                    w = layer.weight.data[idxs].flatten(1)
                #local_imp = w.abs().pow(self.p).sum(1)
                local_imp = w.abs().pow(self.p).mean(1)
                group_imp.append(local_imp)
                group_idxs.append(root_idxs)

                if self.bias and layer.bias is not None:
                    local_imp = layer.bias.data[idxs].abs().pow(self.p)
                    group_imp.append(local_imp)
                    group_idxs.append(root_idxs)

            ####################
            # Conv/Linear Input
            ####################
            elif prune_fn in [
                function.prune_conv_in_channels,
                function.prune_linear_in_channels,
            ]:
                if hasattr(layer, "transposed") and layer.transposed:
                    w = (layer.weight.data).flatten(1)
                else:
                    w = (layer.weight.data).transpose(0, 1).flatten(1)
                #local_imp = w.abs().pow(self.p).sum(1)
                local_imp = w.abs().pow(self.p).mean(1)

                # repeat importance for group convolutions
                if prune_fn == function.prune_conv_in_channels and layer.groups != layer.in_channels and layer.groups != 1:
                    local_imp = local_imp.repeat(layer.groups)
                
                local_imp = local_imp[idxs]
                group_imp.append(local_imp)
                group_idxs.append(root_idxs)

            ####################
            # BatchNorm
            ####################
            elif prune_fn == function.prune_batchnorm_out_channels:
                # regularize BN
                if layer.affine:
                    w = layer.weight.data[idxs]
                    local_imp = w.abs().pow(self.p)
                    group_imp.append(local_imp)
                    group_idxs.append(root_idxs)

                    if self.bias and layer.bias is not None:
                        local_imp = layer.bias.data[idxs].abs().pow(self.p)
                        group_imp.append(local_imp)
                        group_idxs.append(root_idxs)
            ####################
            # LayerNorm
            ####################
            elif prune_fn == function.prune_layernorm_out_channels:

                if layer.elementwise_affine:
                    w = layer.weight.data[idxs]
                    local_imp = w.abs().pow(self.p)
                    group_imp.append(local_imp)
                    group_idxs.append(root_idxs)

                    if self.bias and layer.bias is not None:
                        local_imp = layer.bias.data[idxs].abs().pow(self.p)
                        group_imp.append(local_imp)
                        group_idxs.append(root_idxs)

        if len(group_imp) == 0: # skip groups without parameterized layers
            return None

        group_imp = self._reduce(group_imp, group_idxs)
        group_imp = self._normalize(group_imp, self.normalizer)
        return group_imp
