We'll write a simple template for seq2seq using PyTorch. For demonstration, we attack the g2p task. G2p is a task of converting graphemes (spelling) to phonemes (pronunciation). It's a very good source for this purpose as it's simple enough for you to up and run. If you want to know more about g2p, see my [repo](https:/github.com/kyubyong/g2p)

In [109]:
__author__ = "kyubyong"
__address__ = "https://github.com/kyubyong/nlp_made_easy"
__email__ = "kbpark.linguist@gmail.com"

In [110]:
import numpy as np
from tqdm import tqdm_notebook as tqdm
from distance import levenshtein
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils import data
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

In [111]:
torch.__version__

'1.2.0'

# Hyperparameters

In [112]:
class Hparams:
    batch_size = 128
    enc_maxlen = 20
    dec_maxlen = 20
    num_epochs = 10
    hidden_units = 128
    emb_units = 64
    graphemes = ["<pad>", "<unk>", "</s>"] + list("abcdefghijklmnopqrstuvwxyz")
    phonemes = ["<pad>", "<unk>", "<s>", "</s>"] + ['AA0', 'AA1', 'AA2', 'AE0', 'AE1', 'AE2', 'AH0', 'AH1', 'AH2', 'AO0',
                    'AO1', 'AO2', 'AW0', 'AW1', 'AW2', 'AY0', 'AY1', 'AY2', 'B', 'CH', 'D', 'DH',
                    'EH0', 'EH1', 'EH2', 'ER0', 'ER1', 'ER2', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH',
                    'IH0', 'IH1', 'IH2', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW0', 'OW1',
                    'OW2', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH0', 'UH1', 'UH2', 'UW',
                    'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH']
    lr = 0.001
    logdir = "log/01"
hp = Hparams()

# Prepare Data

In [113]:
import nltk
# nltk.download('cmudict')# <- if you haven't downloaded, do this.
from nltk.corpus import cmudict
cmu = cmudict.dict()
cmu["refuse"]

[['R', 'AH0', 'F', 'Y', 'UW1', 'Z'],
 ['R', 'EH1', 'F', 'Y', 'UW2', 'Z'],
 ['R', 'IH0', 'F', 'Y', 'UW1', 'Z']]

In [114]:
def load_vocab():
    g2idx = {g: idx for idx, g in enumerate(hp.graphemes)}
    idx2g = {idx: g for idx, g in enumerate(hp.graphemes)}

    p2idx = {p: idx for idx, p in enumerate(hp.phonemes)}
    idx2p = {idx: p for idx, p in enumerate(hp.phonemes)}

    return g2idx, idx2g, p2idx, idx2p # note that g and p mean grapheme and phoneme, respectively.

g2idx, idx2g, p2idx, idx2p = load_vocab()

In [115]:
def prepare_data():
    words = [" ".join(list(word)) for word, prons in cmu.items()]
    prons = [" ".join(prons[0]) for word, prons in cmu.items()]
    indices = list(range(len(words)))
    from random import shuffle
    shuffle(indices)
    words = [words[idx] for idx in indices]
    prons = [prons[idx] for idx in indices]
    num_train, num_test = int(len(words)*.8), int(len(words)*.1)
    train_words, eval_words, test_words = words[:num_train], \
                                          words[num_train:-num_test],\
                                          words[-num_test:]
    train_prons, eval_prons, test_prons = prons[:num_train], \
                                          prons[num_train:-num_test],\
                                          prons[-num_test:]    
    return train_words, eval_words, test_words, train_prons, eval_prons, test_prons

In [116]:
train_words, eval_words, test_words, train_prons, eval_prons, test_prons = prepare_data()
print(train_words[0])
print(train_prons[0])

f l a p j a c k
F L AE1 P JH AE2 K


In [117]:
def drop_lengthy_samples(words, prons, enc_maxlen, dec_maxlen):
    """We only include such samples less than maxlen."""
    _words, _prons = [], []
    for w, p in zip(words, prons):
        if len(w.split()) + 1 > enc_maxlen: continue
        if len(p.split()) + 1 > dec_maxlen: continue # 1: <EOS>
        _words.append(w)
        _prons.append(p)
    return _words, _prons          

In [118]:
train_words, train_prons = drop_lengthy_samples(train_words, train_prons, hp.enc_maxlen, hp.dec_maxlen)
# We do NOT apply this constraint to eval and test datasets.

# Data Loader

In [119]:
def encode(inp, type, dict):
    '''convert string into ids
    type: "x" or "y"
    dict: g2idx for 'x', p2idx for 'y'
    '''
    if type=="x": tokens = inp.split() + ["</s>"]
    else: tokens = ["<s>"] + inp.split() + ["</s>"]

    x = [dict.get(t, dict["<unk>"]) for t in tokens]
    return x

In [120]:
class G2pDataset(data.Dataset):

    def __init__(self, words, prons):
        """
        words: list of words. e.g., ["w o r d", ]
        prons: list of prons. e.g., ['W ER1 D',]
        """
        self.words = words
        self.prons = prons

    def __len__(self):
        return len(self.words)

    def __getitem__(self, idx):
        word, pron = self.words[idx], self.prons[idx]
        x = encode(word, "x", g2idx)
        y = encode(pron, "y", p2idx)
        decoder_input, y = y[:-1], y[1:]

        x_seqlen, y_seqlen = len(x), len(y)
                
        return x, x_seqlen, word, decoder_input, y, y_seqlen, pron

In [121]:
def pad(batch):
    '''Pads zeros such that the length of all samples in a batch is the same.'''
    f = lambda x: [sample[x] for sample in batch]
    x_seqlens = f(1)
    y_seqlens = f(5)
    words = f(2)
    prons = f(-1)
    
    x_maxlen = np.array(x_seqlens).max()
    y_maxlen = np.array(y_seqlens).max()
    
    f = lambda x, maxlen, batch: [sample[x]+[0]*(maxlen-len(sample[x])) for sample in batch]
    x = f(0, x_maxlen, batch)
    decoder_inputs = f(3, y_maxlen, batch)
    y = f(4, y_maxlen, batch)
    
    f = torch.LongTensor
    return f(x), x_seqlens, words, f(decoder_inputs), f(y), y_seqlens, prons

# Model

In [234]:
class Encoder(nn.Module):
    global g2idx, idx2g, p2idx, idx2p
    def __init__(self, emb_units, hidden_units):
        super().__init__()
        self.emb_units = emb_units
        self.hidden_units = hidden_units
        self.emb = nn.Embedding(len(g2idx), emb_units)
        self.rnn = nn.GRU(emb_units, hidden_units, batch_first=True)
        
    def forward(self, x, seqlens):
        x = self.emb(x)
            
        # packing -> rnn -> unpacking -> position recovery: note that enforce_sorted is set to False.
        packed_input = pack_padded_sequence(x, seqlens, batch_first=True, enforce_sorted=False)   
        outputs, last_hidden = self.rnn(packed_input)
#         outputs, _ = pad_packed_sequence(outputs, batch_first=True, total_length=x.size()[1])

        # last hidden
        last_hidden = last_hidden.permute(1, 2, 0)
        last_hidden = last_hidden.view(last_hidden.size()[0], -1)
        
        return last_hidden



In [235]:
class Decoder(nn.Module):
    global g2idx, idx2g, p2idx, idx2p
    def __init__(self, emb_units, hidden_units):
        super().__init__()
        
        self.emb_units = emb_units
        self.hidden_units = hidden_units
        self.emb = nn.Embedding(len(p2idx), emb_units)
        self.rnn = nn.GRU(emb_units, hidden_units, batch_first=True)
        self.fc = nn.Linear(hidden_units, len(p2idx))
        
    def forward(self, decoder_inputs, h0):
        decoder_inputs = self.emb(decoder_inputs)
           
        outputs, last_hidden = self.rnn(decoder_inputs, h0)
        logits = self.fc(outputs) # (N, T, V)
        y_hat = logits.argmax(-1)
        
        return logits, y_hat, last_hidden


In [236]:
class Net(nn.Module):
    global g2idx, idx2g, p2idx, idx2p
    
    def __init__(self, encoder, decoder): 
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, x, seqlens, decoder_inputs, teacher_forcing=True, dec_maxlen=None):  
        '''
        At training, decoder inputs (ground truth) and teacher forcing is applied. 
        At evaluation, decoder inputs are ignored, and the decoding keeps for `dec_maxlen` steps.
        '''
        last_hidden = self.encoder(x, seqlens)
        h0 = last_hidden.unsqueeze(0)
        
        if teacher_forcing: # training
            logits, y_hat, h0 = self.decoder(decoder_inputs, h0)
        else: # evaluation
            decoder_inputs = decoder_inputs[:, :1] # "<s>"
            logits, y_hat = [], []
            for t in range(dec_maxlen):
                _logits, _y_hat, h0 =self.decoder(decoder_inputs, h0) # _logits: (N, 1, V), _y_hat: (N, 1), h0: (1, N, N)
                logits.append(_logits)
                y_hat.append(_y_hat)
                decoder_inputs = _y_hat
        
            logits = torch.cat(logits, 1)
            y_hat = torch.cat(y_hat, 1)
        
        return logits, y_hat


# Train & Eval functions

In [237]:
def train(model, iterator, optimizer, criterion, device):
    model.train()
    for i, batch in enumerate(iterator):
        x, x_seqlens, words, decoder_inputs, y, y_seqlens, prons = batch
        
        x, decoder_inputs = x.to(device), decoder_inputs.to(device) 
        y = y.to(device)
        
        optimizer.zero_grad()
        logits, y_hat = model(x, x_seqlens, decoder_inputs)
        
        # calc loss
        logits = logits.view(-1, logits.shape[-1]) # (N*T, VOCAB)
        y = y.view(-1) # (N*T,)
        loss = criterion(logits, y)
        loss.backward()
        
        optimizer.step()
        
        if i and i%100==0:
            print(f"step: {i}, loss: {loss.item()}")
        

In [239]:
def calc_per(Y_true, Y_pred):
    '''Calc phoneme error rate
    Y_true: list of predicted phoneme sequences. e.g., [["B", "L", "AA1", "K", "HH", "AW2", "S"], ...]
    Y_pred: list of ground truth phoneme sequences. e.g., [["B", "L", "AA1", "K", "HH", "AW2", "S"], ...]
    '''
    num_phonemes, num_erros = 0, 0
    for y_true, y_pred in zip(Y_true, Y_pred):
        num_phonemes += len(y_true)
        num_erros += levenshtein(y_true, y_pred)

    per = round(num_erros / num_phonemes, 2)
    return per, num_erros

In [240]:
def convert_ids_to_phonemes(ids, idx2p):
    phonemes = []
    for idx in ids:
        if idx == 3: # 3: </s>
            break
        p = idx2p[idx]
        phonemes.append(p)
    return phonemes
        
            

def eval(model, iterator, device, dec_maxlen):
    model.eval()

    Y_true, Y_pred = [], []
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            x, x_seqlens, words, decoder_inputs, y, y_seqlens, prons = batch
            x, decoder_inputs = x.to(device), decoder_inputs.to(device) 

            _, y_hat = model(x, x_seqlens, decoder_inputs, False, dec_maxlen) # <- teacher forcing is suppressed.
            
            y = y.to('cpu').numpy().tolist()
            y_hat = y_hat.to('cpu').numpy().tolist()
            for yy, yy_hat in zip(y, y_hat):
                y_true = convert_ids_to_phonemes(yy, idx2p)
                y_pred = convert_ids_to_phonemes(yy_hat, idx2p)
                Y_true.append(y_true)
                Y_pred.append(y_pred)
    
    # calc per.
    per, num_errors = calc_per(Y_true, Y_pred)
    print("per: %.2f" % per, "num errors: ", num_errors)
    
    with open("result", "w") as fout:
        for y_true, y_pred in zip(Y_true, Y_pred):
            fout.write(" ".join(y_true) + "\n")
            fout.write(" ".join(y_pred) + "\n\n")
    
    return per
            

# Train & Evaluate

In [241]:
train_dataset = G2pDataset(train_words, train_prons)
eval_dataset = G2pDataset(eval_words, eval_prons)

train_iter = data.DataLoader(train_dataset, batch_size=hp.batch_size, shuffle=True, collate_fn=pad)
eval_iter = data.DataLoader(eval_dataset, batch_size=hp.batch_size, shuffle=False, collate_fn=pad)


In [242]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [243]:
encoder = Encoder(hp.emb_units, hp.hidden_units)
decoder = Decoder(hp.emb_units, hp.hidden_units)
model = Net(encoder, decoder)
model.to(device)

optimizer = optim.Adam(model.parameters(), lr = hp.lr)
criterion = nn.CrossEntropyLoss(ignore_index=0)

for epoch in range(1, hp.num_epochs+1):
    print(f"\nepoch: {epoch}")
    train(model, train_iter, optimizer, criterion, device)
    eval(model, eval_iter, device, hp.dec_maxlen)


epoch: 1
step: 100, loss: 2.729764461517334
step: 200, loss: 2.06953763961792
step: 300, loss: 1.6415880918502808
step: 400, loss: 1.3378574848175049
step: 500, loss: 1.2205088138580322
step: 600, loss: 0.987713098526001
step: 700, loss: 0.9300493597984314
per: 0.39 num errors:  30495

epoch: 2
step: 100, loss: 0.862144947052002
step: 200, loss: 0.8274500370025635
step: 300, loss: 0.7277541160583496
step: 400, loss: 0.8145508766174316
step: 500, loss: 0.6055648922920227
step: 600, loss: 0.6670782566070557
step: 700, loss: 0.7308872938156128
per: 0.30 num errors:  23436

epoch: 3
step: 100, loss: 0.6280044913291931
step: 200, loss: 0.6292356848716736
step: 300, loss: 0.6315799951553345
step: 400, loss: 0.6083210110664368
step: 500, loss: 0.5998032093048096
step: 600, loss: 0.634599506855011
step: 700, loss: 0.6139586567878723
per: 0.25 num errors:  19908

epoch: 4
step: 100, loss: 0.5246961116790771
step: 200, loss: 0.4894694685935974
step: 300, loss: 0.5112945437431335
step: 400, loss

# Inference

In [244]:
test_dataset = G2pDataset(test_words, test_prons)
test_iter = data.DataLoader(test_dataset, batch_size=hp.batch_size, shuffle=False, collate_fn=pad)

In [245]:
eval(model, test_iter, device, hp.dec_maxlen)

per: 0.18 num errors:  14045


0.18

Check the results.

In [246]:
open('result', 'r').read().splitlines()[-100:]

['',
 'G L IY1 M D',
 'G L IY1 M D',
 '',
 'P EY1 D AH0 N',
 'P EY1 D AH0 N',
 '',
 'B L UW1 N AH0 S',
 'B L UW1 N AH0 S',
 '',
 'HH OW1 L B R UH0 K S',
 'HH OW1 L B R UH2 K S',
 '',
 'B AE1 R IH0 S T ER0 Z',
 'B AE1 R IH0 S T ER0 Z',
 '',
 'P EH1 L T',
 'P EH1 L T',
 '',
 'M AA1 R K AH0 L',
 'M AA1 R K AH0 L',
 '',
 'F EY1 G ER0 S T R AH0 M',
 'F EY1 G ER0 S T R AH0 M',
 '',
 'P EH1 R AH0 SH UW2 T',
 'P EH1 R AH0 CH UW0 T',
 '',
 'B EH2 L W OW0 M IY1 N IY0',
 'B EH0 L UW1 M IY0 N IY0',
 '',
 'L AA1 HH AH0 V IH0 CH',
 'L AH0 SH OW1 IH0 K S',
 '',
 'F EY1 G AH0 N',
 'F AE1 G AH0 N',
 '',
 'P IH1 S ER0 EH0 K',
 'P IH0 S AA1 R EH0 K',
 '',
 'R IY1 D ER0 M AH0 N',
 'R IY1 D ER0 M AH0 N',
 '',
 'K AA1 K AH0 T UW2 Z',
 'K AH0 K EY1 T OW0 Z',
 '',
 'R IY0 B AH1 F IH0 NG',
 'R IH0 F AH1 B IH0 NG',
 '',
 'S AW1 TH D AW2 N',
 'S AW1 TH D AW2 N',
 '',
 'B AE1 L AH0 N T R EY2',
 'B AE1 L AH0 N T R EY2',
 '',
 'S L OW1 P S',
 'S L OW1 P S',
 '',
 'V AE1 N D ER0 V L IY2 T',
 'V AE1 N D ER0 V L AY2 T