In [1]:
import collections
import logging
import os
import pathlib
import re
import string
import sys
import tempfile
import time

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import tensorflow_datasets as tfds
import tensorflow_text as text
import tensorflow as tf
from tensorflow_text.tools.wordpiece_vocab import bert_vocab_from_dataset as bert_vocab

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tf.get_logger().setLevel('ERROR')
pwd = pathlib.Path.cwd()

In [3]:
AUTOTUNE = tf.data.AUTOTUNE
batch_size = 32
seed = 42

raw_train_ds = tf.keras.utils.text_dataset_from_directory(
    'data/train',
    batch_size=batch_size,
    validation_split=0.2,
    subset='training',
    seed=seed)

class_names = raw_train_ds.class_names
train_ds = raw_train_ds.cache().prefetch(buffer_size=AUTOTUNE)

val_ds = tf.keras.utils.text_dataset_from_directory(
    'data/train',
    batch_size=batch_size,
    validation_split=0.2,
    subset='validation',
    seed=seed)

val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

test_ds = tf.keras.utils.text_dataset_from_directory(
    'data/test',
    batch_size=batch_size)

test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)

Found 1582 files belonging to 2 classes.
Using 1266 files for training.
Found 1582 files belonging to 2 classes.
Using 316 files for validation.
Found 396 files belonging to 2 classes.


Load CSV datasets into `tf.Dataset` format

In [4]:
# train_path = 'data/train.csv'

# # Automatically determines the data types.
# # Specify the label's column name if you have a target variable.
# train_articles = tf.data.experimental.make_csv_dataset(
#     train_path,
#     batch_size=1,  # Adjust based on your needs
#     label_name='is_fake',
#     na_value="?",
#     num_epochs=1,
#     ignore_errors=True,
# )

In [5]:
for texts, labels in train_ds.take(1):
    print(texts.numpy()[0].decode('utf-8'))
    print(labels.numpy()[0])

huma abedin swore under oath she had no emails she lied sioux indians wish dakota pipeline protesters would go home october   daniel greenfield the ecoloons protesting the dakota pipeline have the support of the media and the white house but many of the local sioux dont see them as defenders they just wish they would go home  ask around and youll hear stories of pipeline protesters whove traveled great distances theyve come from japan russia and germany australia israel and serbia and of course there are the allies not exclusively native american or indigenous whove flocked here from all corners of the us demonstrating is their proud daily work the obnoxious leftists of the world have united and they want no pipelines or showers no one makes this clearer than robert fool bear sr  district chairman of cannon ball the town he runs estimated population of  is just a few miles from the action its so close that given the faceoffs with law enforcement you have to pass through a police checkp

Train tokeniser on our data

In [6]:
train_set = train_ds.map(lambda text, label: text)

In [7]:
bert_tokenizer_params=dict(lower_case=True)
reserved_tokens=["[PAD]", "[UNK]", "[START]", "[END]"]

bert_vocab_args = dict(
    # The target vocabulary size
    vocab_size = 8000,
    # Reserved tokens that must be included in the vocabulary
    reserved_tokens=reserved_tokens,
    # Arguments for `text.BertTokenizer`
    bert_tokenizer_params=bert_tokenizer_params,
    # Arguments for `wordpiece_vocab.wordpiece_tokenizer_learner_lib.learn`
    learn_params={},
)

In [8]:
# %%time
train_vocab = bert_vocab.bert_vocab_from_dataset(
    train_set.batch(3).prefetch(2),
    **bert_vocab_args
)

Take a look at the vocab it has come up with

In [9]:
print(train_vocab[:10])
print(train_vocab[100:110])
print(train_vocab[1000:1010])
print(train_vocab[-10:])

['[PAD]', '[UNK]', '[START]', '[END]', '1', '2', '?', '_', 'a', 'b']
['up', 'she', '##ly', 'like', 'just', 'them', 'also', 'do', 'new', 'than']
['leading', 'protests', 'significant', '##el', 'air', 'effort', 'iraqi', 'judge', 'ones', 'organizations']
['witnessing', 'ww', 'yelled', 'yellow', '##1', '##2', '##?', '##_', '##ø', '##⁄']


Save the vocab to a file

In [10]:
def write_vocab_file(filepath, vocab):
    with open(filepath, 'w', encoding='utf-8') as f:
          for token in vocab:
            print(token, file=f)

In [22]:
vocab_file = 'data/train_vocab.txt'
write_vocab_file(vocab_file, train_vocab)

In [12]:
train_tokeniser = text.BertTokenizer('data/train_vocab.txt', **bert_tokenizer_params)

Take 3 examples to see how BERT tokenises them

In [13]:
examples = []
for texts, _ in train_ds.take(1):
    for i in range(3):
        examples.append(texts[i])
for ex in examples:
    print(ex.numpy().decode())

how hillary courts the black vote last majority white election minorities to give democrats white house through  paul bedard washington examiner november   this may be the last year that a republican will be elected president as the growth of liberalleaning minorities all but guarantees that democrats will hold the majority at least for the next four decades thats according to an unusual survey on the impact of minorities mostly latin american done for wallethub and provided to secrets  the study used two models based on population projections and matched to the overwhelming  percent minority turnout for president obama in  and the underwhelming  percent response for george w bush in  the bottom line in no presidential election from  do the republicans win snip snip
obama talks about himself  times in speeches supposedly about hillary  things you need to know about trump and sex slave island by amanda prestigiacomo may   presumptive republican presidential nominee donald trump has thro

In [14]:
token_batch = train_tokeniser.tokenize(examples)
token_batch.merge_dims(-2, -1)

for ex in token_batch.to_list():
    print(ex)

[[117], [88], [2297], [36], [266], [189], [205], [585], [152], [113], [2060], [37], [378], [309], [152], [203], [196], [777], [3095, 1857], [308], [5869, 167], [232], [49], [173], [53], [36], [205], [250], [41], [8], [239], [71], [53], [639], [130], [47], [36], [804], [38], [1120, 6584], [2060], [74], [65], [6840, 43], [41], [309], [71], [1266], [36], [585], [60], [341], [44], [36], [247], [598], [966], [276], [209], [37], [62], [5670], [2950], [45], [36], [1213], [38], [2060], [1215], [521, 6649], [160], [460], [44], [846, 6536, 3404], [39], [1322], [37], [6433], [36], [646], [255], [158], [5946], [632], [45], [600], [645, 919], [39], [4452, 75], [37], [36], [4691], [242], [1849], [2326], [44], [130], [143], [40], [39], [36], [221, 1088, 1729, 3515, 70], [242], [1056], [44], [756], [30], [1105], [40], [36], [2638], [655], [40], [87], [230], [113], [56], [107], [36], [343], [457], [1826], [1826]]
[[143], [1894], [77], [641], [240], [40], [2542], [3149], [77], [88], [257], [55], [225], 

In [15]:
# replace IDs with their text representations
txt_tokens = tf.gather(train_vocab, token_batch)
texts_tokenised = tf.strings.reduce_join(txt_tokens, separator=' ', axis=-1)
for text_tokenised in texts_tokenised:
    print(' '.join([term.decode('utf-8') for term in text_tokenised.numpy()]))

how hillary courts the black vote last majority white election minorities to give democrats white house through paul bed ##ard washington examine ##r november this may be the last year that a republican will be elected president as the growth of liberal ##leaning minorities all but guarantee ##s that democrats will hold the majority at least for the next four decades thats according to an unusual survey on the impact of minorities mostly la ##tin american done for wall ##eth ##ub and provided to secrets the study used two models based on population project ##ions and match ##ed to the overwhelming percent minority turnout for president obama in and the under ##w ##he ##lm ##ing percent response for george w bush in the bottom line in no presidential election from do the republicans win snip snip
obama talks about himself times in speeches supposedly about hillary things you need to know about trump and sex slave island by am ##anda p ##res ##ti ##g ##ia ##com ##o may p ##res ##ump ##ti

In [16]:
# reassemble words from extracted tokens
words = train_tokeniser.detokenize(token_batch)
reassembled_words = tf.strings.reduce_join(words, separator=' ', axis=-1)
for words in reassembled_words:
    print(' '.join([term.decode('utf-8') for term in words.numpy()]))

how hillary courts the black vote last majority white election minorities to give democrats white house through paul bedard washington examiner november this may be the last year that a republican will be elected president as the growth of liberalleaning minorities all but guarantees that democrats will hold the majority at least for the next four decades thats according to an unusual survey on the impact of minorities mostly latin american done for wallethub and provided to secrets the study used two models based on population projections and matched to the overwhelming percent minority turnout for president obama in and the underwhelming percent response for george w bush in the bottom line in no presidential election from do the republicans win snip snip
obama talks about himself times in speeches supposedly about hillary things you need to know about trump and sex slave island by amanda prestigiacomo may presumptive republican presidential nominee donald trump has thrown everything

Customisation and export

In [17]:
START = tf.argmax(tf.constant(reserved_tokens) == '[START]')
END = tf.argmax(tf.constant(reserved_tokens) == '[END]')

def add_start_end(ragged):
    count = ragged.bounding_shape()[0] # num rows
    starts = tf.fill([count, 1], START) # num rows x 1 tensor of STARTs
    ends = tf.fill([count, 1], END)
    return tf.concat([starts, ragged, ends], axis=1)

In [18]:
# words = en_tokeniser.detokenize(add_start_end(token_batch))
# tf.strings.reduce_join(words, separator=' ', axis=-1)

In [19]:
def cleanup_text(reserved_tokens, token_txt):
    # Drop the reserved tokens, except for "[UNK]".
    bad_tokens = [re.escape(tok) for tok in reserved_tokens if tok != "[UNK]"]
    bad_token_re = "|".join(bad_tokens)

    bad_cells = tf.strings.regex_full_match(token_txt, bad_token_re)
    result = tf.ragged.boolean_mask(token_txt, ~bad_cells)

    # Join them into strings.
    result = tf.strings.reduce_join(result, separator=' ', axis=-1)

    return result

In [20]:
class CustomTokenizer(tf.Module):
    def __init__(self, reserved_tokens, vocab_path):
        self.tokenizer = text.BertTokenizer(vocab_path, lower_case=True)
        self._reserved_tokens = reserved_tokens
        self._vocab_path = tf.saved_model.Asset(vocab_path)

        vocab = pathlib.Path(vocab_path) \
            .read_text(encoding='utf-8').splitlines()
        self.vocab = tf.Variable(vocab)

        ## Create the signatures for export:   

        # Include a tokenize signature for a batch of strings. 
        self.tokenize.get_concrete_function(
            tf.TensorSpec(shape=[None], dtype=tf.string))

        # Include `detokenize` and `lookup` signatures for:
        #   * `Tensors` with shapes [tokens] and [batch, tokens]
        #   * `RaggedTensors` with shape [batch, tokens]
        self.detokenize.get_concrete_function(
            tf.TensorSpec(shape=[None, None], dtype=tf.int64))
        self.detokenize.get_concrete_function(
            tf.RaggedTensorSpec(shape=[None, None], dtype=tf.int64))

        self.lookup.get_concrete_function(
            tf.TensorSpec(shape=[None, None], dtype=tf.int64))
        self.lookup.get_concrete_function(
            tf.RaggedTensorSpec(shape=[None, None], dtype=tf.int64))

        # These `get_*` methods take no arguments
        self.get_vocab_size.get_concrete_function()
        self.get_vocab_path.get_concrete_function()
        self.get_reserved_tokens.get_concrete_function()

    @tf.function
    def tokenize(self, strings):
        enc = self.tokenizer.tokenize(strings)
        # Merge the `word` and `word-piece` axes.
        enc = enc.merge_dims(-2,-1)
        enc = add_start_end(enc)
        return enc

    @tf.function
    def detokenize(self, tokenized):
        words = self.tokenizer.detokenize(tokenized)
        return cleanup_text(self._reserved_tokens, words)

    @tf.function
    def lookup(self, token_ids):
        return tf.gather(self.vocab, token_ids)

    @tf.function
    def get_vocab_size(self):
        return tf.shape(self.vocab)[0]

    @tf.function
    def get_vocab_path(self):
        return self._vocab_path

    @tf.function
    def get_reserved_tokens(self):
        return tf.constant(self._reserved_tokens)

In [23]:
tokenizer = CustomTokenizer(reserved_tokens, vocab_file)

In [24]:
model_name = 'fake_news_bert_tokenizer'
tf.saved_model.save(tokenizer, model_name)

In [25]:
reloaded_tokenizer = tf.saved_model.load(model_name)
reloaded_tokenizer.get_vocab_size().numpy()

7656

In [26]:
tokens = reloaded_tokenizer.tokenize(['Hello TensorFlow!'])
tokens.numpy()

array([[   2, 1184,  206, 2777,  427,  565, 3997,    1,    3]],
      dtype=int64)

In [37]:
text_tokens = reloaded_tokenizer.lookup(tokens)
print(' '.join([term.decode('utf-8') for term in text_tokens.numpy()[0]]))
    

[START] hell ##o tens ##or ##f ##low [UNK] [END]


In [28]:
round_trip = reloaded_tokenizer.detokenize(tokens)

print(round_trip.numpy()[0].decode('utf-8'))

hello tensorflow [UNK]
