In [None]:
import heapq
from collections import defaultdict
from importlib import import_module
from pathlib import Path
from typing import Any, Callable, List, Tuple, Dict

import torch
import torch.nn as nn
from torch import Tensor
from tqdm.auto import tqdm

import neural.common.utils as utils
from neural.common.data.vocab import Vocab, VocabBuilder
from neural.common.scores import ROUGE, F1Score
from neural.ner.dataloader import NERDataset, NERDataLoader
from neural.summarization.dataloader import SummarizationDataset, SummarizationDataLoader
from utils.database import DatabaseConnector

In [None]:
models_path = Path('../data/saved/models')
device = utils.get_device(use_cuda=True)
utils.set_random_seed(0)

In [None]:
def load_trained_model(model_id: str, path_to_model: Path, weights_name: str = None, **params: Any) -> nn.Module:
    module = import_module(f'neural.train_{model_id}')
    args = utils.load_args_from_file(path_to_model)
    model_name = path_to_model.stem
    weights_path = path_to_model / f'{model_name}.pt'

    model = module.create_model_from_args(args, **params)
    weights = torch.load(weights_path, map_location=device)
    weights_name = weights_name or model_name
    model.load_state_dict(weights[f'{weights_name}_state_dict'])
    model.to(device)
    model.eval()
    del weights

    return model


def count_model_parameters(model: nn.Module) -> None:
    learnable_parameters = defaultdict(int)
    constant_parameters = defaultdict(int)
    model.train()
    for name, parameter in model.named_parameters():
        parameter_name = name.split('.weight')[0].split('.bias')[0]
        if parameter.requires_grad:
            learnable_parameters[parameter_name] += torch.numel(parameter)
        else:
            constant_parameters[parameter_name] += torch.numel(parameter)
    model.eval()

    if len(learnable_parameters) > 0:
        print('Learnable parameters:')
        for name, count in learnable_parameters.items():
            print(f'{name}: {count}')
        print(f'Sum: {sum(learnable_parameters.values())}')

    if len(constant_parameters) > 0:
        print('Constant parameters:')
        for name, count in constant_parameters.items():
            print(f'{name}: {count}')
        print(f'Sum: {sum(constant_parameters.values())}')


def show_examples_summarization(model: nn.Module, loader: SummarizationDataLoader, vocab: Vocab,
                                predict_tokens: Callable[[nn.Module, Tuple[Any, ...]], Tensor],
                                use_oov: bool = True, examples_number: int = 3) -> None:
    scorer = ROUGE(vocab, 'rouge1')
    examples = []
    with torch.no_grad():
        for inputs in tqdm(loader):
            inputs = utils.convert_input_to_device(inputs, device)
            if use_oov:
                targets, oov_list = inputs[-2:]
            else:
                oov_list = None
                targets = inputs[-1]

            tokens = predict_tokens(model, inputs)
            for i in range(tokens.shape[1]):
                if use_oov:
                    utils.add_words_to_vocab(vocab, oov_list[i])

                score_out = tokens[:, i].unsqueeze(dim=1)
                score_target = targets[:, i].unsqueeze(dim=1)
                score = scorer.score(score_out, score_target)

                score_out = utils.clean_predicted_tokens(score_out, 3)
                score_out = utils.remove_unnecessary_padding(score_out)
                score_out = utils.tensor_to_string(vocab, score_out)
                score_target = utils.remove_unnecessary_padding(score_target)
                score_target = utils.tensor_to_string(vocab, score_target)

                if use_oov:
                    utils.remove_words_from_vocab(vocab, oov_list[i])
                heapq.heappush(examples, (score['ROUGE-1'], (score_out, score_target)))

    print_summarization_examples(examples, examples_number, best=True)
    print_summarization_examples(examples, examples_number, best=False)


def print_summarization_examples(examples: List[Tuple[float, Tuple[str, str]]], examples_number: int,
                                 best: bool) -> None:
    if best:
        examples_type = 'Best'
        examples_generator = heapq.nlargest
    else:
        examples_type = 'Worst'
        examples_generator = heapq.nsmallest

    print(f'{examples_type} examples:')
    for score, (prediction, target) in examples_generator(examples_number, examples):
        print('ROUGE-1:', score)
        print('Prediction:', prediction)
        print()
        print('Target:', target)
        print(50 * '-')


def predict_pointer_generator(model: nn.Module, inputs: Tuple[Any, ...]) -> Tensor:
    texts, texts_lengths, summaries, summaries_lengths, texts_extended, targets, oov_list = inputs
    oov_size = len(max(oov_list, key=lambda x: len(x)))
    _, tokens, _, _ = model(texts, texts_lengths, texts_extended, oov_size)

    return tokens


def predict_rl(model: nn.Module, inputs: Tuple[Any, ...]) -> Tensor:
    texts, texts_lengths, summaries, summaries_lengths, texts_extended, targets, oov_list = inputs
    oov_size = len(max(oov_list, key=lambda x: len(x)))
    _, tokens, _ = model(texts, texts_lengths, texts_extended, oov_size)

    return tokens


def predict_transformer(model: nn.Module, inputs: Tuple[Any, ...]) -> Tensor:
    texts, _, summaries, _, targets = inputs
    _, tokens = model(texts)

    return tokens


def show_examples_ner(model: nn.Module, loader: NERDataLoader, vocab: Vocab, tags_dict: Dict[int, Tuple[str, str]],
                      predict_tokens: Callable[[nn.Module, Tuple[Any, ...]], Tensor], examples_number: int = 3) -> None:
    labels = list(tags_dict.keys())
    scorer = F1Score(labels)
    examples = []
    with torch.no_grad():
        for inputs in tqdm(loader):
            inputs = utils.convert_input_to_device(inputs, device)
            text, targets = inputs[:2]
            tokens = predict_tokens(model, inputs)
            for i in range(tokens.shape[1]):
                score_out = tokens[:, i].unsqueeze(dim=1)
                score_target = targets[:, i].unsqueeze(dim=1)

                score_out = score_out[score_target >= 0]
                score_target = score_target[score_target >= 0]
                if not any(score_target) != 0:
                    continue

                score = scorer.score(score_out, score_target)

                texts = text[:, i]
                texts = texts[texts > 0]

                score_out = ' '.join(tags_dict[tag.item()][0] if tag.item() in tags_dict else 'O' for tag in score_out)
                score_target = ' '.join(tags_dict[tag.item()][0] if tag.item() in tags_dict else 'O' for tag
                                        in score_target)
                texts = utils.tensor_to_string(vocab, texts)

                heapq.heappush(examples, (score['F1'], (texts, score_out, score_target)))

    print_ner_examples(examples, examples_number, best=True)
    print_ner_examples(examples, examples_number, best=False)


def print_ner_examples(examples: List[Tuple[float, Tuple[str, str]]], examples_number: int,
                       best: bool) -> None:
    if best:
        examples_type = 'Best'
        examples_generator = heapq.nlargest
    else:
        examples_type = 'Worst'
        examples_generator = heapq.nsmallest

    print(f'{examples_type} examples:')
    for score, (text, prediction, target) in examples_generator(examples_number, examples):
        print('F1:', score)
        print('Text:', text)
        print('Prediction:', prediction)
        print('Target:', target)
        print(50 * '-')


def predict_bilstm_cnn(model: nn.Module, inputs: Tuple[Any, ...]) -> Tensor:
    words, tags, chars, word_features, char_features = inputs
    output = model(words, chars, word_features, char_features)
    tokens = torch.argmax(output, dim=-1)

    return tokens


def predict_bilstm_crf(model: nn.Module, inputs: Tuple[Any, ...]) -> Tensor:
    words, tags, chars, _, _ = inputs
    mask = (tags >= 0).float()
    loss, predictions = model(words, chars, tags, mask)

    return predictions


def predict_id_cnn(model: nn.Module, inputs: Tuple[Any, ...]) -> Tensor:
    words, tags, _, word_features, _ = inputs
    outputs = model(words, word_features)
    tokens = torch.argmax(outputs[-1], dim=-1)

    return tokens


In [None]:
pointer_generator = load_trained_model('pointer_generator', models_path / 'pointer_generator', bos_index=2, eos_index=3,
                                       unk_index=1)
pointer_generator.activate_coverage()
count_model_parameters(pointer_generator)
vocab = VocabBuilder.build_vocab('cnn_dailymail', 'summarization', vocab_size=50000)
dataset = SummarizationDataset('cnn_dailymail', 'test', 400, 100, vocab, get_oov=True)
dataloader = SummarizationDataLoader(dataset, batch_size=32)
show_examples_summarization(pointer_generator, dataloader, vocab, predict_pointer_generator)

In [None]:
reinforcement_summarization = load_trained_model('rl', models_path / 'reinforcement_learning', weights_name='rl_model',
                                                 bos_index=2, eos_index=3, unk_index=1)
count_model_parameters(reinforcement_summarization)
vocab = VocabBuilder.build_vocab('cnn_dailymail', 'summarization', vocab_size=50000)
dataset = SummarizationDataset('cnn_dailymail', 'test', 800, 100, vocab, get_oov=True)
dataloader = SummarizationDataLoader(dataset, batch_size=32)
show_examples_summarization(reinforcement_summarization, dataloader, vocab, predict_rl)

In [None]:
transformer = load_trained_model('transformer', models_path / 'transformer', bos_index=2, eos_index=3)
count_model_parameters(transformer)
vocab = VocabBuilder.build_vocab('cnn_dailymail', 'summarization', vocab_size=50000)
dataset = SummarizationDataset('cnn_dailymail', 'test', 400, 100, vocab)
dataloader = SummarizationDataLoader(dataset, batch_size=16)
show_examples_summarization(transformer, dataloader, vocab, predict_transformer, use_oov=False)

In [None]:
tags_count = DatabaseConnector().get_tag_count('conll2003') + 1
tags_dict = DatabaseConnector().get_tags_dict('conll2003')
vocab = VocabBuilder.build_vocab('conll2003', 'ner', vocab_type='char', digits_to_zero=True)
bilstm_cnn = load_trained_model('bilstm_cnn', models_path / 'bilstm_cnn', tags_count=tags_count, vocab=vocab)
count_model_parameters(bilstm_cnn)
dataset = NERDataset('conll2003', 'test', vocab)
dataloader = NERDataLoader(dataset, batch_size=128, two_sided_char_padding=True, conv_kernel_size=3)
show_examples_ner(bilstm_cnn, dataloader, vocab, tags_dict, predict_bilstm_cnn)

In [None]:
tags_count = DatabaseConnector().get_tag_count('conll2003') + 1
tags_dict = DatabaseConnector().get_tags_dict('conll2003')
vocab = VocabBuilder.build_vocab('conll2003', 'ner', vocab_type='char', digits_to_zero=True)
bilstm_crf = load_trained_model('bilstm_crf', models_path / 'bilstm_crf', tags_count=tags_count, vocab=vocab)
count_model_parameters(bilstm_crf)
dataset = NERDataset('conll2003', 'test', vocab)
dataloader = NERDataLoader(dataset, batch_size=128)
show_examples_ner(bilstm_crf, dataloader, vocab, tags_dict, predict_bilstm_crf)

In [None]:
tags_count = DatabaseConnector().get_tag_count('conll2003') + 1
tags_dict = DatabaseConnector().get_tags_dict('conll2003')
vocab = VocabBuilder.build_vocab('conll2003', 'ner', vocab_type='char', digits_to_zero=True)
id_cnn = load_trained_model('id_cnn', models_path / 'id_cnn', tags_count=tags_count, vocab=vocab)
count_model_parameters(id_cnn)
dataset = NERDataset('conll2003', 'test', vocab)
dataloader = NERDataLoader(dataset, batch_size=128)
show_examples_ner(id_cnn, dataloader, vocab, tags_dict, predict_id_cnn)