In [2]:
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf

import numpy as np
import os
import time

import time

from sklearn.utils import shuffle

In [3]:
# read titles from disk
titles = []
with open('clickbait_dataset.txt', 'r') as f:
    for l in f.readlines():
        titles.append("<start> " + l.strip() + " <end>")
shuffle(titles, random_state=42)

# tokenize the titles
tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=5000,
                                                  oov_token='<unk>',
                                                  filters='!"#$%&()*+.,-/:;=?@[\]^_`{|}~ ')
tokenizer.fit_on_texts(titles)
train_seqs = tokenizer.texts_to_sequences(titles)

# pad the titles to get a numpy array
tokenizer.word_index['<pad>'] = 0
tokenizer.index_word[0] = '<pad>'

train_seqs_pad = tf.keras.preprocessing.sequence.pad_sequences(train_seqs, padding='post')

In [4]:
# hyperparameters
num_words = len(tokenizer.word_index) + 1
embedding_dim = 100
rnn_units = 256
batch_size = 64
num_steps = len(train_seqs) // batch_size
max_length = max(len(t) for t in train_seqs)

In [5]:
# create the dataset
dataset = tf.data.Dataset.from_tensor_slices(train_seqs_pad)
dataset = dataset.shuffle(1000).batch(batch_size)
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

In [6]:
# define the model
class RNN_Model(tf.keras.Model):
    def __init__(self, embedding_dim, units, vocab_size):
        super(RNN_Model, self).__init__()
        
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        self.lstm = tf.keras.layers.LSTMCell(units, recurrent_initializer='glorot_uniform')
        self.dense = tf.keras.layers.Dense(vocab_size)

    def call(self, inputs, states=None):
        encoded_txt = self.embedding(inputs)
        
        states = states if states != None else self.lstm.get_initial_state(inputs=encoded_txt)
        
        outputs, states_nxt = self.lstm(encoded_txt, states)
        x = self.dense(outputs)
        return x, states_nxt

In [7]:
model = RNN_Model(embedding_dim, rnn_units, num_words)

In [8]:
optimizer = tf.keras.optimizers.Adam()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True)

def loss_function(real, pred):
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss_ = loss_object(real, pred)

    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask

    return tf.reduce_mean(loss_)

In [9]:
checkpoint_path = "./checkpoints/train"
ckpt = tf.train.Checkpoint(model=model,
                           optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

start_epoch = 0
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    start_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1])

In [10]:
@tf.function
def train_step(title_seq):
    loss = 0
    states = None
    with tf.GradientTape() as tape:
        for i in range(title_seq.shape[1]-1):
            inputs = title_seq[:, i]
            prediction, states = model(inputs, states)
            
            loss += loss_function(title_seq[:, i+1], prediction)

    total_loss = (loss / int(title_seq.shape[1]))

    trainable_variables = model.trainable_variables

    gradients = tape.gradient(loss, trainable_variables)

    optimizer.apply_gradients(zip(gradients, trainable_variables))

    return loss, total_loss

In [43]:
EPOCHS = 20

for epoch in range(start_epoch, EPOCHS):
    start = time.time()
    total_loss = 0
    
    for (batch, title_seq) in enumerate(dataset):
        batch_loss, t_loss = train_step(title_seq)
        total_loss += t_loss

        if batch % 100 == 0:
            print ('Epoch {} Batch {} Loss {:.4f}'.format(
              epoch + 1, batch, batch_loss.numpy() / int(title_seq.shape[1])))

    ckpt_manager.save()

    print ('Epoch {} Loss {:.6f}'.format(epoch + 1,
                                         total_loss/num_steps))
    print ('Time taken for 1 epoch {} sec\n'.format(time.time() - start))

Epoch 12 Batch 0 Loss 1.3055
Epoch 12 Batch 100 Loss 1.2575
Epoch 12 Batch 200 Loss 1.1534
Epoch 12 Loss 1.276757
Time taken for 1 epoch 246.0968861579895 sec

Epoch 13 Batch 0 Loss 1.2494
Epoch 13 Batch 100 Loss 1.2787
Epoch 13 Batch 200 Loss 1.2048
Epoch 13 Loss 1.256216
Time taken for 1 epoch 246.00481534004211 sec

Epoch 14 Batch 0 Loss 1.1709
Epoch 14 Batch 100 Loss 1.2421
Epoch 14 Batch 200 Loss 1.1669
Epoch 14 Loss 1.235014
Time taken for 1 epoch 246.1310214996338 sec

Epoch 15 Batch 0 Loss 1.2308
Epoch 15 Batch 100 Loss 1.1437
Epoch 15 Batch 200 Loss 1.0828
Epoch 15 Loss 1.216074
Time taken for 1 epoch 247.85262060165405 sec

Epoch 16 Batch 0 Loss 1.2200
Epoch 16 Batch 100 Loss 1.1892
Epoch 16 Batch 200 Loss 1.1073
Epoch 16 Loss 1.196643
Time taken for 1 epoch 252.4271650314331 sec

Epoch 17 Batch 0 Loss 1.2199
Epoch 17 Batch 100 Loss 1.1517
Epoch 17 Batch 200 Loss 1.1120
Epoch 17 Loss 1.178150
Time taken for 1 epoch 248.58952474594116 sec

Epoch 18 Batch 0 Loss 1.2537
Epoch 18

In [11]:
def gen_new_title():
    inputs = tf.Variable([tokenizer.word_index['<start>']])
    states = None
    
    result = ['<start>']
    for i in range(max_length):
        nxt_word, states = model(inputs, states)
        
        predicted_id = tf.random.categorical(nxt_word, 1)[0][0].numpy()
        result.append(tokenizer.index_word[predicted_id])
        inputs = tf.Variable([predicted_id])

        if tokenizer.index_word[predicted_id] == '<end>':
            return result
    
    return result

In [12]:
for _ in range(10):
    print(*gen_new_title()[1:-1])

are you way to not watching this holiday opinions <unk> <unk> we've into the hunter cast <unk>
the rock kids just <unk> the chennai rains when she <unk> about taylor swift off
you'll only guess which k pop group from 2019 based on which bts you should visit what your crush
here are all the <unk> to taylor swift's way laugh the investigation into their lives
19 times bbc company was the best romantic on this decade
which character from facebook are you
an old <unk> shot the <unk> art and anne bush and there are <unk>
62 christmas videos and we'll tell you where are 2020
i'm a black jackson fan and we'll predict which high school musical as you
are you more like never <unk> a college level netflix
