# Multi XScience

In [1]:
import re

from datasets import load_dataset, load_metric
import evaluate
import nltk
import nltk.data
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import (
    AdamW, AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)

In [2]:
torch.backends.mps.is_available()

True

In [3]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to /Users/luka/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [4]:
DATASET_NAME = "multi_x_science_sum"
DOC_SEP = " ||||| "
BATCH_SIZE = 64
MAX_LENGTH_ENC = 512  #4096
MAX_LENGTH_DEC = 256
# N = 16

## Set up evaluation

In [5]:
rouge = load_metric("rouge")

  rouge = load_metric("rouge")


## Load dataset

In [6]:
dataset = load_dataset(DATASET_NAME)

Found cached dataset multi_x_science_sum (/Users/luka/.cache/huggingface/datasets/multi_x_science_sum/default/1.1.0/2876ec0401f8f5c5acf7f4857dbc8d6229a390ab428321ab848f03f14b7f9729)


  0%|          | 0/3 [00:00<?, ?it/s]

In [7]:
dataset

DatasetDict({
    train: Dataset({
        features: ['aid', 'mid', 'abstract', 'related_work', 'ref_abstract'],
        num_rows: 30369
    })
    test: Dataset({
        features: ['aid', 'mid', 'abstract', 'related_work', 'ref_abstract'],
        num_rows: 5093
    })
    validation: Dataset({
        features: ['aid', 'mid', 'abstract', 'related_work', 'ref_abstract'],
        num_rows: 5066
    })
})

In [8]:
dataset["train"][0]

{'aid': 'math9912167',
 'mid': '1631980677',
 'abstract': 'Author(s): Kuperberg, Greg; Thurston, Dylan P. | Abstract: We give a purely topological definition of the perturbative quantum invariants of links and 3-manifolds associated with Chern-Simons field theory. Our definition is as close as possible to one given by Kontsevich. We will also establish some basic properties of these invariants, in particular that they are universally finite type with respect to algebraically split surgery and with respect to Torelli surgery. Torelli surgery is a mutual generalization of blink surgery of Garoufalidis and Levine and clasper surgery of Habiro.',
 'related_work': 'Two other generalizations that can be considered are invariants of graphs in 3-manifolds, and invariants associated to other flat connections @cite_16 . We will analyze these in future work. Among other things, there should be a general relation between flat bundles and links in 3-manifolds on the one hand and finite covers and b

## Format dataset to our needs

In [9]:
pat = re.compile("@cite_[0-9]+")

In [10]:
def preprocess_dataset_batched(example):
    output = {}
    output["abstracts"] = []
    output["related_work"] = []
    
    for abstract, ref_abstract in zip(
        example["abstract"], example["ref_abstract"]
    ):
        output["abstracts"].append(
            abstract.split("| Abstract: ")[-1]
            + DOC_SEP
            + DOC_SEP.join([x for x in ref_abstract["abstract"] if x])
        )
    for related_work in example["related_work"]:
        output["related_work"].append(pat.sub("@cite", related_work))
    
    return output

In [11]:
dataset_processed = (
    dataset
    # .filter(lambda _, idx: idx < N, with_indices=True)
    .map(
        # preprocess_dataset,
        preprocess_dataset_batched,
        remove_columns=dataset["train"].column_names,
        batched=True,
        batch_size=BATCH_SIZE,
    )
)

Loading cached processed dataset at /Users/luka/.cache/huggingface/datasets/multi_x_science_sum/default/1.1.0/2876ec0401f8f5c5acf7f4857dbc8d6229a390ab428321ab848f03f14b7f9729/cache-f3232cb45e6040e6.arrow
Loading cached processed dataset at /Users/luka/.cache/huggingface/datasets/multi_x_science_sum/default/1.1.0/2876ec0401f8f5c5acf7f4857dbc8d6229a390ab428321ab848f03f14b7f9729/cache-da1a80df5f4bf131.arrow
Loading cached processed dataset at /Users/luka/.cache/huggingface/datasets/multi_x_science_sum/default/1.1.0/2876ec0401f8f5c5acf7f4857dbc8d6229a390ab428321ab848f03f14b7f9729/cache-238661e6a318eda9.arrow


In [12]:
dataset_processed

DatasetDict({
    train: Dataset({
        features: ['related_work', 'abstracts'],
        num_rows: 30369
    })
    test: Dataset({
        features: ['related_work', 'abstracts'],
        num_rows: 5093
    })
    validation: Dataset({
        features: ['related_work', 'abstracts'],
        num_rows: 5066
    })
})

## Model 2: Fine-tune Centrum

In [20]:
CHECKPOINT = "ratishsp/Centrum"
CHECKPOINT_MODEL = "./checkpoint_512/checkpoint-950"

In [21]:
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
model = AutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT_MODEL)

In [22]:
tokenizer.add_tokens(DOC_SEP, special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
docsep_token_id = tokenizer.convert_tokens_to_ids(DOC_SEP)

In [23]:
def tokenize_dataset_batched(example):
    # Tokenizer input
    output = tokenizer(
        example["abstracts"],
        padding="max_length",
        truncation=True,
        max_length=MAX_LENGTH_ENC,
        return_tensors="pt",
    )
    
    # Tokenizer output
    output["labels"] = (
        tokenizer(
            example["related_work"],
            padding="max_length",
            truncation=True,
            max_length=MAX_LENGTH_DEC,
            return_tensors="pt",
            return_attention_mask=False,
        )
        .input_ids
    )
    # Tokenizer output ignore padding in loss function
    # torch ignore -100 in loss function computation
    output["labels"] = [
        [
            -100 if token == tokenizer.pad_token_id else token
            for token in labels
        ]
        for labels in output["labels"]
    ]
    
    # Global attention
    output["global_attention_mask"] = np.array(
        [
            [
                1 if token in (tokenizer.cls_token_id, docsep_token_id) else 0 
                for token in each
            ]
            for each in output["input_ids"]
        ], 
        dtype=np.float32,
    )
    
    return output

In [24]:
dataset_tokenized = (
    dataset_processed
    .map(
        tokenize_dataset_batched,
        remove_columns=dataset_processed["train"].column_names,
        batched=True,
        batch_size=BATCH_SIZE,
    )
)

Loading cached processed dataset at /Users/luka/.cache/huggingface/datasets/multi_x_science_sum/default/1.1.0/2876ec0401f8f5c5acf7f4857dbc8d6229a390ab428321ab848f03f14b7f9729/cache-9c485eb4fb0ff28b.arrow
Loading cached processed dataset at /Users/luka/.cache/huggingface/datasets/multi_x_science_sum/default/1.1.0/2876ec0401f8f5c5acf7f4857dbc8d6229a390ab428321ab848f03f14b7f9729/cache-b20b09fa3e8c6b66.arrow


Map:   0%|          | 0/5066 [00:00<?, ? examples/s]

In [25]:
dataset_tokenized

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels', 'global_attention_mask'],
        num_rows: 30369
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels', 'global_attention_mask'],
        num_rows: 5093
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels', 'global_attention_mask'],
        num_rows: 5066
    })
})

## Evaluation

## evaluation with sample 1000

In [54]:
# Sample 1000 for evaluation
rng = np.random.RandomState(0)

eval_idx = {}
for split in dataset_processed.keys():
    eval_idx[split] = rng.permutation(dataset_processed[split].shape[0])[:1000]

In [55]:
model_pred = {}

for split in dataset_tokenized.keys():
    model_pred[split] = []
    each_dataset = dataset_tokenized[split]
    for i in tqdm(range(0, len(eval_idx[split]), BATCH_SIZE)):
        
        idx = eval_idx[split][i:i+BATCH_SIZE]

        input_ids = torch.tensor(each_dataset["input_ids"])[idx]
        attention_mask = torch.tensor(each_dataset["attention_mask"])[idx]
        global_attention_mask = (
            torch.tensor(each_dataset["global_attention_mask"])[idx]
        )

        model_output = (
            model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                global_attention_mask=global_attention_mask,
                no_repeat_ngram_size=3,
                max_length=128,
                num_beams=4,
            )
        )
        
        model_pred[split].extend(
            tokenizer.batch_decode(
                model_output,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True,
            )
        )

100%|█████████████████████████████████████████████████████████████████████| 16/16 [26:46<00:00, 100.43s/it]
100%|██████████████████████████████████████████████████████████████████████| 16/16 [24:56<00:00, 93.56s/it]
100%|██████████████████████████████████████████████████████████████████████| 16/16 [24:33<00:00, 92.07s/it]


In [56]:
model_pred["train"][:2]

['Neural representation learning @cite is a well-studied area in computer vision, and has been applied to many computer vision tasks, such as clustering, classification, visualization, and word sense disambiguation. In this paper, we propose a new approach for learning graph embeddings, which is based on structural measures of node similarities for generation of training data. The model learns nodes that are able to approximate a given measure such as the shortest path distance or any other. Our work differs from previous work in that our model learns low dimensional vectors to represent vertices appearing in a graph and, unlike existing work,',
 'In @cite, the authors propose a type inference algorithm for Datalog with negation. The algorithm is based on a type system where equalities are tracked, and present a type-based algorithm. The results show that it is optimal for object-oriented Datalogy without negation, in the sense that the inferred type is as tight as possible.']

In [58]:
for split in dataset_tokenized.keys():
    
    model_pred[split] = [each.strip() for each in model_pred[split]]
    
    scores = rouge.compute(
        predictions=model_pred[split],
        references=(
            np.array(dataset_processed[split]["related_work"])
            [eval_idx[split]]
        ),
        use_stemmer=True,
    )
    print(split)
    print(scores)
    print()

train
{'rouge1': AggregateScore(low=Score(precision=0.34746422014300804, recall=0.273249424007993, fmeasure=0.28604606562659757), mid=Score(precision=0.3559469499811089, recall=0.27943006179438545, fmeasure=0.2904711339670817), high=Score(precision=0.36410203631150506, recall=0.286097831511261, fmeasure=0.2954159014835912)), 'rouge2': AggregateScore(low=Score(precision=0.05959426010148445, recall=0.04487425598194418, fmeasure=0.04742389772735309), mid=Score(precision=0.06308240138774394, recall=0.04758573403068603, fmeasure=0.04994374901741308), high=Score(precision=0.06657248003259754, recall=0.05063040323261355, fmeasure=0.052788689654841556)), 'rougeL': AggregateScore(low=Score(precision=0.19306343998654202, recall=0.15141867755368166, fmeasure=0.15794940024314422), mid=Score(precision=0.1980041613985281, recall=0.15521645582490146, fmeasure=0.16071709105215864), high=Score(precision=0.2029374466991714, recall=0.15935784407547057, fmeasure=0.1636477424420036)), 'rougeLsum': Aggregat

### train sample 2000

In [60]:
# Sample 1000 for evaluation
rng = np.random.RandomState(0)

eval_idx = {}
for split in dataset_processed.keys():
    eval_idx[split] = rng.permutation(dataset_processed[split].shape[0])[:2000]

In [61]:
model_pred = {}

for split in dataset_tokenized.keys():
    model_pred[split] = []
    each_dataset = dataset_tokenized[split]
    for i in tqdm(range(0, len(eval_idx[split]), BATCH_SIZE)):
        
        idx = eval_idx[split][i:i+BATCH_SIZE]

        input_ids = torch.tensor(each_dataset["input_ids"])[idx]
        attention_mask = torch.tensor(each_dataset["attention_mask"])[idx]
        global_attention_mask = (
            torch.tensor(each_dataset["global_attention_mask"])[idx]
        )

        model_output = (
            model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                global_attention_mask=global_attention_mask,
                no_repeat_ngram_size=3,
                max_length=128,
                num_beams=4,
            )
        )
        
        model_pred[split].extend(
            tokenizer.batch_decode(
                model_output,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True,
            )
        )

100%|██████████████████████████████████████████████████████████████████████| 32/32 [52:18<00:00, 98.07s/it]
100%|██████████████████████████████████████████████████████████████████████| 32/32 [48:18<00:00, 90.56s/it]
100%|██████████████████████████████████████████████████████████████████████| 32/32 [48:13<00:00, 90.41s/it]


In [62]:
for split in dataset_tokenized.keys():
    
    model_pred[split] = [each.strip() for each in model_pred[split]]
    
    scores = rouge.compute(
        predictions=model_pred[split],
        references=(
            np.array(dataset_processed[split]["related_work"])
            [eval_idx[split]]
        ),
        use_stemmer=True,
    )
    print(split)
    print(scores)
    print()

train
{'rouge1': AggregateScore(low=Score(precision=0.35391011334189376, recall=0.2737611459771506, fmeasure=0.28916646809152835), mid=Score(precision=0.35940752474614257, recall=0.2783476730762118, fmeasure=0.2924012530012049), high=Score(precision=0.36499001737976594, recall=0.282727809774108, fmeasure=0.29584997001026225)), 'rouge2': AggregateScore(low=Score(precision=0.06142128512453346, recall=0.046026066629214435, fmeasure=0.048951201528433894), mid=Score(precision=0.06388302825783397, recall=0.04801530560445494, fmeasure=0.050789875545480236), high=Score(precision=0.06669488169547416, recall=0.05021880024980211, fmeasure=0.05288513509201786)), 'rougeL': AggregateScore(low=Score(precision=0.1953893435988529, recall=0.1513425921569353, fmeasure=0.15910107951683983), mid=Score(precision=0.19862920523131505, recall=0.15400884197487658, fmeasure=0.16101314096447117), high=Score(precision=0.2018729706985693, recall=0.15688351413995336, fmeasure=0.16308482736658378)), 'rougeLsum': Aggr

# Sandbox

In [76]:
%%time

train_result = trainer.train()

***** Running training *****
  Num examples = 16
  Num Epochs = 2
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 4
  Number of trainable parameters = 152408832


Step,Training Loss,Validation Loss
2,2.2539,4.312594
4,1.969,4.312594


***** Running Evaluation *****
  Num examples = 16
  Batch size = 8
Saving model checkpoint to ./checkpoint-2
Configuration saved in ./checkpoint-2/config.json
Configuration saved in ./checkpoint-2/generation_config.json
Model weights saved in ./checkpoint-2/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 16
  Batch size = 8
Saving model checkpoint to ./checkpoint-4
Configuration saved in ./checkpoint-4/config.json
Configuration saved in ./checkpoint-4/generation_config.json
Model weights saved in ./checkpoint-4/pytorch_model.bin


Training completed. Do not forget to share your model on huggingface.co/models =)




CPU times: user 9min 10s, sys: 7min 41s, total: 16min 52s
Wall time: 3min 58s


In [77]:
trainer.log_metrics("train", train_result.metrics)

***** train metrics *****
  epoch                    =        2.0
  total_flos               =    80472GF
  train_loss               =     2.1585
  train_runtime            = 0:03:58.61
  train_samples_per_second =      0.134
  train_steps_per_second   =      0.017


In [None]:
led_output_model1 = []
for i in tqdm(range(0, len(dataset_tokenized["test"]["input_ids"]), BATCH_SIZE)):
    
    input_ids = dataset_tokenized["test"]["input_ids"][i:i+BATCH_SIZE]
    attention_mask = dataset_tokenized["test"]["attention_mask"][i:i+BATCH_SIZE]
    global_attention_mask = dataset_tokenized["test"]["global_attention_mask"][i:i+BATCH_SIZE]

    led_output_model1.append(
        model.generate(
            input_ids=torch.as_tensor(input_ids),
            attention_mask=torch.as_tensor(attention_mask),
            global_attention_mask=torch.as_tensor(global_attention_mask),
            no_repeat_ngram_size=3,
            max_length=128,
            num_beams=4,
        )
    )
        
led_output_model1 = torch.cat(led_output_model1)

In [39]:
test_pred_model1 = tokenizer.batch_decode(
    led_output_model1,
    skip_special_tokens=True,
    clean_up_tokenization_spaces=True,
)

In [40]:
test_pred_model1 = [each.strip() for each in test_pred_model1]

In [41]:
test_pred_model1[:2]

["The long-term goal of our field is the creation and understanding of intelligence. Productive research in AI, both practical and theoretical, benefits from a notion of intelligence that is precise enough to allow the cumulative development of robust systems and general results. This paper outlines a gradual evolution in our formal conception of intelligence, that brings it closer to our informal conception and simultaneously reduces the gap between theory and practice. The article presents experimental results illustrating the agents' dynamic behavior. I. Introduction, 488. — II. The model with automobiles as an example, 489. — III. Examples and applications, 492. — IV.",
 '“Interaction in virtual reality (VR) environments is essential to ensure a pleasant and immersive experience. In this work, we propose a visually realistic, flexible and robust grasping system that enables real-time interactions in virtual environments. Resulting grasps are visually realistic because hand is autom

In [43]:
scores = rouge.compute(
    predictions=test_pred_model1,
    references=dataset_processed["test"]["related_work"],
    use_stemmer=True,
)
scores

{'rouge1': AggregateScore(low=Score(precision=0.2933487638578587, recall=0.29200709832037636, fmeasure=0.2815715983753184), mid=Score(precision=0.31049575936025475, recall=0.30798508735871666, fmeasure=0.29432358485593413), high=Score(precision=0.3280951076358564, recall=0.3250701177873674, fmeasure=0.30592633047634554)),
 'rouge2': AggregateScore(low=Score(precision=0.04350232238478724, recall=0.04349789242499366, fmeasure=0.041407873852706695), mid=Score(precision=0.05183538594707471, recall=0.052840695590428914, fmeasure=0.04971452590462587), high=Score(precision=0.06038594419192698, recall=0.06344514610297083, fmeasure=0.058800723232261924)),
 'rougeL': AggregateScore(low=Score(precision=0.15518199067772578, recall=0.15768831899322766, fmeasure=0.15026841424341691), mid=Score(precision=0.16485430398692857, recall=0.16904587700574214, fmeasure=0.15823517427974215), high=Score(precision=0.1745321488156862, recall=0.18136510192458355, fmeasure=0.16781062293071153)),
 'rougeLsum': Aggr

# Sandbox

In [None]:
%%time

led_output_model1 = model.generate(
    **dataset_tokenized["test"],
    # input_ids=dataset_tokenized["test"]["input_ids"][:n],
    # attention_mask=dataset_tokenized["test"]["attention_mask"][:n],
    # global_attention_mask=dataset_tokenized["test"]["global_attention_mask"][:n],
    no_repeat_ngram_size=3,
    max_length=128,
    num_beams=4,
)

In [41]:
led_output_model1 = []
for i in tqdm(range(0, len(dataset_tokenized["test"]["input_ids"]), BATCH_SIZE)):
    
    input_ids = dataset_tokenized["test"]["input_ids"][i:i+BATCH_SIZE]
    attention_mask = dataset_tokenized["test"]["attention_mask"][i:i+BATCH_SIZE]
    global_attention_mask = dataset_tokenized["test"]["global_attention_mask"][i:i+BATCH_SIZE]

    led_output_model1.append(
        model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            global_attention_mask=global_attention_mask,
            no_repeat_ngram_size=3,
            max_length=128,
            num_beams=4,
        )
    )
        
led_output_model1 = torch.cat(led_output_model1)

100%|████████████████████████████████████████████████████████████████████████████████████| 80/80 [3:11:58<00:00, 143.98s/it]


In [42]:
test_pred_model1 = tokenizer.batch_decode(
    led_output_model1,
    skip_special_tokens=True,
    clean_up_tokenization_spaces=True,
)

In [43]:
test_pred_model1 = [each.strip() for each in test_pred_model1]

In [44]:
test_pred_model1[:2]

["The long-term goal of our field is the creation and understanding of intelligence. Productive research in AI, both practical and theoretical, benefits from a notion of intelligence that is precise enough to allow the cumulative development of robust systems and general results. This paper outlines a gradual evolution in our formal conception of intelligence, that brings it closer to our informal conception and simultaneously reduces the gap between theory and practice. The article presents experimental results illustrating the agents' dynamic behavior. I. Introduction, 488. — II. The model with automobiles as an example, 489. — III. Examples and applications, 492. — IV.",
 '“Interaction in virtual reality (VR) environments (e.g. grasping and manipulating virtual objects) is essential to ensure a pleasant and immersive experience. In this work, we propose a visually realistic, flexible and robust grasping system that enables real-time interactions in virtual environments. Resulting gr

In [45]:
scores = rouge.compute(
    predictions=test_pred_model1,
    references=dataset_processed["test"]["related_work"],
    use_stemmer=True,
)
scores

{'rouge1': AggregateScore(low=Score(precision=0.2982484009370646, recall=0.3057877687297647, fmeasure=0.2877549383290699), mid=Score(precision=0.3011000320975565, recall=0.308186529440017, fmeasure=0.28974742885739196), high=Score(precision=0.30423223116720244, recall=0.31058012470155627, fmeasure=0.2917896849338474)),
 'rouge2': AggregateScore(low=Score(precision=0.046860562392848866, recall=0.04742509910853629, fmeasure=0.04476525721466253), mid=Score(precision=0.04805986457733996, recall=0.048611032584090635, fmeasure=0.0458330562332894), high=Score(precision=0.04916571328652274, recall=0.0498130873004002, fmeasure=0.046871122190935706)),
 'rougeL': AggregateScore(low=Score(precision=0.15540080811555254, recall=0.1636225890270709, fmeasure=0.15142010390078578), mid=Score(precision=0.15687893432752667, recall=0.16523845626801453, fmeasure=0.1524776573706927), high=Score(precision=0.15834376408384634, recall=0.16680563034698537, fmeasure=0.15364196168214198)),
 'rougeLsum': AggregateS