Skip to content

Core Concepts

Yiming Cui edited this page Mar 11, 2020 · 3 revisions

Variables

  • Model_T: an instance of torch.nn.Module, the teacher model, which usually has more parameters than the student model.

  • Model_S: an instance of torch.nn.Module, the student model, usually smaller than the teacher model for the purpose of model compression and faster inference speed.

  • optimizer: instance of torch.optim.Optimizer.

  • scheduler: instance of torch.optim.lr_scheduler, allows flexible adjustment of learning rate.

  • dataloader: data iterator, used to generate data batches. A batch can be a tuple or a dict. t

  for batch in dataloader:
    # if batch_postprocessor is not None:
    batch = batch_postprocessor(batch)
    # check batch datatype
    # passes batch to the model and adaptors

Note:

  1. During training, the distiller will check if the batch is a dict, if so the model will be called as model(**batch, **args), otherwise the model is called as model(*batch, **args). Hence if the batch is not a dict, users should make sure that the order of each element in the batch is the same as the order of the arguments of model.forward. args is used for passing additional parameters.
  2. Users can define a batch_postprocessor function to post-process batches if needed. batch_postprocessor should take a batch and return a batch. See the explanation on train method of Distillers for more details.

Config and Distiller

Configurations

  • TrainingConfig: configuration related to general deep learning model training.
  • DistillationConfig: configuration related to distillation methods.

Distillers

Distillers are in charge of conducting the actual experiments. The following distillers are available:

  • BasicDistiller: single-teacher single-task distillation, provides basic distillation strategies.
  • GeneralDistiller (Recommended): single-teacher single-task distillation, supports intermediate features matching. Recommended most of the time.
  • MultiTeacherDistiller: multi-teacher distillation, which distills multiple teacher models ( of the same task) into a single student. This class doesn't support Intermediate features matching.
  • MultiTaskDistiller: multi-task distillation, which distills multiple teacher models (of different tasks) into a single student. This class doesn't support Intermediate features matching.
  • BasicTrainer: Supervised training a single model on a labeled dataset, not for distillation. It can be used to train a teacher model.

User-Defined Functions

In TextBrewer, there are two functions that should be implemented by users: callback and adaptor.

Callback

Optional, can be None. At each checkpoint, after saving the model, the distiller calls the callback function with arguments model=model_S, step=global_step. Callback can be used to evaluate the performance of the student model at each checkpoint. If users want to do an evaluation in the callback, remember to add model.eval() in the callback.

The signature is

callback(model: torch.nn.Module, step: int) -> Any

Adaptor

It converts the model inputs and outputs to the specified format so that it could be recognized by the distiller, and distillation loss can be computed. At each training step, batch and model outputs will be passed to the adaptor; adaptor reorganize the data and returns a dict.

adaptor(batch: Union[Dict,Tuple], model_outputs: Tuple) -> Dict

The functionality of the adaptor is shown in the figure below:

The available keys and their values of the returned dict are:

  • 'logits' : List[torch.Tensor] or torch.Tensor :

    The inputs to the final softmax. Each tensor should have the shape (batch_size, num_labels) or (batch_size, length, num_labels).

  • 'logits_mask': List[torch.Tensor] or torch.Tensor:

    0/1 matrix, which masks logits at specified positions. The positions where mask==0 won't be included in the calculation of loss on logits. Each tensor should have the shape (batch_size, length).

  • 'labels': List[torch.Tensor] or torch.Tensor:

    Ground-truth labels of the examples. Each tensor should have the shape (batch_size,) or (batch_size, length).

    Note:

    • logits_mask only works for logits with shape (batch_size, length, num_labels). It's used to mask in the length dimension, commonly used in sequence labeling tasks.

    • logits, logits_mask and labels should either all be lists of tensors, or all be tensors.

  • 'losses' : List[torch.Tensor] :

    It stores pre-computed losses, for example, the cross-entropy between logits and ground-truth labels. All the losses stored here would be summed and weighted by hard_label_weight and added to the total loss. Each tensor in the list should be a scalar, i.e., shape [].

  • 'attention': List[torch.Tensor] :

    List of attention matrices, used to compute intermediate feature matching. Each tensor should have the shape (batch_size, num_heads, length, length) or (batch_size, length, length), depending on what attention loss is used. Details about various loss functions can be found at Intermediate Loss.

  • 'hidden': List[torch.Tensor] :

    List of hidden states used to compute intermediate feature matching. Each tensor should have the shape (batch_size, length, hidden_dim).

  • 'inputs_mask' : torch.Tensor :

    0/1 matrix, performs masking on 'attention' and 'hidden', should have the shape (batch_size, length).

These keys are all optional:

  • If there is no 'inputs_mask' or 'logits_mask', then it's considered as no masking, or equivalent to using a mask with all elements equal to 1.
  • If not using intermediate feature matching, you can ignore 'attention' and 'hidden'.
  • If you don't want to add loss of the original hard labels, you can set hard_label_weight=0, and ignore 'losses'.
  • If 'logits' is not provided, the KD loss of the logits will be omitted.
  • 'labels' is required if and only if probability_shift==True.
  • You shouldn't ignore all the keys, otherwise the training won't start :)

Usually 'logits' should be provided, unless you are doing multi-stage training.