<a href="https://colab.research.google.com/github/Shivanandroy/simpleT5/blob/main/examples/simpleT5-summarization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install simplet5

In [15]:
# let's get a dataset
import pandas as pd
from sklearn.model_selection import train_test_split

path = "https://raw.githubusercontent.com/Shivanandroy/T5-Finetuning-PyTorch/main/data/news_summary.csv"
df = pd.read_csv(path)
df.head()

Unnamed: 0,headlines,text
0,upGrad learner switches to career in ML & Al w...,"Saurav Kant, an alumnus of upGrad and IIIT-B's..."
1,Delhi techie wins free food from Swiggy for on...,Kunal Shah's credit card bill payment platform...
2,New Zealand end Rohit Sharma-led India's 12-ma...,New Zealand defeated India by 8 wickets in the...
3,Aegon life iTerm insurance plan helps customer...,"With Aegon Life iTerm Insurance plan, customer..."
4,"Have known Hirani for yrs, what if MeToo claim...",Speaking about the sexual harassment allegatio...


In [16]:
# simpleT5 expects dataframe to have 2 columns: "source_text" and "target_text"
df = df.rename(columns={"headlines":"target_text", "text":"source_text"})
df = df[['source_text', 'target_text']]

# T5 model expects a task related prefix: since it is a summarization task, we will add a prefix "summarize: "
df['source_text'] = "summarize: " + df['source_text']
df

Unnamed: 0,source_text,target_text
0,"summarize: Saurav Kant, an alumnus of upGrad a...",upGrad learner switches to career in ML & Al w...
1,summarize: Kunal Shah's credit card bill payme...,Delhi techie wins free food from Swiggy for on...
2,summarize: New Zealand defeated India by 8 wic...,New Zealand end Rohit Sharma-led India's 12-ma...
3,summarize: With Aegon Life iTerm Insurance pla...,Aegon life iTerm insurance plan helps customer...
4,summarize: Speaking about the sexual harassmen...,"Have known Hirani for yrs, what if MeToo claim..."
...,...,...
98396,summarize: A CRPF jawan was on Tuesday axed to...,CRPF jawan axed to death by Maoists in Chhatti...
98397,"summarize: 'Uff Yeh', the first song from the ...",First song from Sonakshi Sinha's 'Noor' titled...
98398,"summarize: According to reports, a new version...",'The Matrix' film to get a reboot: Reports
98399,summarize: A new music video shows rapper Snoo...,Snoop Dogg aims gun at clown dressed as Trump ...


In [18]:
train_df, test_df = train_test_split(df, test_size=0.2)
train_df.shape, test_df.shape

((78720, 2), (19681, 2))

In [19]:
from simplet5 import SimpleT5

model = SimpleT5()
model.from_pretrained(model_type="t5", model_name="t5-base")
model.train(train_df=train_df[:5000], eval_df=test_df[:100], source_max_token_len=128, target_max_token_len=50, batch_size=8, max_epochs=3, use_gpu=True)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                       | Params
-----------------------------------------------------
0 | model | T5ForConditionalGeneration | 222 M 
-----------------------------------------------------
222 M     Trainable params
0         Non-trainable params
222 M     Total params
891.614   Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

Global seed set to 42




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




In [24]:
# let's load the trained model for inferencing:
model.load_model("t5","outputs/SimpleT5-epoch-2-train-loss-0.9526", use_gpu=True)

text_to_summarize="""summarize: Rahul Gandhi has replied to Goa CM Manohar Parrikar's letter, 
which accused the Congress President of using his "visit to an ailing man for political gains". 
"He's under immense pressure from the PM after our meeting and needs to demonstrate his loyalty by attacking me," 
Gandhi wrote in his letter. Parrikar had clarified he didn't discuss Rafale deal with Rahul.
"""
model.predict(text_to_summarize)

["Rahul responds to Parrikar's letter accusing him of visiting ailing man"]

In [25]:
# for faster inference on cpu, quantization, onnx support:
model.convert_and_load_onnx_model(model_dir="outputs/SimpleT5-epoch-2-train-loss-0.9526")

[KExporting to onnx... |################################| 3/3
[KQuantizing... |################################| 3/3
[?25h

Setting up onnx model...
Done!


In [26]:
%%time
model.onnx_predict(text_to_summarize)

CPU times: user 754 ms, sys: 23.5 ms, total: 777 ms
Wall time: 799 ms


"Rahul responds to Parrikar's letter accusing him of visiting Goa"