This code fine-tunes the T5 transformer model for a summarization task. It begins by loading required libraries and reading a dataset of news summaries. The dataset is prepared by prefixing each text entry with "summarize: ", ensuring the model understands it as a summarization task. The T5 model and tokenizer from Hugging Face Transformers are initialized, and a DataLoader pipeline is set up to manage data batching. Fine-tuning proceeds by optimizing the model on the dataset, allowing it to learn to generate concise summaries. The code includes error handling for data loading and utilizes rich for console output, making the process more informative and interactive. ​

In [None]:
# Importing libraries
import os
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
import os

# Importing the T5 modules from huggingface/transformers
from transformers import T5Tokenizer, T5ForConditionalGeneration
from torch import cuda
from rich.table import Column, Table
from rich import box
from rich.console import Console

In [None]:
# Adding the quoting=3 argument to handle unclosed or improperly escaped quotes.
df = pd.read_csv('news_summary.csv', quoting=3, on_bad_lines='skip')

In [None]:
df.head()

Unnamed: 0,Unnamed: 1,headlines,text
upGrad learner switches to career in ML & Al with 90% salary hike,"""Saurav Kant",an alumnus of upGrad and IIIT-B's PG Program ...,was a Sr Systems Engineer at Infosys with alm...
New Zealand end Rohit Sharma-led India's 12-match winning streak,"""New Zealand defeated India by 8 wickets in the fourth ODI at Hamilton on Thursday to win their first match of the five-match ODI series. India lost an international match under Rohit Sharma's captaincy after 12 consecutive victories dating back to March 2018. The match witnessed India getting all out for 92",their seventh lowest total in ODI cricket his...,
Rahat Fateh Ali Khan denies getting notice for smuggling currency,"""Pakistani singer Rahat Fateh Ali Khan has denied receiving any notice from the Enforcement Directorate over allegedly smuggling foreign currency out of India. """"It would have been better if the authorities would have served the notice first if any and then publicised this",""""" reads a press release issued on behalf of R...",
"""India get all out for 92","their lowest ODI total in New Zealand""","""India recorded their lowest ODI total in New ...",while their number ten batsman Yuzvendra Chah...
UP cousins fed human excreta for friendship with boys,"""Two minor cousins in Uttar Pradesh's Gorakhpur were allegedly repeatedly burnt with tongs and forced to eat human excreta by their family for being friends with two boys from the same school. The cousins revealed their ordeal to the police and Child Welfare Committee after being brought back to Gorakhpur from Nepal","where they had fled to escape the torture.""",


In [None]:
df.sample(10)

Unnamed: 0,Unnamed: 1,headlines,text
Pogba's 81st-minute goal helps France beat Australia 2-1,"""Midfielder Paul Pogba scored in the 81st minute to help France defeat 36th-ranked Australia 2-1 in their first match of the 2018 FIFA World Cup on Saturday. Forward Kylian Mbappe",aged 19 years and 178 days,became the youngest ever footballer to repres...
Sergio Ramos equals all-time La Liga record of 18 red cards,Real Madrid captain Sergio Ramos now holds the joint-record for most red cards received in La Liga history after being handed his 18th red in his side's 3-0 victory over Deportivo La Coruna on Sunday. The 31-year-old Spanish defender shares the record with former Barcelona and Sevilla defender Pablo Alfaro and former Real Zaragoza defender Xavi Aguado.,,
Anju Bobby George to appeal for 2004 Olympics medal,"""Indian long jumper Anju Bobby George has claimed she was """"robbed"""" of an Olympic medal at the 2004 Athens Games and will appeal to the International Olympic Committee. The three Russian medalists at the event failed subsequent drug tests and fifth-placed George has claimed they were """"inhaling something on the field"""". She will appeal with the fourth and sixth-placed athletes.""",,
"""Ball slips out of SA fast bowler's hand","lands near point fielder""","""During the ninth over of Australia's innings ...",the ball slipped out of fast bowler Kagiso Ra...
Over 250 human skulls found in Mexican mass grave,"""Over 250 human skulls have been found in a mass grave in Mexico's Veracruz",in what was likely a dumping ground where dru...,authorities said. The authorities have been w...
TN CM leads silent march on Jayalalithaa's death anniversary,"""Tamil Nadu CM Edappadi K Palaniswami and Deputy CM O Panneerselvam on Tuesday led a silent procession to pay tribute to former CM J Jayalalithaa on her death anniversary. Most of the AIADMK workers and ministers were dressed in black shirts. Jayalalithaa passed away on December 5",2016,after being hospitalised for several days in ...
Virgin Galactic to launch space flights from Italy,"""Richard Branson-led commercial spaceflight Virgin Galactic has announced that it will launch its commercial space flights from Italy. It has partnered with Italian Space Agency-owned company Altec and Sitael",Italy's private space company to develop a fr...,the companies did not provide a timeline for ...
Deepika features in Varietyâs Int'l Womenâs Impact Report,"""Deepika Padukone is the only Indian actress to feature in International Women's Impact Report 2018 by US magazine Variety. The report stated","""""The star of the recent Bollywood blockbuste...",
"""No settlement talks among Qualcomm","Foxconn in Apple dispute""","""Appleâs group of contract manufacturers inc...",according to the groupâs lead attorney. Qua...
"""Indian-American's find reverses wrinkles","hair loss in aged mice""","""Indian-American scientist Keshav Singh and hi...","a kind of cellular atrophy caused by ageing."""


In [None]:
df["text"] = "summarize: "+df["text"]

In [None]:
df.head()

Unnamed: 0,Unnamed: 1,headlines,text
upGrad learner switches to career in ML & Al with 90% salary hike,"""Saurav Kant",an alumnus of upGrad and IIIT-B's PG Program ...,summarize: was a Sr Systems Engineer at Infos...
New Zealand end Rohit Sharma-led India's 12-match winning streak,"""New Zealand defeated India by 8 wickets in the fourth ODI at Hamilton on Thursday to win their first match of the five-match ODI series. India lost an international match under Rohit Sharma's captaincy after 12 consecutive victories dating back to March 2018. The match witnessed India getting all out for 92",their seventh lowest total in ODI cricket his...,
Rahat Fateh Ali Khan denies getting notice for smuggling currency,"""Pakistani singer Rahat Fateh Ali Khan has denied receiving any notice from the Enforcement Directorate over allegedly smuggling foreign currency out of India. """"It would have been better if the authorities would have served the notice first if any and then publicised this",""""" reads a press release issued on behalf of R...",
"""India get all out for 92","their lowest ODI total in New Zealand""","""India recorded their lowest ODI total in New ...",summarize: while their number ten batsman Yuz...
UP cousins fed human excreta for friendship with boys,"""Two minor cousins in Uttar Pradesh's Gorakhpur were allegedly repeatedly burnt with tongs and forced to eat human excreta by their family for being friends with two boys from the same school. The cousins revealed their ordeal to the police and Child Welfare Committee after being brought back to Gorakhpur from Nepal","where they had fled to escape the torture.""",


Use Console for the Display of Process

In [None]:
# define a rich console logger
console=Console(record=True)

def display_df(df):
  """display dataframe in ASCII format"""

  console=Console()
  table = Table(Column("source_text", justify="center" ), Column("target_text", justify="center"), title="Sample Data",pad_edge=False, box=box.ASCII)

  for i, row in enumerate(df.values.tolist()):
    table.add_row(str(row[0]), str(row[1]))
  console.print(table)

training_logger = Table(Column("Epoch", justify="center" ),
                        Column("Steps", justify="center"),
                        Column("Loss", justify="center"),
                        title="Training Status",pad_edge=False, box=box.ASCII)


In [None]:
# Setting up the device for GPU usage
device = 'cuda' if cuda.is_available() else 'cpu'

  Creating a custom dataset for reading the dataset and
  loading it into the dataloader to pass it to the neural network for finetuning the model

In [None]:
class YourDataSetClass(Dataset):

  def __init__(self, dataframe, tokenizer, source_len, target_len, source_text, target_text):
    self.tokenizer = tokenizer
    self.data = dataframe
    self.source_len = source_len
    self.summ_len = target_len
    self.target_text = self.data[target_text]
    self.source_text = self.data[source_text]

  def __len__(self):
    return len(self.target_text)

  def __getitem__(self, index):
    source_text = str(self.source_text[index])
    target_text = str(self.target_text[index])

    #cleaning data so as to ensure data is in string type
    source_text = ' '.join(source_text.split())
    target_text = ' '.join(target_text.split())

    source = self.tokenizer.batch_encode_plus([source_text], max_length= self.source_len, pad_to_max_length=True, truncation=True, padding="max_length", return_tensors='pt')
    target = self.tokenizer.batch_encode_plus([target_text], max_length= self.summ_len, pad_to_max_length=True, truncation=True, padding="max_length", return_tensors='pt')

    source_ids = source['input_ids'].squeeze()
    source_mask = source['attention_mask'].squeeze()
    target_ids = target['input_ids'].squeeze()
    target_mask = target['attention_mask'].squeeze()

    return {
        'source_ids': source_ids.to(dtype=torch.long),
        'source_mask': source_mask.to(dtype=torch.long),
        'target_ids': target_ids.to(dtype=torch.long),
        'target_ids_y': target_ids.to(dtype=torch.long)
    }

In [None]:
def train(epoch, tokenizer, model, device, loader, optimizer):

  """
  Function to be called for training with the parameters passed from main function

  """

  model.train()
  for _,data in enumerate(loader, 0):
    y = data['target_ids'].to(device, dtype = torch.long)
    y_ids = y[:, :-1].contiguous()
    lm_labels = y[:, 1:].clone().detach()
    lm_labels[y[:, 1:] == tokenizer.pad_token_id] = -100
    ids = data['source_ids'].to(device, dtype = torch.long)
    mask = data['source_mask'].to(device, dtype = torch.long)

    outputs = model(input_ids = ids, attention_mask = mask, decoder_input_ids=y_ids, labels=lm_labels)
    loss = outputs[0]

    if _%10==0:
      training_logger.add_row(str(epoch), str(_), str(loss))
      console.print(training_logger)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
def validate(epoch, tokenizer, model, device, loader):

  """
  Function to evaluate model for predictions

  """
  model.eval()
  predictions = []
  actuals = []
  with torch.no_grad():
      for _, data in enumerate(loader, 0):
          y = data['target_ids'].to(device, dtype = torch.long)
          ids = data['source_ids'].to(device, dtype = torch.long)
          mask = data['source_mask'].to(device, dtype = torch.long)

          generated_ids = model.generate(
              input_ids = ids,
              attention_mask = mask,
              max_length=150,
              num_beams=2,
              repetition_penalty=2.5,
              length_penalty=1.0,
              early_stopping=True
              )
          preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]
          target = [tokenizer.decode(t, skip_special_tokens=True, clean_up_tokenization_spaces=True)for t in y]
          if _%10==0:
              console.print(f'Completed {_}')

          predictions.extend(preds)
          actuals.extend(target)
  return predictions, actuals

In [None]:
def T5Trainer(dataframe, source_text, target_text, model_params, output_dir="./outputs/" ):

  """
  T5 trainer

  """

  # Set random seeds and deterministic pytorch for reproducibility
  torch.manual_seed(model_params["SEED"]) # pytorch random seed
  np.random.seed(model_params["SEED"]) # numpy random seed
  torch.backends.cudnn.deterministic = True

  # logging
  console.log(f"""[Model]: Loading {model_params["MODEL"]}...\n""")

  # tokenzier for encoding the text
  tokenizer = T5Tokenizer.from_pretrained(model_params["MODEL"])

  # Defining the model. We are using t5-base model and added a Language model layer on top for generation of Summary.
  # Further this model is sent to device (GPU/TPU) for using the hardware.
  model = T5ForConditionalGeneration.from_pretrained(model_params["MODEL"])
  model = model.to(device)

  # logging
  console.log(f"[Data]: Reading data...\n")

  # Importing the raw dataset
  dataframe = dataframe[[source_text,target_text]]
  display_df(dataframe.head(2))


  # Creation of Dataset and Dataloader
  # Defining the train size. So 80% of the data will be used for training and the rest for validation.
  train_size = 0.8
  train_dataset=dataframe.sample(frac=train_size,random_state = model_params["SEED"])
  val_dataset=dataframe.drop(train_dataset.index).reset_index(drop=True)
  train_dataset = train_dataset.reset_index(drop=True)

  console.print(f"FULL Dataset: {dataframe.shape}")
  console.print(f"TRAIN Dataset: {train_dataset.shape}")
  console.print(f"TEST Dataset: {val_dataset.shape}\n")


  # Creating the Training and Validation dataset for further creation of Dataloader
  training_set = YourDataSetClass(train_dataset, tokenizer, model_params["MAX_SOURCE_TEXT_LENGTH"], model_params["MAX_TARGET_TEXT_LENGTH"], source_text, target_text)
  val_set = YourDataSetClass(val_dataset, tokenizer, model_params["MAX_SOURCE_TEXT_LENGTH"], model_params["MAX_TARGET_TEXT_LENGTH"], source_text, target_text)


  # Defining the parameters for creation of dataloaders
  train_params = {
      'batch_size': model_params["TRAIN_BATCH_SIZE"],
      'shuffle': True,
      'num_workers': 0
      }


  val_params = {
      'batch_size': model_params["VALID_BATCH_SIZE"],
      'shuffle': False,
      'num_workers': 0
      }


  # Creation of Dataloaders for testing and validation. This will be used down for training and validation stage for the model.
  training_loader = DataLoader(training_set, **train_params)
  val_loader = DataLoader(val_set, **val_params)


  # Defining the optimizer that will be used to tune the weights of the network in the training session.
  optimizer = torch.optim.Adam(params =  model.parameters(), lr=model_params["LEARNING_RATE"])


  # Training loop
  console.log(f'[Initiating Fine Tuning]...\n')

  for epoch in range(model_params["TRAIN_EPOCHS"]):
      train(epoch, tokenizer, model, device, training_loader, optimizer)

  console.log(f"[Saving Model]...\n")
  #Saving the model after training
  path = os.path.join(output_dir, "model_files")
  model.save_pretrained(path)
  tokenizer.save_pretrained(path)


# evaluating test dataset
  console.log(f"[Initiating Validation]...\n")
  for epoch in range(model_params["VAL_EPOCHS"]):
    predictions, actuals = validate(epoch, tokenizer, model, device, val_loader)
    final_df = pd.DataFrame({'Generated Text':predictions,'Actual Text':actuals})
    final_df.to_csv(os.path.join(output_dir,'predictions.csv'))

  console.save_text(os.path.join(output_dir,'logs.txt'))

  console.log(f"[Validation Completed.]\n")
  console.print(f"""[Model] Model saved @ {os.path.join(output_dir, "model_files")}\n""")
  console.print(f"""[Validation] Generation on Validation data saved @ {os.path.join(output_dir,'predictions.csv')}\n""")
  console.print(f"""[Logs] Logs saved @ {os.path.join(output_dir,'logs.txt')}\n""")


model paramters

In [None]:
model_params={
    "MODEL":"t5-base",             # model_type: t5-base/t5-large
    "TRAIN_BATCH_SIZE":8,          # training batch size
    "VALID_BATCH_SIZE":8,          # validation batch size
    "TRAIN_EPOCHS":3,              # number of training epochs
    "VAL_EPOCHS":1,                # number of validation epochs
    "LEARNING_RATE":1e-4,          # learning rate
    "MAX_SOURCE_TEXT_LENGTH":512,  # max length of source text
    "MAX_TARGET_TEXT_LENGTH":50,   # max length of target text
    "SEED": 42                     # set seed for reproducibility

}

In [None]:
T5Trainer(dataframe=df[:500], source_text="text", target_text="headlines", model_params=model_params, output_dir="outputs")



In [None]:
# Prediction function to generate summaries
def generate_summary(text, model, tokenizer, max_length=150):
    model.eval()
    # Preprocess the text with the T5 format
    input_text = "summarize: " + text
    inputs = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True).to(device)

    # Generate summary using T5 model
    summary_ids = model.generate(inputs, max_length=max_length, num_beams=4, length_penalty=2.0, early_stopping=True)
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)

    return summary

# Example usage of the prediction function
sample_text = "Your input text for summarization goes here."
summary = generate_summary(sample_text, model, tokenizer)
print("Generated Summary:", summary)