# Distillation + Pruning + Quantization

The following code snippets allow you to distil a `microsoft/xtremedistil-l6-h384-uncased` teacher into a `microsoft/xtremedistil-l6-h256-uncased` student.

At the end of the training the student is then [dynamically quantized](https://pytorch.org/docs/stable/generated/torch.ao.quantization.quantize_dynamic.html#torch.ao.quantization.quantize_dynamic) and some weights are pruned based on their magnitude.

If you do not want to perform quantization and/or pruning simply remove the corresponding callback from the configuration.

In [1]:
from bert_squeeze.assistants import DistilAssistant
from lightning.pytorch import Trainer

In [2]:
# We are using xtremedistil because they are lightweight models but feel free
# to change it to the base model you want.
config_assistant = {
    "name": "distil",
    "teacher_kwargs": {
        "pretrained_model": "microsoft/xtremedistil-l6-h384-uncased",
        "num_labels": 2
    },
    "student_kwargs": {
        "pretrained_model": "microsoft/xtremedistil-l6-h256-uncased",
        "num_labels": 2
    },
    "data_kwargs": {
        "teacher_module": {
            "dataset_config": {
                "path": "linxinyuan/cola",
            }
        }
    },
    "callbacks": [
        {
            "_target_": "bert_squeeze.utils.callbacks.pruning.ThresholdBasedPruning",
            "threshold": 0.2,
            "start_pruning_epoch": -1
        },
        {
            "_target_": "bert_squeeze.utils.callbacks.quantization.DynamicQuantization"
        }
    ]
}

In [3]:
assistant = DistilAssistant(**config_assistant)

model = assistant.model
callbacks = assistant.callbacks
train_dataloader = assistant.data.train_dataloader()
test_dataloader = assistant.data.test_dataloader()

basic_trainer = Trainer(
    max_steps=2,
    callbacks=callbacks
)

basic_trainer.fit(
    model=model,
    train_dataloaders=train_dataloader,
    val_dataloaders=test_dataloader
)



  0%|          | 0/2 [00:00<?, ?it/s]

INFO:root:Dataset 'linxinyuan/cola' successfully loaded.


  0%|          | 0/2 [00:00<?, ?it/s]

INFO:root:Dataset 'linxinyuan/cola' successfully loaded.


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name         | Type               | Params
----------------------------------------------------
0 | teacher      | LtCustomBert       | 22.9 M
1 | student      | LtCustomBert       | 12.8 M
2 | loss_ce      | LabelSmoothingLoss | 0     
3 | loss_distill | MSELoss            | 0     
----------------------------------------------------
35.7 M    Trainable params
0         Non-trainable params
35.7 M    Total params
142.718   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/python_arg_parser.cpp:1485.)
  next_m.mul_(beta1).add_(1 - beta1, grad)
`Trainer.fit` stopped: `max_steps=2` reached.


Pruning model...
INFO:root:Model quantized and saved - size (MB): 142.818233
