<a href="https://colab.research.google.com/github/lokwq/TextBrewer/blob/add_note_examples/msra.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook shows how to fine-tune a model on msra_ner datasets and how to distill the model with TextBrewer.

Detailed Docs can be find here:
https://github.com/airaria/TextBrewer

In [1]:
import torch
device='cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
!pip install datasets
!pip install transformers
!pip install seqeval
!pip install textbrewer

In [3]:
import os
import torch
from transformers import BertForSequenceClassification, BertTokenizer,BertConfig,BertForTokenClassification
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers import Trainer, TrainingArguments
from transformers import pipeline
from datasets import load_dataset,load_metric

### Prepare dataset to train

In [4]:
task = "ner" #  "ner", "pos" or "chunk"
model_checkpoint = "bert-base-chinese"
batch_size = 8

In [None]:
from datasets import load_dataset, load_metric
datasets = load_dataset("msra_ner")

In [6]:
label_list = datasets["train"].features[f"{task}_tags"].feature.names
label_list

['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC']

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [8]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples[f"{task}_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                label_ids.append(-100)
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                label_ids.append(label[word_idx] if label_all_tokens else -100)
            previous_word_idx = word_idx

        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [9]:
import numpy as np

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

In [None]:
from transformers import BertForTokenClassification, TrainingArguments, Trainer

model = BertForTokenClassification.from_pretrained(model_checkpoint, num_labels=len(label_list))
model.to(device)

In [11]:
from transformers import DataCollatorForTokenClassification
data_collator = DataCollatorForTokenClassification(tokenizer)

In [None]:
metric = load_metric("seqeval")

In [None]:
tokenized_datasets = datasets.map(tokenize_and_align_labels, batched=True)

In [14]:
args = TrainingArguments(
    f"test-{task}",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=2,
    weight_decay=0.01,
)

In [15]:
trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
trainer.train()

In [None]:
torch.save(model.state_dict(), '/content/drive/MyDrive/msra_teacher_model.pt') #save the teacher model weights to distill

### Start distiilation

In [None]:
import textbrewer
from textbrewer import GeneralDistiller
from textbrewer import TrainingConfig, DistillationConfig
from transformers import BertForTokenClassification, BertConfig, AdamW,BertTokenizer
from transformers import get_linear_schedule_with_warmup
import torch 

Initialize the student model by BertConfig and prepare the teacher model.

bert_config_L3.json refer to a 3-layer Bert.

In [None]:
bert_config_T3 = BertConfig.from_json_file('/content/drive/MyDrive/TextBrewer-master/examples/student_config/bert_base_cased_config/bert_config_L3.json') 
bert_config_T3.output_hidden_states = True
bert_config_T3.num_labels = len(label_list)

student_model = BertForTokenClassification(bert_config_T3)

bert_config = BertConfig.from_json_file('/content/drive/MyDrive/TextBrewer-master/examples/student_config/bert_base_cased_config/bert_config.json')
bert_config.output_hidden_states = True
bert_config.num_labels = len(label_list)

teacher_model = BertForTokenClassification(bert_config) 
teacher_model.load_state_dict(torch.load('/content/drive/MyDrive/msra_teacher_model.pt'))

teacher_model.to(device)
student_model.to(device)


The cell below is to distill the teacher model to student model you prepared.

After the code execution is complete, the distilled model will be in 'saved_model' in colab file list

In [None]:
num_epochs = 20
num_training_steps = len(train_dataloader) * num_epochs

optimizer = AdamW(student_model.parameters(), lr=1e-5)

scheduler_class = get_linear_schedule_with_warmup

scheduler_args = {'num_warmup_steps':int(0.1*num_training_steps), 'num_training_steps':num_training_steps}

def simple_adaptor(batch, model_outputs):
  return {"logits":model_outputs.logits, 'hidden': model_outputs.hidden_states}

distill_config = DistillationConfig(
    intermediate_matches=[{"layer_T":0, "layer_S":0, "feature":"hidden", "loss":"hidden_mse", "weight":1},
               {"layer_T":4, "layer_S":1, "feature":"hidden", "loss":"hidden_mse", "weight":1},
               {"layer_T":8, "layer_S":2, "feature":"hidden", "loss":"hidden_mse", "weight":1},
               {"layer_T":12,"layer_S":3, "feature":"hidden", "loss":"hidden_mse", "weight":1}])

train_config = TrainingConfig()
distiller = GeneralDistiller(
    train_config=train_config, distill_config=distill_config,
    model_T=teacher_model, model_S=student_model, 
    adaptor_T=simple_adaptor, adaptor_S=simple_adaptor)


with distiller:
    distiller.train(optimizer, train_dataloader, num_epochs, scheduler_class=scheduler_class, scheduler_args = scheduler_args, callback=None)

Then evaluate the distilled model.

In [None]:
bert_config_T3 = BertConfig.from_json_file('/content/drive/MyDrive/data/bert_config/bert_config_L3.json')

bert_config_T3.output_hidden_states = True
bert_config_T3.num_labels = len(label_list)
test_model = BertForTokenClassification(bert_config_T3)


test_model.load_state_dict(torch.load('/content/drive/MyDrive/model/gs2813.pkl'))
test_model.to(device)


In [None]:
args = TrainingArguments(
    f"distill-test",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    do_train=False,
    do_eval=True,
    no_cuda=False,
    num_train_epochs=2,
    weight_decay=0.01,
)

In [None]:
trainer = Trainer(
    test_model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
trainer.evaluate()