In [None]:
#!/usr/bin/env python
# coding=utf-8
# import ipdb
import argparse
import os
import random
import math
import datasets
import evaluate
import torch
import nltk
import numpy as np

from accelerate import Accelerator
from accelerate.utils import set_seed
from accelerate import notebook_launcher
from datasets import load_dataset
from datasets import DatasetDict
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

import transformers
from transformers import (
    CONFIG_MAPPING,
    MODEL_MAPPING,
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    SchedulerType,
    get_scheduler,
)

In [None]:
# Global variables
USE_NOTEBOOK_LAUNCHER = False
str_args = None

In [None]:
# Comment out when using .py file
str_args = [
    "--test_file", "./data/public.jsonl",
    "--output_dir", "./output"
]

In [None]:
def parse_args(str_args = None):
    parser = argparse.ArgumentParser()
    # Data
    parser.add_argument("--seed", type=int, default=None)
    parser.add_argument("--test_file", type=str ,required=True)
    parser.add_argument(
        "--output_dir", 
        type=str, 
        default="./output"
    )
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        default = "google/mt5-small"
    )
    # Training Parameters
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=8,
    )
    # Preprocessing
    parser.add_argument(
        "--source_prefix",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--max_source_length",
        type=int,
        default=1024,
    )
    parser.add_argument(
        "--max_target_length",
        type=int,
        default=128,
    )    
    parser.add_argument(
        "--preprocessing_num_workers",
        type=int,
        default=None,
    )

    
    args = parser.parse_args(str_args)
    return args

In [None]:
def main(str_args = None):
    args = parse_args(str_args)
    
    # Initialize accelerator
    accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
    
    # Prepare 
    if args.seed is not None:
        set_seed(args.seed)
        
    if accelerator.is_main_process: 
        if args.output_dir is not None:
            os.makedirs(args.output_dir, exist_ok=True)
    accelerator.wait_for_everyone()
        
    # Load Dataset
    data_files ={}
    data_files['test'] = args.test_file
    raw_datasets = load_dataset("json", data_files=data_files)
    
    # Load Model
    config = AutoConfig.from_pretrained(args.model_name_or_path)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=True)
    model = AutoModelForSeq2SeqLM.from_pretrained(
            args.model_name_or_path,
            config=config
        )

    # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
    # on a small vocab and want a smaller embedding size, remove this test.
    embedding_size = model.get_input_embeddings().weight.shape[0]
    if len(tokenizer) > embedding_size:
        model.resize_token_embeddings(len(tokenizer))
    if model.config.decoder_start_token_id is None:
        raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
    
    prefix = args.source_prefix if args.source_prefix is not None else ""
    # Preprocessing the datasets.
    # First we tokenize all the texts.    
    column_names = raw_datasets["train"].column_names
    text_column = 'maintext'
    summary_column = 'title'
    
    max_target_length = args.max_target_length
    padding = False
    def preprocess_function(examples):
        inputs = examples[text_column]
        targets = examples[summary_column]
        inputs = [prefix + inp for inp in inputs]
        model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=padding, truncation=True)

        # Tokenize targets with the `text_target` keyword argument
        labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)

        model_inputs["labels"] = labels["input_ids"]
        return model_inputs
    
    with accelerator.main_process_first():
        test_dataset = raw_datasets["test"].map(
            preprocess_function,
            batched=True,
            num_proc=args.preprocessing_num_workers,
            remove_columns=column_names,
        )

    # Data Collator
    label_pad_token_id = -100
    data_collator = DataCollatorForSeq2Seq(
        tokenizer,
        model=model,
        label_pad_token_id=label_pad_token_id,
        pad_to_multiple_of = None
    )

    # Postprocessing the predictions
    def postprocess_text(preds, labels):
        preds = [pred.strip() for pred in preds]
        labels = [label.strip() for label in labels]

        # rougeLSum expects newline after each sentence
        preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
        labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

        return preds, labels
    
    # Data Loader
    
    test_dataloader = DataLoader(test_dataset, collate_fn=data_collator, batch_size=args.batch_size)
    
    # Prepare everything with our `accelerator`.
    model,test_dataloader = accelerator.prepare(
        model,test_dataloader
    )
    
    # Evaluation
    model.eval()
    preds = []
    refs = []
    gen_kwargs = {
        "max_length": args.max_target_length,
        "num_beams": args.num_beams,
    }
    for step, batch in enumerate(test_dataloader):
        with torch.no_grad():
            generated_tokens = accelerator.unwrap_model(model).generate(
                batch["input_ids"],
                attention_mask=batch["attention_mask"],
                **gen_kwargs,
            )

            generated_tokens = accelerator.pad_across_processes(
                generated_tokens, dim=1, pad_index=tokenizer.pad_token_id
            )
            labels = batch["labels"]
            labels = accelerator.pad_across_processes(batch["labels"], dim=1, pad_index=tokenizer.pad_token_id)

            generated_tokens, labels = accelerator.gather_for_metrics((generated_tokens, labels))
            generated_tokens = generated_tokens.cpu().numpy()
            labels = labels.cpu().numpy()

            labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
            if isinstance(generated_tokens, tuple):
                generated_tokens = generated_tokens[0]
            decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
            decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

            decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
            preds += decoded_preds
            refs += decoded_labels
            
    result = get_rouge(preds,refs)
    all_results.append(result)
    
    if args.output_dir is not None:
        accelerator.wait_for_everyone()
        # Save predictions


In [None]:
if __name__ == "__main__":
    if USE_NOTEBOOK_LAUNCHER:
        notebook_launcher(main,(str_args,), num_processes=1)
    else:      
        main(str_args)