<a href="https://colab.research.google.com/github/lokwq/TextBrewer/blob/add_note_examples/sst2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook shows how to fine-tune a model on sst-2 dataset and how to distill the model with TextBrewer.

Detailed Docs can be find here:
https://github.com/airaria/TextBrewer

In [1]:
import torch
device='cuda' if torch.cuda.is_available() else 'cpu' 

In [None]:
!pip install transformers #4.8.2
!pip install datasets
!pip install textbrewer

In [3]:
import os
import torch
from transformers import BertForSequenceClassification, BertTokenizer,BertConfig
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers import Trainer, TrainingArguments
from datasets import load_dataset,load_metric

### Prepare dataset to train

In [None]:
train_dataset = load_dataset('glue', 'sst2', split='train')
val_dataset = load_dataset('glue', 'sst2', split='validation')
test_dataset = load_dataset('glue', 'sst2', split='test')

In [None]:
train_dataset = train_dataset.map(lambda examples: {'labels': examples['label']}, batched=True)
val_dataset = val_dataset.map(lambda examples: {'labels': examples['label']}, batched=True)
test_dataset = test_dataset.map(lambda examples: {'labels': examples['label']}, batched=True)
val_dataset = val_dataset.remove_columns(['label'])
test_dataset = test_dataset.remove_columns(['label'])
train_dataset = train_dataset.remove_columns(['label'])

In [None]:
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
MAX_LENGTH = 128
train_dataset = train_dataset.map(lambda e: tokenizer(e['sentence'], truncation=True, padding='max_length', max_length=MAX_LENGTH), batched=True)
val_dataset = val_dataset.map(lambda e: tokenizer(e['sentence'], truncation=True, padding='max_length', max_length=MAX_LENGTH), batched=True)
test_dataset = test_dataset.map(lambda e: tokenizer(e['sentence'], truncation=True, padding='max_length', max_length=MAX_LENGTH), batched=True)

In [11]:
train_dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
val_dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
test_dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])

In [12]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

In [None]:
#start training 
training_args = TrainingArguments(
    output_dir='/content/drive/MyDrive/results',          #output directory
    learning_rate=1e-4,
    num_train_epochs=5,              
    per_device_train_batch_size=32,                #batch size per device during training
    per_device_eval_batch_size=32,                #batch size for evaluation
    logging_dir='/content/drive/MyDrive/logs',            
    logging_steps=100,
    do_train=True,
    do_eval=True,
    no_cuda=False,
    load_best_model_at_end=True,
    # eval_steps=100,
    evaluation_strategy="epoch"
)

trainer = Trainer(
    model=model,                         
    args=training_args,                  
    train_dataset=train_dataset,         
    eval_dataset=val_dataset,            
    compute_metrics=compute_metrics
)

train_out = trainer.train()

#after training, you could find traing logs and checpoints in your own dirve. also you can reset the file address in training args

In [None]:
torch.save(model.state_dict(), '/content/drive/MyDrive/sst2_teacher_model.pt')


### Start distillation

In [None]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32) #prepare dataloader

In [None]:
import textbrewer
from textbrewer import GeneralDistiller
from textbrewer import TrainingConfig, DistillationConfig
from transformers import BertForSequenceClassification, BertConfig, AdamW,BertTokenizer
from transformers import get_linear_schedule_with_warmup

Initialize the student model by BertConfig and prepare the teacher model.

bert_config_L3.json refers to a 3-layer Bert.

bert_config.json refers to a standard 12-layer Bert.

In [None]:
bert_config_T3 = BertConfig.from_json_file('/content/drive/MyDrive/TextBrewer-master/examples/student_config/bert_base_cased_config/bert_config_L3.json')#相对路径
bert_config_T3.output_hidden_states = True

student_model = BertForSequenceClassification(bert_config_T3) #, num_labels = 2
student_model.to(device=device)


bert_config = BertConfig.from_json_file('/content/drive/MyDrive/TextBrewer-master/examples/student_config/bert_base_cased_config/bert_config.json')
bert_config.output_hidden_states = True
teacher_model = BertForSequenceClassification(bert_config) #, num_labels = 2
teacher_model.load_state_dict(torch.load('/content/drive/MyDrive/sst2_teacher_model.pt'))
teacher_model.to(device=device)

The cell below is to distill the teacher model to student model you prepared.

After the code execution is complete, the distilled model will be in 'saved_model' in colab file list

In [None]:
num_epochs = 20
num_training_steps = len(train_dataloader) * num_epochs
# Optimizer and learning rate scheduler
optimizer = AdamW(student_model.parameters(), lr=1e-4)

scheduler_class = get_linear_schedule_with_warmup
# arguments dict except 'optimizer'
scheduler_args = {'num_warmup_steps':int(0.1*num_training_steps), 'num_training_steps':num_training_steps}


def simple_adaptor(batch, model_outputs):
    return {'logits': model_outputs.logits, 'hidden': model_outputs.hidden_states}

distill_config = DistillationConfig(
    intermediate_matches=[    
     {'layer_T':0, 'layer_S':0, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1},
     {'layer_T':8, 'layer_S':2, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1}])
train_config = TrainingConfig()

distiller = GeneralDistiller(
    train_config=train_config, distill_config=distill_config,
    model_T=teacher_model, model_S=student_model, 
    adaptor_T=simple_adaptor, adaptor_S=simple_adaptor)


with distiller:
    distiller.train(optimizer, train_dataloader, num_epochs, scheduler_class=scheduler_class, scheduler_args = scheduler_args, callback=None)

In [None]:
test_model = BertForSequenceClassification(bert_config_T3)
test_model.load_state_dict(torch.load('/content/drive/MyDrive/gs4210.pkl'))#gs4210 is the distilled model weights file

<All keys matched successfully>

In [None]:
from torch.utils.data import DataLoader
eval_dataloader = DataLoader(val_dataset, batch_size=8)

In [None]:
metric= load_metric("accuracy")
test_model.eval()
for batch in eval_dataloader:
    batch = {k: v for k, v in batch.items()}
    with torch.no_grad():
        outputs = test_model(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=batch["labels"])

metric.compute()