Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.

## Extractive Summarization on CNN/DM Dataset using Transformer Version of BertSum


### Summary

This notebook demonstrates how to fine tune Transformers for extractive text summarization. Utility functions and classes in the NLP Best Practices repo are used to facilitate data preprocessing, model training, model scoring, result postprocessing, and model evaluation.

BertSum refers to  [Fine-tune BERT for Extractive Summarization](https://arxiv.org/pdf/1903.10318.pdf) with [published example](https://github.com/nlpyang/BertSum/). And the Transformer version of Bertsum refers to our modification of BertSum and the source code can be accessed at (https://github.com/daden-ms/BertSum/). 

Extractive summarization are usually used in document summarization where each input document consists of mutiple sentences. The preprocessing of the input training data involves assigning label 0 or 1 to the document sentences based on the give summary. The summarization problem is also simplfied to classifying whether a document sentence should be included in the summary. 

The figure below illustrates how BERTSum can be fine tuned for extractive summarization task. [CLS] token is inserted at the beginning of each sentence, so is [SEP] token at the end. Interval segment embedding and positional embedding are added upon the token embedding as the input of the BERT model. The [CLS] token representation is used as sentence embedding and only the [CLS] tokens are used as the input for the summarization model. The summarization layer predicts the probability for each  sentence being included in the summary. Techniques like trigram blocking can be used to improve model accuarcy.   

<img src="https://nlpbp.blob.core.windows.net/images/BertSum.PNG">


### Before You Start

The running time shown in this notebook is on a Standard_NC24s_v3 Azure Ubuntu Virtual Machine with 4 NVIDIA Tesla V100 GPUs. 
> **Tip**: If you want to run through the notebook quickly, you can set the **`QUICK_RUN`** flag in the cell below to **`True`** to run the notebook on a small subset of the data and a smaller number of epochs. 

Using only 1 NVIDIA Tesla V100 GPUs, 16GB GPU memory configuration,
- for data preprocessing, it takes around 1 minutes to preprocess the data for quick run. Otherwise it takes ~20 minutes to finish the data preprocessing. This time estimation assumes that the chosen transformer model is "distilbert-base-uncased" and the sentence selection method is "greedy", which is the default. The preprocessing time can be significantly longer if the sentence selection method is "combination", which can achieve better model performance.

- for model fine tuning, it takes around 2 minutes for quick run. Otherwise, it takes around ~3 hours to finish. This estimation assumes the chosen encoder method is "transformer". The model fine tuning time can be shorter if other encoder method is chosen, which may result in worse model performance. 

### Additional Notes

* **ROUGE Evalation**: To run rouge evaluation, please refer to the section of compute_rouge_perl in [summarization_evaluation.ipynb](./summarization_evaluation.ipynb) for setup.

* **Distributed Training**:
Please note that the jupyter notebook only allows to use pytorch [DataParallel](https://pytorch.org/docs/master/nn.html#dataparallel). Faster speed and larger batch size can be achieved with pytorch [DistributedDataParallel](https://pytorch.org/docs/master/notes/ddp.html)(DDP). Script [extractive_summarization_cnndm_distributed_train.py](./extractive_summarization_cnndm_distributed_train.py) shows an example of how to use DDP.



In [1]:
%load_ext autoreload

%autoreload 2

In [2]:
## Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.
QUICK_RUN = True
## Set USE_PREPROCSSED_DATA = True to skip the data preprocessing
USE_PREPROCSSED_DATA = False

### Configuration


In [3]:
import os
import shutil
import sys
from tempfile import TemporaryDirectory
import torch

nlp_path = os.path.abspath("../../")
if nlp_path not in sys.path:
    sys.path.insert(0, nlp_path)

from utils_nlp.dataset.swiss import SwissSummarizationDataset
from utils_nlp.dataset.bundes import BundesSummarizationDataset

from utils_nlp.models.transformers.datasets import SummarizationDataset
import nltk
from nltk import tokenize

import pandas as pd
import scrapbook as sb
import pprint

In [4]:
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

# This script reuses some code from https://github.com/nlpyang/BertSum

nlp_path = os.path.abspath("../../")
if nlp_path not in sys.path:
    sys.path.insert(0, nlp_path)

import functools
import itertools
import logging
import os
import pickle
from multiprocessing import Pool, cpu_count

import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoTokenizer, BertModel, DistilBertModel, XLMRobertaModel, XLMRobertaTokenizer

from utils_nlp.common.pytorch_utils import (
    compute_training_steps,
    get_device,
    move_model_to_device,
    parallelize_model,
)
from utils_nlp.dataset.sentence_selection import combination_selection, greedy_selection
from utils_nlp.models.transformers.abstractive_summarization_bertsum import (
    fit_to_block_size,
)

from utils_nlp.models.transformers.bertsum import model_builder
from utils_nlp.models.transformers.bertsum.data_loader import (
    Batch,
    ChunkDataLoader,
    IterableDistributedSampler,
)
from utils_nlp.models.transformers.bertsum.dataset import (
    ExtSumProcessedDataset,
    ExtSumProcessedIterableDataset,
)
from utils_nlp.models.transformers.bertsum.model_builder import BertSumExt
from utils_nlp.models.transformers.common import Transformer

  import pandas.util.testing as tm


In [5]:
MODEL_CLASS = {"bert-base-uncased": BertModel, 
               "bert-base-german-cased": BertModel, 
               "distilbert-base-uncased": BertModel,
               "dbmdz/bert-base-german-uncased": BertModel,
               "bert-base-german-dbmdz-cased": BertModel,
               "bert-base-multilingual-cased": BertModel,
               "distilbert-base-german-cased": BertModel,
               "bert-base-german-dbmdz-uncased": BertModel,
               "severinsimmler/bert-adapted-german-press": BertModel,
               "xlm-roberta-large-finetuned-conll03-german": XLMRobertaModel, # not sure about this one...
              }


logger = logging.getLogger(__name__)


class Bunch(object):
    """ Class which convert a dictionary to an object """

    def __init__(self, adict):
        self.__dict__.update(adict)

In [6]:
def get_dataloader(
    data_iter,
    shuffle=True,
    is_labeled=False,
    batch_size=3000,
    world_size=1,
    rank=0,
    local_rank=-1,
):
    """
    Function to get data iterator over a list of data objects.

    Args:
        data_iter (generator): Data generator.
        shuffle (bool): Whether the data is shuffled. Defaults to True.
        is_labeled (bool): Whether the data objects are labeled data.
                            Defaults to False.
        batch_size (int): Number of tokens per batch. Defaults to 3000.
        world_size (int): Total number of GPUs that will be used. Defaults to 1.
        rank (int): Rank of the current GPU. Defaults to -1.

    Returns:
        DataIterator
    """
    sampler = IterableDistributedSampler(world_size, rank, local_rank)
    return ChunkDataLoader(
        data_iter, batch_size, shuffle=shuffle, is_labeled=is_labeled, sampler=sampler
    )

In [7]:
def get_pred(
    example,
    sent_scores,
    cal_lead=False,
    sentence_separator="<q>",
    block_trigram=True,
    top_n=3,
):
    """
        Get the summarization prediction for the paragraph example based on the scores
        returned by the transformer summarization model.

        Args:
            example (str): The object with "src_txt" field as the paragraph which
                requries summarization. The "src_txt" is a list of strings.
            sent_scores (list of floats): List of scores of how likely of the
                sentence is included in the summary.
            cal_lead (bool, optional): Boolean value which specifies whether the
                prediction uses the first few sentences as summary. Defaults to False.
            sentence_separator (str, optional): Seperator used in the generated summary.
                Defaults to '<q>'.
            block_trigram (bool, optional): Boolean value which specifies whether the
                summary should include any sentence that has the same trigram as the
                already selected sentences. Defaults to True.
            top_n (int, optional): The maximum number of sentences that the summary
                should included. Defaults to 3.

        Returns:
            A string which is the summary for the example.
    """

    def _get_ngrams(n, text):
        ngram_set = set()
        text_length = len(text)
        max_index_ngram_start = text_length - n
        for i in range(max_index_ngram_start + 1):
            ngram_set.add(tuple(text[i : i + n]))
        return ngram_set

    def _block_tri(c, p):
        tri_c = _get_ngrams(3, c.split())
        for s in p:
            tri_s = _get_ngrams(3, s.split())
            if len(tri_c.intersection(tri_s)) > 0:
                return True
        return False

    selected_ids = np.argsort(-sent_scores)
    # selected_ids = np.argsort(-sent_scores, 1)
    if cal_lead:
        selected_ids = range(len(example["clss"]))

    pred = []
    _pred = []
    final_selections = []
    for j in selected_ids[: len(example["src_txt"])]:
        if j >= len(example["src_txt"]):
            continue
        candidate = example["src_txt"][j].strip()
        if block_trigram:
            if not _block_tri(candidate, _pred):
                _pred.append(candidate)
                final_selections.append(j)
        else:
            _pred.append(candidate)
            final_selections.append(j)

        # only select the top n
        if len(_pred) == top_n:
            break

    sorted_selections = sorted(final_selections)
    _pred = []
    for i in sorted_selections:
        _pred.append(example["src_txt"][i].strip())
    _pred = sentence_separator.join(_pred)
    pred.append(_pred.strip())
    return pred

In [8]:
class ExtSumProcessedData:
    """class loaded data preprocessed as in
    :class:`utils_nlp.models.transformers.datasets.SummarizationDataset`"""

    @staticmethod
    def save_data(data_iter, is_test=False, save_path="./", chunk_size=None):
        """ Save the preprocessed data into files with specified chunk size

        Args:
            data_iter (iterator): Data iterator returned from
                :class:`utils_nlp.models.transformers.datasets.SummarizationDataset`
            is_test (bool): Boolean value which indicates whether target data
                is included. If set to True, the file name contains "test", otherwise,
                the file name contains "train". Defaults to False.
            save_path (str): Directory where the data should be saved. Defaults to "./".
            chunk_size (int): The number of examples that should be included in each
                file. Defaults to None, which means only one file is used.

        Returns:
            a list of strings which are the files the data is saved to.
        """
        os.makedirs(save_path, exist_ok=True)

        def _chunks(iterable, chunk_size):
            iterator = filter(None, iterable)
            for first in iterator:
                if chunk_size:
                    yield itertools.chain(
                        [first], itertools.islice(iterator, chunk_size - 1)
                    )
                else:
                    yield itertools.chain([first], itertools.islice(iterator, None))

        chunks = _chunks(data_iter, chunk_size)
        filename_list = []
        for i, chunked_data in enumerate(chunks):
            filename = f"{i}_test" if is_test else f"{i}_train"
            torch.save(list(chunked_data), os.path.join(save_path, filename))
            filename_list.append(os.path.join(save_path, filename))
        return filename_list

    def _get_files(self, root):
        train_files = []
        test_files = []
        files = [
            os.path.join(root, f)
            for f in os.listdir(root)
            if os.path.isfile(os.path.join(root, f))
        ]
        for fname in files:
            if fname.find("train") != -1:
                train_files.append(fname)
            elif fname.find("test") != -1:
                test_files.append(fname)

        return train_files, test_files

    def splits(self, root, train_iterable=False):
        """Get the train and test dataset from the folder

        Args:
            root (str): Directory where the data can be loaded.

        Returns:
            Tuple of ExtSumProcessedIterableDataset as train dataset
            and ExtSumProcessedDataset as test dataset.
        """
        train_files, test_files = self._get_files(root)
        if train_iterable:
            return (
                ExtSumProcessedIterableDataset(train_files, is_shuffle=True),
                ExtSumProcessedDataset(test_files, is_shuffle=False),
            )
        else:
            return (
                ExtSumProcessedDataset(train_files, is_shuffle=True),
                ExtSumProcessedDataset(test_files, is_shuffle=False),
            )


In [9]:





def preprocess_single_add_oracleids(input_data, oracle_mode="greedy", selections=3):
    """ Preprocess single data point to generate oracle summaries and
        sentence tokenization of the source text.

        Args:
            input_data (dict): An item from `SummarizationDataset`
            oracle_mode (str, optional): Sentence selection method.
                Defaults to "greedy".
            selections (int, optional): The number of sentence used as summary.
                Defaults to 3.
        Returns:
            Dictionary of fields "src", "src_txt", "tgt", "tgt_txt" and "oracle_ids"
    """

    oracle_ids = None
    if "tgt" in input_data:
        if oracle_mode == "greedy":
            oracle_ids = greedy_selection(
                input_data["src"], input_data["tgt"], selections
            )
        elif oracle_mode == "combination":
            oracle_ids = combination_selection(
                input_data["src"], input_data["tgt"], selections
            )
        input_data["oracle_ids"] = oracle_ids
    # input_data["src_txt"] = tokenize.sent_tokenize(input_data["src_txt"])
    return input_data




In [10]:
def parallel_preprocess(input_data, preprocess, num_pool=-1):
    """
    Process data in parallel using multiple GPUs.

    Args:
        input_data (list): List if input strings to process.
        preprocess_pipeline (list): List of functions to apply on the input data.
        word_tokenize (func, optional): A tokenization function used to tokenize
            the results from preprocess_pipeline.
        num_pool (int, optional): Number of CPUs to use. Defaults to -1 and all
            available CPUs are used.

    Returns:
        list: list of processed text strings.

    """
    if num_pool == -1:
        num_pool = cpu_count()

    num_pool = min(num_pool, len(input_data))

    p = Pool(num_pool)

    results = p.map(
        preprocess, input_data, chunksize=min(1, int(len(input_data) / num_pool))
    )
    p.close()
    p.join()

    return results


In [27]:
class ExtSumProcessor:
    """Class for preprocessing extractive summarization data."""

    def __init__(
        self,
        model_name="bert-base-german-dbmdz-cased",
        to_lower=False,
        cache_dir=".",
        max_nsents=200,
        max_src_ntokens=2000,
        min_nsents=3,
        min_src_ntokens=5,
    ):
        """ Initialize the preprocessor.

        Args:
            model_name (str, optional): Transformer model name used in preprocessing.
                check MODEL_CLASS for supported models. Defaults to "bert-base-cased".
            to_lower (bool, optional): Whether to convert all letters to lower case
                during tokenization. This is determined by if a cased model is used.
                Defaults to False, which corresponds to a cased model.
            cache_dir (str, optional): Directory to cache the tokenizer.
                Defaults to ".".
            max_nsents (int, optional): Max number of sentences that can be used
                as input. Defaults to 200.
            max_src_ntokens (int, optional): Max number of tokens that be used
                as input. Defaults to 2000.
            min_nsents (int, optional): Minimum number of sentences that are required
                as input. If the input has less number of sentences than this value,
                it's skipped and cannot be used as a valid input. Defaults to 3.
            min_src_ntokens (int, optional): Minimum number of tokens that are required
                as an input sentence.If the input sentence has less number of tokens
                than this value, it's skipped and cannot be used as a valid sentence.
                Defaults to 5.

        """
        self.model_name = model_name
        if "roberta" in self.model_name:
            self.tokenizer = XLMRobertaTokenizer.from_pretrained(
                model_name,
                do_lower_case=to_lower,
                cache_dir=cache_dir,
                output_loading_info=False,
                eos_token="",
                bos_token="",
                sep_token="[SEP]",
                cls_token="[CLS]",
                pad_token="[PAD]",

            )
            self.tokenizer = XLMRobertaTokenizer.from_pretrained(
                model_name,
                do_lower_case=to_lower,
                cache_dir=cache_dir,
                output_loading_info=False,
            )
            self.sep_vid = self.tokenizer.sep_token_id
            self.cls_vid = self.tokenizer.cls_token_id
            self.pad_vid = self.tokenizer.pad_token_id
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                do_lower_case=to_lower,
                cache_dir=cache_dir,
                output_loading_info=False,
            )
            self.sep_vid = self.tokenizer.vocab["[SEP]"]
            self.cls_vid = self.tokenizer.vocab["[CLS]"]
            self.pad_vid = self.tokenizer.vocab["[PAD]"]
        
        print("sep_vid: ", self.sep_vid)
        print("cls_vid: ", self.cls_vid)
        print("pad_vid: ", self.pad_vid)
        self.max_nsents = max_nsents
        self.max_src_ntokens = max_src_ntokens
        self.min_nsents = min_nsents
        self.min_src_ntokens = min_src_ntokens

    @staticmethod
    def list_supported_models():
        return list(MODEL_CLASS)

    @property
    def model_name(self):
        return self._model_name

    @model_name.setter
    def model_name(self, value):
        if value not in self.list_supported_models():
            raise ValueError(
                "Model name {} is not supported by ExtSumProcessor. "
                "Call 'ExtSumProcessor.list_supported_models()' to get all supported "
                "model names.".format(value)
            )

        self._model_name = value

    @staticmethod
    def get_inputs(batch, device, model_name, train_mode=True):
        """
        Creates an input dictionary given a model name.

        Args:
            batch (object): A Batch containing input ids, segment ids, sentence class
                ids, masks for the input ids, masks for  sentence class ids and source
                text. If train_model is True, it also contains the labels and target
                text.
            device (torch.device): A PyTorch device.
            model_name (bool): Model name used to format the inputs.
            train_mode (bool, optional): Training mode flag.
                Defaults to True.

        Returns:
            dict: Dictionary containing input ids, segment ids, sentence class ids,
            masks for the input ids, masks for the sentence class ids and labels.
            Labels are only returned when train_mode is True.
        """
        if train_mode:
            
            batch = batch.to(device)
            print("batch.segs: ", batch.segs)
            print("batch.clss: ", batch.clss)
            print("batch.mask: ", batch.mask)
            print("batch.mask_cls: ", batch.mask_cls)
            # labels must be the last
            d = {
                "x": batch.src,
                "segs": batch.segs,
                "clss": batch.clss,
                "mask": batch.mask,
                "mask_cls": batch.mask_cls,
                "labels": batch.labels,
            }
            print("returned dict: ", d)
            return d
        else:
            batch = batch.to(device)
            return {
                "x": batch.src,
                "segs": batch.segs,
                "clss": batch.clss,
                "mask": batch.mask,
                "mask_cls": batch.mask_cls,
                # "labels": batch.labels,
            }
            """
            return {
                "x": batch.src.to(device),
                "segs": batch.segs.to(device),
                "clss": batch.clss.to(device),
                "mask": batch.mask.to(device),
                "mask_cls": batch.mask_cls.to(device),
            }
            """

    def preprocess(self, input_data_list, oracle_mode="greedy", selections=3):
        """ Preprocess multiple data points.

           Args:
              input_data_list (SummarizationDataset): The dataset to be preprocessed.
              oracle_mode (str, optional): Sentence selection method.
                Defaults to "greedy".
              selections (int, optional): The number of sentence used as summary.
                Defaults to 3.

            Returns:
                Iterator of dictory objects containing input ids, segment ids,
                sentence class ids, labels, source text and target text.
                If targets is None, the label and target text are None.
        """
        preprocess = functools.partial(
            preprocess_single_add_oracleids, oracle_mode="greedy", selections=3
        )
        return parallel_preprocess(input_data_list, preprocess)

    def collate(self, data, block_size, device, train_mode=True):
        """ Collcate function for pytorch data loaders.
            Args:
                data (list): A list of samples from SummarizationDataset.
                block_size (int): maximum input length for the model.
                train_mode (bool): whether the collate function is used for training
                    or not. Defaults to True.

            Returns:
                `Batch` object: a data minibatch as the input of a model.

        """

        if len(data) == 0:
            return None
        else:
            if train_mode is True and "tgt" in data[0] and "oracle_ids" in data[0]:
                encoded_text = [self.encode_single(d, block_size) for d in data]
                batch = Batch(list(filter(None, encoded_text)), True)
            else:
                encoded_text = [
                    self.encode_single(d, block_size, train_mode) for d in data
                ]
                # src, labels, segs, clss, src_txt, tgt_txt =  zip(*encoded_text)
                # new_data = [list(i) for i in list(zip(*encoded_text))]
                # batch =  Batch(new_data)
                filtered_list = list(filter(None, encoded_text))
                # if len(filtered_list) != len(data):
                #    raise ValueError("no test data shouldn't be skipped")
                batch = Batch(filtered_list)
            return batch.to(device)

    def encode_single(self, d, block_size, train_mode=True):
        """ Enocde a single sample.
            Args:
                d (dict): s data sample from SummarizationDataset.
                block_size (int): maximum input length for the model.

            Returns:
                Tuple of encoded data.

        """

        src = d["src"]

        if len(src) == 0:
            raise ValueError("source doesn't have any sentences")

        original_src_txt = [" ".join(s) for s in src]
        # no filtering for prediction
        idxs = [i for i, s in enumerate(src)]
        src = [src[i] for i in idxs]

        tgt_txt = None
        labels = None
        if (
            train_mode and "oracle_ids" in d and "tgt" in d and "tgt_txt" in d
        ):  # is not None and tgt is not None:
            labels = [0] * len(src)
            for l in d["oracle_ids"]:
                labels[l] = 1

            # source filtering for only training
            idxs = [i for i, s in enumerate(src) if (len(s) > self.min_src_ntokens)]
            src = [src[i][: self.max_src_ntokens] for i in idxs]
            src = src[: self.max_nsents]
            labels = [labels[i] for i in idxs]
            labels = labels[: self.max_nsents]

            if len(src) < self.min_nsents:
                return None
            if len(labels) == 0:
                return None
            tgt_txt = "".join([" ".join(tt) for tt in d["tgt"]])

        src_txt = [" ".join(sent) for sent in src]
        text = " [SEP] [CLS] ".join(src_txt)
        src_subtokens = self.tokenizer.tokenize(text)
        # src_subtokens = src_subtokens[:510]
        src_subtokens = (
            ["[CLS]"]
            + fit_to_block_size(
                src_subtokens, block_size - 2, self.tokenizer.pad_token_id
            )
            + ["[SEP]"]
        )
        src_subtoken_idxs = self.tokenizer.convert_tokens_to_ids(src_subtokens)
        _segs = [-1] + [i for i, t in enumerate(src_subtoken_idxs) if t == self.sep_vid]
        segs = [_segs[i] - _segs[i - 1] for i in range(1, len(_segs))]
        segments_ids = []
        for i, s in enumerate(segs):
            if i % 2 == 0:
                segments_ids += s * [0]
            else:
                segments_ids += s * [1]
        cls_ids = [i for i, t in enumerate(src_subtoken_idxs) if t == self.cls_vid]
        if labels:
            labels = labels[: len(cls_ids)]
        src_txt = [original_src_txt[i] for i in idxs]
        return src_subtoken_idxs, labels, segments_ids, cls_ids, src_txt, tgt_txt


In [28]:




class ExtractiveSummarizer(Transformer):
    """class which performs extractive summarization fine tuning and prediction """

    def __init__(
        self,
        processor,
        model_name="bert-base-german-dbmdz-cased",
        encoder="transformer",
        max_pos_length=512,
        cache_dir=".",
    ):
        """Initialize a ExtractiveSummarizer.

        Args:
            model_name (str, optional): Transformer model name used in preprocessing.
                check MODEL_CLASS for supported models.
                Defaults to "distilbert-base-uncased".
            encoder (str, optional): Encoder algorithm used by summarization layer.
                There are four options:
                    - baseline: it used a smaller transformer model to replace the bert
                        model and with transformer summarization layer.
                    - classifier: it uses pretrained BERT and fine-tune BERT with simple
                        logistic classification summarization layer.
                    - transformer: it uses pretrained BERT and fine-tune BERT with
                        transformer summarization layer.
                    - RNN: it uses pretrained BERT and fine-tune BERT with LSTM
                        summarization layer.
                Defaults to "transformer".
            cache_dir (str, optional): Directory to cache the tokenizer.
                Defaults to ".".
        """

        model = MODEL_CLASS[model_name].from_pretrained(
            model_name, cache_dir=cache_dir, num_labels=0, output_loading_info=False
        )
        super().__init__(model_name=model_name, model=model, cache_dir=cache_dir)

        if model_name not in self.list_supported_models():
            raise ValueError(
                "Model name {} is not supported by ExtractiveSummarizer. "
                "Call 'ExtractiveSummarizer.list_supported_models()' to get all  "
                "supported model names.".format(model_name)
            )
        self.processor = processor
        self.max_pos_length = max_pos_length
        self.model_class = MODEL_CLASS[model_name]
        default_summarizer_layer_parameters = {
            "ff_size": 512,
            "heads": 4,
            "dropout": 0.1,
            "inter_layers": 2,
            "hidden_size": 128,
            "rnn_size": 512,
            "param_init": 0.0,
            "param_init_glorot": True,
        }

        args = Bunch(default_summarizer_layer_parameters)
        self.model = BertSumExt(
            encoder, args, self.model_class, model_name, max_pos_length, None, cache_dir
        )
        print("hello~")
    @staticmethod
    def list_supported_models():
        return list(MODEL_CLASS)

    def fit(
        self,
        train_dataset,
        num_gpus=None,
        gpu_ids=None,
        batch_size=3000,
        local_rank=-1,
        max_steps=5e5,
        warmup_steps=1e5,
        learning_rate=2e-3,
        optimization_method="adam",
        max_grad_norm=0,
        beta1=0.9,
        beta2=0.999,
        decay_method="noam",
        gradient_accumulation_steps=1,
        report_every=50,
        verbose=True,
        seed=None,
        save_every=-1,
        world_size=1,
        rank=0,
        use_preprocessed_data=False,
        **kwargs,
    ):
        """
        Fine-tune pre-trained transofmer models for extractive summarization.

        Args:
            train_dataset (ExtSumProcessedIterableDataset): Training dataset.
            num_gpus (int, optional): The number of GPUs to use.
                If None, all available GPUs will be used. If set to 0 or GPUs are not
                available, CPU device will be used. Defaults to None.
            gpu_ids (list): List of GPU IDs to be used.
                If set to None, the first num_gpus GPUs will be used.
                Defaults to None.
            batch_size (int, optional): Maximum number of tokens in each batch.
            local_rank (int, optional): Local_rank for distributed training on GPUs.
                Defaults to -1, which means non-distributed training.
            max_steps (int, optional): Maximum number of training steps.
                Defaults to 5e5.
            warmup_steps (int, optional): Number of steps taken to increase learning
                rate from 0 to `learning_rate`. Defaults to 1e5.
            learning_rate (float, optional):  Learning rate of the AdamW optimizer.
                Defaults to 5e-5.
            optimization_method (string, optional): Optimization method used in
                fine tuning.
            max_grad_norm (float, optional): Maximum gradient norm for gradient
                clipping.
                Defaults to 0.
            gradient_accumulation_steps (int, optional): Number of batches to accumulate
                gradients on between each model parameter update. Defaults to 1.
            decay_method (string, optional): learning rate decrease method.
                Defaulta to 'noam'.
            report_every (int, optional): The interval by steps to print out the
                trainint log.
                Defaults to 50.
            beta1 (float, optional): The exponential decay rate for the first moment
                estimates.
                Defaults to 0.9.
            beta2 (float, optional): The exponential decay rate for the second-moment
                estimates.
                This value should be set close to 1.0 on problems with a sparse
                gradient.
                Defaults to 0.99.
            verbose (bool, optional): Whether to print out the training log.
                Defaults to True.
            seed (int, optional): Random seed used to improve reproducibility.
                Defaults to None.
            rank (int, optional): Global rank of the current GPU in distributed
                training. It's calculated with the rank of the current node in
                the cluster/world and the `local_rank` of the device in the current
                node. See an example in :file: `examples/text_summarization/
                extractive_summarization_cnndm_distributed_train.py`.
                Defaults to 0.
        """

        # get device
        device, num_gpus = get_device(
            num_gpus=num_gpus, gpu_ids=gpu_ids, local_rank=local_rank
        )
        # move model
        self.model = move_model_to_device(model=self.model, device=device)

        # init optimizer
        print("before optimizer")
        optimizer = model_builder.build_optim(
            self.model,
            optimization_method,
            learning_rate,
            max_grad_norm,
            beta1,
            beta2,
            decay_method,
            warmup_steps,
        )
        print("before parallelize_model")
        self.model = parallelize_model(
            model=self.model,
            device=device,
            num_gpus=num_gpus,
            gpu_ids=gpu_ids,
            local_rank=local_rank,
        )

        # batch_size is the number of tokens in a batch
        print("before parallelize_model")

        if local_rank == -1:
            print("RandomSampler")
            sampler = RandomSampler(train_dataset)
        else:
            print("DistributedSample")

            sampler = DistributedSampler(
                train_dataset, num_replicas=world_size, rank=rank
            )

        def collate_fn(data):
            return self.processor.collate(
                data, block_size=self.max_pos_length, device=device
            )
        train_dataloader = DataLoader(
            train_dataset,
            sampler=sampler,
            batch_size=batch_size,
            collate_fn=collate_fn,
        )

        # compute the max number of training steps
        max_steps = compute_training_steps(
            train_dataloader,
            max_steps=max_steps,
            gradient_accumulation_steps=gradient_accumulation_steps,
        )
        print("before fine tune")
        print()
        super().fine_tune(
            train_dataloader=train_dataloader,
            get_inputs=ExtSumProcessor.get_inputs,
            device=device,
            num_gpus=num_gpus,
            max_steps=max_steps,
            max_grad_norm=max_grad_norm,
            gradient_accumulation_steps=gradient_accumulation_steps,
            optimizer=optimizer,
            scheduler=None,
            verbose=verbose,
            seed=seed,
            report_every=report_every,
            clip_grad_norm=False,
            save_every=save_every,
        )

    def predict(
        self,
        test_dataset,
        num_gpus=None,
        gpu_ids=None,
        batch_size=16,
        sentence_separator="<q>",
        top_n=3,
        block_trigram=True,
        cal_lead=False,
        verbose=True,
        local_rank=-1,
    ):
        """
        Predict the summarization for the input data iterator.

        Args:
            test_dataset (Dataset): Dataset for which the summary to be predicted
            num_gpus (int, optional): The number of GPUs used in prediction.
                Defaults to 1.
            gpu_ids (list): List of GPU IDs to be used.
                If set to None, the first num_gpus GPUs will be used.
                Defaults to None.
            batch_size (int, optional): The number of test examples in each batch.
                Defaults to 16.
            sentence_separator (str, optional): String to be inserted between
                sentences in the prediction. Defaults to '<q>'.
            top_n (int, optional): The number of sentences that should be selected
                from the paragraph as summary. Defaults to 3.
            block_trigram (bool, optional): voolean value which specifies whether
                the summary should include any sentence that has the same trigram
                as the already selected sentences. Defaults to True.
            cal_lead (bool, optional): Boolean value which specifies whether the
                prediction uses the first few sentences as summary. Defaults to False.
            verbose (bool, optional): Whether to print out the training log.
                Defaults to True.

        Returns:
            List of strings which are the summaries

        """

        device, num_gpus = get_device(
            num_gpus=num_gpus, gpu_ids=gpu_ids, local_rank=local_rank
        )

        def collate_processed_data(dict_list):
            # tuple_batch =  [list(col) for col in zip(*[d.values() for d in dict_list]
            if dict_list is None or len(dict_list) <= 0:
                return None
            tuple_batch = [list(d.values()) for d in dict_list]
            # generate mask and mask_cls, and only select tensors for the model input
            # the labels was never used in prediction, set is_labeled as False
            batch = Batch(tuple_batch, is_labeled=False)
            return batch

        def collate(data):
            return self.processor.collate(
                data, block_size=self.max_pos_length, train_mode=False, device=device
            )

        if len(test_dataset) == 0:
            return None
        if "segs" in test_dataset[0]:
            collate_fn = collate_processed_data
        else:
            collate_fn = collate

        test_sampler = SequentialSampler(test_dataset)
        test_dataloader = DataLoader(
            test_dataset,
            sampler=test_sampler,
            batch_size=batch_size,
            collate_fn=collate_fn,
        )
        sent_scores = self.predict_scores(
            test_dataloader, num_gpus=num_gpus, gpu_ids=gpu_ids
        )

        sent_scores_list = list(sent_scores)
        scores_list = []
        for i in sent_scores_list:
            scores_list.extend(i)
        prediction = []
        for i in range(len(test_dataset)):
            temp_pred = get_pred(
                test_dataset[i],
                scores_list[i],
                cal_lead=cal_lead,
                sentence_separator=sentence_separator,
                block_trigram=block_trigram,
                top_n=top_n,
            )
            prediction.extend(temp_pred)

        # release GPU memories
        self.model.cpu()
        torch.cuda.empty_cache()

        return prediction

    def predict_scores(self, test_dataloader, num_gpus=1, gpu_ids=None, verbose=True):
        """
        Scores a dataset using a fine-tuned model and a given dataloader.

        Args:
            test_dataloader (Dataloader): Dataloader for scoring the data.
            num_gpus (int, optional): The number of GPUs to use.
                If None, all available GPUs will be used.
                If set to 0 or GPUs are not available, CPU device will be used.
                Defaults to None.
            gpu_ids (list): List of GPU IDs to be used.
                If set to None, the first num_gpus GPUs will be used.
                Defaults to None.
            verbose (bool, optional): Whether to print out the training log.
                Defaults to True.

        Returns
            1darray: numpy array of predicted sentence scores.
        """

        preds = list(
            super().predict(
                eval_dataloader=test_dataloader,
                get_inputs=ExtSumProcessor.get_inputs,
                num_gpus=num_gpus,
                gpu_ids=gpu_ids,
                verbose=verbose,
            )
        )
        return preds

    def save_model(self, full_name=None):
        """
        save the trained model.

        Args:
            full_name (str, optional): File name to save the model's `state_dict()`.
                If it's None, the model is going to be saved under "fine_tuned"
                folder of the cached directory of the object. Defaults to None.
        """
        model_to_save = (
            self.model.module if hasattr(self.model, "module") else self.model
        )  # Take care of distributed/parallel training

        if full_name is None:
            output_model_dir = os.path.join(self.cache_dir, "fine_tuned")
            os.makedirs(self.cache_dir, exist_ok=True)
            os.makedirs(output_model_dir, exist_ok=True)
            full_name = os.path.join(output_model_dir, self.model_name)

        logger.info("Saving model checkpoint to %s", full_name)
        try:
            print("saving through pytorch")
            torch.save(model_to_save.state_dict(), full_name)
        except OSError:
            try:
                print("saving as pickle")
                pickle.dump(model_to_save.state_dict(), open(full_name, "wb"))
            except Exception:
                raise
        except Exception:
            raise


In [29]:
pd.DataFrame({"model_name": ExtractiveSummarizer.list_supported_models()})

Unnamed: 0,model_name
0,bert-base-uncased
1,bert-base-german-cased
2,distilbert-base-uncased
3,dbmdz/bert-base-german-uncased
4,bert-base-german-dbmdz-cased
5,bert-base-multilingual-cased
6,distilbert-base-german-cased
7,bert-base-german-dbmdz-uncased
8,severinsimmler/bert-adapted-german-press
9,xlm-roberta-large-finetuned-conll03-german


In [30]:
# Transformer model being used
MODEL_NAME = "xlm-roberta-large-finetuned-conll03-german"

# Data being used
DATA_NAME = "bundes"

In [31]:
# notebook parameters
# the cache data path during find tuning
CACHE_DIR = TemporaryDirectory().name
processor = ExtSumProcessor(model_name=MODEL_NAME, cache_dir=CACHE_DIR)

HBox(children=(IntProgress(value=0, description='Downloading', max=5069051, style=ProgressStyle(description_wi…


sep_vid:  2
cls_vid:  0
pad_vid:  1


### Data Preprocessing

The dataset we used for this notebook is CNN/DM dataset which contains the documents and accompanying questions from the news articles of CNN and Daily mail. The highlights in each article are used as summary. The dataset consits of ~289K training examples, ~11K valiation examples and ~11K test examples.  You can choose the [Option 1] below preprocess the data or [Option 2] to use the preprocessed version at [BERTSum published example](https://github.com/nlpyang/BertSum/). You don't need to manually download any of these two data sets as the code below will handle downloading. Functions defined specific in [cnndm.py](../../utils_nlp/dataset/cnndm.py) are unique to CNN/DM dataset that's preprocessed by harvardnlp. However, it provides a skeleton of how to preprocessing text into the format that model preprocessor takes: sentence tokenization and work tokenization. 

##### Details of Data Preprocessing

The purpose of preprocessing is to process the input articles to the format that model finetuning needed. Assuming you have (1) all articles and (2) target summaries, each in a file and line-breaker separated, the steps to preprocess the data are:
1. sentence tokenization
2. word tokenization
3. **label** the sentences in the article with 1 meaning the sentence is selected and 0 meaning the sentence is not selected. The algorithms for the sentence selection are "greedy" and "combination" and can be found in [sentence_selection.py](../../utils_nlp/dataset/sentence_selection.py)
3. convert each example to  the desired format for extractive summarization
    - filter the sentences in the example based on the min_src_ntokens argument. If the lefted total sentence number is less than min_nsents, the example is discarded.
    - truncate the sentences in the example if the length is greater than max_src_ntokens
    - truncate the sentences in the example and the labels if the total number of sentences is greater than max_nsents
    - [CLS] and [SEP] are inserted before and after each sentence
    - wordPiece tokenization or Byte Pair Encoding (BPE) subword tokenization
    - truncate the example to 512 tokens
    - convert the tokens into token indices corresponding to the transformer tokenizer's vocabulary.
    - segment ids are generated and added
    - [CLS] token positions are logged
    - [CLS] token labels are truncated if it's greater than 512, which is the maximum input length that can be taken by the transformer model.
    
    
Note that the original BERTSum paper use Stanford CoreNLP for data preprocessing, here we use NLTK for data preprocessing. 

In [32]:
# the data path used to save the downloaded data file
DATA_PATH = '/home/ubuntu/data/mnt/bundes_dataset/'
# The number of lines at the head of data file used for preprocessing. -1 means all the lines.
TOP_N = 500
if not QUICK_RUN:
    TOP_N = -1

In [33]:
validation = False

if validation:
    train_dataset, validation_dataset, test_dataset = BundesSummarizationDataset(top_n=TOP_N, validation=True, language='german')
else:
    train_dataset, test_dataset = BundesSummarizationDataset(top_n=TOP_N, validation=False, language='german')


source[0]:  Berlin, 26. Oktober 2017
Durch die Pflegereform konnte die Zahl der zusätzlichen Betreuungskräfte in stationären Pflegeeinrichtungen auf rund 60.000 Betreuungskräfte mehr als verdoppelt werden. Das ergibt sich aus der aktuellen Ausgabenentwicklung.
Bundesgesundheitsminister Hermann Gröhe : "Dass wir die Zahl der Betreuungskräfte in Pflegeeinrichtungen mehr als verdoppeln konnten, ist zusammen mit dem Abbau unnötiger Bürokratie und der Bezahlung von Tariflöhnen ein wichtiger Schritt. Weitere Schritte müssen folgen, dazu gehören insbesondere weitere Verbesserungen bei der Ausstattung mit Pflegefachkräften und angemessene Löhne auch in nicht tariflich-gebundenen Pflegeeinrichtungen überall in Deutschland. Gute Arbeitsbedingungen für alle, die in der Pflege täglich enormes leisten, werden ein zentrales Thema auch in dieser Wahlperiode sein - dafür kämpfe ich."
Die Zahl der zusätzlichen Betreuungskräfte in stationären Pflegeeinrichtungen konnte infolge der Pflegestärkungsgesetze

In [34]:

len(train_dataset), len(test_dataset)

(475, 25)

### Preprocess the data.

In [35]:

ext_sum_train = processor.preprocess(train_dataset, oracle_mode="greedy")
ext_sum_test = processor.preprocess(test_dataset, oracle_mode="greedy")


### Save the data.

In [36]:
ext_sum_test[0]['tgt']

[['Heute',
  'wurden',
  'die',
  'Verträge',
  'mit',
  'den',
  'an',
  'den',
  'Max',
  'Planck',
  'Schools',
  'beteiligten',
  'Universitäten',
  'unterzeichnet',
  ':',
  '„',
  'Damit',
  'treten',
  'wir',
  'in',
  'Konkurrenz',
  'zu',
  'Oxford',
  ',',
  'Cambridge',
  ',',
  'der',
  '‚Ivy',
  'League',
  '‘',
  'oder',
  'ähnlichen',
  'Eliteuniversitäten',
  '“',
  ',',
  'sagt',
  'Staatssekretär',
  'Michael',
  'Meister',
  '.']]

In [37]:
SAVE_DATA = False


# save and load preprocessed data

if SAVE_DATA:
    save_path = os.path.join(DATA_PATH, DATA_NAME + "_processed")
    os.makedirs(save_path, exist_ok=True)

    torch.save(ext_sum_train, os.path.join(save_path, "train_full.pt"))
    torch.save(ext_sum_test, os.path.join(save_path, "test_full.pt"))

In [38]:
len(ext_sum_train)

475

#### Inspect Data

##### [Option 2] Reuse cached preprocessed data

In [39]:
if USE_PREPROCSSED_DATA:
    save_path = os.path.join(DATA_PATH)
    ext_sum_train = torch.load(os.path.join(save_path, "train_full.pt"))
    ext_sum_test = torch.load(os.path.join(save_path, "test_full.pt"))
    

### Model training
To start model training, we need to create a instance of ExtractiveSummarizer.

Potentionally, roberta-based model and xlnet can be supported but needs to be tested.
#### Choose the encoder algorithm.
There are four options:
- baseline: it used a smaller transformer model to replace the bert model and with transformer summarization layer
- classifier: it uses pretrained BERT and fine-tune BERT with **simple logistic classification** summarization layer
- transformer: it uses pretrained BERT and fine-tune BERT with **transformer** summarization layer
- RNN: it uses pretrained BERT and fine-tune BERT with **LSTM** summarization layer

In [40]:
BATCH_SIZE = 5 # batch size, unit is the number of samples
MAX_POS_LENGTH = 512


# GPU used for training
NUM_GPUS = torch.cuda.device_count()

# Encoder name. Options are: 1. baseline, classifier, transformer, rnn.
ENCODER = "transformer"

# Learning rate
LEARNING_RATE=2e-3

# How often the statistics reports show up in training, unit is step.
REPORT_EVERY=50

# total number of steps for training
MAX_STEPS=1e2
# number of steps for warm up
WARMUP_STEPS=5e2
    
if not QUICK_RUN:
    MAX_STEPS=5e4
    WARMUP_STEPS=5e3
 

In [41]:
summarizer = ExtractiveSummarizer(processor, MODEL_NAME, ENCODER, MAX_POS_LENGTH, CACHE_DIR)

HBox(children=(IntProgress(value=0, description='Downloading', max=886, style=ProgressStyle(description_width=…




HBox(children=(IntProgress(value=0, description='Downloading', max=2239696654, style=ProgressStyle(description…


hello~


In [42]:
summarizer.model_name

'xlm-roberta-large-finetuned-conll03-german'

In [43]:
#"""

summarizer.fit(
            ext_sum_train,
            num_gpus=NUM_GPUS,
            batch_size=BATCH_SIZE,
            gradient_accumulation_steps=2,
            max_steps=MAX_STEPS,
            learning_rate=LEARNING_RATE,
            warmup_steps=WARMUP_STEPS,
            verbose=True,
            report_every=REPORT_EVERY,
            clip_grad_norm=False,
            use_preprocessed_data=False
        )

#"""


Iteration:   0%|          | 0/95 [00:00<?, ?it/s]

before optimizer
before parallelize_model
before parallelize_model
RandomSampler
before fine tune

<torch.utils.data.dataloader.DataLoader object at 0x7f6948f00320>
before the while


TypeError: not a string

In [None]:
summarizer.save_model(
    os.path.join(
        CACHE_DIR,
        "extsum_modelname_{0}_usepreprocess{1}_steps_{2}.pt".format(
            MODEL_NAME, USE_PREPROCSSED_DATA, MAX_STEPS
        ),
    )
)

### Model Evaluation

[ROUGE](https://en.wikipedia.org/wiki/ROUGE_(metric)), or Recall-Oriented Understudy for Gisting Evaluation has been commonly used for evaluating text summarization.

In [None]:
# for loading a previous saved model

model_filename = "dist_extsum_model.pt"
model_filepath = "/home/ubuntu/mnt/train/distilbert-base-german-cased/2007142250/"
model_path = os.path.join(model_filepath, model_filename)
summarizer = ExtractiveSummarizer(processor, MODEL_NAME, ENCODER, MAX_POS_LENGTH, CACHE_DIR)
summarizer.model.load_state_dict(torch.load(model_path, map_location="cpu"))

In [None]:
if "segs" in ext_sum_test[0]: # preprocessed_data
    source = [i['src_txt'] for i in ext_sum_test]
    target = ["\n".join(i['tgt_txt'].split("<q>")) for i in ext_sum_test]
else:
    source = []
    temp_target = []
    for i in ext_sum_test:
        source.append(i["src_txt"]) 
        temp_target.append(" ".join(j) for j in i['tgt']) 
    target = [''.join(i) for i in list(temp_target)]

In [None]:
%%time
sentence_separator = "\n"
prediction = summarizer.predict(ext_sum_test, num_gpus=NUM_GPUS, batch_size=BATCH_SIZE, sentence_separator=sentence_separator)

In [None]:
rouge_scores = compute_rouge_python(cand=prediction, ref=target)
pprint.pprint(rouge_scores)

In [None]:
prediction[0].replace("\n", " ")

In [None]:
with open('sample_results.txt','w') as f:
    for i in range(len(prediction)):
        source_output = " ".join(source[i]) 
        f.write("Source Text: \n")
        f.write("\"" + source_output + "\" \n")
        f.write("\n")
        f.write("Source target: \n")
        f.write("\"" + target[i] + "\" \n")
        f.write("\n")
        f.write("Model Prediction: \n")
        f.write("\"" + prediction[i].replace("\n", " ") + "\" \n")        
        f.write("\n")
        f.write("======================================")        
        f.write("\n \n")

In [None]:
target[10]

In [None]:
prediction[10]

In [None]:
# for testing
sb.glue("rouge_2_f_score", rouge_scores['rouge-2']['f'])

## Prediction on a single input sample

In [None]:
source = """
Italien erlaubt nach tagelangem Zögern den etwa 180 Migranten auf dem privaten Rettungsschiff "Ocean Viking" den Wechsel auf das italienische Quarantäne-Schiff "Moby Zaza". Die Übernahme der aus Seenot geretteten Menschen sei für Montag geplant, hieß es am Samstagabend aus Quellen im Innenministerium in Rom. Zuvor hatte sich die Lage auf dem Schiff der Organisation SOS Méditerranée, das sich in internationalen Gewässern vor Sizilien befindet, zugespitzt.
Die Betreiber berichteten demnach von einem Hungerstreik unter den Geflüchteten. Verena Papke, Geschäftsführerin von SOS Méditerranée für Deutschland, hatte am Freitag von mehreren Suizidversuchen gesprochen. Die "Ocean Viking" hatte zudem den Notstand an Bord ausgerufen. Bis dahin waren mehrere Bitten um Zuweisung eines sicheren Hafens in Malta und Italien erfolglos geblieben.

Corona-Abstriche bei den Migranten geplant
Die Crew sandte die dringende Anfrage an die Behörden beider Länder zur Aufnahme von rund 45 Menschen, die in schlechter Verfassung seien. Italien schickte daraufhin am Samstag einen Psychiater und einen kulturellen Mediator aus Pozzallo für mehrere Stunden an Bord, berichteten beide Seiten. Danach kam die Erlaubnis aus Rom zur Übernahme auf die "Moby Zaza". Die Lage an Bord habe sich jedoch etwas entspannt, hieß es aus der italienischen Hauptstadt. Am Sonntag seien zunächst Corona-Abstriche bei den Migranten geplant.

Wie SOS Méditerranée am Samstag schrieb, nahm das Schiff in insgesamt vier Einsätzen am 25. und am 30. Juni etwa 180 Menschen aus dem Mittelmeer an Bord. Italien und Malta hatten sich in der Corona-Pandemie zu nicht sicheren Häfen erklärt. Trotzdem brechen Migranten von Libyen und Tunesien in Richtung Europa auf. Rom und Valletta nahmen zuletzt zwar wieder Menschen von privaten Schiffen auf, doch die Länder zögern mit der Zuweisung von Häfen oft lange. Sie fordern von anderen EU-Staaten regelmäßig Zusagen über die Weiterverteilung der Menschen."""

In [None]:
test_dataset = SummarizationDataset(
    None,
    source=[source],
    source_preprocessing=[tokenize.sent_tokenize],
    word_tokenize=nltk.word_tokenize,
    language='german'
)
processor = ExtSumProcessor(model_name=MODEL_NAME,  cache_dir=CACHE_DIR)
preprocessed_dataset = processor.preprocess(test_dataset)

In [None]:
preprocessed_dataset[0].keys()

In [None]:
prediction = summarizer.predict(preprocessed_dataset, num_gpus=0, batch_size=1, sentence_separator="\n")

In [None]:
prediction

## Clean up temporary folders

In [None]:
if os.path.exists(DATA_PATH):
    shutil.rmtree(DATA_PATH, ignore_errors=True)
if os.path.exists(CACHE_DIR):
    shutil.rmtree(CACHE_DIR, ignore_errors=True)
if USE_PREPROCSSED_DATA:
    if os.path.exists(PROCESSED_DATA_PATH):
        shutil.rmtree(PROCESSED_DATA_PATH, ignore_errors=True)

In [44]:
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large-finetuned-conll03-german")

# model = AutoModelForTokenClassification.from_pretrained("xlm-roberta-large-finetuned-conll03-german")

In [45]:
tokenizer.eos_token = '[SEP]'

In [49]:
tokenizer.sp_model

<sentencepiece.SentencePieceProcessor; proxy of <Swig Object of type 'sentencepiece::SentencePieceProcessor *' at 0x7f6948ff4750> >

In [None]:
model = AutoModel.from_pretrained("xlm-roberta-large-finetuned-conll03-german")

In [None]:
model.train()

In [None]:
from transformers import AdamW
optimizer = AdamW(model.parameters(), lr=1e-3)

In [None]:
from transformers import AutoTokenizer


In [None]:
from transformers import XLMRobertaTokenizer
tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-large-finetuned-conll03-german")
text_batch = ["Ich liebe Pixar.", "Ich hasse Pixar."]
encoding = tokenizer(text_batch, return_tensors='pt', padding=True, truncation=True)
input_ids = encoding['input_ids']
attention_mask = encoding['attention_mask']

In [None]:
labels = torch.tensor([1,0]).unsqueeze(0)


In [None]:
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs[0]
loss.backward()
optimizer.step()