This repo is an archive for experiments exploring model pruning. One of the common ways of compressing AI models. This README is also available as HTML document in README.html
If using poetry
, the dependencies and respective versions are stated in pyproject.toml
. For pip, see requirements.txt
.
# If using poetry
poetry install
# If using pip
pip install -r requirements.txt
Execute this script to train a base (unpruned) model:
python3 train.py --dataset Flowers102 --model mobilenet_v3_small --epochs 8 --batch-size 8 --save-as full --save-path ./models/mobilenet_v3_small-Flowers102-1.pth
Numerous options are available to control for the training process:
--dataset
: Which dataset to use, currently onlyFlowers102
is supported.--model
: Which CNN or Transformer-based model to train. Seeconfig.py
for examples of how to refer to the model to use. Currently 7 models are supported:Model Name argument MobileNetV3Small mobilenet_v3_small
MobileNetV3Large mobilenet_v3_large
EfficientNetB1 efficientnet_b1
EfficientNetB3 efficientnet_b3
DenseNet169 densenet169
DenseNet201 densenet201
ViT B16 vit_b_16
--epochs
: How many epochs are the network is trained for. Defaults to 32.--batch-size
: How many images in one batch. Used to split the train and test set.--save-as
: Thetorch.save()
configuration to use. Options are state-dict, or full. State-dict will save the weights, whereas full will save the whole model--save-path
: Where to save the trained model.
Bash script (available as train.sh
) is available for automation, the above configurable option is editable inside the bash script, below is an example on how to execute the bash script in the background:
chmod +x ./train.sh
nohup ./train.sh &
To prune, execute the following:
python3 prune.py --path models/mobilenet_v3_small-Flowers102-1.pth --layer-kind Complete --pruning-method L1Unstructured --amount 0.8 --global-unstructured --check-percentage
The arguments for this script is as follows:
--path
: Location of the .pth file containing the full model to prune.--layer-kind
: What layers to prune, available options are Linear, Conv2d, or Complete (selecting both Linear and Conv2d).--pruning-method
: Which pruning method to use. Currently two are supported:RandomUnstructured
,L1Unstructured
.--amount
: How much of the model is going to be pruned. Setting this to 0.5 is equivalent to pruning to 50%.--global-unstructured
: Set this flag to utilize global unstructured pruning. ALWAYS ENABLE THIS FLAG FOR NOW!--check-percentage
: Set this flag to see exact numbers of parameters being pruned in each layer.
To benchmark latency, first, download the following Bromelia and place the contents (files called image_..) inside the inference/
directory. the command to run is:
python3 -m benchmark.lat-check --model efficientnet_b1 --path ./models/efficientnet_b1-Flowers102-1.pth
Arguments:
--model
: Which model is being tested for latency. This argument is used to select the appropriate preprocessing pipeline to use for the training and test set.--path
: Path to the .pth file of the model to test.
To benchmark accuracy, execute the following script:
python3 -m benchmark.acc-check --model mobilenet_v3_small --dataset Flowers102 --batch-size 8 --path ./models/mobilenet_v3_small-Flowers102-1.pth
Arguments:
--model
: Which model is being tested for accuracy. This argument is used to select the appropriate preprocessing pipeline to use for the training and test set.--dataset
: Which dataset to use.--batch-size
: How many images per batch.--path
: Path to the .pth file of the model to test.