In [1]:
import torch
torch.cuda.is_available()

True

In [2]:
!pip install transformers[sentencepiece] datasets sacrebleu rouge_score py7zr -q

In [3]:
!pip install --upgrade accelerate
!pip uninstall -y transformers accelerate
!pip install transformers accelerate

Collecting accelerate
  Using cached accelerate-0.29.3-py3-none-any.whl.metadata (18 kB)
Using cached accelerate-0.29.3-py3-none-any.whl (297 kB)
Installing collected packages: accelerate
  Attempting uninstall: accelerate
    Found existing installation: accelerate 0.28.0
    Uninstalling accelerate-0.28.0:
      Successfully uninstalled accelerate-0.28.0
Successfully installed accelerate-0.29.3
Found existing installation: transformers 4.39.1
Uninstalling transformers-4.39.1:
  Successfully uninstalled transformers-4.39.1
Found existing installation: accelerate 0.29.3
Uninstalling accelerate-0.29.3:
  Successfully uninstalled accelerate-0.29.3


In [None]:
# Import necessary modules
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
import matplotlib.pyplot as plt
import pandas as pd
from datasets import load_dataset, load_metric, concatenate_datasets
import nltk
from nltk.tokenize import sent_tokenize
from tqdm import tqdm
import torch

nltk.download("punkt")

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\priks\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [None]:
token = "User Access Tokens"
import subprocess
subprocess.run(["huggingface-cli", "login", "--token", token], check=True)

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [4]:
# Define the checkpoint or model identifier for the BART model
model_ckpt = "google-t5/t5-small"

# Initialize a tokenizer using the pretrained BART model's tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)

# Load the pretrained BART model for sequence-to-sequence tasks
# and move it to a specified device (e.g., GPU if available)
model = AutoModelForSeq2SeqLM.from_pretrained(model_ckpt).to(device)


In [4]:
# load the dataset, using samsum dataset from hugging face

from datasets import load_dataset

dataset_name = "samsum"
dataset = load_dataset(dataset_name)

In [5]:
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 14732
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 819
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 818
    })
})

In [6]:
print(f"text: \n{dataset['train'][1]['dialogue']}, \n\nsummary: \n{dataset['train'][1]['summary']}")

text: 
Olivia: Who are you voting for in this election? 
Oliver: Liberals as always.
Olivia: Me too!!
Oliver: Great, 

summary: 
Olivia and Olivier are voting for liberals in this election. 


In [7]:
# Split the dataset into training, testing, and validation sets
train_dataset = dataset['train']
test_dataset = dataset['test']
validation_dataset = dataset['validation']

# Print information about the datasets
print("Training dataset:")
print(train_dataset)
print("Testing dataset:")
print(test_dataset)
print("Validation dataset:")
print(validation_dataset)

Training dataset:
Dataset({
    features: ['id', 'dialogue', 'summary'],
    num_rows: 14732
})
Testing dataset:
Dataset({
    features: ['id', 'dialogue', 'summary'],
    num_rows: 819
})
Validation dataset:
Dataset({
    features: ['id', 'dialogue', 'summary'],
    num_rows: 818
})


In [9]:
prefix = "summarize: "

def convert_examples_to_features(example_batch):
    # Tokenize the dialogue from the example batch
    inputs = [prefix + doc for doc in example_batch['dialogue']]
    model_inputs = tokenizer(
        inputs,                     # Extract dialogue from the example batch
        max_length=1024,             # Set maximum sequence length for input
        truncation=True             # Truncate sequences that exceed max_length
    )

    # Tokenize the summary from the example batch as target
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            example_batch['summary'],  # Extract summary from the example batch
            max_length=128,             # Set maximum sequence length for summary
            truncation=True            # Truncate sequences that exceed max_length
        )

    
    model_inputs["labels"] = labels["input_ids"]
    # Return a dictionary containing input and target encodings
    return model_inputs


In [10]:
dataset_samsum_pt = dataset.map(convert_examples_to_features, batched = True) # batched=True to process multiple elements of the dataset at once

In [11]:
dataset_samsum_pt

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 14732
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 819
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 818
    })
})

In [12]:
dataset_samsum_pt["train"]

Dataset({
    features: ['id', 'dialogue', 'summary', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 14732
})

In [13]:
dataset_samsum_pt["train"][0]

{'id': '13818513',
 'dialogue': "Amanda: I baked  cookies. Do you want some?\r\nJerry: Sure!\r\nAmanda: I'll bring you tomorrow :-)",
 'summary': 'Amanda baked cookies and will bring Jerry some tomorrow.',
 'input_ids': [21603,
  10,
  21542,
  10,
  27,
  13635,
  5081,
  5,
  531,
  25,
  241,
  128,
  58,
  16637,
  10,
  10625,
  55,
  21542,
  10,
  27,
  31,
  195,
  830,
  25,
  5721,
  3,
  10,
  18,
  61,
  1],
 'attention_mask': [1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1],
 'labels': [21542, 13635, 5081, 11, 56, 830, 16637, 128, 5721, 5, 1]}

In [14]:
# Training

from transformers import DataCollatorForSeq2Seq

# create a batch of examples using DataCollatorForSeq2Seq. It’s more efficient to dynamically pad the sentences to the longest length in a batch during collation, instead of padding the whole dataset to the maximum length.
seq2seq_data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
     

In [15]:
from transformers import Trainer, TrainingArguments

# Adjusted parameters for better GPU utilization and checkpoint saving
training_args = TrainingArguments(
    output_dir='t5-small-samsum',
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=64,             # total number of training epochs
    per_device_train_batch_size=8,   # batch size
    per_device_eval_batch_size=32,    # batch size for evaluation
    warmup_steps=500,
    weight_decay=0.01,
    save_total_limit=3,               # Save the last 5 checkpoints
    gradient_accumulation_steps=4,    # Gradient accumulation
    logging_dir='./logs',             # Directory for logs
    save_strategy="epoch",            # Save checkpoint at the end of each epoch
    load_best_model_at_end=True,      # Load the best model based on early stopping
    metric_for_best_model="eval_loss",# Metric to use for early stopping
    greater_is_better=False,          # Whether the metric should be minimized or maximized
    fp16=True,                        # Mixed precision training
)

trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    tokenizer=tokenizer,                 # the tokenizer used to preprocess the data
    data_collator=seq2seq_data_collator, # data collator used for batching and padding
    train_dataset=dataset_samsum_pt["train"],         # training dataset
    eval_dataset=dataset_samsum_pt["validation"],     # evaluation dataset
)



dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [16]:
# Start training
trainer.train()

  2%|▏         | 460/29440 [01:40<1:39:29,  4.85it/s]
  2%|▏         | 460/29440 [01:44<1:39:29,  4.85it/s]

{'eval_loss': 1.9598251581192017, 'eval_runtime': 3.9133, 'eval_samples_per_second': 209.032, 'eval_steps_per_second': 6.644, 'epoch': 1.0}


  2%|▏         | 500/29440 [01:53<1:42:00,  4.73it/s] 

{'loss': 2.4944, 'grad_norm': 1.68755042552948, 'learning_rate': 1.9960000000000002e-05, 'epoch': 1.09}


                                                     
  3%|▎         | 921/29440 [03:27<1:27:56,  5.41it/s]

{'eval_loss': 1.8660821914672852, 'eval_runtime': 3.8432, 'eval_samples_per_second': 212.845, 'eval_steps_per_second': 6.765, 'epoch': 2.0}


  3%|▎         | 1001/29440 [03:45<1:41:26,  4.67it/s]

{'loss': 2.0902, 'grad_norm': 1.9294962882995605, 'learning_rate': 1.9655148583275746e-05, 'epoch': 2.17}


                                                      
  5%|▍         | 1381/29440 [05:11<1:45:45,  4.42it/s]

{'eval_loss': 1.821037769317627, 'eval_runtime': 3.8688, 'eval_samples_per_second': 211.438, 'eval_steps_per_second': 6.721, 'epoch': 3.0}


  5%|▌         | 1500/29440 [05:37<1:31:54,  5.07it/s] 

{'loss': 2.0173, 'grad_norm': 1.5648648738861084, 'learning_rate': 1.930960608154803e-05, 'epoch': 3.26}


                                                      
  6%|▋         | 1842/29440 [06:56<1:41:06,  4.55it/s]

{'eval_loss': 1.800935983657837, 'eval_runtime': 3.9288, 'eval_samples_per_second': 208.207, 'eval_steps_per_second': 6.618, 'epoch': 4.0}


  7%|▋         | 2000/29440 [07:32<1:35:14,  4.80it/s] 

{'loss': 1.9623, 'grad_norm': 2.0426533222198486, 'learning_rate': 1.896406357982032e-05, 'epoch': 4.34}


                                                      
  8%|▊         | 2302/29440 [08:41<1:38:34,  4.59it/s]

{'eval_loss': 1.778729796409607, 'eval_runtime': 3.817, 'eval_samples_per_second': 214.303, 'eval_steps_per_second': 6.812, 'epoch': 5.0}


  8%|▊         | 2500/29440 [09:25<1:31:04,  4.93it/s] 

{'loss': 1.9331, 'grad_norm': 1.7702367305755615, 'learning_rate': 1.8619212163096064e-05, 'epoch': 5.43}


                                                      
  9%|▉         | 2763/29440 [10:40<2:26:28,  3.04it/s]

{'eval_loss': 1.7636957168579102, 'eval_runtime': 4.1186, 'eval_samples_per_second': 198.612, 'eval_steps_per_second': 6.313, 'epoch': 6.0}


 10%|█         | 3000/29440 [12:03<2:26:50,  3.00it/s] 

{'loss': 1.903, 'grad_norm': 1.3852741718292236, 'learning_rate': 1.8274360746371806e-05, 'epoch': 6.51}


                                                      
 11%|█         | 3223/29440 [13:25<2:33:28,  2.85it/s]

{'eval_loss': 1.751402735710144, 'eval_runtime': 4.1196, 'eval_samples_per_second': 198.564, 'eval_steps_per_second': 6.311, 'epoch': 7.0}


 12%|█▏        | 3500/29440 [15:02<2:33:30,  2.82it/s] 

{'loss': 1.881, 'grad_norm': 1.5601893663406372, 'learning_rate': 1.7928818244644094e-05, 'epoch': 7.6}


                                                      
 13%|█▎        | 3684/29440 [16:10<2:29:31,  2.87it/s]

{'eval_loss': 1.739009141921997, 'eval_runtime': 4.0537, 'eval_samples_per_second': 201.79, 'eval_steps_per_second': 6.414, 'epoch': 8.0}


 14%|█▎        | 4000/29440 [18:00<2:24:26,  2.94it/s] 

{'loss': 1.8648, 'grad_norm': 1.3928889036178589, 'learning_rate': 1.758327574291638e-05, 'epoch': 8.69}


                                                      
 14%|█▍        | 4144/29440 [18:54<2:29:17,  2.82it/s]

{'eval_loss': 1.7349522113800049, 'eval_runtime': 4.0981, 'eval_samples_per_second': 199.604, 'eval_steps_per_second': 6.344, 'epoch': 9.0}


 15%|█▌        | 4501/29440 [20:36<1:21:23,  5.11it/s] 

{'loss': 1.8463, 'grad_norm': 1.2980334758758545, 'learning_rate': 1.7237733241188667e-05, 'epoch': 9.77}


                                                      
 16%|█▌        | 4605/29440 [21:02<1:18:18,  5.29it/s]

{'eval_loss': 1.7242380380630493, 'eval_runtime': 3.91, 'eval_samples_per_second': 209.207, 'eval_steps_per_second': 6.65, 'epoch': 10.0}


 17%|█▋        | 5001/29440 [22:32<1:25:37,  4.76it/s] 

{'loss': 1.8302, 'grad_norm': 1.7003198862075806, 'learning_rate': 1.6892190739460956e-05, 'epoch': 10.86}


                                                      
 17%|█▋        | 5065/29440 [22:50<1:35:39,  4.25it/s]

{'eval_loss': 1.7188748121261597, 'eval_runtime': 3.9509, 'eval_samples_per_second': 207.042, 'eval_steps_per_second': 6.581, 'epoch': 11.0}


 19%|█▊        | 5501/29440 [24:25<1:17:30,  5.15it/s] 

{'loss': 1.8119, 'grad_norm': 1.3742313385009766, 'learning_rate': 1.6546648237733244e-05, 'epoch': 11.94}


                                                      
 19%|█▉        | 5526/29440 [24:34<1:17:38,  5.13it/s]

{'eval_loss': 1.7098318338394165, 'eval_runtime': 3.8162, 'eval_samples_per_second': 214.347, 'eval_steps_per_second': 6.813, 'epoch': 12.0}


                                                       
 20%|██        | 5986/29440 [26:17<1:16:47,  5.09it/s]

{'eval_loss': 1.7076162099838257, 'eval_runtime': 3.8562, 'eval_samples_per_second': 212.127, 'eval_steps_per_second': 6.742, 'epoch': 13.0}


 20%|██        | 6000/29440 [26:21<1:25:33,  4.57it/s]

{'loss': 1.8007, 'grad_norm': 1.4094339609146118, 'learning_rate': 1.620110573600553e-05, 'epoch': 13.03}


                                                      
 22%|██▏       | 6447/29440 [28:00<1:13:43,  5.20it/s]

{'eval_loss': 1.7057257890701294, 'eval_runtime': 3.8089, 'eval_samples_per_second': 214.76, 'eval_steps_per_second': 6.826, 'epoch': 14.0}


 22%|██▏       | 6500/29440 [28:11<1:13:05,  5.23it/s]

{'loss': 1.7903, 'grad_norm': 1.3425556421279907, 'learning_rate': 1.5855563234277817e-05, 'epoch': 14.12}


                                                      
 23%|██▎       | 6907/29440 [29:42<1:31:30,  4.10it/s]

{'eval_loss': 1.6984028816223145, 'eval_runtime': 3.8495, 'eval_samples_per_second': 212.495, 'eval_steps_per_second': 6.754, 'epoch': 15.0}


 24%|██▍       | 7000/29440 [30:02<1:17:10,  4.85it/s]

{'loss': 1.778, 'grad_norm': 1.34147047996521, 'learning_rate': 1.5510020732550106e-05, 'epoch': 15.2}


                                                      
 25%|██▌       | 7368/29440 [31:24<1:18:44,  4.67it/s]

{'eval_loss': 1.6943695545196533, 'eval_runtime': 3.8173, 'eval_samples_per_second': 214.289, 'eval_steps_per_second': 6.811, 'epoch': 16.0}


 25%|██▌       | 7500/29440 [31:53<1:23:43,  4.37it/s]

{'loss': 1.7639, 'grad_norm': 1.3456470966339111, 'learning_rate': 1.5164478230822392e-05, 'epoch': 16.29}


                                                      
 27%|██▋       | 7828/29440 [33:06<1:16:09,  4.73it/s]

{'eval_loss': 1.690664291381836, 'eval_runtime': 3.8162, 'eval_samples_per_second': 214.35, 'eval_steps_per_second': 6.813, 'epoch': 17.0}


 27%|██▋       | 8000/29440 [33:44<1:06:30,  5.37it/s]

{'loss': 1.7596, 'grad_norm': 1.4662257432937622, 'learning_rate': 1.4818935729094679e-05, 'epoch': 17.37}


                                                      
 28%|██▊       | 8289/29440 [34:48<1:08:19,  5.16it/s]

{'eval_loss': 1.689630389213562, 'eval_runtime': 3.8466, 'eval_samples_per_second': 212.654, 'eval_steps_per_second': 6.759, 'epoch': 18.0}


 29%|██▉       | 8500/29440 [35:34<1:07:22,  5.18it/s]

{'loss': 1.746, 'grad_norm': 1.3186774253845215, 'learning_rate': 1.4473393227366967e-05, 'epoch': 18.46}


                                                      
 30%|██▉       | 8749/29440 [36:30<1:10:23,  4.90it/s]

{'eval_loss': 1.6861369609832764, 'eval_runtime': 3.8245, 'eval_samples_per_second': 213.887, 'eval_steps_per_second': 6.798, 'epoch': 19.0}


 31%|███       | 9000/29440 [37:25<1:09:59,  4.87it/s]

{'loss': 1.7342, 'grad_norm': 1.6053637266159058, 'learning_rate': 1.4129232895646165e-05, 'epoch': 19.54}


                                                      
 31%|███▏      | 9210/29440 [38:13<1:04:16,  5.25it/s]

{'eval_loss': 1.6859763860702515, 'eval_runtime': 3.8361, 'eval_samples_per_second': 213.237, 'eval_steps_per_second': 6.778, 'epoch': 20.0}


 32%|███▏      | 9500/29440 [39:14<1:06:00,  5.03it/s]

{'loss': 1.732, 'grad_norm': 1.9439330101013184, 'learning_rate': 1.3783690393918455e-05, 'epoch': 20.63}


                                                      
 33%|███▎      | 9670/29440 [39:55<1:04:40,  5.10it/s]

{'eval_loss': 1.6808369159698486, 'eval_runtime': 3.8567, 'eval_samples_per_second': 212.098, 'eval_steps_per_second': 6.741, 'epoch': 21.0}


 34%|███▍      | 10001/29440 [41:06<1:01:40,  5.25it/s]

{'loss': 1.719, 'grad_norm': 1.5432207584381104, 'learning_rate': 1.343814789219074e-05, 'epoch': 21.72}


                                                       
 34%|███▍      | 10131/29440 [41:37<1:03:02,  5.11it/s]

{'eval_loss': 1.6759803295135498, 'eval_runtime': 3.8232, 'eval_samples_per_second': 213.956, 'eval_steps_per_second': 6.801, 'epoch': 22.0}


 36%|███▌      | 10500/29440 [42:55<1:07:55,  4.65it/s]

{'loss': 1.7152, 'grad_norm': 1.4856846332550049, 'learning_rate': 1.309260539046303e-05, 'epoch': 22.8}


                                                       
 36%|███▌      | 10591/29440 [43:20<1:07:12,  4.67it/s]

{'eval_loss': 1.677799105644226, 'eval_runtime': 3.8832, 'eval_samples_per_second': 210.653, 'eval_steps_per_second': 6.696, 'epoch': 23.0}


 37%|███▋      | 11000/29440 [44:49<1:13:10,  4.20it/s]

{'loss': 1.7082, 'grad_norm': 1.4780641794204712, 'learning_rate': 1.2747062888735315e-05, 'epoch': 23.89}


                                                       
 38%|███▊      | 11052/29440 [45:04<1:05:08,  4.71it/s]

{'eval_loss': 1.6762239933013916, 'eval_runtime': 3.8343, 'eval_samples_per_second': 213.338, 'eval_steps_per_second': 6.781, 'epoch': 24.0}


 39%|███▉      | 11500/29440 [46:40<1:07:31,  4.43it/s]

{'loss': 1.7003, 'grad_norm': 2.406327486038208, 'learning_rate': 1.2401520387007605e-05, 'epoch': 24.97}


                                                       
 39%|███▉      | 11512/29440 [46:46<1:04:41,  4.62it/s]

{'eval_loss': 1.6707149744033813, 'eval_runtime': 3.8475, 'eval_samples_per_second': 212.608, 'eval_steps_per_second': 6.758, 'epoch': 25.0}


                                                       
 41%|████      | 11973/29440 [49:18<1:38:03,  2.97it/s]

{'eval_loss': 1.6721521615982056, 'eval_runtime': 4.0619, 'eval_samples_per_second': 201.384, 'eval_steps_per_second': 6.401, 'epoch': 26.0}


 41%|████      | 12000/29440 [49:28<1:36:32,  3.01it/s]

{'loss': 1.6952, 'grad_norm': 1.5470000505447388, 'learning_rate': 1.205597788527989e-05, 'epoch': 26.06}


                                                       
 42%|████▏     | 12433/29440 [51:58<1:26:48,  3.27it/s]

{'eval_loss': 1.6701104640960693, 'eval_runtime': 3.9891, 'eval_samples_per_second': 205.057, 'eval_steps_per_second': 6.518, 'epoch': 27.0}


 42%|████▏     | 12500/29440 [52:22<1:35:28,  2.96it/s]

{'loss': 1.6848, 'grad_norm': 1.2745064496994019, 'learning_rate': 1.1710435383552178e-05, 'epoch': 27.14}


                                                       
 44%|████▍     | 12894/29440 [54:43<1:38:14,  2.81it/s]

{'eval_loss': 1.6671485900878906, 'eval_runtime': 4.0669, 'eval_samples_per_second': 201.136, 'eval_steps_per_second': 6.393, 'epoch': 28.0}


 44%|████▍     | 13000/29440 [55:20<1:31:02,  3.01it/s]

{'loss': 1.6814, 'grad_norm': 1.4625930786132812, 'learning_rate': 1.136558396682792e-05, 'epoch': 28.23}


                                                       
 45%|████▌     | 13354/29440 [57:22<1:31:44,  2.92it/s]

{'eval_loss': 1.6667605638504028, 'eval_runtime': 4.0479, 'eval_samples_per_second': 202.078, 'eval_steps_per_second': 6.423, 'epoch': 29.0}


 46%|████▌     | 13500/29440 [58:12<1:28:13,  3.01it/s]

{'loss': 1.6743, 'grad_norm': 1.8524802923202515, 'learning_rate': 1.1020041465100208e-05, 'epoch': 29.32}


                                                       
 47%|████▋     | 13815/29440 [1:00:02<1:31:10,  2.86it/s]

{'eval_loss': 1.6637248992919922, 'eval_runtime': 4.0446, 'eval_samples_per_second': 202.242, 'eval_steps_per_second': 6.428, 'epoch': 30.0}


 48%|████▊     | 14000/29440 [1:01:05<1:27:00,  2.96it/s]

{'loss': 1.6742, 'grad_norm': 1.3432433605194092, 'learning_rate': 1.0675190048375952e-05, 'epoch': 30.4}


                                                         
 48%|████▊     | 14275/29440 [1:02:36<54:00,  4.68it/s]

{'eval_loss': 1.664025068283081, 'eval_runtime': 3.8539, 'eval_samples_per_second': 212.252, 'eval_steps_per_second': 6.746, 'epoch': 31.0}


 49%|████▉     | 14500/29440 [1:03:41<1:34:11,  2.64it/s]

{'loss': 1.6652, 'grad_norm': 1.300162434577942, 'learning_rate': 1.0329647546648239e-05, 'epoch': 31.49}


                                                         
 50%|█████     | 14736/29440 [1:05:05<1:22:37,  2.97it/s]

{'eval_loss': 1.6624494791030884, 'eval_runtime': 4.0649, 'eval_samples_per_second': 201.233, 'eval_steps_per_second': 6.396, 'epoch': 32.0}


 51%|█████     | 15000/29440 [1:06:36<1:26:04,  2.80it/s]

{'loss': 1.6582, 'grad_norm': 1.4229928255081177, 'learning_rate': 9.984796129923982e-06, 'epoch': 32.57}


                                                         
 52%|█████▏    | 15196/29440 [1:07:47<1:16:49,  3.09it/s]

{'eval_loss': 1.6605582237243652, 'eval_runtime': 4.0635, 'eval_samples_per_second': 201.306, 'eval_steps_per_second': 6.398, 'epoch': 33.0}


 53%|█████▎    | 15500/29440 [1:09:33<1:20:31,  2.89it/s]

{'loss': 1.6575, 'grad_norm': 1.6251088380813599, 'learning_rate': 9.639253628196269e-06, 'epoch': 33.66}


                                                         
 53%|█████▎    | 15657/29440 [1:10:33<1:25:37,  2.68it/s]

{'eval_loss': 1.6604702472686768, 'eval_runtime': 4.1072, 'eval_samples_per_second': 199.164, 'eval_steps_per_second': 6.33, 'epoch': 34.0}


 54%|█████▍    | 16000/29440 [1:12:31<1:18:12,  2.86it/s]

{'loss': 1.6499, 'grad_norm': 1.4407778978347778, 'learning_rate': 9.293711126468557e-06, 'epoch': 34.74}


                                                         
 55%|█████▍    | 16117/29440 [1:13:15<1:16:33,  2.90it/s]

{'eval_loss': 1.661665916442871, 'eval_runtime': 4.0198, 'eval_samples_per_second': 203.491, 'eval_steps_per_second': 6.468, 'epoch': 35.0}


 56%|█████▌    | 16500/29440 [1:15:24<1:08:35,  3.14it/s]

{'loss': 1.6455, 'grad_norm': 1.643412470817566, 'learning_rate': 8.948168624740844e-06, 'epoch': 35.83}


                                                         
 56%|█████▋    | 16578/29440 [1:15:53<1:07:33,  3.17it/s]

{'eval_loss': 1.660108208656311, 'eval_runtime': 4.0307, 'eval_samples_per_second': 202.941, 'eval_steps_per_second': 6.45, 'epoch': 36.0}


 58%|█████▊    | 17000/29440 [1:18:12<1:07:26,  3.07it/s]

{'loss': 1.6506, 'grad_norm': 1.5966081619262695, 'learning_rate': 8.602626123013132e-06, 'epoch': 36.92}


                                                         
 58%|█████▊    | 17038/29440 [1:18:29<1:08:02,  3.04it/s]

{'eval_loss': 1.6594067811965942, 'eval_runtime': 4.0078, 'eval_samples_per_second': 204.103, 'eval_steps_per_second': 6.487, 'epoch': 37.0}


                                                         
 59%|█████▉    | 17499/29440 [1:21:05<1:02:50,  3.17it/s]

{'eval_loss': 1.6556426286697388, 'eval_runtime': 4.0185, 'eval_samples_per_second': 203.559, 'eval_steps_per_second': 6.47, 'epoch': 38.0}


 59%|█████▉    | 17500/29440 [1:21:06<5:50:40,  1.76s/it]

{'loss': 1.637, 'grad_norm': 1.663530707359314, 'learning_rate': 8.257083621285419e-06, 'epoch': 38.0}


                                                         
 61%|██████    | 17959/29440 [1:23:44<1:08:08,  2.81it/s]

{'eval_loss': 1.6569980382919312, 'eval_runtime': 4.0468, 'eval_samples_per_second': 202.137, 'eval_steps_per_second': 6.425, 'epoch': 39.0}


 61%|██████    | 18000/29440 [1:23:59<1:06:47,  2.85it/s]

{'loss': 1.6374, 'grad_norm': 1.4430961608886719, 'learning_rate': 7.911541119557707e-06, 'epoch': 39.09}


                                                         
 63%|██████▎   | 18420/29440 [1:26:27<1:04:29,  2.85it/s]

{'eval_loss': 1.6557620763778687, 'eval_runtime': 4.2037, 'eval_samples_per_second': 194.591, 'eval_steps_per_second': 6.185, 'epoch': 40.0}


 63%|██████▎   | 18500/29440 [1:26:56<1:05:19,  2.79it/s]

{'loss': 1.6303, 'grad_norm': 1.2979075908660889, 'learning_rate': 7.565998617829994e-06, 'epoch': 40.17}


                                                         
 64%|██████▍   | 18880/29440 [1:29:10<1:01:30,  2.86it/s]

{'eval_loss': 1.6556789875030518, 'eval_runtime': 4.1448, 'eval_samples_per_second': 197.358, 'eval_steps_per_second': 6.273, 'epoch': 41.0}


 65%|██████▍   | 19000/29440 [1:29:53<59:34,  2.92it/s]  

{'loss': 1.6311, 'grad_norm': 1.56690514087677, 'learning_rate': 7.220456116102281e-06, 'epoch': 41.26}


                                                         
 66%|██████▌   | 19341/29440 [1:31:56<1:00:09,  2.80it/s]

{'eval_loss': 1.6552847623825073, 'eval_runtime': 4.199, 'eval_samples_per_second': 194.808, 'eval_steps_per_second': 6.192, 'epoch': 42.0}


 66%|██████▌   | 19500/29440 [1:32:32<36:16,  4.57it/s]  

{'loss': 1.6234, 'grad_norm': 1.461303949356079, 'learning_rate': 6.874913614374569e-06, 'epoch': 42.35}


                                                       
 67%|██████▋   | 19801/29440 [1:33:42<32:38,  4.92it/s]

{'eval_loss': 1.6569994688034058, 'eval_runtime': 3.8839, 'eval_samples_per_second': 210.611, 'eval_steps_per_second': 6.694, 'epoch': 43.0}


 68%|██████▊   | 20000/29440 [1:34:25<35:03,  4.49it/s]  

{'loss': 1.619, 'grad_norm': 1.6194016933441162, 'learning_rate': 6.530062197650312e-06, 'epoch': 43.43}


                                                       
 69%|██████▉   | 20262/29440 [1:35:28<32:45,  4.67it/s]

{'eval_loss': 1.6536585092544556, 'eval_runtime': 3.9719, 'eval_samples_per_second': 205.947, 'eval_steps_per_second': 6.546, 'epoch': 44.0}


 70%|██████▉   | 20500/29440 [1:36:21<34:00,  4.38it/s]  

{'loss': 1.6214, 'grad_norm': 1.7984278202056885, 'learning_rate': 6.184519695922599e-06, 'epoch': 44.52}


                                                       
 70%|███████   | 20722/29440 [1:37:14<30:54,  4.70it/s]

{'eval_loss': 1.6528544425964355, 'eval_runtime': 3.909, 'eval_samples_per_second': 209.261, 'eval_steps_per_second': 6.651, 'epoch': 45.0}


 71%|███████▏  | 21000/29440 [1:38:17<32:17,  4.36it/s]  

{'loss': 1.6183, 'grad_norm': 1.718948483467102, 'learning_rate': 5.838977194194887e-06, 'epoch': 45.6}


                                                       
 72%|███████▏  | 21183/29440 [1:39:00<29:41,  4.63it/s]

{'eval_loss': 1.6542410850524902, 'eval_runtime': 3.9424, 'eval_samples_per_second': 207.485, 'eval_steps_per_second': 6.595, 'epoch': 46.0}


 73%|███████▎  | 21500/29440 [1:40:09<30:04,  4.40it/s]  

{'loss': 1.609, 'grad_norm': 1.2070860862731934, 'learning_rate': 5.493434692467174e-06, 'epoch': 46.69}


                                                       
 74%|███████▎  | 21643/29440 [1:40:45<31:29,  4.13it/s]

{'eval_loss': 1.6543306112289429, 'eval_runtime': 3.8757, 'eval_samples_per_second': 211.058, 'eval_steps_per_second': 6.708, 'epoch': 47.0}


 75%|███████▍  | 22001/29440 [1:42:05<26:45,  4.63it/s]  

{'loss': 1.6159, 'grad_norm': 1.6577644348144531, 'learning_rate': 5.148583275742916e-06, 'epoch': 47.77}


                                                       
 75%|███████▌  | 22104/29440 [1:42:40<40:30,  3.02it/s]

{'eval_loss': 1.6529765129089355, 'eval_runtime': 4.0347, 'eval_samples_per_second': 202.742, 'eval_steps_per_second': 6.444, 'epoch': 48.0}


 76%|███████▋  | 22500/29440 [1:44:39<37:20,  3.10it/s]  

{'loss': 1.6101, 'grad_norm': 1.44333016872406, 'learning_rate': 4.803040774015204e-06, 'epoch': 48.86}


                                                       
 77%|███████▋  | 22564/29440 [1:44:59<26:31,  4.32it/s]

{'eval_loss': 1.652403712272644, 'eval_runtime': 3.8275, 'eval_samples_per_second': 213.716, 'eval_steps_per_second': 6.793, 'epoch': 49.0}


 78%|███████▊  | 23000/29440 [1:47:13<36:50,  2.91it/s]  

{'loss': 1.6083, 'grad_norm': 1.515410304069519, 'learning_rate': 4.457498272287491e-06, 'epoch': 49.95}


                                                       
 78%|███████▊  | 23025/29440 [1:47:26<36:46,  2.91it/s]

{'eval_loss': 1.6514601707458496, 'eval_runtime': 4.0382, 'eval_samples_per_second': 202.565, 'eval_steps_per_second': 6.439, 'epoch': 50.0}


                                                         
 80%|███████▉  | 23485/29440 [1:50:05<33:02,  3.00it/s]

{'eval_loss': 1.652807593345642, 'eval_runtime': 4.0039, 'eval_samples_per_second': 204.3, 'eval_steps_per_second': 6.494, 'epoch': 51.0}


 80%|███████▉  | 23500/29440 [1:50:11<33:35,  2.95it/s]  

{'loss': 1.605, 'grad_norm': 1.4280529022216797, 'learning_rate': 4.111955770559779e-06, 'epoch': 51.03}


                                                       
 81%|████████▏ | 23946/29440 [1:52:40<30:08,  3.04it/s]

{'eval_loss': 1.6526485681533813, 'eval_runtime': 4.0497, 'eval_samples_per_second': 201.99, 'eval_steps_per_second': 6.42, 'epoch': 52.0}


 82%|████████▏ | 24000/29440 [1:52:59<31:03,  2.92it/s]  

{'loss': 1.6011, 'grad_norm': 1.3336507081985474, 'learning_rate': 3.7664132688320666e-06, 'epoch': 52.12}


                                                       
 83%|████████▎ | 24406/29440 [1:55:14<26:32,  3.16it/s]

{'eval_loss': 1.6515066623687744, 'eval_runtime': 4.0076, 'eval_samples_per_second': 204.11, 'eval_steps_per_second': 6.488, 'epoch': 53.0}


 83%|████████▎ | 24500/29440 [1:55:46<27:41,  2.97it/s]  

{'loss': 1.6028, 'grad_norm': 1.296498417854309, 'learning_rate': 3.420870767104354e-06, 'epoch': 53.2}


                                                       
 84%|████████▍ | 24867/29440 [1:57:49<25:24,  3.00it/s]

{'eval_loss': 1.6517422199249268, 'eval_runtime': 4.0118, 'eval_samples_per_second': 203.897, 'eval_steps_per_second': 6.481, 'epoch': 54.0}


 85%|████████▍ | 25000/29440 [1:58:33<23:33,  3.14it/s]  

{'loss': 1.6015, 'grad_norm': 1.4270243644714355, 'learning_rate': 3.0753282653766416e-06, 'epoch': 54.29}


                                                       
 86%|████████▌ | 25327/29440 [2:00:31<21:48,  3.14it/s]

{'eval_loss': 1.6511937379837036, 'eval_runtime': 4.0136, 'eval_samples_per_second': 203.807, 'eval_steps_per_second': 6.478, 'epoch': 55.0}


 87%|████████▋ | 25500/29440 [2:01:30<22:36,  2.91it/s]  

{'loss': 1.601, 'grad_norm': 1.39304780960083, 'learning_rate': 2.7297857636489287e-06, 'epoch': 55.37}


                                                       
 88%|████████▊ | 25788/29440 [2:03:12<20:32,  2.96it/s]

{'eval_loss': 1.6504225730895996, 'eval_runtime': 4.0355, 'eval_samples_per_second': 202.702, 'eval_steps_per_second': 6.443, 'epoch': 56.0}


 88%|████████▊ | 26000/29440 [2:04:24<18:54,  3.03it/s]  

{'loss': 1.6007, 'grad_norm': 1.4652317762374878, 'learning_rate': 2.384934346924672e-06, 'epoch': 56.46}


                                                       
 89%|████████▉ | 26248/29440 [2:05:51<17:38,  3.02it/s]

{'eval_loss': 1.651258111000061, 'eval_runtime': 4.0149, 'eval_samples_per_second': 203.74, 'eval_steps_per_second': 6.476, 'epoch': 57.0}


 90%|█████████ | 26500/29440 [2:07:16<16:33,  2.96it/s]  

{'loss': 1.5948, 'grad_norm': 1.5659147500991821, 'learning_rate': 2.039391845196959e-06, 'epoch': 57.55}


                                                       
 91%|█████████ | 26709/29440 [2:08:30<17:26,  2.61it/s]

{'eval_loss': 1.651141881942749, 'eval_runtime': 3.959, 'eval_samples_per_second': 206.616, 'eval_steps_per_second': 6.567, 'epoch': 58.0}


 92%|█████████▏| 27000/29440 [2:09:57<13:31,  3.01it/s]  

{'loss': 1.5973, 'grad_norm': 1.9024454355239868, 'learning_rate': 1.6938493434692469e-06, 'epoch': 58.63}


                                                       
 92%|█████████▏| 27169/29440 [2:10:58<12:52,  2.94it/s]

{'eval_loss': 1.6515285968780518, 'eval_runtime': 4.0842, 'eval_samples_per_second': 200.285, 'eval_steps_per_second': 6.366, 'epoch': 59.0}


 93%|█████████▎| 27500/29440 [2:12:51<11:04,  2.92it/s]  

{'loss': 1.5929, 'grad_norm': 1.3006001710891724, 'learning_rate': 1.3483068417415344e-06, 'epoch': 59.72}


                                                       
 94%|█████████▍| 27630/29440 [2:13:40<09:43,  3.10it/s]

{'eval_loss': 1.6514065265655518, 'eval_runtime': 4.0568, 'eval_samples_per_second': 201.636, 'eval_steps_per_second': 6.409, 'epoch': 60.0}


 95%|█████████▌| 28000/29440 [2:15:46<08:05,  2.97it/s]  

{'loss': 1.5955, 'grad_norm': 1.7669113874435425, 'learning_rate': 1.0027643400138219e-06, 'epoch': 60.8}


                                                       
 95%|█████████▌| 28090/29440 [2:16:20<07:40,  2.93it/s]

{'eval_loss': 1.650673270225525, 'eval_runtime': 4.0477, 'eval_samples_per_second': 202.091, 'eval_steps_per_second': 6.423, 'epoch': 61.0}


 97%|█████████▋| 28500/29440 [2:18:40<05:29,  2.85it/s]

{'loss': 1.5931, 'grad_norm': 1.3935017585754395, 'learning_rate': 6.579129232895646e-07, 'epoch': 61.89}


                                                       
 97%|█████████▋| 28551/29440 [2:19:02<05:04,  2.92it/s]

{'eval_loss': 1.6507068872451782, 'eval_runtime': 4.0268, 'eval_samples_per_second': 203.138, 'eval_steps_per_second': 6.457, 'epoch': 62.0}


 99%|█████████▊| 29000/29440 [2:21:32<02:21,  3.10it/s]

{'loss': 1.5939, 'grad_norm': 2.4237279891967773, 'learning_rate': 3.123704215618521e-07, 'epoch': 62.98}


                                                       
 99%|█████████▊| 29011/29440 [2:21:39<02:14,  3.18it/s]

{'eval_loss': 1.6507459878921509, 'eval_runtime': 3.9931, 'eval_samples_per_second': 204.855, 'eval_steps_per_second': 6.511, 'epoch': 63.0}


                                                       
100%|██████████| 29440/29440 [2:24:09<00:00,  3.16it/s]

{'eval_loss': 1.6507415771484375, 'eval_runtime': 4.0337, 'eval_samples_per_second': 202.79, 'eval_steps_per_second': 6.446, 'epoch': 63.93}


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'].
100%|██████████| 29440/29440 [2:24:10<00:00,  3.40it/s]

{'train_runtime': 8650.2567, 'train_samples_per_second': 108.997, 'train_steps_per_second': 3.403, 'train_loss': 1.7147044596464738, 'epoch': 63.93}





TrainOutput(global_step=29440, training_loss=1.7147044596464738, metrics={'train_runtime': 8650.2567, 'train_samples_per_second': 108.997, 'train_steps_per_second': 3.403, 'train_loss': 1.7147044596464738, 'epoch': 63.93})

In [17]:
trainer.push_to_hub()

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]
[A

[A[A

training_args.bin: 100%|██████████| 4.86k/4.86k [00:01<00:00, 4.48kB/s]
spiece.model: 100%|██████████| 792k/792k [00:02<00:00, 345kB/s] 4MB/s]
model.safetensors: 100%|██████████| 242M/242M [01:13<00:00, 3.29MB/s] 

Upload 3 LFS files: 100%|██████████| 3/3 [01:15<00:00, 25.07s/it]


CommitInfo(commit_url='https://huggingface.co/Prikshit7766/t5-small-samsum/commit/b1fc10c541454cb9ca44558042f8a2d1fe2ce8b5', commit_message='End of training', commit_description='', oid='b1fc10c541454cb9ca44558042f8a2d1fe2ce8b5', pr_url=None, pr_revision=None, pr_num=None)

In [12]:
# Evaluation

def generate_batch_sized_chunks(list_of_elements, batch_size):
    """split the dataset into smaller batches that we can process simultaneously
    Yield successive batch-sized chunks from list_of_elements."""
    for i in range(0, len(list_of_elements), batch_size):
        yield list_of_elements[i : i + batch_size]



def calculate_metric_on_test_ds(dataset, metric, model, tokenizer,
                               batch_size=16, device=device,
                               column_text="article",
                               column_summary="highlights"):
    article_batches = list(generate_batch_sized_chunks(dataset[column_text], batch_size))
    target_batches = list(generate_batch_sized_chunks(dataset[column_summary], batch_size))

    for article_batch, target_batch in tqdm(
        zip(article_batches, target_batches), total=len(article_batches)):

        inputs = tokenizer(article_batch, max_length=1024,  truncation=True,
                        padding="max_length", return_tensors="pt")

        summaries = model.generate(input_ids=inputs["input_ids"].to(device),
                         attention_mask=inputs["attention_mask"].to(device),
                         length_penalty=0.8, num_beams=8, max_length=128)
        ''' parameter for length penalty ensures that the model does not generate sequences that are too long. '''

        # Finally, we decode the generated texts,
        # replace the  token, and add the decoded texts with the references to the metric.
        decoded_summaries = [tokenizer.decode(s, skip_special_tokens=True,
                                clean_up_tokenization_spaces=True)
               for s in summaries]

        decoded_summaries = [d.replace("", " ") for d in decoded_summaries]


        metric.add_batch(predictions=decoded_summaries, references=target_batch)

    #  Finally compute and return the ROUGE scores.
    score = metric.compute()
    return score

In [14]:
rouge_names = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
rouge_metric = load_metric('rouge')

  rouge_metric = load_metric('rouge')
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


In [8]:
final_test_data = concatenate_datasets([test_dataset, validation_dataset])

In [9]:
final_test_data

Dataset({
    features: ['id', 'dialogue', 'summary'],
    num_rows: 1637
})

In [10]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("Prikshit7766/t5-small-samsum")
model = AutoModelForSeq2SeqLM.from_pretrained("Prikshit7766/t5-small-samsum").to(device)

In [15]:
score = calculate_metric_on_test_ds(
    final_test_data, rouge_metric, model, tokenizer, batch_size = 32, column_text = 'dialogue', column_summary= 'summary'
)

rouge_dict = dict((rn, score[rn].mid.fmeasure ) for rn in rouge_names )

rouge_dict

100%|██████████| 52/52 [30:08<00:00, 34.78s/it]


{'rouge1': 0.022735986144018477,
 'rouge2': 0.00038145898148848575,
 'rougeL': 0.022369386115261396,
 'rougeLsum': 0.022368230911790993}

In [17]:
#Prediction

gen_kwargs = {"length_penalty": 0.8, "num_beams":8, "max_length": 128}



sample_text = dataset["test"][0]["dialogue"]

reference = dataset["test"][0]["summary"]

pipe = pipeline("summarization", model="t5-small-samsum",tokenizer=tokenizer)

##
print("Dialogue:")
print(sample_text)


print("\nReference Summary:")
print(reference)


print("\nModel Summary:")
print(pipe(sample_text, **gen_kwargs)[0]["summary_text"])

Dialogue:
Hannah: Hey, do you have Betty's number?
Amanda: Lemme check
Hannah: <file_gif>
Amanda: Sorry, can't find it.
Amanda: Ask Larry
Amanda: He called her last time we were at the park together
Hannah: I don't know him well
Hannah: <file_gif>
Amanda: Don't be shy, he's very nice
Hannah: If you say so..
Hannah: I'd rather you texted him
Amanda: Just text him 🙂
Hannah: Urgh.. Alright
Hannah: Bye
Amanda: Bye bye

Reference Summary:
Hannah needs Betty's number but Amanda doesn't have it. She needs to contact Larry.

Model Summary:
Betty's number is Lemme. Amanda can't find it. Larry called Betty last time they were at the park together.
