# Task-specific knowledge distillation for BERT using Transformers & Amazon SageMaker
### Text Classification Example using `BERT-Base` as Teacher and `BERT-Tiny` as Student

Welcome to our end-to-end task-specific knowledge distilattion Text-Classification example using Transformers, PyTorch & Amazon SageMaker. Distillation is the process of training a small "student" to mimic a larger "teacher". In this example, we will use [BERT-base](https://huggingface.co/textattack/bert-base-uncased-SST-2) as Teacher and [BERT-Tiny](https://huggingface.co/google/bert_uncased_L-2_H-128_A-2) as Student. We will use [Text-Classification](https://huggingface.co/tasks/text-classification) as task-specific knowledge distillation task and the [Stanford Sentiment Treebank v2 (SST-2)](https://paperswithcode.com/dataset/sst) dataset for training.


They are two different types of knowledge distillation, the Task-agnostic knowledge distillation (right) and the Task-specific knowledge distillation (left). In this example we are going to use the Task-specific knowledge distillation. 

![knowledge-distillation](./imgs/knowledge-distillation.png)
_Task-specific distillation (left) versus task-agnostic distillation (right). Figure from FastFormers by Y. Kim and H. Awadalla [arXiv:2010.13382]._


In Task-specific knowledge distillation a "second step of distillation" is used to "fine-tune" the model on a given dataset. This idea comes from the [DistilBERT paper](https://arxiv.org/pdf/1910.01108.pdf) where it was shown that a student performed better than simply finetuning the distilled language model:

> We also studied whether we could add another step of distillation during the adaptation phase by fine-tuning DistilBERT on SQuAD using a BERT model previously fine-tuned on SQuAD as a teacher for an additional term in the loss (knowledge distillation). In this setting, there are thus two successive steps of distillation, one during the pre-training phase and one during the adaptation phase. In this case, we were able to reach interesting performances given the size of the model:79.8 F1 and 70.4 EM, i.e. within 3 points of the full model.

If you are more interested in those topics you should defintely read: 
* [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter](https://arxiv.org/abs/1910.01108)
* [FastFormers: Highly Efficient Transformer Models for Natural Language Understanding](https://arxiv.org/abs/2010.13382)

Especially the [FastFormers paper](https://arxiv.org/abs/2010.13382) contains great research on what works and doesn't work when using knowledge distillation.

---

Huge thanks to [Lewis Tunstall](https://www.linkedin.com/in/lewis-tunstall/) and his great [Weeknotes: Distilling distilled transformers](https://lewtun.github.io/blog/weeknotes/nlp/huggingface/transformers/2021/01/17/wknotes-distillation-and-generation.html#fn-1)


### Installation

In [None]:
# %pip install sagemaker huggingface_hub

This example will use the [Hugging Face Hub](https://huggingface.co/models) as remote model versioning service. To be able to push our model to the Hub, you need to register on the [Hugging Face](https://huggingface.co/join). 
If you already have an account you can skip this step. 
After you have an account, we will use the `notebook_login` util from the `huggingface_hub` package to log into our account and store our token (access key) on the disk. 

In [None]:
# from huggingface_hub import notebook_login

# notebook_login()

## Setup & Configuration

In this step we will define global configurations and parameters, which are used across the whole end-to-end fine-tuning proccess, e.g. `tokenizer` and `model` we will use. 

In [7]:
import sagemaker

sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket=None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

role = sagemaker.get_execution_role()
sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

sagemaker role arn: arn:aws:iam::079002598131:role/service-role/AmazonSageMaker-ExecutionRole-20220804T150518
sagemaker bucket: sagemaker-us-east-1-079002598131
sagemaker session region: us-east-1


_Note: The execution role is only available when running a notebook within SageMaker (SageMaker Notebook Instances or Studio). If you run `get_execution_role` in a notebook not on SageMaker, expect a region error._

You can comment in the cell below and provide a an IAM Role name with SageMaker permissions to setup your environment out side of SageMaker.

In [8]:
# import sagemaker
# import boto3
# import os

# os.environ["AWS_DEFAULT_REGION"]="region"

# # This ROLE needs to exists with your associated AWS Credentials and needs permission for SageMaker
# ROLE_NAME='role-name-of-your-iam-role-with-right-permissions'

# iam_client = boto3.client('iam')
# role = iam_client.get_role(RoleName=ROLE_NAME)['Role']['Arn']
# sess = sagemaker.Session()

# print(f"sagemaker role arn: {role}")
# print(f"sagemaker bucket: {sess.default_bucket()}")
# print(f"sagemaker session region: {sess.boto_region_name}")

## `DistillationTrainer`


Normally, when fine-tuning a transformer model using PyTorch you should go with the `Trainer-API`. The [Trainer](https://huggingface.co/docs/transformers/v4.16.1/en/main_classes/trainer#transformers.Trainer) class provides an API for feature-complete training in PyTorch for most standard use cases. 

In our example we cannot use the `Trainer` out-of-the-box, since we need to pass in two models, the `Teacher` and the `Student` and compute the loss for both. But we can subclass the `Trainer` to create a `DistillationTrainer` which will take care of it and only overwrite the [compute_loss](https://github.com/huggingface/transformers/blob/c4ad38e5ac69e6d96116f39df789a2369dd33c21/src/transformers/trainer.py#L1962) method as well as the `init` method. In addition to this we also need to subclass the `TrainingArguments` to include the our distillation hyperparameters. 

The [DistillationTrainer](https://github.com/philschmid/knowledge-distillation-transformers-pytorch-sagemaker/blob/e8d04240d0ebbd7bd0741d196e8902a69a34b414/scripts/train.py#L28) and [DistillationTrainingArguments](https://github.com/philschmid/knowledge-distillation-transformers-pytorch-sagemaker/blob/e8d04240d0ebbd7bd0741d196e8902a69a34b414/scripts/train.py#L21) are directly integrated into [training script](./scripts/train.py)

```python
class DistillationTrainingArguments(TrainingArguments):
    def __init__(self, *args, alpha=0.5, temperature=2.0, **kwargs):
        super().__init__(*args, **kwargs)
        self.alpha = alpha
        self.temperature = temperature


class DistillationTrainer(Trainer):
    def __init__(self, *args, teacher_model=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.teacher = teacher_model
        # place teacher on same device as student
        self._move_model_to_device(self.teacher, self.model.device)
        self.teacher.eval()

    def compute_loss(self, model, inputs, return_outputs=False):

        # compute student output
        outputs_student = model(**inputs)
        student_loss = outputs_student.loss
        # compute teacher output
        with torch.no_grad():
            outputs_teacher = self.teacher(**inputs)

        # assert size
        assert outputs_student.logits.size() == outputs_teacher.logits.size()

        # Soften probabilities and compute distillation loss
        loss_function = nn.KLDivLoss(reduction="batchmean")
        loss_logits = (
            loss_function(
                F.log_softmax(outputs_student.logits / self.args.temperature, dim=-1),
                F.softmax(outputs_teacher.logits / self.args.temperature, dim=-1),
            )
            * (self.args.temperature ** 2)
        )
        # Return weighted student loss
        loss = self.args.alpha * student_loss + (1.0 - self.args.alpha) * loss_logits
        return (loss, outputs_student) if return_outputs else loss
```



## Creating an Estimator with our Teacher & Student Model

In this example, we will use [BERT-base](textattack/bert-base-uncased-SST-2) as Teacher and [BERT-Tiny](https://huggingface.co/google/bert_uncased_L-2_H-128_A-2) as Student. Our Teacher is already fine-tuned on our dataset, which makes it easy for us to directly start the distillation training job rather than fine-tuning the teacher first to then distill it afterwards.

_**IMPORTANT**: This example will only work with a `Teacher` & `Student` combination where the Tokenizer is creating the same output._

Additionally, describes the [FastFormers: Highly Efficient Transformer Models for Natural Language Understanding](https://arxiv.org/abs/2010.13382) paper an additional phenomenon. 
> In our experiments, we have observed that dis-
tilled models do not work well when distilled to a
different model type. Therefore, we restricted our
setup to avoid distilling RoBERTa model to BERT
or vice versa. The major difference between the
two model groups is the input token (sub-word) em-
bedding. We think that different input embedding
spaces result in different output embedding spaces,
and knowledge transfer with different spaces does
not work well

In [23]:
from sagemaker.huggingface import HuggingFace
# from huggingface_hub import HfFolder

# hyperparameters, which are passed into the training job
hyperparameters={
    'teacher_id':'textattack/bert-base-uncased-SST-2',           
    'student_id':'google/bert_uncased_L-2_H-128_A-2',           
    'dataset_id':'glue',           
    'dataset_config':'sst2',             
    # distillation parameter
    'alpha': 0.5,
    'temparature': 4,
    # hpo parameter
    "run_hpo": True,
    "n_trials": 1, # was 100
    # push to hub config
    # 'push_to_hub': True,                            
    # 'hub_model_id': 'tiny-bert-sst2-distilled', 
    # 'hub_token': HfFolder.get_token()               
}

# create the Estimator
huggingface_estimator = HuggingFace(
    entry_point          = 'train.py',        
    source_dir           = './scripts',       
    instance_type        = 'ml.p4d.24xlarge',   
    instance_count       = 1,                 
    role                 = role,              
    transformers_version = '4.17',
    pytorch_version      = '1.10',             
    py_version           = 'py38',            
    hyperparameters      = hyperparameters,   
)

## Start our Training with Knowledge-Distillation and Hyperparamter optimization  

In [24]:
# define a data input dictonary with our uploaded s3 uris

# starting the train job with our uploaded datasets as input
# setting wait to False to not expose the HF Token
huggingface_estimator.fit(wait=False)

INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.
INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.
INFO:sagemaker:Creating training-job with name: huggingface-pytorch-training-2023-03-10-02-50-51-919


Since we are using the Hugging Face Hub intergration with Tensorboard we can inspect our progress directly on the hub, as well as testing checkpoints during the training.

In [None]:
# from huggingface_hub import HfApi

# whoami = HfApi().whoami()
# username = whoami['name']

# print(f"https://huggingface.co/{username}/{hyperparameters['hub_model_id']}")

We were able to achieve a `accuracy` of 0.8337, which is a very good result for our model. Our distilled `Tiny-Bert` has 96% less parameters than the teacher `bert-base` and runs ~46.5x faster while preserving over 90% of BERT’s performances as measured on the SST2 dataset.

| model | Parameter | Speed-up | Accuracy |
|------------|-----------|----------|----------|
| BERT-base  | 109M      | 1x       | 93%      |
| tiny-BERT  | 4M        | 46.5x    | 83%      |

_Note: The [FastFormers paper](https://arxiv.org/abs/2010.13382) uncovered that the biggest boost in performance is observerd when having 6 or more layers in the student. The [google/bert_uncased_L-2_H-128_A-2](https://huggingface.co/google/bert_uncased_L-2_H-128_A-2) we used only had 2, which means when changing our student to, e.g. `distilbert-base-uncased` we should better performance in terms of accuracy._