# Training&Evaluation of a developed algorithm

In this python notebook you can try several possible architectures and train&evaluate them.

0. Utilities

In [None]:
import logging
import numpy as np
import pandas as pd
import random
import sys
import torch
from tqdm.auto import tqdm

import pysrc.review.config as cfg

logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s')

# Used memory analysis utility
def sizeof_fmt(num, suffix='B'):
    for unit in ['','Ki','Mi','Gi','Ti','Pi','Ei','Zi']:
        if abs(num) < 1024.0:
            return "%3.1f %s%s" % (num, unit, suffix)
        num /= 1024.0
    return "%.1f %s%s" % (num, 'Yi', suffix)

# Print all allocated variables
def print_mem_usage():
    for name, size in sorted(((name, sys.getsizeof(value)) for name, value in globals().items()),
                             key= lambda x: x[1],
                             reverse=True)[:10]:
        print("Global {:>30}: {:>8}".format(name, sizeof_fmt(size)))    
    for name, size in sorted(((name, sys.getsizeof(value)) for name, value in locals().items()),
                             key= lambda x: x[1],
                             reverse=True)[:10]:
        print("Local {:>30}: {:>8}".format(name, sizeof_fmt(size)))

def init_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

init_seed(cfg.seed)

In [None]:
# Check if CUDA is available
print(torch.cuda.is_available())

1. Create a `Summarizer`&`Classifier` class.

It has several options to set up: 
* with or without features (right now without features works better), 
* `BERT` or `roberta` as basis (no big difference), 

You can also choose `frozen_strategy`:
* `froze_all` in case you don't want to improve bert layers but only the summarization layer, 
* `unfroze_last4` -- modifies bert weights and still training not very slow, 
* `unfroze_all` -- the training is slow, the results may better though

In [None]:
import os
import torch.nn as nn
import torch
import numpy as np
from transformers import BertModel, RobertaModel
from collections import namedtuple
from transformers import BertTokenizer, RobertaTokenizer
from pathlib import Path

import pysrc.review.config as cfg
from pysrc.review.utils import get_ids_mask


SpecToken = namedtuple('SpecToken', ['tkn', 'idx'])
ConvertToken2Id = lambda tokenizer, tkn: tokenizer.convert_tokens_to_ids([tkn])[0]


class Summarizer(nn.Module):

    enc_output: torch.Tensor
    rouges_values: np.array = np.zeros(4)
    dec_ids_mask: torch.Tensor
    encdec_ids_mask: torch.Tensor

    def __init__(self, model_type, article_len, with_features=False, num_features=10):
        super(Summarizer, self).__init__()

        self.article_len = article_len

        if model_type == 'bert':
            self.backbone, self.tokenizer, BOS, EOS, PAD = self.initialize_bert()
        elif model_type == 'roberta':
            self.backbone, self.tokenizer, BOS, EOS, PAD = self.initialize_roberta()
        else:
            raise Exception(f"Wrong model_type argument: {model_type}")
            
        if with_features:
            self.features = nn.Sequential(nn.Linear(num_features, 100),
                                          nn.ReLU(),
                                          nn.Linear(100, 100),
                                          nn.ReLU(),
                                          nn.Linear(100, 50))
        else:
            self.features = None

        self.PAD = SpecToken(PAD, ConvertToken2Id(self.tokenizer, PAD))
        self.artBOS = SpecToken(BOS, ConvertToken2Id(self.tokenizer, BOS))
        self.artEOS = SpecToken(EOS, ConvertToken2Id(self.tokenizer, EOS))

        # add special tokens tokenizer
        self.tokenizer.add_special_tokens({'additional_special_tokens': ["<sum>", "</sent>", "</sum>"]})
        self.vocab_size = len(self.tokenizer)
        self.sumBOS = SpecToken("<sum>", ConvertToken2Id(self.tokenizer, "<sum>"))
        self.sumEOS = SpecToken("</sent>", ConvertToken2Id(self.tokenizer, "</sent>"))
        self.sumEOA = SpecToken("</sum>", ConvertToken2Id(self.tokenizer, "</sum>"))
        self.backbone.resize_token_embeddings(200 + self.vocab_size)

        # tokenizer
        self.tokenizer.PAD = self.PAD
        self.tokenizer.artBOS = self.artBOS
        self.tokenizer.artEOS = self.artEOS
        self.tokenizer.sumBOS = self.sumBOS
        self.tokenizer.sumEOS = self.sumEOS
        self.tokenizer.sumEOA = self.sumEOA
        self.vocab_size = len(self.tokenizer)

        # initialize backbone emb pulling
        def backbone_forward(input_ids, input_mask, input_segment, input_pos):
            return self.backbone(
                input_ids=input_ids,
                attention_mask=input_mask,
                token_type_ids=input_segment,
                position_ids=input_pos,
            )
        self.encoder = lambda *args: backbone_forward(*args)[0]

        # initialize decoder
        if not with_features:
            self.decoder = Classifier(cfg.d_hidden)
        else:
            self.decoder = Classifier(cfg.d_hidden + 50)

    def expand_posembs_ifneed(self):
        print(self.backbone.config.max_position_embeddings, self.article_len)
        if self.article_len > self.backbone.config.max_position_embeddings:
            print("OK")
            old_maxlen = self.backbone.config.max_position_embeddings
            old_w = self.backbone.embeddings.position_embeddings.weight
            logging.info(f"Backbone pos embeddings expanded from {old_maxlen} upto {self.article_len}")
            self.backbone.embeddings.position_embeddings = \
                nn.Embedding(self.article_len, self.backbone.config.hidden_size)
            self.backbone.embeddings.position_embeddings.weight[:old_maxlen].data.copy_(old_w)
            self.backbone.config.max_position_embeddings = self.article_len
        print(self.backbone.config.max_position_embeddings)

    @staticmethod
    def initialize_bert():
        backbone = BertModel.from_pretrained(
            "bert-base-uncased", output_hidden_states=False
        )
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

        BOS = "[CLS]"
        EOS = "[SEP]"
        PAD = "[PAD]"
        return backbone, tokenizer, BOS, EOS, PAD

    @staticmethod
    def initialize_roberta():
        backbone = RobertaModel.from_pretrained(
            'roberta-base', output_hidden_states=False
        )
        # initialize token type emb, by default roberta doesn't have it
        backbone.config.type_vocab_size = 2
        backbone.embeddings.token_type_embeddings = nn.Embedding(2, backbone.config.hidden_size)
        backbone.embeddings.token_type_embeddings.weight.data.normal_(
            mean=0.0, std=backbone.config.initializer_range
        )
        tokenizer = RobertaTokenizer.from_pretrained('roberta-base', do_lower_case=True)
        BOS = "<s>"
        EOS = "</s>"
        PAD = "<pad>"
        return backbone, tokenizer, BOS, EOS, PAD

    def save(self, save_filename):
        """ Save model in filename

        :param save_filename: str
        """
        if not self.features:
            state = {
                'encoder_dict': self.backbone.state_dict(),
                'decoder_dict': self.decoder.state_dict(),
            }
        else:
            state = {
                'encoder_dict': self.backbone.state_dict(),
                'decoder_dict': self.decoder.state_dict(),
                'features_dict': self.features.state_dict(),
            }
        models_folder = os.path.expanduser(cfg.weights_path)
        if not os.path.exists(models_folder):
            os.mkdirs(models_folder)
        torch.save(state, f"{models_folder}/{save_filename}.pth")

    def load(self, load_filename):
        path = f"{os.path.expanduser(cfg.weights_path)}/{load_filename}.pth"
        state = torch.load(path, map_location=lambda storage, location: storage)
        self.backbone.load_state_dict(state['encoder_dict'])
        self.decoder.load_state_dict(state['decoder_dict'])
        if self.features:
            self.features.load_state_dict(state['features_dict'])
        

    def froze_backbone(self, froze_strategy):

        assert froze_strategy in ['froze_all', 'unfroze_last4', 'unfroze_all'],\
            f"incorrect froze_strategy argument: {froze_strategy}"

        if froze_strategy == 'froze_all':
            for name, param in self.backbone.named_parameters():
                param.requires_grad_(False)

        elif froze_strategy == 'unfroze_last4':
            for name, param in self.backbone.named_parameters():
                param.requires_grad_(True if (
                    'encoder.layer.11' in name or
                    'encoder.layer.10' in name or
                    'encoder.layer.9' in name or
                    'encoder.layer.8' in name
                ) else False)

        elif froze_strategy == 'unfroze_all':
            for param in self.backbone.parameters():
                param.requires_grad_(True)

    def unfroze_head(self):

        for name, param in self.decoder.named_parameters():
            param.requires_grad_(True)

    @property
    def rouge_1(self):
        return self.rouges_values[0]

    @property
    def rouge_2(self):
        return self.rouges_values[1]

    @property
    def rouge_l(self):
        return self.rouges_values[2]

    @property
    def rouge_mean(self):
        return self.rouges_values[3]

    def forward(self, input_ids, input_mask, input_segment, input_features=None):
        """ Train for 1st stage of model

        :param input_ids: torch.Size([batch_size, article_len])
        :param input_mask: torch.Size([batch_size, article_len])
        :param input_segment: torch.Size([batch_size, article_len])
        :return:
            logprobs | torch.Size([batch_size, summary_len, vocab_size])
        """
        
        cls_mask = (input_ids == self.artBOS.idx)

        # position ids | torch.Size([batch_size, article_len])
        pos_ids = torch\
            .arange(0, self.article_len, dtype=torch.long, device=input_ids.device)\
            .unsqueeze(0)\
            .repeat(len(input_ids), 1)
        # extract bert embeddings | torch.Size([batch_size, article_len, d_bert])
        enc_output = self.encoder(input_ids, input_mask, input_segment, pos_ids)
        
        if self.features:
            temp_features = self.features(input_features)
            draft_logprobs = self.decoder(torch.cat([enc_output[cls_mask], temp_features], dim=-1))
        else:
            draft_logprobs = self.decoder(enc_output[cls_mask])

        return draft_logprobs

    def evaluate(self, input_ids, input_mask, input_segment, input_features=None):
        """ Eval for 1st stage of model

        :param input_ids: torch.Size([batch_size, article_len])
        :param input_mask: torch.Size([batch_size, article_len])
        :param input_segment: torch.Size([batch_size, article_len])
        :return:
            draft_ids | torch.Size([batch_size, summary_len])
        """

        cls_mask = (input_ids == self.artBOS.idx)

        # position ids | torch.Size([batch_size, article_len])
        pos_ids = torch\
            .arange(0, self.article_len, dtype=torch.long, device=input_ids.device)\
            .unsqueeze(0)\
            .repeat(len(input_ids), 1)
        # extract bert embeddings | torch.Size([batch_size, article_len, d_bert])
        enc_output = self.encoder(input_ids, input_mask, input_segment, pos_ids)

        ans = []
        for eo, cm in zip(enc_output, cls_mask):
            if self.features:
                scores = self.decoder.evaluate(torch.cat([eo[cm], self.features(input_features)], dim=-1))
            else:
                scores = self.decoder.evaluate(eo[cm])
            ans.append(scores)
        return ans


class Classifier(nn.Module):
    def __init__(self, hidden_size):
        super(Classifier, self).__init__()
        self.linear1 = nn.Linear(hidden_size, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.linear1(x).squeeze(-1)
        scores = self.sigmoid(x)
        return scores

    def evaluate(self, x):
        x = self.linear1(x).squeeze(-1)
        scores = self.sigmoid(x)
        return scores

2. `train_fun` -- training function for model without features. \
`train_fun_ft` -- training function for model with features.

In [None]:
from torch.optim.optimizer import Optimizer
from tqdm import tqdm
import math
from pysrc.review.utils import get_enc_lr, get_dec_lr

def backward_step(loss: torch.Tensor, optimizer: Optimizer, model: nn.Module, clip: float, amp_enabled: int):
    loss.backward()
    total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
    return total_norm

def train_fun(model, dataloader, optimizer, scheduler, criter, device, rank, writer, distributed):

    # draft, refine
    model.train()
    model_ref = model.module if distributed else model

    pbar = tqdm(enumerate(dataloader), total=len(dataloader), leave=False, disable=rank != 0)
    for idx_batch, batch in pbar:
        
        input_ids, input_mask, input_segment, target_scores = \
            [(x.to(device) if isinstance(x, torch.Tensor) else x) for x in batch]
        target_scores = torch.cat(target_scores).to(device)

        # forward pass
        draft_probs = model(
            input_ids, input_mask, input_segment,
        )

        try:
        # loss
            loss = criter(
                draft_probs,
                target_scores,
            )
        except Exception:
            print(idx_batch, draft_probs.shape, target_scores.shape, input_segment)
            return

        # backward
        grad_norm = backward_step(loss, optimizer, model, optimizer.clip_value, amp_enabled=cfg.amp_enabled)
        grad_norm = 0 if (math.isinf(grad_norm) or math.isnan(grad_norm)) else grad_norm

        # record a loss value
        # loss_val += loss.item() * len(input_ids)
        pbar.set_description(f"loss:{loss.item():.2f}")
        writer.add_scalar(f"Train/loss", loss.item(), writer.train_step)
        writer.add_scalar("Train/grad_norm", grad_norm, writer.train_step)
        writer.add_scalar("Train/lr_enc", get_enc_lr(optimizer), writer.train_step)
        writer.add_scalar("Train/lr_dec", get_dec_lr(optimizer), writer.train_step)
        writer.train_step += 1

        # make a gradient step
        if (idx_batch + 1) % optimizer.accumulation_interval == 0 or (idx_batch + 1) == len(dataloader):
            optimizer.step()
            optimizer.zero_grad()
        scheduler.step()

    # overall loss per epoch
    # if distributed:
    #     loss_val = distribute(loss_val, device)
    # logging.info(f"mean loss: {loss_val / len(dataloader.dataset):.4f}", is_print=rank == 0)

    # save model, just in case
    if rank == 0:
        model_ref.save('temp')

    return model, optimizer, scheduler, writer

def train_fun_ft(model, dataloader, optimizer, scheduler, criter, device, rank, writer, distributed):

    # draft, refine
    model.train()
    model_ref = model.module if distributed else model

    pbar = tqdm(enumerate(dataloader), total=len(dataloader), leave=False, disable=rank != 0)
    for idx_batch, batch in pbar:
        
        input_ids, input_mask, input_segment, target_scores, input_features = \
            [(x.to(device) if isinstance(x, torch.Tensor) else x) for x in batch]
        target_scores = torch.cat(target_scores).to(device)
        input_features = torch.cat(input_features).to(device)
        # forward pass
        draft_probs = model(
            input_ids, input_mask, input_segment, input_features,
        )

        try:
        # loss
            loss = criter(
                draft_probs,
                target_scores,
            )
        except Exception:
            print(idx_batch, draft_probs.shape, target_scores.shape, input_segment)
            return

        # backward
        grad_norm = backward_step(loss, optimizer, model, optimizer.clip_value, amp_enabled=cfg.amp_enabled)
        grad_norm = 0 if (math.isinf(grad_norm) or math.isnan(grad_norm)) else grad_norm

        # record a loss value
        # loss_val += loss.item() * len(input_ids)
        pbar.set_description(f"loss:{loss.item():.2f}")
        writer.add_scalar(f"Train/loss", loss.item(), writer.train_step)
        writer.add_scalar("Train/grad_norm", grad_norm, writer.train_step)
        writer.add_scalar("Train/lr_enc", get_enc_lr(optimizer), writer.train_step)
        writer.add_scalar("Train/lr_dec", get_dec_lr(optimizer), writer.train_step)
        writer.train_step += 1

        # make a gradient step
        if (idx_batch + 1) % optimizer.accumulation_interval == 0 or (idx_batch + 1) == len(dataloader):
            optimizer.step()
            optimizer.zero_grad()
        scheduler.step()

    # overall loss per epoch
    # if distributed:
    #     loss_val = distribute(loss_val, device)
    # logging.info(f"mean loss: {loss_val / len(dataloader.dataset):.4f}", is_print=rank == 0)

    # save model, just in case
    if rank == 0:
        model_ref.save('temp')

    return model, optimizer, scheduler, writer

3. `load_data` loads all the needed datafiles for building train dataset.\
The several next steps are only should be done if no train/test/val datasets are saved.

In [None]:
def load_data(ft):
    root2data = Path(os.path.expanduser(cfg.dataset_path))
    logging.info(f'Loading references dataset from {root2data}')
    
    logging.info('Loading citations_df')
    citations_df = pd.read_csv(root2data / "citations.csv", sep='\t')
    logging.info(sizeof_fmt(sys.getsizeof(citations_df)))

    logging.info('Loading sentences_df')
    sentences_df = pd.read_csv(root2data / "sentences.csv", sep='\t')
    logging.info(sizeof_fmt(sys.getsizeof(sentences_df)))

    logging.info('Loading review_files_df')
    review_files_df = pd.read_csv(root2data / "review_files.csv", sep='\t')
    logging.info(sizeof_fmt(sys.getsizeof(review_files_df)))

    logging.info('Loading reverse_ref_df')
    reverse_ref_df = pd.read_csv(root2data / "reverse_ref.csv", sep='\t')
    logging.info(sizeof_fmt(sys.getsizeof(reverse_ref_df)))
    if not ft:
        return citations_df, sentences_df, review_files_df, reverse_ref_df

    logging.info('Loading abstracts_df')
    abstracts_df = pd.read_csv(root2data / "abstracts.csv", sep='\t')
    logging.info(sizeof_fmt(sys.getsizeof(abstracts_df)))
    
    logging.info('Loading filelist_df')
    filelist_df = pd.read_csv(root2data / "filelist.csv", sep=',')
    logging.info(sizeof_fmt(sys.getsizeof(filelist_df)))
        
    logging.info('Loading figures_df')
    figures_df = pd.read_csv(root2data / "figures.csv", sep='\t')
    logging.info(sizeof_fmt(sys.getsizeof(figures_df)))
        
    logging.info('Loading tables_df')
    tables_df = pd.read_csv(root2data / "tables.csv", sep='\t')
    logging.info(sizeof_fmt(sys.getsizeof(tables_df)))
    return citations_df, sentences_df, review_files_df, reverse_ref_df, \
        abstracts_df, filelist_df, figures_df, tables_df    

4. `get_rouge` function will be needed anyway. You can uncomment `rouge-l`, if you want. 

In [None]:
from rouge import Rouge

ROUGE_METER = Rouge()
TOKENIZER = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

def get_rouge(sent1, sent2):
    sent_1 = TOKENIZER.tokenize(sent1)
    #print(sent_1)
    sent_1 = " ".join(list(filter(lambda x: x.isalpha() or x in '.!,?', sent_1)))
    #print(sent_1)
    sent_2 = TOKENIZER.tokenize(sent2)
    sent_2 = " ".join(list(filter(lambda x: x.isalpha() or x in '.!,?', sent_2)))
    rouges = ROUGE_METER.get_scores(sent_1, sent_2)[0]
    rouges = [rouges[f'rouge-{x}']["f"] for x in ('1', '2')] # , 'l')]
    return np.mean(rouges) * 100

In [None]:
get_rouge("I am scout. True!", "No you are not a scout.")

5. Several functions needed to build the datasets. If the datasets alredy exist, do not have to run this cell.

In [None]:
from unidecode import unidecode
from nltk.tokenize import sent_tokenize
import re

REPLACE_SYMBOLS = {
    '—': '-',
    '–': '-',
    '―': '-',
    '…': '...',
    '´´': "´",
    '´´´': "´´",
    "''": "'",
    "'''": "'",
    "``": "`",
    "```": "`",
    ":": " : ",
}

def parse_sents(data):
    sents = sum([sent_tokenize(text) for text in data], [])
    sents = [x for x in text if len(x) > 3]
    return sents

def sent_standardize(sent):
    sent = unidecode(sent)
    sent = re.sub(r"\[(xref_\w*_\w\d*]*)(, xref_\w*_\w\d*)*\]", " ", sent)  # delete [xref,...]
    sent = re.sub(r"\( (xref_\w*_\w\d*)(; xref_\w*_\w\d*)* \)", " ", sent)  # delete (xref; ...)
    sent = re.sub(r"\[xref_\w*_\w\d*\]", " ", sent)  # delete [xref]
    sent = re.sub(r"xref_\w*_\w\d*", " ", sent)  # delete [[xref]]
    for k, v in REPLACE_SYMBOLS.items():
        sent = sent.replace(k, v)
    return sent.strip()


def standardize(text):
    return [x for x in (sent_standardize(sent) for sent in text) if len(x) > 3]



def preprocess_paper(paper_id, sentences_df, ref_sents_df):
    paper = sentences_df[sentences_df['pmid'] == paper_id]['sentence']
    paper = standardize(paper)
    
    ref_sents = ref_sents_df[ref_sents_df['ref_pmid'] == paper_id]['sentence']
    ref_sents = standardize(ref_sents)
    
    if len(paper) < 50:
        return None
            
    if len(paper) > 100:
        paper = list(paper[:50]) + list(paper[-50:])
    
    preprocessed_score = [sum(get_rouge(sent, ref_sent) for ref_sent in ref_sents) / len(ref_sents)
                          for sent in paper]
    return paper, preprocessed_score

def preprocess_paper_with_features(paper_id, sentences_df, ref_sents_df, abstracts_df, \
                                   figures_df, reverse_ref_df, tables_df):
    preprocessed_score = []
    features = []
    
    paper = sentences_df[sentences_df['pmid'] == paper_id]['sentence']
    paper = standardize(paper)
    
    sent_ids = sentences_df[sentences_df['pmid'] == paper_id]['sent_id']
    
    sent_types = sentences_df[sentences_df['pmid'] == paper_id]['type']
    
    ref_sents = ref_sents_df[ref_sents_df['ref_pmid'] == paper_id]['sentence']
    ref_sents = standardize(ref_sents)

    
    fig_captions = figures_df[figures_df['pmid'] == paper_id]['caption']
    fig_captions = standardize(fig_captions)
    
    tab_captions = tables_df[tables_df['pmid'] == paper_id]['caption']
    tab_captions = standardize(tab_captions)
    
    abstract = abstracts_df[abstracts_df['pmid'] == paper_id]['abstract']
    if len(abstract) != 0:
        abstract = standardize(abstract)
    
    tmp_df = reverse_ref_df[reverse_ref_df['pmid'] == paper_id]
    
    if len(paper) < 50:
        return None
            
    if len(paper) > 100:
        paper = list(paper[:50]) + list(paper[-50:])
        sent_ids = list(sent_ids[:50]) + list(sent_ids[-50:])
        sent_types = list(sent_types[:50]) + list(sent_types[-50:])
        
        
    def mean_rouge(sent, text):
        try:
            return sum(get_rouge(sent, ref_sent) for ref_sent in text) / len(text)
        except Exception as e:
            logging.error(f'Exception at mean_rouge {e}')
            return None
    
    def min_rouge(sent, text):
        try:
            score = 100000000
            for ref_sent in text:
                score = min(get_rouge(sent, ref_sent), score)
            if score == 100000000:
                return None
            return score
        except Exception as e:
            logging.error(f'Exception at min_rouge {e}')
            return None
    
    def max_rouge(sent, text):
        try:
            score = -100000
            for ref_sent in text:
                score = max(get_rouge(sent, ref_sent), score)
            if score == -100000:
                return None
            return score
        except Exception as e:
            logging.error(f'Exception at max_rouge {e}')
            return None
    
    for i, sent in enumerate(paper):
        score = mean_rouge(sent, ref_sents)
        if score is None:
            return None
        
        try:
            abst_diff = get_rouge(sent, abstract[0])
        except Exception as e:
            logging.error(f'Exception at preprocess_paper_with_features {e}')
            abst_diff = None
        num_refs = len(tmp_df[(tmp_df['sent_type'] == sent_types[i]) & (tmp_df['sent_id'] == sent_ids[i])])
        preprocessed_score.append(score)
        features.append((sent_ids[i], int(sent_types[i] == "general"), abst_diff, num_refs,\
                        mean_rouge(sent, fig_captions), mean_rouge(sent, tab_captions),\
                        min_rouge(sent, fig_captions), min_rouge(sent, tab_captions), \
                        max_rouge(sent, fig_captions), max_rouge(sent, tab_captions)))
    return paper, preprocessed_score, features

6. Download data to make datasets from it. If datasets with features are needed, the commented code here should be uncommented.

In [None]:
citations_df, sentences_df, review_files_df, reverse_ref_df = load_data(ft=False)

# Uncomment to load additional features
# citations_df, sentences_df, review_files_df, reverse_ref_df, \
#     abstracts_df, filelist_df, figures_df, tables_df = load_data()

logging.info('Done loading references dataset')

In [None]:
from tqdm import tqdm
import matplotlib.pyplot as plt
from collections import Counter

res = Counter(list(sentences_df['pmid'].values))
plt.hist(res.values(), bins=range(-1, 400))
plt.title('Length of papers')
plt.show()

7. In case datasets are not yet created and `ref_sents_df` is also not yet created, let's create `ref_sents_df`.\
For each paper pmid there is a list of sentences from review papers in which the paper with this `pmid` is cited.

In [None]:
REF_SENTS_DF_PATH = f"{os.path.expanduser(cfg.base_path)}/ref_sents.csv"
if os.path.exists(REF_SENTS_DF_PATH):
    ref_sents_df = pd.read_csv(REF_SENTS_DF_PATH, sep='\t')
else:
    logging.info('Creating reference sentences dataset')
    ref_sents_df = pd.merge(citations_df, reverse_ref_df, left_on = ['pmid', 'ref_id'], right_on = ['pmid', 'ref_id'])
    ref_sents_df = pd.merge(ref_sents_df, sentences_df, left_on = ['pmid', 'sent_type', 'sent_id'], right_on = ['pmid', 'type', 'sent_id'])
    ref_sents_df = ref_sents_df[ref_sents_df['pmid'].isin(review_files_df['pmid'].values)]    
    ref_sents_df = ref_sents_df.drop_duplicates()
    logging.info(f'Len of unique ref_sents {len(set(ref_sents_df["ref_pmid"]))}')
    ref_sents_df = ref_sents_df[['pmid', 'ref_id', 'pub_type', 'ref_pmid', 'sent_type', 'sent_id', 'sentence']]
    ref_sents_df.to_csv(REF_SENTS_DF_PATH, sep='\t', index=False)

In [None]:
display(ref_sents_df.head())

8. To create a dataset with features use `preprocess_paper_with_features`. Otherwise, use `preprocess_paper`. 

In [None]:
def process_reference_sentences_dataset(sentences_df, ref_sents_df):
    res = {}
    inter = set(sentences_df['pmid'].values) & set(ref_sents_df['ref_pmid'].values)
    linter = len(inter)
    for i, pmid in enumerate(inter):
        try:
            # For dataset with additional features uncomment this
    #         temp = preprocess_paper_with_features(pmid, sentences_df, ref_sents_df, abstracts_df,
    #                                               figures_df, reverse_ref_df, tables_df)
            temp = preprocess_paper(pmid, sentences_df, ref_sents_df)
        except Exception as e:
            logging.warning(f'Error during processing {pmid} {e}')
            continue
        if temp is None:
            logging.warning(f'temp is None for {pmid}')
            continue
        res[pmid] = temp
        print(f"\r{i}/{linter} {pmid} {np.mean(res[pmid][1])}", end="")

    logging.info(f'Successfully preprocessed {len(res)} of {len(inter)} papers')

    logging.info(f'Creating train dataset')
    feature_names = ['sent_id', 'sent_type', 'r_abs', 'num_refs', \
                 'mean_r_fig', 'mean_r_tab',\
                 'min_r_fig', 'min_r_tab',\
                 'max_r_fig', 'max_r_tab']
    train_dic = {'pmid':[], 'sentence':[], 'score':[],\
                'sent_id':[], 'sent_type':[], 'r_abs':[], 'num_refs':[], \
                     'mean_r_fig':[], 'mean_r_tab':[],\
                     'min_r_fig':[], 'min_r_tab':[],\
                     'max_r_fig':[], 'max_r_tab':[]}

    for pmid, stat in tqdm(res.items()):
        if len(stat) == 2:
            for sent, score in zip(*stat):
                train_dic['pmid'].append(pmid)
                train_dic['sentence'].append(sent)
                train_dic['score'].append(score)    
        else:
            for sent, score, features in zip(*stat):
                train_dic['pmid'].append(pmid)
                train_dic['sentence'].append(sent)
                train_dic['score'].append(score)
                for name, val in zip(feature_names, features):
                    train_dic[name].append(val)

    train_df = pd.DataFrame({k:v for k,v in train_dic.items() if v})
    logging.info(f'Full train dataset {len(train_df)}')
    return train_df

In [None]:
TRAIN_DATASET_PATH = f'{os.path.expanduser(cfg.base_path)}/dataset.csv'
if os.path.exists(TRAIN_DATASET_PATH):
    train_df = pd.read_csv(TRAIN_DATASET_PATH)
else:
    train_df = process_reference_sentences_dataset(sentences_df, ref_sents_df)
    train_df.to_csv(TRAIN_DATASET_PATH, index=False)

In [None]:
display(train_df.head(1))

9. A function to preprocess input text for `BERT`.

In [None]:
def preprocess_paper_bert(text, max_len, tokenizer):
    sents = [[tokenizer.artBOS.tkn] + tokenizer.tokenize(sent) + [tokenizer.artEOS.tkn]
             for sent in text]
    ids, segments, segment_signature = [], [], 0
    n_setns = 0
    for s in sents:
        if len(ids) + len(s) <= max_len:
            n_setns += 1
            ids.extend(tokenizer.convert_tokens_to_ids(s))
            segments.extend([segment_signature] * len(s))
            segment_signature = (segment_signature + 1) % 2
        else:
            break
    mask = [1] * len(ids)

    pad_len = max(0, max_len - len(ids))
    ids += [tokenizer.PAD.idx] * pad_len
    mask += [0] * pad_len
    segments += [segment_signature] * pad_len

    return ids, mask, segments, n_setns

10. Splitting data into train/test/val.

In [None]:
from sklearn.model_selection import train_test_split

train_ids, test_ids = train_test_split(list(set(train_df['pmid'].values)), test_size=0.2)
test_ids, val_ids = train_test_split(test_ids, test_size = 0.4)
train, test, val = train_df[train_df['pmid'].isin(train_ids)], \
    train_df[train_df['pmid'].isin(test_ids)], train_df[train_df['pmid'].isin(val_ids)] 

11. Dataset classes.\
`Other*` classes first preprocess all the data and then give out the batches.\
`Ordinary` data classes preprocess data before each batch to give out. 

12. Helpers for creating the model, training and evaluation. 

In [None]:
from torch.utils.data import DataLoader, Dataset, DistributedSampler


class OtherTrainDatasetFeatures(Dataset):
    """ Custom Train Dataset
    """

    def __init__(self, dataframe, tokenizer, article_len):
        self.df = dataframe
        self.data = []
        self.names = list(set(dataframe['pmid'].values))
        self.tokenizer = tokenizer
        self.article_len = article_len
        
        for name in tqdm(self.names):
            ex = self.df[self.df['pmid'] == name]
            paper = ex['sentence'].values
# Uncomment me in case of features are used            
#             features = np.nan_to_num(ex[['sent_id', 'sent_type', 'r_abs',
#                            'num_refs', 'mean_r_fig', 'mean_r_tab', 
#                            'min_r_fig', 'min_r_tab',
#                            'max_r_fig', 'max_r_tab']].values.astype(float))
#             abstract = standardize(ex.abstract)
            total_sents = 0
            while total_sents < len(paper):
                magic = max(0, total_sents - 5)
                article_ids, article_mask, article_segment, n_setns = \
                    preprocess_paper_bert(paper[magic:], self.article_len, self.tokenizer)
                if n_setns <= 5:
                    total_sents += 1
                    continue
                target_scores = ex['score'].values[magic:magic + n_setns] / 100
                self.data.append((article_ids, article_mask, article_segment, target_scores)) #, features[magic:magic + n_setns]))
                total_sents = magic + n_setns
        
        self.n_examples = len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return self.n_examples

class OtherTrainDataset(Dataset):
    """ Custom Train Dataset
    """

    def __init__(self, dataframe, tokenizer, article_len):
        self.df = dataframe
        self.data = []
        self.names = list(set(dataframe['pmid'].values))
        self.tokenizer = tokenizer
        self.article_len = article_len
        
        for name in tqdm(self.names):
            ex = self.df[self.df['pmid'] == name]
            paper = ex['sentence'].values
#             abstract = standardize(ex.abstract)
            total_sents = 0
            while total_sents < len(paper):
                magic = max(0, total_sents - 5)
                article_ids, article_mask, article_segment, n_setns = \
                    preprocess_paper_bert(paper[magic:], self.article_len, self.tokenizer)
                if n_setns <= 5:
                    total_sents += 1
                    continue
                target_scores = ex['score'].values[magic:magic + n_setns] / 100
                self.data.append((article_ids, article_mask, article_segment, target_scores))
                total_sents = magic + n_setns
        
        self.n_examples = len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return self.n_examples

class TrainDataset(Dataset):
    """ Custom Train Dataset
    """

    def __init__(self, dataframe, tokenizer, article_len):
        self.df = dataframe
        self.n_examples = len(set(dataframe['pmid'].values))
        self.names = list(set(dataframe['pmid'].values))
        self.tokenizer = tokenizer
        self.article_len = article_len

    def __getitem__(self, idx):
        idx = self.names[idx]
        ex = self.df[self.df['pmid'] == idx]
        paper = ex['sentence'].values
        
        # abstract = standardize(ex.abstract)
        article_ids, article_mask, article_segment, n_setns = \
            preprocess_paper_bert(paper, self.article_len, self.tokenizer)

        # form target
        target_scores = ex['score'].values[:n_setns] / 100

        return article_ids, article_mask, article_segment, target_scores

    def __len__(self):
        return self.n_examples


class EvalDatasetFeatures(Dataset):
    """ Custom Valid/Test Dataset
    """

    def __init__(self, dataframe, tokenizer, article_len):
        self.df = dataframe
        self.n_examples = len(set(dataframe['pmid'].values))
        print(self.n_examples)
        self.names = list(set(dataframe['pmid'].values))
        self.tokenizer = tokenizer
        self.article_len = article_len

    def __getitem__(self, idx):
        idx = self.names[idx]
        ex = self.df[self.df['pmid'] == idx]
        paper = ex['sentence'].values
        features = np.nan_to_num(ex[['sent_id', 'sent_type', 'r_abs',
                           'num_refs', 'mean_r_fig', 'mean_r_tab', 
                           'min_r_fig', 'min_r_tab',
                           'max_r_fig', 'max_r_tab']].values.astype(float))
        
        article_ids, article_mask, article_segment, n_setns = \
            preprocess_paper_bert(paper, self.article_len, self.tokenizer)

        # form target
        target_scores = ex['score'].values[:n_setns] / 100

        return article_ids, article_mask, article_segment, target_scores, features[:n_setns]

    @staticmethod
    def extract_gold_sents(paper, gold_ids):
        paper = sent_tokenize(paper)
        gold_sents = [sent for i, sent in enumerate(paper) if i in gold_ids]
        return gold_sents

    def __len__(self):
        return self.n_examples
    
class EvalDataset(Dataset):
    """ Custom Valid/Test Dataset
    """

    def __init__(self, dataframe, tokenizer, article_len):
        self.df = dataframe
        self.n_examples = len(set(dataframe['pmid'].values))
        print(self.n_examples)
        self.names = list(set(dataframe['pmid'].values))
        self.tokenizer = tokenizer
        self.article_len = article_len

    def __getitem__(self, idx):
        idx = self.names[idx]
        ex = self.df[self.df['pmid'] == idx]
        paper = ex['sentence'].values
        
        article_ids, article_mask, article_segment, n_setns = \
            preprocess_paper_bert(paper, self.article_len, self.tokenizer)

        # form target
        target_scores = ex['score'].values[:n_setns] / 100

        return article_ids, article_mask, article_segment, target_scores

    @staticmethod
    def extract_gold_sents(paper, gold_ids):
        paper = sent_tokenize(paper)
        gold_sents = [sent for i, sent in enumerate(paper) if i in gold_ids]
        return gold_sents

    def __len__(self):
        return self.n_examples

In [None]:
from pysrc.review.utils import count_parameters
from transformers import AdamW
from torch.nn import BCELoss, BCEWithLogitsLoss, MSELoss

def load_model(model_type, froze_strategy, rank, article_len, features=False):
    model = Summarizer(model_type, article_len, features)
    model.expand_posembs_ifneed()
#     model.load('temp')
    model.froze_backbone(froze_strategy)
    model.unfroze_head()
    if rank == 0:
        print(f"Model trainable parameters: {count_parameters(model)}")
    return model

def train_collate_fn_ft(batch_data):
    print(str(batch_data))
    data0, data1, data2, data3, data4 = list(zip(*batch_data))
    return torch.tensor(data0, dtype=torch.long), \
        torch.tensor(data1, dtype=torch.long), \
        torch.tensor(data2, dtype=torch.long), \
        [torch.tensor(e, dtype=torch.float) for e in data3], \
        [torch.tensor(e, dtype=torch.float) for e in data4]

def eval_collate_fn_ft(batch_data):
    print(str(batch_data))
    data0, data1, data2, data3, data4 = list(zip(*batch_data))

    return torch.tensor(data0, dtype=torch.long), \
        torch.tensor(data1, dtype=torch.long), \
        torch.tensor(data2, dtype=torch.long), \
        [torch.tensor(e, dtype=torch.float) for e in data3], \
        [torch.tensor(e, dtype=torch.float) for e in data4]

def get_dataloaders(train, val, batch_size,
                    article_len, tokenizer, ddp):
    dl_func = create_ddp_loader if ddp else create_loader

    logging.info('Creating train dataset...')
    train_ds = OtherTrainDataset(train, tokenizer, article_len) 
#     train_ds = OtherTrainDatasetFeatures(train, tokenizer, article_len)     
    logging.info('Applying loader functions to train...')
    train_dl = dl_func(train_ds, batch_size, train_collate_fn)
#     train_dl = dl_func(train_df, batch_size, train_collate_fn_ft)    
    
    logging.info('Creating val dataset...')    
    val_ds = EvalDataset(val, tokenizer, article_len)
#     val_ds = EvalDatasetFeatures(val, tokenizer, article_len)     
    logging.info('Applying loader functions to val...')
    val_dl = dl_func(val_ds, batch_size, eval_collate_fn)
#     val_dl = dl_func(val_ds, batch_size, eval_collate_fn_ft)
    
    return train_dl, val_dl


def get_tools(model, enc_lr, dec_lr, warmup,
              weight_decay, clip_value,
              accumulation_interval):

    enc_parameters = [
        param for name, param in model.named_parameters()
        if param.requires_grad and name.startswith('bert.')
    ]
    dec_parameters = [
        param for name, param in model.named_parameters()
        if param.requires_grad and not name.startswith('bert.')
    ]
    optimizer = AdamW([
        {'params': enc_parameters, 'lr': enc_lr},
        {'params': dec_parameters, 'lr': dec_lr},
    ], weight_decay=weight_decay)
    optimizer.clip_value = clip_value
    optimizer.accumulation_interval = accumulation_interval

    scheduler = NoamScheduler(optimizer, warmup=warmup)
    criter = MSELoss()

    return optimizer, scheduler, criter


# def setup_multi_gpu(model, optimizer, rank, size):
#     logging.info('Setup distributed settings...')
#     distrib_config = DistributedConfig(local_rank=rank, size=size, amp_enabled=cfg.amp_enabled)
#     setup_distributed(distrib_config)
#     device = choose_device(local_rank=rank)
#     model = model.to(device)
#     model, optimizer = setup_apex_if_enabled(model, optimizer, config=distrib_config)
#     model = setup_distrib_if_enabled(model, config=distrib_config)
#     return model, device, optimizer


def setup_single_gpu(model):
    logging.info('Setup single-device settings...')
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    return model, device

In [None]:
def evaluate(model, dataloader, criter, device, rank, writer, distributed):
    model.eval()
    model_ref = model.module if distributed else model
    loss_val = 0
    mean_sents = 0
    szs = 0
    
    #pbar = tqdm(enumerate(dataloader), total=len(dataloader), leave=False, disable=rank != 0)
    for idx_batch, batch in enumerate(dataloader):
        input_ids, input_mask, input_segment, target_scores = \
            [(x.to(device) if isinstance(x, torch.Tensor) else x) for x in batch]
        sizes = [dc.shape[0] for dc in target_scores]
        mean_sents += sum(sizes)
        szs += len(sizes)
        target_scores = torch.cat(target_scores).to(device)
        
        # forward pass
        draft_probs = model(
            input_ids, input_mask, input_segment,
        )
        #print(draft_probs.shape, target_scores.shape)
        
        # loss
        loss = criter(
            draft_probs,
            target_scores,
        )

        # record a loss value
        loss_val += loss.item()
        writer.add_scalar(f"Eval/loss", loss.item(), writer.train_step)
        writer.train_step += 1

    # overall loss per epoch
    # if distributed:
    #     loss_val = distribute(loss_val, device)
    # logging.info(f"mean loss: {loss_val / len(dataloader.dataset):.4f}", is_print=rank == 0)
    print("Val loss:", loss_val/len(dataloader))
    print("Mean sent len:", mean_sents/szs)
    # save model, just in case
    model_ref.save('validated_weights.pth')

    return model

def evaluate_ft(model, dataloader, criter, device, rank, writer, distributed):
    model.eval()
    model_ref = model.module if distributed else model
    loss_val = 0
    mean_sents = 0
    szs = 0
    
    #pbar = tqdm(enumerate(dataloader), total=len(dataloader), leave=False, disable=rank != 0)
    for idx_batch, batch in enumerate(dataloader):
        input_ids, input_mask, input_segment, target_scores, input_features = \
            [(x.to(device) if isinstance(x, torch.Tensor) else x) for x in batch]
        sizes = [dc.shape[0] for dc in target_scores]
        mean_sents += sum(sizes)
        szs += len(sizes)
        input_features = torch.cat(input_features).to(device)
        target_scores = torch.cat(target_scores).to(device)
        
        # forward pass
        draft_probs = model(
            input_ids, input_mask, input_segment, input_features,
        )
        #print(draft_probs.shape, target_scores.shape)
        
        # loss
        loss = criter(
            draft_probs,
            target_scores,
        )

        # record a loss value
        loss_val += loss.item()
        writer.add_scalar(f"Eval/loss", loss.item(), writer.train_step)
        writer.train_step += 1

    # overall loss per epoch
    # if distributed:
    #     loss_val = distribute(loss_val, device)
    # logging.info(f"mean loss: {loss_val / len(dataloader.dataset):.4f}", is_print=rank == 0)
    print("Val loss:", loss_val/len(dataloader))
    print("Mean sent len:", mean_sents/szs)
    # save model, just in case
    model_ref.save('validated_weights.pth')

    return model

13. Load model

In [None]:
MODEL_SIZE = 512

model = load_model("bert", "froze_all", 0, MODEL_SIZE, features=False)    
model, device = setup_single_gpu(model)

14. Create dataloaders and start training.

Use `train_fun_ft` for training model with additional features and `train_fun` without.\
Use `evaluate_ft` for training model with additional features and `evaluate` without.

In [None]:
from tensorboardX import SummaryWriter

MODEL_PATH = f'{os.path.expanduser(cfg.weights_path)}/learn_simple_berta.pth'

def load_or_train_model(model):
    if os.path.exists(MODEL_PATH):
        logging.info(f'Loading model {MODEL_PATH}')
        model.load("learn_simple_berta")
        model, device = setup_single_gpu(model)
    else:
        logging.info('Create dataloaders...')
        model_ref = model
        train_loader, valid_loader = get_dataloaders(train, val, 4, MODEL_SIZE, model_ref.tokenizer, ddp=False)

        writer = SummaryWriter(log_dir=f'{os.path.expanduser(cfg.base_path)}/logs')
        writer.train_step, writer.eval_step = 0, 0

        optimizer, scheduler, criter = get_tools(model, 0.00001, 0.001, 5, 0.005, 1.0, 1)    

        for epoch in range(1, 20 + 1):
            logging.info(f"{epoch} epoch training...")

            model, optimizer, scheduler, writer = train_fun(
                model, train_loader, optimizer, scheduler,
                criter, device, 0, writer, False
            )
        #     model, optimizer, scheduler, writer = train_fun_ft(
        #         model, train_loader, optimizer, scheduler,
        #         criter, device, 0, writer, False
        #     )


            logging.info(f"{epoch} epoch validation...")
            model = evaluate(model, valid_loader, criter, device, 0, writer, False, )
        #         model = evaluate_ft(model, valid_loader, criter, device, 0, writer, False, )
        # Save trained model
        model.save('learn_simple_berta')
    return model

15. See the example of how model works:

In [None]:
model = load_or_train_model(model)
model_ref = model

In [None]:
display(val.head())

In [None]:
# Pick the first value
ex = val[val['pmid'] == list(val['pmid'])[0]]
paper = ex['sentence'].values
article_ids, article_mask, article_segment, n_sents = \
            preprocess_paper_bert(paper, MODEL_SIZE, model_ref.tokenizer)
res_sents = paper[:n_sents]
scores = ex['score'].values[:n_sents]

In [None]:
input_ids = torch.tensor([article_ids]).to(device)
input_mask = torch.tensor([article_mask]).to(device)
input_segment = torch.tensor([article_segment]).to(device)
draft_probs = model(input_ids, input_mask, input_segment)

In [None]:
to_show_df = pd.DataFrame({
    'sentence': res_sents, 
    'ideal_score': scores / 100, 
    'res_score': draft_probs.cpu().detach().numpy()
})
display(to_show_df.head())

In [None]:
' '.join(to_show_df[to_show_df['res_score'] > 0.07]['sentence'].values)

In [None]:
to_show_df.to_csv(f'{cfg.base_path}/show_scores.csv')

16. Evaluation.\
First, let's see if the model trained well.\
Then will count an `MSE` score on some real example.

In [None]:
if 'model' not in globals():
    model = load_model("bert", "froze_all", 0, MODEL_SIZE, False)
    model.load("learn_simple_berta")
    model, device = setup_single_gpu(model)
    model_ref = model

In [None]:
from tensorboardX import SummaryWriter
REF_SCORES_PATH = os.path.expanduser(f"{cfg.base_path}/refs_and_scores.csv")

if os.path.exists(REF_SCORES_PATH):
    final_ref_show_df = pd.read_csv(f"{cfg.base_path}/refs_and_scores.csv")
else:
    if 'train_loader' not in globals() or 'valid_loader' not in globals():
        train_loader, valid_loader = get_dataloaders(train, val, 4, MODEL_SIZE, model_ref.tokenizer, ddp=False)

    writer = SummaryWriter(log_dir=os.path.expanduser(cfg.log_dir))
    writer.train_step, writer.eval_step = 0, 0
    optimizer, scheduler, criter = get_tools(model, 0.00001, 0.001, 5, 0.005, 1.0, 1)
    model = evaluate(model, train_loader, criter, device, 0, writer, False,)

    # model = evaluate_ft(model, valid_loader, criter, device, 0, writer, False,)    
    to_show_ref = pd.merge(train_df, ref_sents_df[['ref_pmid', 'sentence']], 
                           left_on = ['pmid'], right_on = ['ref_pmid'])
    to_show_ref = to_show_ref.rename(columns={'sentence_x': 'sentence', 'sentence_y': 'ref_sentence'})
    to_show_ref = to_show_ref[['pmid', 'sentence', 'ref_sentence', 'score']]    
    final_ref_show_dic = {'pmid': [], 'sentence': [], 'ref_sentence': [], 'score':[]}
    ite = [(pmid, sent) for pmid, sent in to_show_ref[['pmid', 'sentence']].values]

    for pmid, sent in tqdm(set(ite)):
        refs_df = to_show_ref[(to_show_ref['pmid'] == pmid) & (to_show_ref['sentence'] == sent)]
        final_ref_show_dic['pmid'].append(pmid)
        final_ref_show_dic['sentence'].append(sent)
        final_ref_show_dic['ref_sentence'].append(" ".join(refs_df['ref_sentence'].values))
        final_ref_show_dic['score'].append(refs_df['score'].values[0])
    final_ref_show_df = pd.DataFrame(final_ref_show_dic)
    final_ref_show_df.to_csv(REF_SCORES_PATH, index=False)    

display(final_ref_show_df)

In [None]:
to_test = final_ref_show_df[final_ref_show_df['pmid'].isin(set(val['pmid'].values))]
display(to_test)

In [None]:
res = {'pmid':[], 'sentence':[], 'ref_sentences':[], 'score': [], 'res_score':[]}

for id in tqdm(set(to_test['pmid'].values)):
    ex = to_test[to_test['pmid'] == id]
    paper = ex['sentence'].values
    article_ids, article_mask, article_segment, n_sents = \
            preprocess_paper_bert(paper, MODEL_SIZE, model_ref.tokenizer)
    res_sents = paper[:n_sents]
    scores = ex['score'].values[:n_sents] / 100
    input_ids = torch.tensor([article_ids]).to(device)
    input_mask = torch.tensor([article_mask]).to(device)
    input_segment = torch.tensor([article_segment]).to(device)
    draft_probs = model(input_ids, input_mask, input_segment,)
    for sent, sc, res_sc in zip(res_sents, scores, draft_probs.cpu().detach().numpy()):
        res['pmid'].append(id)
        res['sentence'].append(sent)
        res['ref_sentences'].append(ex['ref_sentence'].values[0])
        res['score'].append(sc)
        res['res_score'].append(res_sc)
res_df = pd.DataFrame(res)

In [None]:
res_df

In [None]:
res_df.to_csv(f"{cfg.base_path}/saved_example_refs.csv")

In [None]:
diff = ((res_df['score'].values - res_df['res_score'].values)**2).mean()

In [None]:
diff**0.5 #The MSE score!

## TODO
Cleanup the code below.

17. Let's see what is the original score distribution to understand the quality of `MSE`. 

In [None]:
import matplotlib.pyplot as plt
from collections import Counter

res = Counter(list(train_df['score'].values))

plt.hist(res.values(), bins=range(2, 20))
plt.show()

In [None]:
ref_sents_df

In [None]:
paper_ref = sentences_df[sentences_df['pmid']==26194312]['sentence'].values

In [None]:
papers_to_check = list(set(ref_sents_df[ref_sents_df['pmid']==26194312]['ref_pmid'].values))

In [None]:
len(papers_to_check)

18. A test of several paper summarization into review one. 

In [None]:
# from .model import Summarizer as BertSum

In [None]:
model = Summarizer('bert', MODEL_SIZE)
# model.load('bert_sum')
model.froze_backbone("froze_all")
model.unfroze_head()

In [None]:
model.eval()
model, device = setup_single_gpu(model)

In [None]:
model_ref = model

In [None]:
# Features-related evaluation is commented right now.
test_stat = {'rev_pmid':[], 'sent_num':[], 'true_rouge':[], 'diff_papers': []} #'rouge': [], 
inter = set(sentences_df['pmid'].values) & set(ref_sents_df['ref_pmid'].values)
review_papers = list(set(ref_sents_df[ref_sents_df['ref_pmid'].isin(inter)]['pmid'].values))
print('Review papers', len(review_papers))
cnt = 0

for rev_id in tqdm(review_papers):
#     print(f"\r{rev_id} {cnt} / {len(review_papers)}", end="")
    cnt += 1
    paper_ref = sentences_df[sentences_df['pmid']==rev_id]['sentence'].values
    papers_to_check = list(set(ref_sents_df[ref_sents_df['pmid']==rev_id]['ref_pmid'].values))
    result = {'pmid':[], 'sentence':[], 'score':[]}
    for paper_id in papers_to_check:
        ex = test[test['pmid'] == paper_id]
        paper = ex['sentence'].values
        #features = np.nan_to_num(ex[['sent_id', 'sent_type', 'r_abs',
        #                   'num_refs', 'mean_r_fig', 'mean_r_tab', 
        #                   'min_r_fig', 'min_r_tab',
        #                   'max_r_fig', 'max_r_tab']].values.astype(float))
        total_sents = 0
        while total_sents < len(paper):
            magic = max(0, total_sents - 5)
            article_ids, article_mask, article_segment, n_setns = \
            preprocess_paper_bert(paper[magic:], MODEL_SIZE, model_ref.tokenizer)
            if n_setns <= 5:
                total_sents += 1
                continue
            old_total = total_sents
            total_sents = magic + n_setns
            input_ids = torch.tensor([article_ids]).to(device)
            input_mask = torch.tensor([article_mask]).to(device)
            input_segment = torch.tensor([article_segment]).to(device)
            #input_features = [torch.tensor(e, dtype=torch.float) for e in features[magic:total_sents]]
            #input_features = torch.stack(input_features).to(device)
            #print(input_features)
            draft_probs = model(
                        input_ids, input_mask, input_segment, #input_features,
                    )
            result['pmid'].extend([paper_id] * (total_sents - old_total))
            result['sentence'].extend(list(paper[old_total:total_sents]))
            result['score'].extend(list(draft_probs.cpu().detach().numpy())[old_total - magic:])
    res_df = pd.DataFrame(result)
    sorted_arr = sorted(list(res_df['score'].values))
    for i in range(5, 103, 5):
        if len(sorted_arr) < i:
            break
        treshold = sorted_arr[-i]
        final_text = res_df[res_df['score'] >= treshold][['pmid', 'sentence']]
        #mean_score = 0
        #num = 0
        #for sent in final_text['sentence'].values:
        #    for ref_sent in paper_ref:
        #        try:
        #            mean_score += get_rouge(sent, ref_sent)
        #            num += 1
        #        except Exception:
        #            continue
        #mean_score /= num
        real_score = get_rouge(" ".join(final_text['sentence'].values), " ".join(paper_ref))
        test_stat['rev_pmid'].append(rev_id)
        test_stat['sent_num'].append(i)
        #print(len(" ".join(final_text['sentence'].values)), len(" ".join(paper_ref)))
        
        
        #test_stat['rouge'].append(mean_score)
        test_stat['true_rouge'].append(real_score)
        test_stat['diff_papers'].append(len(set(final_text['pmid'])))
            

In [None]:
print(*[len(arr) for key, arr in test_stat.items()])

In [None]:
test_stat_df = pd.DataFrame(test_stat)

In [None]:
test_stat_df

In [None]:
test_stat_df = test_stat_df[test_stat_df['rev_pmid'] != 29574033]

In [None]:
test_stat_df.to_csv(f"{cfg.base_path}/simple_right_test_on_review.csv", index=False)

In [None]:
len(result['sentence'])

In [None]:
len(result['score'])

In [None]:
res_df = pd.DataFrame(result)

In [None]:
res_df

In [None]:
treshold = sorted(list(res_df['score'].values))[-5]

In [None]:
final_text = res_df[res_df['score'] >= treshold][['pmid', 'sentence']]

In [None]:
len(set(final_text['pmid']))

In [None]:
" ".join(final_text['sentence'].values)

In [None]:
mean_score = 0
num = 0
for sent in final_text['sentence'].values:
    for ref_sent in paper_ref:
        mean_score += get_rouge(sent, ref_sent)
        num += 1
mean_score /= num

In [None]:
mean_score

In [None]:
" ".join(final_text)

In [None]:
inter = list(set(sentences_df['pmid'].values) & set(ref_sents_df['ref_pmid'].values))

In [None]:
26194312 in inter

In [None]:
ref_sents_df[ref_sents_df['ref_pmid']==25559091]

In [None]:
import pandas as pd

test_stat_df = pd.read_csv(f"{cfg.base_path}/simple_right_test_on_review.csv")
ft_test_stat_df = pd.read_csv(f"{cfg.base_path}/bertsum_test_on_review.csv")

In [None]:
test_stat_df

In [None]:
ft_test_stat_df

In [None]:
df_1 = test_stat_df.assign(model = ['Основная модель']*len(test_stat_df))

In [None]:
df_2 = ft_test_stat_df.assign(model = ['BERTSUM']*len(ft_test_stat_df))

In [None]:
draw_df = pd.concat([df_1, df_2])
draw_df.head()

In [None]:
rouge_means = []
rouge_err = []
papers_means = []
papers_err = []

for i in range(5, 103, 5):
    tmp = test_stat_df.groupby(['sent_num']).get_group(i)
    rouge_means.append(tmp['rouge'].mean())
    rouge_err.append(tmp['rouge'].std())
    papers_means.append(tmp['diff_papers'].mean())
    papers_err.append(tmp['diff_papers'].std())
    

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.errorbar(list(range(5, 103, 5)), rouge_means, yerr=rouge_err, fmt='-o')

In [None]:
plt.errorbar(list(range(5, 103, 5)), papers_means, yerr=papers_err, fmt='-o')

In [None]:
import seaborn as sns

In [None]:
sns.catplot(x="sent_num", y="rouge", kind="box", hue='model', aspect=1.7, color='lightblue', data=draw_df).set_axis_labels("ЧИСЛО ПРЕДЛОЖЕНИЙ", "ROUGE, %")

In [None]:
sns.catplot(x="sent_num", y="diff_papers", kind="box", aspect=1.5, color = 'lightblue', data=test_stat_df).set_axis_labels("ЧИСЛО ПРЕДЛОЖЕНИЙ", "ЧИСЛО СТАТЕЙ В РЕЗЮМЕ")

In [None]:
sns.catplot(x="sent_num", y="true_rouge", kind="box", aspect=1.7, hue='model', color='lightblue', data=draw_df).set_axis_labels("ЧИСЛО ПРЕДЛОЖЕНИЙ", "ROUGE, %")