In [None]:
!pip install simpletransformers

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import pandas
from simpletransformers.seq2seq import Seq2SeqModel,Seq2SeqArgs

In [None]:
#set the required parameters for the model to be trained on
m_args = Seq2SeqArgs()
m_args.num_train_epochs = 100
m_args.no_save = True
m_args.evaluate_generated_text = True
m_args.evaluate_during_training = True
m_args.evaluate_during_training_verbose = True
m_args.overwrite_output_dir = True

#Instantiate and Initialize a seq2seq model with these args
model_obj = Seq2SeqModel(
    encoder_decoder_type="bart",
    encoder_decoder_name="facebook/bart-large",
    args=m_args,
    use_cuda=True
)

In [None]:
import re

def load_dataset(file_path, encoding_standard):
    data_file = pandas.read_csv(file_path, encoding=encoding_standard)

    base_text = data_file["base_text"].tolist()
    summary_text = data_file["summary_text"].tolist()

    return (base_text,summary_text)


def clean_dataset(text_set):
    returning_set = []
    for i in range(len(text_set)):
        temp_string = text_set[i]
        
        #remove escape characters from the text
        temp_string = re.sub("(\\t)", ' ', str(temp_string)).lower()
        temp_string = re.sub("(\\r)", ' ', str(temp_string)).lower()
        temp_string = re.sub("(\\n)", ' ', str(temp_string)).lower()

        #remove underscore if it occurs more than once in a row
        temp_string = re.sub("(__+)", ' ', str(temp_string)).lower()

        #remove hyphen if it occurs more than once in a row
        temp_string = re.sub("(--+)", ' ', str(temp_string)).lower()

        #remove tilde symbol if it occurs more than once in a row
        temp_string = re.sub("(~~+)", ' ', str(temp_string)).lower()

        #remove the plus symbol if it occurs more than once in a row
        temp_string = re.sub("(\+\++)", ' ', str(temp_string)).lower()

        #remove period if it occurs more than once in a row
        temp_string = re.sub("(\.\.+)", ' ', str(temp_string)).lower()

        #remove special characters: <>()|&©ø"',;?~*!
        temp_string = re.sub(r"[<>()|&©ø\[\]\'\",;?~*!]", ' ', str(temp_string)).lower()

        #remove mailto attribute from the scraped HTML page
        temp_string = re.sub('(mailto:)', ' ', str(temp_string)).lower()

        #remove \x9*
        temp_string = re.sub(r"(\\x9\d)", ' ', str(temp_string)).lower()

        #remove INC nums with the INC_NUMS constant
        temp_string = re.sub("([iI][nN][cC]\d+)", 'INC_NUM', str(temp_string)).lower()

        #remove CM# and CHG# with the CM_NUMS constant
        temp_string = re.sub("([cC][mM]\d+)|([cC][hH][gG]\d+)", 'CM_NUM', str(temp_string)).lower()

        #remove fullstops but only if they occur at the end of the words, not the ones in between words
        temp_string = re.sub("(\.\s+)", ' ', str(temp_string)).lower()

        #remove hyphens but only if they occur at the end of the words, not the ones in between words
        temp_string = re.sub("(\-\s+)", ' ', str(temp_string)).lower()

        #remove colons but only if they occur at the end of the words, not the ones in between words
        temp_string = re.sub("(\:\s+)", ' ', str(temp_string)).lower()

        #remove single characters between 2 spaces
        temp_string = re.sub("(\s+.\s+)", ' ', str(temp_string)).lower()

        #try removing https info from any url's embedded in the text
        try:
            url = re.search(r'((https*:\/*)([^\/\s]+))(.[^\s]+)', str(temp_string))
            replacement_url = url.group(3)
            temp_string = re.sub(r'((https*:\/*)([^\/\s]+))(.[^\s]+)', replacement_url, str(temp_string))
        except:
            pass #case of emails with no url in them

        #remove multiple consecutive spaces
        temp_string = re.sub("(\s+)", ' ', str(temp_string)).lower()

         #remove single characters between 2 spaces
        temp_string = re.sub("(\s+.\s+)", ' ', str(temp_string)).lower()

        #put cleaned string back into the text set
        returning_set.append(temp_string)
    
    return returning_set

In [None]:
raw_base_data_set, raw_summary_data_set = load_dataset(file_path="https://raw.githubusercontent.com/NeelOommen/News-Article-Text-Summarizer/main/dataset/batch1.csv", encoding_standard="iso-8859-1")
cleaned_base_set = clean_dataset(raw_base_data_set)
cleaned_summary_data_set = clean_dataset(raw_summary_data_set)


In [None]:
#setting up the data as required by the model
num_data_pairs = len(cleaned_base_set)
#num_data_pairs = 10
num_data_pairs

In [None]:
import math
training_data_set_size = math.floor(num_data_pairs * 0.9)
training_data_set_size

eval_data_set_size = len(cleaned_base_set) - training_data_set_size
#eval_data_set_size = 10 - training_data_set_size
eval_data_set_size

In [None]:
#create training data set
training_data_set = []
for i in range(training_data_set_size):
    new_list = [cleaned_base_set[i], cleaned_summary_data_set[i]]
    training_data_set.append(new_list)

In [None]:
#create evaluation data set
evaluation_data_set = []
for i in range(eval_data_set_size):
    new_list = [cleaned_base_set[training_data_set_size + i], cleaned_summary_data_set[training_data_set_size + i]]
    evaluation_data_set.append(new_list)

In [None]:
#create a pandas training data frame
training_data_frame = pandas.DataFrame(
    training_data_set, columns=["input_text","target_text"]
)

#create a pandas evaluation data frame
evaluation_data_frame = pandas.DataFrame(
    evaluation_data_set, columns=["input_text","target_text"]
)

In [None]:
#training the model
model_obj.train_model(
    training_data_frame, 
    eval_data = evaluation_data_frame, 
    output_dir="content/drive/My Drive/colab_model_data/",
    save_total_limit = 5,
    load_best_model_at_end = True
)