### Step1: Import packages

In [1]:
import torch
import os
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments, AutoModelForPreTraining, pipeline
from transformers import T5Tokenizer, T5ForConditionalGeneration

### Step2: Read dataset

In [2]:
#ds = load_dataset("cnn_dailymail", "1.0.0")

In [36]:
train_ds = load_dataset("cnn_dailymail", "1.0.0", split='train[:90%]')
val_ds = load_dataset("cnn_dailymail", "1.0.0", split='train[90%:]')


In [41]:
train_ds.shuffle(seed=42)
val_ds.shuffle(seed=42)

Dataset({
    features: ['article', 'highlights', 'id'],
    num_rows: 28711
})

### Step3: Analyze data

In [42]:
torch.cuda.is_available()

True

In [9]:
#tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small", max_length=1024)
tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small", max_seq_len=1024)
#model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small", max_memory = 1024)
model = T5ForConditionalGeneration.from_pretrained("./summary/last-checkpoint-1536", max_memory = 1024)
print("Model weights loaded...\n")

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Model weights loaded...


In [43]:
def process_func(examples):
    contents = ['Generate summary: \n' + e for e in examples['article']]
    inputs = tokenizer(contents, max_length=1024, truncation=True)
    labels = tokenizer(text_target=examples['highlights'], max_length=64, truncation=True)
    inputs['labels'] = labels['input_ids']
    return inputs


In [11]:
tokenized_ds = ds.map(process_func, batched=True)

In [12]:
tokenizer.decode(tokenized_ds['train'][0]['input_ids'])

'Generate summary: LONDON, England (Reuters) -- Harry Potter star Daniel Radcliffe gains access to a reported £20 million ($41.1 million) fortune as he turns 18 on Monday, but he insists the money won\'t cast a spell on him. Daniel Radcliffe as Harry Potter in "Harry Potter and the Order of the Phoenix" To the disappointment of gossip columnists around the world, the young actor says he has no plans to fritter his cash away on fast cars, drink and celebrity parties. "I don\'t plan to be one of those people who, as soon as they turn 18, suddenly buy themselves a massive sports car collection or something similar," he told an Australian interviewer earlier this month. "I don\'t think I\'ll be particularly extravagant. "The things I like buying are things that cost about 10 pounds -- books and CDs and DVDs." At 18, Radcliffe will be able to gamble in a casino, buy a drink in a pub or see the horror film "Hostel: Part II," currently six places below his number one movie on the UK box offic

In [13]:
tokenizer.decode(tokenized_ds['train'][0]['labels'])

"Harry Potter star Daniel Radcliffe gets £20M fortune as he turns 18 Monday. Young actor says he has no plans to fritter his cash away. Radcliffe's earnings from first five Potter films have been held in trust fund.</s>"

In [14]:
ds['train'][0]['highlights']

"Harry Potter star Daniel Radcliffe gets £20M fortune as he turns 18 Monday . Young actor says he has no plans to fritter his cash away . Radcliffe's earnings from first five Potter films have been held in trust fund ."

### Step4: Create model

In [15]:
#model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-small", max_length=1024)
#model = AutoModelForSeq2SeqLM.from_pretrained("./summary/checkpoint-1536", max_length = 1024)
#print("Model weights loaded...\n")

### Step5: Create evaluate function

In [16]:
import numpy as np
from rouge import Rouge

rouge = Rouge()


In [17]:
def compute_metric(evalPred):
    preds, labels = evalPred
    decode_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decode_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    decode_preds = [" ".join(p) for p in decode_preds]
    decode_labels = [" ".join(p) for p in decode_labels]
    scores = rouge.get_scores(decode_preds, decode_labels, avg=True)
    return {
        "rouge-1": scores['rouge-1']['f'],
        "rouge-2": scores['rouge-2']['f'],
        "rouge-l": scores['rouge-l']['f']
    }



### Step6: Set training parameters

In [18]:
args = Seq2SeqTrainingArguments(
    output_dir="./summary",
    learning_rate=3e-4,
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=8,
    warmup_steps=128,
    logging_steps=512,
    logging_dir="./logging",
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=512,
    save_total_limit=3,     # save the last 3 model
    metric_for_best_model="rouge-l",
    predict_with_generate=True,  # must set True
    #load_best_model_at_end=True
)

### Step7: Create trainer

In [19]:
trainer = Seq2SeqTrainer(
    args=args,
    model=model,
    train_dataset=tokenized_ds['train'],
    eval_dataset=tokenized_ds['test'],
    compute_metrics=compute_metric,
    tokenizer=tokenizer,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer)
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


### Step8: Train the model

In [20]:
trainer.train() 

Step,Training Loss,Validation Loss


Step,Training Loss,Validation Loss,Rouge-1,Rouge-2,Rouge-l
512,1.7432,1.724787,0.771811,0.411786,0.691962


Step,Training Loss,Validation Loss,Rouge-1,Rouge-2,Rouge-l
512,1.7432,1.724787,0.771811,0.411786,0.691962
1024,1.7311,1.722753,0.772798,0.413633,0.691759


Step,Training Loss,Validation Loss,Rouge-1,Rouge-2,Rouge-l
512,1.7432,1.724787,0.771811,0.411786,0.691962
1024,1.7311,1.722753,0.772798,0.413633,0.691759
1536,1.7333,1.709278,0.772665,0.414348,0.693385


Step,Training Loss,Validation Loss,Rouge-1,Rouge-2,Rouge-l
512,1.7432,1.724787,0.771811,0.411786,0.691962
1024,1.7311,1.722753,0.772798,0.413633,0.691759
1536,1.7333,1.709278,0.772665,0.414348,0.693385
2048,1.8558,1.680838,0.773619,0.416598,0.695329


Step,Training Loss,Validation Loss,Rouge-1,Rouge-2,Rouge-l
512,1.7432,1.724787,0.771811,0.411786,0.691962
1024,1.7311,1.722753,0.772798,0.413633,0.691759
1536,1.7333,1.709278,0.772665,0.414348,0.693385
2048,1.8558,1.680838,0.773619,0.416598,0.695329
2560,1.8427,1.661814,0.773028,0.413879,0.693893


Step,Training Loss,Validation Loss,Rouge-1,Rouge-2,Rouge-l
512,1.7432,1.724787,0.771811,0.411786,0.691962
1024,1.7311,1.722753,0.772798,0.413633,0.691759
1536,1.7333,1.709278,0.772665,0.414348,0.693385
2048,1.8558,1.680838,0.773619,0.416598,0.695329
2560,1.8427,1.661814,0.773028,0.413879,0.693893
3072,1.827,1.664096,0.772125,0.411808,0.693255


Step,Training Loss,Validation Loss,Rouge-1,Rouge-2,Rouge-l
512,1.7432,1.724787,0.771811,0.411786,0.691962
1024,1.7311,1.722753,0.772798,0.413633,0.691759
1536,1.7333,1.709278,0.772665,0.414348,0.693385
2048,1.8558,1.680838,0.773619,0.416598,0.695329
2560,1.8427,1.661814,0.773028,0.413879,0.693893
3072,1.827,1.664096,0.772125,0.411808,0.693255
3584,1.8285,1.653124,0.77231,0.414298,0.693539


Step,Training Loss,Validation Loss,Rouge-1,Rouge-2,Rouge-l
512,1.7432,1.724787,0.771811,0.411786,0.691962
1024,1.7311,1.722753,0.772798,0.413633,0.691759
1536,1.7333,1.709278,0.772665,0.414348,0.693385
2048,1.8558,1.680838,0.773619,0.416598,0.695329
2560,1.8427,1.661814,0.773028,0.413879,0.693893
3072,1.827,1.664096,0.772125,0.411808,0.693255
3584,1.8285,1.653124,0.77231,0.414298,0.693539
4096,1.8193,1.656115,0.771975,0.41302,0.693362


Step,Training Loss,Validation Loss,Rouge-1,Rouge-2,Rouge-l
512,1.7432,1.724787,0.771811,0.411786,0.691962
1024,1.7311,1.722753,0.772798,0.413633,0.691759
1536,1.7333,1.709278,0.772665,0.414348,0.693385
2048,1.8558,1.680838,0.773619,0.416598,0.695329
2560,1.8427,1.661814,0.773028,0.413879,0.693893
3072,1.827,1.664096,0.772125,0.411808,0.693255
3584,1.8285,1.653124,0.77231,0.414298,0.693539
4096,1.8193,1.656115,0.771975,0.41302,0.693362
4608,1.8059,1.660652,0.772301,0.412626,0.693388


Step,Training Loss,Validation Loss,Rouge-1,Rouge-2,Rouge-l
512,1.7432,1.724787,0.771811,0.411786,0.691962
1024,1.7311,1.722753,0.772798,0.413633,0.691759
1536,1.7333,1.709278,0.772665,0.414348,0.693385
2048,1.8558,1.680838,0.773619,0.416598,0.695329
2560,1.8427,1.661814,0.773028,0.413879,0.693893
3072,1.827,1.664096,0.772125,0.411808,0.693255
3584,1.8285,1.653124,0.77231,0.414298,0.693539
4096,1.8193,1.656115,0.771975,0.41302,0.693362
4608,1.8059,1.660652,0.772301,0.412626,0.693388
5120,1.778,1.644493,0.772808,0.413753,0.694583


Step,Training Loss,Validation Loss,Rouge-1,Rouge-2,Rouge-l
512,1.7432,1.724787,0.771811,0.411786,0.691962
1024,1.7311,1.722753,0.772798,0.413633,0.691759
1536,1.7333,1.709278,0.772665,0.414348,0.693385
2048,1.8558,1.680838,0.773619,0.416598,0.695329
2560,1.8427,1.661814,0.773028,0.413879,0.693893
3072,1.827,1.664096,0.772125,0.411808,0.693255
3584,1.8285,1.653124,0.77231,0.414298,0.693539
4096,1.8193,1.656115,0.771975,0.41302,0.693362
4608,1.8059,1.660652,0.772301,0.412626,0.693388
5120,1.778,1.644493,0.772808,0.413753,0.694583


Step,Training Loss,Validation Loss,Rouge-1,Rouge-2,Rouge-l
512,1.7432,1.724787,0.771811,0.411786,0.691962
1024,1.7311,1.722753,0.772798,0.413633,0.691759
1536,1.7333,1.709278,0.772665,0.414348,0.693385
2048,1.8558,1.680838,0.773619,0.416598,0.695329
2560,1.8427,1.661814,0.773028,0.413879,0.693893
3072,1.827,1.664096,0.772125,0.411808,0.693255
3584,1.8285,1.653124,0.77231,0.414298,0.693539
4096,1.8193,1.656115,0.771975,0.41302,0.693362
4608,1.8059,1.660652,0.772301,0.412626,0.693388
5120,1.778,1.644493,0.772808,0.413753,0.694583


Step,Training Loss,Validation Loss,Rouge-1,Rouge-2,Rouge-l
512,1.7432,1.724787,0.771811,0.411786,0.691962
1024,1.7311,1.722753,0.772798,0.413633,0.691759
1536,1.7333,1.709278,0.772665,0.414348,0.693385
2048,1.8558,1.680838,0.773619,0.416598,0.695329
2560,1.8427,1.661814,0.773028,0.413879,0.693893
3072,1.827,1.664096,0.772125,0.411808,0.693255
3584,1.8285,1.653124,0.77231,0.414298,0.693539
4096,1.8193,1.656115,0.771975,0.41302,0.693362
4608,1.8059,1.660652,0.772301,0.412626,0.693388
5120,1.778,1.644493,0.772808,0.413753,0.694583


Step,Training Loss,Validation Loss,Rouge-1,Rouge-2,Rouge-l
512,1.7432,1.724787,0.771811,0.411786,0.691962
1024,1.7311,1.722753,0.772798,0.413633,0.691759
1536,1.7333,1.709278,0.772665,0.414348,0.693385
2048,1.8558,1.680838,0.773619,0.416598,0.695329
2560,1.8427,1.661814,0.773028,0.413879,0.693893
3072,1.827,1.664096,0.772125,0.411808,0.693255
3584,1.8285,1.653124,0.77231,0.414298,0.693539
4096,1.8193,1.656115,0.771975,0.41302,0.693362
4608,1.8059,1.660652,0.772301,0.412626,0.693388
5120,1.778,1.644493,0.772808,0.413753,0.694583


Step,Training Loss,Validation Loss,Rouge-1,Rouge-2,Rouge-l
512,1.7432,1.724787,0.771811,0.411786,0.691962
1024,1.7311,1.722753,0.772798,0.413633,0.691759
1536,1.7333,1.709278,0.772665,0.414348,0.693385
2048,1.8558,1.680838,0.773619,0.416598,0.695329
2560,1.8427,1.661814,0.773028,0.413879,0.693893
3072,1.827,1.664096,0.772125,0.411808,0.693255
3584,1.8285,1.653124,0.77231,0.414298,0.693539
4096,1.8193,1.656115,0.771975,0.41302,0.693362
4608,1.8059,1.660652,0.772301,0.412626,0.693388
5120,1.778,1.644493,0.772808,0.413753,0.694583


Step,Training Loss,Validation Loss,Rouge-1,Rouge-2,Rouge-l
512,1.7432,1.724787,0.771811,0.411786,0.691962
1024,1.7311,1.722753,0.772798,0.413633,0.691759
1536,1.7333,1.709278,0.772665,0.414348,0.693385
2048,1.8558,1.680838,0.773619,0.416598,0.695329
2560,1.8427,1.661814,0.773028,0.413879,0.693893
3072,1.827,1.664096,0.772125,0.411808,0.693255
3584,1.8285,1.653124,0.77231,0.414298,0.693539
4096,1.8193,1.656115,0.771975,0.41302,0.693362
4608,1.8059,1.660652,0.772301,0.412626,0.693388
5120,1.778,1.644493,0.772808,0.413753,0.694583


Step,Training Loss,Validation Loss,Rouge-1,Rouge-2,Rouge-l
512,1.7432,1.724787,0.771811,0.411786,0.691962
1024,1.7311,1.722753,0.772798,0.413633,0.691759
1536,1.7333,1.709278,0.772665,0.414348,0.693385
2048,1.8558,1.680838,0.773619,0.416598,0.695329
2560,1.8427,1.661814,0.773028,0.413879,0.693893
3072,1.827,1.664096,0.772125,0.411808,0.693255
3584,1.8285,1.653124,0.77231,0.414298,0.693539
4096,1.8193,1.656115,0.771975,0.41302,0.693362
4608,1.8059,1.660652,0.772301,0.412626,0.693388
5120,1.778,1.644493,0.772808,0.413753,0.694583


Step,Training Loss,Validation Loss,Rouge-1,Rouge-2,Rouge-l
512,1.7432,1.724787,0.771811,0.411786,0.691962
1024,1.7311,1.722753,0.772798,0.413633,0.691759
1536,1.7333,1.709278,0.772665,0.414348,0.693385
2048,1.8558,1.680838,0.773619,0.416598,0.695329
2560,1.8427,1.661814,0.773028,0.413879,0.693893
3072,1.827,1.664096,0.772125,0.411808,0.693255
3584,1.8285,1.653124,0.77231,0.414298,0.693539
4096,1.8193,1.656115,0.771975,0.41302,0.693362
4608,1.8059,1.660652,0.772301,0.412626,0.693388
5120,1.778,1.644493,0.772808,0.413753,0.694583


Step,Training Loss,Validation Loss,Rouge-1,Rouge-2,Rouge-l
512,1.7432,1.724787,0.771811,0.411786,0.691962
1024,1.7311,1.722753,0.772798,0.413633,0.691759
1536,1.7333,1.709278,0.772665,0.414348,0.693385
2048,1.8558,1.680838,0.773619,0.416598,0.695329
2560,1.8427,1.661814,0.773028,0.413879,0.693893
3072,1.827,1.664096,0.772125,0.411808,0.693255
3584,1.8285,1.653124,0.77231,0.414298,0.693539
4096,1.8193,1.656115,0.771975,0.41302,0.693362
4608,1.8059,1.660652,0.772301,0.412626,0.693388
5120,1.778,1.644493,0.772808,0.413753,0.694583


Step,Training Loss,Validation Loss,Rouge-1,Rouge-2,Rouge-l
512,1.7432,1.724787,0.771811,0.411786,0.691962
1024,1.7311,1.722753,0.772798,0.413633,0.691759
1536,1.7333,1.709278,0.772665,0.414348,0.693385
2048,1.8558,1.680838,0.773619,0.416598,0.695329
2560,1.8427,1.661814,0.773028,0.413879,0.693893
3072,1.827,1.664096,0.772125,0.411808,0.693255
3584,1.8285,1.653124,0.77231,0.414298,0.693539
4096,1.8193,1.656115,0.771975,0.41302,0.693362
4608,1.8059,1.660652,0.772301,0.412626,0.693388
5120,1.778,1.644493,0.772808,0.413753,0.694583


Step,Training Loss,Validation Loss,Rouge-1,Rouge-2,Rouge-l
512,1.7432,1.724787,0.771811,0.411786,0.691962
1024,1.7311,1.722753,0.772798,0.413633,0.691759
1536,1.7333,1.709278,0.772665,0.414348,0.693385
2048,1.8558,1.680838,0.773619,0.416598,0.695329
2560,1.8427,1.661814,0.773028,0.413879,0.693893
3072,1.827,1.664096,0.772125,0.411808,0.693255
3584,1.8285,1.653124,0.77231,0.414298,0.693539
4096,1.8193,1.656115,0.771975,0.41302,0.693362
4608,1.8059,1.660652,0.772301,0.412626,0.693388
5120,1.778,1.644493,0.772808,0.413753,0.694583


KeyboardInterrupt: 

In [None]:
pipe = pipeline('text2text-generation', model=model, tokenizer=tokenizer, device=0)
text = ds['validation'][200]['article']
target = ds['validation'][200]['highlights']
print(target)
print("----------------------------------\n")
pip_res = pipe("Generate summary:\n" + text, max_length = 64)
t5_summary = pip_res[0]['generated_text']
print(t5_summary)
print("----------------------------------\n")
print(" Rouge-L between label and generate summary with t5 model is ", rouge.get_scores(target, t5_summary)[0]['rouge-l'])


### Step8.5: Retrain model if needed

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained("./summary/last-checkpoint-3072")
print("Model weights loaded...\n")

pipe = pipeline('text2text-generation', model=model, tokenizer=tokenizer, device=0)
text = ds['validation'][200]['article']
target = ds['validation'][200]['highlights']
print(target)
print("----------------------------------\n")
pip_res = pipe("Generate summary:\n" + text, max_length = 64)
t5_summary = pip_res[0]['generated_text']
print(t5_summary)
print("----------------------------------\n")
print(" Rouge-L between label and generate summary with t5 model is ", rouge.get_scores(target, t5_summary)[0]['rouge-l'])


In [None]:
validations = ds['validation']
texts: list[str] = validations['article']
labels: list[str] = validations['highlights']
t5_summaries : list[str] = [pipe(each, max_length = 64)[0]['generated_text'] for each in texts]
rouge.get_scores(labels, t5_summaries, avg=True)['rouge-l']

In [None]:
text_generator = pipeline("text2text-generation")

# List of input prompts
input_prompts = [
    "Once upon a time, there was a king who ruled over a prosperous kingdom.",
    "In a galaxy far, far away, a young Jedi embarked on a journey to defeat the Sith.",
    "The scientist conducted an experiment that would change the course of human history."
]

# Generate results for each input prompt
results = [text_generator(prompt) for prompt in input_prompts]

# Print results
for input_prompt, result in zip(input_prompts, results):
    print("Input Prompt:", input_prompt)
    print("Generated Text:", result)
    print()

### Step9: Evaluate the model

In [None]:
import spacy
from spacy.lang.en.stop_words import STOP_WORDS
from string import punctuation
from heapq import nlargest
from datasets import load_dataset
from rouge import Rouge
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments

stopwords = list(STOP_WORDS)
nlp = spacy.load('en_core_web_sm')


In [None]:
def select_main_sentence(text, punctuation, nlp):
    summary_length = 3
    doc = nlp(text)
    tokens = [token.text for token in doc]
    punctuation = punctuation + '\n'
    sentence_tokens = [sent for sent in doc.sents]
    
    word_frequencies = {}
    for word in doc:
        if word.text.lower() not in stopwords:
            if word.text.lower() not in punctuation:
                if word.text not in word_frequencies.keys():
                    word_frequencies[word.text] = 1
                else:
                    word_frequencies[word.text] += 1

    sentence_scores = {}
    for sent in sentence_tokens:
        for word in sent:
            if word.text.lower() in word_frequencies.keys():
                if sent not in sentence_scores.keys():
                    sentence_scores[sent] = word_frequencies[word.text.lower()]
                else:
                    sentence_scores[sent] += word_frequencies[word.text.lower()]
         
    summary = nlargest(summary_length, sentence_scores, key = sentence_scores.get)
    return summary

In [None]:
print("\n----------------------------article---------------------------------------\n")
text = ds['validation'][1400]['article']
print(text)
print("\n----------------------------label---------------------------------------\n")
target = ds['validation'][1400]['highlights']
print(target)
print("\n----------------------------generate summary---------------------------------------")
summary = select_main_sentence(text, punctuation, nlp)
generate_summary = ""
for each in summary:
    generate_summary = generate_summary + str(each)
print(generate_summary)
print(" Rouge-L: ", rouge.get_scores(target, generate_summary)[0]['rouge-l'])

print("\n----------------------------generate summary t5 model---------------------------------------")
pip_res = pipe("Generate summary:\n" + text, max_length = 64)
t5_summary = pip_res[0]['generated_text']
print(t5_summary)
print(" Rouge-L between label and generate summary with t5 model is ", rouge.get_scores(target, t5_summary)[0]['rouge-l'])