<a href="https://colab.research.google.com/github/patrickvonplaten/notebooks/blob/master/BERT2BERT_for_CNN_Dailymail.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## **Warm-starting BERT2BERT for CNN/Dailymail**

***Note***: This notebook only uses a few training, validation, and test data samples for demonstration purposes. To fine-tune an encoder-decoder model on the full training data, the user should change the training and data preprocessing parameters accordingly as highlighted by the comments.


### **Data Preprocessing**


In [1]:
%%capture
!pip install datasets
!pip install transformers
!pip install starcc
!pip install evaluate

In [2]:
import json
from StarCC import PresetConversion
convert = PresetConversion(src='cn', dst='hk', with_phrase=False)

with open("train/train.json", "r") as input_file, open("train/train.can", "w+") as can_file, open("train/train.man", "w+") as man_file:
    for line in input_file.readlines():
        translation = json.loads(line)["translation"]
        can_file.write(translation["yue"] + "\n")
        man_file.write(convert(translation["zh"]) + "\n")

Building prefix dict from the default dictionary ...
Dumping model to file cache /var/folders/kk/n4ff6h1n3t170b1m4zv09yf40000gn/T/jieba.cache
Loading model cost 0.355 seconds.
Prefix dict has been built successfully.


In [3]:
from datasets import Dataset

train_data = None
val_data = None
test_data = None

with open("train/train.can", "r") as can_file, open("train/train.man", "r") as man_file:
    train_data = Dataset.from_dict({"can": can_file.read().splitlines(), "man": man_file.read().splitlines()})
    print(f"Loaded training data.")
    print(f"First line: {train_data[0]}")

with open("para/dev/dev.can", "r") as can_file, open("para/dev/dev.man", "r") as man_file:
    val_data = Dataset.from_dict({"can": can_file.read().splitlines(), "man": man_file.read().splitlines()})
    print(f"Loaded validation data.")
    print(f"First line: {val_data[0]}")

with open("para/test/test.can", "r") as can_file, open("para/test/test.man", "r") as man_file:
    test_data = Dataset.from_dict({"can": can_file.read().splitlines(), "man": man_file.read().splitlines()})
    print(f"Loaded test data.")
    print(f"First line: {test_data[0]}")


Loaded training data.
First line: {'can': '杞人嘅朋友嘆咗一口氣', 'man': '杞人的朋友嘆了一口氣'}
Loaded validation data.
First line: {'can': '啲咁耐就攪掂嘞，真係掯', 'man': '他一會兒工夫就弄好了，真神'}
Loaded test data.
First line: {'can': '筷子放喺你嘅右便', 'man': '筷子放在你的右邊'}


In [4]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('Ayaka/bart-base-cantonese')
tokenizer.bos_token = tokenizer.cls_token
tokenizer.eos_token = tokenizer.sep_token

In [26]:
batch_size=16  # change to 16 for full training
encoder_max_length=384
decoder_max_length=384

def process_data_to_model_inputs(batch):
  # tokenize the inputs and labels
  inputs = tokenizer(batch["can"], padding="max_length", truncation=True, max_length=encoder_max_length)
  outputs = tokenizer(batch["man"], padding="max_length", truncation=True, max_length=decoder_max_length)

  batch["input_ids"] = inputs.input_ids
  batch["attention_mask"] = inputs.attention_mask
  batch["decoder_input_ids"] = outputs.input_ids
  batch["decoder_attention_mask"] = outputs.attention_mask
  batch["labels"] = outputs.input_ids.copy()

  # because BERT automatically shifts the labels, the labels correspond exactly to `decoder_input_ids`. 
  # We have to make sure that the PAD token is ignored
#   batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]]

  return batch

# only use 32 training examples for notebook - DELETE LINE FOR FULL TRAINING
# train_data = train_data.select(range(32))

train_data = train_data.map(
    process_data_to_model_inputs, 
    batched=True, 
    batch_size=batch_size,
)
train_data.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)


# only use 16 training examples for notebook - DELETE LINE FOR FULL TRAINING
val_data = val_data.select(range(16))

val_data = val_data.map(
    process_data_to_model_inputs, 
    batched=True, 
    batch_size=batch_size,
)
val_data.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

Map:   0%|          | 0/32 [00:00<?, ? examples/s]

Map:   0%|          | 0/16 [00:00<?, ? examples/s]

### **Warm-starting the Encoder-Decoder Model**

In [27]:
# from transformers import EncoderDecoderModel
from transformers import BartForConditionalGeneration

bert2bert = BartForConditionalGeneration.from_pretrained('Ayaka/bart-base-cantonese')
# bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("Ayaka/bart-base-cantonese", "Ayaka/bart-base-cantonese")

In [28]:
# set special tokens
bert2bert.config.decoder_start_token_id = tokenizer.bos_token_id
bert2bert.config.eos_token_id = tokenizer.eos_token_id
bert2bert.config.pad_token_id = tokenizer.pad_token_id

# sensible parameters for beam search
# bert2bert.config.vocab_size = bert2bert.config.decoder.vocab_size
bert2bert.config.max_length = 200
bert2bert.config.min_length = 3
bert2bert.config.no_repeat_ngram_size = 3
bert2bert.config.early_stopping = True
bert2bert.config.length_penalty = 2.0
bert2bert.config.num_beams = 4

### **Fine-Tuning Warm-Started Encoder-Decoder Models**

In [29]:
import evaluate

# load bleu for validation
bleu = evaluate.load("bleu")

def compute_metrics(pred):
    # print("Start compute_metrics")
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    # all unnecessary tokens are removed
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    # labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    bleu_output = bleu.compute(predictions=pred_str, references=label_str)

    return {
        "bleu": round(bleu_output, 4),
    }

Cool! Finally, we start training.

In [30]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

# set training arguments - these params are not really tuned, feel free to change
training_args = Seq2SeqTrainingArguments(
    output_dir="./",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    predict_with_generate=True,
    use_mps_device=True,
    do_train=True,
    # evaluation_strategy="steps",
    do_eval=False,
    logging_steps=1000,  # set to 1000 for full training
    save_steps=500,  # set to 500 for full training
    eval_steps=8000,  # set to 8000 for full training
    warmup_steps=2000,  # set to 2000 for full training
    # max_steps=16, # delete for full training
    overwrite_output_dir=True,
    save_total_limit=3,
    fp16=False, 
)

# instantiate trainer
trainer = Seq2SeqTrainer(
    model=bert2bert,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_data,
    eval_dataset=val_data,
)
trainer.train()



  0%|          | 0/6 [00:00<?, ?it/s]

{'train_runtime': 64.4637, 'train_samples_per_second': 1.489, 'train_steps_per_second': 0.093, 'train_loss': 6.857500712076823, 'epoch': 3.0}


TrainOutput(global_step=6, training_loss=6.857500712076823, metrics={'train_runtime': 64.4637, 'train_samples_per_second': 1.489, 'train_steps_per_second': 0.093, 'train_loss': 6.857500712076823, 'epoch': 3.0})

### **Evaluation**

Awesome, we finished training our dummy model. Let's now evaluated the model on the test data. We make use of the dataset's handy `.map()` function to generate a summary of each sample of the test data.

In [33]:
from transformers import BertTokenizer, BartForConditionalGeneration

tokenizer = BertTokenizer.from_pretrained('Ayaka/bart-base-cantonese')
model = BartForConditionalGeneration.from_pretrained("./checkpoint-16").to('mps')

# only use 16 training examples for notebook - DELETE LINE FOR FULL TRAINING
# test_data = test_data.select(range(16))

batch_size = 64  # change to 64 for full evaluation

# map data correctly
def generate_summary(batch):
    # Tokenizer will automatically set [BOS] <text> [EOS]
    # cut off at BERT max length 512
    inputs = tokenizer(batch["can"], padding="max_length", truncation=True, max_length=200, return_tensors="pt")
    input_ids = inputs.input_ids.to('mps')
    attention_mask = inputs.attention_mask.to('mps')

    outputs = model.generate(input_ids, attention_mask=attention_mask)

    # all special tokens including will be removed
    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    batch["pred"] = [s.replace(" ", "") for s in output_str]
    for man, pred in list(zip(batch["man"][:10], batch["pred"][:10])):
        print("target: " + man)
        print("pred: " + pred)
        print()

    return batch

results = test_data.map(generate_summary, batched=True, batch_size=batch_size)

pred_str = results["pred"]
label_str = results["man"]

bleu_output = bleu.compute(predictions=pred_str, references=label_str)

print(bleu_output)

Map:   0%|          | 0/16 [00:00<?, ? examples/s]

The fully trained *BERT2BERT* model is uploaded to the 🤗model hub under [patrickvonplaten/bert2bert_cnn_daily_mail](https://huggingface.co/patrickvonplaten/bert2bert_cnn_daily_mail). 

The model achieves a ROUGE-2 score of **18.22**, which is even a little better than reported in the paper.

For some summarization examples, the reader is advised to use the online inference API of the model, [here](https://huggingface.co/patrickvonplaten/bert2bert_cnn_daily_mail).