## ðŸ¤— Finetune **Longformer Encoder-Decoder (LED)** on 8K Tokens ðŸ¤—

The *Longformer Encoder-Decoder (LED)* was recently added as an extension to [Longformer: The Long-Document Transformer](https://arxiv.org/abs/2004.05150) by Iz Beltagy, Matthew E. Peters, Arman Cohan.

In this notebook we will finetune *LED* for Summarization on [Pubmed](https://huggingface.co/datasets/viewer/?dataset=scientific_papers). *Pubmed* is a long-range summarization dataset, which makes it a good candidate for LED. LED will be finetuned up to an input length of 8K tokens on a single GPU.

We will leverage ðŸ¤—`Seq2SeqTrainer`, gradient checkpointing and as usual ðŸ¤—`datasets`.

First, let's try to get a GPU with at least 15GB RAM.

In [None]:
# crash colab to get more RAM
# !kill -9 -1

To check that we are having enough RAM we can run the following command.
If the randomely allocated GPU is too small, the above cells can be run 
to crash the notebook hoping to get a better GPU.

In [None]:
!nvidia-smi

NVIDIA-SMI has failed because it couldn't communicate with the NVIDIA driver. Make sure that the latest NVIDIA driver is installed and running.



Next, we install ðŸ¤—Transformers, ðŸ¤—Datasets, and `rouge_score`.



In [None]:
%%capture
!pip install datasets==2.10.1
!pip install transformers==4.2.0
!pip install rouge_score

Let's start by loading and preprocessing the dataset.



In [None]:
import datasets
import random
import pandas as pd

from datasets import load_dataset, load_metric
from functools import partial
from IPython.display import display, HTML
from transformers import AutoTokenizer
from transformers import AutoModelForSeq2SeqLM
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

In [None]:
# Load the metric scoring object early
rouge = load_metric("rouge")

  rouge = load_metric("rouge")


Next, we download the pubmed train and validation dataset ([click to see on ðŸ¤—Datasets Hub](https://huggingface.co/datasets/scientific_papers)). This can take a couple of minutes **â˜•** .

In [None]:
from typing import Tuple, Optional
def get_dataset(data: str, host: Optional[str] = None) -> Tuple:
  """Getting the training and validation data for our models"""

  train_dataset = load_dataset(data, host, split="train")
  val_dataset = load_dataset(data, host, split="validation")

  return (train_dataset, val_dataset)

It's always a good idea to take a look at some data samples. Let's do that here.

In [None]:
centrum_train, centrum_val = get_dataset(data="multi_x_science_sum")
led_train, led_val = get_dataset(data="scientific_papers", host="pubmed")



In [None]:
print(centrum_train, led_train)

Dataset({
    features: ['aid', 'mid', 'abstract', 'related_work', 'ref_abstract'],
    num_rows: 30369
}) Dataset({
    features: ['article', 'abstract', 'section_names'],
    num_rows: 119924
})


In [None]:
def show_random_elements(dataset, num_examples=4):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

show_random_elements(centrum_train)

Unnamed: 0,aid,mid,abstract,related_work,ref_abstract
0,0704.3603,2950224979,"In this work we show that for every @math , such that for all @math where the parameters of the model do not depend on @math . They also provide a rare example where one can prove a polynomial time mixing of Gibbs sampler in a situation where the actual mixing time is slower than @math . Our proof exploits in novel ways the local treelike structure of Erd o s-R 'enyi random graphs, comparison and block dynamics arguments and a recent result of Weitz. Our results extend to much more general families of graphs which are sparse in some average sense and to much more general interactions. In particular, they apply to any graph for which every vertex @math of the graph has a neighborhood @math of radius @math in which the induced sub-graph is a tree union at most @math edges and where for each simple path in @math the sum of the vertex degrees along the path is @math . Moreover, our result apply also in the case of arbitrary external fields and provide the first FPRAS for sampling the Ising distribution in this case. We finally present a non Markov Chain algorithm for sampling the distribution which is effective for a wider range of parameters. In particular, for @math it applies for all external fields and @math , where @math is the critical point for decay of correlation for the Ising model on @math .","Much work has been focused on the problem of understanding the mixing time of the Ising model in various contexts. In a series of results @cite_11 @cite_16 @cite_3 culminating in @cite_21 it was shown that the Gibbs sampler on integer lattice mixes rapidly when the model has the strong spatial mixing property. In @math strong spatial mixing, and therefore rapid mixing, holds in the entire uniqueness regime (see e.g. @cite_7 ). On the regular tree the mixing time is always polynomial but is only @math up to the threshold for extremity @cite_18 . For completely general graphs the best known results are given by the Dobrushin condition which establishes rapid mixing when @math where @math is the maximum degree.","{'cite_N': ['@cite_18', '@cite_7', '@cite_21', '@cite_3', '@cite_16', '@cite_11'], 'mid': ['2611336766', '2011373957', '2088794759', '', '', '1770973266'], 'abstract': ['We study discrete time Glauber dynamics for random configurations with local constraints (e.g. proper coloring, Ising and Potts models) on finite graphs with n vertices and of bounded degree. We show that the relaxation time (defined as the reciprocal of the spectral gap 1 - _2 for the dynamics on trees and on certain hyperbolic graphs is polynomial in n. For these hyperbolic graphs, this yields a general polynomial sampling algorithm for random configurations. We then show that if the relaxation time T2 satisfies T2 = O(n), then the correlation coefficient, and the mutual information, between any local function (which dependsonly on the configuration in a fixed window) and the boundary conditions, decays exponentially in the distance between the window and the boundary. For the Ising model on a regular tree, this condition is sharp.', 'Various finite volume mixing conditions in classical statistical mechanics are reviewed and critically analyzed. In particular somefinite size conditions are discussed, together with their implications for the Gibbs measures and for the approach to equilibrium of Glauber dynamics inarbitrarily large volumes. It is shown that Dobrushin-Shlosman's theory ofcomplete analyticity and its dynamical counterpart due to Stroock and Zegarlinski, cannot be applied, in general, to the whole one phase region since it requires mixing properties for regions ofarbitrary shape. An alternative approach, based on previous ideas of Oliveri, and Picco, is developed, which allows to establish results on rapid approach to equilibrium deeply inside the one phase region. In particular, in the ferromagnetic case, we considerably improve some previous results by Holley and Aizenman and Holley. Our results are optimal in the sene that, for example, they show for the first time fast convergence of the dynamicsfor any temperature above the critical one for thed-dimensional Ising model with or without an external field. In part II we extensively consider the general case (not necessarily attractive) and we develop a new method, based on renormalizations group ideas and on an assumption of strong mixing in a finite cube, to prove hypercontractivity of the Markov semigroup of the Glauber dynamics.', 'For finite range lattice gases with a finite spin space, it is shown that the Dobrushin-Shlosman mixing condition is equivalent to the existence of a logarithmic Sobolev inequality for the associated (unique) Gibbs state. In addition, implications of these considerations for the ergodic properties of the corresponding Glauber dynamics are examined.', '', '', 'We show that, under the conditions of the Dobrushin Shlosman theorem for uniqueness of the Gibbs state, the reversible stochastic Ising model converges to equilibrium exponentially fast on the L2 space of that Gibbs state. For stochastic Ising models with attractive interactions and under conditions which are somewhat stronger than Dobrushinâ€™s, we prove that the semi-group of the stochastic Ising model converges to equilibrium exponentially fast in the uniform norm. We also give a new, much shorter, proof of a theorem which says that if the semi-group of an attractive spin flip system converges to equilibrium faster than 1 td where d is the dimension of the underlying lattice, then the convergence must be exponentially fast.']}"
1,1806.07585,2919006242,"Extending R. A. Fisher and D. A. Freedman's results on the analysis of covariance, Lin [2013] proposed an ordinary least squares adjusted estimator of the average treatment effect in completely randomized experiments. We further study its statistical properties under the potential outcomes model in the asymptotic regimes allowing for a diverging number of covariates. We show that Lin [2013]'s estimator is consistent when @math and asymptotically normal when @math under mild moment conditions, where @math is the maximum leverage score of the covariate matrix. In the favorable case where leverage scores are all close together, his estimator is consistent when @math and is asymptotically normal when @math . In addition, we propose a bias-corrected estimator that is consistent when @math and is asymptotically normal, with the same variance in the fixed- @math regime, when @math . In the favorable case, the latter condition reduces to @math . Similar to Lin [2013], our results hold for non-random potential outcomes and covariates without any model specification. Our analysis requires novel analytic tools for sampling without replacement, which complement and potentially enrich the theory in other areas such as survey sampling, matrix sketching, and transductive learning.","Theoretical analyses under the finite-population randomization model are challenging due to the lack of probability tools. The closest work to ours is @cite_8 , which allows @math to grow with @math and potentially exceed @math . However, they assume that the potential outcomes have sparse linear representations based on the covariates, and require @math where @math is a measure of sparsity. Under additional regularities conditions, they show that @math is consistent and asymptotically normal with @math being the LASSO coefficients of the covariates. Although the LASSO-adjusted estimator can handle ultra-high dimensional case where @math , it has three limitations. First, the requirement @math is stringent. For instance, the PAC-man dataset considered by @cite_8 has @math and @math , so the condition reads @math , which implicitly imposes a strong sparse modelling assumption.","{'cite_N': ['@cite_8'], 'mid': ['2963608360'], 'abstract': ['We provide a principled way for investigators to analyze randomized experiments when the number of covariates is large. Investigators often use linear multivariate regression to analyze randomized experiments instead of simply reporting the difference of means between treatment and control groups. Their aim is to reduce the variance of the estimated treatment effect by adjusting for covariates. If there are a large number of covariates relative to the number of observations, regression may perform poorly because of overfitting. In such cases, the least absolute shrinkage and selection operator (Lasso) may be helpful. We study the resulting Lasso-based treatment effect estimator under the Neymanâ€“Rubin model of randomized experiments. We present theoretical conditions that guarantee that the estimator is more efficient than the simple difference-of-means estimator, and we provide a conservative estimator of the asymptotic variance, which can yield tighter confidence intervals than the difference-of-means estimator. Simulation and data examples show that Lasso-based adjustment can be advantageous even when the number of covariates is less than the number of observations. Specifically, a variant using Lasso for selection and ordinary least squares (OLS) for estimation performs particularly well, and it chooses a smoothing parameter based on combined performance of Lasso and OLS.']}"
2,1301.1590,2950513089,"It has been shown that minimum free energy structure for RNAs and RNA-RNA interaction is often incorrect due to inaccuracies in the energy parameters and inherent limitations of the energy model. In contrast, ensemble based quantities such as melting temperature and equilibrium concentrations can be more reliably predicted. Even structure prediction by sampling from the ensemble and clustering those structures by Sfold [7] has proven to be more reliable than minimum free energy structure prediction. The main obstacle for ensemble based approaches is the computational complexity of the partition function and base pairing probabilities. For instance, the space complexity of the partition function for RNA-RNA interaction is @math and the time complexity is @math which are prohibitively large [4,12]. Our goal in this paper is to give a fast algorithm, based on sparse folding, to calculate an upper bound on the partition function. Our work is based on the recent algorithm of Hazan and Jaakkola [10]. The space complexity of our algorithm is the same as that of sparse folding algorithms, and the time complexity of our algorithm is @math for single RNA and @math for RNA-RNA interaction in practice, in which @math is the running time of sparse folding and @math ( @math ) is a sequence dependent parameter.","Methods to the partition function for interacting RNAs have been proposed in the literature. Instead, methods for comutation of the partition function have been developed, having high both time and space complexity. Most notably, @cite_5 developed an @math --time and @math --space dynamic programming algorithm that computes the partition function of RNA--RNA interaction complexes, thereby providing detailed insights into their thermodynamic properties. @cite_4 has developed a algorithm that produces a Boltzmann weighted ensemble of RNAâ€“-RNA interaction structures for the calculation of (and not the partition function) for any given interval on the target RNAs.","{'cite_N': ['@cite_5', '@cite_4'], 'mid': ['2166268906', '2099306789'], 'abstract': ['Recent interests, such as RNA interference and antisense RNA regulation, strongly motivate the problem of predicting whether two nucleic acid strands interact. Motivation: Regulatory non-coding RNAs (ncRNAs) such as microRNAs play an important role in gene regulation. Studies on both prokaryotic and eukaryotic cells show that such ncRNAs usually bind to their target mRNA to regulate the translation of corresponding genes. The specificity of these interactions depends on the stability of intermolecular and intramolecular base pairing. While methods like deep sequencing allow to discover an ever increasing set of ncRNAs, there are no high-throughput methods available to detect their associated targets. Hence, there is an increasing need for precise computational target prediction. In order to predict base-pairing probability of any two bases in interacting nucleic acids, it is necessary to compute the interaction partition function over the whole ensemble. The partition function is a scalar value from which various thermodynamic quantities can be derived. For example, the equilibrium concentration of each complex nucleic acid species and also the melting temperature of interacting nucleic acids can be calculated based on the partition function of the complex. Results: We present a model for analyzing the thermodynamics of two interacting nucleic acid strands considering the most general type of interactions studied in the literature. We also present a corresponding dynamic programming algorithm that computes the partition function over (almost) all physically possible joint secondary structures formed by two interacting nucleic acids in O(n6) time. We verify the predictive power of our algorithm by computing (i) the melting temperature for interacting RNA pairs studied in the literature and (ii) the equilibrium concentration for several variants of the OxySâ€“fhlA complex. In both experiments, our algorithm shows high accuracy and outperforms competitors. Availability: Software and web server is available at http: compbio.cs.sfu.ca taverna pirna Contact:cenk@cs.sfu.ca; backofen@informatik.uni-freiburg.de Supplementary information:Supplementary data are avaliable at Bioinformatics online.', 'Motivation: It has been proven that the accessibility of the target sites has a critical influence on RNAâ€“RNA binding, in general and the specificity and efficiency of miRNAs and siRNAs, in particular. Recently, O(N6) time and O(N4) space dynamic programming (DP) algorithms have become available that compute the partition function of RNAâ€“RNA interaction complexes, thereby providing detailed insights into their thermodynamic properties. Results: Modifications to the grammars underlying earlier approaches enables the calculation of interaction probabilities for any given interval on the target RNA. The computation of the â€˜hybrid probabilitiesâ€™ is complemented by a stochastic sampling algorithm that produces a Boltzmann weighted ensemble of RNAâ€“RNA interaction structures. The sampling of k structures requires only negligible additional memory resources and runs in O(kÂ·N3). Availability: The algorithms described here are implemented in C as part of the rip package. The source code of rip2 can be downloaded from http: www.combinatorics.cn cbpc rip.html and http: www.bioinf.uni-leipzig.de Software rip.html. Contact: duck@santafe.edu Supplementary information:Supplementary data are available at Bioinformatics online.']}"
3,cs0204018,2952290399,"We study one dimension in program evolution, namely the evolution of the datatype declarations in a program. To this end, a suite of basic transformation operators is designed. We cover structure-preserving refactorings, but also structure-extending and -reducing adaptations. Both the object programs that are subject to datatype transformations, and the meta programs that encode datatype transformations are functional programs.","There is a large body of research addressing the related problem of database schema evolution @cite_11 as relevant, for example, in database re- and reverse engineering @cite_9 . The schema transformations themselves can be compared with our datatype transformations only at a superficial level because of the different formalisms involved. There exist formal frameworks for the definition of schema transformations and various formalisms have been investigated @cite_6 . An interesting aspect of database schema evolution is that schema evolution necessitates a database instance mapping @cite_18 . Compare this with the evolution of the datatypes in a functional program. Here, the main concern is to update the function declarations for compliance with the new datatypes. It seems that the instance mapping problem is a special case of the program update problem.","{'cite_N': ['@cite_9', '@cite_18', '@cite_6', '@cite_11'], 'mid': ['1544920330', '', '2583677609', '2215315499'], 'abstract': ['The paper presents a DBMS-independent database reverse engineering (DBRE) methodology based on a generic process model and on transformation techniques. DBRE is proposed as a two-phase process consisting in recovering the DBMS-dependent data structures (data structure extraction) then in recovering their semantics (data structure conceptualization). The second phase, that is strongly linked with the logical design phase of current database design methodologies, can be performed by application of a selected set of standard schema restructuring techniques, or schema transformations. The paper illustrates the methodology by applying it to various DBRE processes : removing optimization structures, untransfating Relational, COBOL, CODASYL, TOTAL IMAGE and IMS database as well as file structures, and finally conceptual normalization.', '', 'Several methodologies for semantic schema integration have been proposed in the literature, often using some variant of the ER model as the common data model. As part of these methodologies, various transformations have been defined that map between ER schemas which are in some sense equivalent. This paper gives a unifying formalisation of the ER schema transformation process and shows how some common schema transformations can be expressed within this single framework. Our formalism clearly identifies which transformations apply for any instance of the schema and which only for certain instances.', 'Object-oriented programming is well-suited to such data-intensive application domains as CAD CAM, AI, and OIS (office information systems) with multimedia documents. At MCC we have built a prototype object-oriented database system, called ORION. It adds persistence and sharability to objects created and manipulated in applications implemented in an object-oriented programming environment. One of the important requirements of these applications is schema evolution, that is, the ability to dynamically make a wide variety of changes to the database schema. In this paper, following a brief review of the object-oriented data model that we support in ORION, we establish a framework for supporting schema evolution, define the semantics of schema evolution, and discuss its implementation.']}"


In [None]:
# Non-consecutive added token '<doc-sep>' found. Should have index 50266 but has index 50265 in saved vocabulary (Centrum).

def get_tokenizer(host_tokenizer: str):
  """return the tokenizer for LLM training"""

  return AutoTokenizer.from_pretrained(host_tokenizer)


led_tokenizer = get_tokenizer("allenai/led-base-16384")
# centrum_tokenizer = get_tokenizer("ratishsp/Centrum")

Note that for the sake of this notebook, we finetune the "smaller" LED checkpoint ["allenai/led-base-16384"](https://huggingface.co/allenai/led-base-16384). Better performance can however be attained by finetuning ["allenai/led-large-16384"](https://huggingface.co/allenai/led-large-16384) at the cost of a higher required GPU RAM.

Now, let's write down the input data processing function that will be used to map each data sample to the correct model format.
As explained earlier `article` represents here our input data and `abstract` is the target data. The datasamples are thus tokenized up to the respective maximum lengths of 8192 and 512.

In addition to the usual `attention_mask`, LED can make use of an additional `global_attention_mask` defining which input tokens are attended globally and which are attended only locally, just as it's the case of [Longformer](https://huggingface.co/transformers/model_doc/longformer.html). For more information on Longformer's self-attention, please take a look at the corresponding [docs](https://huggingface.co/transformers/model_doc/longformer.html#longformer-self-attention). For summarization, we follow recommendations of the [paper](https://arxiv.org/abs/2004.05150) and use global attention only for the very first token. Finally, we make sure that no loss is computed on padded tokens by setting their index to `-100`.

In [None]:
# Setting up input/output parameters
max_input_length = 8192
max_output_length = 512
batch_size = 2

def process_data_to_model_inputs(batch, model_tokenizer):
    # tokenize the inputs and labels
    inputs = model_tokenizer(
        batch["article"],
        padding="max_length",
        truncation=True,
        max_length=max_input_length,
    )
    outputs = model_tokenizer(
        batch["abstract"],
        padding="max_length",
        truncation=True,
        max_length=max_output_length,
    )

    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask

    # create 0 global_attention_mask lists
    batch["global_attention_mask"] = len(batch["input_ids"]) * [
        [0 for _ in range(len(batch["input_ids"][0]))]
    ]

    # since above lists are references, the following line changes the 0 index for all samples
    batch["global_attention_mask"][0][0] = 1
    batch["labels"] = outputs.input_ids

    # We have to make sure that the PAD token is ignored
    batch["labels"] = [
        [-100 if token == model_tokenizer.pad_token_id else token for token in labels]
        for labels in batch["labels"]
    ]

    return batch

For the sake of this notebook, we will reduce the training and validation data 
to a dummy dataset of sizes 250 and 25 respectively. For a full training run, those lines should be commented out.

Great, having defined the mapping function, let's preprocess the training data

In [None]:
def prep_and_convert_data(train: datasets.arrow_dataset.Dataset, validation: datasets.arrow_dataset.Dataset, train_range: Optional[int] = None, validation_range: Optional[int] = None) -> Tuple:
  """Processing the training and validation dataset to be trained"""

  processed_model_data = partial(process_data_to_model_inputs, model_tokenizer=led_tokenizer)

  if train_range and validation_range:
    train_dataset = train.select(range(train_range))
    val_dataset = validation.select(range(validation_range))
  else:
    train_dataset = train
    val_dataset = validation

  train_dataset = train_dataset.map(
      processed_model_data,
      batched=True,
      batch_size=batch_size,
      remove_columns=["article", "abstract", "section_names"],
  )
  val_dataset = val_dataset.map(
    processed_model_data,
    batched=True,
    batch_size=batch_size,
    remove_columns=["article", "abstract", "section_names"],
  )
  train_dataset.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "global_attention_mask", "labels"],
  )
  val_dataset.set_format(
      type="torch",
      columns=["input_ids", "attention_mask", "global_attention_mask", "labels"],
  )

  return (train_dataset, val_dataset)


train_dataset, val_dataset = prep_and_convert_data(train=led_train, validation=led_val, train_range=250, validation_range=50)



Map:   0%|          | 0/50 [00:00<?, ? examples/s]

We've decided to stick to the smaller model `"allenai/led-base-16384"` for the sake of this notebook. In addition, we directly enable gradient checkpointing and disable the caching mechanism to save memory.

In [None]:
print(val_dataset, val_dataset['labels'])

Dataset({
    features: ['input_ids', 'attention_mask', 'global_attention_mask', 'labels'],
    num_rows: 50
}) tensor([[    0,  3618,     8,  ...,  -100,  -100,  -100],
        [    0,  3618,   627,  ...,  -100,  -100,  -100],
        [    0,  1437, 50118,  ...,  -100,  -100,  -100],
        ...,
        [    0,    52,   266,  ...,  -100,  -100,  -100],
        [    0,  5283,   390,  ...,  -100,  -100,  -100],
        [    0,  9695, 11474,  ...,  -100,  -100,  -100]])


In [None]:
def get_model(model_host: str):
  """Get either the LED or Centrum model"""

  return AutoModelForSeq2SeqLM.from_pretrained(model_host, gradient_checkpointing=True, use_cache=False)

led = get_model(model_host="allenai/led-base-16384")

During training, we want to evaluate the model on Rouge, the most common metric used in summarization, to make sure the model is indeed improving during training. For this, we set fitting generation parameters. We'll use beam search with a small beam of just 2 to save memory. Also, we force the model to generate at least 100 tokens, but no more than 512. In addition, some other generation parameters are set that have been found helpful for generation. For more information on those parameters, please take a look at the [docs](https://huggingface.co/transformers/main_classes/model.html?highlight=generate#transformers.generation_utils.GenerationMixin.generate).

In [None]:
# set generate hyperparameters
led.config.num_beams = 2
led.config.max_length = 512
led.config.min_length = 100
led.config.length_penalty = 2.0
led.config.early_stopping = True
led.config.no_repeat_ngram_size = 3

In [None]:
# Compute metrics for rouge
def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    rouge_output = rouge.compute(
        predictions=pred_str, references=label_str, rouge_types=["rouge2"]
    )["rouge2"].mid

    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }

Now, we're ready to start training. Let's import the `Seq2SeqTrainer` and `Seq2SeqTrainingArguments`.

In contrast to the usual `Trainer`, the `Seq2SeqTrainer` makes it possible to use the `generate()` function during evaluation. This should be enabled with `predict_with_generate=True`. Because our GPU RAM is limited, we make use of gradient accumulation by setting `gradient_accumulation_steps=4` to have an effective `batch_size` of 2 * 4 = 8.

Other training arguments can be read upon in the [docs](https://huggingface.co/transformers/main_classes/trainer.html?highlight=trainingarguments#transformers.TrainingArguments).

In [None]:
# enable fp16 apex training

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    fp16=False,
    output_dir="./",
    logging_steps=5,
    eval_steps=10,
    save_steps=10,
    save_total_limit=2,
    gradient_accumulation_steps=4,
    num_train_epochs=1,
)

The training arguments, along with the model, tokenizer, datasets and the `compute_metrics` function can then be passed to the `Seq2SeqTrainer`

In [None]:
trainer = Seq2SeqTrainer(
    model=led,
    tokenizer=led_tokenizer,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

and we can start training. This will take about ~35min.

In [None]:
trainer.train()

Step,Training Loss,Validation Loss




KeyboardInterrupt: ignored

This completes the fine-tuning tutorial for LED. This training script with some small changes was used to train [this](https://huggingface.co/patrickvonplaten/led-large-16384-pubmed) checkpoint, called `" patrickvonplaten/led-large-16384-pubmed"` on a single GPU for ca. 3 days. Evaluating `" patrickvonplaten/led-large-16384-pubmed"` on Pubmed's test data gives a Rouge-2 score of **19.33** which is around 1 Rouge-2 point below SOTA performance on Pubmed.

In the Appendix below, the condensed training and evaluation scripts that were used locally to finetune `" patrickvonplaten/led-large-16384-pubmed"` are attached.

# **Appendix**

## Training

In [None]:
#!/usr/bin/env python3
from datasets import load_dataset, load_metric
from transformers import (
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
)

# load rouge
rouge = load_metric("rouge")

# load pubmed
pubmed_train = load_dataset("scientific_papers", "pubmed", ignore_verifications=True, split="train")
pubmed_val = load_dataset("scientific_papers", "pubmed", ignore_verifications=True, split="validation[:10%]")

# comment out following lines for a test run
# pubmed_train = pubmed_train.select(range(32))
# pubmed_val = pubmed_val.select(range(32))

# load tokenizer
tokenizer = AutoTokenizer.from_pretrained("allenai/led-large-16384")


# max encoder length is 8192 for PubMed
encoder_max_length = 8192
decoder_max_length = 512
batch_size = 2


def process_data_to_model_inputs(batch):
    # tokenize the inputs and labels
    inputs = tokenizer(
        batch["article"],
        padding="max_length",
        truncation=True,
        max_length=encoder_max_length,
    )
    outputs = tokenizer(
        batch["abstract"],
        padding="max_length",
        truncation=True,
        max_length=decoder_max_length,
    )

    batch["input_ids"] = inputs.input_ids
    batch["attention_mask"] = inputs.attention_mask

    # create 0 global_attention_mask lists
    batch["global_attention_mask"] = len(batch["input_ids"]) * [
        [0 for _ in range(len(batch["input_ids"][0]))]
    ]

    # since above lists are references, the following line changes the 0 index for all samples
    batch["global_attention_mask"][0][0] = 1
    batch["labels"] = outputs.input_ids

    # We have to make sure that the PAD token is ignored
    batch["labels"] = [
        [-100 if token == tokenizer.pad_token_id else token for token in labels]
        for labels in batch["labels"]
    ]

    return batch


# map train data
pubmed_train = pubmed_train.map(
    process_data_to_model_inputs,
    batched=True,
    batch_size=batch_size,
    remove_columns=["article", "abstract", "section_names"],
)

# map val data
pubmed_val = pubmed_val.map(
    process_data_to_model_inputs,
    batched=True,
    batch_size=batch_size,
    remove_columns=["article", "abstract", "section_names"],
)

# set Python list to PyTorch tensor
pubmed_train.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "global_attention_mask", "labels"],
)

# set Python list to PyTorch tensor
pubmed_val.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "global_attention_mask", "labels"],
)

# enable fp16 apex training
training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    fp16=True,
    fp16_backend="apex",
    output_dir="./",
    logging_steps=250,
    eval_steps=5000,
    save_steps=500,
    warmup_steps=1500,
    save_total_limit=2,
    gradient_accumulation_steps=4,
)


# compute Rouge score during validation
def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    rouge_output = rouge.compute(
        predictions=pred_str, references=label_str, rouge_types=["rouge2"]
    )["rouge2"].mid

    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }


# load model + enable gradient checkpointing & disable cache for checkpointing
led = AutoModelForSeq2SeqLM.from_pretrained("allenai/led-large-16384", gradient_checkpointing=True, use_cache=False)

# set generate hyperparameters
led.config.num_beams = 4
led.config.max_length = 512
led.config.min_length = 100
led.config.length_penalty = 2.0
led.config.early_stopping = True
led.config.no_repeat_ngram_size = 3


# instantiate trainer
trainer = Seq2SeqTrainer(
    model=led,
    tokenizer=tokenizer,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=pubmed_train,
    eval_dataset=pubmed_val,
)

# start training
trainer.train()

## Evaluation

In [None]:
import torch

from datasets import load_dataset, load_metric
from transformers import LEDTokenizer, LEDForConditionalGeneration

# load pubmed
pubmed_test = load_dataset("scientific_papers", "pubmed", ignore_verifications=True, split="test")

# load tokenizer
tokenizer = LEDTokenizer.from_pretrained("patrickvonplaten/led-large-16384-pubmed")
model = LEDForConditionalGeneration.from_pretrained("patrickvonplaten/led-large-16384-pubmed").to("cuda").half()


def generate_answer(batch):
  inputs_dict = tokenizer(batch["article"], padding="max_length", max_length=8192, return_tensors="pt", truncation=True)
  input_ids = inputs_dict.input_ids.to("cuda")
  attention_mask = inputs_dict.attention_mask.to("cuda")
  global_attention_mask = torch.zeros_like(attention_mask)
  # put global attention on <s> token
  global_attention_mask[:, 0] = 1

  predicted_abstract_ids = model.generate(input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask)
  batch["predicted_abstract"] = tokenizer.batch_decode(predicted_abstract_ids, skip_special_tokens=True)
  return batch


result = pubmed_test.map(generate_answer, batched=True, batch_size=4)

# load rouge
rouge = load_metric("rouge")

print("Result:", rouge.compute(predictions=result["predicted_abstract"], references=result["abstract"], rouge_types=["rouge2"])["rouge2"].mid)
