In [None]:
# default_exp tokenizers

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#hide
%load_ext autoreload
%autoreload 2

# Tokenizers

> Contains tokenizers used in Reformer paper experiments, converted to fastai transforms

In [None]:
#export
import six
from fastai.text.all import *

## ByteTextTokenizer

A tokenizer which encodes each byte in a string to an id. For 8-bit strings only. This is the tokenizer used in Language Modelling tasks in the Reformer paper, based off the implementation in the [tensor2tensor library here](https://github.com/tensorflow/tensor2tensor/blob/5f9dd2db6d7797162e53adf152310ed13e9fc711/tensor2tensor/data_generators/text_encoder.py#L176)

In [None]:
#export
class ByteTextTokenizer(Transform):
    """
        Encodes each byte to an id. For 8-bit strings only.
        Credit: https://github.com/tensorflow/tensor2tensor/blob/5f9dd2db6d7797162e53adf152310ed13e9fc711/tensor2tensor/data_generators/text_encoder.py#L176
    """
    def __init__(self, is_lm=True, add_bos=False, add_eos=False):
        store_attr('is_lm, add_bos, add_eos')
        self.pad_token, self.eos_token, self.bos_token = '<pad>', '<eos>', '<bos>',
        self.pad_token_id, self.eos_token_id, self.bos_token_id = 0,1,2
        self.reserved_toks = [self.pad_token, self.eos_token, self.bos_token]  ## self.bos_token_id 
        self.reserved_tokens_bytes = [bytes(rtok, 'ascii') for rtok in self.reserved_toks]
        self.numres = len(self.reserved_toks)
        self.int2byte = six.int2byte

    @typedispatch
    def __call__(self, o:list, **kwargs):
        out = [c + self.numres for s in o for c in s.encode("utf-8")]
        if self.add_bos: out = [self.bos_token_id] + out
        if self.add_eos: out =  out + [self.eos_token_id]
        if self.is_lm:return LMTensorText(out)
        else: return TensorText(out) 
        
    @typedispatch
    def __call__(self, o:str, **kwargs):
        out = [c + self.numres for c in o.encode("utf-8")]
        if self.add_bos: out = [self.bos_token_id] + out
        if self.add_eos: out =  out + [self.eos_token_id]
        if self.is_lm: return LMTensorText(out)
        else: return TensorText(out) 
    
    def encodes(self,o):
        return self.__call__(o)
    
    def decodes(self, o:tuple):
        decoded_ids = ()
        for i in o:
            tmp_ls=[]
            for id_ in i:
                if 0 <= id_ < self.numres: tmp_ls.append(self.reserved_tokens_bytes[int(id_)])
                else: tmp_ls.append(self.int2byte(id_ - self.numres))
            decoded_ids = decoded_ids + (b"".join(tmp_ls).decode("utf-8", "replace"),)
        return TitledStr(decoded_ids)
    
    def decodes(self, o:list):
        decoded_ids = []
        for id_ in o:
            if 0 <= id_ < self.numres: decoded_ids.append(self.reserved_tokens_bytes[int(id_)])
            else: decoded_ids.append(self.int2byte(id_ - self.numres))
        return TitledStr(b"".join(decoded_ids).decode("utf-8", "replace"))
    
    def decodes(self, o:TensorText):
        return self.decodes(o.tolist())
    
    def decodes(self, o:LMTensorText):
        return self.decodes(o.tolist())
    
    @property
    def vocab_size(self): return 2**8 + self.numres

In [None]:
wonder = "I wonder how the moon got it's shine?"
tok = ByteTextTokenizer()
tok_wonder = tok(wonder)

# test string vs list
assert (tok(wonder) == tok([wonder])).sum() == len(tok(wonder)) 
# assert (tok.decode(tok_wonder) == tok.decode([tok_wonder])).sum() == len(wonder) 
assert type(tok_wonder) == LMTensorText
assert len(tok_wonder) == 37
assert tok.decode(tok_wonder) == wonder

In [None]:
tok2 = ByteTextTokenizer(add_bos=True)
tok_wonder2 = tok2(wonder)
assert tok_wonder2[0] == 2

tok3 = ByteTextTokenizer(add_eos=True)
tok_wonder3 = tok3(wonder)
assert tok_wonder3[-1] == 1

In [None]:
#hide
from nbdev.export import notebook2script; notebook2script()

Converted 00_core.ipynb.
Converted 01_layers.ipynb.
Converted 02_attention.ipynb.
Converted 03_transformer.ipynb.
Converted 04_reformer.ipynb.
Converted 05_tokenizers.ipynb.
Converted 05_tokenizers_misktak.ipynb.
Converted 06_data.ipynb.
Converted index.ipynb.
