Notebook used in dataset overview.

In [None]:
from collections import Counter, defaultdict
from pathlib import Path
from typing import Dict, List, Callable, Any, Sequence

import matplotlib.pyplot as plt
import seaborn as sns
from nltk.tokenize import sent_tokenize
from nltk.tokenize.treebank import TreebankWordDetokenizer
from nltk.util import ngrams

from neural.common.data.datasets import DatasetGenerator
from utils.database import DatabaseConnector

In [None]:
save_path = Path('images/datasets')
save_path.mkdir(parents=True, exist_ok=True)
sns.set_theme()
sns.set_palette('muted')

In [None]:
def all_splits(func: Callable[[str, Any], None]) -> Callable:
    def wrapper(*args: Any) -> None:
        for split in ['train', 'validation', 'test']:
            func(split, *args)

    return wrapper

In [None]:
def print_entities_examples(dataset_name: str, max_tags: int = 20, skip_sentences: int = 0, entity_type: str = None):
    tags_dict = DatabaseConnector().get_tags_dict(dataset_name)
    dataset = DatasetGenerator.generate_dataset(dataset_name, 'train')
    tags_count = 0
    for i, (sentence, tags) in enumerate(dataset):
        if i < skip_sentences:
            continue
        for j, tag in enumerate(tags):
            if tag in tags_dict and (entity_type is None or tags_dict[tag][1] == entity_type):
                print(sentence[j], tags_dict[tag])
                tags_count += 1

        if tags_count >= max_tags:
            break

In [None]:
@all_splits
def count_dataset_entities(split: str, dataset_name: str) -> None:
    tags_dict = DatabaseConnector().get_tags_dict(dataset_name)
    all_tags = []
    tag_started = None
    for _, tags in DatasetGenerator.generate_dataset(dataset_name, split):
        for tag in tags:
            if tag not in tags_dict:
                tag_started = None
                continue

            tag, category = tags_dict[tag]
            position, tag = tag.split('-')
            if position == 'I' and tag == tag_started:
                continue

            if position == 'B':
                tag_started = tag

            all_tags.append(category)

    counter = Counter(all_tags)
    print(f'Named entities counts for dataset "{dataset_name}" with {split} split:')
    for tag, count in counter.most_common():
        print(f'{tag}: {count}')

In [None]:
@all_splits
def print_ner_dataset_distribution(split: str) -> None:
    detokenizer = TreebankWordDetokenizer()
    article_count = 0
    sentence_count = 0
    tokens_count = 0
    for tokens, tags in DatasetGenerator.generate_dataset('gmb', split):
        article_count += 1
        sentence_count += len(sent_tokenize(detokenizer.detokenize(tokens)))
        tokens_count += len(tokens)
    print(f'Dataset stats for {split} split:')
    print('Article count:', article_count)
    print('Sentence count:', sentence_count)
    print('Tokens count:', tokens_count)


@all_splits
def print_ner_dataset_average_sample_length(split: str, dataset_name: str) -> None:
    samples = []
    for tokens, _ in DatasetGenerator.generate_dataset(dataset_name, split):
        samples.append(len(tokens))

    average_length = round(sum(samples) / len(samples))
    print(f'Average sample length for {split} split of dataset "{dataset_name}": {average_length}')


In [None]:
@all_splits
def print_summarization_dataset_distribution(split: str, dataset_name: str) -> None:
    pairs = 0
    article_sentence_count = 0
    article_lengths = []
    summary_sentence_count = 0
    summary_lengths = []
    detokenizer = TreebankWordDetokenizer()
    for article, summary in DatasetGenerator.generate_dataset(dataset_name, split):
        pairs += 1
        article_sentence_count += len(sent_tokenize(detokenizer.detokenize(article)))
        summary_sentence_count += len(sent_tokenize(detokenizer.detokenize(summary)))
        article_lengths.append(len(article))
        summary_lengths.append(len(summary))

    average_article_length = round(sum(article_lengths) / len(article_lengths))
    average_summary_length = round(sum(summary_lengths) / len(summary_lengths))

    print(f'Dataset {dataset_name} stats for {split} split:')
    print('Pair count:', pairs)
    print('Articles sentence count:', article_sentence_count)
    print('Articles tokens count:', sum(article_lengths))
    print('Average article length:', average_article_length)
    print('Summaries sentence count:', summary_sentence_count)
    print('Summaries tokens count:', sum(summary_lengths))
    print('Average summary length:', average_summary_length)


@all_splits
def print_summarization_novel_ngrams(split: str, dataset_name: str, n_grams: Sequence[int] = (1, 2, 3, 4)) -> None:
    ratios = defaultdict(list)
    for article, summary in DatasetGenerator.generate_dataset(dataset_name, split):
        article = [token.lower() for token in article]
        summary = [token.lower() for token in summary]
        for n_gram in n_grams:
            article_n_gram = {sequence for sequence in ngrams(article, n_gram)}
            summary_n_gram = {sequence for sequence in ngrams(summary, n_gram)}
            if len(summary_n_gram) == 0:
                continue
            novel_n_gram = summary_n_gram - article_n_gram
            novel_ratio = len(novel_n_gram) / len(summary_n_gram)
            ratios[n_gram].append(novel_ratio)

    print(f'Novel n-grams in {split} split of "{dataset_name}" dataset:')
    for n_gram, ratio in ratios.items():
        ratio = sum(ratio) / len(ratio)
        ratio = round(ratio * 100, 2)
        print(f'Novel {n_gram}-gram ratio: {ratio}%')

In [None]:
def plot_ner_dataset_lengths(splits: Sequence[str] = ('train', 'validation', 'test')) -> None:
    conll_tokens_lengths = []
    gmb_tokens_lengths = []
    for split in splits:
        for tokens, tags in DatasetGenerator.generate_dataset('conll2003', split):
            conll_tokens_lengths.append(len(tokens))

        for tokens, tags in DatasetGenerator.generate_dataset('gmb', split):
            gmb_tokens_lengths.append(len(tokens))

    data = {
        'CoNLL-2003': conll_tokens_lengths,
        'GMB': gmb_tokens_lengths
    }

    fig = plt.figure(figsize=(15, 7))
    plt.xlim(0, 170)
    plt.yscale('log')
    plt.xlabel('Data sample length', fontsize=14)
    plt.ylabel('Number of samples', fontsize=14)
    plot = sns.histplot(data, binwidth=2)
    plt.setp(plot.get_legend().get_texts(), fontsize=14)
    fig.tight_layout()
    plt.savefig(save_path / f'ner_datasets_lengths.png')
    plt.show()


def print_ner_examples(dataset_name: str, examples_number: int = 3, skip_examples: int = 0) -> None:
    tags_dict = DatabaseConnector().get_tags_dict(dataset_name)
    for i, (tokens, tags) in enumerate(DatasetGenerator.generate_dataset(dataset_name, 'train')):
        if i < skip_examples:
            continue
        if i == examples_number + skip_examples:
            break

        tags = [tags_dict[tag][0] if tag in tags_dict else 'O' for tag in tags]
        tokens_str = ' '.join(tokens)
        tags_str = ' '.join(tags)
        print(tokens_str)
        print(tags_str)
        print()


def plot_summaries_sizes_single_ax(ax: plt.Axes, data: Dict[str, List[int]], name: str, x_lim: int,
                                   bin_width: int = 2) -> None:
    ax.set_xlim(0, x_lim)
    ax.set_xlabel(f'{name.capitalize()} lengths', fontsize=26)
    ax.set_ylabel(f'Number of {name}', fontsize=26)
    ax.set_title(f'{name.capitalize()} length comparison.', fontsize=30)
    ax.tick_params(axis='both', which='major', labelsize=20)
    ax.tick_params(axis='both', which='minor', labelsize=20)
    plot = sns.histplot(data, binwidth=bin_width, ax=ax)
    plt.setp(plot.get_legend().get_texts(), fontsize=26)


def plot_summaries_dataset_lengths(splits: Sequence[str] = ('train', 'validation', 'test')) -> None:
    cnn_articles_len = []
    cnn_summaries_len = []
    xsum_articles_len = []
    xsum_summaries_len = []

    for split in splits:
        for article, summary in DatasetGenerator.generate_dataset('cnn_dailymail', split):
            cnn_articles_len.append(len(article))
            cnn_summaries_len.append(len(summary))

        for article, summary in DatasetGenerator.generate_dataset('xsum', split):
            xsum_articles_len.append(len(article))
            xsum_summaries_len.append(len(summary))

    article_data = {
        'CNN/Daily Mail': cnn_articles_len,
        'XSum': xsum_articles_len
    }

    summary_data = {
        'CNN/Daily Mail': cnn_summaries_len,
        'XSum': xsum_summaries_len
    }

    fig, (article_ax, summary_ax) = plt.subplots(2, 1, figsize=(25, 16))
    plot_summaries_sizes_single_ax(article_ax, article_data, 'articles', x_lim=2000, bin_width=20)
    plot_summaries_sizes_single_ax(summary_ax, summary_data, 'summaries', x_lim=150)
    fig.tight_layout()
    plt.savefig(save_path / f'summarization_datasets_lengths.png')
    plt.show()


def print_summarization_examples(dataset_name: str, examples_number: int = 3, skip_examples: int = 0) -> None:
    detokenizer = TreebankWordDetokenizer()
    for i, (article, summary) in enumerate(DatasetGenerator.generate_dataset(dataset_name, 'train')):
        if i < skip_examples:
            continue
        if i == examples_number + skip_examples:
            break

        print(detokenizer.detokenize(article))
        print(15 * '-')
        print(detokenizer.detokenize(summary))
        print(50 * '-')

In [None]:
print_entities_examples('conll2003')

In [None]:
print_entities_examples('gmb')

In [None]:
print_ner_dataset_distribution()

In [None]:
print_ner_dataset_average_sample_length('conll2003')
print_ner_dataset_average_sample_length('gmb')

In [None]:
print_summarization_dataset_distribution('cnn_dailymail')

In [None]:
print_summarization_dataset_distribution('xsum')

In [None]:
plot_ner_dataset_lengths()

In [None]:
count_dataset_entities('conll2003')

In [None]:
count_dataset_entities('gmb')

In [None]:
plot_summaries_dataset_lengths()

In [None]:
print_ner_examples('conll2003')

In [None]:
print_ner_examples('gmb')

In [None]:
print_summarization_examples('cnn_dailymail', examples_number=15)

In [None]:
print_summarization_examples('xsum', examples_number=15)

In [None]:
print_summarization_novel_ngrams('cnn_dailymail')

In [None]:
print_summarization_novel_ngrams('xsum')