In [1]:
# Imports
import datasets
import pandas as pd

from IPython.display import display, HTML
from datasets import ClassLabel
from transformers import EncoderDecoderModel, BertTokenizer, BertTokenizerFast
from seq2seq_trainer import Seq2SeqTrainer
from seq2seq_training_args import Seq2SeqTrainingArguments

'''
!pip install --upgrade tensorflow
!pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio===0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
!pip install datasets
!pip install transformers
!pip install git-python==1.0.3
!pip install rouge_score
!pip install sacrebleu
!pip install wget

!python -m wget https://raw.githubusercontent.com/huggingface/transformers/master/examples/seq2seq/seq2seq_trainer.py
!python -m wget https://raw.githubusercontent.com/huggingface/transformers/master/examples/seq2seq/seq2seq_training_args.py
'''

'\n!pip install --upgrade tensorflow\n!pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio===0.7.2 -f https://download.pytorch.org/whl/torch_stable.html\n!pip install datasets\n!pip install transformers\n!pip install git-python==1.0.3\n!pip install rouge_score\n!pip install sacrebleu\n!pip install wget\n\n!python -m wget https://raw.githubusercontent.com/huggingface/transformers/master/examples/seq2seq/seq2seq_trainer.py\n!python -m wget https://raw.githubusercontent.com/huggingface/transformers/master/examples/seq2seq/seq2seq_training_args.py\n'

In [2]:
# Load data
train_data = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train")
val_data = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:10%]")
test_data = datasets.load_dataset("cnn_dailymail", "3.0.0", split="test[:5%]")

train_data.info.description
df = pd.DataFrame(train_data[:1])
del df["id"]

for column, typ in train_data.features.items():
    if isinstance(typ, ClassLabel):
        df[column] = df[column].transform(lambda i: typ.names[i])

display(HTML(df.to_html()))

KeyboardInterrupt: 

In [None]:
# Load tokenizer
tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-cased")

In [None]:
# Prepare data
encoder_max_length=512
decoder_max_length=128
batch_size = 4# 16

def process_data_to_model_inputs(batch):
    inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=encoder_max_length)
    outputs = tokenizer(batch["highlights"], padding="max_length", truncation=True, max_length=decoder_max_length)

    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask
    batch["decoder_input_ids"] = outputs.input_ids
    batch["decoder_attention_mask"] = outputs.attention_mask
    batch["labels"] = outputs.input_ids.copy()
    batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]]

    return batch

In [None]:
# Training data
train_data = train_data.select(range(50000))

train_data = train_data.map(
    process_data_to_model_inputs, 
    batched=True, 
    batch_size=batch_size, 
    remove_columns=["article", "highlights", "id"]
)

train_data.set_format(
    type="torch",
    columns=["input_ids",
             "attention_mask",
             "decoder_input_ids",
             "decoder_attention_mask",
             "labels"]
)

In [None]:
# Validation data
val_data = val_data.select(range(500))

val_data = val_data.map(
    process_data_to_model_inputs, 
    batched=True, 
    remove_columns=["article", "highlights", "id"]
)

val_data.set_format(
    type="torch",
    columns=["input_ids",
             "attention_mask",
             "decoder_input_ids",
             "decoder_attention_mask",
             "labels"]
)

In [None]:
# Load german data

In [None]:
# Load models
tf2tf = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-multilingual-cased", "bert-base-multilingual-cased", tie_encoder_decoder=False)
tf2tf.save_pretrained("bert2bert_multilingual")
tf2tf = EncoderDecoderModel.from_pretrained("bert2bert_multilingual")

In [None]:
# Configure models
tf2tf.config.decoder_start_token_id = tokenizer.cls_token_id
tf2tf.config.eos_token_id = tokenizer.sep_token_id
tf2tf.config.pad_token_id = tokenizer.pad_token_id
tf2tf.config.vocab_size = tf2tf.config.encoder.vocab_size

In [None]:
# Configure beam search
tf2tf.config.max_length = 142
tf2tf.config.min_length = 56
tf2tf.config.no_repeat_ngram_size = 3
tf2tf.config.early_stopping = True
tf2tf.config.length_penalty = 2.0
tf2tf.config.num_beams = 4

In [None]:
# Prepare metric
rouge = datasets.load_metric("rouge")

def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid

    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }

In [None]:
# Try CUDA
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()
print("Device:", device)

In [None]:
# Setup arguments
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    fp16=True,
    output_dir="./",
    warmup_steps=1000,
    save_steps=1000,
    logging_steps=100,
    eval_steps=1000,
    save_total_limit=1
)

# tf2tf = EncoderDecoderModel.from_pretrained("./checkpoint-1000")

In [None]:
# Start training
tf2tf.to("cuda")

trainer = Seq2SeqTrainer(
    model=tf2tf,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_data,
    eval_dataset=val_data,
)

trainer.train()

In [None]:
# Evaluate training
tf2tf = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert_cnn_daily_mail").to("cuda")
tokenizer = BertTokenizer.from_pretrained("patrickvonplaten/bert2bert_cnn_daily_mail")

def generate_summary(batch):
    inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
    input_ids = inputs.input_ids.to("cuda")
    attention_mask = inputs.attention_mask.to("cuda")
    
    outputs = tf2tf.generate(input_ids, attention_mask=attention_mask)
    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    batch["pred_summary"] = output_str

    return batch

results = test_data.map(
    generate_summary,
    batched=True,
    batch_size=batch_size,
    remove_columns=["article"]
)

print(results[0]["pred_summary"])
print(results[0]["highlights"])
print("====================")

rouge.compute(predictions=results["pred_summary"], references=results["highlights"], rouge_types=["rouge2"])["rouge2"].mid