In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_text as text
# import warnings
# warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)

# Preprocessing

## Import the data

In [2]:
def load_data(
    path_to_train_test: str,
    batch_size: int = 32,
    seed: int = 42
) -> tuple:
    AUTOTUNE = tf.data.AUTOTUNE

    raw_train_ds = tf.keras.utils.text_dataset_from_directory(
        f'{path_to_train_test}/train',
        batch_size=batch_size,
        validation_split=0.2,
        subset='training',
        seed=seed)

    class_names = raw_train_ds.class_names
    train_ds = raw_train_ds.cache().prefetch(buffer_size=AUTOTUNE)

    val_ds = tf.keras.utils.text_dataset_from_directory(
        f'{path_to_train_test}/train',
        batch_size=batch_size,
        validation_split=0.2,
        subset='validation',
        seed=seed)

    val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

    test_ds = tf.keras.utils.text_dataset_from_directory(
        f'{path_to_train_test}/test',
        batch_size=batch_size)

    test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)

    return train_ds, val_ds, test_ds, class_names


In [3]:
train_ds, val_ds, test_ds, class_names = load_data('data/text_only')

Found 1582 files belonging to 2 classes.
Using 1266 files for training.
Found 1582 files belonging to 2 classes.
Using 316 files for validation.
Found 396 files belonging to 2 classes.


## Tokenise the data

In [4]:
# load tokenizer
model_name = 'saved_models/fake_news_bert_tokenizer'
tokenizer = tf.saved_model.load(model_name)

Work out input sequence length

In [5]:
lengths = []

for txts, _ in train_ds.take(1).prefetch(2):
    for txt in txts.numpy():
        tokens = tokenizer.tokenize([txt.decode('utf-8')])
        lengths.append(tokens.row_lengths())
        print('.', end='', flush=True)

................................

In [6]:
all_lengths = np.concatenate(lengths)

plt.hist(all_lengths, np.linspace(0, 500, 101))
plt.ylim(plt.ylim())
max_length = max(all_lengths)
plt.plot([max_length, max_length], plt.ylim())
plt.title(f'Maximum tokens per example: {max_length}')

: 

# Define the Model