# Training&Evaluation of a developed algorithm

This notebook contains source code for several possible architectures to train & evaluate them.

In [None]:
import logging
import os
import random
import sys

import numpy as np
import pandas as pd
import torch
from tqdm.auto import tqdm

from review import config as cfg

logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s')

In [None]:
print('CUDA version', torch.version.cuda)
print('CUDA is available', torch.cuda.is_available())


def setup_cuda_device(model):
    logging.info('Setup single-device settings...')
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    return model, device

In [None]:
logging.info('Fix seed')
seed = cfg.seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# Loading data and preparation of dataset

`load_data` loads all the needed datafiles for building train dataset.\
The several next steps are only should be done if no train/test/val datasets are saved.

In [None]:
def sizeof_fmt(num, suffix='B'):
    for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
        if abs(num) < 1024.0:
            return "%3.1f %s%s" % (num, unit, suffix)
        num /= 1024.0
    return "%.1f %s%s" % (num, 'Yi', suffix)


def load_data(additional_features):
    dataset_root = os.path.expanduser(cfg.dataset_path)
    logging.info(f'Loading references dataset from {dataset_root}')

    logging.info('Loading citations_df')
    citations_df = pd.read_csv(os.path.join(dataset_root, "citations.csv"), sep='\t')
    logging.info(sizeof_fmt(sys.getsizeof(citations_df)))

    logging.info('Loading sentences_df')
    sentences_df = pd.read_csv(os.path.join(dataset_root, "sentences.csv"), sep='\t')
    logging.info(sizeof_fmt(sys.getsizeof(sentences_df)))

    logging.info('Loading review_files_df')
    review_files_df = pd.read_csv(os.path.join(dataset_root, "review_files.csv"), sep='\t')
    logging.info(sizeof_fmt(sys.getsizeof(review_files_df)))

    logging.info('Loading reverse_ref_df')
    reverse_ref_df = pd.read_csv(os.path.join(dataset_root, "reverse_ref.csv"), sep='\t')
    logging.info(sizeof_fmt(sys.getsizeof(reverse_ref_df)))

    if not additional_features:
        return citations_df, sentences_df, review_files_df, reverse_ref_df, None, None, None

    logging.info('Loading abstracts_df')
    abstracts_df = pd.read_csv(os.path.join(dataset_root, "abstracts.csv"), sep='\t')
    logging.info(sizeof_fmt(sys.getsizeof(abstracts_df)))

    logging.info('Loading figures_df')
    figures_df = pd.read_csv(os.path.join(dataset_root, "figures.csv"), sep='\t')
    logging.info(sizeof_fmt(sys.getsizeof(figures_df)))

    logging.info('Loading tables_df')
    tables_df = pd.read_csv(os.path.join(dataset_root, "tables.csv"), sep='\t')
    logging.info(sizeof_fmt(sys.getsizeof(tables_df)))

    return citations_df, sentences_df, review_files_df, reverse_ref_df, abstracts_df, figures_df, tables_df

In [None]:
ADDITIONAL_FEATURES = False

citations_df, sentences_df, review_files_df, reverse_ref_df, abstracts_df, figures_df, tables_df = load_data(
    additional_features=ADDITIONAL_FEATURES)

logging.info('Done loading references dataset')

In [None]:
import matplotlib.pyplot as plt
from collections import Counter

res = Counter(list(sentences_df['pmid'].values))
plt.hist(res.values(), bins=range(-1, 400))
plt.title('Length of papers')
plt.show()
del res

In [None]:
REF_SENTS_DF_PATH = f"{os.path.expanduser(cfg.base_path)}/ref_sents.csv"
! rm {REF_SENTS_DF_PATH}

if os.path.exists(REF_SENTS_DF_PATH):
    ref_sents_df = pd.read_csv(REF_SENTS_DF_PATH, sep='\t')
else:
    logging.info('Creating reference sentences dataset')
    ref_sents_df = pd.merge(citations_df, reverse_ref_df, left_on=['pmid', 'ref_id'], right_on=['pmid', 'ref_id'])
    ref_sents_df = pd.merge(ref_sents_df, sentences_df, left_on=['pmid', 'sent_type', 'sent_id'],
                            right_on=['pmid', 'type', 'sent_id'])
    ref_sents_df = ref_sents_df[ref_sents_df['pmid'].isin(review_files_df['pmid'].values)]
    ref_sents_df = ref_sents_df.drop_duplicates()
    logging.info(f'Len of unique ref_sents {len(set(ref_sents_df["ref_pmid"]))}')
    ref_sents_df = ref_sents_df[['pmid', 'ref_id', 'pub_type', 'ref_pmid', 'sent_type', 'sent_id', 'sentence']]
    ref_sents_df.to_csv(REF_SENTS_DF_PATH, sep='\t', index=False)

logging.info('Cleanup memory')
del citations_df
del review_files_df

display(ref_sents_df.head())

# Rouge
`get_rouge` function allows to compute similarity between two sentences.
`rouge-l` is another possible option.

In [None]:
from rouge import Rouge
from transformers import BertTokenizer

ROUGE_METER = Rouge()
TOKENIZER = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)


def get_rouge(sent1, sent2):
    if sent1 is None or sent2 is None:
        return None
    sent_1 = TOKENIZER.tokenize(sent1)
    if len(sent_1) == 0:
        return None
    sent_1 = " ".join(list(filter(lambda x: x.isalpha() or x in '.!,?', sent_1)))
    sent_2 = TOKENIZER.tokenize(sent2)
    if len(sent_2) == 0:
        return None
    sent_2 = " ".join(list(filter(lambda x: x.isalpha() or x in '.!,?', sent_2)))
    if len(sent_1) == 0 or len(sent_2) == 0:
        return None
    rouges = ROUGE_METER.get_scores(sent_1, sent_2)[0]
    rouges = [rouges[f'rouge-{x}']["f"] for x in ('1', '2')]  # , 'l')]
    return np.mean(rouges) * 100


def mean_rouge(sent, text):
    if len(text) == 0:
        return None
    try:
        return sum(get_rouge(sent, ref_sent) for ref_sent in text) / len(text)
    except Exception as e:
        logging.error(f'Exception at mean_rouge {e}')
        return None


def min_rouge(sent, text):
    try:
        score = 100000000
        for ref_sent in text:
            score = min(get_rouge(sent, ref_sent), score)
        if score == 100000000:
            return None
        return score
    except Exception as e:
        logging.error(f'Exception at min_rouge {e}')
        return None


def max_rouge(sent, text):
    try:
        score = -100000
        for ref_sent in text:
            score = max(get_rouge(sent, ref_sent), score)
        if score == -100000:
            return None
        return score
    except Exception as e:
        logging.error(f'Exception at max_rouge {e}')
        return None

In [None]:
get_rouge("I am scout. True!", "No you are not a scout.")

# Build train / test/ validate datasets

To create a dataset with features use `preprocess_paper_with_features`. Otherwise, use `preprocess_paper`.

In case datasets are not yet created and `ref_sents_df` is also not yet created, let's create `ref_sents_df`.\
For each paper pmid there is a list of sentences from review papers in which the paper with this `pmid` is cited.

In [None]:
from unidecode import unidecode
from nltk.tokenize import sent_tokenize
import re

REPLACE_SYMBOLS = {
    '—': '-',
    '–': '-',
    '―': '-',
    '…': '...',
    '´´': "´",
    '´´´': "´´",
    "''": "'",
    "'''": "'",
    "``": "`",
    "```": "`",
    ":": " : ",
}


def parse_sents(data):
    sents = sum([sent_tokenize(text) for text in data], [])
    sents = list(filter(lambda x: len(x) > 3, sents))
    return sents


def sent_standardize(sent):
    sent = unidecode(sent)
    sent = re.sub(r"\[(xref_\w*_\w\d*]*)(, xref_\w*_\w\d*)*\]", " ", sent)  # delete [xref,...]
    sent = re.sub(r"\( (xref_\w*_\w\d*)(; xref_\w*_\w\d*)* \)", " ", sent)  # delete (xref; ...)
    sent = re.sub(r"\[xref_\w*_\w\d*\]", " ", sent)  # delete [xref]
    sent = re.sub(r"xref_\w*_\w\d*", " ", sent)  # delete [[xref]]
    for k, v in REPLACE_SYMBOLS.items():
        sent = sent.replace(k, v)
    return sent.strip()


def standardize(text):
    return [x for x in (sent_standardize(sent) for sent in text) if len(x) > 3]


PAPER_MIN_SENTENCES = 50
PAPER_MAX_SENTENCES = 100


def preprocess_paper(
        paper_id, sentences_df, ref_sents_df,
        paper_min_sents=PAPER_MIN_SENTENCES, paper_max_sents=PAPER_MAX_SENTENCES):
    paper = sentences_df[sentences_df['pmid'] == paper_id]['sentence']
    paper = standardize(paper)

    ref_sents = ref_sents_df[ref_sents_df['ref_pmid'] == paper_id]['sentence']
    ref_sents = standardize(ref_sents)

    if len(paper) < paper_min_sents:
        return None

    if len(paper) > paper_max_sents:
        paper = list(paper[:paper_min_sents]) + list(paper[-paper_min_sents:])

    preprocessed_score = [
        sum(get_rouge(sent, ref_sent) for ref_sent in ref_sents) / len(ref_sents) for sent in paper
    ]
    return paper, preprocessed_score


def preprocess_paper_with_features(paper_id, sentences_df, ref_sents_df,
                                   abstracts_df, figures_df, reverse_ref_df, tables_df,
                                   paper_min_sents=PAPER_MIN_SENTENCES, paper_max_sents=PAPER_MAX_SENTENCES):
    preprocessed_score = []
    features = []

    papers = sentences_df[sentences_df['pmid'] == paper_id]['sentence']
    papers = standardize(papers)

    sent_ids = sentences_df[sentences_df['pmid'] == paper_id]['sent_id']
    sent_types = sentences_df[sentences_df['pmid'] == paper_id]['type']

    ref_sents = ref_sents_df[ref_sents_df['ref_pmid'] == paper_id]['sentence']
    ref_sents = standardize(ref_sents)

    fig_captions = figures_df[figures_df['pmid'] == paper_id]['caption']
    fig_captions = standardize(fig_captions)

    tab_captions = tables_df[tables_df['pmid'] == paper_id]['caption']
    tab_captions = standardize(tab_captions)

    abstract = abstracts_df[abstracts_df['pmid'] == paper_id]['abstract']
    if len(abstract) != 0:
        abstract = standardize(abstract)

    tmp_df = reverse_ref_df[reverse_ref_df['pmid'] == paper_id]

    if len(papers) < paper_min_sents:
        return None

    if len(papers) > paper_max_sents:
        papers = list(papers[:paper_min_sents]) + list(papers[-paper_min_sents:])
        sent_ids = list(sent_ids[:paper_min_sents]) + list(sent_ids[-paper_min_sents:])
        sent_types = list(sent_types[:paper_min_sents]) + list(sent_types[-paper_min_sents:])

    for sent, sent_type, sent_id in zip(papers, sent_types, sent_ids):
        score = mean_rouge(sent, ref_sents)
        if score is None:
            return None

        r_abs = get_rouge(sent, abstract[0])
        num_refs = len(tmp_df[(tmp_df['sent_type'] == sent_type) & (tmp_df['sent_id'] == sent_id)])
        preprocessed_score.append(score)
        features.append((sent_id, int(sent_type == "general"), r_abs, num_refs,
                         mean_rouge(sent, fig_captions), mean_rouge(sent, tab_captions),
                         min_rouge(sent, fig_captions), min_rouge(sent, tab_captions),
                         max_rouge(sent, fig_captions), max_rouge(sent, tab_captions)))
    return papers, preprocessed_score, features

In [None]:
FEATURES_NUMBER = 10


def process_reference_sentences_dataset(sentences_df, ref_sents_df, additional_features):
    res = {}
    inter = set(sentences_df['pmid'].values) & set(ref_sents_df['ref_pmid'].values)
    for pmid in tqdm(inter):
        try:
            if additional_features:
                temp = preprocess_paper_with_features(
                    pmid, sentences_df, ref_sents_df, abstracts_df, figures_df, reverse_ref_df, tables_df
                )
            else:
                temp = preprocess_paper(pmid, sentences_df, ref_sents_df)
        except Exception as e:
            logging.warning(f'Error during processing {pmid} {e}')
            continue
        if temp is None:
            logging.warning(f'temp is None for {pmid}')
            continue
        res[pmid] = temp
        print(f"\r{pmid} {np.mean(res[pmid][1])}", end="")

    logging.info(f'Successfully preprocessed {len(res)} of {len(inter)} papers')

    logging.info(f'Creating train dataset')
    feature_names = [
        'sent_id', 'sent_type', 'r_abs', 'num_refs',
        'mean_r_fig', 'mean_r_tab', 'min_r_fig', 'min_r_tab', 'max_r_fig', 'max_r_tab'
    ]
    assert len(feature_names) == FEATURES_NUMBER
    train_dic = dict(
        pmid=[], sentence=[], score=[], sent_id=[], sent_type=[], r_abs=[], num_refs=[],
        mean_r_fig=[], mean_r_tab=[], min_r_fig=[], min_r_tab=[], max_r_fig=[], max_r_tab=[]
    )

    for pmid, stat in tqdm(res.items()):
        if len(stat) == 2:
            for sent, score in zip(*stat):
                train_dic['pmid'].append(pmid)
                train_dic['sentence'].append(sent)
                train_dic['score'].append(score)
        else:
            for sent, score, features in zip(*stat):
                train_dic['pmid'].append(pmid)
                train_dic['sentence'].append(sent)
                train_dic['score'].append(score)
                for name, val in zip(feature_names, features):
                    train_dic[name].append(val)

    train_df = pd.DataFrame({k: v for k, v in train_dic.items() if v})
    logging.info(f'Full train dataset {len(train_df)}')
    return train_df

In [None]:
TRAIN_DATASET_PATH = f'{os.path.expanduser(cfg.base_path)}/dataset.csv'
! rm {TRAIN_DATASET_PATH}

if os.path.exists(TRAIN_DATASET_PATH):
    train_df = pd.read_csv(TRAIN_DATASET_PATH)
else:
    train_df = process_reference_sentences_dataset(
        sentences_df, ref_sents_df, additional_features=ADDITIONAL_FEATURES
    )
    train_df.to_csv(TRAIN_DATASET_PATH, index=False)

In [None]:
display(train_df.head())

In [None]:
import matplotlib.pyplot as plt
from collections import Counter

res = Counter(list(train_df['score'].values))

plt.hist(res.values(), bins=range(2, 20))
plt.title('Train dataset scores')
plt.show()

# Splitting data into train/test/val

In [None]:
from sklearn.model_selection import train_test_split

train_ids, test_ids = train_test_split(list(set(train_df['pmid'].values)), test_size=0.2)
test_ids, val_ids = train_test_split(test_ids, test_size=0.4)

train = train_df[train_df['pmid'].isin(train_ids)]
logging.info(f'Train {len(train)}')
display(train.head(1))

test = train_df[train_df['pmid'].isin(test_ids)]
logging.info(f'Test {len(test)}')
display(test.head(1))

val = train_df[train_df['pmid'].isin(val_ids)]
logging.info(f'Validate {len(val)}')
display(val.head(1))


# Inspect pretrained models: BERT and Roberta

In [None]:
from transformers import BertModel, RobertaModel

backbone = BertModel.from_pretrained(
    "bert-base-uncased", output_hidden_states=False
)
print('BERT pretrained')
print(f'Parameters {sum(p.numel() for p in backbone.parameters() if p.requires_grad)}')
print(', '.join(n for n, p in backbone.named_parameters()))
print(backbone)

# backbone = RobertaModel.from_pretrained(
#     'roberta-base', output_hidden_states=False
# )
# print('ROBERTA pretrained')
# print(f'Parameters {sum(p.numel() for p in backbone.parameters() if p.requires_grad)}')
# print(', '.join(n for n, p in backbone.named_parameters()))
# print(backbone)

# Preprocessing text for BERT

In [None]:
def preprocess_paper_bert(text, max_len, tokenizer):
    """
    Preprocess text for Bert / Robert model.
    NOTE: not all the text can be processed because of max_len.
    :param text: list(list(str))
    :param max_len: maximum length of preprocessing
    :param tokenizer: Bert of Robert tokenize
    :return:
        ids | tokenized ids of length max_len, 0 if padding
        attention_mask | list(str) 1 if real token, not padding
        token_type_ids | 0-1 for different sentences
        n_sents | number of actual sentences encoded
    """
    sents = [
        [tokenizer.artBOS.tkn] + tokenizer.tokenize(sent) + [tokenizer.artEOS.tkn] for sent in text
    ]
    logging.debug(f'sents {sents}')
    ids, token_type_ids, segment_signature = [], [], 0
    n_sents = 0
    for i, s in enumerate(sents):
        logging.debug(f'sentence {i} {s}')
        logging.debug(f'ids {len(ids)}')
        logging.debug(f'segments {len(token_type_ids)}')
        logging.debug(f'segment_signature {segment_signature}')
        if len(ids) + len(s) <= max_len:
            n_sents += 1
            ids.extend(tokenizer.convert_tokens_to_ids(s))
            token_type_ids.extend([segment_signature] * len(s))
            segment_signature = (segment_signature + 1) % 2
        else:
            logging.debug(f'break, len(s)={len(s)}')
            break
    attention_mask = [1] * len(ids)

    logging.debug('Padding data')
    pad_len = max(0, max_len - len(ids))
    ids += [tokenizer.PAD.idx] * pad_len
    attention_mask += [0] * pad_len
    token_type_ids += [segment_signature] * pad_len
    assert len(ids) == len(attention_mask)
    assert len(ids) == len(token_type_ids)
    return ids, attention_mask, token_type_ids, n_sents

In [None]:
print('Inspect preprocessing text for Bert')

tokenizer = Summarizer.initialize_bert_tokenizer()
text = list(train_df.head(5)['sentence'])
input_ids, attention_mask, token_type_ids, n_sents = preprocess_paper_bert(text, 512, tokenizer)

print(f'input_ids {input_ids}')
print(f'attention_mask {attention_mask}')
print(f'token_type_ids {token_type_ids}')
print(f'n_sents {n_sents}')

# Main model classes

The model in main pubtrends application is loaded using `load_model` function from `review.model` module.

It has several options to set up:
* with or without features (right now without features works better),
* `BERT` or `roberta` as basis (no big difference),

You can also choose `frozen_strategy`:
* `froze_all` in case you don't want to improve bert layers but only the summarization layer,
* `unfroze_last` -- modifies bert weights and still training not very slow,
* `unfroze_all` -- the training is slow, the results may better though

In [None]:
import os
import torch.nn as nn
from collections import namedtuple
from transformers import BertModel, RobertaModel
from transformers import BertTokenizer, RobertaTokenizer

import review.config as cfg

SpecToken = namedtuple('SpecToken', ['tkn', 'idx'])
ConvertToken2Id = lambda tokenizer, tkn: tokenizer.convert_tokens_to_ids([tkn])[0]

EMBEDDINGS_ADDITIONAL = 100

FEATURES_NN_INTERMEDIATE = 100
FEATURES_NN_OUT = 50
FEATURES_DROPOUT = 0.1

BERT_PARAMS_NOT_INITIALIZED = {
    'cls.seq_relationship.weight', 'cls.seq_relationship.bias',
    'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias',
    'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias',
    'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight'
}


class Summarizer(nn.Module):
    """
    This is the main summarization model.
    See https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_tf_bert.py
    It operates the same input format as original BERT used underneath.
    See forward, evaluate params description.
    """
    enc_output: torch.Tensor
    dec_ids_mask: torch.Tensor
    encdec_ids_mask: torch.Tensor

    def __init__(self, model_type, article_len, additional_features, num_features=FEATURES_NUMBER):
        super(Summarizer, self).__init__()

        print(f'Initialize backbone and tokenizer for {model_type}')
        self.article_len = article_len
        if model_type == 'bert':
            self.backbone = self.initialize_bert()
            self.tokenizer = self.initialize_bert_tokenizer()
        elif model_type == 'roberta':
            self.backbone = self.initialize_roberta()
            self.tokenizer = self.initialize_roberta_tokenizer()
        else:
            raise Exception(f"Wrong model_type argument: {model_type}")
        self.backbone.resize_token_embeddings(EMBEDDINGS_ADDITIONAL + self.tokenizer.vocab_size)

        if additional_features:
            print('Adding additional features double fully connected nn')
            self.features = nn.Sequential(
                nn.Linear(num_features, FEATURES_NN_INTERMEDIATE),
                nn.LeakyReLU(),
                nn.Dropout(FEATURES_DROPOUT),
                nn.Linear(FEATURES_NN_INTERMEDIATE, FEATURES_NN_OUT)
            )
        else:
            self.features = None

        print('Initialize backbone embeddings pulling')

        def backbone_forward(input_ids, attention_mask, token_type_ids, position_ids):
            return self.backbone(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                position_ids=position_ids
            )

        self.encoder = lambda *args: backbone_forward(*args)[0]

        print('Initialize decoder')
        if additional_features:
            self.decoder = Classifier(cfg.d_hidden + FEATURES_NN_OUT)
        else:
            self.decoder = Classifier(cfg.d_hidden)

    def expand_positional_embs_if_need(self):
        print('Expand positional embeddings if need')
        print('Positional embeddings', self.backbone.config.max_position_embeddings, self.article_len)
        if self.article_len > self.backbone.config.max_position_embeddings:
            old_maxlen = self.backbone.config.max_position_embeddings
            old_w = self.backbone.embeddings.position_embeddings.weight
            logging.info(f'Backbone pos embeddings expanded from {old_maxlen} upto {self.article_len}')
            self.backbone.embeddings.position_embeddings = nn.Embedding(
                self.article_len, self.backbone.config.hidden_size
            )
            self.backbone.embeddings.position_embeddings.weight[:old_maxlen].data.copy_(old_w)
            self.backbone.config.max_position_embeddings = self.article_len
            print('New positional embeddings', self.backbone.config.max_position_embeddings)

    @staticmethod
    def initialize_bert():
        return BertModel.from_pretrained(
            "bert-base-uncased", output_hidden_states=False
        )

    @staticmethod
    def initialize_bert_tokenizer():
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
        BOS = "[CLS]"
        EOS = "[SEP]"
        PAD = "[PAD]"
        Summarizer._init_tokenizer(tokenizer, BOS, EOS, PAD)
        return tokenizer

    @staticmethod
    def initialize_roberta():
        backbone = RobertaModel.from_pretrained(
            'roberta-base', output_hidden_states=False
        )
        print('initialize token type emb, by default roberta doesnt have it')
        backbone.config.type_vocab_size = 2
        backbone.embeddings.token_type_embeddings = nn.Embedding(2, backbone.config.hidden_size)
        backbone.embeddings.token_type_embeddings.weight.data.normal_(
            mean=0.0, std=backbone.config.initializer_range
        )
        return backbone

    @staticmethod
    def initialize_roberta_tokenizer():
        tokenizer = RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=True)
        BOS = "<s>"
        EOS = "</s>"
        PAD = "<pad>"
        Summarizer._init_tokenizer(tokenizer, BOS, EOS, PAD)
        return tokenizer

    @staticmethod
    def _init_tokenizer(tokenizer, BOS, EOS, PAD):
        print('Initializing tokenizer with special tokens')
        PAD = SpecToken(PAD, ConvertToken2Id(tokenizer, PAD))
        artBOS = SpecToken(BOS, ConvertToken2Id(tokenizer, BOS))
        artEOS = SpecToken(EOS, ConvertToken2Id(tokenizer, EOS))
        print('Add special tokens to tokenizer')

        tokenizer.add_special_tokens(dict(additional_special_tokens=["<sum>", "</sent>", "</sum>"]))
        sumBOS = SpecToken("<sum>", ConvertToken2Id(tokenizer, "<sum>"))
        sumEOS = SpecToken("</sent>", ConvertToken2Id(tokenizer, "</sent>"))
        sumEOA = SpecToken("</sum>", ConvertToken2Id(tokenizer, "</sum>"))

        print('Configure tokenizer')
        tokenizer.PAD = PAD
        tokenizer.artBOS = artBOS
        tokenizer.artEOS = artEOS
        tokenizer.sumBOS = sumBOS
        tokenizer.sumEOS = sumEOS
        tokenizer.sumEOA = sumEOA
        print('Done initializing tokenizer with special tokens')

    def save(self, save_filename):
        """ Save model in filename

        :param save_filename: str
        """
        state = dict(
            encoder_dict=self.backbone.state_dict(),
            decoder_dict=self.decoder.state_dict()
        )
        if self.features:
            state['features_dict'] = self.features.state_dict()
        models_folder = os.path.expanduser(cfg.weights_path)
        if not os.path.exists(models_folder):
            os.makedirs(models_folder)
        torch.save(state, f"{models_folder}/{save_filename}.pth")

    def load(self, load_filename):
        path = f"{os.path.expanduser(cfg.weights_path)}/{load_filename}.pth"
        state = torch.load(path, map_location=lambda storage, location: storage)
        self.backbone.load_state_dict(state['encoder_dict'])
        self.decoder.load_state_dict(state['decoder_dict'])
        if self.features:
            self.features.load_state_dict(state['features_dict'])

    def froze_backbone(self, froze_strategy):

        for name, param in self.backbone.named_parameters():
            param.requires_grad_(name in BERT_PARAMS_NOT_INITIALIZED)
        assert froze_strategy in ['froze_all', 'unfroze_last',
                                  'unfroze_all'], f"incorrect froze_strategy argument: {froze_strategy}"

        if froze_strategy == 'froze_all':
            for param in self.backbone.parameters():
                param.requires_grad_(False)

        elif froze_strategy == 'unfroze_last':
            for name, param in self.backbone.named_parameters():
                param.requires_grad_(
                    'encoder.layer.11' in name or
                    'encoder.layer.10' in name or
                    'encoder.layer.9' in name
                )

        elif froze_strategy == 'unfroze_all':
            for param in self.backbone.parameters():
                param.requires_grad_(True)

    def unfroze_head(self):
        for param in self.decoder.parameters():
            param.requires_grad_(True)

    def forward(self, input_ids, attention_mask, token_type_ids, input_features=None):
        """
        :param input_ids: torch.Size([batch_size, article_len])
        Indices of input sequence tokens in the vocabulary.
        :param attention_mask: torch.Size([batch_size, article_len])
        Mask to avoid performing attention on padding token indices.
        Mask values selected in `[0, 1]`:
        - 1 for tokens that are **not masked**,
        - 0 for tokens that are **masked**.
        :param token_type_ids: torch.Size([batch_size, article_len])
        Segment token indices to indicate first and second portions of the inputs.
        Indices are selected in `[0, 1]`:
        - 0 corresponds to a *sentence A* token,
        - 1 corresponds to a *sentence B* token.
        :return: scores | torch.Size([batch_size, summary_len])
        """
        cls_mask = (input_ids == self.tokenizer.artBOS.idx)

        # Indices of positions of each input sequence tokens in the position embeddings.
        # position ids | torch.Size([batch_size, article_len])
        pos_ids = torch.arange(
            0,
            self.article_len,
            dtype=torch.long,
            device=input_ids.device
        ).unsqueeze(0).repeat(len(input_ids), 1)

        # extract bert embeddings | torch.Size([batch_size, article_len, d_bert])
        enc_output = self.encoder(input_ids, attention_mask, token_type_ids, pos_ids)

        if self.features:
            out_features = self.features(input_features)
            scores = self.decoder(torch.cat([enc_output[cls_mask], out_features], dim=-1))
        else:
            scores = self.decoder(enc_output[cls_mask])

        return scores

    def evaluate(self, input_ids, attention_mask, token_type_ids, input_features=None):
        """See forward for parameters and output description"""

        cls_mask = (input_ids == self.tokenizer.artBOS.idx)

        # position ids | torch.Size([batch_size, article_len])
        pos_ids = torch.arange(
            0,
            self.article_len,
            dtype=torch.long,
            device=input_ids.device
        ).unsqueeze(0).repeat(len(input_ids), 1)

        # extract bert embeddings | torch.Size([batch_size, article_len, d_bert])
        enc_output = self.encoder(input_ids, attention_mask, token_type_ids, pos_ids)

        scores = []
        for eo, cm in zip(enc_output, cls_mask):
            if self.features:
                out_features = self.features(input_features)
                score = self.decoder.evaluate(torch.cat([eo[cm], out_features], dim=-1))
            else:
                score = self.decoder.evaluate(eo[cm])
            scores.append(score)
        return scores


class Classifier(nn.Module):
    def __init__(self, hidden_size):
        super(Classifier, self).__init__()
        self.linear = nn.Linear(hidden_size, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        return self.sigmoid(self.linear(x).squeeze(-1))

    def evaluate(self, x):
        return self.sigmoid(self.linear(x).squeeze(-1))


def create_model(model_type, froze_strategy, article_len, additional_features):
    model = Summarizer(model_type, article_len, additional_features)
    model.expand_positional_embs_if_need()
    # Load intermediate model
    #     model.load('temp')
    model.froze_backbone(froze_strategy)
    model.unfroze_head()
    if additional_features:
        print('Parameters for features NN', sum(p.numel() for p in model.features.parameters() if p.requires_grad))
    print('Parameters for backbone', sum(p.numel() for p in model.backbone.parameters() if p.requires_grad))
    print('Parameters for classifier', sum(p.numel() for p in model.decoder.parameters() if p.requires_grad))
    return model

## Train and evaluate functions

In [None]:
import torch.nn as nn
from torch.optim.optimizer import Optimizer
import math


def get_enc_lr(optimizer):
    return optimizer.param_groups[0]['lr']


def get_dec_lr(optimizer):
    return optimizer.param_groups[1]['lr']


def backward_step(loss: torch.Tensor, optimizer: Optimizer, model: nn.Module, clip: float):
    loss.backward()
    total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
    return total_norm

def batch_to_device(batch, additional_features, device):
    batch_on_device = [(x.to(device) if isinstance(x, torch.Tensor) else x) for x in batch]
    if additional_features:
        input_ids, attention_mask, token_type_ids, target_scores, input_features = batch_on_device
        input_features = torch.cat(input_features).to(device)
    else:
        input_ids, attention_mask, token_type_ids, target_scores = batch_on_device
        input_features = None
    return input_ids, attention_mask, token_type_ids, target_scores, input_features

def train_fun(
        model,
        dataloader,
        optimizer,
        scheduler,
        criter,
        device,
        writer,
        additional_features
):
    # Put the model into training mode. Don't be mislead--the call to
    # `train` just changes the *mode*, it doesn't *perform* the training.
    # `dropout` and `batchnorm` layers behave differently during training
    # vs. test (source: https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch)
    model.train()
    loss_val = 0
    target_val = 0
    mean_sents = 0
    szs = 0

    for idx_batch, batch in enumerate(dataloader):
        input_ids, attention_mask, token_type_ids, target_scores, input_features = \
            batch_to_device(batch, additional_features, device)

        sizes = [dc.shape[0] for dc in target_scores]
        mean_sents += sum(sizes)
        szs += len(sizes)

        # Always clear any previously calculated gradients before performing a
        # backward pass. PyTorch doesn't do this automatically because
        # accumulating the gradients is "convenient while training RNNs".
        # (source: https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch)
        model.zero_grad()

        # forward pass
        logging.debug(f'batch {idx_batch}')
        logging.debug('forward')
        logging.debug(f'input_ids {input_ids.shape}')
        logging.debug(f'attention_mask {attention_mask.shape}')
        logging.debug(f'token_type_ids {token_type_ids.shape}')
        if additional_features:
            logging.debug(f'input_features {input_features.shape}')

        model_scores = model(input_ids, attention_mask, token_type_ids, input_features)
        logging.debug(f'models_scores {model_scores.shape}')
        logging.debug(f'target_scores {len(target_scores)}')

        target_scores = torch.cat(target_scores).to(device)
        target_val += sum(target_scores) / len(target_scores)
        try:
            # loss
            loss = criter(model_scores, target_scores, )
            loss_val += loss.item()
            logging.debug(f'loss {loss}')
        except Exception:
            print(idx_batch, model_scores.shape, target_scores.shape, token_type_ids)
            return

        # backward
        logging.debug('backward')
        grad_norm = backward_step(loss, optimizer, model, optimizer.clip_value)
        grad_norm = 0 if (math.isinf(grad_norm) or math.isnan(grad_norm)) else grad_norm

        # record a loss value
        logging.debug(f'{idx_batch} / {len(dataloader)} train loss {loss.item()}')
        print(f'\r{idx_batch} / {len(dataloader)} train loss {loss.item()}', end='')
        writer.add_scalar(f"Train/loss", loss.item(), writer.train_step)
        writer.add_scalar("Train/grad_norm", grad_norm, writer.train_step)
        writer.add_scalar("Train/lr_enc", get_enc_lr(optimizer), writer.train_step)
        writer.add_scalar("Train/lr_dec", get_dec_lr(optimizer), writer.train_step)
        writer.train_step += 1

        # make a gradient step
        if (idx_batch + 1) % optimizer.accumulation_interval == 0 or (idx_batch + 1) == len(dataloader):
            logging.debug('optimizer step')
            optimizer.step()
            optimizer.zero_grad()

        logging.debug('scheduler step')
        scheduler.step()

    print("\rTrain loss:", loss_val / len(dataloader), f"{100 * loss_val / target_val:.5f}%")
    # print("Train mean sent len:", mean_sents / szs)

    # save model, just in case
    model.save('temp')
    logging.debug('save model to temp')

    return model, optimizer, scheduler, writer


def evaluate_fun(
        model,
        dataloader,
        criter,
        device,
        writer,
        additional_features
):
    # Put the model in evaluation mode--the dropout layers behave differently
    # during evaluation.
    model.eval()
    loss_val = 0
    target_val = 0
    mean_sents = 0
    szs = 0

    for idx_batch, batch in enumerate(dataloader):
        input_ids, attention_mask, token_type_ids, target_scores, input_features = \
            batch_to_device(batch, additional_features, device)
        sizes = [dc.shape[0] for dc in target_scores]
        mean_sents += sum(sizes)
        szs += len(sizes)
        target_scores = torch.cat(target_scores).to(device)

        # evaluate pass
        logging.debug('evaluate')
        logging.debug(f'input_ids {input_ids.shape}')
        logging.debug(f'attention_mask {attention_mask.shape}')
        logging.debug(f'token_type_ids {token_type_ids.shape}')
        if additional_features:
            logging.debug(f'input_features {input_features.shape}')
        # Tell pytorch not to bother with constructing the compute graph during
        # the forward pass, since this is only needed for backprop (training).
        with torch.no_grad():
            model_scores = model(input_ids, attention_mask, token_type_ids, input_features)
        logging.debug(f'model_scores {model_scores.shape}')
        logging.debug(f'target_scores {len(target_scores)}')

        # loss
        loss = criter(model_scores, target_scores, )
        target_val += sum(target_scores) / len(target_scores)

        # record a loss value
        logging.debug(f'{idx_batch} / {len(dataloader)} val loss {loss.item()}')
        print(f'\r{idx_batch} / {len(dataloader)} val loss {loss.item()}', end='')
        loss_val += loss.item()
        writer.add_scalar(f"Eval/loss", loss.item(), writer.train_step)
        writer.train_step += 1

    print("\rValidate loss:", loss_val / len(dataloader), f"{100 * loss_val / target_val:.5f}%")
    # print("Validate mean sent len:", mean_sents / szs)

    # save model, just in case
    model.save('validated_weights')
    logging.debug('save model to validated_weights')

    return model




# Dataset and dataloader classes

In [None]:
from torch.utils.data import Dataset, DataLoader

OVERLAP = 5


class TrainDataset(Dataset):
    """ Custom Train Dataset for data with additional features.
        First preprocess all the data and then give out the batches.
        It implements overlapping between batches to keep context between train examples.
    """

    def __init__(self, dataframe, tokenizer, article_len, additional_features):
        self.df = dataframe
        self.data = []
        self.pmids = list(set(dataframe['pmid'].values))
        self.tokenizer = tokenizer
        self.article_len = article_len
        self.additional_features = additional_features
        # Create a list of test inputs for each pmid
        for pmid in tqdm(self.pmids):
            ex = self.df[self.df['pmid'] == pmid]
            text = ex['sentence'].values
            features = np.nan_to_num(
                ex[['sent_id', 'sent_type', 'r_abs', 'num_refs',
                    'mean_r_fig', 'mean_r_tab',
                    'min_r_fig', 'min_r_tab',
                    'max_r_fig', 'max_r_tab']].values.astype(float)
            ) if additional_features else None
            # Preprocessing BERT cannot encode all the text,
            # only limited number of sentences per single model run is supported.
            total_sents = 0
            while total_sents < len(text):
                offset = max(0, total_sents - OVERLAP)
                input_ids, attention_mask, token_type_ids, n_sents = preprocess_paper_bert(
                    text[offset:], self.article_len, self.tokenizer
                )
                if n_sents <= OVERLAP:
                    total_sents += 1
                    continue
                total_sents = offset + n_sents
                target_scores = ex['score'].values[offset: offset + n_sents] / 100
                input_features = features[offset: offset + n_sents] if additional_features else None
                logging.debug(f'Train dataset example {len(self.data)}\n'
                              f'input_ids {input_ids}\n'
                              f'attention_mask {attention_mask}\n'
                              f'token_type_ids {token_type_ids}\n'
                              f'target_scores {target_scores}\n'
                              f'features {input_features}')
                if additional_features:
                    self.data.append((input_ids, attention_mask, token_type_ids, target_scores, input_features))
                else:
                    self.data.append((input_ids, attention_mask, token_type_ids, target_scores))

        logging.info(f'Train dataset size {len(self.data)}')

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)


class EvalDataset(Dataset):
    """ Custom Valid/Test Dataset
    """

    def __init__(self, dataframe, tokenizer, article_len, additional_features):
        self.df = dataframe
        self.pmids = list(set(dataframe['pmid'].values))
        logging.info(f'Eval dataset size {len(self.pmids)}')
        self.tokenizer = tokenizer
        self.article_len = article_len
        self.additional_features = additional_features

    def __getitem__(self, idx):
        pmid = self.pmids[idx]
        ex = self.df[self.df['pmid'] == pmid]
        paper = ex['sentence'].values
        features = np.nan_to_num(
            ex[['sent_id', 'sent_type', 'r_abs',
                'num_refs', 'mean_r_fig', 'mean_r_tab',
                'min_r_fig', 'min_r_tab',
                'max_r_fig', 'max_r_tab']].values.astype(float)
        ) if self.additional_features else None
        input_ids, attention_mask, token_type_ids, n_sents = preprocess_paper_bert(
            paper, self.article_len, self.tokenizer
        )

        # form target
        target_scores = ex['score'].values[:n_sents] / 100
        input_features = features[:n_sents] if self.additional_features else None
        logging.debug(f'Eval dataset example {idx}\n'
                      f'input_ids {input_ids}\n'
                      f'attention_mask {attention_mask}\n'
                      f'token_type_ids {token_type_ids}\n'
                      f'target_scores {target_scores}\n'
                      f'features {input_features}')
        if self.additional_features:
            return input_ids, attention_mask, token_type_ids, target_scores, input_features
        else:
            return input_ids, attention_mask, token_type_ids, target_scores

    def __len__(self):
        return len(self.pmids)


def create_collate_fn(additional_features):
    """Create Function to pull batch for train / eval."""

    def _collate_fn(batch_data):
        """
        :param batch_data: list of `TrainDataset` or `EvalDataset` Examples
        :return: one batch of data
        """
        data = list(zip(*batch_data))
        result = [
            torch.tensor(data[0], dtype=torch.long),
            torch.tensor(data[1], dtype=torch.long),
            torch.tensor(data[2], dtype=torch.long),
            [torch.tensor(e, dtype=torch.float) for e in data[3]]
        ]
        if additional_features:
            result.append([torch.tensor(e, dtype=torch.float) for e in data[4]])
        return result

    return _collate_fn


# The DataLoader needs to know our batch size for training, so we specify it
# here. For fine-tuning BERT on a specific task, the authors recommend a batch
# size of 16 or 32.
BATCH_SIZE = 32


def get_dataloaders(train, val, batch_size, article_len, tokenizer, additional_features):
    logging.info('Creating train dataset...')
    train_ds = TrainDataset(train, tokenizer, article_len, additional_features)

    logging.info('Applying loader functions to train...')
    train_dl = DataLoader(
        dataset=train_ds, batch_size=batch_size, shuffle=False,
        pin_memory=True, collate_fn=create_collate_fn(additional_features), num_workers=cfg.num_workers
    )

    logging.info('Creating val dataset...')
    val_ds = EvalDataset(val, tokenizer, article_len, additional_features)

    logging.info('Applying loader functions to val...')
    val_dl = DataLoader(
        dataset=val_ds, batch_size=batch_size, shuffle=False,
        pin_memory=True, collate_fn=create_collate_fn(additional_features), num_workers=cfg.num_workers
    )

    return train_dl, val_dl

# Custom scheduler used in training

In [None]:
from torch.optim.lr_scheduler import _LRScheduler, ExponentialLR


class CustomScheduler(_LRScheduler):
    timestep: int = 0

    def __init__(self, optimizer, gamma, warmup=None):
        self.optimizer = optimizer
        self.after_warmup = ExponentialLR(optimizer, gamma=gamma)
        self.initial_lrs = [p_group['lr'] for p_group in self.optimizer.param_groups]
        self.warmup = 0 if warmup is None else warmup
        super(CustomScheduler, self).__init__(optimizer)

    def get_lr(self):
        return [self.timestep * group_init_lr / self.warmup for group_init_lr in
                self.initial_lrs] if self.timestep < self.warmup else self.after_warmup.get_lr()

    def step(self, epoch=None):
        if self.timestep < self.warmup:
            self.timestep += 1
            super(CustomScheduler, self).step(epoch)
        else:
            self.after_warmup.step(epoch)


class NoamScheduler(_LRScheduler):
    """
    Noam optimizer has a warm-up period and then an exponentially decaying learning.
    This is the PyTorch implementation of optimizer introduced in the paper "Attention is all you need"
    """

    def __init__(self, optimizer, warmup):
        assert warmup > 0
        self.optimizer = optimizer
        self.initial_lrs = [p_group['lr'] for p_group in self.optimizer.param_groups]
        self.warmup = warmup
        self.timestep = 0
        super(NoamScheduler, self).__init__(optimizer)

    def get_lr(self):
        noam_lr = self.get_noam_lr()
        return [group_init_lr * noam_lr for group_init_lr in self.initial_lrs]

    def get_noam_lr(self):
        return min(self.timestep ** -0.5, self.timestep * self.warmup ** -1.5)

    def step(self, epoch=None):
        self.timestep += 1
        super(NoamScheduler, self).step(epoch)

# Prepare and configure training

In [None]:
from torch.optim import AdamW
from torch.nn import MSELoss
from tensorboardX import SummaryWriter

ENCODER_LEARNING_RATE = 0.0001
DECODER_LEARNING_RATE = 0.001

WARMUP = 5
WEIGHT_DECAY = 0.01
CLIP_VALUE = 1.0
ACCUMULATION_INTERVAL = 1

# Number of training epochs. The BERT authors recommend between 2 and 4.
# We chose to run for 4, but we'll see later that this may be over-fitting the
# training data.
EPOCHS_NUMBER = 10


def prepare_learning_tools(
        model,
        enc_lr=ENCODER_LEARNING_RATE,
        dec_lr=DECODER_LEARNING_RATE,
        warmup=WARMUP,
        weight_decay=WEIGHT_DECAY,
        clip_value=CLIP_VALUE,
        accumulation_interval=ACCUMULATION_INTERVAL
):
    # TODO fix for Roberta model
    enc_parameters = [
        param for name, param in model.named_parameters()
        if param.requires_grad and name.startswith('bert.')
    ]
    dec_parameters = [
        param for name, param in model.named_parameters() if param.requires_grad
    ]
    optimizer = AdamW([
        dict(params=enc_parameters, lr=enc_lr),
        dict(params=dec_parameters, lr=dec_lr),
    ], weight_decay=weight_decay)
    optimizer.clip_value = clip_value
    optimizer.accumulation_interval = accumulation_interval

    scheduler = NoamScheduler(optimizer, warmup=warmup)
    criter = MSELoss()

    return optimizer, scheduler, criter


def load_or_train_model(model, device, additional_features):
    model_name = f'learn_simple_berta_{additional_features}.pth'.lower()
    MODEL_PATH = f'{os.path.expanduser(cfg.weights_path)}/{model_name}'
    # ! rm {MODEL_PATH}

    if os.path.exists(MODEL_PATH):
        logging.info(f'Loading model {MODEL_PATH}')
        model.load(model_name)
        model, device = setup_cuda_device(model)
        return model
    else:
        logging.info('Create dataloaders...')
        train_loader, valid_loader = get_dataloaders(
            train, val, BATCH_SIZE, ARTICLE_LENGTH, model.tokenizer, additional_features
        )

        writer = SummaryWriter(log_dir=os.path.expanduser(cfg.log_path))
        writer.train_step, writer.eval_step = 0, 0

        optimizer, scheduler, criter = prepare_learning_tools(model)

        logging.info(f"Start training {EPOCHS_NUMBER} epochs...")
        for epoch in tqdm(range(1, EPOCHS_NUMBER + 1)):
            print(f'Epoch {epoch}')
            model, optimizer, scheduler, writer = train_fun(
                model, train_loader, optimizer, scheduler,
                criter, device, writer, additional_features
            )
            model = evaluate_fun(
                model, valid_loader, criter, device, writer, additional_features
            )
        logging.info(f"Done training {EPOCHS_NUMBER} epochs...")
        logging.info(f'Save trained model to {model_name}')
        model.save(model_name)

    return model

# Create and train model

In [None]:
ARTICLE_LENGTH = 512

model = create_model("bert", "froze_all", ARTICLE_LENGTH, additional_features=ADDITIONAL_FEATURES)
model, device = setup_cuda_device(model)

In [None]:
model = load_or_train_model(model, device, additional_features=ADDITIONAL_FEATURES)

# Example of model predictions

In [None]:
print('Prepare data for model')
ex = val[val['pmid'] == val['pmid'].values[0]]
print('ex', len(ex))
text = ex['sentence'].values

input_ids, attention_mask, token_type_ids, n_sents = preprocess_paper_bert(
    text, ARTICLE_LENGTH, model.tokenizer
)
print('n_sents', n_sents)
res_sents = text[:n_sents]
scores = ex['score'].values[:n_sents]

input_ids = torch.tensor([input_ids]).to(device)
attention_mask = torch.tensor([attention_mask]).to(device)
token_type_ids = torch.tensor([token_type_ids]).to(device)

if ADDITIONAL_FEATURES:
    features = np.nan_to_num(
        ex[['sent_id', 'sent_type', 'r_abs', 'num_refs',
            'mean_r_fig', 'mean_r_tab',
            'min_r_fig', 'min_r_tab',
            'max_r_fig', 'max_r_tab']].values.astype(float)
    )
    features = features[:n_sents]
    features = [torch.tensor(e, dtype=torch.float) for e in features]
    features = torch.stack(features).to(device)
else:
    features = None

print('Apply model')
model_scores = model(input_ids, attention_mask, token_type_ids, features)

to_show_df = pd.DataFrame(
    dict(sentence=res_sents, ideal_score=scores / 100, res_score=model_scores.cpu().detach().numpy())
)
display(to_show_df.head())
print(((to_show_df['ideal_score'].values - to_show_df['res_score'].values) ** 2).mean() ** 0.5)

# Evaluate model performance

In [None]:
if 'model' not in globals():
    model = create_model("bert", "froze_all", ARTICLE_LENGTH, additional_features=ADDITIONAL_FEATURES)
    model, device = setup_cuda_device(model)
    model = load_or_train_model(model, device, additional_features=ADDITIONAL_FEATURES)

In [None]:
print('Prepare refs_and_scores dataset')
REF_SCORES_PATH = os.path.expanduser(f"{cfg.base_path}/refs_and_scores.csv")
! rm {REF_SCORES_PATH}

if os.path.exists(REF_SCORES_PATH):
    final_ref_show_df = pd.read_csv(REF_SCORES_PATH)
else:
    to_show_ref = pd.merge(train_df, ref_sents_df[['ref_pmid', 'sentence']],
                           left_on=['pmid'], right_on=['ref_pmid'])
    to_show_ref = to_show_ref.rename(columns=dict(sentence_x='sentence', sentence_y='ref_sentence'))
    to_show_ref = to_show_ref[['pmid', 'sentence', 'ref_sentence', 'score']]
    final_ref_show_dic = dict(pmid=[], sentence=[], ref_sentence=[], score=[])
    ite = [(pmid, sent) for pmid, sent in to_show_ref[['pmid', 'sentence']].values]

    for pmid, sent in tqdm(set(ite)):
        refs_df = to_show_ref[(to_show_ref['pmid'] == pmid) & (to_show_ref['sentence'] == sent)]
        final_ref_show_dic['pmid'].append(pmid)
        final_ref_show_dic['sentence'].append(sent)
        final_ref_show_dic['ref_sentence'].append(" ".join(refs_df['ref_sentence'].values))
        final_ref_show_dic['score'].append(refs_df['score'].values[0])
    final_ref_show_df = pd.DataFrame(final_ref_show_dic)
    final_ref_show_df.to_csv(REF_SCORES_PATH, index=False)

display(final_ref_show_df)

In [None]:
print('Prepare dataset to estimate performance')
to_test = final_ref_show_df[final_ref_show_df['pmid'].isin(set(val['pmid'].values))]
if ADDITIONAL_FEATURES:
    to_test = pd.merge(to_test, train_df[['pmid', 'sentence',
                                          'sent_id', 'sent_type', 'r_abs', 'num_refs',
                                          'mean_r_fig', 'mean_r_tab',
                                          'min_r_fig', 'min_r_tab',
                                          'max_r_fig', 'max_r_tab']],
                       left_on=['pmid', 'sentence'], right_on=['pmid', 'sentence'])
display(to_test)

In [None]:
print('Using model for predictions')
res = dict(pmid=[], sentence=[], ref_sentences=[], score=[], res_score=[])

for pmid in tqdm(set(to_test['pmid'].values)):
    ex = to_test[to_test['pmid'] == pmid]
    text = ex['sentence'].values
    input_ids, attention_mask, token_type_ids, n_sents = preprocess_paper_bert(
        text, ARTICLE_LENGTH, model.tokenizer
    )
    res_sents = text[:n_sents]
    scores = ex['score'].values[:n_sents] / 100
    input_ids = torch.tensor([input_ids]).to(device)
    attention_mask = torch.tensor([attention_mask]).to(device)
    token_type_ids = torch.tensor([token_type_ids]).to(device)
    if ADDITIONAL_FEATURES:
        features = np.nan_to_num(
            ex[['sent_id', 'sent_type', 'r_abs', 'num_refs',
                'mean_r_fig', 'mean_r_tab',
                'min_r_fig', 'min_r_tab',
                'max_r_fig', 'max_r_tab']].values.astype(float)
        )
        features = features[:n_sents]
        input_features = [torch.tensor(e, dtype=torch.float) for e in features]
        input_features = torch.stack(input_features).to(device)
        model_scores = model(input_ids, attention_mask, token_type_ids, input_features)
    else:
        model_scores = model(input_ids, attention_mask, token_type_ids)
    for sent, sc, res_sc in zip(res_sents, scores, model_scores.cpu().detach().numpy()):
        res['pmid'].append(pmid)
        res['sentence'].append(sent)
        res['ref_sentences'].append(ex['ref_sentence'].values[0])
        res['score'].append(sc)
        res['res_score'].append(res_sc)

res_df = pd.DataFrame(res)
display(res_df.head())
res_df.to_csv(f"{cfg.base_path}/saved_example_refs.csv")

print('MSE score', ((res_df['score'].values - res_df['res_score'].values) ** 2).mean() ** 0.5)

# Quality analysis

In [None]:
if 'model' not in globals():
    model = create_model("bert", "froze_all", ARTICLE_LENGTH, additional_features=ADDITIONAL_FEATURES)
    model, device = setup_cuda_device(model)
    model = load_or_train_model(model, device, additional_features=ADDITIONAL_FEATURES)

In [None]:
print('Searching for review papers')
inter = set(sentences_df['pmid'].values) & set(ref_sents_df['ref_pmid'].values)
review_papers = list(set(ref_sents_df[ref_sents_df['ref_pmid'].isin(inter)]['pmid'].values))
print('Review papers', len(review_papers))

In [None]:
test_stat = dict(rev_pmid=[], sent_num=[], rouge=[], true_rouge=[], diff_papers=[])

for rev_id in tqdm(review_papers):
    paper_ref = sentences_df[sentences_df['pmid'] == rev_id]['sentence'].values
    papers_to_check = list(set(ref_sents_df[ref_sents_df['pmid'] == rev_id]['ref_pmid'].values))
    result = {'pmid': [], 'sentence': [], 'score': []}
    for paper_id in papers_to_check:
        ex = test[test['pmid'] == paper_id]
        text = ex['sentence'].values

        features = np.nan_to_num(
            ex[['sent_id', 'sent_type', 'r_abs',
                'num_refs', 'mean_r_fig', 'mean_r_tab',
                'min_r_fig', 'min_r_tab',
                'max_r_fig', 'max_r_tab']].values.astype(float)
        ) if ADDITIONAL_FEATURES else None
        total_sents = 0
        while total_sents < len(text):
            magic = max(0, total_sents - 5)
            input_ids, attention_mask, token_type_ids, n_sents = preprocess_paper_bert(
                text[magic:], ARTICLE_LENGTH, model.tokenizer
            )
            if n_sents <= 5:
                total_sents += 1
                continue
            old_total = total_sents
            total_sents = magic + n_sents
            input_ids = torch.tensor([input_ids]).to(device)
            attention_mask = torch.tensor([attention_mask]).to(device)
            token_type_ids = torch.tensor([token_type_ids]).to(device)
            if ADDITIONAL_FEATURES:
                input_features = [torch.tensor(e, dtype=torch.float) for e in features[magic:total_sents]]
                input_features = torch.stack(input_features).to(device)
                model_scores = model(input_ids, attention_mask, token_type_ids, input_features)
            else:
                model_scores = model(input_ids, attention_mask, token_type_ids)

            result['pmid'].extend([paper_id] * (total_sents - old_total))
            result['sentence'].extend(list(text[old_total:total_sents]))
            result['score'].extend(list(model_scores.cpu().detach().numpy())[old_total - magic:])

    res_df = pd.DataFrame(result)
    sorted_arr = sorted(list(res_df['score'].values))
    for i in range(5, 103, 5):
        if len(sorted_arr) < i:
            break
        threshold = sorted_arr[-i]
        final_text = res_df[res_df['score'] >= threshold][['pmid', 'sentence']]
        mean_score = 0
        num = 0
        for sent in final_text['sentence'].values:
            for ref_sent in paper_ref:
                try:
                    mean_score += get_rouge(sent, ref_sent)
                    num += 1
                except Exception:
                    continue
        mean_score /= num
        real_score = get_rouge(" ".join(final_text['sentence'].values), " ".join(paper_ref))
        test_stat['rev_pmid'].append(rev_id)
        test_stat['sent_num'].append(i)
        print(len(" ".join(final_text['sentence'].values)), len(" ".join(paper_ref)))

        test_stat['rouge'].append(mean_score)
        test_stat['true_rouge'].append(real_score)
        test_stat['diff_papers'].append(len(set(final_text['pmid'])))

test_stat_df = pd.DataFrame(test_stat)
test_stat_df

In [None]:
print(*[len(arr) for key, arr in test_stat.items()])

In [None]:
test_stat_df.to_csv(f"{cfg.base_path}/simple_right_test_on_review_{ADDITIONAL_FEATURES}.csv", index=False)

In [None]:
rouge_means = []
rouge_err = []
papers_means = []
papers_err = []

for i in range(5, 103, 5):
    tmp = test_stat_df.groupby(['sent_num']).get_group(i)
    rouge_means.append(tmp['rouge'].mean())
    rouge_err.append(tmp['rouge'].std())
    papers_means.append(tmp['diff_papers'].mean())
    papers_err.append(tmp['diff_papers'].std())

In [None]:
plt.errorbar(list(range(5, 103, 5)), rouge_means, yerr=rouge_err, fmt='-o')
plt.title('Mean rouge value')
plt.show()

In [None]:
plt.errorbar(list(range(5, 103, 5)), papers_means, yerr=papers_err, fmt='-o')
plt.title('Mean number of papers')
plt.show()

# Compare models with and without features

In [None]:
import pandas as pd

df1 = pd.read_csv(f"{cfg.base_path}/simple_right_test_on_review_{False}.csv", index=False)
df1 = df1.assign(model=['BERTSUM'] * len(df1))
df2 = pd.read_csv(f"{cfg.base_path}/simple_right_test_on_review_{True}.csv", index=False)
df2 = df1.assign(model=['BERTSUM with features'] * len(df1))
draw_df = pd.concat([df1, df2])

In [None]:
import seaborn as sns

sns.catplot(x="sent_num", y="rouge", kind="box", hue='model', aspect=1.7, color='lightblue',
            data=draw_df).set_axis_labels("Number of sentences", "ROUGE, %")
plt.show()