In [13]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import platform
import time
import pathlib
import os
from google.colab import drive

In [17]:
drive.mount('/content/drive')
filename = '/content/Persian-WikiText-1.txt'
text = open(filename, 'r', encoding='utf-8').read()

Mounted at /content/drive


In [18]:
text = text[:1000000]
text = text.lower()
text = text.replace('\n', ' ')

In [19]:
# we map each word to a number
vocab = sorted(set(text))
charToInt = {char: index for index, char in enumerate(vocab)}
intToChar = np.array(vocab)
textAsInt = np.array([charToInt[char] for char in text])

In [20]:
charDataset = tf.data.Dataset.from_tensor_slices(textAsInt) 

In [21]:
sequences = charDataset.batch(101, drop_remainder=True)

In [22]:
def split_input_target(chunk):
    input_text = chunk[:-1]
    target_text = chunk[1:]
    return input_text, target_text

In [23]:
dataset = sequences.map(split_input_target)

In [24]:
dataset = dataset.shuffle(10000).batch(64, drop_remainder=True)

In [25]:
model = tf.keras.models.Sequential()

model.add(tf.keras.layers.Embedding(
      input_dim=len(vocab),
      output_dim=256,
      batch_input_shape=[64, None]
    ))

model.add(tf.keras.layers.LSTM(
      units=1024,
      return_sequences=True,
      stateful=True,
      recurrent_initializer=tf.keras.initializers.GlorotNormal()
    ))

model.add(tf.keras.layers.Dense(len(vocab)))

In [26]:
def loss(labels, logits):
    return tf.keras.losses.sparse_categorical_crossentropy(
      y_true=labels,
      y_pred=logits,
      from_logits=True
    )

adam_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(
    optimizer=adam_optimizer,
    loss=loss
)

In [27]:
checkpoint_dir = 'tmp/checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

checkpoint_prefix = os.path.join(checkpoint_dir, 'LSTM_Model_{epoch}')

checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True
)

In [28]:
history = model.fit(
  x=dataset,
  epochs=20,
  callbacks=[
    checkpoint_callback
  ]
)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
