We'll write a simple template for seq2seq using Tensorflow. For demonstration, we attack the g2p task. G2p is a task of converting graphemes (spelling) to phonemes (pronunciation). It's a very good source for this purpose as it's simple enough for you to up and run. If you want to know more about g2p, see my [repo](https://github.com/kyubyong/g2p)

In [1]:
__author__ = "kyubyong"
__address__ = "https://github.com/kyubyong/nlp_made_easy"
__email__ = "kbpark.linguist@gmail.com"

In [2]:
import numpy as np
import tensorflow as tf
from tqdm import tqdm_notebook as tqdm
from distance import levenshtein
import os
import math

In [3]:
tf.__version__

'1.12.0'

# Hyperparameters

In [33]:
class Hparams:
    batch_size = 128
    enc_maxlen = 20
    dec_maxlen = 20
    num_epochs = 10
    hidden_units = 128
    graphemes = ["<pad>", "<unk>", "</s>"] + list("abcdefghijklmnopqrstuvwxyz")
    phonemes = ["<pad>", "<unk>", "<s>", "</s>"] + ['AA0', 'AA1', 'AA2', 'AE0', 'AE1', 'AE2', 'AH0', 'AH1', 'AH2', 'AO0',
                    'AO1', 'AO2', 'AW0', 'AW1', 'AW2', 'AY0', 'AY1', 'AY2', 'B', 'CH', 'D', 'DH',
                    'EH0', 'EH1', 'EH2', 'ER0', 'ER1', 'ER2', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH',
                    'IH0', 'IH1', 'IH2', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW0', 'OW1',
                    'OW2', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH0', 'UH1', 'UH2', 'UW',
                    'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH']
    lr = 0.001
    eval_steps = 500
    logdir = "log/02"
hp = Hparams()

# Prepare Data

In [5]:
import nltk
# nltk.download('cmudict')# <- if you haven't downloaded, do this.
from nltk.corpus import cmudict
cmu = cmudict.dict()

In [6]:
def load_vocab():
    g2idx = {g: idx for idx, g in enumerate(hp.graphemes)}
    idx2g = {idx: g for idx, g in enumerate(hp.graphemes)}

    p2idx = {p: idx for idx, p in enumerate(hp.phonemes)}
    idx2p = {idx: p for idx, p in enumerate(hp.phonemes)}

    return g2idx, idx2g, p2idx, idx2p # note that g and p mean grapheme and phoneme, respectively.

In [7]:
def prepare_data():
    words = [" ".join(list(word)) for word, prons in cmu.items()]
    prons = [" ".join(prons[0]) for word, prons in cmu.items()]
    indices = list(range(len(words)))
    from random import shuffle
    shuffle(indices)
    words = [words[idx] for idx in indices]
    prons = [prons[idx] for idx in indices]
    num_train, num_test = int(len(words)*.8), int(len(words)*.1)
    train_words, eval_words, test_words = words[:num_train], \
                                          words[num_train:-num_test],\
                                          words[-num_test:]
    train_prons, eval_prons, test_prons = prons[:num_train], \
                                          prons[num_train:-num_test],\
                                          prons[-num_test:]    
    return train_words, eval_words, test_words, train_prons, eval_prons, test_prons

In [8]:
train_words, eval_words, test_words, train_prons, eval_prons, test_prons = prepare_data()
print(train_words[0])
print(train_prons[0])

c o n s e c r a t e d
K AA1 N S AH0 K R EY2 T AH0 D


In [9]:
def drop_lengthy_samples(words, prons, enc_maxlen, dec_maxlen):
    """We only include such samples less than maxlen."""
    _words, _prons = [], []
    for w, p in zip(words, prons):
        if len(w.split()) + 1 > enc_maxlen: continue
        if len(p.split()) + 1 > dec_maxlen: continue # 1: <EOS>
        _words.append(w)
        _prons.append(p)
    return _words, _prons

In [10]:
train_words, train_prons = drop_lengthy_samples(train_words, train_prons, hp.enc_maxlen, hp.dec_maxlen)
# We do NOT apply this constraint to eval and test datasets.

# Data Loader

In [11]:
def encode(inp, type, dict):
    '''type: "x" or "y"'''
    inp_str = inp.decode("utf-8")
    if type=="x": tokens = inp_str.split() + ["</s>"]
    else: tokens = ["<s>"] + inp_str.split() + ["</s>"]

    x = [dict.get(t, dict["<unk>"]) for t in tokens]
    return x
    

In [12]:
def generator_fn(words, prons):
    '''
    words: 1d byte array. e.g., [b"w o r d", ]
    prons: 1d byte array. e.g., [b'W ER1 D', ]
    
    yields
    xs: tuple of
        x: list of encoded x. encoder input
        x_seqlen: scalar.
        word: string
        
    ys: tuple of
        decoder_input: list of decoder inputs
        y: list of encoded y. label.
        y_seqlen: scalar.
        pron: string
    '''
    g2idx, idx2g, p2idx, idx2p = load_vocab()
    for word, pron in zip(words, prons):
        x = encode(word, "x", g2idx)
        y = encode(pron, "y", p2idx)
        decoder_input, y = y[:-1], y[1:]

        x_seqlen, y_seqlen = len(x), len(y)
        yield (x, x_seqlen, word), (decoder_input, y, y_seqlen, pron)
        

In [13]:
def input_fn(words, prons, batch_size, shuffle=False):
    '''Batchify data
    words: list of words. e.g., ["word", ]
    prons: list of prons. e.g., ['W ER1 D',]
    batch_size: scalar.
    shuffle: boolean
    '''
    shapes = ( ([None], (), ()),
               ([None], [None], (), ())  )
    types = (  (tf.int32, tf.int32, tf.string),
               (tf.int32, tf.int32, tf.int32, tf.string)  )
    paddings = (  (0, 0, ''),
                  (0, 0, 0, '')  )

    dataset = tf.data.Dataset.from_generator(
        generator_fn,
        output_shapes=shapes,
        output_types=types,
        args=(words, prons)) # <- converted to np string arrays
        
    if shuffle:
        dataset = dataset.shuffle(64*batch_size)
    
    dataset = dataset.repeat() # iterate forever
    dataset = dataset.padded_batch(batch_size, shapes, paddings).prefetch(1)

    return dataset

In [14]:
def get_batch(words, prons, batch_size, shuffle=False):
    '''Gets training / evaluation mini-batches
    fpath1: source file path. string.
    fpath2: target file path. string.
    maxlen1: source sent maximum length. scalar.
    maxlen2: target sent maximum length. scalar.
    vocab_fpath: string. vocabulary file path.
    batch_size: scalar
    shuffle: boolean

    Returns
    batches
    num_batches: number of mini-batches
    num_samples
    '''
    batches = input_fn(words, prons, batch_size, shuffle=shuffle)
    num_batches = calc_num_batches(len(words), batch_size)
    return batches, num_batches, len(words)


# Model

In [15]:
def convert_idx_to_token_tensor(inputs, idx2token):
    '''Converts int32 tensor to string tensor.
    inputs: 1d int32 tensor. indices.
    idx2token: dictionary

    Returns
    1d string tensor.
    '''
    def my_func(inputs):
        return " ".join(idx2token[elem] for elem in inputs)

    return tf.py_func(my_func, [inputs], tf.string)


In [16]:
class Net:
    def __init__(self, hp):
        self.g2idx, self.idx2g, self.p2idx, self.idx2p = load_vocab()
        self.hp = hp
    
    def encode(self, xs):
        '''
        xs: tupple of 
            x: (N, T). int32
            seqlens: (N,). int32
            words: (N,). string
            
        returns
        last hidden: (N, hidden_units). float32
        words: (N,). string
        '''
        with tf.variable_scope("encode", reuse=tf.AUTO_REUSE):
            x, seqlens, words = xs
            x = tf.one_hot(x, len(self.g2idx))
            cell = tf.contrib.rnn.GRUCell(self.hp.hidden_units)
            _, last_hidden = tf.nn.dynamic_rnn(cell, x, seqlens, dtype=tf.float32)
            
        return last_hidden, words
        
    
    def decode(self, ys, h0=None):
        '''
        ys: tupple of 
            decoder_inputs: (N, T). int32
            y: (N, T). int32
            seqlens: (N,). int32
            prons: (N,). string.
        h0: initial hidden state. (N, hidden_units)
        
        returns
        logits: (N, T, len(p2idx)). float32. before softmax
        y_hat: (N, T). int32.
        y: (N, T). int32. label.
        prons: (N,). string. ground truth phonemes 
        last_hidden: (N, hidden_units). This is for autoregressive inference
        '''
        decoder_inputs, y, seqlens, prons = ys
       
        with tf.variable_scope("decode", reuse=tf.AUTO_REUSE):
            inputs = tf.one_hot(decoder_inputs, len(self.p2idx))
            
            cell = tf.contrib.rnn.GRUCell(self.hp.hidden_units)
            outputs, last_hidden = tf.nn.dynamic_rnn(cell, inputs, initial_state=h0, dtype=tf.float32)

            # projection
            logits = tf.layers.dense(outputs, len(self.p2idx))
            y_hat = tf.to_int32(tf.argmax(logits, axis=-1))
        
        return logits, y_hat, y, prons, last_hidden
            
    def train(self, xs, ys):
        # forward
        last_hidden, words = self.encode(xs)
        logits, y_hat, y, prons, last_hidden = self.decode(ys, h0=last_hidden)
        
        # train scheme
        ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y)
        nonpadding = tf.to_float(tf.not_equal(y, self.p2idx["<pad>"])) # 0: <pad>
        loss = tf.reduce_sum(ce*nonpadding) / (tf.reduce_sum(nonpadding)+1e-7)

        global_step = tf.train.get_or_create_global_step()
        train_op = tf.train.AdamOptimizer(hp.lr).minimize(loss, global_step=global_step)
        
        return loss, train_op, global_step
    
    def eval(self, xs, ys):
        '''Evaluates in the teacher-forcing manner
        Returns
        acc: accuracy. float.
        summaries: evaluation summary node
        '''
        # forward
        last_hidden, words = self.encode(xs)
        logits, y_hat, y, prons, last_hidden = self.decode(ys, h0=last_hidden)

        # we evaluate based on acc
        # note that this is more or less different from the real test
        # because this is calculated from teacher forcing
        hits = tf.to_float(tf.equal(y_hat, y))
        nonpadding = tf.to_float(tf.not_equal(y, self.p2idx["<pad>"]))
        acc = tf.reduce_sum(hits*nonpadding) / ( tf.reduce_sum(nonpadding)+1e-7 )
        
        # monitor a random sample
        n = tf.random_uniform((), 0, tf.shape(y_hat)[0]-1, tf.int32)
        word = words[n]
        pred = convert_idx_to_token_tensor(y_hat[n], self.idx2p)
        pron = prons[n]

        return acc, word, pred, pron

    
    def infer(self, xs, ys):
        '''Predicts autoregressively
        input ys is ignored.
        Returns
        y_hat: (N, T2)
        '''
        decoder_inputs = tf.ones((tf.shape(xs[0])[0], 1), tf.int32) * self.p2idx["<s>"]
        ys = (decoder_inputs, None, None, None)

        last_hidden, words = self.encode(xs)

        
        h0 = last_hidden
        y_hats = []
        print("Inference graph is being built. Please be patient.")
        for t in tqdm(range(self.hp.dec_maxlen)):
            _, y_hat, _, _, h0 = self.decode(ys, h0)
            if tf.reduce_sum(y_hat, 1)==0: break
           
            ys = (y_hat, None, None, None)
            y_hats.append(tf.squeeze(y_hat))
        y_hats = tf.stack(y_hats, 1)
        return y_hats

    

# Train & Evaluate

In [17]:
def calc_num_batches(total_num, batch_size):
    return total_num // batch_size + int(total_num % batch_size != 0) 

In [47]:
# evaluation metric
def per(ref, hyp):
    '''Calc phoneme error rate
    hyp: list of predicted phoneme sequences. e.g., [["B", "L", "AA1", "K", "HH", "AW2", "S"], ...]
    ref: list of ground truth phoneme sequences. e.g., [["B", "L", "AA1", "K", "HH", "AW2", "S"], ...]
    '''
    num_phonemes, num_erros = 0, 0
    g2idx, idx2g, p2idx, idx2p = load_vocab()
    for r, h in zip(ref, hyp):
        r = r.split()
        h = " ".join(idx2p[idx] for idx in h)
        h = h.split("</s>")[0].strip().split()
        
        num_phonemes += len(r)
        num_erros += levenshtein(h, r)
#         print(h, r)
    per = round(num_erros / num_phonemes, 2)
    return per

In [19]:
tf.reset_default_graph()
# prepare batches
train_batches, num_train_batches, num_train_samples = get_batch(train_words, train_prons,
                         hp.batch_size, shuffle=True)
eval_batches, num_eval_batches, num_eval_samples = get_batch(eval_words, eval_prons,
                         hp.batch_size, shuffle=False)

In [20]:
# create a iterator of the correct shape and type
iter = tf.data.Iterator.from_structure(train_batches.output_types, train_batches.output_shapes)

# create the initialisation operations
train_init_op = iter.make_initializer(train_batches)
eval_init_op = iter.make_initializer(eval_batches)

In [21]:
# variable specs
def print_variable_specs(fpath):
    def get_size(shp):
        size = 1
        for d in range(len(shp)):
            size *=shp[d]
        return size

    params, num_params = [], 0
    for v in tf.global_variables():
        params.append("{}==={}\n".format(v.name, v.shape))
        num_params += get_size(v.shape)
    print("num_params:", num_params)
#     with open(fpath, 'w') as fout:
#         fout.write("num_params: {}\n".format(num_params))
#         fout.write("\n".join(params))

In [22]:
# Load model
net = Net(hp)
xs, ys = iter.get_next()
loss, train_op, global_step = net.train(xs, ys)
acc, word, pred, pron = net.eval(xs, ys)
y_hat = net.infer(xs, ys)

Inference graph is being built. Please be patient.


HBox(children=(IntProgress(value=0, max=20), HTML(value='')))




In [34]:
# Session
saver = tf.train.Saver()
with tf.Session() as sess:
    ckpt = tf.train.latest_checkpoint(hp.logdir)
    if ckpt is None:
        sess.run(tf.global_variables_initializer())
        print("Variables initialized")
    else:
        saver.restore(sess, ckpt)
        print("Restored from file: ", ckpt)

    print_variable_specs('specs')

    sess.run(train_init_op)
    total_steps = hp.num_epochs*num_train_batches
    _gs = sess.run(global_step)
    for _ in tqdm(range(_gs, total_steps+1)):
        # training
        _, _gs, _loss = sess.run([train_op, global_step,loss]) 

        epoch = math.ceil(_gs / num_train_batches)
            
        if _gs and _gs % hp.eval_steps==0:
            print("# evaluation / sanity check")
            _loss = sess.run(loss) # training loss
            
            sess.run(eval_init_op)
            _acc, _word, _pred, _pron = sess.run([acc, word, pred, pron])

            # monitoring
            print("global step=", _gs)
            print("training loss=", _loss)
            print("eval acc.=", _acc)
            print("wrd =", _word.decode("utf-8"))
            print("exp =", _pron.decode("utf-8"))
            print("got =", _pred.decode("utf-8"))
            print()

            sess.run(train_init_op)

            
        if _gs and _gs % num_train_batches == 0:
            sess.run(eval_init_op)
            _y_hats = []
            for _ in range(num_eval_batches):
                _y_hat = sess.run(y_hat)
                _y_hats.extend(_y_hat.tolist())
            
            _per = per(eval_prons, _y_hats)
            
            print("epoch=", epoch, "is done!")
            print("global step=", _gs)
            print("per=%.2f"%_per)
            print()
                  
            sess.run(train_init_op)
            
            # save
            if not os.path.exists(hp.logdir): os.makedirs(hp.logdir)
            fname = os.path.join(hp.logdir, "my_model_loss_%.2f_per_%.2f" % (_loss, _per))
            saver.save(sess, fname, global_step=_gs)
   
    print("Training Done!")

Variables initialized
num_params: 444513


HBox(children=(IntProgress(value=0, max=7721), HTML(value='')))

# evaluation / sanity check
global step= 500
training loss= 1.8728248
eval acc.= 0.49479166
wrd = a t a t u r k
exp = AE1 T AH0 T ER2 K
got = T T T T ER0 D </s> </s> </s> </s> </s> </s> </s> </s> </s>

epoch= 1 is done!
global step= 772
per=0.60

# evaluation / sanity check
global step= 1000
training loss= 1.2719252
eval acc.= 0.65729165
wrd = r i c k w a r d
exp = R IH1 K W ER0 D
got = R IH1 K ER0 ER0 D </s> </s> </s> </s> </s> </s> </s> </s> </s>

# evaluation / sanity check
global step= 1500
training loss= 1.0375587
eval acc.= 0.73645836
wrd = s i e b e
exp = S IY1 B
got = S IH1 B IY0 </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s>

epoch= 2 is done!
global step= 1544
per=0.43

# evaluation / sanity check
global step= 2000
training loss= 0.7594987
eval acc.= 0.78020835
wrd = l o o p e r
exp = L UW1 P ER0
got = L OW1 P ER0 </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s>

epoch= 3 is done!
global step= 2316
per=0.35

# evaluation / sanity check
global step= 2500
training

# Inference

In [38]:
get_batch?

In [40]:
tf.reset_default_graph()
test_batches, num_test_batches, num_test_samples  = get_batch(test_words, test_prons,
                                                              hp.batch_size,
                                                              shuffle=False)
iter = tf.data.Iterator.from_structure(test_batches.output_types, test_batches.output_shapes)

# create the initialisation operations
test_init_op = iter.make_initializer(test_batches)

In [41]:
# Load model
xs, ys = iter.get_next()
net = Net(hp)
y_hat = net.infer(xs, ys)

Inference graph is being built. Please be patient.


HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

In [51]:
# saver for restoration
ckpt = tf.train.latest_checkpoint(hp.logdir)
print(ckpt)
# saver = tf.train.import_meta_graph(ckpt + ".meta")# <- Do NOT use this as we'll use a distinct graph.
saver = tf.train.Saver()
   
with tf.Session() as sess:
    
    saver.restore(sess, ckpt); print("checkpoint restored") 
    sess.run(test_init_op)

    _y_hats = []
    for _ in range(num_test_batches):
        _y_hat = sess.run(y_hat)
        _y_hats.extend(_y_hat.tolist())
            
    _per = per(test_prons, _y_hats)
            
    print("per=%.2f"%_per)
    
    # save
    g2idx, idx2g, p2idx, idx2p = load_vocab()
    
    with open("result", 'w') as fout:
        fout.write("per: %.2f\n" % _per)
        for w, r, h in zip(test_words, test_prons, _y_hats):
            w = w.replace(" ", "")
            h = " ".join(idx2p[idx] for idx in h)
            h = h.split("</s>")[0].strip()
            fout.write("wrd: {}\nexp: {}\ngot: {}\n\n".format(w, r, h))
            
    print("Done!")

log/02/my_model_loss_0.40_per_0.20-7720
INFO:tensorflow:Restoring parameters from log/02/my_model_loss_0.40_per_0.20-7720
checkpoint restored
per=0.20
Done!


Let's see some results.

In [52]:
open('result', 'r').read().splitlines()[-100:]

['wrd: thain',
 'exp: TH EY1 N',
 'got: TH EY1 N',
 '',
 'wrd: decommission',
 'exp: D IY0 K AH0 M IH1 SH AH0 N',
 'got: D IH0 K AH0 M S IH1 SH AH0 N',
 '',
 'wrd: retlin',
 'exp: R EH1 T L IH0 N',
 'got: R EH1 T L IH0 N',
 '',
 'wrd: sherbert',
 'exp: SH ER1 B ER0 T',
 'got: SH ER1 B ER0 T',
 '',
 'wrd: turn',
 'exp: T ER1 N',
 'got: T ER1 N',
 '',
 'wrd: firsthand',
 'exp: F ER0 S T HH AE1 N D',
 'got: F ER1 S T HH AE2 N D',
 '',
 'wrd: compassionately',
 'exp: K AH0 M P AE1 SH AH0 N AH0 T L IY0',
 'got: K AH0 M P AE1 S AH0 N AH0 T EY1 L',
 '',
 'wrd: mcfarlan',
 'exp: M AH0 K F AA1 R L AH0 N',
 'got: M AH0 K F AA1 R L AH0 N',
 '',
 'wrd: venerating',
 'exp: V EH1 N ER0 EY2 T IH0 NG',
 'got: V EH1 N ER0 EY2 T IH0 NG',
 '',
 'wrd: stubbed',
 'exp: S T AH1 B D',
 'got: S T AH1 B D',
 '',
 'wrd: sidgraph',
 'exp: S IH1 D G R AE0 F',
 'got: S IH1 D G R AE2 F',
 '',
 'wrd: cobre',
 'exp: K AA1 B R AH0',
 'got: K AA1 B R',
 '',
 'wrd: homicides',
 'exp: HH AA1 M AH0 S AY2 D Z',
 'got: HH A