In [5]:
import utils
import os
import pandas as pd
import random

In [6]:
# Constants
TLDR = ' TL;DR '
MAX_LEN = 512
NUM_ELEMENTS = 50000
BATCHES = 2
SAVE_MODEL_PATH = '../trained_models/gpt2-summarization-gpu'
DATA_PATH = "../data/cleaned_data/"

In [7]:
if not os.path.exists(DATA_PATH):
    utils.clean_data()
all_articles_dict = utils.load_article_data(path=DATA_PATH)
del all_articles_dict['clean_Articles.csv']
del all_articles_dict['clean_CNN_Articels_clean.csv']
all_articles_df = pd.concat([df for df in all_articles_dict.values()])

In [8]:
def strip_nonalnum(word):
    if not word:
        return word  # nothing to strip
    for start, c in enumerate(word):
        if c.isalnum():
            break
    for end, c in enumerate(word[::-1]):
        if c.isalnum():
            break
    return word[start:len(word) - end]

def clean_datapoint(datapoint):
    """
    Given a line from the cleaned data. Perform transformations to get a resulting string of
    the format: 'article TL;DR headline' without any starting or trailing non-alphanumeric characters.
    Also remove ending titles for specific newspapers.
    """
    res = strip_nonalnum(datapoint[1]) + ' TL;DR ' + strip_nonalnum(datapoint[0]).replace(' - The New York Times', '').replace(' - Breitbart', '')
    return res

def pad_and_truncate_data(dataset):
    """
    Format data to always contain the TL;DR and the entire headline. Truncate the article such that
    the whole string becomes MAX_LEN long.
    """
    ARTICLE_LEN = MAX_LEN - len(TLDR)
    result = []
    for d in dataset:
        try:
            article, headline = d.split(' TL;DR ')
            result.append(article[0:ARTICLE_LEN - len(headline)] + TLDR + headline)
        except:
            continue
    return result   


# Clean each element of data and format by: article TL;DR headline
all_articles = all_articles_df.values.tolist()
all_articles = [clean_datapoint(x) for x in all_articles if isinstance(x[0], str) and isinstance(x[1], str)][:NUM_ELEMENTS]

# Pad and truncate data to specific length
all_articles = pad_and_truncate_data(all_articles)
print(f'Example: {all_articles[0]}')

Example: WASHINGTON  —   Congressional Republicans have a new fear when it comes to their    health care lawsuit against the Obama administration: They might win. The incoming Trump administration could choose to no longer defend the executive branch against the suit, which challenges the administration’s authority to spend billions of dollars on health insurance subsidies for   and   Americans, handing House Republicans a big victory on    issues. Bu TL;DR House Republicans Fret About Winning Their Health Care Suit


In [None]:
# Write data to files to be loaded into a dataset
random.seed(11)
random.shuffle(all_articles)
TRAIN_SPLIT = 0.9
END_IDX = int(len(all_articles) * TRAIN_SPLIT)
with open("../data/train_data.txt", "w", encoding='utf-8') as txt_file:
    for line in all_articles[0:END_IDX]:
        txt_file.write(line + "\n") # works with any number of elements in a line
with open("../data/test_data.txt", "w", encoding='utf-8') as txt_file:
    for line in all_articles[END_IDX:]:
        txt_file.write(line + "\n") # works with any number of elements in a line