### Script to generate summaries using chunking based pre-trained BigBird model

Assign the dataset and output_path variable according to requirements.  


In [None]:
dataset = "IN" # Options: IN - IN-Abs, UK-UK-Abs, N2-IN-Ext 
output_path = "./IN_BigBird/"

In [None]:
import pandas as pd
import numpy as np
import glob
import sys
sys.path.insert(0, '../')
from utilities import *
import os
import nltk
import torch
from transformers import BigBirdPegasusForConditionalGeneration, AutoTokenizer

In [None]:
if not os.path.exists(output_path):
    os.makedirs(output_path)

In [None]:
#Reading the test documents
names, data_source, data_summary = get_summary_data(dataset, "test")
print(len(names))
print(len(data_source))
print(len(data_summary))
len_dic = dict_names = get_req_len_dict(dataset, "test")   

In [None]:
DATASET_NAME = "pubmed"
DEVICE = "cuda:1"
CACHE_DIR = DATASET_NAME
MODEL_ID = f"google/bigbird-pegasus-large-{DATASET_NAME}"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = BigBirdPegasusForConditionalGeneration.from_pretrained(MODEL_ID).to(DEVICE)


In [None]:
def get_summ(input, max_l, min_l):
    '''
    Function to generate summaries from the document. This function uses a chunking-based approach.
    input:  nested_sentences - chunks
            p - Number of words in summaries per word in the document
    output: document summary
    '''
    nested = nest_sentences(input, 4096)
    summs = []
    for chunk in nested:
        inputs_dict = tokenizer(chunk, padding="max_length", max_length=4096, return_tensors="pt", truncation=True)
        inputs_dict = {k: inputs_dict[k].to(DEVICE) for k in inputs_dict}
        predicted_abstract_ids = model.generate(**inputs_dict, min_length=min_l, num_beams=5)
        result = tokenizer.decode(predicted_abstract_ids[0], skip_special_tokens=True)
        summs.append(result)
#         print(result)
    summ = '. '.join(summs)
    return summ
    

In [None]:
# main loop to generate and save summaries of each document in the test dataset
result = []
for i in range(len(data_source)):
    print(str(i) + " : " + names[i])
    summ = get_summ(data_source[i], len_dic[names[i]], len_dic[names[i]]-100)
    result.append(summ)
    path = output_path + names[i]
    file = open(path,'w')
    file.write(summ)
    file.close()
#     break
print(result)