## 1. Download data

In [None]:
# Based on https://www.tensorflow.org/tutorials/keras/text_classification
import tensorflow as tf
import os
import shutil

url = "https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"

dataset = tf.keras.utils.get_file("aclImdb_v1", url,  untar=True, cache_dir='.', cache_subdir='')

dataset_dir = os.path.join(os.path.dirname(dataset), 'aclImdb')

Downloading data from https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz


In [None]:
from google.colab import drive
drive.mount('/content/drive')

## 2. Remove additional folders

In [None]:
remove_dir = os.path.join(dataset_dir + '/train/', 'unsup')
shutil.rmtree(remove_dir)

## 3. Load data from folders

In [None]:
train_ds = tf.keras.utils.text_dataset_from_directory(
    'aclImdb/train', 
    batch_size=32, 
    validation_split=0.2, 
    subset='training', seed=111)

val_ds = tf.keras.utils.text_dataset_from_directory(
    'aclImdb/train', 
    batch_size=32, 
    validation_split=0.2, 
    subset='validation', seed=111)

test_ds = tf.keras.utils.text_dataset_from_directory(
    'aclImdb/test', 
    batch_size=32)

## 4. View data

In [None]:
for x, y in train_ds.take(1):
  print(y, x)

## 5. Preprocess text



In [None]:
import re
import string 

def custom_standardization(input_data):
  lowercase = tf.strings.lower(input_data)
  return tf.strings.regex_replace(lowercase, '[%s]' % re.escape(string.punctuation), '')

max_features = 10000
sequence_length = 250

vectorize_layer = tf.keras.layers.TextVectorization(
    standardize=custom_standardization,
    max_tokens=max_features,
    output_mode='int',
    output_sequence_length=sequence_length)

train_text = train_ds.map(lambda x, y: x)
vectorize_layer.adapt(train_text)

In [None]:
text = tf.constant(['Hello World'])
print(vectorize_layer(text))

## 6. Create and train

In [None]:
inputs = tf.keras.layers.Input((1,), dtype=tf.string)
embedding_layer = tf.keras.layers.Embedding(max_features + 1, 100, trainable=True)

x = vectorize_layer(inputs)
x = embedding_layer(x)

conv1 = tf.keras.layers.Conv1D(100, kernel_size=2)(x)
conv1 = tf.keras.layers.Dropout(0.4)(conv1)
conv1 = tf.keras.layers.GlobalMaxPool1D()(conv1)

conv2 = tf.keras.layers.Conv1D(100, kernel_size=3)(x)
conv2 = tf.keras.layers.Dropout(0.4)(conv2)
conv2 = tf.keras.layers.GlobalMaxPool1D()(conv2)

conv3 = tf.keras.layers.Conv1D(100, kernel_size=4)(x)
conv3 = tf.keras.layers.Dropout(0.4)(conv3)
conv3 = tf.keras.layers.GlobalMaxPool1D()(conv3)

conv4 = tf.keras.layers.Conv1D(100, kernel_size=5)(x)
conv4 = tf.keras.layers.Dropout(0.4)(conv4)
conv4 = tf.keras.layers.GlobalMaxPool1D()(conv4)

outputs = tf.keras.layers.Add()([conv1, conv2, conv3, conv4])
outputs = tf.keras.layers.Dense(100, activation='relu')(outputs)
outputs = tf.keras.layers.Dense(75, activation='relu')(outputs)
outputs = tf.keras.layers.Dense(50, activation='relu')(outputs)
outputs = tf.keras.layers.Dense(1, activation = 'sigmoid')(outputs)

model = tf.keras.Model(inputs, outputs)

model.summary()
model.compile(optimizer='adam', loss=tf.keras.losses.BinaryCrossentropy(), metrics=tf.metrics.BinaryAccuracy())
epochs = 3
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs)

### 7. Evaluate model

In [None]:
loss, accuracy = model.evaluate(test_ds)

print("Loss: ", loss)
print("Accuracy: ", accuracy)

In [None]:
text = tf.constant(['This movie was really funny'])
model(text)

## 8. Calculate the gradient of embedded_text with respect to the output

In [None]:
embedding_layer = model.layers[2]
new_model = tf.keras.Model(inputs=model.layers[3].input, outputs=model.layers[-1].output)

#text = tf.constant(['I love this film. It is well written and acted and has good cinematography'])
#text = tf.constant(["Let me start off by saying that this doesn't seem or feel like a movie"])
#text = tf.constant(["I didn't enjoy this film. I thought the acting wasn't very good and the story was boring."])
#text = tf.constant(["I'm not going to say that this movie is horrible, because I have seen worse, but it's not even halfway decent."])
#text = tf.constant(["This movie was just horrible"])
#text = tf.constant(["This was a very dull but enjoyable movie"])
#text = tf.constant(["This movie was a disgrace"])

with tf.GradientTape(persistent=True) as tape:
    x = vectorize_layer(text)
    embedded_text = embedding_layer(x)
    tape.watch(embedded_text)
    y = new_model(embedded_text)
    print(y)

grads = tape.gradient(y, embedded_text)

text = tf.strings.split(text)
text = [x.decode('utf-8') for x in text.to_tensor().numpy()[0]]
output = tf.squeeze(tf.reduce_max(tf.abs(grads), axis=-1, keepdims=True)[:, :len(text), :], axis=-1)[0].numpy()
print(output)

## 9. Visualize results

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import FormatStrFormatter

sns.set_style("whitegrid", {"grid.color": ".98"})

def gradientbars(bars, cmap):
    ax = bars[0].axes
    xmin, xmax = ax.get_xlim()
    ymin, ymax = ax.get_ylim()
    for bar in bars:
        bar.set_zorder(1)
        bar.set_facecolor("none")
        x, y = bar.get_xy()
        w, h = bar.get_width(), bar.get_height()
        grad = np.linspace(y, y + h, 256).reshape(256, 1)
        ax.imshow(grad, extent=[x, x + w, y, y + h], aspect="auto", zorder=0, origin='lower',
                  vmin=ymin, vmax=ymax, cmap=cmap)
    ax.axis([xmin, xmax, ymin, ymax])

def erase_xaxis(ax):
    ax.tick_params(
    axis='x',         
    which='both',      
    bottom=False,      
    top=False,        
    labelbottom=False)
    
def add_text(rects, text, ax):
    for idx, rect in enumerate(rects):
        height = rect.get_height()
        ax.text(rect.get_x()+rect.get_width()/2., 1.01*height, text[idx],
                ha='center', va='bottom', size=11)
    
fig, ax1 = plt.subplots(1, 1, figsize=(7, 7), dpi=80)
fig.tight_layout()
fig.subplots_adjust(hspace=0.3, wspace=0.5)

ax1.set_title('Text w.r.t  \noutput probability:{:f}'.format(y[0][0]))
rects = ax1.bar(np.arange(len(output)), output)
gradientbars(rects, 'Blues')
erase_xaxis(ax1)
add_text(rects, text, ax1)
ax1.ticklabel_format(style='plain', axis='y')
ax1.yaxis.set_major_formatter(FormatStrFormatter('%.3f'))

plt.show()
