# simpleT5: Generating Title

**simpleT5** is built on top of PyTorch-lightning and Transformers

## Dataset Prepration

**Load the dataset in Pandas DataFrame**

In [None]:
import re
import json
import pandas as pd
from tqdm import tqdm

data_file = '../input/arxiv/arxiv-metadata-oai-snapshot.json'

""" Using `yield` to load the JSON file in a loop to prevent Python memory issues if JSON is loaded directly"""

def get_metadata():
    with open(data_file, 'r') as f:
        for line in f:
            yield line

            
            
# we will consider below 3 categories for training 
paper_categories = ["cs.AI", # Artificial Intelligence
                    "cs.CV", # Computer Vision and Pattern Recognition
                    "cs.LG"] # Machine Learning



def build_dataset(categories=paper_categories):
    titles = []
    abstracts = []
    metadata = get_metadata()
    for paper in tqdm(metadata):
        paper_dict = json.loads(paper)
        category = paper_dict.get('categories')
        if category in categories:
            try:
                year = int(paper_dict.get('journal-ref')[-4:])
                titles.append(paper_dict.get('title'))
                abstracts.append(paper_dict.get('abstract').replace("\n",""))
            except:
                pass 

    papers = pd.DataFrame({'title': titles,'abstract': abstracts})
    papers = papers.dropna()
    papers["title"] = papers["title"].apply(lambda x: re.sub('\s+',' ', x))
    papers["abstract"] = papers["abstract"].apply(lambda x: re.sub('\s+',' ', x))

    del titles, abstracts
    return papers

In [None]:
papers = build_dataset()

## Training

In [None]:
!pip install simplet5

In [None]:
# simpleT5 expects training and validation dataframes to have 2 columns: "source_text" and "target_text"
papers = papers[['abstract','title']]
papers.columns = ["source_text", "target_text"]

# let's add a prefix to source_text, to uniquely identify kind of task we are performing on the data, in this case --> "summarize"
papers['source_text'] = "summarize: "+ papers['source_text']

In [None]:
# split the data into training and test
from sklearn.model_selection import train_test_split

train_df, test_df = train_test_split(papers, test_size=0.1)

In [None]:
# import
from simplet5 import SimpleT5

# instatntiate
model = SimpleT5()

# load
model.from_pretrained("t5","t5-base")

# train
model.train(train_df=train_df, eval_df=test_df, source_max_token_len=512, target_max_token_len=128, max_epochs=5, batch_size=8, use_gpu=True)

## Inferencing
**simpleT5** saves your model at every epoch in "outputs" folder (default)

In [None]:
!ls /kaggle/working
from IPython.display import FileLink
display(FileLink("simplet5-epoch-4-train-loss-1.1383-val-loss-1.9299"))

In [None]:
!ls outputs/

In [None]:
# load a trained model
model.load_model("outputs/SimpleT5-epoch-4-train-loss-1.1577", use_gpu=True)

In [None]:
sample_abstracts = test_df.sample(10)

for i, abstract in sample_abstracts.iterrows():
    print(f"===== Abstract =====")
    print(abstract['source_text'])
    summary= model.predict(abstract['source_text'])[0]
    print(f"\n===== Actual Title =====")
    print(f"{abstract['target_text']}")
    print(f"\n===== Generated Title =====")
    print(f"{summary}")
    print("\n +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n")