More flexible distillation: Supports feeding different batches to the student and teacher. It means the batches for the student and teacher no longer need to be the same. It can be used for distilling models with different vocabularies (e.g., from RoBERTa to BERT). See the documentation for details.
Faster distillation: Users can pre-compute and cache the teacher outputs, then feed the cache to the distiller to save teacher's forward pass time. See the documentation for details.
MultiTaskDistillernow is the subclass of
GeneralDistillerand supports intermediate feature matching loss.
- Tensorboard now records more detailed losses (KD loss, hard label loss, matching losses...).
pkd_lossnow accepts tensors of shape (batch_size, length,hidden_size) or (batch_size,hidden_size). In the latter case, the loss is computed directly on the input tensors, without taking the hidden states on the first position.
Now supports distributed data-parallel training with
torch.nn.parallel.DistributedDataParallel! You can pass
TrainingConfigto setup for the distributed training. The detailed usage of
DistributedDataParallelcan be found at the PyTorch docs.
We also added an example (Chinese NER task) to demonstrate how to use TextBrewer with distributed data-parallel training.
- Now supports mixed precision training with Apex! Just set
TrainingConfig. See the documentation of
TrainingConfigto enable data parallel training within TextBrewer.
- Added an option
is_caching_logitsis True, the distiller will cache the batches and the output logits of the teacher model, so that those logits will only be calcuated once. It will speed up the distillation process. This feature is only available for
MultiTeacherDistiller. Be caution of setting it to True on large datasets, since it will store the batches and logits into the memory.
- Added new argument
trainmethod. It sets the strength of gradient clipping. Default -1, i.e., no gradient clipping.
- Added new arguments
trainmethod. The old
schedulermay cause convergence problem and is deprecated in favor of
scheduler_args. See the documentation for details.
display_paramters. Now it won't print the statistics directly to the screen.
- Fixed wrong call of zero_grad().
TrainingConfig.log_dircan be set to
Noneto disable TensorBoard.
- Added an attribute
print_freqto the distiller to control the frequency of logging.
- Added a new argument
trainmethod of the distiller. If
num_stepsis specified, the distiller will ignore
num_epochsand allow an unknown-size dataloader (i.e., which has no
- Added a new argument
trainmethod of the distiller to allow post-processing of batches.