In [None]:
import sys
import numpy as np
from typing import Tuple
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import re
import string

In [None]:
def load_file(file_path: str) -> Tuple[Tuple[str], Tuple[str]]:
    """ A helper functions that loads the file into a tuple of strings

    :param file_path: path to the data file
    :return factors: (LHS) inputs to the model
            expansions: (RHS) group truth
    """
    data = open(file_path, "r").readlines()
    factors, expansions = zip(*[line.strip().split("=") for line in data])
    return factors, expansions

In [None]:
text_pairs = []
factors, expansions = load_file("train.txt")
for i in range(len(factors)):
  text_pairs.append(factors[i])
  text_pairs.append(expansions[i])

In [None]:
num_val_samples = int(0.15 * len(text_pairs))
num_train_samples = len(text_pairs) - 2 * num_val_samples
train_pairs = text_pairs[:num_train_samples]
val_pairs = text_pairs[num_train_samples:num_train_samples + num_val_samples]
test_pairs = text_pairs[num_train_samples + num_val_samples:]

In [None]:
vocab_size = 4000
sequence_length = 29

source_vectorization = layers.TextVectorization(
    max_tokens=vocab_size,
    output_mode="int",
    output_sequence_length=sequence_length,
)
target_vectorization = layers.TextVectorization(
    max_tokens=vocab_size,
    output_mode="int",
    output_sequence_length=sequence_length + 1,
)

train_factored_texts = [pair[0] for pair in train_pairs]
train_expansion_texts = [pair[1] for pair in train_pairs]
source_vectorization.adapt(train_factored_texts)
target_vectorization.adapt(train_expansion_texts)

In [None]:
batch_size = 12

def format_dataset(factored, expanded):
factored = source_vectorization(factored)
expanded = target_vectorization(expanded)
return ({
           "factored": factored,
           "expanded": expanded[:, :-1], }, expanded[:, 1:])

def make_dataset(pairs):
factored_texts = zip(train_factored_texts)
expanded_texts = zip(train_expansion_texts)
factored_texts = list(factored_texts)
expanded_texts = list(expanded_texts)
dataset = tf.data.Dataset.from_tensor_slices((factored_texts, expanded_texts))
dataset = dataset.batch(batch_size)
dataset = dataset.map(format_dataset, num_parallel_calls=4)
return dataset

train_ds = make_dataset(train_pairs)
val_ds = make_dataset(val_pairs)

In [None]:
embed_dim = 50
latent_dim = 400

source = keras.Input(shape=(None,), dtype="int64", name="Factors")
x = layers.Embedding(vocab_size, embed_dim, mask_zero=True)(source)
encoded_source = layers.Bidirectional(
    layers.GRU(latent_dim), merge_mode="sum")(x)

In [None]:
past_target = keras.Input(shape=(None,), dtype="int64", name="Expansions")
x = layers.Embedding(vocab_size, embed_dim, mask_zero=True)(past_target)
decoder_gru = layers.GRU(latent_dim, return_sequences=True)
x = decoder_gru(x, initial_state=encoded_source)
x = layers.Dropout(0.5)(x)
target_next_step = layers.Dense(vocab_size, activation="softmax")(x)
seq2seq_rnn = keras.Model([source, past_target], target_next_step)

In [None]:
seq2seq_rnn.compile(
    optimizer="rmsprop",
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"])

In [None]:
seq2seq_rnn.summary()

Model: "model_6"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 Factors (InputLayer)           [(None, None)]       0           []                               
                                                                                                  
 Expansions (InputLayer)        [(None, None)]       0           []                               
                                                                                                  
 embedding_12 (Embedding)       (None, None, 50)     200000      ['Factors[0][0]']                
                                                                                                  
 embedding_13 (Embedding)       (None, None, 50)     200000      ['Expansions[0][0]']             
                                                                                            

In [None]:
seq2seq_rnn.fit(train_ds, epochs=1, validation_data=val_ds)

Epoch 1/7
 1888/21875 [=>............................] - ETA: 1:46:52 - loss: 5.3057e-06 - accuracy: 1.0000