# SwiftFormer pruning example

In this toy example we are going to show how to run channel pruning on a custom model.


In [1]:
import logging
logging.basicConfig(level=logging.INFO)

### 1. Download the code and checkpoint from https://github.com/Amshaker/SwiftFormer

In [None]:
!git clone https://github.com/Amshaker/SwiftFormer.git
!pip install timm==0.5.4

### 2. Load the model and check which modules can be dagable

In [3]:
from SwiftFormer.models.swiftformer import SwiftFormer_S, EfficientAdditiveAttnetion
import torch

import torch_dag as td
from torch_dag.commons import look_for_dagable_modules
import torch_dag_algorithms as tda

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
model = SwiftFormer_S()
checkpoint_path = '/path/to/checkpoint'
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(checkpoint['model'], strict=False)
look_for_dagable_modules(model)

INFO:torch_dag.commons.debugging_tools:<class 'SwiftFormer.models.swiftformer.SwiftFormer'>
  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,
INFO:torch_dag.commons.debugging_tools:FAILURE: <class 'SwiftFormer.models.swiftformer.SwiftFormer'>
INFO:torch_dag.commons.debugging_tools:Tensor type unknown to einops <class 'torch.fx.proxy.Proxy'>
INFO:torch_dag.commons.debugging_tools:<class 'torch.nn.modules.container.Sequential'>
INFO:torch_dag.commons.debugging_tools:SUCCESS: <class 'torch.nn.modules.container.Sequential'>
INFO:torch_dag.commons.debugging_tools:<class 'torch.nn.modules.container.ModuleList'>
INFO:torch_dag.commons.debugging_tools:FAILURE: <class 'torch.nn.modules.container.ModuleList'>
INFO:torch_dag.commons.debugging_tools:Module [ModuleList] is missing the required "forward" function
INFO:torch_dag.commons.debugging_tools:<class 'torch.nn.modules.container.Sequential'>
INFO:torch_dag.commons.debugging_tools:FAILURE: <class 'torch.nn.modules.conta

({SwiftFormer.models.swiftformer.ConvEncoder,
  SwiftFormer.models.swiftformer.Embedding,
  SwiftFormer.models.swiftformer.Mlp,
  SwiftFormer.models.swiftformer.SwiftFormerLocalRepresentation,
  torch.nn.modules.container.Sequential},
 {SwiftFormer.models.swiftformer.EfficientAdditiveAttnetion,
  SwiftFormer.models.swiftformer.SwiftFormer,
  SwiftFormer.models.swiftformer.SwiftFormerEncoder,
  torch.nn.modules.container.ModuleList,
  torch.nn.modules.container.Sequential})

As we can see we cannot trace the EfficientAdditiveAttnetion, so we have to add it as exception to the build_from_unstructured_module function.

### 3. Convert the model to DagModule

In [5]:
custom_module_classes = (EfficientAdditiveAttnetion,)
INPUT_SHAPE = (1, 3, 224, 224)
PRUNING_PROPORTION = 0.5
NUM_PRUNING_STEPS = 100_000
dag = td.build_from_unstructured_module(
    model,
    custom_autowrap_torch_module_classes=custom_module_classes,
)
td.compare_module_outputs(first_module=model, second_module=dag, input_shape=INPUT_SHAPE) # sanity check for conversion

  (to_query): Linear(in_features=48, out_features=48, bias=True)
  (to_key): Linear(in_features=48, out_features=48, bias=True)
  (Proj): Linear(in_features=48, out_features=48, bias=True)
  (final): Linear(in_features=48, out_features=48, bias=True)
) of type: <class 'SwiftFormer.models.swiftformer.EfficientAdditiveAttnetion'> is not covered by `torch-dag`. by the DagModule. In particular, pruning support is not guaranteed.
  (to_query): Linear(in_features=64, out_features=64, bias=True)
  (to_key): Linear(in_features=64, out_features=64, bias=True)
  (Proj): Linear(in_features=64, out_features=64, bias=True)
  (final): Linear(in_features=64, out_features=64, bias=True)
) of type: <class 'SwiftFormer.models.swiftformer.EfficientAdditiveAttnetion'> is not covered by `torch-dag`. by the DagModule. In particular, pruning support is not guaranteed.
  (to_query): Linear(in_features=168, out_features=168, bias=True)
  (to_key): Linear(in_features=168, out_features=168, bias=True)
  (Proj): 

### 5. Prepare the converted model for pruning

In [6]:
pruning_config = tda.pruning.ChannelPruning(
    model=dag,
    input_shape_without_batch=INPUT_SHAPE[1:],
    pruning_proportion=PRUNING_PROPORTION,
    num_training_steps=NUM_PRUNING_STEPS,
    anneal_losses=False,
    custom_unprunable_module_classes=custom_module_classes
)
pruning_model = pruning_config.prepare_for_pruning()

print(f'Prunable proportion: {pruning_config.prunable_proportion}')

INFO:torch_dag_algorithms.pruning.filters:[[1m[96mNonPrunableCustomModulesFilter[0m] Removing orbit [1m[95mOrbit[0m[[1m[93mcolor[0m=2, [1m[93mdiscovery_stage[0m=OrbitsDiscoveryStage.EXTENDED_ORBIT_DISCOVERY, [1m[93msources[0m=[patch_embed_3, network_0_0_pwconv2, network_0_1_pwconv2, network_0_2_local_representation_pwconv2, network_0_2_linear_fc2], [1m[93msinks[0m=[network_0_0_pwconv1, network_0_1_pwconv1, network_0_2_local_representation_pwconv1, network_0_2_linear_fc1, network_1_proj], [1m[93mnon_border[0m={permute, mul, mul_3, add, reshape, network_0_1_dwconv, network_0_2_attn, network_0_1_norm, network_0_1_layer_scale, reshape_1, network_0_2_local_representation_layer_scale, permute_1, mul_4, network_0_2_layer_scale_1, add_3, network_0_2_layer_scale_2, mul_1, network_0_2_linear_norm1, add_1, network_0_2_local_representation_dwconv, network_0_2_local_representation_norm, network_0_2_linear_drop_1, mul_5, network_0_0_layer_scale, mul_2, patch_embed_4, add_4, add_

Prunable proportion: 0.8543182177960208


### 6. Run trainnig with pruning

Now we should train the model, but to save time we will just set logits in orbits to random values.

In [7]:
orbits_dict = tda.pruning.get_orbits_dict(pruning_model)
for k, v in orbits_dict.items():
    num_channels = v.num_channels
    v.debug_logits = torch.normal(mean=torch.zeros(size=(num_channels,)))

### 7. Remove channels from the model

In [8]:
pre_kmapp = td.commons.compute_static_kmapp(pruning_model, input_shape_without_batch=INPUT_SHAPE[1:])

dag_final = tda.pruning.remove_channels_in_dag(pruning_model, input_shape=(1, 3, 224, 224))
post_kmapp = td.commons.compute_static_kmapp(dag_final, input_shape_without_batch=INPUT_SHAPE[1:])

print(post_kmapp/ pre_kmapp)

dag_final.save('/path/to/model/pruned')

INFO:torch_dag_algorithms.pruning.channel_removal_primitives:Pruning conv patch_embed_0: leaving fraction: 0.4583333333333333 of out channels.
INFO:torch_dag_algorithms.pruning.channel_removal_primitives:Pruning conv patch_embed_3: leaving fraction: 1.0 of out channels.
INFO:torch_dag_algorithms.pruning.channel_removal_primitives:Pruning conv network_0_0_pwconv1: leaving fraction: 0.4895833333333333 of out channels.
INFO:torch_dag_algorithms.pruning.channel_removal_primitives:Pruning conv network_0_0_pwconv2: leaving fraction: 1.0 of out channels.
INFO:torch_dag_algorithms.pruning.channel_removal_primitives:Pruning conv network_0_1_pwconv1: leaving fraction: 0.5104166666666666 of out channels.
INFO:torch_dag_algorithms.pruning.channel_removal_primitives:Pruning conv network_0_1_pwconv2: leaving fraction: 1.0 of out channels.
INFO:torch_dag_algorithms.pruning.channel_removal_primitives:Pruning conv network_0_2_local_representation_pwconv1: leaving fraction: 0.5208333333333334 of out cha

0.5796244146040351


### 8. Results

|        Model         | Accuracy | kmapps |
|:--------------------:|:--------:| ------ |
|    SwiftFormer-S     |  78.5%   | 39.43  |
|    SwiftFormer-XS    |  75.7%   | 24.18  |
| SwiftFormer-S-pruned |  77.1%  | 19.96  |
