We'll write a simple template for seq2seq using Tensorflow. For demonstration, we attack the g2p task. G2p is a task of converting graphemes (spelling) to phonemes (pronunciation). It's a very good source for this purpose as it's simple enough for you to up and run. If you want to know more about g2p, see my [repo](https://github.com/kyubyong/g2p)

In [1]:
__author__ = "kyubyong"
__address__ = "https://github.com/kyubyong/nlp_made_easy"
__email__ = "kbpark.linguist@gmail.com"

In [2]:
import numpy as np
import tensorflow as tf
from tqdm import tqdm
from distance import levenshtein
import os

In [3]:
tf.__version__

'1.12.0'

# Hyperparameters

In [37]:
params = dict()
params["batch_size"] = 128
params["test_batch_size"] = 128
params["enc_maxlen"] = 20
params["dec_maxlen"] = 20
params["num_epochs"] = 10
params["hidden_units"] = 128
params["graphemes"] = ["<PAD>", "<UNK>", "<EOS>"] + list("abcdefghijklmnopqrstuvwxyz")
params["phonemes"] = ["<PAD>", "<UNK>", "<BOS>", "<EOS>"] + ['AA0', 'AA1', 'AA2', 'AE0', 'AE1', 'AE2', 'AH0', 'AH1', 'AH2', 'AO0',
                'AO1', 'AO2', 'AW0', 'AW1', 'AW2', 'AY0', 'AY1', 'AY2', 'B', 'CH', 'D', 'DH',
                'EH0', 'EH1', 'EH2', 'ER0', 'ER1', 'ER2', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH',
                'IH0', 'IH1', 'IH2', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW0', 'OW1',
                'OW2', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH0', 'UH1', 'UH2', 'UW',
                'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH']
params["lr"] = 0.001
params["eval_steps"] = 500
params["logdir"] = "logdir1"

# Prepare Data

In [38]:
import nltk
# nltk.download('cmudict') <- if you haven't downloaded, do this.
from nltk.corpus import cmudict
cmu = cmudict.dict()

In [39]:
def load_vocab():
    g2idx = {g: idx for idx, g in enumerate(params["graphemes"])}
    idx2g = {idx: g for idx, g in enumerate(params["graphemes"])}

    p2idx = {p: idx for idx, p in enumerate(params["phonemes"])}
    idx2p = {idx: p for idx, p in enumerate(params["phonemes"])}

    return g2idx, idx2g, p2idx, idx2p # note that g and p mean grapheme and phoneme, respectively.

In [40]:
def prepare_data():
    words = [" ".join(list(word)) for word, prons in cmu.items()]
    prons = [" ".join(prons[0]) for word, prons in cmu.items()]
    indices = list(range(len(words)))
    from random import shuffle
    shuffle(indices)
    words = [words[idx] for idx in indices]
    prons = [prons[idx] for idx in indices]
    num_train, num_test = int(len(words)*.8), int(len(words)*.1)
    train_words, eval_words, test_words = words[:num_train], \
                                          words[num_train:-num_test],\
                                          words[-num_test:]
    train_prons, eval_prons, test_prons = prons[:num_train], \
                                          prons[num_train:-num_test],\
                                          prons[-num_test:]    
    return train_words, eval_words, test_words, train_prons, eval_prons, test_prons

In [41]:
train_words, eval_words, test_words, train_prons, eval_prons, test_prons = prepare_data()
print(train_words[0])
print(train_prons[0])

s e l f - g o v e r n m e n t
S EH1 L F G AH1 V ER0 N M AH0 N T


In [42]:
def drop_lengthy_samples(words, prons, enc_maxlen, dec_maxlen):
    """We only include such samples less than maxlen."""
    _words, _prons = [], []
    for w, p in zip(words, prons):
        if len(w.split()) + 1 > enc_maxlen: continue
        if len(p.split()) + 1 > dec_maxlen: continue # 1: <EOS>
        _words.append(w)
        _prons.append(p)
    return _words, _prons

In [43]:
train_words, train_prons = drop_lengthy_samples(train_words, train_prons, params["enc_maxlen"], params["dec_maxlen"])
eval_words, eval_prons = drop_lengthy_samples(eval_words, eval_prons, params["enc_maxlen"], params["dec_maxlen"])
test_words, test_prons = drop_lengthy_samples(test_words, test_prons, params["enc_maxlen"], params["dec_maxlen"])

# Data Loader

In [44]:
def generator_fn(words, prons, padding=False, enc_maxlen=0, dec_maxlen=0):
    '''
    words: 1d byte array (when training) or list (when predicting). words. e.g., [b"w o r d", ]
    prons: 1d byte array (when training) or list (when predicting). prons. e.g., [b'W ER1 D', ]
    padding: boolean. If True, zeros's are padded such that the length becomes the maxlen.
    enc_maxlen: If padding is True, this must be not 0.
    dec_maxlen: If padding is True, this must be not 0.
    
    yields
    xs: tuple of
        x: list of encoded x. encoder input
        x_seqlen: scalar.
        word: string
        
    ys: tuple of
        decoder_input: list of decoder inputs
        y: list of encoded y. label.
        y_seqlen: scalar.
        pron: string
    '''
    g2idx, idx2g, p2idx, idx2p = load_vocab()
    for word, pron in zip(words, prons):
        w_str = word.decode("utf-8") if isinstance(word, (bytes)) else word
        p_str = pron.decode("utf-8") if isinstance(pron, (bytes)) else pron
        graphemes = w_str.split() + ["<EOS>"]
        phonemes = ["<BOS>"] + p_str.split() + ["<EOS>"]

        x = [g2idx.get(g, g2idx["<UNK>"]) for g in graphemes]
        y = [p2idx.get(p, p2idx["<UNK>"]) for p in phonemes]
        decoder_input, y = y[:-1], y[1:]
        
        x_seqlen, y_seqlen = len(x), len(y)
        if padding:
            x += [g2idx["<PAD>"]]*(enc_maxlen - len(x))
            decoder_input += [p2idx["<PAD>"]]*(dec_maxlen - len(decoder_input))
            y += [p2idx["<PAD>"]]*(dec_maxlen - len(y))

        yield (x, x_seqlen, word), (decoder_input, y, y_seqlen, pron)

In [45]:
def input_fn(words, prons, batch_size, shuffle=False):
    '''Batchify data
    words: list of words. e.g., ["word", ]
    prons: list of prons. e.g., ['W ER1 D',]
    batch_size: scalar.
    shuffle: boolean
    '''
    shapes = ( ([None], (), ()),
               ([None], [None], (), ())  )
    types = (  (tf.int32, tf.int32, tf.string),
               (tf.int32, tf.int32, tf.int32, tf.string)  )
    paddings = (  (0, 0, ''),
                  (0, 0, 0, '')  )

    dataset = tf.data.Dataset.from_generator(
        generator_fn,
        output_shapes=shapes,
        output_types=types,
        args=(words, prons)) # <- converted to np string arrays
        
    if shuffle:
        dataset = dataset.shuffle(64*batch_size)
    
    dataset = dataset.repeat() # iterate forever
    dataset = dataset.padded_batch(batch_size, shapes, paddings).prefetch(1)

    return dataset

# Model

In [46]:
tf.reset_default_graph()

In [47]:
class Net:
    def __init__(self, params):
        self.g2idx, self.idx2g, self.p2idx, self.idx2p = load_vocab()
        self.params = params
    
    def encode(self, xs):
        '''
        xs: tupple of 
            x: (N, T). int32
            seqlens: (N,). int32
            words: (N,). string
            
        returns
        last hidden: (N, hidden_units). float32
        words: (N,). string
        '''
        with tf.variable_scope("encode", reuse=tf.AUTO_REUSE):
            x, seqlens, words = xs
            x = tf.one_hot(x, len(self.g2idx))
            cell = tf.contrib.rnn.GRUCell(self.params["hidden_units"])
            _, last_hidden = tf.nn.dynamic_rnn(cell, x, seqlens, dtype=tf.float32)
            
        return last_hidden, words
        
    
    def decode(self, ys, h0=None):
        '''
        ys: tupple of 
            decoder_inputs: (N, T). int32
            y: (N, T). int32
            seqlens: (N,). int32
            prons: (N,). string.
        h0: initial hidden state. (N, hidden_units)
        
        returns
        logits: (N, T, len(p2idx)). float32. before softmax
        preds: (N, T). int32.
        y: (N, T). int32. label.
        prons: (N,). string. ground truth phonemes 
        last_hidden: (N, hidden_units). This is for autoregressive inference
        '''
        decoder_inputs, y, seqlens, prons = ys
       
        with tf.variable_scope("decode", reuse=tf.AUTO_REUSE):
            inputs = tf.one_hot(decoder_inputs, len(self.p2idx))
            
            cell = tf.contrib.rnn.GRUCell(self.params["hidden_units"])
            outputs, last_hidden = tf.nn.dynamic_rnn(cell, inputs, initial_state=h0, dtype=tf.float32)

            # projection
            logits = tf.layers.dense(outputs, len(self.p2idx))
            preds = tf.to_int32(tf.argmax(logits, axis=-1))
        
        return logits, preds, y, prons, last_hidden
            
    def train(self, xs, ys):
        # forward
        last_hidden, words = self.encode(xs)
        logits, preds, y, prons, last_hidden = self.decode(ys, h0=last_hidden)
        
        # train scheme
        ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y)
        nonpadding = tf.to_float(tf.not_equal(y, self.p2idx["<PAD>"])) # 0: <pad>
        loss = tf.reduce_sum(ce*nonpadding) / (tf.reduce_sum(nonpadding)+1e-7)

        global_step = tf.train.get_or_create_global_step()
        train_op = tf.train.AdamOptimizer(params["lr"]).minimize(loss, global_step=global_step)
        
        return words, preds, y, prons, loss, train_op, global_step

    
    def infer(self):
        # inputs
        self.x = tf.placeholder(tf.int32, (None, None))
        self.x_seqlens = tf.placeholder(tf.int32, (None,))
        self.words = tf.placeholder(tf.string, (None,))

        decoder_inputs = tf.ones((tf.shape(self.x)[0], 1), tf.int32)*self.p2idx["<BOS>"]
        
        xs = (self.x, self.x_seqlens, self.words)
        ys = (decoder_inputs, None, None, None)
        
        last_hidden, words = self.encode(xs)
        h0 = last_hidden
        Preds = []
        for t in range(self.params["dec_maxlen"]):
            _, preds, _, _, h0 = self.decode(ys, h0)
            if tf.reduce_sum(preds, 1)==0: break
           
            ys = (preds, None, None, None)
            Preds.append(tf.squeeze(preds))
        Preds = tf.stack(Preds, 1)
        return Preds
  

# Train & Evaluate

In [48]:
def calc_num_batches(total_num, batch_size):
    return total_num // batch_size + int(total_num % batch_size != 0) 

In [63]:
# evaluation metric
def per(hyp, ref):
    '''Calc phoneme error rate
    hyp: list of predicted phoneme sequences. e.g., [["B", "L", "AA1", "K", "HH", "AW2", "S"], ...]
    ref: list of ground truth phoneme sequences. e.g., [["B", "L", "AA1", "K", "HH", "AW2", "S"], ...]
    '''
    num_phonemes, num_erros = 0, 0
    for h, r in zip(hyp, ref):
        num_phonemes += len(r)
        num_erros += levenshtein(h, r)
#         print(h, r, levenshtein(h, r), len(r))
    per = round(num_erros / num_phonemes, 2)
    return per

In [50]:
# prepare batches
train_batches = input_fn(train_words, train_prons,
                         params["batch_size"], shuffle=True)
num_train_batches = calc_num_batches(len(train_words), params["batch_size"])

eval_batches = input_fn(eval_words, eval_prons,
                        params["batch_size"], shuffle=False)
num_eval_batches = calc_num_batches(len(eval_words), params["batch_size"])

In [51]:
# create a iterator of the correct shape and type
iter = tf.data.Iterator.from_structure(train_batches.output_types, eval_batches.output_shapes)
xs, ys = iter.get_next()

# create the initialisation operations
train_init_op = iter.make_initializer(train_batches)
eval_init_op = iter.make_initializer(eval_batches)

In [52]:
# Training Session
net = Net(params)
words, preds, y, prons, loss, train_op, global_step = net.train(xs, ys)

saver = tf.train.Saver()
with tf.Session() as sess:
    if tf.train.checkpoint_exists(params["logdir"]):
        ckpt = tf.train.latest_checkpoint(params["logdir"])
        print("Restoring from file: ", ckpt)
        saver.restore(sess, ckpt)
    else:
        print("Initializing from scratch")
        sess.run(tf.global_variables_initializer())

    sess.run(train_init_op)
    for epoch in range(1, params["num_epochs"]+1):
        for _ in range(num_train_batches):
            # training
            _, _gs = sess.run([train_op, global_step])  
            
            # regular evaluation
            if _gs%params["eval_steps"]==0:
                _loss = sess.run(loss)

                sess.run(eval_init_op)
                hyp, ref = [], []                
                for _ in range(num_eval_batches):
                    _words, _preds, _y, _prons = sess.run([words, preds, y, prons]) 
                    hyp.extend(_preds.tolist())
                    ref.extend(_y.tolist())
                    
                ## logging
                _word = _words[0].decode('utf-8')
                _pron = _prons[0].decode('utf-8')
                _pred = " ".join(net.idx2p[each] for each in _preds[0])#.split("<EOS>")[0]
                
                _per = per(hyp, ref)
                
                print("="*10, "epoch=", epoch, "global step=", _gs, "="*10)
                print("train loss= %.2f | eval error rate=%.2f" % (_loss, _per))
                print("wrd:", _word)
                print("exp:", _pron)
                print("got:", _pred)
                print()
                
                # save
                if not os.path.exists(params["logdir"]): os.mkdir(params["logdir"])
                fname = os.path.join(params["logdir"], "my_model_loss_%.2f_per_%.2f" % (_loss, _per))
                saver.save(sess, fname, global_step=_gs)
    print("Training Done!")

Initializing from scratch
train loss= 4.29 | eval error rate=0.99
wrd: u p b e a t
exp: AH1 P B IY2 T
got: F ZH N N ZH R AE1 AE1 AE1 AE1 AE1 AE1 AE1 AE1 AE1 AE1 AE1 AE1

train loss= 1.88 | eval error rate=0.74
wrd: u p b e a t
exp: AH1 P B IY2 T
got: K K AH0 AH0 K <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS>

train loss= 1.27 | eval error rate=0.66
wrd: u p b e a t
exp: AH1 P B IY2 T
got: P P B AH0 T <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS>

train loss= 0.98 | eval error rate=0.62
wrd: u p b e a t
exp: AH1 P B IY2 T
got: P P B AH0 T <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> UW UW UW

train loss= 0.80 | eval error rate=0.60
wrd: u p b e a t
exp: AH1 P B IY2 T
got: AH0 P B EY2 T <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> Z Z

train loss= 0.72 | eval error rate=0.59
wrd: u p b e a t
exp: AH1 P B IY2 T
got: AH0 P B EY2 T <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> <EOS> T T T 

# Inference

In [64]:
tf.reset_default_graph()

# prepare batches
test_batches = list(generator_fn(test_words, test_prons, True, params["enc_maxlen"], params["dec_maxlen"]))
num_test_batches = calc_num_batches(len(test_words), params["test_batch_size"])

net = Net(params)
preds = net.infer()

# saver for restoration
# saver = tf.train.import_meta_graph(mname + ".meta") <- Do NOT use this as we'll use a distinct graph.
saver = tf.train.Saver()
   
with tf.Session() as sess:
    ckpt = tf.train.latest_checkpoint(params["logdir"])
    saver.restore(sess, ckpt); print("checkpoint restored") 
    
    hyp = []  
    for i in tqdm(range(num_test_batches)):
        batch = test_batches[i*params["test_batch_size"] : (i+1)*params["test_batch_size"]]
        x = [xs[0] for xs, _ in batch]
        x_seqlens = [xs[1] for xs, _ in batch]
        words = [xs[2] for xs, _ in batch]
        
        feed_dict = {net.x: x, 
                     net.x_seqlens: x_seqlens,
                     net.words: words}
        _preds = sess.run(preds, feed_dict)
        hyp.extend(_preds.tolist())
    
    ## evaluation
    _hyp = []
    for phonemes in hyp:
        each = []
        for idx in phonemes:
            if idx == net.p2idx["<EOS>"]: break
            each.append(net.idx2p[idx])
        _hyp.append(each)
    
    ref = [pron.split() for pron in test_prons]
    _per = per(_hyp, ref)
    
    # save
    with open("result", 'w') as fout:
        fout.write("per: %.2f\n" % _per)
        for w, r, h in zip(test_words, ref, _hyp):
            w = w.replace(" ", "")
            r = " ".join(r)
            h = " ".join(h)
            fout.write("wrd: {}\nexp: {}\ngot: {}\n\n".format(w, r, h))
            
    print("per:", _per)
    print("Done!")

INFO:tensorflow:Restoring parameters from logdir1/my_model_loss_0.25_per_0.53-7500


  0%|          | 0/97 [00:00<?, ?it/s]

checkpoint restored


100%|██████████| 97/97 [00:02<00:00, 33.86it/s]


per: 0.32
Done!


Let's see some results.

In [65]:
open('result', 'r').read().splitlines()[-100:]

['wrd: pledges',
 'exp: P L EH1 JH IH0 Z',
 'got: P L EH1 JH IH0 Z',
 '',
 'wrd: combe',
 'exp: K OW1 M',
 'got: K AA1 M B',
 '',
 'wrd: suspicions',
 'exp: S AH0 S P IH1 SH AH0 N Z',
 'got: S AH0 S P IY0 OW1 N IH0 S',
 '',
 "wrd: fargo's",
 'exp: F AA1 R G OW2 Z',
 'got: F AA1 R G OW0 Z',
 '',
 'wrd: fizzles',
 'exp: F IH1 Z AH0 L Z',
 'got: F EH1 Z AH0 L Z',
 '',
 'wrd: halon',
 'exp: HH EY1 L AA2 N',
 'got: HH AE1 L AH0 N',
 '',
 'wrd: snydergeneral',
 'exp: S N AY2 D ER0 JH EH1 N ER0 AH0 L',
 'got: S N D EH1 G R AH0 N M AY0 L IY0',
 '',
 'wrd: ero',
 'exp: IH1 R OW0',
 'got: EH1 R OW0',
 '',
 'wrd: brockett',
 'exp: B R AA1 K IH0 T',
 'got: B R AA1 K IH0 T',
 '',
 'wrd: sirna',
 'exp: S ER1 N AH0',
 'got: S ER1 N AH0',
 '',
 'wrd: reuss',
 'exp: R UW1 S',
 'got: R UW1 S',
 '',
 'wrd: saint',
 'exp: S EY1 N T',
 'got: S AE1 N T',
 '',
 'wrd: natividad',
 'exp: N AH0 T IH0 V IH0 D AA1 D',
 'got: N EY1 T V IH0 D V IH0 N D',
 '',
 'wrd: jarman',
 'exp: JH AA1 R M AH0 N',
 'got: JH AA1 