In [22]:
!pip install rouge_score

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [21]:
!pip install evaluate

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


# Multi XScience

In [1]:
import re

from datasets import load_dataset, load_metric
import evaluate
import nltk
import nltk.data
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AdamW, AutoTokenizer, AutoModelForSeq2SeqLM

In [2]:
torch.backends.mps.is_available()

True

In [3]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to /Users/luka/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [40]:
DATASET_NAME = "multi_x_science_sum"
DOC_SEP = " ||||| "
BATCH_SIZE = 64

## Set up evaluation

In [5]:
rouge = load_metric("rouge")

  rouge = load_metric("rouge")


Downloading builder script:   0%|          | 0.00/2.17k [00:00<?, ?B/s]

## Load dataset

In [6]:
dataset = load_dataset(DATASET_NAME)

Found cached dataset multi_x_science_sum (/Users/luka/.cache/huggingface/datasets/multi_x_science_sum/default/1.1.0/2876ec0401f8f5c5acf7f4857dbc8d6229a390ab428321ab848f03f14b7f9729)


  0%|          | 0/3 [00:00<?, ?it/s]

In [7]:
dataset

DatasetDict({
    train: Dataset({
        features: ['aid', 'mid', 'abstract', 'related_work', 'ref_abstract'],
        num_rows: 30369
    })
    test: Dataset({
        features: ['aid', 'mid', 'abstract', 'related_work', 'ref_abstract'],
        num_rows: 5093
    })
    validation: Dataset({
        features: ['aid', 'mid', 'abstract', 'related_work', 'ref_abstract'],
        num_rows: 5066
    })
})

In [8]:
dataset["train"][0]

{'aid': 'math9912167',
 'mid': '1631980677',
 'abstract': 'Author(s): Kuperberg, Greg; Thurston, Dylan P. | Abstract: We give a purely topological definition of the perturbative quantum invariants of links and 3-manifolds associated with Chern-Simons field theory. Our definition is as close as possible to one given by Kontsevich. We will also establish some basic properties of these invariants, in particular that they are universally finite type with respect to algebraically split surgery and with respect to Torelli surgery. Torelli surgery is a mutual generalization of blink surgery of Garoufalidis and Levine and clasper surgery of Habiro.',
 'related_work': 'Two other generalizations that can be considered are invariants of graphs in 3-manifolds, and invariants associated to other flat connections @cite_16 . We will analyze these in future work. Among other things, there should be a general relation between flat bundles and links in 3-manifolds on the one hand and finite covers and b

## Format dataset to our needs

In [9]:
pat = re.compile("@cite_[0-9]+")

In [10]:
def preprocess_dataset(example):
    output = {}
    output["abstracts"] = (
        example["abstract"].split("| Abstract: ")[-1]
        + DOC_SEP
        + DOC_SEP.join([x for x in example["ref_abstract"]["abstract"] if x])
    )
    output["related_work"] = pat.sub("@cite", example["related_work"])
    
    return output

In [11]:
def preprocess_dataset_batched(example):
    output = {}
    output["abstracts"] = []
    output["related_work"] = []
    
    for abstract, ref_abstract in zip(
        example["abstract"], example["ref_abstract"]
    ):
        output["abstracts"].append(
            abstract.split("| Abstract: ")[-1]
            + DOC_SEP
            + DOC_SEP.join([x for x in ref_abstract["abstract"] if x])
        )
    for related_work in example["related_work"]:
        output["related_work"].append(pat.sub("@cite", related_work))
    
    return output

In [12]:
dataset_processed = {}
for split in dataset.keys():
    dataset_processed[split] = dataset[split].map(
        # preprocess_dataset,
        preprocess_dataset_batched,
        remove_columns=dataset[split].column_names,
        batched=True,
        batch_size=BATCH_SIZE,
    )

Loading cached processed dataset at /Users/luka/.cache/huggingface/datasets/multi_x_science_sum/default/1.1.0/2876ec0401f8f5c5acf7f4857dbc8d6229a390ab428321ab848f03f14b7f9729/cache-f3232cb45e6040e6.arrow
Loading cached processed dataset at /Users/luka/.cache/huggingface/datasets/multi_x_science_sum/default/1.1.0/2876ec0401f8f5c5acf7f4857dbc8d6229a390ab428321ab848f03f14b7f9729/cache-da1a80df5f4bf131.arrow
Loading cached processed dataset at /Users/luka/.cache/huggingface/datasets/multi_x_science_sum/default/1.1.0/2876ec0401f8f5c5acf7f4857dbc8d6229a390ab428321ab848f03f14b7f9729/cache-238661e6a318eda9.arrow


## Baseline

First 3 sentences from each xscience

In [13]:
dataset_processed["test"]

Dataset({
    features: ['related_work', 'abstracts'],
    num_rows: 5093
})

In [14]:
punkt_tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')

In [15]:
test_pred_baselines = [
    " ".join(
        [
            " ".join(punkt_tokenizer.tokenize(each)[:2]).strip()
            for each in element["abstracts"].split("|||||")
        ]
    ).strip()
    for element in dataset_processed["test"]
]

In [16]:
scores = rouge.compute(
    predictions=test_pred_baselines,
    references=dataset_processed["test"]["related_work"],
    use_stemmer=True,
)
scores

{'rouge1': AggregateScore(low=Score(precision=0.24482187110610437, recall=0.4231868057438022, fmeasure=0.2903145549267753), mid=Score(precision=0.24769720463457645, recall=0.4264784229787808, fmeasure=0.29245088738586544), high=Score(precision=0.25059692375861276, recall=0.42951674345918534, fmeasure=0.2946598249726346)),
 'rouge2': AggregateScore(low=Score(precision=0.04405391930886794, recall=0.07970571327663246, fmeasure=0.053025462922424035), mid=Score(precision=0.04500946103108484, recall=0.08141951669959666, fmeasure=0.05401275674974608), high=Score(precision=0.04609414250738718, recall=0.08307551240628612, fmeasure=0.05517160063438503)),
 'rougeL': AggregateScore(low=Score(precision=0.12238852974260858, recall=0.21482305928198114, fmeasure=0.1453619743002698), mid=Score(precision=0.12366100543098785, recall=0.2167294062130687, fmeasure=0.14639122094115153), high=Score(precision=0.12497420762501354, recall=0.21871095600126428, fmeasure=0.14746105556552414)),
 'rougeLsum': Aggrega

## Model 1: Default Centrum

* Probably need to figure out the distribution of `dataset_processed["test"]["abstracts"]` so that we can estimate the best `max_length`.

In [17]:
CHECKPOINT = "ratishsp/Centrum"

In [18]:
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
model = AutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT)

In [19]:
tokenizer.add_tokens(DOC_SEP, special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
docsep_token_id = tokenizer.convert_tokens_to_ids(DOC_SEP)

In [20]:
dataset_tokenized = {}

dataset_tokenized["test"] = tokenizer(
    dataset_processed["test"]["abstracts"],
    padding=True,
    truncation=True,
    max_length=1024,
    return_tensors="pt",
)

In [21]:
dataset_tokenized["test"]["global_attention_mask"] = np.array([
    [
        1 if token in (tokenizer.cls_token_id, docsep_token_id) else 0 
        for token in each
    ]
    for each in dataset_tokenized["test"]["input_ids"]
])

In [41]:
led_output_model1 = []
for i in tqdm(range(0, len(dataset_tokenized["test"]["input_ids"]), BATCH_SIZE)):
    
    input_ids = dataset_tokenized["test"]["input_ids"][i:i+BATCH_SIZE]
    attention_mask = dataset_tokenized["test"]["attention_mask"][i:i+BATCH_SIZE]
    global_attention_mask = dataset_tokenized["test"]["global_attention_mask"][i:i+BATCH_SIZE]

    led_output_model1.append(
        model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            global_attention_mask=global_attention_mask,
            no_repeat_ngram_size=3,
            max_length=128,
            num_beams=4,
        )
    )
        
led_output_model1 = torch.cat(led_output_model1)

100%|████████████████████████████████████████████████████████████████████████████████████| 80/80 [3:11:58<00:00, 143.98s/it]


In [42]:
test_pred_model1 = tokenizer.batch_decode(
    led_output_model1,
    skip_special_tokens=True,
    clean_up_tokenization_spaces=True,
)

In [43]:
test_pred_model1 = [each.strip() for each in test_pred_model1]

In [44]:
test_pred_model1[:2]

["The long-term goal of our field is the creation and understanding of intelligence. Productive research in AI, both practical and theoretical, benefits from a notion of intelligence that is precise enough to allow the cumulative development of robust systems and general results. This paper outlines a gradual evolution in our formal conception of intelligence, that brings it closer to our informal conception and simultaneously reduces the gap between theory and practice. The article presents experimental results illustrating the agents' dynamic behavior. I. Introduction, 488. — II. The model with automobiles as an example, 489. — III. Examples and applications, 492. — IV.",
 '“Interaction in virtual reality (VR) environments (e.g. grasping and manipulating virtual objects) is essential to ensure a pleasant and immersive experience. In this work, we propose a visually realistic, flexible and robust grasping system that enables real-time interactions in virtual environments. Resulting gr

In [45]:
scores = rouge.compute(
    predictions=test_pred_model1,
    references=dataset_processed["test"]["related_work"],
    use_stemmer=True,
)
scores

{'rouge1': AggregateScore(low=Score(precision=0.2982484009370646, recall=0.3057877687297647, fmeasure=0.2877549383290699), mid=Score(precision=0.3011000320975565, recall=0.308186529440017, fmeasure=0.28974742885739196), high=Score(precision=0.30423223116720244, recall=0.31058012470155627, fmeasure=0.2917896849338474)),
 'rouge2': AggregateScore(low=Score(precision=0.046860562392848866, recall=0.04742509910853629, fmeasure=0.04476525721466253), mid=Score(precision=0.04805986457733996, recall=0.048611032584090635, fmeasure=0.0458330562332894), high=Score(precision=0.04916571328652274, recall=0.0498130873004002, fmeasure=0.046871122190935706)),
 'rougeL': AggregateScore(low=Score(precision=0.15540080811555254, recall=0.1636225890270709, fmeasure=0.15142010390078578), mid=Score(precision=0.15687893432752667, recall=0.16523845626801453, fmeasure=0.1524776573706927), high=Score(precision=0.15834376408384634, recall=0.16680563034698537, fmeasure=0.15364196168214198)),
 'rougeLsum': AggregateS