# Keras Seq2Seq Formality Transfer Model

References:
- https://blog.keras.io/a-ten-minute-introduction-to-sequence-to-sequence-learning-in-keras.html
- https://github.com/lukas/ml-class/blob/master/videos/seq2seq/train.py

## TODO

- Experiment with using GloVe word embedding rather than one-hot character encoding
- Experiment with one-hot n-gram embeddings as well
- Experiment with LSTM/GRU effectiveness

In [None]:
from keras.layers import GRU, TimeDistributed, RepeatVector, Dense
from keras.models import Sequential
from keras.callbacks import Callback, ModelCheckpoint

from IPython import display
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

import numpy as np
import time
import os

## Load training data

In [None]:
num_samples = 1024
train_prop = 0.9
checkpoints_dir = os.path.abspath("keras-checkpoints")

with open("labelled.txt", "r") as l:
    raw_data = l.read()
lines = [x.strip() for x in raw_data.split("\n")]
max_length = max(len(i) for i in lines)

pairs = [(lines[i], lines[i+1]) for i in range(0, min(len(lines),num_samples*2), 2)]

class CharacterTable(object):
    """Given a set of characters:
    + Encode them to a one hot integer representation
    + Decode the one hot integer representation to their character output
    + Decode a vector of probabilities to their character output
    """
    def __init__(self, chars):
        """Initialize character table.
        # Arguments
            chars: Characters that can appear in the input.
        """
        self.chars = sorted(set(chars))
        self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
        self.indices_char = dict((i, c) for i, c in enumerate(self.chars))

    def encode(self, C, num_rows):
        """One hot encode given string C.
        # Arguments
            num_rows: Number of rows in the returned one hot encoding. This is
                used to keep the # of rows for each data the same.
        """
        x = np.zeros((num_rows, len(self.chars)))
        for i, c in enumerate(C):
            x[i, self.char_indices[c]] = 1
        return x

    def decode(self, x, calc_argmax=True):
        if calc_argmax:
            x = x.argmax(axis=-1)
        return ''.join(self.indices_char[x] for x in x).strip()

ctable = CharacterTable(raw_data.replace("\n", "")+"\x00")

X = np.zeros((len(pairs), max_length, len(ctable.chars)))
Y = np.zeros((len(pairs), max_length, len(ctable.chars)))

for i, (xi, yi) in enumerate(pairs):
    X[i] = ctable.encode(xi, max_length)
    Y[i] = ctable.encode(yi, max_length)

train_num = int(train_prop * len(pairs))

X_train, X_test = X[:train_num], X[train_num:]
Y_train, Y_test = Y[:train_num], Y[train_num:]

## Define Model

In [None]:
hidden_size = 256

 # Maybe replace GRUs with LSTMs (better performance but slower)
model = Sequential()
model.add(GRU(hidden_size, input_shape=(max_length, len(ctable.chars))))
model.add(RepeatVector(max_length))
model.add(GRU(hidden_size, return_sequences=True))
model.add(TimeDistributed(Dense(len(ctable.chars), activation="softmax")))

model.compile(
    loss="categorical_crossentropy",
    optimizer="adam",
    metrics=["accuracy"]
)

model.summary()

## Configure Graph

In [None]:
class Plotter(Callback):
    id_counter = 0

    def __init__(self):
        super().__init__()
        self.id = f"plot-{Plotter.id_counter}"
        Plotter.id_counter += 1
        self.fig, self.ax = plt.subplots(figsize=(24, 6), dpi=80)
        self.ax.set_xlabel("Epoch")
        self.ax.set_ylabel("Loss")
        self.ax.xaxis.set_major_locator(ticker.MultipleLocator(base=10.0))
        self.ax.xaxis.set_minor_locator(ticker.MultipleLocator(base=1.0))
        self.ax.yaxis.set_major_locator(ticker.MultipleLocator(base=0.1))
        self.ax.yaxis.set_minor_locator(ticker.MultipleLocator(base=0.02))
        self.ax.grid(which="major", color="#888888")
        self.ax.grid(which="minor", color="#bbbbbb")

        self.fig.patch.set_facecolor("white")
        box = self.ax.get_position()
        self.ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
        display.display(self.fig, display_id=self.id)

        self.epochs = []
        self.train_losses = []
        self.val_losses = []

    def on_epoch_end(self, epoch, logs=None):
        for l in self.ax.lines:
            l.remove()
        
        self.epochs.append(epoch)
        self.train_losses.append(logs["loss"])
        self.val_losses.append(logs["val_loss"])
        
        t_line, = self.ax.plot(self.epochs, self.train_losses, c="#55CDFC")
        v_line, = self.ax.plot(self.epochs, self.val_losses, c="#F7A8B8")
        self.ax.autoscale()
        self.ax.set_ylim(0, None)
        self.ax.legend([t_line, v_line], ["Training", "Validation"], loc='center left', bbox_to_anchor=(1, 0.5))
        display.update_display(self.fig, display_id=self.id)

## Delete Existing Checkpoints &mdash; WARNING


In [None]:
if not os.path.isdir(checkpoints_dir):
    os.makedirs(checkpoints_dir)

for f in os.listdir(checkpoints_dir):
    os.unlink(os.path.join(checkpoints_dir, f))

## Train

In [None]:
batch_size = 1
epochs = 100

demo_str = "You should seek advice from a medical professional."
demo_vec = np.array([ctable.encode(demo_str, max_length)])

filepath = os.path.join(checkpoints_dir, "formal2casual-{epoch:02d}-{loss:.4f}.ckpt")
checkpoint = ModelCheckpoint(filepath, save_weights_only=True, monitor="loss", verbose=0, save_best_only=True, mode="min")

class PrintDemo(Callback):
    def on_epoch_begin(self, epoch, logs=None):
        if epoch == 0:
            return
        pred_vec = self.model.predict(demo_vec, verbose=0)
        pred_str = ctable.decode(pred_vec[0])

        print(f"\"{demo_str}\" -> \"{pred_str}\"\n")


model.fit(
    X_train,
    Y_train,
    batch_size=batch_size,
    epochs=epochs,
    validation_data=(X_test, Y_test),
    callbacks=[checkpoint, PrintDemo(), Plotter()]
)

    

## Load Checkpoint

In [None]:
latest = tf.train.latest_checkpoint(checkpoints_dir)
model.load_weights(latest)

## Test

In [None]:
test_str = "You should seek advice from a medical professional."
start = time.time()
test_vec = np.array([ctable.encode(demo_str, max_length)])

pred_vec = model.predict(test_vec, verbose=0)
pred_str = ctable.decode(pred_vec[0])
elapsed = time.time() - start
print(test_str)
print("->")
print(pred_str)
print(f"Took {elapsed:5f} seconds")