# News Headline generation model using T5

### Library used
simpleT5: https://github.com/Shivanandroy/simpleT5


In [23]:

import numpy as np 
import pandas as pd
from sklearn.model_selection import train_test_split
import os

for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

/kaggle/input/news-summary/news_summary_more.csv
/kaggle/input/news-summary/news_summary.csv


In [24]:
#! pip install simplet5 -q

Uncomment this if simple5 is not installed

# Read Input Dataset

In [25]:
df = pd.read_csv("../input/news-summary/news_summary.csv", encoding='latin-1', usecols=['headlines', 'text'])

In [26]:
df.head()

Unnamed: 0,headlines,text
0,Daman & Diu revokes mandatory Rakshabandhan in...,The Administration of Union Territory Daman an...
1,Malaika slams user who trolled her for 'divorc...,Malaika Arora slammed an Instagram user who tr...
2,'Virgin' now corrected to 'Unmarried' in IGIMS...,The Indira Gandhi Institute of Medical Science...
3,Aaj aapne pakad liya: LeT man Dujana before be...,Lashkar-e-Taiba's Kashmir commander Abu Dujana...
4,Hotel staff to get training to spot signs of s...,Hotels in Maharashtra will train their staff t...


# T5 Data Preparation with Training Data Column Names - `source_text` & `target_text`

In [27]:
df = df.rename(columns={"headlines":"target_text", "text":"source_text"})
df = df[['source_text', 'target_text']]


In [28]:
df.head()

Unnamed: 0,source_text,target_text
0,The Administration of Union Territory Daman an...,Daman & Diu revokes mandatory Rakshabandhan in...
1,Malaika Arora slammed an Instagram user who tr...,Malaika slams user who trolled her for 'divorc...
2,The Indira Gandhi Institute of Medical Science...,'Virgin' now corrected to 'Unmarried' in IGIMS...
3,Lashkar-e-Taiba's Kashmir commander Abu Dujana...,Aaj aapne pakad liya: LeT man Dujana before be...
4,Hotels in Maharashtra will train their staff t...,Hotel staff to get training to spot signs of s...


In [29]:
# T5 model expects a task related prefix: since it is a summarization task, we will add a prefix "summarize: "
df['source_text'] = "summarize: " + df['source_text']
df

Unnamed: 0,source_text,target_text
0,summarize: The Administration of Union Territo...,Daman & Diu revokes mandatory Rakshabandhan in...
1,summarize: Malaika Arora slammed an Instagram ...,Malaika slams user who trolled her for 'divorc...
2,summarize: The Indira Gandhi Institute of Medi...,'Virgin' now corrected to 'Unmarried' in IGIMS...
3,summarize: Lashkar-e-Taiba's Kashmir commander...,Aaj aapne pakad liya: LeT man Dujana before be...
4,summarize: Hotels in Maharashtra will train th...,Hotel staff to get training to spot signs of s...
...,...,...
4509,summarize: Fruit juice concentrate maker Rasna...,Rasna seeking ?250 cr revenue from snack categ...
4510,summarize: Former Indian cricketer Sachin Tend...,Sachin attends Rajya Sabha after questions on ...
4511,"summarize: Aamir Khan, while talking about rea...",Shouldn't rob their childhood: Aamir on kids r...
4512,summarize: The Maharashtra government has init...,"Asha Bhosle gets ?53,000 power bill for unused..."


In [30]:
train_df, test_df = train_test_split(df, test_size=0.3)
train_df.shape, test_df.shape

((3159, 2), (1355, 2))

# Using SimpleT5 for Model Training - Instantiate, Download Pre-trained Model

In [31]:
from simplet5 import SimpleT5

model = SimpleT5()
model.from_pretrained(model_type="t5", model_name="t5-base")


# Model Training

In [32]:
model.train(train_df=train_df,
            eval_df=test_df, 
            source_max_token_len=128, 
            target_max_token_len=50, 
            batch_size=8, 
            max_epochs=3, 
            use_gpu=True
           )

Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

# List of all the trained models

In [36]:
! ( cd outputs; ls )

simplet5-epoch-0-train-loss-1.5659-val-loss-1.2959
simplet5-epoch-0-train-loss-1.5824-val-loss-1.2673
simplet5-epoch-1-train-loss-1.1468-val-loss-1.2883
simplet5-epoch-1-train-loss-1.1642-val-loss-1.2329
simplet5-epoch-2-train-loss-0.9164-val-loss-1.3125
simplet5-epoch-2-train-loss-0.9309-val-loss-1.2804


In [44]:
#If you want to save the model locally zip it and download the file. pytorch_model.bin
#!(zip -r ngga.zip ./outputs/simplet5-epoch-2-train-loss-0.9309-val-loss-1.2804)

  adding: outputs/simplet5-epoch-2-train-loss-0.9309-val-loss-1.2804/ (stored 0%)
  adding: outputs/simplet5-epoch-2-train-loss-0.9309-val-loss-1.2804/pytorch_model.bin (deflated 8%)
  adding: outputs/simplet5-epoch-2-train-loss-0.9309-val-loss-1.2804/special_tokens_map.json (deflated 83%)
  adding: outputs/simplet5-epoch-2-train-loss-0.9309-val-loss-1.2804/tokenizer_config.json (deflated 80%)
  adding: outputs/simplet5-epoch-2-train-loss-0.9309-val-loss-1.2804/tokenizer.json (deflated 59%)
  adding: outputs/simplet5-epoch-2-train-loss-0.9309-val-loss-1.2804/config.json (deflated 63%)
  adding: outputs/simplet5-epoch-2-train-loss-0.9309-val-loss-1.2804/spiece.model (deflated 48%)


# Selecting the model

In [34]:
# Every time a new model will be trained. Make sure to change the model.
model.load_model("t5","outputs/simplet5-epoch-2-train-loss-0.9309-val-loss-1.2804", use_gpu=True)

# Results

In [57]:
text_to_summarize="""Carts set on fire by miscreants during a protest in Ranchi over controversial remarks on Prophet Muhammad (PTI photo)An uneasy calm prevailed in Jharkhand’s Ranchi on Saturday, less than 24 hours after a protest against suspended BJP spokesperson Nupur Sharma’s inflammatory remarks on Prophet Mohammad turned violent.
No violent incident has been reported in Ranchi after irate protesters swept through parts of the city on Friday, pelting stones and committing arson.
Demonstrators bearing placards and shouting slogans demanded the BJP leader’s arrest for her controversial remarks on Prophet Mohammad."""
model.predict(text_to_summarize)

['Carts set on fire during protest over Prophet Mohammad']

# Using Newspaper 3k to get news summary

In [54]:
#pip install newspaper3k

In [58]:
url = input("Enter: ")

Enter:  https://www.ndtv.com/india-news/agnipath-protests-1-dead-over-15-injured-in-telanganas-secunderabad-3075437


In [61]:
from newspaper import Article    
article = Article(url)
article.download()
article.parse()
article.nlp()
g = article.summary

In [63]:
model.predict(g)

['Telangana teenager killed as police fire at angry mob']