<a href="https://colab.research.google.com/github/NikitasTsingenopoulos/Multi-XSience/blob/main/Multi_XScience.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -U datasets
!pip install -U evaluate
!pip install rouge_score

In [None]:
from transformers import (AutoTokenizer, LEDConfig, LEDForConditionalGeneration)
from datasets import load_dataset,Dataset, DatasetDict
import torch

In [None]:
!git clone https://github.com/yaolu/Multi-XScience.git

In [None]:
!gunzip /content/Multi-XScience/data/test.json.gz
!gunzip /content/Multi-XScience/data/train.json.gz
!gunzip /content/Multi-XScience/data/val.json.gz

In [None]:
import json

train_path = '/content/Multi-XScience/data/train.json'
val_path = '/content/Multi-XScience/data/val.json'
test_path = '/content/Multi-XScience/data/test.json'


with open(train_path, 'r', encoding='utf-8') as f:
  dataset_train = json.load(f)

with open(val_path, 'r', encoding='utf-8') as f:
  dataset_val = json.load(f)

with open(test_path, 'r', encoding='utf-8') as f:
  dataset_test = json.load(f)

In [None]:
tokenizer = AutoTokenizer.from_pretrained('allenai/PRIMERA')

config=LEDConfig.from_pretrained('allenai/PRIMERA')

model = LEDForConditionalGeneration.from_pretrained('allenai/PRIMERA')
model.gradient_checkpointing_enable()

PAD_TOKEN_ID = tokenizer.pad_token_id
DOCSEP_TOKEN_ID = tokenizer.convert_tokens_to_ids("<doc-sep>")

In [None]:
def process_document(documents):
    input_ids_all = []
    for data in documents:
        article = data.replace("\n", " ")
        article = " ".join(article.split())
        input_ids = tokenizer.encode(
            article,
            truncation=True,
            max_length=4096,
        )[1:-1]
        input_ids = (
            [tokenizer.bos_token_id]
            + input_ids
            + [tokenizer.eos_token_id]
        )
        input_ids_all.append(torch.tensor(input_ids))

    input_ids = torch.nn.utils.rnn.pad_sequence(
        input_ids_all, batch_first=True, padding_value=PAD_TOKEN_ID
    )
    return input_ids

In [None]:
def batch_process(batch):

    ref_abstracts = []
    for ref_dict in batch["ref_abstract"]:
        context_abstracts = []
        for inner in ref_dict.values():
            if isinstance(inner, dict) and "abstract" in inner:
                context_abstracts.append(inner["abstract"])
        # Join all abstracts for this example
        ref_abstracts.append(" ".join(context_abstracts))


    inputs = [a + " " + r
              for a, r in zip(batch['abstract'], ref_abstracts)
    ]
    input_ids=process_document(inputs)
    # get the input ids and attention masks together
    global_attention_mask = torch.zeros_like(input_ids).to(input_ids.device)
    # put global attention on <s> token

    global_attention_mask[:, 0] = 1
    global_attention_mask[input_ids == DOCSEP_TOKEN_ID] = 1
    generated_ids = model.generate(
        input_ids=input_ids,
        global_attention_mask=global_attention_mask,
        use_cache=True,
        max_length=1024,
        num_beams=5,
    )
    generated_str = tokenizer.batch_decode(
            generated_ids.tolist(), skip_special_tokens=True
        )

    result={}
    result['generated_summaries'] = generated_str
    result['gt_summaries']=batch['related_work'] # Use the correctly extracted abstracts for the batch
    return result

In [None]:
import random

test_dataset = Dataset.from_list(dataset_test)
docs = random.choices(range(len(dataset_test)),k=5)
docs

In [None]:
dataset_small = test_dataset.select(docs)
result_small = dataset_small.map(batch_process, batched=True, batch_size=2)

In [None]:
import evaluate

rouge = evaluate.load("rouge")

In [None]:
import textwrap

for summ in result_small['generated_summaries']:
    wrapped_summ = textwrap.fill(summ, 160)
    print(wrapped_summ)
    print("\n")

In [None]:
score=rouge.compute(predictions=result_small["generated_summaries"], references=result_small["gt_summaries"])

print(f"ROUGE-1: {score['rouge1']:.4f}")
print(f"ROUGE-2: {score['rouge2']:.4f}")
print(f"ROUGE-L: {score['rougeL']:.4f}")

In [None]:
example = dataset_small[4]

display(example)
#display(example_summ)