## Prepare dataset

In this notebook, there is dataset preparation for both training and evaluation. 

TrainingDataset - Speakleash wolne_lektury_corpus
EvalDataset - Speakleash 1000_novels_corpus_CLARIN-PL

Both datasets contains polish poems and books.


### Datasets overview

### Imports and consts

In [1]:
from speakleash import Speakleash
import os
from transformers import AutoTokenizer
from torch.utils.data import Dataset
import torch
from typing import Iterator, List


In [2]:
TRAINING_DATASET = "wolne_lektury_corpus"
EVAL_DATASET = "1000_novels_corpus_CLARIN-PL"

TOKENIZER = "dkleczek/bert-base-polish-uncased-v1"

SPEAKLEASH_DATA_DIR = "./speakleash"
DATASETS_DIR = "./datasets"

os.makedirs(SPEAKLEASH_DATA_DIR, exist_ok=True)
os.makedirs(DATASETS_DIR, exist_ok=True)

### Load Speakleash

In [3]:
sl = Speakleash(SPEAKLEASH_DATA_DIR)

In [12]:
training_data = sl.get(TRAINING_DATASET).data
eval_data = sl.get(EVAL_DATASET).data

### Tokenizer

In [5]:
tokenizer = AutoTokenizer.from_pretrained("dkleczek/bert-base-polish-uncased-v1")

### Dataset creation

In [None]:
class LanguageModelingDataset(Dataset):
    def __init__(self, inputs, labels):
        self.X = [torch.tensor(seq, dtype=torch.long) for seq in inputs]
        self.y = [torch.tensor(seq, dtype=torch.long) for seq in labels]

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


def is_line_valuable(line: str) -> bool:
    return len(line) > 20


def parse_book_to_lines(book: str) -> Iterator[str]:
    lines = []
    splited = book.split("\n")
    for line in splited:
        if is_line_valuable(line):
            lines.append(line)

    return lines


def make_chunks(tokens, seq_len=128, stride=64):
    chunks_input = []
    chunks_label = []

    for i in range(0, len(tokens) - seq_len + 1, stride):
        chunk = tokens[i : i + seq_len]
        chunks_input.append(chunk[:-1])
        chunks_label.append(chunk[1:])
    return chunks_input, chunks_label


def prepare_dataset(texts: List[str], seq_len=128, stride=64, subset=None) -> Dataset:
    all_inputs = []
    all_labels = []

    for i, text in enumerate(texts):
        lines = parse_book_to_lines(text)
        if not lines:
            print(f"Skipping empty book {i}")
            continue

        encodings = tokenizer(
            lines,
            padding=False,
            truncation=True,
            max_length=seq_len,
        )

        if not encodings["input_ids"]:
            print(f"No tokens for book {i}, skipping")
            continue

        book_tokens = []
        for ids in encodings["input_ids"]:
            book_tokens.extend(ids)

        inputs, labels = make_chunks(book_tokens, seq_len=seq_len, stride=stride)
        all_inputs.extend(inputs)
        all_labels.extend(labels)

        if i % 25 == 0:
            print(f"Parsed: {i + 1}")

    if subset is not None:
        all_inputs = all_inputs[:subset]
        all_labels = all_labels[:subset]

    dataset = LanguageModelingDataset(all_inputs, all_labels)
    return dataset

In [None]:
training_dataset = prepare_dataset(training_data, subset=2)
torch.save(
    {
        "inputs": training_dataset.X,
        "labels": training_dataset.y,
    },
    os.path.join(DATASETS_DIR, "training_data_small.pt"),
)

print("Saved training dataset")

Parsed: 1
Saved training dataset


In [13]:
eval_dataset = prepare_dataset(eval_data, subset=10000)
torch.save(
    {
        "inputs": eval_dataset.X,
        "labels": eval_dataset.y,
    },
    os.path.join(DATASETS_DIR, "eval_data_10000_docs.pt"),
)

print("Saved eval dataset")

Parsed: 1
Parsed: 26
Parsed: 51
Parsed: 76
Parsed: 101
Parsed: 126
Parsed: 151
Parsed: 176
Parsed: 201
Parsed: 226
Parsed: 251
Parsed: 276
Parsed: 301
Parsed: 326
Parsed: 351
Parsed: 376
Parsed: 401
Parsed: 426
Parsed: 451
Parsed: 476
Parsed: 501
Parsed: 526
Parsed: 551
Parsed: 576
Parsed: 601
Parsed: 626
Parsed: 651
Parsed: 676
Parsed: 701
Parsed: 726
Parsed: 751
Parsed: 776
Parsed: 801
Parsed: 826
Parsed: 851
Parsed: 876
Parsed: 901
Parsed: 926
Parsed: 951
Parsed: 976
Saved eval dataset
