We'll write a simple template for seq2seq using Tensorflow. For demonstration, we attack the g2p task. G2p is a task of converting graphemes (spelling) to phonemes (pronunciation). It's a very good source for this purpose as it's simple enough for you to up and run. If you want to know more about g2p, see my [repo](https://github.com/kyubyong/g2p)

In [1]:
__author__ = "kyubyong"
__address__ = "https://github.com/kyubyong/nlp_made_easy"
__email__ = "kbpark.linguist@gmail.com"

In [2]:
import numpy as np
import tensorflow as tf
from tqdm import tqdm_notebook as tqdm
from distance import levenshtein
import os
import math

In [3]:
tf.__version__

'1.12.0'

# Hyperparameters

In [65]:
class Hparams:
    batch_size = 128
    enc_maxlen = 20
    dec_maxlen = 20
    num_epochs = 10
    hidden_units = 128
    graphemes = ["<pad>", "<unk>", "</s>"] + list("abcdefghijklmnopqrstuvwxyz")
    phonemes = ["<pad>", "<unk>", "<s>", "</s>"] + ['AA0', 'AA1', 'AA2', 'AE0', 'AE1', 'AE2', 'AH0', 'AH1', 'AH2', 'AO0',
                    'AO1', 'AO2', 'AW0', 'AW1', 'AW2', 'AY0', 'AY1', 'AY2', 'B', 'CH', 'D', 'DH',
                    'EH0', 'EH1', 'EH2', 'ER0', 'ER1', 'ER2', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH',
                    'IH0', 'IH1', 'IH2', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW0', 'OW1',
                    'OW2', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH0', 'UH1', 'UH2', 'UW',
                    'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH']
    lr = 0.001
    eval_steps = 500
    logdir = "log/04"
hp = Hparams()

# Prepare Data

In [66]:
import nltk
# nltk.download('cmudict')# <- if you haven't downloaded, do this.
from nltk.corpus import cmudict
cmu = cmudict.dict()

In [67]:
def load_vocab():
    g2idx = {g: idx for idx, g in enumerate(hp.graphemes)}
    idx2g = {idx: g for idx, g in enumerate(hp.graphemes)}

    p2idx = {p: idx for idx, p in enumerate(hp.phonemes)}
    idx2p = {idx: p for idx, p in enumerate(hp.phonemes)}

    return g2idx, idx2g, p2idx, idx2p # note that g and p mean grapheme and phoneme, respectively.

In [68]:
def prepare_data():
    words = [" ".join(list(word)) for word, prons in cmu.items()]
    prons = [" ".join(prons[0]) for word, prons in cmu.items()]
    indices = list(range(len(words)))
    from random import shuffle
    shuffle(indices)
    words = [words[idx] for idx in indices]
    prons = [prons[idx] for idx in indices]
    num_train, num_test = int(len(words)*.8), int(len(words)*.1)
    train_words, eval_words, test_words = words[:num_train], \
                                          words[num_train:-num_test],\
                                          words[-num_test:]
    train_prons, eval_prons, test_prons = prons[:num_train], \
                                          prons[num_train:-num_test],\
                                          prons[-num_test:]    
    return train_words, eval_words, test_words, train_prons, eval_prons, test_prons

In [69]:
train_words, eval_words, test_words, train_prons, eval_prons, test_prons = prepare_data()
print(train_words[0])
print(train_prons[0])

q u a l i t a t i v e
K W AA1 L AH0 T EY2 T IH0 V


In [70]:
def drop_lengthy_samples(words, prons, enc_maxlen, dec_maxlen):
    """We only include such samples less than maxlen."""
    _words, _prons = [], []
    for w, p in zip(words, prons):
        if len(w.split()) + 1 > enc_maxlen: continue
        if len(p.split()) + 1 > dec_maxlen: continue # 1: <EOS>
        _words.append(w)
        _prons.append(p)
    return _words, _prons

In [71]:
train_words, train_prons = drop_lengthy_samples(train_words, train_prons, hp.enc_maxlen, hp.dec_maxlen)
# We do NOT apply this constraint to eval and test datasets.

# Data Loader

In [72]:
def encode(inp, type, dict):
    '''type: "x" or "y"'''
    inp_str = inp.decode("utf-8")
    if type=="x": tokens = inp_str.split() + ["</s>"]
    else: tokens = ["<s>"] + inp_str.split() + ["</s>"]

    x = [dict.get(t, dict["<unk>"]) for t in tokens]
    return x
    

In [73]:
def generator_fn(words, prons):
    '''
    words: 1d byte array. e.g., [b"w o r d", ]
    prons: 1d byte array. e.g., [b'W ER1 D', ]
    
    yields
    xs: tuple of
        x: list of encoded x. encoder input
        x_seqlen: scalar.
        word: string
        
    ys: tuple of
        decoder_input: list of decoder inputs
        y: list of encoded y. label.
        y_seqlen: scalar.
        pron: string
    '''
    g2idx, idx2g, p2idx, idx2p = load_vocab()
    for word, pron in zip(words, prons):
        x = encode(word, "x", g2idx)
        y = encode(pron, "y", p2idx)
        decoder_input, y = y[:-1], y[1:]

        x_seqlen, y_seqlen = len(x), len(y)
        yield (x, x_seqlen, word), (decoder_input, y, y_seqlen, pron)
        

In [74]:
def input_fn(words, prons, batch_size, shuffle=False):
    '''Batchify data
    words: list of words. e.g., ["word", ]
    prons: list of prons. e.g., ['W ER1 D',]
    batch_size: scalar.
    shuffle: boolean
    '''
    shapes = ( ([None], (), ()),
               ([None], [None], (), ())  )
    types = (  (tf.int32, tf.int32, tf.string),
               (tf.int32, tf.int32, tf.int32, tf.string)  )
    paddings = (  (0, 0, ''),
                  (0, 0, 0, '')  )

    dataset = tf.data.Dataset.from_generator(
        generator_fn,
        output_shapes=shapes,
        output_types=types,
        args=(words, prons)) # <- converted to np string arrays
        
    if shuffle:
        dataset = dataset.shuffle(128*batch_size)    
    dataset = dataset.repeat() # iterate forever
    dataset = dataset.padded_batch(batch_size, shapes, paddings).prefetch(1)

    return dataset

In [75]:
def get_batch(words, prons, batch_size, shuffle=False):
    '''Gets training / evaluation mini-batches
    fpath1: source file path. string.
    fpath2: target file path. string.
    maxlen1: source sent maximum length. scalar.
    maxlen2: target sent maximum length. scalar.
    vocab_fpath: string. vocabulary file path.
    batch_size: scalar
    shuffle: boolean

    Returns
    batches
    num_batches: number of mini-batches
    num_samples
    '''
    batches = input_fn(words, prons, batch_size, shuffle=shuffle)
    num_batches = calc_num_batches(len(words), batch_size)
    return batches, num_batches, len(words)


# Model

In [76]:
def convert_idx_to_token_tensor(inputs, idx2token):
    '''Converts int32 tensor to string tensor.
    inputs: 1d int32 tensor. indices.
    idx2token: dictionary

    Returns
    1d string tensor.
    '''
    def my_func(inputs):
        return " ".join(idx2token[elem] for elem in inputs)

    return tf.py_func(my_func, [inputs], tf.string)


In [77]:
class Net:
    def __init__(self, hp):
        self.g2idx, self.idx2g, self.p2idx, self.idx2p = load_vocab()
        self.hp = hp
    
    def encode(self, xs):
        '''
        xs: tupple of 
            x: (N, T). int32
            seqlens: (N,). int32
            words: (N,). string
            
        returns
        last hidden: (N, hidden_units). float32
        words: (N,). string
        '''
        with tf.variable_scope("encode", reuse=tf.AUTO_REUSE):
            x, seqlens, words = xs
            x = tf.one_hot(x, len(self.g2idx))
            cell = tf.contrib.rnn.GRUCell(self.hp.hidden_units)
            _, last_hidden = tf.nn.dynamic_rnn(cell, x, seqlens, dtype=tf.float32)
            
        return last_hidden, words
        
    
    def decode(self, ys, h0=None):
        '''
        ys: tupple of 
            decoder_inputs: (N, T). int32
            y: (N, T). int32
            seqlens: (N,). int32
            prons: (N,). string.
        h0: initial hidden state. (N, hidden_units)
        
        returns
        logits: (N, T, len(p2idx)). float32. before softmax
        y_hat: (N, T). int32.
        y: (N, T). int32. label.
        prons: (N,). string. ground truth phonemes 
        last_hidden: (N, hidden_units). This is for autoregressive inference
        '''
        decoder_inputs, y, seqlens, prons = ys
       
        with tf.variable_scope("decode", reuse=tf.AUTO_REUSE):
            inputs = tf.one_hot(decoder_inputs, len(self.p2idx))
            
            cell = tf.contrib.rnn.GRUCell(self.hp.hidden_units)
            outputs, last_hidden = tf.nn.dynamic_rnn(cell, inputs, initial_state=h0, dtype=tf.float32)

            # projection
            logits = tf.layers.dense(outputs, len(self.p2idx))
            y_hat = tf.to_int32(tf.argmax(logits, axis=-1))
        
        return logits, y_hat, y, prons, last_hidden
            
    def train(self, xs, ys):
        # forward
        last_hidden, words = self.encode(xs)
        logits, y_hat, y, prons, last_hidden = self.decode(ys, h0=last_hidden)
        
        # train scheme
        ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y)
        nonpadding = tf.to_float(tf.not_equal(y, self.p2idx["<pad>"])) # 0: <pad>
        loss = tf.reduce_sum(ce*nonpadding) / (tf.reduce_sum(nonpadding)+1e-7)

        global_step = tf.train.get_or_create_global_step()
        train_op = tf.train.AdamOptimizer(hp.lr).minimize(loss, global_step=global_step)
        
        return loss, train_op, global_step

    
    def eval(self, xs, ys):
        '''Predicts autoregressively
        At inference input ys is ignored.
        Returns
        y_hat: (N, T2)
        '''
        decoder_inputs, y, seqlens, prons = ys
        decoder_inputs = tf.ones((tf.shape(xs[0])[0], 1), tf.int32) * self.p2idx["<s>"]
        ys = (decoder_inputs, y, seqlens, prons)

        last_hidden, words = self.encode(xs)

        
        h0 = last_hidden
        y_hats = []
        print("Inference graph is being built. Please be patient.")
        for t in tqdm(range(self.hp.dec_maxlen)):
            _, y_hat, _, _, h0 = self.decode(ys, h0)
            if tf.reduce_sum(y_hat, 1)==0: break
           
            ys = (y_hat, y, seqlens, prons)
            y_hats.append(tf.squeeze(y_hat))
        y_hats = tf.stack(y_hats, 1)
        
        # monitor a random sample
        n = tf.random_uniform((), 0, tf.shape(y_hats)[0]-1, tf.int32)
        word = words[n]
        pred = convert_idx_to_token_tensor(y_hats[n], self.idx2p)
        pron = prons[n]
        
        return y_hats, word, pred, pron

    

# Train & Evaluate

In [78]:
def calc_num_batches(total_num, batch_size):
    return total_num // batch_size + int(total_num % batch_size != 0) 

In [79]:
# evaluation metric
def per(ref, hyp):
    '''Calc phoneme error rate
    hyp: list of predicted phoneme sequences. e.g., [["B", "L", "AA1", "K", "HH", "AW2", "S"], ...]
    ref: list of ground truth phoneme sequences. e.g., [["B", "L", "AA1", "K", "HH", "AW2", "S"], ...]
    '''
    num_phonemes, num_erros = 0, 0
    g2idx, idx2g, p2idx, idx2p = load_vocab()
    for r, h in zip(ref, hyp):
        r = r.split()
        h = " ".join(idx2p[idx] for idx in h)
        h = h.split("</s>")[0].strip().split()
        
        num_phonemes += len(r)
        num_erros += levenshtein(h, r)
#         print(h, r)
    per = round(num_erros / num_phonemes, 2)
    return per

In [80]:
tf.reset_default_graph()
# prepare batches
train_batches, num_train_batches, num_train_samples = get_batch(train_words, train_prons,
                         hp.batch_size, shuffle=True)
eval_batches, num_eval_batches, num_eval_samples = get_batch(eval_words, eval_prons,
                         hp.batch_size, shuffle=False)

In [81]:
# create a iterator of the correct shape and type
iter = tf.data.Iterator.from_structure(train_batches.output_types, train_batches.output_shapes)

# create the initialisation operations
train_init_op = iter.make_initializer(train_batches)
eval_init_op = iter.make_initializer(eval_batches)

In [82]:
# variable specs
def print_variable_specs(fpath):
    def get_size(shp):
        size = 1
        for d in range(len(shp)):
            size *=shp[d]
        return size

    params, num_params = [], 0
    for v in tf.global_variables():
        params.append("{}==={}\n".format(v.name, v.shape))
        num_params += get_size(v.shape)
    print("num_params:", num_params)
#     with open(fpath, 'w') as fout:
#         fout.write("num_params: {}\n".format(num_params))
#         fout.write("\n".join(params))

In [83]:
# Load model
net = Net(hp)
xs, ys = iter.get_next()
loss, train_op, global_step = net.train(xs, ys)
y_hat, word, pred, pron = net.eval(xs, ys)

Inference graph is being built. Please be patient.


HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

In [84]:
# Session
saver = tf.train.Saver()
with tf.Session() as sess:
    ckpt = tf.train.latest_checkpoint(hp.logdir)
    if ckpt is None:
        sess.run(tf.global_variables_initializer())
        print("Variables initialized")
    else:
        saver.restore(sess, ckpt)
        print("Restored from file: ", ckpt)

    print_variable_specs('specs')

    sess.run(train_init_op)
    total_steps = hp.num_epochs*num_train_batches
    _gs = sess.run(global_step)
    for _ in tqdm(range(_gs, total_steps+1)):
        # training
        _, _gs, _loss = sess.run([train_op, global_step,loss]) 

        epoch = math.ceil(_gs / num_train_batches)
            
        if _gs and _gs % num_train_batches == 0: # Be careful that you should evaluate at every epoch due to train_init_op
            print("epoch=", epoch, "is done!")
            sess.run(eval_init_op)
            _y_hats = []
            for _ in range(num_eval_batches):
                _y_hat, _word, _pred, _pron = sess.run([y_hat, word, pred, pron])
                _y_hats.extend(_y_hat.tolist())
                
            # sample monitor
            print("wrd:", _word.decode("utf-8"))
            print("exp:", _pron.decode("utf-8"))
            print("got:", _pred.decode("utf-8"))
                
            
            _per = per(eval_prons, _y_hats)
            print("per=%.2f"%_per)
            print()
                  
            sess.run(train_init_op)
            
            # save
            if not os.path.exists(hp.logdir): os.makedirs(hp.logdir)
            fname = os.path.join(hp.logdir, "my_model_loss_%.2f_per_%.2f" % (_loss, _per))
            saver.save(sess, fname, global_step=_gs)
   
    print("Training Done!")

Variables initialized
num_params: 444513


HBox(children=(IntProgress(value=0, max=7721), HTML(value='')))

epoch= 1 is done!
wrd: m a p e l
exp: M AE1 P AH0 L
got: M AE1 P L </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s>
per=0.56

epoch= 2 is done!
wrd: s t e a r i c
exp: S T IY1 R IH0 K
got: S T EH1 R IY0 </s> </s> </s> </s> </s> </s> Z </s> </s> N </s> Z </s> N </s>
per=0.41

epoch= 3 is done!
wrd: s c o l d e d
exp: S K OW1 L D AH0 D
got: S K OW1 D L D </s> </s> </s> </s> </s> L T V UW0 T </s> UW0 V UW0
per=0.34

epoch= 4 is done!
wrd: s c o l d e d
exp: S K OW1 L D AH0 D
got: S K OW1 L D </s> </s> D </s> </s> L </s> T L T V EY0 T W EH1
per=0.30

epoch= 5 is done!
wrd: n o a
exp: N OW1 AH0
got: N OW1 </s> </s> </s> </s> </s> EH1 L EH1 N T </s> EH1 T </s> </s> EH1 T </s>
per=0.26

epoch= 6 is done!
wrd: c o n f i d e n t i a l l y
exp: K AA2 N F AH0 D EH1 N SH AH0 L IY0
got: K AH0 N F EH1 D AH0 N T AH0 L IY0 </s> </s> </s> </s> </s> </s> AW1 </s>
per=0.24

epoch= 7 is done!
wrd: d o l i n g
exp: D OW1 L IH0 NG
got: D OW1 L IH0 NG </s> </s> </s> </s> G </s>

# Inference

In [85]:
tf.reset_default_graph()
test_batches, num_test_batches, num_test_samples  = get_batch(test_words, test_prons,
                                                              hp.batch_size,
                                                              shuffle=False)
iter = tf.data.Iterator.from_structure(test_batches.output_types, test_batches.output_shapes)

# create the initialisation operations
test_init_op = iter.make_initializer(test_batches)

In [88]:
# Load model
xs, ys = iter.get_next()
net = Net(hp)
y_hat, _, _, _ = net.eval(xs, ys)

Inference graph is being built. Please be patient.


HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

In [89]:
# saver for restoration
ckpt = tf.train.latest_checkpoint(hp.logdir)
print(ckpt)
# saver = tf.train.import_meta_graph(ckpt + ".meta")# <- Do NOT use this as we'll use a distinct graph.
saver = tf.train.Saver()
   
with tf.Session() as sess:
    
    saver.restore(sess, ckpt); print("checkpoint restored") 
    sess.run(test_init_op)

    _y_hats = []
    for _ in range(num_test_batches):
        _y_hat = sess.run(y_hat)
        _y_hats.extend(_y_hat.tolist())
            
    _per = per(test_prons, _y_hats)
            
    print("per=%.2f"%_per)
    
    # save
    g2idx, idx2g, p2idx, idx2p = load_vocab()
    
    with open("result", 'w') as fout:
        fout.write("per: %.2f\n" % _per)
        for w, r, h in zip(test_words, test_prons, _y_hats):
            w = w.replace(" ", "")
            h = " ".join(idx2p[idx] for idx in h)
            h = h.split("</s>")[0].strip()
            fout.write("wrd: {}\nexp: {}\ngot: {}\n\n".format(w, r, h))
            
    print("Done!")

log/04/my_model_loss_0.40_per_0.19-7720
INFO:tensorflow:Restoring parameters from log/04/my_model_loss_0.40_per_0.19-7720
checkpoint restored
per=0.20
Done!


Let's see some results.

In [90]:
open('result', 'r').read().splitlines()[-100:]

['wrd: campau',
 'exp: K AA1 M P AW0',
 'got: K AE1 M P OW2',
 '',
 'wrd: tension',
 'exp: T EH1 N SH AH0 N',
 'got: T EH1 N S IY0 AH0 N',
 '',
 'wrd: pithy',
 'exp: P IH1 TH IY0',
 'got: P IH1 TH IY0',
 '',
 'wrd: blaisdell',
 'exp: B L EY1 S D AH0 L',
 'got: B L EY1 S D AH0 L',
 '',
 'wrd: reflectone',
 'exp: R IY0 F L EH1 K T OW2 N',
 'got: R IY0 F L EH1 K T AH0 N',
 '',
 'wrd: cherishing',
 'exp: CH EH1 R IH0 SH IH0 NG',
 'got: CH EH1 R IH0 SH IH0 NG',
 '',
 'wrd: necessitate',
 'exp: N AH0 S EH1 S AH0 T EY2 T',
 'got: N EH2 S AH0 S EH1 T IH0 T',
 '',
 'wrd: swiatkowski',
 'exp: S V IY0 AH0 T K AO1 F S K IY0',
 'got: S W IH0 T AO1 K S W IH0 K',
 '',
 'wrd: tendons',
 'exp: T EH1 N D AH0 N Z',
 'got: T EH1 N D AH0 N Z',
 '',
 'wrd: nucleonic',
 'exp: N UW2 K L IY0 AA1 N IH0 K',
 'got: N AH0 K L EH1 N IH0 K',
 '',
 'wrd: nutone',
 'exp: N UW1 T OW2 N',
 'got: N UW1 T OW2 N',
 '',
 'wrd: demaree',
 'exp: D EH0 M ER0 IY1',
 'got: D IH0 M AA1 R IY0',
 '',
 'wrd: soltau',
 'exp: S OW1 L 