In [None]:
from transformers import BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel
import numpy as np
from datasets import Dataset
from transformers import BertTokenizer
from collections import defaultdict
import torch
from torch.utils.data import DataLoader
from transformers import AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm
# from livelossplot import PlotLosses

In [None]:
import datasets
# train_data = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train")
idx_intent = np.load('./total_idx_intent.npy', allow_pickle=True).item()
idx_titles = np.load('./total_idx_titles.npy', allow_pickle=True).item()

intents = []
titles = []
for k,v in idx_intent.items():
    intents.append(v)
    titles.append(idx_titles[k])
    
# split train, test set = 8:2
test_num = int(len(intents)*0.2)

train_intent = intents[:-test_num]
train_titles = titles[:-test_num]
test_intent = intents[-test_num:]
test_titles = titles[-test_num:]

In [None]:
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", do_lower_case=True)

In [None]:
title_intent = defaultdict(list)


for i in range(len(train_intent)):
    title_intent['titles'].append(train_titles[i])
    title_intent['intents'].append(train_intent[i])
    
test_title_intent = defaultdict(list)
for i in range(len(test_intent)):
    test_title_intent['titles'].append(test_titles[i])
    test_title_intent['intents'].append(test_intent[i])
    
dataset = Dataset.from_dict(title_intent)
vali_dataset = Dataset.from_dict(test_title_intent)


In [None]:
def map_to_length(x):
    x["article_len"] = len(tokenizer(x["titles"]).input_ids)
    x["article_longer_256"] = int(x["article_len"] > 256)
    x["article_longer_512"] = int(x["article_len"] > 512)
    x["summary_len"] = len(tokenizer(x["intents"]).input_ids)
    x["summary_longer_16"] = int(x["summary_len"] > 16)
    x["summary_longer_32"] = int(x["summary_len"] > 32)
    return x

In [None]:
sample_size = len(train_intent)
data_stats = dataset.map(map_to_length, num_proc=4)

In [None]:
encoder_max_length=512
decoder_max_length=64

def process_data_to_model_inputs(batch):
  # tokenize the inputs and labels
    inputs = tokenizer(batch["titles"], padding="max_length", truncation=True, max_length=encoder_max_length)
    outputs = tokenizer(batch["intents"], padding="max_length", truncation=True, max_length=decoder_max_length)

    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask
    batch["decoder_attention_mask"] = outputs.attention_mask
    batch["labels"] = outputs.input_ids

  # because BERT automatically shifts the labels, the labels correspond exactly to `decoder_input_ids`. 
  # We have to make sure that the PAD token is ignored
    batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]]

    return batch

In [None]:
batch_size = 8

train_data = dataset.map(
    process_data_to_model_inputs, 
    batched=True, 
    batch_size=batch_size, 
    remove_columns=["titles", "intents"]
)

In [None]:
train_data.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_attention_mask", "labels"],
)


In [None]:
val_data = vali_dataset.map(
    process_data_to_model_inputs, 
    batched=True, 
    batch_size=batch_size, 
    remove_columns=["titles", "intents"]
)

In [None]:
val_data.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_attention_mask", "labels"],
)

In [None]:
from transformers import EncoderDecoderModel
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")

bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
bert2bert.config.eos_token_id = tokenizer.sep_token_id
bert2bert.config.pad_token_id = tokenizer.pad_token_id
bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size

bert2bert.config.max_length = 16
bert2bert.config.min_length = 1
bert2bert.config.no_repeat_ngram_size = 2
bert2bert.config.early_stopping = True
bert2bert.config.length_penalty = 2.0
bert2bert.config.num_beams = 4

In [None]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    fp16=True, 
    output_dir="./",
    logging_steps=500,
    save_steps=1000,
    eval_steps=500,
    learning_rate=0.00007,
    num_train_epochs=4,
)


In [None]:
rouge = datasets.load_metric("rouge")
def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
    # print(pred_str)
    rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid

    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }


In [None]:
# instantiate trainer
trainer = Seq2SeqTrainer(
    model=bert2bert,
    tokenizer=tokenizer,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_data,
    eval_dataset=val_data,
)
trainer.train()

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

bert2bert = EncoderDecoderModel.from_pretrained("./checkpoint-2000/").to(device)
tokenizer = BertTokenizer.from_pretrained("./checkpoint-2000")

In [None]:
def generate_summary(batch):
    # cut off at BERT max length 512
    inputs = tokenizer(batch["titles"], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
    input_ids = inputs.input_ids.to(device)
    attention_mask = inputs.attention_mask.to(device)

    outputs = bert2bert.generate(input_ids, attention_mask=attention_mask)

    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    batch["pred_summary"] = output_str

    return batch

In [None]:
batch_size = 8  # change to 64 for full evaluation

results = vali_dataset.map(generate_summary, batched=True, batch_size=batch_size, remove_columns=["titles"])


In [None]:
rouge_output = rouge.compute(predictions=results["pred_summary"], references=results["intents"], rouge_types=["rouge2"])["rouge2"].mid