In [1]:
import logging
import argparse 
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import tensorflow as tf
from tensorflow import keras
import tensorflow_datasets as tfds

logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
log = logging.getLogger()
%config Completer.use_jedi = False # make autocompletion works in jupyter

args = argparse.Namespace()
args.data_folder = './data-ignored/imdb/'
args.val_fraction = 0.25
args.vocab_size = 2500
args.epochs = 50

Path(args.data_folder).mkdir(parents=True, exist_ok=True)

ds, info = tfds.load('imdb_reviews', with_info=True, as_supervised=True, data_dir=args.data_folder)
train_ds_len= tf.data.experimental.cardinality(ds['train']).numpy()
test_ds_len= tf.data.experimental.cardinality(ds['test']).numpy() 
print(train_ds_len)
for d in ds['train'].take(1):
    print(d)
    
train_dataset = ds['train']
val_dataset = ds['test'].take(int(args.val_fraction * (train_ds_len + test_ds_len)))
test_dataset = ds['test'].skip(int(args.val_fraction * (train_ds_len + test_ds_len)))

2021-11-11 16:11:51,768 : INFO : No config specified, defaulting to first: imdb_reviews/plain_text
2021-11-11 16:11:53,408 : INFO : Load pre-computed DatasetInfo (eg: splits, num examples,...) from GCS: imdb_reviews/plain_text/1.0.0
2021-11-11 16:11:54,844 : INFO : Load dataset info from /var/folders/0q/20bc6kc571l0p1wnm1zwcpw40000gn/T/tmp66ce8ex3tfds
2021-11-11 16:11:54,862 : INFO : Field info.config_name from disk and from code do not match. Keeping the one from code.
2021-11-11 16:11:54,866 : INFO : Field info.config_description from disk and from code do not match. Keeping the one from code.
2021-11-11 16:11:54,869 : INFO : Field info.citation from disk and from code do not match. Keeping the one from code.
2021-11-11 16:11:54,874 : INFO : Field info.splits from disk and from code do not match. Keeping the one from code.
2021-11-11 16:11:54,878 : INFO : Field info.module_name from disk and from code do not match. Keeping the one from code.
2021-11-11 16:11:54,890 : INFO : Generatin

[1mDownloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to ./data-ignored/imdb/imdb_reviews/plain_text/1.0.0...[0m


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

2021-11-11 16:11:55,887 : INFO : Downloading http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz into data-ignored/imdb/downloads/ai.stanfor.edu_amaas_sentime_aclImdb_v1PaujRp-TxjBWz59jHXsMDm5WiexbxzaFQkEnXc3Tvo8.tar.gz.tmp.0ea1b3b6676a4d9d903d5167f5496ade...


KeyboardInterrupt: 

### Baseline

In [37]:
import functools

@functools.lru_cache()
def get_encoder():
    encoder = TextVectorization(max_tokens=args.vocab_size)
    encoder.adapt(train_dataset.map(lambda text, label: text))
    return encoder

In [39]:
from tensorflow.keras.layers.experimental.preprocessing import TextVectorization

def baseline():
    encoder = get_encoder()
    
    model = keras.models.Sequential()
    model.add(keras.layers.Embedding(
        input_dim=len(encoder.get_vocabulary()),
        output_dim=64,
        # Use masking to handle the variable sequence lengths
        mask_zero=True))
    model.add(keras.layers.Bidirectional(tf.keras.layers.LSTM(64)))
    model.add(keras.layers.Dense(64, activation='relu'))
    model.add(keras.layers.Dense(1, activation='sigmoid'))
    
    model.compile(optimizer=keras.optimizers.Nadam(learning_rate=1e-3),
              loss='binary_crossentropy',
              metrics=['accuracy'])
    
    monitor='val_loss'
    early_stopping = keras.callbacks.EarlyStopping(monitor=monitor, patience=10, mode='auto', restore_best_weights=True, verbose=1)
    reduce_lr_on_plateau = keras.callbacks.ReduceLROnPlateau(monitor=monitor, factor=0.1, patience=3, min_delta=1e-4, mode='auto', verbose=1)
    
    model.fit(train_dataset, validation_data=val_dataset, epochs=1)

baseline()

tf.Tensor(
[  11   14   34  412  384   18   90   28    1    8   33 1322    1   42
  487    1  191   24   85  152   19   11  217  316   28   65  240  214
    8  489   54   65   85  112   96   22    1   11   93  642  743   11
   18    7   34  394    1  170 2464  408    2   88 1216  137   66  144
   51    2    1    1   66  245   65    1   16    1    1    1    1 1426
    1    3   40    1 1579   17    1   14  158   19    4 1216  891    1
    8    4   18   12   14    1    5   99  146 1241   10  237  704   12
   48   24   93   39   11    1  152   39 1322    1   50  398   10   96
 1155  851  141    9], shape=(116,), dtype=int64)
