In [19]:
# default_exp data

In [20]:
#hide
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [21]:
#hide
from nbdev.showdoc import *

In [22]:
#export
from fastcore.all import *
from fastai.basics import Transform, ItemTransform
from fastai.text.all import *
from transformers import AutoTokenizer, AutoConfig

# Data

> Transforms and DataBlocks.

## Transforms

In [23]:
#export
class TokBatchTransform(Transform):
    def __init__(self, pretrained_model_name=None, tokenizer_cls=AutoTokenizer, 
                 config=None, tokenizer=None, is_lm=False, with_labels=False,
                 padding=True, truncation=True, max_length=None, 
                 do_targets=False, **kwargs):
        if tokenizer is None:
            tokenizer = tokenizer_cls.from_pretrained(pretrained_model_name, config=config)
        self.tokenizer = tokenizer
        self.kwargs = kwargs
        store_attr()
    
    def encodes(self, batch):
        # batch is a list of tuples of ({text or (text1, text2)}, {targets...})
        if is_listy(batch[0][0]): # 1st element is tuple
            texts = ([s[0][0] for s in batch], [s[0][1] for s in batch])
        elif is_listy(batch[0]): 
            texts = ([s[0] for s in batch],)
        else: # batch is list of texts
            texts = (list(batch),)
            batch = [(s, ) for s in batch]
        # return_tensors = None if self.is_lm else 'pt'
        # padding = None if self.is_lm else self.padding
        inps = self.tokenizer(*texts,
                              add_special_tokens=True,
                              padding=self.padding,
                              truncation=self.truncation,
                              max_length=self.max_length,
                              return_tensors='pt',
                              **self.kwargs)
        
        if self.do_targets and isinstance(batch[0][1], str):
            target_texts = [s[1] for s in batch]
            targets = self.tokenizer(target_texts,
                              add_special_tokens=False,
                              padding=self.padding,
                              truncation=self.truncation,
                              max_length=self.max_length,
                              return_tensors='pt', 
                                    **self.kwargs).input_ids
            # join inps and targs
        else:
            # inps are batched, collate targets into batches too
            labels = default_collate([s[1:] for s in batch])
            if self.with_labels:
                # TODO consider cases when there are multiple labels
                inps['labels'] = labels[0]
                res = (inps, )
            else:
                res = (inps, ) + tuple(labels)
#         if self.is_lm:
#             res = [(x, x) for x in res]
        return res
    
    def decodes(self, x:TensorText):
        return TitledStr(self.tokenizer.decode(x.cpu(), skip_special_tokens=True))

In [24]:
#export
class Undict(ItemTransform):
    
    def decodes(self, b):
        # this is done hacky way to make show_batch work both when labels are separate and when in dict
        # should be a better way
        x = b[0]
        if 'input_ids' in x: res = (TensorText(x['input_ids']), )
        if 'labels' in x: res += (x['labels'], )
        return res + tuple(b[1:])

## DataBlocks

In [25]:
#export
class TransformersTextBlock(TransformBlock):
    "A `TransformBlock` for texts using pretrained tokenizers from Huggingface"
    @delegates(TokBatchTransform)
    def __init__(self, pretrained_model_name=None, tokenizer_cls=AutoTokenizer, 
                 config=None, tokenizer=None, is_lm=False, **kwargs):
        before_batch_tfm = TokBatchTransform(pretrained_model_name=pretrained_model_name, tokenizer_cls=tokenizer_cls, 
                 config=config, tokenizer=tokenizer, **kwargs)
        return super().__init__(dl_type=LMDataLoader if is_lm else SortedDL,
                                dls_kwargs={'before_batch': before_batch_tfm,
                                            'create_batch': fa_convert},
                                batch_tfms=Undict()
                               )

#     @classmethod
#     def from_pretrained(cls, ):
#         pass

#     @classmethod
#     def from_tokenizer(cls, ):
#         pass

#     @classmethod
#     def from_config(cls, ):
#         pass

In [26]:
#hide
from nbdev.export import notebook2script; notebook2script()

Converted 00_data.ipynb.
Converted 01_learner.ipynb.
Converted index.ipynb.
