In [1]:
import os
import re
from typing import Dict, Tuple
import numpy as np

import tensorflow as tf
from tensorflow.data import Dataset, AUTOTUNE
from tensorflow import keras

import keras.layers as l
from keras import models, callbacks, utils, losses

In [27]:
text = ''
with open('medium_articles.txt', 'r', encoding='utf-8') as file:
    text = file.read(250000)

def get_target(seq: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
    features = seq[:-1]
    target = seq[1:]
    return features, target

BATCH_SIZE = 32

words = list(filter(None, [re.sub('[^a-zA-Z0-9 ,-]', '', s).strip() for s in text.split('.')]))
alp = np.array(sorted(set(' '.join(words).split(' '))))

word_index = {char: i for i, char in enumerate(alp)}
index_word = {i: char for i, char in enumerate(alp)}

sequences = Dataset.from_tensor_slices(np.array([word_index[word] for word in ' '.join(words).split()])).batch(BATCH_SIZE, drop_remainder=True)
dataset = sequences.map(get_target)

data = dataset.batch(BATCH_SIZE, drop_remainder=True).repeat()
data = data.prefetch(AUTOTUNE)

In [28]:
model = keras.Sequential([
    l.Embedding(len(alp), BATCH_SIZE, batch_input_shape=[BATCH_SIZE, None]),
    l.Bidirectional(l.LSTM(150, return_sequences=True)),
    l.Dropout(0.2),
    l.LSTM(512, return_sequences=True, stateful=True),
    l.Dense(len(alp) / 2, activation='relu', kernel_regularizer=keras.regularizers.l2(0.01)),
    l.Dense(len(alp), activation='softmax')
])

In [29]:
model.compile(optimizer='adam', loss=losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
model.fit(data, epochs=45, verbose=1, steps_per_epoch= len(sequences) // BATCH_SIZE)

Epoch 1/45
Epoch 2/45
Epoch 3/45
Epoch 4/45
Epoch 5/45
Epoch 6/45
Epoch 7/45
Epoch 8/45
Epoch 9/45
Epoch 10/45
Epoch 11/45
Epoch 12/45
Epoch 13/45
Epoch 14/45
Epoch 15/45
Epoch 16/45
Epoch 17/45
Epoch 18/45
Epoch 19/45
Epoch 20/45
Epoch 21/45
Epoch 22/45
Epoch 23/45
Epoch 24/45
Epoch 25/45
Epoch 26/45
Epoch 27/45
Epoch 28/45
Epoch 29/45
Epoch 30/45
Epoch 31/45
Epoch 32/45
Epoch 33/45
Epoch 34/45
Epoch 35/45
Epoch 36/45
Epoch 37/45
Epoch 38/45
Epoch 39/45
Epoch 40/45
Epoch 41/45
Epoch 42/45
Epoch 43/45
Epoch 44/45
Epoch 45/45


<keras.src.callbacks.History at 0x2a53f95f100>

In [30]:
model.fit(data, epochs=65, initial_epoch=45, verbose=1, steps_per_epoch=len(sequences) // BATCH_SIZE)

Epoch 46/65
Epoch 47/65
Epoch 48/65
Epoch 49/65
Epoch 50/65
Epoch 51/65
Epoch 52/65
Epoch 53/65
Epoch 54/65
Epoch 55/65
Epoch 56/65
Epoch 57/65
Epoch 58/65
Epoch 59/65
Epoch 60/65
Epoch 61/65
Epoch 62/65
Epoch 63/65
Epoch 64/65
Epoch 65/65


<keras.src.callbacks.History at 0x2a5c25b1f40>

In [32]:
def gen_next(sample, model, tokenizer, vocabulary, n_next, temperature, batch_size, word):
    if word:
        sample_vector = [tokenizer[word] for word in sample.split()]
    else:
        sample_vector = [tokenizer[char] for char in sample]
    predicted = sample_vector
    sample_tensor = tf.expand_dims(sample_vector, 0)
    sample_tensor = tf.repeat(sample_tensor, batch_size, axis=0)
    for i in range(n_next):
        pred = model(sample_tensor)
        pred = pred[0].numpy() / temperature
        pred = tf.random.categorical(pred, num_samples=1)[-1, 0].numpy()
        predicted.append(pred)
        sample_tensor = predicted[-99:]
        sample_tensor = tf.expand_dims([pred], 0)
        sample_tensor = tf.repeat(sample_tensor, batch_size, axis=0)
    pred_seq = [vocabulary[i] for i in predicted]
    generated = ' '.join(pred_seq) if word else ''.join(pred_seq)
    return generated

In [37]:
print(gen_next(
    sample='Where',
    model=model,
    tokenizer=word_index,
    vocabulary=index_word,
    n_next=20,
    temperature=0.4,
    batch_size=BATCH_SIZE,
    word=True
))

Where quietly significantly true falls discriminating whenever acquisitions Donald honestly necessarily phenomena Finally, ScienceCan attacking doctor initiatives Back phrases whisper SME


In [50]:
print(gen_next(
    sample='Who',
    model=model,
    tokenizer=word_index,
    vocabulary=index_word,
    n_next=20,
    temperature=0.2,
    batch_size=BATCH_SIZE,
    word=True
))

Who happy tell trail controversies, netcnnLiege-v-Benfica01 papers, plain CCPS, diagnoses paywall achieving personally Jones wore tying urgency Waste reference quarantining severe


In [58]:
print(gen_next(
    sample='Face',
    model=model,
    tokenizer=word_index,
    vocabulary=index_word,
    n_next=20,
    temperature=0.6,
    batch_size=BATCH_SIZE,
    word=True
))

Face define drilling focuses premier June offline Cantt fight thrived copywriter brought govern traced Fantastic thicker chi complex Back components, Development,
