In [None]:
from transformers import BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel
import numpy as np
from datasets import Dataset
from transformers import BertTokenizer
from collections import defaultdict
import torch
from torch.utils.data import DataLoader
from transformers import AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm

In [None]:
import datasets
# train_data = datasets.load_dataset("cnn_dailymail", "3.0.0", split="train")
idx_intent = np.load('data/total_idx_intent.npy', allow_pickle=True).item()
idx_titles = np.load('data/total_idx_titles.npy', allow_pickle=True).item()

intents = []
titles = []
for k,v in idx_intent.items():
    intents.append(v)
    titles.append(idx_titles[k])
    
# split train, test set = 8:2
test_num = int(len(intents)*0.2)

train_intent = intents[:-test_num]
train_titles = titles[:-test_num]
test_intent = intents[-test_num:]
test_titles = titles[-test_num:]

In [None]:
train_intent.pop(1031)
train_titles.pop(1031)
test_intent.pop(816)
test_titles.pop(816)

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("t5-base", do_lower_case=True)

In [None]:
title_intent = defaultdict(list)
# encoder_max_length=512
# decoder_max_length=32

for i in range(len(train_intent)):
    # token_titles = tokenizer(train_titles[i], add_special_tokens=False, return_tensors="pt", padding="max_length",truncation=True, max_length=encoder_max_length)
    # title_intent['titles'].append(torch.as_tensor(token_titles.input_ids, dtype=torch.int))
    # title_intent['attention_mask'].append(torch.as_tensor(token_titles.attention_mask, dtype=torch.int))
    # token_labels = tokenizer(train_intent[i], add_special_tokens=False, return_tensors="pt", padding="max_length",truncation=True, max_length=decoder_max_length)
    # title_intent['labels'].append(torch.as_tensor(token_labels.input_ids, dtype=torch.int))
    # title_intent['decoder_attention_mask'].append(torch.as_tensor(token_labels.attention_mask, dtype=torch.int))
    title_intent['titles'].append(train_titles[i])
    title_intent['intents'].append(train_intent[i])
    
test_title_intent = defaultdict(list)
for i in range(len(test_intent)):
    # token_titles = tokenizer(test_titles[i], add_special_tokens=False, return_tensors="pt", padding="max_length",truncation=True, max_length=encoder_max_length)
    # test_title_intent['titles'].append(torch.as_tensor(token_titles.input_ids, dtype=torch.int))
    # test_title_intent['attention_mask'].append(torch.as_tensor(token_titles.attention_mask, dtype=torch.int))
    # token_labels = tokenizer(test_intent[i], add_special_tokens=False, return_tensors="pt", padding="max_length",truncation=True, max_length=decoder_max_length)
    # test_title_intent['labels'].append(torch.as_tensor(token_labels.input_ids, dtype=torch.int))
    # test_title_intent['decoder_attention_mask'].append(torch.as_tensor(token_labels.attention_mask, dtype=torch.int))
    test_title_intent['titles'].append(test_titles[i])
    test_title_intent['intents'].append(test_intent[i])
    
dataset = Dataset.from_dict(title_intent)
vali_dataset = Dataset.from_dict(test_title_intent)

# dataset.set_format("torch")
# vali_dataset.set_format("torch")

In [None]:
encoder_max_length=512
decoder_max_length=64

prefix = "summarize: "

def process_data_to_model_inputs(batch):
  # tokenize the inputs and labels
    inputs = [prefix + doc for doc in batch["titles"]]
    model_inputs = tokenizer(inputs, max_length=encoder_max_length, truncation=True)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(batch["intents"], max_length=decoder_max_length, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs
#     inputs = tokenizer(batch["titles"], padding="max_length", truncation=True, max_length=encoder_max_length)
#     outputs = tokenizer(batch["intents"], padding="max_length", truncation=True, max_length=decoder_max_length)

#     batch["input_ids"] = inputs.input_ids
#     batch["attention_mask"] = inputs.attention_mask
#   # batch["decoder_input_ids"] = outputs.input_ids
#     batch["decoder_attention_mask"] = outputs.attention_mask
#     batch["labels"] = outputs.input_ids

#   # because BERT automatically shifts the labels, the labels correspond exactly to `decoder_input_ids`. 
#   # We have to make sure that the PAD token is ignored
#     batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]]

#     return batch

In [None]:
batch_size = 4

train_data = dataset.map(
    process_data_to_model_inputs, 
    batched=True, 
    batch_size=batch_size, 
    remove_columns=["titles", "intents"]
)

In [None]:
train_data.set_format(
    type="torch",
)

In [None]:
val_data = vali_dataset.map(
    process_data_to_model_inputs, 
    batched=True, 
    batch_size=batch_size, 
    remove_columns=["titles", "intents"]
)

In [None]:
val_data.set_format(
    type="torch",
)

In [None]:
from transformers import AutoModelForSeq2SeqLM
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

In [None]:
from transformers import DataCollatorForSeq2Seq
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

In [None]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    fp16=True, 
    output_dir="/home/workshop/dataset/fkd/bertGeneration/t5",
    logging_steps=2500,
    save_steps=10000,
    eval_steps=2500,
    learning_rate=7e-5,
    num_train_epochs=3,
    # logging_steps=1000,
    # save_steps=500,
    # eval_steps=7500,
    # warmup_steps=2000,
    # save_total_limit=3,
)

In [None]:
rouge = datasets.load_metric("rouge")
def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
    # print(pred_str)
    rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid

    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }

In [None]:
# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_data,
    eval_dataset=val_data,
    data_collator=data_collator,
)
trainer.train()

In [None]:

model = AutoModelForSeq2SeqLM.from_pretrained("t5/checkpoint-3000/").to(device)

In [None]:
tokenizer = AutoTokenizer.from_pretrained("t5/checkpoint-3000/")

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

def generate_summary(batch):
    # cut off at BERT max length 512
    inputs = tokenizer(batch["titles"], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
    input_ids = inputs.input_ids.to(device)
    attention_mask = inputs.attention_mask.to(device)

    outputs = model.generate(input_ids, attention_mask=attention_mask)

    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    batch["pred_summary"] = output_str

    return batch

In [None]:
batch_size = 4  # change to 64 for full evaluation

results = vali_dataset.map(generate_summary, batched=True, batch_size=batch_size, remove_columns=["titles"])


In [None]:
rouge_output = rouge.compute(predictions=results["pred_summary"], references=results["intents"], rouge_types=["rouge2"])["rouge2"].mid

In [None]:
rouge_output = rouge.compute(predictions=results["pred_summary"], references=results["intents"], rouge_types=["rouge1"])["rouge1"].mid

In [None]:
rouge_output = rouge.compute(predictions=results["pred_summary"], references=results["intents"], rouge_types=["rougeL"])["rougeL"].mid