In [None]:

%%capture
!pip install datasets
!pip install transformers
%%capture
!rm seq2seq_trainer.py
!wget https://raw.githubusercontent.com/huggingface/transformers/master/examples/seq2seq/seq2seq_trainer.py

!pip install git-python==1.0.3
!pip install sacrebleu==1.4.12
!pip install rouge_score

from transformers import Seq2SeqTrainer
from transformers import TrainingArguments
from dataclasses import dataclass, field
from typing import Optional


In [None]:

import datasets
import transformers
from transformers import BertTokenizerFast
from transformers import EncoderDecoderModel
from transformers import Seq2SeqTrainer
from transformers import TrainingArguments
from dataclasses import dataclass, field
from typing import Optional


In [None]:
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
tokenizer.bos_token = tokenizer.cls_token
tokenizer.eos_token = tokenizer.sep_token
tokenizer.sep_token = tokenizer.sep_token

data = datasets.load_dataset('scientific_papers','arxiv')

train_data_full = data['train']
val_data_full = data['validation']
test_data_full = data['test']

train_data=train_data_full
val_data = val_data_full
test_data = test_data_full

batch_size = 64


In [None]:
def process_data_to_model_inputs(batch):
  # tokenize the inputs and labels
  inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=encoder_max_length)
  outputs = tokenizer(batch["abstract"], padding="max_length", truncation=True, max_length=decoder_max_length)
  # inputs = my_custom_tokenizer(batch["article"], 512) # Need to resolve the technical tensor dimension based errors from this function and use this method to fine-tune
  # outputs = my_custom_tokenizer(batch["abstract"], 128)

  batch["input_ids"] = inputs["input_ids"]
  batch["attention_mask"] = inputs["attention_mask"]
  batch["decoder_input_ids"] = outputs["input_ids"]
  batch["decoder_attention_mask"] = outputs["attention_mask"]
  batch["labels"] = outputs["input_ids"].copy()

  batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]]

  return batch


In [None]:
# Tokenize document into sentences
def my_custom_tokenizer(batch,max_length):

  pad_length = max_length
  ret = {}
  ret['input_ids']=[]
  ret['attention_mask']=[]
  for b in batch:
    for para in b :
      doc = para
      sentences = nltk.sent_tokenize(doc)

      # Initialize lists for tokens, input IDs, and attention masks
      tokens = []
      input_ids = []
      attention_masks = []

      # Add CLS token at the beginning of the article
      tokens.append("[CLS]")

      # Loop through each sentence and tokenize it
      for sentence in sentences:
          # Tokenize the sentence and append the tokens
          sentence_tokens = tokenizer.tokenize(sentence)
          tokens += sentence_tokens

          # Create input IDs and attention masks for the tokens
          sentence_input_ids = tokenizer.convert_tokens_to_ids(sentence_tokens)
          sentence_attention_masks = [1] * len(sentence_input_ids)

          # Append the input IDs and attention masks for the sentence
          input_ids += sentence_input_ids
          attention_masks += sentence_attention_masks

          # Add SEP token after the sentence
          tokens.append("[SEP]")
          input_ids.append(tokenizer.sep_token_id)
          attention_masks.append(1)

      # Truncate the input if it exceeds the maximum length
      if len(input_ids) > max_length:
          input_ids = input_ids[:max_length]
          attention_masks = attention_masks[:max_length]

      # Pad the input if it is shorter than the fixed length
      if len(input_ids) < pad_length:
          padding_length = pad_length - len(input_ids)
          input_ids = input_ids + ([tokenizer.pad_token_id] * padding_length)
          attention_masks = attention_masks + ([0] * padding_length)

      # Convert tokens to input IDs and attention masks
      input_ids = [tokenizer.cls_token_id] + input_ids
      attention_masks = [1] + attention_masks

      ret['input_ids'].append(input_ids)
      ret['attention_mask'].append(attention_masks)

  return ret

In [None]:
train_data=train_data_full
val_data = val_data_full

train_data = train_data.map(
    process_data_to_model_inputs, 
    batched=True, 
    batch_size=batch_size, 
    remove_columns=["article", "abstract", "section_names"]
)
train_data.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)


val_data= val_data.map(
    process_data_to_model_inputs, 
    batched=True, 
    batch_size=batch_size, 
    remove_columns=["article", "abstract", "section_names"]
)
val_data.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

In [None]:
bert2bert = EncoderDecoderModel.from_pretrained('patrickvonplaten/bert2bert_cnn_daily_mail')

count = 0
for param in bert2bert.parameters():
  if count == 250:
    break
  param.requires_grad = False
  count = count +1


In [None]:
@dataclass
class Seq2SeqTrainingArguments(TrainingArguments):
    label_smoothing: Optional[float] = field(
        default=0.0, metadata={"help": "The label smoothing epsilon to apply (if not zero)."}
    )
    sortish_sampler: bool = field(default=False, metadata={"help": "Whether to SortishSamler or not."})
    predict_with_generate: bool = field(
        default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
    )
    adafactor: bool = field(default=False, metadata={"help": "whether to use adafactor"})
    encoder_layerdrop: Optional[float] = field(
        default=None, metadata={"help": "Encoder layer dropout probability. Goes into model.config."}
    )
    decoder_layerdrop: Optional[float] = field(
        default=None, metadata={"help": "Decoder layer dropout probability. Goes into model.config."}
    )
    dropout: Optional[float] = field(default=None, metadata={"help": "Dropout probability. Goes into model.config."})
    attention_dropout: Optional[float] = field(
        default=None, metadata={"help": "Attention dropout probability. Goes into model.config."}
    )
    lr_scheduler: Optional[str] = field(
        default="linear", metadata={"help": f"Which lr scheduler to use."}
    )
    generation_config : Optional[str] = field(
        default=None, metadata={"help": "Goes into model.config"}
    )
    evaluate_during_training : bool = field(
        default=True, metadata={"help": "evaluate during training"}
    )

rouge = datasets.load_metric("rouge")


In [None]:
def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions


    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid

    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    predict_with_generate=True,
    evaluate_during_training=True,
    do_train=True,
    do_eval=True,
    logging_steps=1000,  
    save_steps=500,  
    eval_steps=8000,  
    warmup_steps=2000,  
    # max_steps=16, 
    overwrite_output_dir=True,
    save_total_limit=3,
    fp16=True

)


trainer = Seq2SeqTrainer(
    model=bert2bert,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_data,
    eval_dataset=val_data,
)

In [None]:
trainer.train()

In [None]:
model = bert2bert
model.to("cuda")

test_data = data['test']

batch_size = 5

In [None]:
def generate_summary(batch):

    inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
    input_ids = inputs.input_ids.to("cuda")
    attention_mask = inputs.attention_mask.to("cuda")

    outputs = model.generate(input_ids, attention_mask=attention_mask)


    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    batch["pred"] = output_str

    return batch

results = test_data.map(generate_summary, batched=True, batch_size=batch_size, remove_columns=["article"])

pred_str = results["pred"]
label_str = results["abstract"]