<a href="https://colab.research.google.com/github/NikitasTsingenopoulos/WikiSum/blob/main/WikiSum.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -U datasets
!pip install -U evaluate
!pip install rouge_score

In [None]:
pip install bert-score

In [None]:
from transformers import (AutoTokenizer, LEDConfig, LEDForConditionalGeneration)
from datasets import load_dataset, Dataset
import torch
import pandas as pd

In [None]:
import nltk
nltk.download('punkt')
nltk.download('punkt_tab')

In [None]:
from bert_score import score
from nltk.tokenize import sent_tokenize

In [None]:
from huggingface_hub import hf_hub_download

In [None]:
repo_id = "d0rj/wikisum"
splits = {'train': 'data/train-00000-of-00001-b28959cff7dcaf55.parquet', 'validation': 'data/validation-00000-of-00001-21f2c9acfa77bab4.parquet', 'test': 'data/test-00000-of-00001-52a8a7cd640a9fff.parquet'}

In [None]:
dataset_train = pd.read_parquet(hf_hub_download(repo_id=repo_id, repo_type="dataset", filename=splits["train"]))
dataset_val = pd.read_parquet(hf_hub_download(repo_id=repo_id, repo_type="dataset", filename=splits["validation"]))
dataset_test = pd.read_parquet(hf_hub_download(repo_id=repo_id, repo_type="dataset", filename=splits["test"]))

In [None]:
dataset_train = Dataset.from_pandas(dataset_train)
dataset_val = Dataset.from_pandas(dataset_val)
dataset_test = Dataset.from_pandas(dataset_test)

In [None]:
tokenizer = AutoTokenizer.from_pretrained('allenai/PRIMERA')

config=LEDConfig.from_pretrained('allenai/PRIMERA')

model = LEDForConditionalGeneration.from_pretrained('allenai/PRIMERA')
model.gradient_checkpointing_enable()

PAD_TOKEN_ID = tokenizer.pad_token_id
DOCSEP_TOKEN_ID = tokenizer.convert_tokens_to_ids("<doc-sep>")

In [None]:
def process_document(documents):
    input_ids_all = []

    for data in documents:
        cleaned_text = " ".join(data.replace("\n", " ").split())
        input_chunks = [cleaned_text]

        input_ids = []
        for doc in input_chunks:
            input_ids.extend(
                tokenizer.encode(
                    doc,
                    truncation=True,
                    max_length=4096,
                )[1:-1]
            )
            input_ids.append(DOCSEP_TOKEN_ID)

        input_ids = (
            [tokenizer.bos_token_id]
            + input_ids
            + [tokenizer.eos_token_id]
        )
        input_ids_all.append(torch.tensor(input_ids))

    input_ids = torch.nn.utils.rnn.pad_sequence(
        input_ids_all, batch_first=True, padding_value=PAD_TOKEN_ID
    )
    return input_ids

In [None]:
def bert_apply(generated, reference_summary):
    summary_sents = sent_tokenize(generated)
    # Create a list of the same reference summary, repeated for each candidate sentence
    reference = [reference_summary] * len(summary_sents)

    # Score sentences
    P, R, F1 = score(summary_sents, reference, lang="en", verbose=True)

    # Rank and select top N
    ranked = sorted(zip(summary_sents, F1), key=lambda x: x[1], reverse=True)
    concise_summary = " ".join([sent for sent, _ in ranked[:4]])
    return concise_summary

In [None]:
def batch_process(batch):
    input_ids=process_document(batch['article'])
    # get the input ids and attention masks together
    global_attention_mask = torch.zeros_like(input_ids).to(input_ids.device)
    # put global attention on <s> token

    global_attention_mask[:, 0] = 1
    global_attention_mask[input_ids == DOCSEP_TOKEN_ID] = 1
    generated_ids = model.generate(
        input_ids=input_ids,
        global_attention_mask=global_attention_mask,
        use_cache=True,
        max_length=512,
        min_length=50,
        num_beams=5,
        length_penalty=2,
        no_repeat_ngram_size=3,
        early_stopping=True
    )
    generated_str = tokenizer.batch_decode(
            generated_ids.tolist(), skip_special_tokens=True
        )

    generated_str_bert = [bert_apply(generated, ref) for generated, ref in zip(generated_str, batch['summary'])]

    result={}
    result['generated_summaries'] = generated_str_bert
    result['gt_summaries']=batch['summary']
    return result

In [None]:
import random

docs = random.sample(range(len(dataset_test)),k=5)
docs

In [None]:
dataset_small = dataset_test.select(docs)
result_small = dataset_small.map(batch_process, batched=True, batch_size=2)

In [None]:
import evaluate

rouge = evaluate.load("rouge")

In [None]:
rouge_scores = rouge.compute(predictions=result_small["generated_summaries"], references=result_small["gt_summaries"])

print(f"ROUGE-1: {rouge_scores['rouge1']:.4f}")
print(f"ROUGE-2: {rouge_scores['rouge2']:.4f}")
print(f"ROUGE-L: {rouge_scores['rougeL']:.4f}")

In [None]:
for i in range(0,5):
    example_summ = dataset_small[i]['summary']
    produced_summ = result_small[i]['generated_summaries']
    display(example_summ)
    display(produced_summ)
    print("\n")