In [49]:
import tensorflow as tf
from tensorflow import keras
import numpy as np
import pandas as pd
import pickle
from IPython.display import Markdown, display

In [2]:
class HANModule(keras.layers.Layer):
    def __init__(self, num_cells, **kwargs):
        self.num_cells = num_cells
        super(HANModule, self).__init__(**kwargs)

    def build(self, input_shape):
        self.W = self.add_weight(name="W",
                                 shape=(self.num_cells * 2, 1),
                                 initializer="glorot_uniform",
                                 trainable=True)
        super(HANModule, self).build(input_shape)

    def get_config(self):
        config = super(HANModule, self).get_config()
        config["num_cells"] = self.num_cells
        return config

    def call(self, x):
        return keras.backend.dot(x, self.W)

    def compute_output_shape(self, input_shape):
        return (input_shape.shape[1],)

In [3]:
def get_model(input_length, num_sentences, vocab_size, embedding_size, num_cells, num_output, dropout_prob=0.2):
    input_layer = keras.layers.Input(shape=(num_sentences, input_length,))
    input_sentences = keras.layers.Lambda(lambda x: tf.unstack(x, axis=1))(input_layer)
    embedding_layer = keras.layers.Embedding(vocab_size + 1, embedding_size)
    out = []
    sentence_gru_layer = keras.layers.Bidirectional(keras.layers.GRU(num_cells, return_sequences=True))
    sentence_dense_layer = keras.layers.TimeDistributed(keras.layers.Dense(num_cells * 2, activation="tanh"))
    sentence_embedding_dropout_layer = keras.layers.Dropout(dropout_prob)
    sentence_gru_dropout_layer = keras.layers.Dropout(dropout_prob)
    sentence_dense_dropout_layer = keras.layers.Dropout(dropout_prob)
    sentence_han_layer = HANModule(num_cells)
    for i in range(len(input_sentences)):
        embeddings = embedding_layer(input_sentences[i])
        embeddings = sentence_embedding_dropout_layer(embeddings)
        gru_out = sentence_gru_layer(embeddings)
        gru_out = sentence_gru_dropout_layer(gru_out)
        dense_out = sentence_dense_layer(gru_out)
        dense_out = sentence_dense_dropout_layer(dense_out)
        attention_out = sentence_han_layer(dense_out)
        attention_out = keras.layers.Flatten()(attention_out)
        softmax_out = keras.layers.Activation("softmax", name="attention_word_{0}".format(str(i)))(attention_out)
        softmax_out = keras.layers.Reshape((input_length, 1))(softmax_out)
        mul_out = keras.layers.Multiply()([softmax_out, gru_out])
        sum_out = keras.layers.Lambda(lambda x: keras.backend.sum(x, axis=1))(mul_out)
        sum_out = keras.layers.Flatten()(sum_out)
        out.append(sum_out)
    stacked_input = keras.layers.Lambda(lambda x: tf.stack(x, axis=1))(out)
    gru_out = keras.layers.Bidirectional(keras.layers.GRU(num_cells, return_sequences=True))(stacked_input)
    gru_out = keras.layers.Dropout(dropout_prob)(gru_out)
    dense_out = keras.layers.TimeDistributed(keras.layers.Dense(num_cells * 2, activation="tanh"))(gru_out)
    dense_out = keras.layers.Dropout(dropout_prob)(dense_out)
    attention_out = HANModule(num_cells)(dense_out)
    attention_out = keras.layers.Flatten()(attention_out)
    softmax_out = keras.layers.Activation("softmax", name="attention_sentence")(attention_out)
    softmax_out = keras.layers.Reshape((num_sentences, 1))(softmax_out)
    mul_out = keras.layers.Multiply()([softmax_out, gru_out])
    sum_out = keras.layers.Lambda(lambda x: keras.backend.sum(x, axis=1))(mul_out)
    out = keras.layers.Dense(num_output, activation="softmax")(sum_out)
    model = keras.models.Model(inputs=input_layer, outputs=out)
    return model

In [4]:
def load_model(model_location):
    return keras.models.load_model(model_location, custom_objects={"tf": tf, "HANModule": HANModule, "keras": keras})

In [5]:
def process_inputs(input_data, sequence_length, num_sentences, sentence_sep=".",
                   pad_direction="pre", truncate_direction="pre", tokenizer=None):
    assert pad_direction in ["pre", "post"]
    assert truncate_direction in ["pre", "post"]
    if not tokenizer:
        tokenizer = keras.preprocessing.text.Tokenizer()
        tokenizer.fit_on_texts(input_data)
    x = np.zeros((len(input_data), num_sentences, sequence_length), dtype="int32")

    for i in range(len(input_data)):
        sentences = input_data[i].split(sentence_sep)
        if sentences[-1] == "":
            sentences = sentences[:-1]
        sentences_sequences = tokenizer.texts_to_sequences(sentences)
        padded_sentences = keras.preprocessing.sequence.pad_sequences(sentences_sequences,
                                                                      sequence_length,
                                                                      padding=pad_direction,
                                                                      truncating=truncate_direction)
        if truncate_direction == "pre":
            padded_sentences = padded_sentences[-num_sentences:]
        else:
            padded_sentences = padded_sentences[:num_sentences]

        if pad_direction == "pre":
            x[i, -len(padded_sentences):] = padded_sentences
        else:
            x[i, :len(padded_sentences)] = padded_sentences
    return x, tokenizer

In [6]:
def process_output(output_data):
    output_distribution = output_data.value_counts()
    output_classes = output_distribution.index
    max_count = max(output_distribution)
    output_distribution = list(max_count / output_distribution)
    class_weights = {i: output_distribution[i] for i in range(len(output_distribution))}
    output_mapping = {output_classes[i]: i for i in range(len(output_classes))}
    y = np.array([output_mapping[i] for i in output_data])
    y = keras.utils.to_categorical(y, num_classes=len(output_mapping))
    return y, output_mapping, class_weights

In [7]:
def train(data_file_location, input_col_name, output_col_name,
          embedding_length, pad_direction, truncate_direction, sequence_length, num_sentence, sentence_sep, num_cells, epochs, batch_size):
    data = pd.read_csv(data_file_location).sample(frac=1.0).reset_index(drop=True).iloc[:20000].dropna()
    input_sentences = data[input_col_name].values.tolist()
    output_classes = data[output_col_name]
    x, tokenizer = process_inputs(input_sentences, sequence_length, num_sentence, sentence_sep, pad_direction,
                                  truncate_direction)
    y, output_mapping, class_weights = process_output(output_classes)
    model_metadata = {"tokenizer": tokenizer, "output_mapping": output_mapping, "sentence_sep": sentence_sep, "pad_direction": pad_direction,
                      "truncate_direction": truncate_direction, "sequence_length": sequence_length, "num_sentence": num_sentence}
    pickle.dump(model_metadata, open("model_artifacts/model_metadata.pkl", "wb+"))
    model = get_model(sequence_length, num_sentence, len(tokenizer.word_index), embedding_length, num_cells,
                      len(output_mapping))
    model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["acc"])
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        "model_artifacts/weights.{epoch:02d}.hdf5")
    model.fit(x, y, validation_split=0.1, epochs=epochs, class_weight=class_weights, batch_size=batch_size, callbacks=[checkpoint_callback])

In [8]:
def predict(input_sentences, model, model_metadata):
    tokenizer, output_mapping, sentence_sep, pad_direction, \
            truncate_direction, sequence_length, num_sentences = model_metadata["tokenizer"], model_metadata["output_mapping"], \
                                                                 model_metadata["sentence_sep"], model_metadata["pad_direction"], \
                                                                 model_metadata["truncate_direction"], \
                                                                 model_metadata["sequence_length"], model_metadata["num_sentence"]
    x = process_inputs(input_sentences, sequence_length, num_sentences, sentence_sep, pad_direction,
                       truncate_direction, tokenizer)
    model_output = model.predict(x)
    output_mapping = {j: i for i, j in output_mapping.items()}
    argsort_output = np.argsort(model_output)
    return [[output_mapping[j] for j in i[::-1]] for i in argsort_output]

In [119]:
def visualize_attention(input_sentence, model, model_metadata):
    tokenizer, output_mapping, sentence_sep, pad_direction, \
            truncate_direction, sequence_length, num_sentences = model_metadata["tokenizer"], model_metadata["output_mapping"], \
                                                                 model_metadata["sentence_sep"], model_metadata["pad_direction"], \
                                                                 model_metadata["truncate_direction"], \
                                                                 model_metadata["sequence_length"], model_metadata["num_sentence"]
    x = process_inputs([input_sentence], sequence_length, num_sentences, sentence_sep, pad_direction,
                       truncate_direction, tokenizer)
    sentence_attention = keras.models.Model(inputs = model.inputs, outputs=model.get_layer("attention_sentence").output).predict(x[0])[0]
    rev_word_index = {j: i for i, j in tokenizer.word_index.items()}
    rev_word_index[0] = "<blank>"
    for i in range(num_sentences):
        word_attention = keras.models.Model(inputs = model.inputs, outputs=model.get_layer("attention_word_{0}".format(str(i))).output).predict(x[0])[0]
        colorstr = ""
        for j in range(len(word_attention)):
            colorstr += "<span style='font-size:{0}%;'>{1} </span>".format(100 + 4 * (word_attention[j] * 100), rev_word_index[x[0][0][i][j]])
        colorstr = "<span style='color:hsl(360, 100%, {0}%)'>".format((100 - (sentence_attention[i] * 50) - 15)) + colorstr + "</span>"
        display(Markdown(colorstr))
        

In [52]:
colorstr = "<span style='font-size:100%;'>himanshu</span>"
display(Markdown(colorstr))

<span style='font-size:100%;'>himanshu</span>

In [10]:
sequence_length = 20
num_sentence = 10
embedding_length = 100
data_file_location = "E:\Electronics_5_sample.csv"
input_col_name = "reviewText"
output_col_name = "overall"
pad_direction = "pre"
truncate_direction = "pre"
sentence_sep = "."
num_cells = 50
epochs = 5
batch_size = 200
train(data_file_location, input_col_name, output_col_name, embedding_length,
      pad_direction, truncate_direction, sequence_length, num_sentence,
      sentence_sep, num_cells, epochs, batch_size)

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
Train on 17858 samples, validate on 1985 samples
Instructions for updating:
Use tf.cast instead.
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


In [129]:
model_metadata = pickle.load(open("model_artifacts/model_metadata.pkl", "rb"))
model = load_model("model_artifacts/weights.01.hdf5")

In [131]:
input_sentences = ["Apple shutdown its office. China person dies"]
data = pd.read_csv(data_file_location).sample(frac=1.0).reset_index(drop=True).iloc[:10].dropna()
input_sentences = data[input_col_name].values.tolist()
visualize_attention(model=model, model_metadata=model_metadata, input_sentence=input_sentences[0])
print("Prediction: ", "  >  ".join([str(i) for i in predict(input_sentences[0], model, model_metadata)[0]]))

<span style='color:hsl(360, 100%, 78.4986263513565%)'><span style='font-size:104.7205463051796%;'><blank> </span><span style='font-size:104.63782288134098%;'><blank> </span><span style='font-size:104.82945702970028%;'><blank> </span><span style='font-size:105.22045269608498%;'><blank> </span><span style='font-size:105.7732306420803%;'><blank> </span><span style='font-size:106.47165700793266%;'><blank> </span><span style='font-size:107.31459781527519%;'><blank> </span><span style='font-size:108.31368565559387%;'><blank> </span><span style='font-size:109.49304923415184%;'><blank> </span><span style='font-size:110.89034453034401%;'><blank> </span><span style='font-size:112.55918890237808%;'><blank> </span><span style='font-size:114.57344740629196%;'><blank> </span><span style='font-size:117.0345813035965%;'><blank> </span><span style='font-size:120.08417844772339%;'><blank> </span><span style='font-size:123.92605543136597%;'><blank> </span><span style='font-size:128.86740863323212%;'><blank> </span><span style='font-size:135.40052771568298%;'><blank> </span><span style='font-size:144.3786382675171%;'><blank> </span><span style='font-size:157.43206143379211%;'><blank> </span><span style='font-size:178.07905673980713%;'><blank> </span></span>

<span style='color:hsl(360, 100%, 80.47310099005699%)'><span style='font-size:104.7205463051796%;'><blank> </span><span style='font-size:104.63782288134098%;'><blank> </span><span style='font-size:104.82945702970028%;'><blank> </span><span style='font-size:105.22045269608498%;'><blank> </span><span style='font-size:105.7732306420803%;'><blank> </span><span style='font-size:106.47165700793266%;'><blank> </span><span style='font-size:107.31459781527519%;'><blank> </span><span style='font-size:108.31368565559387%;'><blank> </span><span style='font-size:109.49304923415184%;'><blank> </span><span style='font-size:110.89034453034401%;'><blank> </span><span style='font-size:112.55918890237808%;'><blank> </span><span style='font-size:114.57344740629196%;'><blank> </span><span style='font-size:117.0345813035965%;'><blank> </span><span style='font-size:120.08417844772339%;'><blank> </span><span style='font-size:123.92605543136597%;'><blank> </span><span style='font-size:128.86740863323212%;'><blank> </span><span style='font-size:135.40052771568298%;'><blank> </span><span style='font-size:144.3786382675171%;'><blank> </span><span style='font-size:157.43206143379211%;'><blank> </span><span style='font-size:178.07905673980713%;'><blank> </span></span>

<span style='color:hsl(360, 100%, 81.38334661722183%)'><span style='font-size:104.7205463051796%;'><blank> </span><span style='font-size:104.63782288134098%;'><blank> </span><span style='font-size:104.82945702970028%;'><blank> </span><span style='font-size:105.22045269608498%;'><blank> </span><span style='font-size:105.7732306420803%;'><blank> </span><span style='font-size:106.47165700793266%;'><blank> </span><span style='font-size:107.31459781527519%;'><blank> </span><span style='font-size:108.31368565559387%;'><blank> </span><span style='font-size:109.49304923415184%;'><blank> </span><span style='font-size:110.89034453034401%;'><blank> </span><span style='font-size:112.55918890237808%;'><blank> </span><span style='font-size:114.57344740629196%;'><blank> </span><span style='font-size:117.0345813035965%;'><blank> </span><span style='font-size:120.08417844772339%;'><blank> </span><span style='font-size:123.92605543136597%;'><blank> </span><span style='font-size:128.86740863323212%;'><blank> </span><span style='font-size:135.40052771568298%;'><blank> </span><span style='font-size:144.3786382675171%;'><blank> </span><span style='font-size:157.43206143379211%;'><blank> </span><span style='font-size:178.07905673980713%;'><blank> </span></span>

<span style='color:hsl(360, 100%, 81.94209042936563%)'><span style='font-size:104.7205463051796%;'><blank> </span><span style='font-size:104.63782288134098%;'><blank> </span><span style='font-size:104.82945702970028%;'><blank> </span><span style='font-size:105.22045269608498%;'><blank> </span><span style='font-size:105.7732306420803%;'><blank> </span><span style='font-size:106.47165700793266%;'><blank> </span><span style='font-size:107.31459781527519%;'><blank> </span><span style='font-size:108.31368565559387%;'><blank> </span><span style='font-size:109.49304923415184%;'><blank> </span><span style='font-size:110.89034453034401%;'><blank> </span><span style='font-size:112.55918890237808%;'><blank> </span><span style='font-size:114.57344740629196%;'><blank> </span><span style='font-size:117.0345813035965%;'><blank> </span><span style='font-size:120.08417844772339%;'><blank> </span><span style='font-size:123.92605543136597%;'><blank> </span><span style='font-size:128.86740863323212%;'><blank> </span><span style='font-size:135.40052771568298%;'><blank> </span><span style='font-size:144.3786382675171%;'><blank> </span><span style='font-size:157.43206143379211%;'><blank> </span><span style='font-size:178.07905673980713%;'><blank> </span></span>

<span style='color:hsl(360, 100%, 82.42542710155249%)'><span style='font-size:104.7205463051796%;'><blank> </span><span style='font-size:104.63782288134098%;'><blank> </span><span style='font-size:104.82945702970028%;'><blank> </span><span style='font-size:105.22045269608498%;'><blank> </span><span style='font-size:105.7732306420803%;'><blank> </span><span style='font-size:106.47165700793266%;'><blank> </span><span style='font-size:107.31459781527519%;'><blank> </span><span style='font-size:108.31368565559387%;'><blank> </span><span style='font-size:109.49304923415184%;'><blank> </span><span style='font-size:110.89034453034401%;'><blank> </span><span style='font-size:112.55918890237808%;'><blank> </span><span style='font-size:114.57344740629196%;'><blank> </span><span style='font-size:117.0345813035965%;'><blank> </span><span style='font-size:120.08417844772339%;'><blank> </span><span style='font-size:123.92605543136597%;'><blank> </span><span style='font-size:128.86740863323212%;'><blank> </span><span style='font-size:135.40052771568298%;'><blank> </span><span style='font-size:144.3786382675171%;'><blank> </span><span style='font-size:157.43206143379211%;'><blank> </span><span style='font-size:178.07905673980713%;'><blank> </span></span>

<span style='color:hsl(360, 100%, 83.78017092123628%)'><span style='font-size:111.9294598698616%;'><blank> </span><span style='font-size:111.90865263342857%;'><blank> </span><span style='font-size:112.67248839139938%;'><blank> </span><span style='font-size:114.12331163883209%;'><blank> </span><span style='font-size:116.33408814668655%;'><blank> </span><span style='font-size:119.60434168577194%;'><blank> </span><span style='font-size:131.59812092781067%;'>i </span><span style='font-size:136.33569180965424%;'>bought </span><span style='font-size:131.26491904258728%;'>this </span><span style='font-size:109.70601662993431%;'>lens </span><span style='font-size:146.50352895259857%;'>because </span><span style='font-size:125.86284577846527%;'>lately </span><span style='font-size:114.1014963388443%;'>i've </span><span style='font-size:114.35286849737167%;'>been </span><span style='font-size:118.17687004804611%;'>sort </span><span style='font-size:117.5527423620224%;'>of </span><span style='font-size:126.97702050209045%;'>in </span><span style='font-size:104.15419787168503%;'>love </span><span style='font-size:116.7123094201088%;'>with </span><span style='font-size:120.12903392314911%;'>lenses </span></span>

<span style='color:hsl(360, 100%, 83.67025669664145%)'><span style='font-size:112.06335201859474%;'>i </span><span style='font-size:105.01824989914894%;'>travel </span><span style='font-size:110.46695038676262%;'>a </span><span style='font-size:104.99511882662773%;'>lot </span><span style='font-size:103.52796614170074%;'>and </span><span style='font-size:106.57555684447289%;'>i </span><span style='font-size:106.5852902829647%;'>like </span><span style='font-size:110.94803214073181%;'>visiting </span><span style='font-size:116.63995832204819%;'>a </span><span style='font-size:108.51939022541046%;'>lot </span><span style='font-size:112.74892836809158%;'>of </span><span style='font-size:109.78233367204666%;'>spots </span><span style='font-size:114.30900543928146%;'>most </span><span style='font-size:122.01040834188461%;'>of </span><span style='font-size:143.8704788684845%;'>them </span><span style='font-size:131.27542436122894%;'>being </span><span style='font-size:136.34455502033234%;'>buildings </span><span style='font-size:141.15647077560425%;'>with </span><span style='font-size:139.2597794532776%;'>history </span><span style='font-size:163.90274167060852%;'>etc </span></span>

<span style='color:hsl(360, 100%, 84.14724687114358%)'><span style='font-size:114.69804495573044%;'><blank> </span><span style='font-size:113.96778374910355%;'><blank> </span><span style='font-size:114.41269516944885%;'><blank> </span><span style='font-size:126.57857835292816%;'>the </span><span style='font-size:103.51582989096642%;'>great </span><span style='font-size:123.94513785839081%;'>thing </span><span style='font-size:114.603191614151%;'>is </span><span style='font-size:125.72813928127289%;'>that </span><span style='font-size:130.6143969297409%;'>the </span><span style='font-size:110.13194993138313%;'>lens </span><span style='font-size:114.34132158756256%;'>is </span><span style='font-size:110.30609533190727%;'>very </span><span style='font-size:114.49169665575027%;'>wide </span><span style='font-size:114.37739878892899%;'>angled </span><span style='font-size:109.56038013100624%;'>and </span><span style='font-size:115.13583958148956%;'>has </span><span style='font-size:106.16682022809982%;'>amazing </span><span style='font-size:137.27712035179138%;'>low </span><span style='font-size:156.61287903785706%;'>light </span><span style='font-size:143.53472292423248%;'>performance </span></span>

<span style='color:hsl(360, 100%, 67.3703184723854%)'><span style='font-size:101.2062975205481%;'><blank> </span><span style='font-size:101.27057135105133%;'><blank> </span><span style='font-size:101.42805622890592%;'><blank> </span><span style='font-size:101.67658906430006%;'><blank> </span><span style='font-size:102.02608294785023%;'><blank> </span><span style='font-size:102.4982389062643%;'><blank> </span><span style='font-size:103.13079915940762%;'><blank> </span><span style='font-size:103.98804657161236%;'><blank> </span><span style='font-size:105.18353693187237%;'><blank> </span><span style='font-size:106.93128854036331%;'><blank> </span><span style='font-size:109.67172607779503%;'><blank> </span><span style='font-size:118.67625266313553%;'>i </span><span style='font-size:138.99564445018768%;'>do </span><span style='font-size:291.40182733535767%;'>not </span><span style='font-size:133.835107088089%;'>regret </span><span style='font-size:125.23716688156128%;'>one </span><span style='font-size:113.01716268062592%;'>bit </span><span style='font-size:114.98797535896301%;'>paying </span><span style='font-size:112.52227574586868%;'>for </span><span style='font-size:112.3153530061245%;'>this </span></span>

<span style='color:hsl(360, 100%, 76.30942150950432%)'><span style='font-size:136.3522708415985%;'>try </span><span style='font-size:134.56985354423523%;'>it </span><span style='font-size:118.90083402395248%;'>on </span><span style='font-size:105.55360727012157%;'>macro </span><span style='font-size:105.56655079126358%;'>related </span><span style='font-size:101.52188427746296%;'>photography </span><span style='font-size:103.60101945698261%;'>nor </span><span style='font-size:103.43937873840332%;'>have </span><span style='font-size:105.32187595963478%;'>i </span><span style='font-size:117.23226606845856%;'>tried </span><span style='font-size:102.85878404974937%;'>portraits </span><span style='font-size:107.6732188463211%;'>for </span><span style='font-size:110.98448187112808%;'>which </span><span style='font-size:118.98982673883438%;'>i'm </span><span style='font-size:131.06902539730072%;'>sure </span><span style='font-size:154.63183522224426%;'>it </span><span style='font-size:185.6705605983734%;'>would </span><span style='font-size:133.95880162715912%;'>shine </span><span style='font-size:114.72829282283783%;'>very </span><span style='font-size:107.37561881542206%;'>well </span></span>

Prediction:  5  >  4  >  3  >  2  >  1


In [124]:
[i for i in a if i != [0, 1]]

NameError: name 'a' is not defined