In [16]:
import torch
from datasets import load_dataset

import pandas as pd
import gc
from transformers import AutoTokenizer, T5ForConditionalGeneration

import evaluate
import nltk
from nltk.tokenize import sent_tokenize

nltk.download("punkt")

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Shindy\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [3]:
full_dataset = load_dataset(
    "cnn_dailymail", "3.0.0"
)  # Note: We specify cache_dir to use pre-cached data.

# Use a small sample of the data during this lab, for speed.
sample_size = 100
sample = (
    full_dataset["train"]
    .filter(lambda r: "CNN" in r["article"][:25])
    .shuffle(seed=42)
    .select(range(sample_size))
)
sample

Downloading builder script:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/9.88k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/15.1k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/5 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/159M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/376M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/12.3M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/661k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/572k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/287113 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13368 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11490 [00:00<?, ? examples/s]

Filter:   0%|          | 0/287113 [00:00<?, ? examples/s]

Dataset({
    features: ['article', 'highlights', 'id'],
    num_rows: 100
})

In [5]:
display(sample.to_pandas())

Unnamed: 0,article,highlights,id
0,(CNN) -- A magnitude 6.7 earthquake rattled Pa...,Papua New Guinea is on the so-called Ring of F...,8093dba7bc2260c26f18939826909ef27549c758
1,(CNN) -- Pakistan took big steps towards level...,Australia collapse to 88 all out on opening da...,67d626156f971d0bf55e5f2a48e1ed965eb622a6
2,(CNN) -- Federal prosecutors are pushing to fo...,Jared Loughner is refusing the government's re...,0d02fb8f0d406db956b128a5c1cc7bf3f13860a6
3,"Centennial, Colorado (CNN) -- McKayla Hicks sa...",Shooting victim McKayla Hicks went to hearing ...,39aee887c6d34bd311c826142b14037e6f2639ee
4,(CNN) -- Double-amputee sprinter Oscar Pistori...,Oscar Pistorius to become first double-amputee...,cc83ecdf08f0b598c3b97b3e2819c7e0ae7ca4f2
...,...,...,...
95,(CNN) -- Samuel Eto'o netted a superb hat-tric...,Samuel Eto'o scored a hat-trick as Inter Milan...,6c1924f5852b6980a0835877d3f9591a00c70f37
96,Washington (CNN) -- President Barack Obama's r...,Obama raised almost $30 million less than Romn...,0a5691b8fe654b6b2cdace5ab87aff2ee4c23577
97,(CNN) -- Violence swept across Syria on Friday...,NEW: U.N. Secretary-General Ban Ki-moon joins ...,2cc6e4db9758192ac467bbd7424782e4c92206c1
98,(CNN) -- New HIV infections have fallen worldw...,New infections in sub-Saharan Africa 15 percen...,acb2148184f83ecb516ad19a1b0a0e1bc5047237


In [6]:
example_article = sample["article"][0]
example_summary = sample["highlights"][0]
print(f"Article:\n{example_article}\n")
print(f"Summary:\n{example_summary}")

Article:

Summary:
Papua New Guinea is on the so-called Ring of Fire .
It's on an arc of fault lines that is prone to frequent earthquakes .


In [7]:
def batch_generator(data: list, batch_size: int):
    """
    Creates batches of size `batch_size` from a list.
    """
    s = 0
    e = s + batch_size
    while s < len(data):
        yield data[s:e]
        s = e
        e = min(s + batch_size, len(data))

In [10]:
def summarize_with_t5(
    model_checkpoint: str, articles: list, batch_size: int = 8
) -> list:
    """
    Compute summaries using a T5 model.
    This is similar to a `pipeline` for a T5 model but does tokenization manually.

    :param model_checkpoint: Name for a model checkpoint in Hugging Face, such as "t5-small" or "t5-base"
    :param articles: List of strings, where each string represents one article.
    :return: List of strings, where each string represents one article's generated summary
    """
    if torch.cuda.is_available():
        device = "cuda:0"
    else:
        device = "cpu"

    model = T5ForConditionalGeneration.from_pretrained(model_checkpoint).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, model_max_length=1024)

    def perform_inference(batch: list) -> list:
        inputs = tokenizer(
            batch, max_length=1024, return_tensors="pt", padding=True, truncation=True
        )

        summary_ids = model.generate(
            inputs.input_ids.to(device),
            attention_mask=inputs.attention_mask.to(device),
            num_beams=2,
            min_length=0,
            max_length=40,
        )
        return tokenizer.batch_decode(summary_ids, skip_special_tokens=True)

    res = []

    summary_articles = list(map(lambda article: "summarize: " + article, articles))
    for batch in batch_generator(summary_articles, batch_size=batch_size):
        res += perform_inference(batch)

        torch.cuda.empty_cache()
        gc.collect()

    # clean up
    del tokenizer
    del model
    torch.cuda.empty_cache()
    gc.collect()
    return res

In [11]:
t5_small_summaries = summarize_with_t5("t5-small", sample["article"])

In [12]:
reference_summaries = sample["highlights"]

In [13]:
display(
    pd.DataFrame.from_dict(
        {
            "generated": t5_small_summaries,
            "reference": reference_summaries,
        }
    )
)

Unnamed: 0,generated,reference
0,a magnitude 6.7 earthquake rattles Papua new G...,Papua New Guinea is on the so-called Ring of F...
1,the two-Test cricket series is being played in...,Australia collapse to 88 all out on opening da...
2,federal prosecutors want jared Lee Loughner to...,Jared Loughner is refusing the government's re...
3,"new: ""he tried to kill people,"" a 17-year-old ...",Shooting victim McKayla Hicks went to hearing ...
4,double-amputee sprinter Oscar Pistorius will c...,Oscar Pistorius to become first double-amputee...
...,...,...
95,holders Inter Milan thrash Werder Bremen 4-0 i...,Samuel Eto'o scored a hat-trick as Inter Milan...
96,president's re-election campaign raises $71 mi...,Obama raised almost $30 million less than Romn...
97,"at least 75 people were killed in protests, an...",NEW: U.N. Secretary-General Ban Ki-moon joins ...
98,new infections have fallen by 17 percent in th...,New infections in sub-Saharan Africa 15 percen...


In [14]:
accuracy = 0.0
for i in range(len(reference_summaries)):
    generated_summary = t5_small_summaries[i]
    if generated_summary == reference_summaries[i]:
        accuracy += 1.0
accuracy = accuracy / len(reference_summaries)

print(f"Achieved accuracy {accuracy}!")

Achieved accuracy 0.0!


In [18]:
rouge_score = evaluate.load("rouge")

In [19]:
def compute_rouge_score(generated: list, reference: list) -> dict:
    """
    Compute ROUGE scores on a batch of articles.

    This is a convenience function wrapping Hugging Face `rouge_score`,
    which expects sentences to be separated by newlines.

    :param generated: Summaries (list of strings) produced by the model
    :param reference: Ground-truth summaries (list of strings) for comparison
    """
    generated_with_newlines = ["\n".join(sent_tokenize(s.strip())) for s in generated]
    reference_with_newlines = ["\n".join(sent_tokenize(s.strip())) for s in reference]
    return rouge_score.compute(
        predictions=generated_with_newlines,
        references=reference_with_newlines,
        use_stemmer=True,
    )

In [20]:
compute_rouge_score(t5_small_summaries, reference_summaries)

{'rouge1': 0.30947467673029944,
 'rouge2': 0.10605629582246004,
 'rougeL': 0.2210018955526185,
 'rougeLsum': 0.28223859874195334}

In [21]:
compute_rouge_score(reference_summaries, reference_summaries)

{'rouge1': 1.0, 'rouge2': 1.0, 'rougeL': 1.0, 'rougeLsum': 1.0}

In [22]:
compute_rouge_score(
    generated=["" for _ in range(len(reference_summaries))],
    reference=reference_summaries,
)

{'rouge1': 0.0, 'rouge2': 0.0, 'rougeL': 0.0, 'rougeLsum': 0.0}

In [23]:
rouge_score.compute(
    predictions=["Large language models beat world record"],
    references=["Large language models beating world records"],
    use_stemmer=False,
)

{'rouge1': 0.6666666666666666,
 'rouge2': 0.4000000000000001,
 'rougeL': 0.6666666666666666,
 'rougeLsum': 0.6666666666666666}

In [24]:
rouge_score.compute(
    predictions=["Large language models beat world record"],
    references=["Large language models beating world records"],
    use_stemmer=True,
)

{'rouge1': 1.0, 'rouge2': 1.0, 'rougeL': 1.0, 'rougeLsum': 1.0}

In [25]:
rouge_score.compute(
    predictions=["Large language models beat world record"],
    references=["Large"],
    use_stemmer=True,
)

{'rouge1': 0.2857142857142857,
 'rouge2': 0.0,
 'rougeL': 0.2857142857142857,
 'rougeLsum': 0.2857142857142857}

In [26]:
rouge_score.compute(
    predictions=["Large"],
    references=["Large language models beat world record"],
    use_stemmer=True,
)

{'rouge1': 0.2857142857142857,
 'rouge2': 0.0,
 'rougeL': 0.2857142857142857,
 'rougeLsum': 0.2857142857142857}

In [27]:
rouge_score.compute(
    predictions=["Large language"],
    references=["Large language models beat world record"],
    use_stemmer=True,
)

{'rouge1': 0.5, 'rouge2': 0.33333333333333337, 'rougeL': 0.5, 'rougeLsum': 0.5}

In [28]:
rouge_score.compute(
    predictions=["Models beat large language world record"],
    references=["Large language models beat world record"],
    use_stemmer=True,
)

{'rouge1': 1.0,
 'rouge2': 0.6,
 'rougeL': 0.6666666666666666,
 'rougeLsum': 0.6666666666666666}