## Import libraries

In [2]:
import pandas as pd
import numpy as np
from simplet5 import SimpleT5

Global seed set to 42


## Prepare the data

In [6]:
df = pd.read_csv("/content/news_summary.csv")
df = df.rename(columns={'headlines': 'target_text', 'text': 'source_text'})
df = df[['source_text', 'target_text']]
df['source_text'] = "summarize: " + df['source_text']
df = df.head(5000)
df.head(3)

Unnamed: 0,source_text,target_text
0,"summarize: Saurav Kant, an alumnus of upGrad a...",upGrad learner switches to career in ML & Al w...
1,summarize: Kunal Shah's credit card bill payme...,Delhi techie wins free food from Swiggy for on...
2,summarize: New Zealand defeated India by 8 wic...,New Zealand end Rohit Sharma-led India's 12-ma...


In [7]:
from sklearn.model_selection import train_test_split
train_df, test_df = train_test_split(df, test_size=0.2)
train_df.shape, test_df.shape

((2196, 2), (549, 2))

## Prepare the Model

In [8]:
model = SimpleT5()
model.from_pretrained(model_type='t5', model_name='t5-base')

Downloading:   0%|          | 0.00/773k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.32M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.17k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/850M [00:00<?, ?B/s]

In [9]:
import torch
torch.cuda.empty_cache()
torch.cuda.is_available()

True

In [10]:
model.train(
    train_df=train_df,
    eval_df=test_df,
    source_max_token_len=128,
    target_max_token_len=50,
    batch_size=8,
    max_epochs=2,
    use_gpu=True
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Missing logger folder: /content/lightning_logs

  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | T5ForConditionalGeneration | 222 M 
-----------------------------------------------------
222 M     Trainable params
0         Non-trainable params
222 M     Total params
891.614   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Global seed set to 42


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

## Load and test the model

In [11]:
model.load_model("tf", "/content/outputs/simplet5-epoch-1-train-loss-1.146-val-loss-1.2703", use_gpu=True)

In [12]:
test_df['target_text'][0:10]

471     Wife, I will fight for Navnirman of country: N...
1025    Deep-rooted sexism made people doubt me as dir...
767     Windies Women's captain refuses to go to Pak f...
1896    I have stopped smoking weed, I like to be aler...
430     UK company's delivery boy stands on van, throw...
786     CBI may probe ICICI, Goldman India CEOs in Koc...
2086    Nadal wakes up sleeping journalist during his ...
1830    Man charged for allegedly sending 0.5 kg meth ...
599      Kanye West sued by Yeezy fabric supplier for ...
1731    Teen carries mother's body on cycle as locals ...
Name: target_text, dtype: object

In [13]:
for doc in test_df['source_text'][0:10]:
  print(model.predict(doc))

["Will fight together for 'Navnirman': Patel on resolution"]
['Deep-rooted sexism made people doubt me: Kangana']
["Windies Women's team captain Stafanie Taylor refuses trip to Pakistan"]
['Stop smoking weed: singer-songwriter Ed Sheeran']
['Hermes employee throws parcel on building floor in UK']
['CBI to investigate ICICI CEO, Sonjoy Chatterjee in Kochhar-Videocon case']
["It's not interesting today: Nadal wakes up sleeping journalist"]
['Man who sent meth to Apple Store in US facing meth charges']
['Japanese firm sues Kanye West over 14 cr order']
['17-yr-old carried body of deceased mother on bicycle in Odisha']


### Testing random news from the internet

In [14]:
model.predict("summary: At least six people died and 71 others were hospitalised in the last three days due to diarrhoea after allegedly drinking contaminated water from open sources in several villages in Odisha's Rayagada district, officials said on Saturday. A team of 11 doctors visited the affected villages and collected water and blood samples and sent those for examination.")

['6 dead, 71 others hospitalised after drinking contaminated water']