## Pruning your model

Pruning is a model compression technique where we discard network connections that are unimportant to network performance without affecting the inference accuracy. In a neural network, weights very close to zero contribute very little to the model’s inference. Performing convolution on such weights is equivalent to performing multiplication with zero. Therefore, removing such weights could lead to a reduction in redundant operations, in turn providing higher throughput and lower memory footprint both during the training and the inference phase of a neural network.

We support  two different pruning techniques: (a) unstructured pruning, and (b) structured pruning. Unstructured or element pruning is a fine-grained way of pruning individual weights in a neural network without applying any structural constraints.  In structured pruning, we remove a larger set of weights while maintaining a dense structure of the model. 

Download a model to use as the teacher:
``` bash
$ cd rubicon/models
$ bash download_teacher.sh # we use Bonito_CTC from ONT as the teacher
```

In [1]:
import os
from statistics import mode
import sys
from argparse import ArgumentParser
from argparse import ArgumentDefaultsHelpFormatter
from pathlib import Path
from importlib import import_module

from rubicon.data import load_numpy_shuf,load_numpy_full 
from bonito.data import load_script
from bonito.util import load_symbol, init,load_model
from rubicon.util import load_model_prune,load_model_prune_for_basecall


from rubicon.util import __models__, default_config, default_data

from rubicon.training import load_state, Trainer
from rubicon.prunekdtraining import load_state, PruneKDTrainer
import subprocess
from prettytable import PrettyTable
import toml
import torch
from torch import nn
import numpy as np
from torch.utils.data import DataLoader
from torch.nn.utils import prune
import brevitas.nn as qnn

import shutil
import logging
_logger = logging.getLogger(__name__)

No ROCm runtime is found, using ROCM_HOME='/opt/rocm-5.1.0'


In [2]:
from rubicon.basemodule.prune import measure_module_sparsity,measure_global_sparsity,prune_model_unstructured,prune_model_structured,remove_prune_mask,count_parameters

In [3]:
temp=1
alpha=0.1
structural=False
l1_scoring=True
prune_dim=0
prune_proportion=0.1
save_directory="pruned_results"
workdir = os.path.expanduser(save_directory)
onnx_name=""
onnx_name+=str(prune_proportion)
force=True
pretrained=""
teacher=True
teacher_directory="../rubicon/models/bonito"
seed=25
config="../rubicon/models/configs/config.toml"
device="cuda"
quant=False
type="bonito"
full=False
chunks=128
valid_chunks=128
batch=128
restore_optim=False
save_optim_every=10
grad_accum_split=1
epochs=10
lr=2e-3

In [None]:
_logger.info("Save path: {}".format(workdir))

if force and os.path.exists(workdir):
    shutil.rmtree(workdir)
if os.path.exists(workdir) and not force:
    print("[error] %s exists, remove it to continue or enable force variable delete." % workdir)
    exit(1)

if not prune_proportion:
    _logger.warning("Specify sparsity using args --prune")
    exit(1)

os.makedirs(workdir, exist_ok=True)
init(seed, device)
device = torch.device(device)

if pretrained:
    dirname = pretrained
    if not os.path.isdir(dirname) and os.path.isdir(os.path.join(__models__, dirname)):
        dirname = os.path.join(__models__, dirname)
    config_file = os.path.join(dirname, 'config.toml')
else:
    config_file = config

config = toml.load(config_file)

if (teacher):    
    teacher_path=os.path.expanduser(teacher_directory)
    _logger.info("Loading Teacher Model:{}".format(teacher_path))
    if not os.path.exists(teacher_path):
        _logger.warning("Teacher model %s does not exists" % teacher_path)
        exit(1)
    model_teacher = load_model(teacher_path, device, half=False)

    _logger.info("Total parameters in model teacher:{}".format(sum(p.numel() for p in model_teacher.parameters())) )
    original_stdout = sys.stdout # Save a reference to the original standard output
    with open(workdir+'/model_teacher.txt', 'w') as f:
        sys.stdout = f # Change the standard output to the file we created.
        count_parameters(model_teacher)
    sys.stdout = original_stdout # Reset the standard output to its original value



_logger.info("Loading model for pruning")
if quant:
    module_check=qnn.QuantConv1d
else: 
    module_check=nn.Conv1d


if pretrained:
    _logger.info("[using pretrained model {}]".format(pretrained))
    model_student = load_model_prune(pretrained, device, half=False,load_model_type=type)
else:
    if (type=="rubiconqabas"):
        _logger.info("Training Rubicon QABAS model")
        model_student = load_symbol(config, 'RubiconQabas')(config)  
    elif (type=="bonito"):
            _logger.info("ONT Bonito model")
            model_student = load_symbol(config, 'Model')(config) 
    elif (type=="bonitostaticquant"):
            _logger.info("ONT Bonito model for static quantization")
            model_student = load_symbol(config, 'BonitoStaticQuant')(config)  
    else:
        _logger.warning("Please define a model or choose a model using --type")
        exit(1)
    

if config.get("lr_scheduler"):
    sched_config = config["lr_scheduler"]
    lr_scheduler_fn = getattr(
        import_module(sched_config["package"]), sched_config["symbol"]
    )(**sched_config)
else:
    lr_scheduler_fn = None

original_stdout = sys.stdout # Save a reference to the original standard output
with open(workdir+'/model_student.txt', 'w') as f:
    sys.stdout = f # Change the standard output to the file we created.
    count_parameters(model_student)
    _logger.info("Total parameters in BASE model=%s"%sum(p.numel() for p in model_student.parameters()))
    torch.save(model_student.state_dict(), workdir+"/model_student.h5")
   
    if(structural):
        _logger.info("Structural pruning") 
        prune_model=prune_model_structured(model_student,  module_check, prune_proportion,l1_scoring,prune_dim)

    else:
        _logger.info("Unstructural pruning") 
        prune_model=prune_model_unstructured(model_student, module_check, prune_proportion,l1_scoring)
                                
    torch.save(prune_model.state_dict(), workdir+"/model_prune.h5")
  
  
    _logger.info("***Measuring sparsity***") 

    num_zeros, num_elements, sparsity = measure_global_sparsity(
            prune_model, module_check,
            weight=True,
            bias=False,
            conv1d_use_mask=True)
    _logger.info("Global Sparsity:")
    _logger.info("{:.2f}".format(sparsity))

    if full:
        _logger.info("Full dataset training")
        train_loader_kwargs, valid_loader_kwargs = load_numpy_full(None,
                args.directory
        )
    elif chunks:
        _logger.info("Not full dataset training with shuffling")
        train_loader_kwargs, valid_loader_kwargs = load_numpy(
            chunks,valid_chunks, directory
        )
    else:
        _logger.warning("Please define the training data correctly")
        exit(1)

    loader_kwargs = {
        "batch_size": batch, "num_workers": 2, "pin_memory": True
    }
    train_loader = DataLoader(**loader_kwargs, **train_loader_kwargs)
    valid_loader = DataLoader(**loader_kwargs, **valid_loader_kwargs)
    
    if(teacher):
        _logger.info("Pruning with Knowledge Distillation")
        trainer = PruneKDTrainer(
            model_teacher,
            prune_model, device, train_loader, valid_loader,
            use_amp=False,
            lr_scheduler_fn=lr_scheduler_fn,
            restore_optim=restore_optim,
            save_optim_every=save_optim_every,
            grad_accum_split=grad_accum_split,
            temp=temp,
            alpha=alpha
        )
    else:
        _logger.info("Training without Knowledge Distillation")
        trainer = Trainer(
            prune_model, device, train_loader, valid_loader,
            use_amp=False,
            lr_scheduler_fn=lr_scheduler_fn,
            restore_optim=restore_optim,
            save_optim_every=save_optim_every,
            grad_accum_split=grad_accum_split
        )


    if(remove_mask):
        _logger.warning("Removing mask while storing weights")
        trainer.fit(workdir, epochs, lr,remove_mask=True,quant=quant,load_model_type=type)

    else:
        trainer.fit(workdir, epochs, lr,load_model_type=type)
        _logger.info("Measure sparsity after training")
        num_zeros, num_elements, sparsity = measure_global_sparsity(
            prune_model,module_check,
            weight=True,
            bias=False,
            conv1d_use_mask=True)


    _logger.info("Global Sparsity:")
    _logger.info("{:.2f}".format(sparsity))
    _logger.info("Removing mask after training")
    prune_model=remove_prune_mask(model_student, qnn.QuantConv1d)
    sys.stdout = original_stdout # Reset the standard output to its original value

11/07/2023 08:31:27 AM [INFO] Save path: pruned_results


[2023-11-07 08:31:27] INFO (__main__/MainThread) Save path: pruned_results
[2023-11-07 08:31:27] INFO (__main__/MainThread) Save path: pruned_results


11/07/2023 08:31:27 AM [INFO] Loading Teacher Model:../rubicon/models/bonito


[2023-11-07 08:31:27] INFO (__main__/MainThread) Loading Teacher Model:../rubicon/models/bonito
[2023-11-07 08:31:27] INFO (__main__/MainThread) Loading Teacher Model:../rubicon/models/bonito


11/07/2023 08:31:28 AM [INFO] Total parameters in model teacher:9738573


[2023-11-07 08:31:28] INFO (__main__/MainThread) Total parameters in model teacher:9738573
[2023-11-07 08:31:28] INFO (__main__/MainThread) Total parameters in model teacher:9738573


11/07/2023 08:31:28 AM [INFO] +--------------------------------------------+------------+
|                  Modules                   | Parameters |
+--------------------------------------------+------------+
|    encoder.encoder.0.conv.0.conv.weight    |    3096    |
|      encoder.encoder.0.conv.1.weight       |    344     |
|       encoder.encoder.0.conv.1.bias        |    344     |
| encoder.encoder.1.conv.0.depthwise.weight  |   316480   |
| encoder.encoder.1.conv.0.pointwise.weight  |   145856   |
|      encoder.encoder.1.conv.1.weight       |    424     |
|       encoder.encoder.1.conv.1.bias        |    424     |
| encoder.encoder.1.conv.4.depthwise.weight  |   390080   |
| encoder.encoder.1.conv.4.pointwise.weight  |   179776   |
|      encoder.encoder.1.conv.5.weight       |    424     |
|       encoder.encoder.1.conv.5.bias        |    424     |
|  encoder.encoder.1.residual.0.conv.weight  |   145856   |
|    encoder.encoder.1.residual.1.weight     |    424     |
|     enco

[2023-11-07 08:31:28] INFO (rubicon.basemodule.prune/MainThread) +--------------------------------------------+------------+
|                  Modules                   | Parameters |
+--------------------------------------------+------------+
|    encoder.encoder.0.conv.0.conv.weight    |    3096    |
|      encoder.encoder.0.conv.1.weight       |    344     |
|       encoder.encoder.0.conv.1.bias        |    344     |
| encoder.encoder.1.conv.0.depthwise.weight  |   316480   |
| encoder.encoder.1.conv.0.pointwise.weight  |   145856   |
|      encoder.encoder.1.conv.1.weight       |    424     |
|       encoder.encoder.1.conv.1.bias        |    424     |
| encoder.encoder.1.conv.4.depthwise.weight  |   390080   |
| encoder.encoder.1.conv.4.pointwise.weight  |   179776   |
|      encoder.encoder.1.conv.5.weight       |    424     |
|       encoder.encoder.1.conv.5.bias        |    424     |
|  encoder.encoder.1.residual.0.conv.weight  |   145856   |
|    encoder.encoder.1.residual.1.w

11/07/2023 08:31:28 AM [INFO] Total Trainable Params: 9738573


[2023-11-07 08:31:28] INFO (rubicon.basemodule.prune/MainThread) Total Trainable Params: 9738573
[2023-11-07 08:31:28] INFO (rubicon.basemodule.prune/MainThread) Total Trainable Params: 9738573


11/07/2023 08:31:28 AM [INFO] Loading model for pruning


[2023-11-07 08:31:28] INFO (__main__/MainThread) Loading model for pruning
[2023-11-07 08:31:28] INFO (__main__/MainThread) Loading model for pruning


11/07/2023 08:31:28 AM [INFO] ONT Bonito model


[2023-11-07 08:31:28] INFO (__main__/MainThread) ONT Bonito model
[2023-11-07 08:31:28] INFO (__main__/MainThread) ONT Bonito model


11/07/2023 08:31:28 AM [INFO] +--------------------------------------------+------------+
|                  Modules                   | Parameters |
+--------------------------------------------+------------+
|    encoder.encoder.0.conv.0.conv.weight    |    3096    |
|      encoder.encoder.0.conv.1.weight       |    344     |
|       encoder.encoder.0.conv.1.bias        |    344     |
| encoder.encoder.1.conv.0.depthwise.weight  |   316480   |
| encoder.encoder.1.conv.0.pointwise.weight  |   145856   |
|      encoder.encoder.1.conv.1.weight       |    424     |
|       encoder.encoder.1.conv.1.bias        |    424     |
| encoder.encoder.1.conv.4.depthwise.weight  |   390080   |
| encoder.encoder.1.conv.4.pointwise.weight  |   179776   |
|      encoder.encoder.1.conv.5.weight       |    424     |
|       encoder.encoder.1.conv.5.bias        |    424     |
|  encoder.encoder.1.residual.0.conv.weight  |   145856   |
|    encoder.encoder.1.residual.1.weight     |    424     |
|     enco

[2023-11-07 08:31:28] INFO (rubicon.basemodule.prune/MainThread) +--------------------------------------------+------------+
|                  Modules                   | Parameters |
+--------------------------------------------+------------+
|    encoder.encoder.0.conv.0.conv.weight    |    3096    |
|      encoder.encoder.0.conv.1.weight       |    344     |
|       encoder.encoder.0.conv.1.bias        |    344     |
| encoder.encoder.1.conv.0.depthwise.weight  |   316480   |
| encoder.encoder.1.conv.0.pointwise.weight  |   145856   |
|      encoder.encoder.1.conv.1.weight       |    424     |
|       encoder.encoder.1.conv.1.bias        |    424     |
| encoder.encoder.1.conv.4.depthwise.weight  |   390080   |
| encoder.encoder.1.conv.4.pointwise.weight  |   179776   |
|      encoder.encoder.1.conv.5.weight       |    424     |
|       encoder.encoder.1.conv.5.bias        |    424     |
|  encoder.encoder.1.residual.0.conv.weight  |   145856   |
|    encoder.encoder.1.residual.1.w

11/07/2023 08:31:28 AM [INFO] Total Trainable Params: 9738573


[2023-11-07 08:31:28] INFO (rubicon.basemodule.prune/MainThread) Total Trainable Params: 9738573
[2023-11-07 08:31:28] INFO (rubicon.basemodule.prune/MainThread) Total Trainable Params: 9738573


11/07/2023 08:31:28 AM [INFO] Total parameters in BASE model=9738573


[2023-11-07 08:31:28] INFO (__main__/MainThread) Total parameters in BASE model=9738573
[2023-11-07 08:31:28] INFO (__main__/MainThread) Total parameters in BASE model=9738573


11/07/2023 08:31:29 AM [INFO] Unstructural pruning


[2023-11-07 08:31:29] INFO (__main__/MainThread) Unstructural pruning
[2023-11-07 08:31:29] INFO (__main__/MainThread) Unstructural pruning
