# Load text

This notebook demonstrates the classification of EEG text reports from the Temple University Hospital EEG Corpus. The basic code structure is based on Example 1 in [this demo](https://www.tensorflow.org/tutorials/load_data/text).

First, let's install and import some useful libraries.

In [1]:
# Be sure you're using the stable versions of both tf and tf-text, for binary compatibility.
!pip install -q -U tensorflow
!pip install -q -U tensorflow-text

In [2]:
import collections
import pathlib
import re
import string

import tensorflow as tf

from tensorflow.keras import layers
from tensorflow.keras import losses
from tensorflow.keras import preprocessing
from tensorflow.keras import utils
from tensorflow.keras.layers.experimental.preprocessing import TextVectorization

import tensorflow_datasets as tfds
import tensorflow_text as tf_text

# Download and explore the dataset

First we'll use a handy tool called `gdown` to download the dataset (just the text reports) from where your team have stored them on Google Drive.

In [3]:
!gdown --id 1NuNQw_HT49c0Omb051xko1NvGTh1m0lx

Downloading...
From: https://drive.google.com/uc?id=1NuNQw_HT49c0Omb051xko1NvGTh1m0lx
To: /content/TUAB_txt_sorted.tar
0.00B [00:00, ?B/s]4.72MB [00:00, 30.5MB/s]9.37MB [00:00, 43.9MB/s]


The dataset is compressed inside the archive file TUABtxt.tar, so let's extract it (like unzipping a zip file).

In [4]:
import tarfile
tar = tarfile.open("TUAB_txt_sorted.tar")
tar.extractall()
tar.close()

Now we've extracted a folder called TUABtxt. Let's use pathlib library to explore this directory.

In [5]:
dataset_dir = pathlib.Path('TUAB_txt_sorted/v2.0.0/edf')
list(dataset_dir.iterdir())

[PosixPath('TUAB_txt_sorted/v2.0.0/edf/train'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/eval')]

Extract the training dataset from the library

In [6]:
train_dir = dataset_dir/'train'
list(train_dir.iterdir())

[PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/normal'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal')]

Extract the Evaluation dataset from the library

In [7]:
eval_dir = dataset_dir/'eval'
list(train_dir.iterdir())

[PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/normal'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal')]

From the training dataset extract the abnormal and normal classes. Abnormal Dataset:

In [8]:
abnormal_train_dir = train_dir/'abnormal/01_tcp_ar'
list(abnormal_train_dir.iterdir())

[PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal/01_tcp_ar/081'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal/01_tcp_ar/049'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal/01_tcp_ar/012'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal/01_tcp_ar/007'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal/01_tcp_ar/101'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal/01_tcp_ar/032'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal/01_tcp_ar/008'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal/01_tcp_ar/080'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal/01_tcp_ar/086'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal/01_tcp_ar/027'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal/01_tcp_ar/039'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal/01_tcp_ar/016'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal/01_tcp_ar/062'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal/01_tcp_ar/096'),
 Posix

Normal Dataset:

In [9]:
normal_train_dir = train_dir/'normal/01_tcp_ar'
list(abnormal_train_dir.iterdir())

[PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal/01_tcp_ar/081'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal/01_tcp_ar/049'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal/01_tcp_ar/012'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal/01_tcp_ar/007'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal/01_tcp_ar/101'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal/01_tcp_ar/032'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal/01_tcp_ar/008'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal/01_tcp_ar/080'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal/01_tcp_ar/086'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal/01_tcp_ar/027'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal/01_tcp_ar/039'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal/01_tcp_ar/016'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal/01_tcp_ar/062'),
 PosixPath('TUAB_txt_sorted/v2.0.0/edf/train/abnormal/01_tcp_ar/096'),
 Posix

Evaluation datasets

In [10]:
abnormal_eval_dir = eval_dir/'abnormal/01_tcp_ar'
print(len(list(abnormal_eval_dir.iterdir())))
normal_eval_dir = eval_dir/'normal/01_tcp_ar'
print(len(list(normal_eval_dir.iterdir())))

52
57


We see from the above output that the data is stored across many subfolders. The documentation for the TUAB set explains this folder structure. Below each of the arbitrary subfolders listed above is a further hierarchy a folders for individual subjects and recording sessions. You don't need to understand this structure in detail, because we'll use a function to automatically extract the txt data. But let's just take a look inside one of the txt files.

In [11]:
sample_file = abnormal_train_dir/'007/00000739/s002_2012_09_26/00000739_s002.txt'
with open(sample_file) as f:
  print(f.read())

CLINICAL HISTORY:  A 64-year-old male found with an empty bottle of Januvia, and glucose of 99.  Given Narcan, Ativan and moxifloxacin.  He has a history of diabetes and neuropathy.  Previous EEG in 2003 is not available.
MEDICATIONS:  Ativan, Narcan, others.
INTRODUCTION:  Digital video EEG was performed at the bedside in the ICU using standard 10-20 system of electrode placement with one channel of EKG.  The patient was poorly responsive.
DESCRIPTION OF THE RECORD:  The background EEG demonstrates a mixed frequency background with spontaneous arousals and generous beta.
The patient drifts off to sleep and there is an increase in background slowing.  Occasional FIRDA was noted.
IMPRESSION:  This is an abnormal EEG due to:  Generalized background slowing.
CLINICAL CORRELATION:  The generous beta described above may be due to this patient's medications.  If epilepsy is an important consideration a follow up study when the patient is out of the ICU may be helpful.






### Load the dataset

Next, we will load the data off disk and prepare it into a format suitable for training. The text_dataset_from_directory utility makes this easy, and creates a tf.data.Dataset object with labels ('normal' and 'abnormal') automatically recognised from the folder structure. (tf.data is a collection of tools for building input pipelines for machine learning).

In [12]:
full_train_ds = preprocessing.text_dataset_from_directory(train_dir, batch_size=16)

Found 2717 files belonging to 2 classes.


In [13]:
test_ds = preprocessing.text_dataset_from_directory(eval_dir, batch_size=6)

Found 276 files belonging to 2 classes.


When running a machine learning experiment, it is a best practice to divide your dataset into three splits: [train](https://developers.google.com/machine-learning/glossary#training_set), [validation](https://developers.google.com/machine-learning/glossary#validation_set), and [test](https://developers.google.com/machine-learning/glossary#test-set). There are no strict rules, but usually it's best to put most of your data in the training (so that there's plenty to learn from. A 70-15-15 percent split is fairly common, as implemented below.

In [14]:
# Set the size of each subset of data:
n = len(list(full_train_ds)) # Number of batches in original dataset
n_train = int(0.7*n)   # Use about 70% as training data ...
n_val = int(0.15*n)    # ... 15% as validation data ...
n_test = n-n_train-n_val # ... and the rest as test data.
print(f"We have {n} batches in the full dataset.")
print(f"We'll use {n_train} batches in the training set, {n_val} in the validation set, and {n_test} in the test set.")

We have 170 batches in the full dataset.
We'll use 118 batches in the training set, 25 in the validation set, and 27 in the test set.


Now we're ready to actually make the split.

In [15]:
# Split the data into training, validation, and test sets:
raw_train_ds = full_train_ds.take(n_train)
raw_val_ds = full_train_ds.skip(n_train).take(n_val)
raw_test_ds = full_train_ds.skip(n_train+n_val)
#raw_test_ds = test_ds

assert(len(list(raw_test_ds))==n_test) # This assertion statement checks our code, to make sure the test dataset size is what we expect.

Let's print out a few examples, to get more of a feel for the data.

In [16]:
for text_batch, label_batch in raw_train_ds.take(1):   # Take a single batch from the dataset.
  for i in range(10):                                  # Iterate through the first 10 examples in that batch.
    print("Report: ", text_batch.numpy()[i])
    print("Label:", label_batch.numpy()[i])

Report:  b'CLINICAL HISTORY: 29 year old female with cerebral palsy, multiple seizures, recent PEG tube placement due to aspiration pneumonia. This is an outpatient EEG.\nMEDICATIONS: Dilantin, Topamax, Phenobarbital\nINTRODUCTION: Digital video EEG was performed in lab using standard 10-20 system of electrode placement with 1 channel of EKG.\nDESCRIPTION OF THE RECORD: The overall background is poorly organized. There is rhythmic, 5.3 Hz activity noted in the posterior regions on the right. This activity is somewhat faster on the left at 5.8 Hz. The overall background is high amplitude, poorly organized. Paroxysmal bursts of generalized 6 Hz activity are seen intermittently in the background. As the patient becomes drowsy, these bursts are a bit more prominent. There are rare left frontocentral sharp waves.\nHR: 120 bpm\nIMPRESSION: Abnormal EEG due to:\n1. Background slowing.\n2. Slowing of the alpha rhythm.\n3. An asymmetry in alpha rhythm with a slower rhythm on the right.\n4. Paro

The labels are `0` or `1`. To see which of these correspond to which string label, you can check the `class_names` property on the dataset, as below.


In [17]:
for i, label in enumerate(full_train_ds.class_names):
  print("Label", i, "corresponds to", label)

Label 0 corresponds to abnormal
Label 1 corresponds to normal


### Prepare the dataset for training

Next, you will standardize, tokenize, and vectorize the data using the `preprocessing.TextVectorization` layer.
* Standardization refers to preprocessing the text, typically to remove punctuation or HTML elements to simplify the dataset.

* Tokenization refers to splitting strings into tokens (for example, splitting a sentence into individual words by splitting on whitespace).

* Vectorization refers to converting tokens into numbers so they can be fed into a neural network.

All of these tasks can be accomplished with this layer. You can learn more about each of these in the [API doc](https://www.tensorflow.org/api_docs/python/tf/keras/layers/experimental/preprocessing/TextVectorization).

* The default standardization converts text to lowercase and removes punctuation.

* The default tokenizer splits on whitespace.

* The default vectorization mode is `int`. This outputs integer indices (one per token). This mode can be used to build models that take word order into account. You can also use other modes, like `binary`, to build bag-of-word models.


Here we will use the `binary` mode to build a bag-of-words model (essentially one-hot encoding of whether each word in the vocabulary appears in the report). Then we will use the `int` mode (integer encoding of each word in the report, with order preserved) with a 1D ConvNet.

In [18]:
VOCAB_SIZE = 10000

binary_vectorize_layer = TextVectorization(
    max_tokens=VOCAB_SIZE,
    output_mode='binary')

For `int` mode, in addition to maximum vocabulary size, you need to set an explicit maximum sequence length, which will cause the layer to pad or truncate sequences to exactly sequence_length values.

In [19]:
MAX_SEQUENCE_LENGTH = 250

int_vectorize_layer = TextVectorization(
    max_tokens=VOCAB_SIZE,
    output_mode='int',
    output_sequence_length=MAX_SEQUENCE_LENGTH)

Next, you will call `adapt` to make the VectorizationLayer adjust itself according to the vocabulary in the dataset.

Note: it's important to only use your training data when calling adapt (using the test set would leak information).

In [20]:
# To avoid some errors caused by non-standard characters, we create a function
# that does some additional 'cleaning' of the text.
def clean_text(text, labels):
  cleaned_version_of_text = tf.strings.unicode_transcode(text, "US ASCII", "UTF-8") 
  return cleaned_version_of_text
  
# Now apply our clean_text function to the full dataset.
train_text = raw_train_ds.map(clean_text) 

# Finally, let the vectorize layers adjust themselves to fit the vocabulary of the dataset.
binary_vectorize_layer.adapt(train_text)
int_vectorize_layer.adapt(train_text)

See the result of using these layers to preprocess data:

In [21]:
def binary_vectorize_text(text, label):
  text = tf.expand_dims(text, -1)
  return binary_vectorize_layer(text), label

In [22]:
def int_vectorize_text(text, label):
  text = tf.expand_dims(text, -1)
  return int_vectorize_layer(text), label

In [23]:
# Retrieve a batch (of 32 reports and labels) from the dataset
text_batch, label_batch = next(iter(raw_train_ds))
first_report, first_label = text_batch[0], label_batch[0]
print("Report", first_report)
print("Label", first_label)

Report tf.Tensor(b'REASON FOR STUDY:  Episodes of seizures.\nCLINICAL HISTORY:  A 30-year-old man with history of seizures for 3 months, approximately once a month.  Described as single tingling, then right upper extremity numbness and stiffness with loss of consciousness and confusion afterwards.  Past medical history of dementia.\nMEDICATIONS:  Tegretol and Seroquel.\nINTRODUCTION:  A routine EEG is performed using standard 10-20 electrode placement with an anterior temporal and single lead EKG electrode.  The patient was recorded in wakefulness and sleep.  Activating procedures included hyperventilation and photic stimulation.\nDESCRIPTION OF THE RECORD:  The record opens to a well-defined posterior dominant rhythm that reaches 9-10 Hz, which is reactive to eye opening.  There is normal frontocentral beta.  The patient is recorded during wakefulness, stage I and stage II sleep.  Please note, the patient easily falls asleep throughout the record.\nActivating procedures produced no ab

In [24]:
print("'binary' vectorized report:", 
      binary_vectorize_text(first_report, first_label)[0])

'binary' vectorized report: tf.Tensor([[0. 1. 1. ... 0. 0. 0.]], shape=(1, 6206), dtype=float32)


In [25]:
print("'int' vectorized report:",
      int_vectorize_text(first_report, first_label)[0])

'int' vectorized report: tf.Tensor(
[[ 150   54  103  208    3   23   12   18    6 1777  212    7   18    3
    23   54  183  520  293 1076    6  558  215   64  131 1066  238   24
   979 1162  571    4 1910    7  179    3  216    4  243 1148  182  271
    18    3  549   34  431    4  785   49    6  190    9    5   13   51
    44   45   29   52    7   25   92   53    4  131  148   37   29    2
    17   11   98    8   35    4   20  108  107  240   30    4   21   19
    41    3    2   14    2   14  149   10    6  265   62   91   32   88
   195  297   40   70    5  146   10  105  125   16    5   36  127   61
     2   17    5   98   81   35   65  167    4   65  134   20  365  284
     2   17 1351 1275  161  189    2   14  108  107  248   22   46   58
    46   58   75   23   75  117  124  294   42   36   85    4  161  532
    12   48   15    5    6   36   85    4  161    9   22   23   68   50
    58   33   57    0    0    0    0    0    0    0    0    0    0    0
     0    0    0    0    0  

As you can see above, `binary` mode returns an array denoting which tokens exist at least once in the input, while `int` mode replaces each token by an integer, thus preserving their order. You can lookup the token (string) that each integer corresponds to by calling `.get_vocabulary()` on the layer.

In [26]:
print("42 ---> ", int_vectorize_layer.get_vocabulary()[42])
print("44 ---> ", int_vectorize_layer.get_vocabulary()[44])
print("Vocabulary size: {}".format(len(int_vectorize_layer.get_vocabulary())))

42 --->  impression
44 --->  standard
Vocabulary size: 6215


You are nearly ready to train your model. As a final preprocessing step, you will apply the `TextVectorization` layers you created earlier to the train, validation, and test dataset.

In [27]:
binary_train_ds = raw_train_ds.map(binary_vectorize_text)
binary_val_ds = raw_val_ds.map(binary_vectorize_text)
binary_test_ds = raw_test_ds.map(binary_vectorize_text)

int_train_ds = raw_train_ds.map(int_vectorize_text)
int_val_ds = raw_val_ds.map(int_vectorize_text)
int_test_ds = raw_test_ds.map(int_vectorize_text)

# Rule-Based (non-ML) Approach

Looking through the reports, it seems as though it's usually stated quite clearly when the EEG is abnormal. Rather than attempting any machine learning, why don't we just look for that key word (or related words/phrases) in the text? A very basic version of this approach is implemented below.

In [30]:
# First initialise some counters
n = 0
n_correct = 0
n_failed_decode = 0
wrong_normal = 0
wrong_abnormal = 0
count = 0

# Iterate over all batches, taking the text and labels batch-by-batch.
# N.B. take(-1) has the effect of pulling out all the batches, instead of a specific number, as explained in the docs here: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#take
for text_batch, label_batch in full_train_ds.take(-1):

  # Iterate over the report examples in the batch:
  for ind,text in enumerate(text_batch):


    # Get rid of any pesky non-standard characters using the function we created previously.
    cleaned_text = clean_text(text,0)
    # Then convert it from a tensorflow Tensor to a python string so that we can 
    # use some standard python text analysis on it.
    cleaned_and_decoded_text = cleaned_text.numpy().decode("UTF-8")

    find_impression = re.search("impression", cleaned_and_decoded_text.lower(), flags=re.IGNORECASE)
    first_char,last_char = find_impression.span()

    #find_clinical = re.search("clinical correlation:", cleaned_and_decoded_text.lower(), flags=re.IGNORECASE)
    #if find_clinical != None:
    #  first_char_clin, last_char_clin = find_clinical.span()
    #  last_char = first_char_clin
    #else:
    last_char = last_char +75

    searched_text = cleaned_and_decoded_text.lower()[first_char:last_char]




    is_abnormal = re.search('abnormal|absence of normal|outside of the range of normal|not normal', searched_text.lower(), flags=re.IGNORECASE)

    is_normal = re.search('normal', searched_text.lower(), flags=re.IGNORECASE)
    #also_abnormal = re.search("absence of normal", searched_text.lower(), flags=re.IGNORECASE)
    
    # Check if the word 'abnormal' is in the report, and label it accordingly.
    if is_abnormal:
      predicted_label = 0
    elif is_normal:
      predicted_label = 1
      #print(cleaned_and_decoded_text)
      
    # If we predicted correctly, add one to our count of correct predictions.
    if predicted_label==label_batch[ind]:
      n_correct = n_correct+1
    else:
      # Uncomment the lines below if you want to inspect the cases where we were wrong.
      # print("--- Wrong example ---")
       print(f"This example was classified with label {predicted_label} but its actual label is {label_batch[ind].numpy()}.")
       print("---")
       if predicted_label == 0:
         wrong_abnormal = wrong_abnormal + 1
       if predicted_label == 1:
         wrong_normal = wrong_normal + 1
       
       print(cleaned_and_decoded_text)
       #print("---------------------")
      # pass

    # Add one to our count of the total number of examples examined.
    n = n+1

if count == 0:
  print(f"Accuracy = {round(100*n_correct/n,3)} percent ({n_correct} correct predictions out of {n}). {n - n_correct} misclassified. {wrong_normal} were mislabelled normal while {wrong_abnormal} were mislabelled abnormal")

Accuracy = 100.0 percent (2717 correct predictions out of 2717). 0 misclassified. 0 were mislabelled normal while 0 were mislabelled abnormal


In [29]:
'''
BACKUP VERSION

TODO:
Check labeling.
'''

'''
# First initialise some counters
n = 0
n_correct = 0
n_failed_decode = 0


# Iterate over all batches, taking the text and labels batch-by-batch.
# N.B. take(-1) has the effect of pulling out all the batches, instead of a specific number, as explained in the docs here: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#take
for text_batch, label_batch in full_ds.take(-1):

  # Iterate over the report examples in the batch:
  for ind,text in enumerate(text_batch):

    # Get rid of any pesky non-standard characters using the function we created previously.
    cleaned_text = clean_text(text,0)
    # Then convert it from a tensorflow Tensor to a python string so that we can 
    # use some standard python text analysis on it.
    cleaned_and_decoded_text = cleaned_text.numpy().decode("UTF-8")

    find_impression = re.search("impression", cleaned_and_decoded_text.lower(), flags=re.IGNORECASE)
    first_char,last_char = find_impression.span()

    #find_clinical = re.search("clinical correlation:", cleaned_and_decoded_text.lower(), flags=re.IGNORECASE)
    #if find_clinical != None:
    #  first_char_clin, last_char_clin = find_clinical.span()
    #  last_char = first_char_clin
    #else:
    last_char = last_char +50

    searched_text = cleaned_and_decoded_text.lower()[first_char:last_char]


    is_abnormal = re.search('abnormal|absence of normal|outside of the range of normal|not normal', searched_text.lower(), flags=re.IGNORECASE)
    #also_abnormal = re.search("absence of normal", searched_text.lower(), flags=re.IGNORECASE)
    
    # Check if the word 'abnormal' is in the report, and label it accordingly.
    if is_abnormal:
      predicted_label = 0
    else:
      predicted_label = 1
      
    # If we predicted correctly, add one to our count of correct predictions.
    if predicted_label==label_batch[ind]:
      n_correct = n_correct+1
    else:
      # Uncomment the lines below if you want to inspect the cases where we were wrong.
      # print("--- Wrong example ---")
       print(f"This example was classified with label {predicted_label} but its actual label is {label_batch[ind].numpy()}.")
       print("---")
       print(cleaned_and_decoded_text)
       print("---------------------")
      # pass

    # Add one to our count of the total number of examples examined.
    n = n+1


print(f"Accuracy = {round(100*n_correct/n,3)} percent ({n_correct} correct predictions out of {n}). {n - n_correct} misclassified.")
'''

'\n# First initialise some counters\nn = 0\nn_correct = 0\nn_failed_decode = 0\n\n\n# Iterate over all batches, taking the text and labels batch-by-batch.\n# N.B. take(-1) has the effect of pulling out all the batches, instead of a specific number, as explained in the docs here: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#take\nfor text_batch, label_batch in full_ds.take(-1):\n\n  # Iterate over the report examples in the batch:\n  for ind,text in enumerate(text_batch):\n\n    # Get rid of any pesky non-standard characters using the function we created previously.\n    cleaned_text = clean_text(text,0)\n    # Then convert it from a tensorflow Tensor to a python string so that we can \n    # use some standard python text analysis on it.\n    cleaned_and_decoded_text = cleaned_text.numpy().decode("UTF-8")\n\n    find_impression = re.search("impression", cleaned_and_decoded_text.lower(), flags=re.IGNORECASE)\n    first_char,last_char = find_impression.span()\n\n    #find_cl

### Configure the dataset for performance

These are two important methods you should use when loading data to make sure that I/O does not become blocking.

`.cache()` keeps data in memory after it's loaded off disk. This will ensure the dataset does not become a bottleneck while training your model. If your dataset is too large to fit into memory, you can also use this method to create a performant on-disk cache, which is more efficient to read than many small files.

`.prefetch()` overlaps data preprocessing and model execution while training. 

You can learn more about both methods, as well as how to cache data to disk in the [data performance guide](https://www.tensorflow.org/guide/data_performance).

In [31]:
AUTOTUNE = tf.data.AUTOTUNE

def configure_dataset(dataset):
  return dataset.cache().prefetch(buffer_size=AUTOTUNE)

In [32]:
binary_train_ds = configure_dataset(binary_train_ds)
binary_val_ds = configure_dataset(binary_val_ds)
binary_test_ds = configure_dataset(binary_test_ds)

int_train_ds = configure_dataset(int_train_ds)
int_val_ds = configure_dataset(int_val_ds)
int_test_ds = configure_dataset(int_test_ds)

### Train the model
It's time to create our neural network. For the `binary` vectorized data, train a simple bag-of-words linear model:

In [33]:
binary_model = tf.keras.Sequential([layers.Dense(2)])
binary_model.compile(
    loss=losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer='adam',
    metrics=['accuracy'])
history = binary_model.fit(
    binary_train_ds, validation_data=binary_val_ds, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


Next, you will use the `int` vectorized layer to build a 1D ConvNet.

In [34]:
def create_model(vocab_size, num_labels):
  model = tf.keras.Sequential([
      layers.Embedding(vocab_size, 64, mask_zero=True),
      layers.Conv1D(64, 5, padding="valid", activation="elu", strides=2),
      layers.GlobalMaxPooling1D(),
      layers.Dense(num_labels)
  ])
  return model

In [35]:
# vocab_size is VOCAB_SIZE + 1 since 0 is used additionally for padding.
int_model = create_model(vocab_size=VOCAB_SIZE + 1, num_labels=2)
int_model.compile(
    loss=losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer='adam',
    metrics=['accuracy'])
history = int_model.fit(int_train_ds, validation_data=int_val_ds, epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


Compare the two models:

In [36]:
print("Linear model on binary vectorized data:")
print(binary_model.summary())

Linear model on binary vectorized data:
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 2)                 12414     
Total params: 12,414
Trainable params: 12,414
Non-trainable params: 0
_________________________________________________________________
None


In [37]:
print("ConvNet model on int vectorized data:")
print(int_model.summary())

ConvNet model on int vectorized data:
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding (Embedding)        (None, None, 64)          640064    
_________________________________________________________________
conv1d (Conv1D)              (None, None, 64)          20544     
_________________________________________________________________
global_max_pooling1d (Global (None, 64)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 2)                 130       
Total params: 660,738
Trainable params: 660,738
Non-trainable params: 0
_________________________________________________________________
None


Evaluate both models on the test data:

In [38]:
binary_loss, binary_accuracy = binary_model.evaluate(binary_test_ds)
int_loss, int_accuracy = int_model.evaluate(int_test_ds)

print("Binary model accuracy: {:2.2%}".format(binary_accuracy))
print("Int model accuracy: {:2.2%}".format(int_accuracy))

Binary model accuracy: 99.07%
Int model accuracy: 99.53%


Note: This example dataset represents a rather simple classification problem. More complex datasets and problems bring out subtle but significant differences in preprocessing strategies and model architectures. Be sure to try out different hyperparameters and epochs to compare various approaches.

### Export the model

In the code above, you applied the `TextVectorization` layer to the dataset before feeding text to the model. If you want to make your model capable of processing raw strings (for example, to simplify deploying it), you can include the `TextVectorization` layer inside your model. To do so, you can create a new model using the weights you just trained.

In [39]:
export_model = tf.keras.Sequential(
    [binary_vectorize_layer, binary_model,
     layers.Activation('elu')])

export_model.compile(
    loss=losses.SparseCategoricalCrossentropy(from_logits=False),
    optimizer='adam',
    metrics=['accuracy'])

# Test it with `raw_test_ds`, which yields raw strings
loss, accuracy = export_model.evaluate(raw_val_ds)
print("Accuracy: {:2.2%}".format(accuracy))

Accuracy: 99.50%


Now your model can take raw strings as input and predict a score for each label using `model.predict`. Define a function to find the label with the maximum score:

In [40]:
labels = ['abnormal', 'normal']
def get_string_labels(predicted_scores_batch):
  predicted_int_labels = tf.argmax(predicted_scores_batch, axis=1)
  predicted_labels = []
  for intlab in predicted_int_labels:
    predicted_labels.append(labels[intlab.numpy()])
  #  predicted_labels = tf.gather(['raw_train_ds.class_names'], predicted_int_labels)
  return predicted_labels

### Run inference on new data

Now we can create a few custom inputs to explore the model's behaviour.

In [41]:
inputs = [
    "This EEG is totally normal",  # normal
    "This recording is markedly abnormal",  # abnormal
    "This shows no abnormalities",  # abnormal
    "Some ever so slight abnormalities, but then again, who can say what normal really means",  # abnormal
    "They seem fine.",  # normal?
    "They are fine.", # normal
    "This person is fine.",  # normal
    "This person is very unwell.",  # abnormal
    "IMPRESSION: abnormal", # abnormal
    "IMPRESSION: markedly abnormal", # abnormal
    "IMPRESSION: This recording is markedly abnormal", # abnormal
    "They are not normal" #abnormal
]
predicted_scores = export_model.predict(inputs)
print(predicted_scores)
predicted_labels = get_string_labels(predicted_scores)
for input, label, scores in zip(inputs, predicted_labels, predicted_scores):
  print("-----------------------------------")
  print("Question: ", input)
  print("Predicted label: ", label)
  print("Confidence scores: abnormal vs normal")
  print(f"        {round(scores[0], 2)} vs {round(scores[1], 2)}")

[[-0.19692296  0.2900898 ]
 [ 0.5372987  -0.4347239 ]
 [-0.06418455  0.09787869]
 [-0.25355804  0.35264894]
 [ 0.05671321  0.0017204 ]
 [ 0.02679556  0.0737832 ]
 [-0.02480018  0.07440683]
 [ 0.13213836 -0.10222548]
 [ 0.3157288  -0.308609  ]
 [ 0.41592282 -0.38063073]
 [ 0.5337525  -0.43717706]
 [-0.23353177  0.342391  ]]
-----------------------------------
Question:  This EEG is totally normal
Predicted label:  normal
Confidence scores: abnormal vs normal
        -0.20000000298023224 vs 0.28999999165534973
-----------------------------------
Question:  This recording is markedly abnormal
Predicted label:  abnormal
Confidence scores: abnormal vs normal
        0.5400000214576721 vs -0.4300000071525574
-----------------------------------
Question:  This shows no abnormalities
Predicted label:  normal
Confidence scores: abnormal vs normal
        -0.05999999865889549 vs 0.10000000149011612
-----------------------------------
Question:  Some ever so slight abnormalities, but then again, 

Including the text preprocessing logic inside your model enables you to export a model for production that simplifies deployment, and reduces the potential for [train/test skew](https://developers.google.com/machine-learning/guides/rules-of-ml#training-serving_skew).

There is a performance difference to keep in mind when choosing where to apply your `TextVectorization` layer. Using it outside of your model enables you to do asynchronous CPU processing and buffering of your data when training on GPU. So, if you're training your model on the GPU, you probably want to go with this option to get the best performance while developing your model, then switch to including the TextVectorization layer inside your model when you're ready to prepare for deployment.

Visit this [tutorial](https://www.tensorflow.org/tutorials/keras/save_and_load) to learn more about saving models.