In [None]:
%run ./../utils/_logger.ipynb
%run ./../utils/_preprocess-utils.ipynb

In [None]:
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
from tqdm import tqdm
from tensorflow.keras import Model
from tensorflow.keras.layers import Input

In [None]:
TRF_DATA_DIR = f'{PROC_DATA_DIR}/trf'
TRF_DATA_PATH = f'{TRF_DATA_DIR}/trf_data.np'

TRF_PREPROCESSOR_HANDLE = 'https://tfhub.dev/tensorflow/albert_en_preprocess/3'
TRF_MODEL_HANDLE = 'https://tfhub.dev/tensorflow/albert_en_base/3'
USE_TYPE_IDS = True

SEQ_LEN = 512

In [None]:
DOCUMENT_FILTERS = (deaccent, lower_to_unicode, strip_tags, strip_multiple_whitespaces)

In [None]:
if not Path(TRF_DATA_PATH).is_file():
    corpus = df[doc_col].values

    logger.info("Preprocessing corpus...")
    corpus = [apply_filters(doc, filters=DOCUMENT_FILTERS) for doc in tqdm(corpus, disable=SILENT)]

    logger.info("Replacing special characters...")
    corpus = [sub_pattern(doc, pattern=SUB_PATTERN) for doc in tqdm(corpus, disable=SILENT)]

    logger.info("Removing unprintable characters...")
    corpus = [remove_unprintable(doc) for doc in tqdm(corpus, disable=SILENT)]

    logger.info("Tokenizing corpus...")    
    preprocessor = hub.load(TRF_PREPROCESSOR_HANDLE)
    tokenizer = hub.KerasLayer(preprocessor.tokenize)
    inputs_packer = hub.KerasLayer(preprocessor.bert_pack_inputs, arguments=dict(seq_length=SEQ_LEN))
    
    tokenizer_input = [Input(shape=(), dtype=tf.string)]
    tokenizer_output = inputs_packer([tokenizer(doc) for doc in tokenizer_input])
    corpus = Model(tokenizer_input, tokenizer_output).predict(tf.constant(corpus), batch_size=8, verbose=not SILENT)

    if USE_TYPE_IDS:
        trf_data = np.stack((corpus['input_word_ids'], corpus['input_mask'], corpus['input_type_ids']), axis=1)
    else:
        trf_data = np.stack((corpus['input_word_ids'], corpus['input_mask']), axis=1)

    logger.info("Storing encoded corpus to disk...")
    Path(TRF_DATA_DIR).mkdir(parents=True, exist_ok=True)
    with open(TRF_DATA_PATH, 'wb') as f:
        np.save(f, trf_data)
else:
    with open(TRF_DATA_PATH, 'rb') as f:
        trf_data = np.load(f)