# Training&Evaluation of a developed algorithm

This notebook contains source code for several possible architectures to train & evaluate them.

In [None]:
import logging
import os
import random
import sys

import numpy as np
import pandas as pd
import torch
from tqdm.auto import tqdm

import review.config as cfg

logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s')

In [None]:
logging.info('Check if CUDA is available')
print(f"CUDA version: {torch.version.cuda}")
print(torch.cuda.is_available())


def setup_cuda_device(model):
    logging.info('Setup single-device settings...')
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    return model, device

In [None]:
logging.info('Fix seed')
seed = cfg.seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# Loading data and preparation of dataset

`load_data` loads all the needed datafiles for building train dataset.\
The several next steps are only should be done if no train/test/val datasets are saved.

In [None]:
def sizeof_fmt(num, suffix='B'):
    for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
        if abs(num) < 1024.0:
            return "%3.1f %s%s" % (num, unit, suffix)
        num /= 1024.0
    return "%.1f %s%s" % (num, 'Yi', suffix)


def load_data(additional_features):
    dataset_root = os.path.expanduser(cfg.dataset_path)
    logging.info(f'Loading references dataset from {dataset_root}')

    logging.info('Loading citations_df')
    citations_df = pd.read_csv(os.path.join(dataset_root, "citations.csv"), sep='\t')
    logging.info(sizeof_fmt(sys.getsizeof(citations_df)))

    logging.info('Loading sentences_df')
    sentences_df = pd.read_csv(os.path.join(dataset_root, "sentences.csv"), sep='\t')
    logging.info(sizeof_fmt(sys.getsizeof(sentences_df)))

    logging.info('Loading review_files_df')
    review_files_df = pd.read_csv(os.path.join(dataset_root, "review_files.csv"), sep='\t')
    logging.info(sizeof_fmt(sys.getsizeof(review_files_df)))

    logging.info('Loading reverse_ref_df')
    reverse_ref_df = pd.read_csv(os.path.join(dataset_root, "reverse_ref.csv"), sep='\t')
    logging.info(sizeof_fmt(sys.getsizeof(reverse_ref_df)))

    if not additional_features:
        return citations_df, sentences_df, review_files_df, reverse_ref_df

    logging.info('Loading abstracts_df')
    abstracts_df = pd.read_csv(os.path.join(dataset_root, "abstracts.csv"), sep='\t')
    logging.info(sizeof_fmt(sys.getsizeof(abstracts_df)))

    logging.info('Loading figures_df')
    figures_df = pd.read_csv(os.path.join(dataset_root, "figures.csv"), sep='\t')
    logging.info(sizeof_fmt(sys.getsizeof(figures_df)))

    logging.info('Loading tables_df')
    tables_df = pd.read_csv(os.path.join(dataset_root, "tables.csv"), sep='\t')
    logging.info(sizeof_fmt(sys.getsizeof(tables_df)))

    return citations_df, sentences_df, review_files_df, reverse_ref_df, abstracts_df, figures_df, tables_df

In [None]:
# citations_df, sentences_df, review_files_df, reverse_ref_df = load_data(additional_features=False)

# Uncomment to load additional features
citations_df, sentences_df, review_files_df, reverse_ref_df, abstracts_df, figures_df, tables_df = load_data(
    additional_features=True)

logging.info('Done loading references dataset')

In [None]:
import matplotlib.pyplot as plt
from collections import Counter

res = Counter(list(sentences_df['pmid'].values))
plt.hist(res.values(), bins=range(-1, 400))
plt.title('Length of papers')
plt.show()
del res

In [None]:
REF_SENTS_DF_PATH = f"{os.path.expanduser(cfg.base_path)}/ref_sents.csv"
! rm {REF_SENTS_DF_PATH}

if os.path.exists(REF_SENTS_DF_PATH):
    ref_sents_df = pd.read_csv(REF_SENTS_DF_PATH, sep='\t')
else:
    logging.info('Creating reference sentences dataset')
    ref_sents_df = pd.merge(citations_df, reverse_ref_df, left_on=['pmid', 'ref_id'], right_on=['pmid', 'ref_id'])
    ref_sents_df = pd.merge(ref_sents_df, sentences_df, left_on=['pmid', 'sent_type', 'sent_id'],
                            right_on=['pmid', 'type', 'sent_id'])
    ref_sents_df = ref_sents_df[ref_sents_df['pmid'].isin(review_files_df['pmid'].values)]
    ref_sents_df = ref_sents_df.drop_duplicates()
    logging.info(f'Len of unique ref_sents {len(set(ref_sents_df["ref_pmid"]))}')
    ref_sents_df = ref_sents_df[['pmid', 'ref_id', 'pub_type', 'ref_pmid', 'sent_type', 'sent_id', 'sentence']]
    ref_sents_df.to_csv(REF_SENTS_DF_PATH, sep='\t', index=False)

logging.info('Cleanup memory')
del citations_df
del review_files_df

display(ref_sents_df.head())

# Rouge
`get_rouge` function allows to compute similarity between two sentences.
`rouge-l` is another possible option.

In [None]:
from rouge import Rouge
from transformers import BertTokenizer

ROUGE_METER = Rouge()
TOKENIZER = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)


def get_rouge(sent1, sent2):
    if sent1 is None or sent2 is None:
        return None
    sent_1 = TOKENIZER.tokenize(sent1)
    if len(sent_1) == 0:
        return None
    sent_1 = " ".join(list(filter(lambda x: x.isalpha() or x in '.!,?', sent_1)))
    sent_2 = TOKENIZER.tokenize(sent2)
    if len(sent_2) == 0:
        return None
    sent_2 = " ".join(list(filter(lambda x: x.isalpha() or x in '.!,?', sent_2)))
    if len(sent_1) == 0 or len(sent_2) == 0:
        return None
    rouges = ROUGE_METER.get_scores(sent_1, sent_2)[0]
    rouges = [rouges[f'rouge-{x}']["f"] for x in ('1', '2')]  # , 'l')]
    return np.mean(rouges) * 100


def mean_rouge(sent, text):
    if len(text) == 0:
        return None
    try:
        return sum(get_rouge(sent, ref_sent) for ref_sent in text) / len(text)
    except Exception as e:
        logging.error(f'Exception at mean_rouge {e}')
        return None


def min_rouge(sent, text):
    try:
        score = 100000000
        for ref_sent in text:
            score = min(get_rouge(sent, ref_sent), score)
        if score == 100000000:
            return None
        return score
    except Exception as e:
        logging.error(f'Exception at min_rouge {e}')
        return None


def max_rouge(sent, text):
    try:
        score = -100000
        for ref_sent in text:
            score = max(get_rouge(sent, ref_sent), score)
        if score == -100000:
            return None
        return score
    except Exception as e:
        logging.error(f'Exception at max_rouge {e}')
        return None

In [None]:
get_rouge("I am scout. True!", "No you are not a scout.")

# Build train / test/ validate datasets

To create a dataset with features use `preprocess_paper_with_features`. Otherwise, use `preprocess_paper`.

In case datasets are not yet created and `ref_sents_df` is also not yet created, let's create `ref_sents_df`.\
For each paper pmid there is a list of sentences from review papers in which the paper with this `pmid` is cited.

In [None]:
from unidecode import unidecode
from nltk.tokenize import sent_tokenize
import re

REPLACE_SYMBOLS = {
    '—': '-',
    '–': '-',
    '―': '-',
    '…': '...',
    '´´': "´",
    '´´´': "´´",
    "''": "'",
    "'''": "'",
    "``": "`",
    "```": "`",
    ":": " : ",
}


def parse_sents(data):
    sents = sum([sent_tokenize(text) for text in data], [])
    sents = list(filter(lambda x: len(x) > 3, sents))
    return sents


def sent_standardize(sent):
    sent = unidecode(sent)
    sent = re.sub(r"\[(xref_\w*_\w\d*]*)(, xref_\w*_\w\d*)*\]", " ", sent)  # delete [xref,...]
    sent = re.sub(r"\( (xref_\w*_\w\d*)(; xref_\w*_\w\d*)* \)", " ", sent)  # delete (xref; ...)
    sent = re.sub(r"\[xref_\w*_\w\d*\]", " ", sent)  # delete [xref]
    sent = re.sub(r"xref_\w*_\w\d*", " ", sent)  # delete [[xref]]
    for k, v in REPLACE_SYMBOLS.items():
        sent = sent.replace(k, v)
    return sent.strip()


def standardize(text):
    return [x for x in (sent_standardize(sent) for sent in text) if len(x) > 3]


def preprocess_paper(paper_id, sentences_df, ref_sents_df, min_paper_sents=50, max_paper_sents=100):
    paper = sentences_df[sentences_df['pmid'] == paper_id]['sentence']
    paper = standardize(paper)

    ref_sents = ref_sents_df[ref_sents_df['ref_pmid'] == paper_id]['sentence']
    ref_sents = standardize(ref_sents)

    if len(paper) < min_paper_sents:
        return None

    if len(paper) > max_paper_sents:
        paper = list(paper[:min_paper_sents]) + list(paper[-min_paper_sents:])

    preprocessed_score = [
        sum(get_rouge(sent, ref_sent) for ref_sent in ref_sents) / len(ref_sents) for sent in paper
    ]
    return paper, preprocessed_score


def preprocess_paper_with_features(paper_id, sentences_df, ref_sents_df,
                                   abstracts_df, figures_df, reverse_ref_df, tables_df,
                                   min_paper_sents=50, max_paper_sents=100):
    preprocessed_score = []
    features = []

    papers = sentences_df[sentences_df['pmid'] == paper_id]['sentence']
    papers = standardize(papers)

    sent_ids = sentences_df[sentences_df['pmid'] == paper_id]['sent_id']
    sent_types = sentences_df[sentences_df['pmid'] == paper_id]['type']

    ref_sents = ref_sents_df[ref_sents_df['ref_pmid'] == paper_id]['sentence']
    ref_sents = standardize(ref_sents)

    fig_captions = figures_df[figures_df['pmid'] == paper_id]['caption']
    fig_captions = standardize(fig_captions)

    tab_captions = tables_df[tables_df['pmid'] == paper_id]['caption']
    tab_captions = standardize(tab_captions)

    abstract = abstracts_df[abstracts_df['pmid'] == paper_id]['abstract']
    if len(abstract) != 0:
        abstract = standardize(abstract)

    tmp_df = reverse_ref_df[reverse_ref_df['pmid'] == paper_id]

    if len(papers) < min_paper_sents:
        return None

    if len(papers) > max_paper_sents:
        papers = list(papers[:min_paper_sents]) + list(papers[-min_paper_sents:])
        sent_ids = list(sent_ids[:min_paper_sents]) + list(sent_ids[-min_paper_sents:])
        sent_types = list(sent_types[:min_paper_sents]) + list(sent_types[-min_paper_sents:])

    for sent, sent_type, sent_id in zip(papers, sent_types, sent_ids):
        score = mean_rouge(sent, ref_sents)
        if score is None:
            return None

        r_abs = get_rouge(sent, abstract[0])
        num_refs = len(tmp_df[(tmp_df['sent_type'] == sent_type) & (tmp_df['sent_id'] == sent_id)])
        preprocessed_score.append(score)
        features.append((sent_id, int(sent_type == "general"), r_abs, num_refs,
                         mean_rouge(sent, fig_captions), mean_rouge(sent, tab_captions),
                         min_rouge(sent, fig_captions), min_rouge(sent, tab_captions),
                         max_rouge(sent, fig_captions), max_rouge(sent, tab_captions)))
    return papers, preprocessed_score, features

In [None]:
def process_reference_sentences_dataset(sentences_df, ref_sents_df, additional_features=False):
    res = {}
    inter = set(sentences_df['pmid'].values) & set(ref_sents_df['ref_pmid'].values)
    for pmid in tqdm(inter):
        try:
            if additional_features:
                temp = preprocess_paper_with_features(
                    pmid, sentences_df, ref_sents_df, abstracts_df, figures_df, reverse_ref_df, tables_df
                )
            else:
                temp = preprocess_paper(pmid, sentences_df, ref_sents_df)
        except Exception as e:
            logging.warning(f'Error during processing {pmid} {e}', e)
            continue
        if temp is None:
            logging.warning(f'temp is None for {pmid}')
            continue
        res[pmid] = temp
        print(f"\r{pmid} {np.mean(res[pmid][1])}", end="")

    logging.info(f'Successfully preprocessed {len(res)} of {len(inter)} papers')

    logging.info(f'Creating train dataset')
    feature_names = [
        'sent_id', 'sent_type', 'r_abs', 'num_refs',
        'mean_r_fig', 'mean_r_tab', 'min_r_fig', 'min_r_tab', 'max_r_fig', 'max_r_tab'
    ]
    train_dic = dict(
        pmid=[], sentence=[], score=[], sent_id=[], sent_type=[], r_abs=[], num_refs=[],
        mean_r_fig=[], mean_r_tab=[], min_r_fig=[], min_r_tab=[], max_r_fig=[], max_r_tab=[]
    )

    for pmid, stat in tqdm(res.items()):
        if len(stat) == 2:
            for sent, score in zip(*stat):
                train_dic['pmid'].append(pmid)
                train_dic['sentence'].append(sent)
                train_dic['score'].append(score)
        else:
            for sent, score, features in zip(*stat):
                train_dic['pmid'].append(pmid)
                train_dic['sentence'].append(sent)
                train_dic['score'].append(score)
                for name, val in zip(feature_names, features):
                    train_dic[name].append(val)

    train_df = pd.DataFrame({k: v for k, v in train_dic.items() if v})
    logging.info(f'Full train dataset {len(train_df)}')
    return train_df

In [None]:
TRAIN_DATASET_PATH = f'{os.path.expanduser(cfg.base_path)}/dataset.csv'
! rm {TRAIN_DATASET_PATH}

if os.path.exists(TRAIN_DATASET_PATH):
    train_df = pd.read_csv(TRAIN_DATASET_PATH)
else:
    train_df = process_reference_sentences_dataset(sentences_df, ref_sents_df, additional_features=True)
    train_df.to_csv(TRAIN_DATASET_PATH, index=False)

In [None]:
display(train_df.head())

# Preprocessing text for BERT
A function to preprocess input text for `BERT` model.

In [None]:
def preprocess_paper_bert(text, max_len, tokenizer):
    sents = [
        [tokenizer.artBOS.tkn] + tokenizer.tokenize(sent) + [tokenizer.artEOS.tkn] for sent in text
    ]
    ids, segments, segment_signature = [], [], 0
    n_sents = 0
    for s in sents:
        if len(ids) + len(s) <= max_len:
            n_sents += 1
            ids.extend(tokenizer.convert_tokens_to_ids(s))
            segments.extend([segment_signature] * len(s))
            segment_signature = (segment_signature + 1) % 2
        else:
            break
    mask = [1] * len(ids)

    pad_len = max(0, max_len - len(ids))
    ids += [tokenizer.PAD.idx] * pad_len
    mask += [0] * pad_len
    segments += [segment_signature] * pad_len

    return ids, mask, segments, n_sents

# Splitting data into train/test/val

In [None]:
from sklearn.model_selection import train_test_split

train_ids, test_ids = train_test_split(list(set(train_df['pmid'].values)), test_size=0.2)
test_ids, val_ids = train_test_split(test_ids, test_size=0.4)

train = train_df[train_df['pmid'].isin(train_ids)]
logging.info(f'Train {len(train)}')
display(train.head(1))

test = train_df[train_df['pmid'].isin(test_ids)]
logging.info(f'Test {len(test)}')
display(test.head(1))

val = train_df[train_df['pmid'].isin(val_ids)]
logging.info(f'Validate {len(val)}')
display(val.head(1))


# Model training functions

`train_fun` -- training function for model without features. \
`train_fun_ft` -- training function for model with features.

In [None]:
import torch.nn as nn
from torch.optim.optimizer import Optimizer
import math


def get_enc_lr(optimizer):
    return optimizer.param_groups[0]['lr']


def get_dec_lr(optimizer):
    return optimizer.param_groups[1]['lr']


def backward_step(loss: torch.Tensor, optimizer: Optimizer, model: nn.Module, clip: float):
    loss.backward()
    total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
    return total_norm


def train_fun(model, dataloader, optimizer, scheduler, criter, device, writer):
    # draft, refine
    model.train()

    pbar = tqdm(enumerate(dataloader), total=len(dataloader), leave=False)
    for idx_batch, batch in pbar:

        input_ids, input_mask, input_segment, target_scores = [
            (x.to(device) if isinstance(x, torch.Tensor) else x) for x in batch
        ]
        target_scores = torch.cat(target_scores).to(device)

        # forward pass
        draft_probs = model(input_ids, input_mask, input_segment, )

        try:
            # loss
            loss = criter(draft_probs, target_scores, )
        except Exception:
            print(idx_batch, draft_probs.shape, target_scores.shape, input_segment)
            return

        # backward
        grad_norm = backward_step(loss, optimizer, model, optimizer.clip_value)
        grad_norm = 0 if (math.isinf(grad_norm) or math.isnan(grad_norm)) else grad_norm

        # record a loss value
        # loss_val += loss.item() * len(input_ids)
        pbar.set_description(f"loss:{loss.item():.5f}")
        writer.add_scalar(f"Train/loss", loss.item(), writer.train_step)
        writer.add_scalar("Train/grad_norm", grad_norm, writer.train_step)
        writer.add_scalar("Train/lr_enc", get_enc_lr(optimizer), writer.train_step)
        writer.add_scalar("Train/lr_dec", get_dec_lr(optimizer), writer.train_step)
        writer.train_step += 1

        # make a gradient step
        if (idx_batch + 1) % optimizer.accumulation_interval == 0 or (idx_batch + 1) == len(dataloader):
            optimizer.step()
            optimizer.zero_grad()
        scheduler.step()

    # save model, just in case
    model.save('temp')

    return model, optimizer, scheduler, writer


def train_fun_ft(model, dataloader, optimizer, scheduler, criter, device, writer):
    # draft, refine
    model.train()

    pbar = tqdm(enumerate(dataloader), total=len(dataloader), leave=False)
    for idx_batch, batch in pbar:

        input_ids, input_mask, input_segment, target_scores, input_features = [
            (x.to(device) if isinstance(x, torch.Tensor) else x) for x in batch]
        target_scores = torch.cat(target_scores).to(device)
        input_features = torch.cat(input_features).to(device)
        # forward pass
        draft_probs = model(
            input_ids, input_mask, input_segment, input_features,
        )

        try:
            # loss
            loss = criter(draft_probs, target_scores, )
        except Exception:
            print(idx_batch, draft_probs.shape, target_scores.shape, input_segment)
            return

        # backward
        grad_norm = backward_step(loss, optimizer, model, optimizer.clip_value)
        grad_norm = 0 if (math.isinf(grad_norm) or math.isnan(grad_norm)) else grad_norm

        # record a loss value
        # loss_val += loss.item() * len(input_ids)
        pbar.set_description(f"loss:{loss.item():.5f}")
        writer.add_scalar(f"Train/loss", loss.item(), writer.train_step)
        writer.add_scalar("Train/grad_norm", grad_norm, writer.train_step)
        writer.add_scalar("Train/lr_enc", get_enc_lr(optimizer), writer.train_step)
        writer.add_scalar("Train/lr_dec", get_dec_lr(optimizer), writer.train_step)
        writer.train_step += 1

        # make a gradient step
        if (idx_batch + 1) % optimizer.accumulation_interval == 0 or (idx_batch + 1) == len(dataloader):
            optimizer.step()
            optimizer.zero_grad()
        scheduler.step()

    # save model, just in case
    model.save('temp')

    return model, optimizer, scheduler, writer

# Inspect pretrained models: BERT and Roberta

In [None]:
from transformers import BertModel, RobertaModel

backbone = BertModel.from_pretrained(
    "bert-base-uncased", output_hidden_states=False
)
print('BERT pretrained')
print(f'Parameters {sum(p.numel() for p in backbone.parameters() if p.requires_grad)}')
print(', '.join(n for n, p in backbone.named_parameters()))
# print(backbone)

backbone = RobertaModel.from_pretrained(
    'roberta-base', output_hidden_states=False
)
print('ROBERTA pretrained')
print(f'Parameters {sum(p.numel() for p in backbone.parameters() if p.requires_grad)}')
print(', '.join(n for n, p in backbone.named_parameters()))
# print(backbone)

## Main model classes

The model in main pubtrends application is loaded using `load_model` function from `review.model` module.

It has several options to set up:
* with or without features (right now without features works better),
* `BERT` or `roberta` as basis (no big difference),

You can also choose `frozen_strategy`:
* `froze_all` in case you don't want to improve bert layers but only the summarization layer,
* `unfroze_last4` -- modifies bert weights and still training not very slow,
* `unfroze_all` -- the training is slow, the results may better though

In [None]:
import os
import torch.nn as nn
from collections import namedtuple
from transformers import BertModel, RobertaModel
from transformers import BertTokenizer, RobertaTokenizer

import review.config as cfg

SpecToken = namedtuple('SpecToken', ['tkn', 'idx'])
ConvertToken2Id = lambda tokenizer, tkn: tokenizer.convert_tokens_to_ids([tkn])[0]


class Summarizer(nn.Module):
    enc_output: torch.Tensor
    dec_ids_mask: torch.Tensor
    encdec_ids_mask: torch.Tensor

    def __init__(self, model_type, article_len, additional_features=False, num_features=10):
        super(Summarizer, self).__init__()

        self.article_len = article_len

        if model_type == 'bert':
            self.backbone, self.tokenizer, BOS, EOS, PAD = self.initialize_bert()
        elif model_type == 'roberta':
            self.backbone, self.tokenizer, BOS, EOS, PAD = self.initialize_roberta()
        else:
            raise Exception(f"Wrong model_type argument: {model_type}")

        if additional_features:
            self.features = nn.Sequential(
                nn.Linear(num_features, 100),
                nn.ReLU(),
                nn.Linear(100, 100),
                nn.ReLU(),
                nn.Linear(100, 50)
            )
        else:
            self.features = None

        self.PAD = SpecToken(PAD, ConvertToken2Id(self.tokenizer, PAD))
        self.artBOS = SpecToken(BOS, ConvertToken2Id(self.tokenizer, BOS))
        self.artEOS = SpecToken(EOS, ConvertToken2Id(self.tokenizer, EOS))

        # add special tokens tokenizer
        self.tokenizer.add_special_tokens({'additional_special_tokens': ["<sum>", "</sent>", "</sum>"]})
        self.vocab_size = len(self.tokenizer)
        self.sumBOS = SpecToken("<sum>", ConvertToken2Id(self.tokenizer, "<sum>"))
        self.sumEOS = SpecToken("</sent>", ConvertToken2Id(self.tokenizer, "</sent>"))
        self.sumEOA = SpecToken("</sum>", ConvertToken2Id(self.tokenizer, "</sum>"))
        self.backbone.resize_token_embeddings(200 + self.vocab_size)

        # tokenizer
        self.tokenizer.PAD = self.PAD
        self.tokenizer.artBOS = self.artBOS
        self.tokenizer.artEOS = self.artEOS
        self.tokenizer.sumBOS = self.sumBOS
        self.tokenizer.sumEOS = self.sumEOS
        self.tokenizer.sumEOA = self.sumEOA
        self.vocab_size = len(self.tokenizer)

        # initialize backbone emb pulling
        def backbone_forward(input_ids, input_mask, input_segment, input_pos):
            return self.backbone(
                input_ids=input_ids,
                attention_mask=input_mask,
                token_type_ids=input_segment,
                position_ids=input_pos,
            )

        self.encoder = lambda *args: backbone_forward(*args)[0]

        # initialize decoder
        if not additional_features:
            self.decoder = Classifier(cfg.d_hidden)
        else:
            self.decoder = Classifier(cfg.d_hidden + 50)

    def expand_posembs_ifneed(self):
        print(self.backbone.config.max_position_embeddings, self.article_len)
        if self.article_len > self.backbone.config.max_position_embeddings:
            print("OK")
            old_maxlen = self.backbone.config.max_position_embeddings
            old_w = self.backbone.embeddings.position_embeddings.weight
            logging.info(f"Backbone pos embeddings expanded from {old_maxlen} upto {self.article_len}")
            self.backbone.embeddings.position_embeddings = nn.Embedding(self.article_len,
                                                                        self.backbone.config.hidden_size)
            self.backbone.embeddings.position_embeddings.weight[:old_maxlen].data.copy_(old_w)
            self.backbone.config.max_position_embeddings = self.article_len
        print(self.backbone.config.max_position_embeddings)

    @staticmethod
    def initialize_bert():
        backbone = BertModel.from_pretrained(
            "bert-base-uncased", output_hidden_states=False
        )
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

        BOS = "[CLS]"
        EOS = "[SEP]"
        PAD = "[PAD]"
        return backbone, tokenizer, BOS, EOS, PAD

    @staticmethod
    def initialize_roberta():
        backbone = RobertaModel.from_pretrained(
            'roberta-base', output_hidden_states=False
        )
        # initialize token type emb, by default roberta doesn't have it
        backbone.config.type_vocab_size = 2
        backbone.embeddings.token_type_embeddings = nn.Embedding(2, backbone.config.hidden_size)
        backbone.embeddings.token_type_embeddings.weight.data.normal_(
            mean=0.0, std=backbone.config.initializer_range
        )
        tokenizer = RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=True)
        BOS = "<s>"
        EOS = "</s>"
        PAD = "<pad>"
        return backbone, tokenizer, BOS, EOS, PAD

    def save(self, save_filename):
        """ Save model in filename

        :param save_filename: str
        """
        state = dict(
            encoder_dict=self.backbone.state_dict(),
            decoder_dict=self.decoder.state_dict()
        )
        if self.features:
            state['features_dict'] = self.features.state_dict()
        models_folder = os.path.expanduser(cfg.weights_path)
        if not os.path.exists(models_folder):
            os.makedirs(models_folder)
        torch.save(state, f"{models_folder}/{save_filename}.pth")

    def load(self, load_filename):
        path = f"{os.path.expanduser(cfg.weights_path)}/{load_filename}.pth"
        state = torch.load(path, map_location=lambda storage, location: storage)
        self.backbone.load_state_dict(state['encoder_dict'])
        self.decoder.load_state_dict(state['decoder_dict'])
        if self.features:
            self.features.load_state_dict(state['features_dict'])

    def froze_backbone(self, froze_strategy):
        assert froze_strategy in ['froze_all', 'unfroze_last4',
                                  'unfroze_all'], f"incorrect froze_strategy argument: {froze_strategy}"

        if froze_strategy == 'froze_all':
            for name, param in self.backbone.named_parameters():
                param.requires_grad_(False)

        elif froze_strategy == 'unfroze_last4':
            for name, param in self.backbone.named_parameters():
                param.requires_grad_(
                    'encoder.layer.11' in name or
                    'encoder.layer.10' in name or
                    'encoder.layer.9' in name or
                    'encoder.layer.8' in name
                )

        elif froze_strategy == 'unfroze_all':
            for param in self.backbone.parameters():
                param.requires_grad_(True)

    def unfroze_head(self):
        for name, param in self.decoder.named_parameters():
            param.requires_grad_(True)

    def forward(self, input_ids, input_mask, input_segment, input_features=None):
        """ Train for 1st stage of model
        :param input_ids: torch.Size([batch_size, article_len])
        :param input_mask: torch.Size([batch_size, article_len])
        :param input_segment: torch.Size([batch_size, article_len])
        :return:
            logprobs | torch.Size([batch_size, summary_len, vocab_size])
        """

        cls_mask = (input_ids == self.artBOS.idx)

        # position ids | torch.Size([batch_size, article_len])
        pos_ids = torch.arange(
            0,
            self.article_len,
            dtype=torch.long,
            device=input_ids.device
        ).unsqueeze(0).repeat(len(input_ids), 1)
        # extract bert embeddings | torch.Size([batch_size, article_len, d_bert])
        enc_output = self.encoder(input_ids, input_mask, input_segment, pos_ids)

        if self.features:
            temp_features = self.features(input_features)
            draft_logprobs = self.decoder(torch.cat([enc_output[cls_mask], temp_features], dim=-1))
        else:
            draft_logprobs = self.decoder(enc_output[cls_mask])

        return draft_logprobs

    def evaluate(self, input_ids, input_mask, input_segment, input_features=None):
        """ Eval for 1st stage of model
        :param input_ids: torch.Size([batch_size, article_len])
        :param input_mask: torch.Size([batch_size, article_len])
        :param input_segment: torch.Size([batch_size, article_len])
        :return:
            draft_ids | torch.Size([batch_size, summary_len])
        """

        cls_mask = (input_ids == self.artBOS.idx)

        # position ids | torch.Size([batch_size, article_len])
        pos_ids = torch.arange(
            0,
            self.article_len,
            dtype=torch.long,
            device=input_ids.device
        ).unsqueeze(0).repeat(len(input_ids), 1)
        # extract bert embeddings | torch.Size([batch_size, article_len, d_bert])
        enc_output = self.encoder(input_ids, input_mask, input_segment, pos_ids)

        ans = []
        for eo, cm in zip(enc_output, cls_mask):
            if self.features:
                scores = self.decoder.evaluate(torch.cat([eo[cm], self.features(input_features)], dim=-1))
            else:
                scores = self.decoder.evaluate(eo[cm])
            ans.append(scores)
        return ans


class Classifier(nn.Module):
    def __init__(self, hidden_size):
        super(Classifier, self).__init__()
        self.linear1 = nn.Linear(hidden_size, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.linear1(x).squeeze(-1)
        scores = self.sigmoid(x)
        return scores

    def evaluate(self, x):
        x = self.linear1(x).squeeze(-1)
        scores = self.sigmoid(x)
        return scores


def load_model(model_type, froze_strategy, article_len, additional_features=False):
    model = Summarizer(model_type, article_len, additional_features)
    model.expand_posembs_ifneed()
    # Load intermediate model
    #     model.load('temp')
    model.froze_backbone(froze_strategy)
    model.unfroze_head()
    print(f'Parameters {sum(p.numel() for p in model.parameters() if p.requires_grad)}')
    return model

# Dataset classes

In [None]:
from torch.utils.data import Dataset


class TrainDataset(Dataset):
    """ Custom Train Dataset for data with additional features
        First preprocess all the data and then give out the batches.
    """

    def __init__(self, dataframe, tokenizer, article_len, additional_features=False):
        self.df = dataframe
        self.data = []
        self.names = list(set(dataframe['pmid'].values))
        self.tokenizer = tokenizer
        self.article_len = article_len

        for name in tqdm(self.names):
            ex = self.df[self.df['pmid'] == name]
            papers = ex['sentence'].values
            features = np.nan_to_num(
                ex[['sent_id', 'sent_type', 'r_abs', 'num_refs',
                    'mean_r_fig', 'mean_r_tab',
                    'min_r_fig', 'min_r_tab',
                    'max_r_fig', 'max_r_tab']].values.astype(float)
            ) if additional_features else None
            total_sents = 0
            while total_sents < len(papers):
                magic = max(0, total_sents - 5)
                article_ids, article_mask, article_segment, n_sents = preprocess_paper_bert(
                    papers[magic:], self.article_len, self.tokenizer
                )
                if n_sents <= 5:
                    total_sents += 1
                    continue
                target_scores = ex['score'].values[magic:magic + n_sents] / 100
                self.data.append(
                    (article_ids, article_mask, article_segment, target_scores, features[magic:magic + n_sents])
                    if additional_features else
                    (article_ids, article_mask, article_segment, target_scores)
                )
                total_sents = magic + n_sents

        self.n_examples = len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return self.n_examples


class EvalDataset(Dataset):
    """ Custom Valid/Test Dataset
    """

    def __init__(self, dataframe, tokenizer, article_len, additional_features=False):
        self.df = dataframe
        self.n_examples = len(set(dataframe['pmid'].values))
        print(self.n_examples)
        self.names = list(set(dataframe['pmid'].values))
        self.tokenizer = tokenizer
        self.article_len = article_len
        self.additional_features = additional_features

    def __getitem__(self, idx):
        idx = self.names[idx]
        ex = self.df[self.df['pmid'] == idx]
        paper = ex['sentence'].values
        features = np.nan_to_num(
            ex[['sent_id', 'sent_type', 'r_abs',
                'num_refs', 'mean_r_fig', 'mean_r_tab',
                'min_r_fig', 'min_r_tab',
                'max_r_fig', 'max_r_tab']].values.astype(float)
        ) if self.additional_features else None

        article_ids, article_mask, article_segment, n_sents = preprocess_paper_bert(
            paper, self.article_len, self.tokenizer
        )

        # form target
        target_scores = ex['score'].values[:n_sents] / 100
        if self.additional_features:
            return article_ids, article_mask, article_segment, target_scores, features[:n_sents]
        else:
            return article_ids, article_mask, article_segment, target_scores


    def __len__(self):
        return self.n_examples

In [None]:
from torch.optim.lr_scheduler import _LRScheduler, ExponentialLR


class CustomScheduler(_LRScheduler):
    timestep: int = 0

    def __init__(self, optimizer, gamma, warmup=None):
        self.optimizer = optimizer
        self.after_warmup = ExponentialLR(optimizer, gamma=gamma)
        self.initial_lrs = [p_group['lr'] for p_group in self.optimizer.param_groups]
        self.warmup = 0 if warmup is None else warmup
        super(CustomScheduler, self).__init__(optimizer)

    def get_lr(self):
        return [self.timestep * group_init_lr / self.warmup for group_init_lr in
                self.initial_lrs] if self.timestep < self.warmup else self.after_warmup.get_lr()

    def step(self, epoch=None):
        if self.timestep < self.warmup:
            self.timestep += 1
            super(CustomScheduler, self).step(epoch)
        else:
            self.after_warmup.step(epoch)


class NoamScheduler(_LRScheduler):

    def __init__(self, optimizer, warmup):
        assert warmup > 0
        self.optimizer = optimizer
        self.initial_lrs = [p_group['lr'] for p_group in self.optimizer.param_groups]
        self.warmup = warmup
        self.timestep = 0
        super(NoamScheduler, self).__init__(optimizer)

    def get_lr(self):
        noam_lr = self.get_noam_lr()
        return [group_init_lr * noam_lr for group_init_lr in self.initial_lrs]

    def get_noam_lr(self):
        return min(self.timestep ** -0.5, self.timestep * self.warmup ** -1.5)

    def step(self, epoch=None):
        self.timestep += 1
        super(NoamScheduler, self).step(epoch)

In [None]:
from torch.optim import AdamW
from torch.nn import MSELoss
from torch.utils.data import DataLoader


def train_collate_fn(batch_data):
    """ Function to pull batch for train
    :param batch_data: list of `TrainDataset` Examples
    :return:
        one batch of data
    """
    # print(str(batch_data))
    data0, data1, data2, data3 = list(zip(*batch_data))
    return (
        torch.tensor(data0, dtype=torch.long),
        torch.tensor(data1, dtype=torch.long),
        torch.tensor(data2, dtype=torch.long),
        [torch.tensor(e, dtype=torch.float) for e in data3]
    )


def eval_collate_fn(batch_data):
    """ Function to pull batch for valid/test
    :param batch_data: list of `EvalDataset` Examples
    :return:
        one batch of data
    """
    # print(str(batch_data))
    data0, data1, data2, data3 = list(zip(*batch_data))
    return (
        torch.tensor(data0, dtype=torch.long),
        torch.tensor(data1, dtype=torch.long),
        torch.tensor(data2, dtype=torch.long),
        [torch.tensor(e, dtype=torch.float) for e in data3]
    )


def train_collate_fn_ft(batch_data):
    # print(str(batch_data))
    data0, data1, data2, data3, data4 = list(zip(*batch_data))
    return (
        torch.tensor(data0, dtype=torch.long),
        torch.tensor(data1, dtype=torch.long),
        torch.tensor(data2, dtype=torch.long),
        [torch.tensor(e, dtype=torch.float) for e in data3],
        [torch.tensor(e, dtype=torch.float) for e in data4]
    )


def eval_collate_fn_ft(batch_data):
    # print(str(batch_data))
    data0, data1, data2, data3, data4 = list(zip(*batch_data))

    return (
        torch.tensor(data0, dtype=torch.long),
        torch.tensor(data1, dtype=torch.long),
        torch.tensor(data2, dtype=torch.long),
        [torch.tensor(e, dtype=torch.float) for e in data3],
        [torch.tensor(e, dtype=torch.float) for e in data4]
    )


def create_loader(dataset, batch_size, collate_fn):
    return DataLoader(
        dataset=dataset, batch_size=batch_size, shuffle=False,
        pin_memory=True, collate_fn=collate_fn, num_workers=cfg.num_workers
    )


def get_dataloaders(train, val, batch_size, article_len, tokenizer, additional_features=False):
    logging.info('Creating train dataset...')
    train_ds = TrainDataset(train, tokenizer, article_len, additional_features)

    logging.info('Applying loader functions to train...')
    train_dl = create_loader(
        train_ds, batch_size, train_collate_fn_ft if additional_features else train_collate_fn
    )

    logging.info('Creating val dataset...')
    val_ds = EvalDataset(val, tokenizer, article_len, additional_features)

    logging.info('Applying loader functions to val...')
    val_dl = create_loader(
        val_ds, batch_size, eval_collate_fn_ft if additional_features else eval_collate_fn
    )

    return train_dl, val_dl


def get_tools(model, enc_lr, dec_lr, warmup,
              weight_decay, clip_value,
              accumulation_interval):
    # TODO fix for Roberta model
    enc_parameters = [
        param for name, param in model.named_parameters()
        if param.requires_grad and name.startswith('bert.')
    ]
    dec_parameters = [
        param for name, param in model.named_parameters()
        if param.requires_grad and not name.startswith('bert.')
    ]
    optimizer = AdamW([
        dict(params=enc_parameters, lr=enc_lr),
        dict(params=dec_parameters, lr=dec_lr),
    ], weight_decay=weight_decay)
    optimizer.clip_value = clip_value
    optimizer.accumulation_interval = accumulation_interval

    scheduler = NoamScheduler(optimizer, warmup=warmup)
    criter = MSELoss()

    return optimizer, scheduler, criter

In [None]:
def evaluate(model, dataloader, criter, device, writer):
    model.eval()
    loss_val = 0
    mean_sents = 0
    szs = 0

    pbar = tqdm(dataloader, total=len(dataloader), leave=False)
    for batch in pbar:
        input_ids, input_mask, input_segment, target_scores = [
            (x.to(device) if isinstance(x, torch.Tensor) else x) for x in batch
        ]
        sizes = [dc.shape[0] for dc in target_scores]
        mean_sents += sum(sizes)
        szs += len(sizes)
        target_scores = torch.cat(target_scores).to(device)

        # forward pass
        draft_probs = model(
            input_ids, input_mask, input_segment,
        )
        #print(draft_probs.shape, target_scores.shape)

        # loss
        loss = criter(draft_probs, target_scores, )

        # record a loss value
        pbar.set_description(f"loss:{loss.item():.5f}")
        loss_val += loss.item()
        writer.add_scalar(f"Eval/loss", loss.item(), writer.train_step)
        writer.train_step += 1

    print("Val loss:", loss_val / len(dataloader))
    print("Mean sent len:", mean_sents / szs)
    # save model, just in case
    model.save('validated_weights.pth')

    return model


def evaluate_ft(model, dataloader, criter, device, writer):
    model.eval()
    loss_val = 0
    mean_sents = 0
    szs = 0

    pbar = tqdm(dataloader, total=len(dataloader), leave=False)
    for batch in pbar:
        input_ids, input_mask, input_segment, target_scores, input_features = [
            (x.to(device) if isinstance(x, torch.Tensor) else x) for x in batch
        ]
        sizes = [dc.shape[0] for dc in target_scores]
        mean_sents += sum(sizes)
        szs += len(sizes)
        input_features = torch.cat(input_features).to(device)
        target_scores = torch.cat(target_scores).to(device)

        # forward pass
        draft_probs = model(
            input_ids, input_mask, input_segment, input_features,
        )
        #print(draft_probs.shape, target_scores.shape)

        # loss
        loss = criter(draft_probs, target_scores, )

        # record a loss value
        pbar.set_description(f"loss:{loss.item():.5f}")
        loss_val += loss.item()
        writer.add_scalar(f"Eval/loss", loss.item(), writer.train_step)
        writer.train_step += 1

    print("Val loss:", loss_val / len(dataloader))
    print("Mean sent len:", mean_sents / szs)
    # save model, just in case
    model.save('validated_weights.pth')

    return model

# Loading or training the model

In [None]:
ARTICLE_LENGTH = 512

model = load_model("bert", "froze_all", ARTICLE_LENGTH, additional_features=True)
model, device = setup_cuda_device(model)

# Create dataloaders and start training

In [None]:
from tensorboardX import SummaryWriter


def load_or_train_model(model, device, additional_features=False):
    model_name = f'learn_simple_berta_{additional_features}.pth'.lower()
    MODEL_PATH = f'{os.path.expanduser(cfg.weights_path)}/{model_name}'
    ! rm {MODEL_PATH}

    if os.path.exists(MODEL_PATH):
        logging.info(f'Loading model {MODEL_PATH}')
        model.load("learn_simple_berta")
        model, device = setup_cuda_device(model)
    else:
        logging.info('Create dataloaders...')
        train_loader, valid_loader = get_dataloaders(
            train, val, 4, ARTICLE_LENGTH, model.tokenizer, additional_features
        )

        writer = SummaryWriter(log_dir=os.path.expanduser(cfg.log_path))
        writer.train_step, writer.eval_step = 0, 0

        optimizer, scheduler, criter = get_tools(model, 0.00001, 0.001, 5, 0.005, 1.0, 1)

        for epoch in range(1, 20 + 1):
            logging.info(f"{epoch} epoch training...")
            if additional_features:
                model, optimizer, scheduler, writer = train_fun_ft(
                    model, train_loader, optimizer, scheduler,
                    criter, device, writer
                )
            else:
                model, optimizer, scheduler, writer = train_fun(
                    model, train_loader, optimizer, scheduler,
                    criter, device, writer
                )

            logging.info(f"{epoch} epoch validation...")
            if additional_features:
                model = evaluate_ft(model, valid_loader, criter, device, writer, )
            else:
                model = evaluate(model, valid_loader, criter, device, writer, )

        logging.info(f'Save trained model to {model_name}')
        model.save(model_name)
    return model

In [None]:
model = load_or_train_model(model, device, additional_features=True)

# Evaluate model performance
TODO: update model with additional features analysis.

In [None]:
# Pick the first value
ex = val[val['pmid'] == list(val['pmid'])[0]]
paper = ex['sentence'].values
article_ids, article_mask, article_segment, n_sents = preprocess_paper_bert(
    paper, ARTICLE_LENGTH, model.tokenizer
)
res_sents = paper[:n_sents]
scores = ex['score'].values[:n_sents]

In [None]:
input_ids = torch.tensor([article_ids]).to(device)
input_mask = torch.tensor([article_mask]).to(device)
input_segment = torch.tensor([article_segment]).to(device)
draft_probs = model(input_ids, input_mask, input_segment)

In [None]:
to_show_df = pd.DataFrame(
    dict(sentence=res_sents, ideal_score=scores / 100, res_score=draft_probs.cpu().detach().numpy())
)
display(to_show_df.head())

In [None]:
' '.join(to_show_df[to_show_df['res_score'] > 0.07]['sentence'].values)

In [None]:
to_show_df.to_csv(f'{cfg.base_path}/show_scores.csv')

# Inspect scores
First, let's see if the model trained well.\
Then will count an `MSE` score on some real example.

In [None]:
if 'model' not in globals():
    model = load_model("bert", "froze_all", ARTICLE_LENGTH, additional_features=True)
    model.load("learn_simple_berta")
    model, device = setup_cuda_device(model)

In [None]:
REF_SCORES_PATH = os.path.expanduser(f"{cfg.base_path}/refs_and_scores.csv")

! rm {REF_SCORES_PATH}

if os.path.exists(REF_SCORES_PATH):
    final_ref_show_df = pd.read_csv(REF_SCORES_PATH)
else:
    if 'train_loader' not in globals() or 'valid_loader' not in globals():
        train_loader, valid_loader = get_dataloaders(train, val, 4, ARTICLE_LENGTH, model.tokenizer)

    writer = SummaryWriter(log_dir=os.path.expanduser(cfg.log_path))
    writer.train_step, writer.eval_step = 0, 0
    optimizer, scheduler, criter = get_tools(model, 0.00001, 0.001, 5, 0.005, 1.0, 1)
    model = evaluate(model, train_loader, criter, device, writer)

    # model = evaluate_ft(model, valid_loader, criter, device, 0, writer, False,)    
    to_show_ref = pd.merge(train_df, ref_sents_df[['ref_pmid', 'sentence']],
                           left_on=['pmid'], right_on=['ref_pmid'])
    to_show_ref = to_show_ref.rename(columns=dict(sentence_x='sentence', sentence_y='ref_sentence'))
    to_show_ref = to_show_ref[['pmid', 'sentence', 'ref_sentence', 'score']]
    final_ref_show_dic = dict(pmid=[], sentence=[], ref_sentence=[], score=[])
    ite = [(pmid, sent) for pmid, sent in to_show_ref[['pmid', 'sentence']].values]

    for pmid, sent in tqdm(set(ite)):
        refs_df = to_show_ref[(to_show_ref['pmid'] == pmid) & (to_show_ref['sentence'] == sent)]
        final_ref_show_dic['pmid'].append(pmid)
        final_ref_show_dic['sentence'].append(sent)
        final_ref_show_dic['ref_sentence'].append(" ".join(refs_df['ref_sentence'].values))
        final_ref_show_dic['score'].append(refs_df['score'].values[0])
    final_ref_show_df = pd.DataFrame(final_ref_show_dic)
    final_ref_show_df.to_csv(REF_SCORES_PATH, index=False)

display(final_ref_show_df)

In [None]:
to_test = final_ref_show_df[final_ref_show_df['pmid'].isin(set(val['pmid'].values))]
display(to_test)

In [None]:
res = dict(pmid=[], sentence=[], ref_sentences=[], score=[], res_score=[])

for id in tqdm(set(to_test['pmid'].values)):
    ex = to_test[to_test['pmid'] == id]
    paper = ex['sentence'].values
    article_ids, article_mask, article_segment, n_sents = preprocess_paper_bert(
        paper, ARTICLE_LENGTH, model.tokenizer
    )
    res_sents = paper[:n_sents]
    scores = ex['score'].values[:n_sents] / 100
    input_ids = torch.tensor([article_ids]).to(device)
    input_mask = torch.tensor([article_mask]).to(device)
    input_segment = torch.tensor([article_segment]).to(device)
    draft_probs = model(input_ids, input_mask, input_segment, )
    for sent, sc, res_sc in zip(res_sents, scores, draft_probs.cpu().detach().numpy()):
        res['pmid'].append(id)
        res['sentence'].append(sent)
        res['ref_sentences'].append(ex['ref_sentence'].values[0])
        res['score'].append(sc)
        res['res_score'].append(res_sc)
res_df = pd.DataFrame(res)

In [None]:
res_df.head()

In [None]:
res_df.to_csv(f"{cfg.base_path}/saved_example_refs.csv")

In [None]:
diff = ((res_df['score'].values - res_df['res_score'].values) ** 2).mean()

In [None]:
diff ** 0.5  #The MSE score!

# Investigate MSE
Let's see what is the original score distribution to understand the quality of `MSE`.

In [None]:
import matplotlib.pyplot as plt
from collections import Counter

res = Counter(list(train_df['score'].values))

plt.hist(res.values(), bins=range(2, 20))
plt.title('Scores')
plt.show()

In [None]:
ref_sents_df.head()

In [None]:
sentences_df.head()

In [None]:
paper_ref = sentences_df[sentences_df['pmid'] == 14706945]['sentence'].values
paper_ref

In [None]:
papers_to_check = list(set(ref_sents_df[ref_sents_df['pmid'] == 8300981]['ref_pmid'].values))
papers_to_check

## A test of several paper summarization into review one

In [None]:
model = Summarizer('bert', ARTICLE_LENGTH)
# model.load('bert_sum')
model.load('learn_simple_berta')
model.froze_backbone("froze_all")
model.unfroze_head()

In [None]:
model.eval()
model, device = setup_cuda_device(model)

In [None]:
# Features-related evaluation is commented right now.
test_stat = dict(rev_pmid=[], sent_num=[], true_rouge=[], diff_papers=[])  #'rouge': [],
inter = set(sentences_df['pmid'].values) & set(ref_sents_df['ref_pmid'].values)
review_papers = list(set(ref_sents_df[ref_sents_df['ref_pmid'].isin(inter)]['pmid'].values))
print('Review papers', len(review_papers))
cnt = 0

for rev_id in tqdm(review_papers):
    #     print(f"\r{rev_id} {cnt} / {len(review_papers)}", end="")
    cnt += 1
    paper_ref = sentences_df[sentences_df['pmid'] == rev_id]['sentence'].values
    papers_to_check = list(set(ref_sents_df[ref_sents_df['pmid'] == rev_id]['ref_pmid'].values))
    result = {'pmid': [], 'sentence': [], 'score': []}
    for paper_id in papers_to_check:
        ex = test[test['pmid'] == paper_id]
        paper = ex['sentence'].values
        #features = np.nan_to_num(ex[['sent_id', 'sent_type', 'r_abs',
        #                   'num_refs', 'mean_r_fig', 'mean_r_tab',
        #                   'min_r_fig', 'min_r_tab',
        #                   'max_r_fig', 'max_r_tab']].values.astype(float))
        total_sents = 0
        while total_sents < len(paper):
            magic = max(0, total_sents - 5)
            article_ids, article_mask, article_segment, n_sents = preprocess_paper_bert(
                paper[magic:], ARTICLE_LENGTH, model.tokenizer
            )
            if n_sents <= 5:
                total_sents += 1
                continue
            old_total = total_sents
            total_sents = magic + n_sents
            input_ids = torch.tensor([article_ids]).to(device)
            input_mask = torch.tensor([article_mask]).to(device)
            input_segment = torch.tensor([article_segment]).to(device)
            #input_features = [torch.tensor(e, dtype=torch.float) for e in features[magic:total_sents]]
            #input_features = torch.stack(input_features).to(device)
            #print(input_features)
            draft_probs = model(
                input_ids, input_mask, input_segment,  #input_features,
            )
            result['pmid'].extend([paper_id] * (total_sents - old_total))
            result['sentence'].extend(list(paper[old_total:total_sents]))
            result['score'].extend(list(draft_probs.cpu().detach().numpy())[old_total - magic:])
    res_df = pd.DataFrame(result)
    sorted_arr = sorted(list(res_df['score'].values))
    for i in range(5, 103, 5):
        if len(sorted_arr) < i:
            break
        threshold = sorted_arr[-i]
        final_text = res_df[res_df['score'] >= threshold][['pmid', 'sentence']]
        #mean_score = 0
        #num = 0
        #for sent in final_text['sentence'].values:
        #    for ref_sent in paper_ref:
        #        try:
        #            mean_score += get_rouge(sent, ref_sent)
        #            num += 1
        #        except Exception:
        #            continue
        #mean_score /= num
        real_score = get_rouge(" ".join(final_text['sentence'].values), " ".join(paper_ref))
        test_stat['rev_pmid'].append(rev_id)
        test_stat['sent_num'].append(i)
        #print(len(" ".join(final_text['sentence'].values)), len(" ".join(paper_ref)))

        #test_stat['rouge'].append(mean_score)
        test_stat['true_rouge'].append(real_score)
        test_stat['diff_papers'].append(len(set(final_text['pmid'])))


In [None]:
print(*[len(arr) for key, arr in test_stat.items()])

In [None]:
test_stat_df = pd.DataFrame(test_stat)

In [None]:
test_stat_df

In [None]:
test_stat_df = test_stat_df[test_stat_df['rev_pmid'] != 29574033]

In [None]:
test_stat_df.to_csv(f"{cfg.base_path}/simple_right_test_on_review.csv", index=False)

In [None]:
len(result['sentence'])

In [None]:
len(result['score'])

In [None]:
res_df = pd.DataFrame(result)

In [None]:
res_df

In [None]:
threshold = sorted(list(res_df['score'].values))[-5]

In [None]:
final_text = res_df[res_df['score'] >= threshold][['pmid', 'sentence']]

In [None]:
len(set(final_text['pmid']))

In [None]:
" ".join(final_text['sentence'].values)

In [None]:
mean_score = 0
num = 0
for sent in final_text['sentence'].values:
    for ref_sent in paper_ref:
        mean_score += get_rouge(sent, ref_sent)
        num += 1
mean_score /= num

In [None]:
mean_score

In [None]:
" ".join(final_text)

In [None]:
inter = list(set(sentences_df['pmid'].values) & set(ref_sents_df['ref_pmid'].values))

In [None]:
26194312 in inter

In [None]:
ref_sents_df[ref_sents_df['ref_pmid'] == 25559091]

# Quality plots

In [None]:
draw_df = pd.concat([df_1, df_2])
draw_df.head()

In [None]:
rouge_means = []
rouge_err = []
papers_means = []
papers_err = []

for i in range(5, 103, 5):
    tmp = test_stat_df.groupby(['sent_num']).get_group(i)
    rouge_means.append(tmp['rouge'].mean())
    rouge_err.append(tmp['rouge'].std())
    papers_means.append(tmp['diff_papers'].mean())
    papers_err.append(tmp['diff_papers'].std())


In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.errorbar(list(range(5, 103, 5)), rouge_means, yerr=rouge_err, fmt='-o')

In [None]:
plt.errorbar(list(range(5, 103, 5)), papers_means, yerr=papers_err, fmt='-o')

In [None]:
import seaborn as sns

In [None]:
sns.catplot(x="sent_num", y="rouge", kind="box", hue='model', aspect=1.7, color='lightblue',
            data=draw_df).set_axis_labels("ЧИСЛО ПРЕДЛОЖЕНИЙ", "ROUGE, %")

In [None]:
sns.catplot(x="sent_num", y="diff_papers", kind="box", aspect=1.5, color='lightblue',
            data=test_stat_df).set_axis_labels("ЧИСЛО ПРЕДЛОЖЕНИЙ", "ЧИСЛО СТАТЕЙ В РЕЗЮМЕ")

In [None]:
sns.catplot(x="sent_num", y="true_rouge", kind="box", aspect=1.7, hue='model', color='lightblue',
            data=draw_df).set_axis_labels("ЧИСЛО ПРЕДЛОЖЕНИЙ", "ROUGE, %")