In [1]:
import datasets

book_builder = datasets.load_dataset_builder("wikitext", "wikitext-103-v1")
print(book_builder.info.description)
print(book_builder.info.splits)
print(book_builder.info.features)

  from .autonotebook import tqdm as notebook_tqdm


 The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified
 Good and Featured articles on Wikipedia. The dataset is available under the Creative Commons Attribution-ShareAlike
 License.

{'test': SplitInfo(name='test', num_bytes=1295575, num_examples=4358, shard_lengths=None, dataset_name='wikitext'), 'train': SplitInfo(name='train', num_bytes=545141915, num_examples=1801350, shard_lengths=[1653000, 148350], dataset_name='wikitext'), 'validation': SplitInfo(name='validation', num_bytes=1154751, num_examples=3760, shard_lengths=None, dataset_name='wikitext')}
{'text': Value(dtype='string', id=None)}


In [2]:
from tokenizers import SentencePieceBPETokenizer
from transformers import PreTrainedTokenizerFast
import os
import re

def chunk_tokens_and_ids(examples, chunk_size, ids_key, tokens_key):
    chunked_ids = []
    chunked_tokens = []
    for id, token in zip(examples[ids_key], examples[tokens_key]):
        chunked_ids += [id[i:i + chunk_size] for i in range(0, len(id), chunk_size)]
        chunked_tokens += [token[i:i + chunk_size] for i in range(0, len(token), chunk_size)]
    return {f"chunked_{tokens_key}": chunked_tokens, f"chunked_{ids_key}": chunked_ids}

def get_tokens_and_ids(examples, tokenizer, text_key, ids_key, tokens_key):
    tokens = []
    ids = []
    for example in examples[text_key]:
        token = tokenizer.tokenize(example, padding=False)
        id = tokenizer.convert_tokens_to_ids(token)
        tokens.append(token)
        ids.append(id)
    return {tokens_key: tokens, ids_key: ids}

def calc_token_ratio(examples, ids_key, data_key):
    ratios = []
    for id, data in zip(examples[ids_key], examples[data_key]):
        ratios.append(len(id) / len(data))
    return {"token_ratio": ratios}

class DatasetWrapper:
    dataset_name = "wikitext"
    data_config = "wikitext-103-v1"
    download_split = "train"
    download_split_pct = "1%"
    data_key = "text"
    ids_key = "input_ids"
    tokens_key = "tokens"
    max_token_ratio = .33

    def __init__(self, model_max_length=512, processes=None):
        self.tokenizer_filename = f"{self.dataset_name}_{self.download_split_pct}_tokenizer"
        self.model_max_length = model_max_length
        self.processes = processes

    def dataset_info(self):
        data = datasets.load_dataset_builder(self.dataset_name)
        print(data.info.description)
        print(data.info.splits)
        print(data.info.features)

    def process_dataset(self):
        data = self.load_dataset()
        if self.combine_func is not None:
            data = data.map(lambda x: self.combine_func(x, text_key=self.data_key), batched=True, remove_columns=data.column_names)

        tokenizer = self.get_tokenizer(data[self.data_key])
        tokenized = self.tokenize_dataset(data, tokenizer)
        tokenized = self.filter_tokenized_text(tokenized)
        data_chunks = self.chunk_tokens(tokenized)
        return data_chunks

    def combine_func(self, examples, text_key="text"):
        entries = []
        entry = ""
        for sentence in examples[text_key]:
            if re.match("^ \= \w", sentence):
                if entry:
                    entries.append(entry)
                entry = ""
            else:
                entry += sentence
        entries.append(entry)
        return {text_key: entries}

    def load_dataset(self):
        split_str = self.download_split
        if self.download_split_pct:
            split_str = f"{self.download_split}[:{self.download_split_pct}]"
        if not self.data_config:
            data = datasets.load_dataset(self.dataset_name, split=split_str, num_proc=self.processes)
        else:
            data = datasets.load_dataset(self.dataset_name, self.data_config, split=split_str, num_proc=self.processes)
        return data

    def tokenize_dataset(self, data, tokenizer):
        data = data.map(lambda examples: get_tokens_and_ids(examples, tokenizer, self.data_key, self.ids_key, self.tokens_key), batched=True, num_proc=self.processes)
        return data

    def filter_tokenized_text(self, data):
        data = data.map(lambda x: calc_token_ratio(x, self.ids_key, self.data_key), batched=True, num_proc=self.processes)
        data = data.filter(lambda x: x["token_ratio"] < self.max_token_ratio, num_proc=self.processes)
        return data

    def chunk_tokens(self, data):
        data_chunks = data.map(lambda examples: chunk_tokens_and_ids(examples, self.model_max_length, self.ids_key, self.tokens_key), batched=True, remove_columns=data.column_names, num_proc=self.processes)
        return data_chunks

    def get_tokenizer(self, data=None):
        if os.path.exists(self.tokenizer_filename):
            tokenizer = self.load_tokenizer(self.tokenizer_filename)
        else:
            tokenizer = self.train_tokenizer(data, self.tokenizer_filename)
        return tokenizer

    def train_tokenizer(self, text, save_file, vocab_size=5000, min_frequency=2):
        special_tokens = ["<s>", "<pad>", "</s>", "<unk>"]
        sptokenizer = SentencePieceBPETokenizer()
        sptokenizer.train_from_iterator(
            text,
            vocab_size=vocab_size,
            min_frequency=min_frequency,
            show_progress=True,
            special_tokens=special_tokens
        )

        tokenizer = PreTrainedTokenizerFast(tokenizer_object=sptokenizer, special_tokens=special_tokens)
        for attr, token in zip(["bos_token", "pad_token", "eos_token", "unk_token"], special_tokens):
            setattr(tokenizer, attr, token)
            setattr(tokenizer, f"{attr}_id", sptokenizer.token_to_id(token))
        tokenizer.save_pretrained(save_file)
        return tokenizer

    def load_tokenizer(self, tokenizer_file):
        tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_file)
        return tokenizer

In [11]:
from datasketch import MinHash, MinHashLSH
from multiprocessing import Manager
from collections import defaultdict
from itertools import chain

def hash_tokens(tokens):
    m = MinHash(num_perm=256)
    for t in tokens:
        m.update(t.encode())
    return m

def hash_examples(examples, idxs, queue):
    hashes = []
    for idx, ex in zip(idxs, examples["chunked_tokens"]):
        min_hash = hash_tokens(ex)
        hashes.append((idx, min_hash))
    queue.put(hashes)

def index_hashes(examples, index, dup_store):
    for example in examples:
        idx, min_hash = example
        key = idx
        if key in index.keys:
            continue
        process_duplicates(min_hash, key, dup_store, index)
        index.insert(key, min_hash)

def process_duplicates(min_hash, key, dup_store, index):
    close_duplicates = index.query(min_hash)
    # Assign input hash to at most one duplicate cluster
    if len(close_duplicates) > 0:
        for base_duplicate in close_duplicates:
            if base_duplicate in dup_store:
                dup_store[base_duplicate].add(key)
                break
        else:
            dup_store[close_duplicates[0]].add(key)

def find_extremes(cluster, data, thresh=.9):
    extremes = set()
    for dup in cluster:
        for elem in extremes:
            d1 = set(data[dup]["chunked_tokens"])
            d2 = set(data[elem]["chunked_tokens"])
            sim = (d1 & d2) / (d1 | d2)
            if sim > thresh:
                break
        else:
            extremes.add(dup)
    return extremes

class Deduplicator:
    def __init__(self, threshold=.9, num_perm=256, processes=None):
        self.index = MinHashLSH(threshold=threshold, num_perm=num_perm)
        self.dup_store = defaultdict(set)
        self.processes = processes

    def deduplicate(self, data):
        manager = Manager()
        hash_queue = manager.Queue()
        data.map(lambda xs, idxs: hash_examples(xs, idxs, hash_queue), batched=True, with_indices=True, num_proc=self.processes)
        # Used to end the processing later on
        hash_queue.put(None)

        # Add hashes to queue
        while True:
            examples = hash_queue.get()
            if examples is None:
                break
            index_hashes(examples, self.index, self.dup_store)

        duplicate_indices = set(chain.from_iterable(self.dup_store.values()))
        extremes = self.get_extremes(data)
        duplicate_indices = duplicate_indices - extremes
        filtered = data.filter(lambda x, idx: idx not in duplicate_indices, with_indices=True, num_proc=self.processes)
        return filtered

    def get_extremes(self, data):
        extremes = map(lambda x: find_extremes(x, data), self.dup_store.values())
        extremes = set(chain.from_iterable(extremes))
        return extremes

In [14]:
wrapper = DatasetWrapper(512)
data = wrapper.process_dataset()

dup = Deduplicator()
deduped = dup.deduplicate(data)

Found cached dataset wikitext (/Users/vik/.cache/huggingface/datasets/wikitext/wikitext-103-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
Loading cached processed dataset at /Users/vik/.cache/huggingface/datasets/wikitext/wikitext-103-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-03792397f5d3a78f.arrow
Loading cached processed dataset at /Users/vik/.cache/huggingface/datasets/wikitext/wikitext-103-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-1e53970faad85a3a.arrow
Loading cached processed dataset at /Users/vik/.cache/huggingface/datasets/wikitext/wikitext-103-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-6c1519a2a7c93855.arrow
Loading cached processed dataset at /Users/vik/.cache/huggingface/datasets/wikitext/wikitext-103-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-8a573e0ed1fb6b50.arrow
Loading cached processed dataset 

In [13]:
deduped

Dataset({
    features: ['chunked_tokens', 'chunked_input_ids'],
    num_rows: 3111
})