**Install dependencies**

In [1]:
# !pip install -q -U tensorflow-text
# !pip install -q tf-models-official

**Python modules**

In [2]:
import os
import shutil

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
from official.nlp import optimization  # to create AdamW optimizer

import matplotlib.pyplot as plt

tf.get_logger().setLevel('ERROR')

**Project dependencies**

In [6]:
from config import config
from db import db

from bert.handler import get_encoder_and_preprocess_handlers

**Tensorflow hub settings**

In [16]:
def create_env():
    """This tells tensorflow to persis modle in specified directory,
    otherwise, it is going to save to model files somewhere else(hard to find).
    You can not download a model in other people's computer without global
    vpn access.
    
    The first time to run this script, you should ensure global vpn access.
    """
    tfhub_path = os.path.join(config.model_path,"tfhub_modules")
    if not os.path.exists(tfhub_path):
        os.makedirs(tfhub_path)
        
    os.environ['TFHUB_CACHE_DIR'] = tfhub_path
    
create_env()

In [5]:
data = db.read()
keyword2index = {keyword:i for i, keyword in enumerate(data.topic.unique())}
index2keyword = {v:k for k,v in keyword2index.items()}
data['target'] = data['topic'].replace(keyword2index)

In [7]:
tfhub_handle_encoder, tfhub_handle_preprocess =get_encoder_and_preprocess_handlers()

BERT model selected           : https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1
Preprocess model auto-selected: https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3


In [17]:
bert_preprocess_model = hub.KerasLayer(tfhub_handle_preprocess)
bert_model = hub.KerasLayer(tfhub_handle_encoder)

In [24]:
bert_preprocess_model([data.iloc[0]['article_content']])

{'input_word_ids': <tf.Tensor: shape=(1, 128), dtype=int32, numpy=
 array([[  101,  2013, 16948,  1010,  1996,  2489, 12204,  5376,  2000,
          9163,  5376,  2000,  3945,  1012, 12464,  1011, 11968,  8043,
          1011,  6434,  1012,  6045, 22074,  1063, 15489,  1011,  2806,
          1024,  2009, 27072,  1065,  1012, 12464,  1011, 11968,  8043,
          1011,  6434,  4487,  2615,  1012,  6045, 22074,  1063, 11687,
          4667,  1011,  2187,  1024,  1015,  1012,  1020,  6633,  1025,
          7785,  1011,  3953,  1024,  1014,  1012,  1019,  6633,  1065,
          1012, 12464,  1011, 11968,  8043,  1011,  6434,  1012,  6045,
         22074,  1045,  1063, 15489,  1011,  2806,  1024,  3671,  1065,
          1012, 12464,  1011, 11968,  8043,  1011,  6434,  1012,  6045,
         22074,  1009,  4957,  1009,  1012,  6045, 22074,  1063,  7785,
          1011,  2327,  1024,  1011,  1014,  1012,  1019,  6633,  1065,
          2005,  1996, 16948,  6410,  1997,  2806,  7175,  2189,  101

In [21]:
dataset = tf.data.Dataset.from_tensor_slices((data.article_content.values, data.target.values))