# Thai to English Transliteration with Seq2Seq model

In [458]:
import time
import sys
import os
import random
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch import optim
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision import transforms, utils
import numpy as np
from tqdm.auto import tqdm
from collections import Counter
from matplotlib import pyplot as plt
from collections import OrderedDict

SEED = 0
np.random.seed(SEED)
random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True


In [2]:
# Check if GPUs are in the machine, otherwise assign device as CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

The csv file contains two columns indicates Thai text and its corresponding English tranliteration

In [3]:
DATA_PATH = '../dataset/data.csv'
!head $DATA_PATH

กองพันทหารปืนใหญ่	kongphanthahanpuenyai
วิฑูรย์	withun
เมตาบอลิสม	metabolisom
บ้านหนองเลา	bannonglao
อายุษ	ayut
แทมปา	thaempa
ประเทศกรีนแลนด์	prathetkrinlaen
พรรคคองเกรส	phakkhongkerot
การสูบ	kansup
บ้านเทพพยัคฆ์ใต้	banthepphayaktai


In [757]:

def load_data(data_path):
    with open(data_path, 'r', encoding='utf-8-sig') as f:
        lines = f.read().split('\n')

    input_texts = []
    target_texts = []
    for line in lines:
        line = line.replace(u'\xa0',' ')
        input_text, target_text = line.split('\t')
        input_texts.append(input_text)
        target_texts.append(target_text)

    return input_texts, target_texts


In [758]:
%time input_texts, target_texts = load_data(DATA_PATH)


CPU times: user 555 ms, sys: 171 ms, total: 726 ms
Wall time: 724 ms


In [759]:
# Define special characters
UNK_token = '<UNK>'
PAD_token = '<PAD>'
START_token = '<start>'
END_token = '<end>'
MAX_LENGTH = 60

class Language:
    def __init__(self, name, is_input=False):
        self.name = name
        self.characters = set()
        self.n_chars = 0
        self.char2index = {}
        self.index2char = {}

        if is_input == True:
            self.index2char = { 0: PAD_token, 1: UNK_token, 2: START_token, 3: END_token }
            self.char2index = { ch:i for i, ch in self.index2char.items() } #reverse dictionary
            self.n_chars = 4
        else:
            self.index2char = { 0: PAD_token, 1: START_token, 2: END_token }
            self.char2index = { ch:i for i, ch in self.index2char.items() } #reverse dictionary
            self.n_chars = 3

    def addText(self, text):
        for character in text:
            self.addCharacter(character)
    
    def addCharacter(self, character):
        if character not in self.char2index.keys():
            self.char2index[character] = self.n_chars
            self.index2char[self.n_chars] = character
            self.n_chars += 1
            
            
def indexesFromText(lang, text):
    """returns indexes for all character given the text in the specified language"""
    return [lang.char2index[char] for char in text]

def tensorFromText(lang, text):
    """construct a tensor given the text in the specified language"""
    indexes = indexesFromText(lang, text)
    indexes.append(lang.char2index[END_token])
    
    no_padded_seq_length = len(indexes) # Number of characters in the text (including <END> token)
    # Add padding token to make all tensors in the same length
    for i in range(len(indexes), MAX_LENGTH): # padding
        indexes.append(lang.char2index[PAD_token])
        
    return torch.tensor(indexes, dtype=torch.long).to(device), no_padded_seq_length

def filterPair(p1, p2):
    """filter for the pair the both texts has length less than `MAX_LENGTH`"""
    return len(p1) < MAX_LENGTH and len(p2) < MAX_LENGTH

def tensorsFromPair(pair, lang1, lang2):
    """construct two tensors from a pair of source and target text specified by source and target language"""
    input_tensor, input_length = tensorFromText(lang1, pair[0])
    target_tensor, target_length = tensorFromText(lang2, pair[1])
    return input_tensor, target_tensor, input_length, target_length



class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        input_text, target_text, lang_th, lang_th_romanized = sample['input_text'], sample['target_text'], \
                                                              sample['lang_th'], sample['lang_th_romanized']

        input_tensor, target_tensor, input_length, target_length = tensorsFromPair([input_text, target_text], 
                                                                                   lang_th, 
                                                                                   lang_th_romanized)
        
        return {
                'input_text': input_text,
                'target_text': target_text,
                'input_length': input_length,
                'target_length': target_length,
                'input_tensor': input_tensor,
                'target_tensor': target_tensor
               }
    
    
class ThaiRomanizationDataset(Dataset):
    """Thai Romanization Dataset class"""
    def __init__(self, 
                 data_path=DATA_PATH, 
                 transform=transforms.Compose([ ToTensor() ])):

        input_texts, target_texts = load_data(data_path)
        
        self.input_texts = input_texts
        self.target_texts = target_texts
        self.transform = transform
        self.lang_th = None
        self.lang_th_romanized = None
        self.counter = Counter()
        self.pairs = []
        self.prepareData()

    def prepareData(self):
        self.lang_th = Language('th', is_input=True)
        self.lang_th_romanized = Language('th_romanized', is_input=False)
        for i in range(len(self.input_texts)):
            
            input_text = self.input_texts[i]
            target_text = self.target_texts[i]
            
            # Count the number of input and target sequences with length `x`
            self.counter.update({ 
                                  'len_input_{}'.format(len(input_text)): 1, 
                                  'len_target_{}'.format(len(target_text)): 1 
                                })
            
            if filterPair(input_text, target_text):
                self.pairs.append((input_text, target_text))
                self.lang_th.addText(input_text)
                self.lang_th_romanized.addText(target_text)    

    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        
        sample = dict()
        sample['input_text'] = self.pairs[idx][0]
        sample['target_text'] = self.pairs[idx][1]
        
        sample['lang_th'] = self.lang_th
        sample['lang_th_romanized'] = self.lang_th_romanized

        if self.transform:
            sample = self.transform(sample)

        return sample

In [760]:
thai_romanization_dataset = ThaiRomanizationDataset()

In [761]:
thai_romanization_dataset.lang_th.index2char.items()

dict_items([(0, '<PAD>'), (1, '<UNK>'), (2, '<start>'), (3, '<end>'), (4, 'ก'), (5, 'อ'), (6, 'ง'), (7, 'พ'), (8, 'ั'), (9, 'น'), (10, 'ท'), (11, 'ห'), (12, 'า'), (13, 'ร'), (14, 'ป'), (15, 'ื'), (16, 'ใ'), (17, 'ญ'), (18, '่'), (19, 'ว'), (20, 'ิ'), (21, 'ฑ'), (22, 'ู'), (23, 'ย'), (24, '์'), (25, 'เ'), (26, 'ม'), (27, 'ต'), (28, 'บ'), (29, 'ล'), (30, 'ส'), (31, '้'), (32, 'ุ'), (33, 'ษ'), (34, 'แ'), (35, 'ะ'), (36, 'ศ'), (37, 'ี'), (38, 'ด'), (39, 'ค'), (40, 'ฆ'), (41, 'โ'), (42, '๋'), (43, 'ช'), (44, 'ไ'), (45, 'จ'), (46, 'ภ'), (47, 'ณ'), (48, 'ำ'), (49, 'ฝ'), (50, 'ข'), (51, 'ผ'), (52, 'ฒ'), (53, 'ซ'), (54, 'ธ'), (55, '.'), (56, ' '), (57, 'ถ'), (58, 'ฐ'), (59, 'ฏ'), (60, 'ฟ'), (61, 'ึ'), (62, '็'), (63, 'ฮ'), (64, '๊'), (65, 'ฉ'), (66, 'ฎ'), (67, '2'), (68, 'ฌ'), (69, 'ฬ'), (70, '1'), (71, '4'), (72, '8'), (73, '3'), (74, 'ฯ'), (75, 'ๆ'), (76, '5'), (77, '๙'), (78, '7'), (79, '0'), (80, 'ฺ'), (81, '6'), (82, '9'), (83, 'ฤ'), (84, '-'), (85, 'ฅ'), (86, 'ๅ'), (87, 'ฃ'), (88, 'ฦ'), (

In [762]:
thai_romanization_dataset.lang_th_romanized.index2char.items()

dict_items([(0, '<PAD>'), (1, '<start>'), (2, '<end>'), (3, 'k'), (4, 'o'), (5, 'n'), (6, 'g'), (7, 'p'), (8, 'h'), (9, 'a'), (10, 't'), (11, 'u'), (12, 'e'), (13, 'y'), (14, 'i'), (15, 'w'), (16, 'm'), (17, 'b'), (18, 'l'), (19, 's'), (20, 'r'), (21, 'd'), (22, 'c'), (23, 'f'), (24, '-'), (25, ' '), (26, '2'), (27, '1'), (28, '4'), (29, '8'), (30, '3'), (31, '5'), (32, '7'), (33, '0'), (34, '6'), (35, '9'), (36, '"'), (37, '!'), (38, '('), (39, ')')])

## Seq2Seq Model architecture

## 1. Encoder

Encoder 
    - Embedding layer :(vocaburay_size, embedding_size) 
        Input: (batch_size, sequence_length)
        Output: (batch_size, sequence_length, embebeding_size)
      
    - Bi-LSTM layer : (input_size, hidden_size, num_layers, batch_first=True)
        Input: (input=(batch_size, seq_len, embebeding_size),  hidden)
        Output: (output=(batch_size, seq_len, hidden_size),
                 (h_n, c_n))
     
     
__Steps:__

1. Receives a batch of source sequences (batch_size, MAX_LENGTH) and a 1-D array of the length for each sequence (batch_size).
     
2. Sort sequences in the batch by sequence length (number of tokens in the sequence where <PAD> token is excluded).

3. Feed the batch of sorted sequences into the Embedding Layer to maps source character indices into vectors. (batch_size,  sequence_length, embebeding_size)

4. Use `pack_padded_sequence` to let LSTM packed input with same length at time step $t$ together. This will reduce time required for training by avoid feeding `<PAD>` token to the LSTMs.


5. Returns LSTM outputs in the unsorted order, and the LSTM hidden state vectors.
     

In [763]:


class Encoder(nn.Module):
    
    def __init__(self, vocabulary_size, embedding_size, hidden_size, dropout=0.5):
        """Constructor"""
        super(Encoder, self).__init__()
        
        self.hidden_size = hidden_size
        self.character_embedding = nn.Embedding(vocabulary_size, embedding_size)
        self.lstm = nn.LSTM(input_size=embedding_size, 
                            hidden_size=hidden_size // 2, 
                            bidirectional=True,
                            batch_first=True)
        
        self.dropout = nn.Dropout(dropout)


    def forward(self, sequences, sequences_lengths):
        batch_size = sequences.size(0)
        self.hidden = self.init_hidden(batch_size) # batch_size

        # sequences :(batch_size, sequence_length=MAX_LENGTH)
        # sequences_lengths: (batch_size)  # an 1-D indicating length of each sequence (excluded <PAD> token) in `seq`
        
        # 1. Firstly we sort `sequences_lengths` according to theirs values and keep list of indexes to perform sorting
        sequences_lengths = np.sort(sequences_lengths)[::-1] # sort in ascending order and reverse it
        index_sorted = np.argsort(-sequences_lengths) # use negation in sort in descending order
        index_unsort = np.argsort(index_sorted) # to unsorted sequence
        
        
        # 2. Then, we change position of sequence in `sequences` according to `index_sorted`
        index_sorted = torch.from_numpy(index_sorted)
        sequences = sequences.index_select(0, index_sorted.to(device))
        
        # 3. Feed to Embedding Layer
        
        sequences = self.character_embedding(sequences)
        sequences = self.dropout(sequences)
        
#         print('sequences',sequences.size(), sequences)
            
        # 3. Use function: pack_padded_sequence to let LSTM packed input with same length at time step T together
        
        # Quick fix: Use seq_len.copy(), instead of seq_len to fix `Torch.from_numpy not support negative strides`
        # ndarray.copy() will alocate new memory for numpy array which make it normal, I mean the stride is not negative any more.

        sequences_packed = nn.utils.rnn.pack_padded_sequence(sequences,
                                                             sequences_lengths.copy(),
                                                             batch_first=True)
#         print('sequences_packed', sequences_packed)

        # 4. Feed to LSTM
        sequences_output, self.hidden = self.lstm(sequences_packed, self.hidden)
        
        # 5. Unpack
        sequences_output, _ = nn.utils.rnn.pad_packed_sequence(sequences_output, batch_first=True)

        # 6. Un-sort by length
        index_unsort = torch.from_numpy(index_unsort).to(device)
        sequences_output = sequences_output.index_select(0, Variable(index_unsort))

#         print('hidden shape', self.hidden[0].shape, self.hidden[0], self.hidden[1].shape, self.hidden[1])
        return sequences_output, self.hidden
    
    def init_hidden(self, batch_size):
        h_0 = torch.zeros([2, batch_size, self.hidden_size // 2], requires_grad=True).to(device)
        c_0 = torch.zeros([2, batch_size, self.hidden_size // 2], requires_grad=True).to(device)
        
        return (h_0, c_0)
    
def save_model(epoch, loss):
    torch.save({
        'epoch': epoch,
        'loss': loss,
        'model_state_dict': model.state_dict(),
        'char_to_ix': thai_romanization_dataset.lang_th.char2index,
        'ix_to_char': thai_romanization_dataset.lang_th.index2char,
        'target_char_to_ix': thai_romanization_dataset.lang_th_romanized.char2index,
        'ix_to_target_char':thai_romanization_dataset.lang_th_romanized.index2char
    #            'optimizerE_state_dict': encoder_optimizer.state_dict(),
    #            'optimizerD_state_dict': decoder_optimizer.state_dict(),
    }, "thai2rom-pytorch.attn.best.tar".format(epoch))
    

## Decoder

   
Decoder architecture

    - Embedding layer :(vocabulary_size, embebeding_size)
        Input: (batch_size, sequence_length=1)
        Output: (batch_size, sequence_length=1, embebeding_size)
    - RNN layer :input_size=embebeding_size, hidden_size, num_layers, batch_first=True)
        Input: (input=(batch_size, input_size=embedding_dimension), hidden:tuple=encoder_hidden
        Output: (batch_size, seq_len, hidden_size), (h_n, c_n)
    - Attention Layer: (in_features=hidden_size, out_features=hidden_size, bias=True)
    - Linear Layer: (in_features, out_features=vocabulary_size)
        Input: (batch_size, hidden_size)
        Output: (batch_size, vocabulary_size)
    
    - Softmax layer
        Input: (batch_size, vocabulary_size)
        Output: (batch_size, vocabulary_size)



For the Attention mechanishm in the Decoder, Luong-style attention [[Luong et. al (2015)](https://arxiv.org/abs/1508.04025)] is used. 



__Steps:__

1. Receives a batch of <start> token (batch_size, 1) and a batch of Encoder's hidden state.
     
2. Embed input into vectors.

3. Feed vectors from (2) to the LSTM.

4. Feed the output of LSTM at time step $t_1$ and Encoder output to the Attention Layer.

5. Attention layer, returns weights for Encoder's hidden states in every time step (masked out the time step with <PAD> token), then multiply with Encoder's hidden states to obtain a context vector
    
6. Concatenate both decoder hidden state and the context vector, feed to a linear layer, and return its output.

7. Decoder then returns, final output, decoder's hidden state, attention weights, and context vector at time step $t$

In [764]:

class Attn(nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()

        self.method = method
        self.hidden_size = hidden_size

        if self.method == 'general':
            self.attn = nn.Linear(self.hidden_size, hidden_size)

        elif self.method == 'concat':
            self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
            self.other = nn.Parameter(torch.FloatTensor(1, hidden_size))

    def forward(self, hidden, encoder_outputs, mask):
        # hidden: B x 1 x h ; 
        # encoder_outputs: B x S x h

        # Calculate energies for each encoder output
        if self.method == 'dot':
            attn_energies = torch.bmm(encoder_outputs, hidden.transpose(1, 2)).squeeze(2)  # B x S
        elif self.method == 'general':
            attn_energies = self.attn(encoder_outputs.view(-1, encoder_outputs.size(-1)))  # (B * S) x h
            attn_energies = torch.bmm(attn_energies.view(*encoder_outputs.size()),
                                      hidden.transpose(1, 2)).squeeze(2)  # B x S
        elif self.method == 'concat':
            attn_energies = self.attn(
                torch.cat((hidden.expand(*encoder_outputs.size()), encoder_outputs), 2))  # B x S x h
            attn_energies = torch.bmm(attn_energies,
                                      self.other.unsqueeze(0).expand(*hidden.size()).transpose(1, 2)).squeeze(2)

        attn_energies = attn_energies.masked_fill(mask == 0, -1e10)

        # Normalize energies to weights in range 0 to 1
        return F.softmax(attn_energies, 1)

class AttentionDecoder(nn.Module): 
    
    def __init__(self, vocabulary_size, embedding_size, hidden_size, dropout=0.5):
        """Constructor"""
        super(AttentionDecoder, self).__init__()
        self.vocabulary_size = vocabulary_size
        self.hidden_size = hidden_size
        self.character_embedding = nn.Embedding(vocabulary_size, embedding_size)
        self.lstm = nn.LSTM(input_size=embedding_size + self.hidden_size,
                            hidden_size=hidden_size,
                            bidirectional=False,
                            batch_first=True)
        
        self.attn = Attn(method="general", hidden_size=self.hidden_size)
        self.linear = nn.Linear(hidden_size * 2, vocabulary_size)
        
        self.dropout = nn.Dropout(dropout)

        
    def forward(self, input, last_hidden, last_context, encoder_outputs, mask):
        """"Defines the forward computation of the decoder"""
        # input: (B, 1) ,
        # last_hidden: (num_layers * num_directions, B, hidden_dim)
        # last_context: (B, 1, hidden_dim)
        # encoder_outputs: (B, S, hidden_dim)
        
        embedded = self.character_embedding(input)
        embedded = self.dropout(embedded)
        
        # embedded: (batch_size, emb_dim)
        rnn_input = torch.cat((embedded, last_context), 2)

        output, hidden = self.lstm(rnn_input, last_hidden)        
        attn_weights = self.attn(output, encoder_outputs, mask)  # B x S
    
        #  context = (B, 1, S) x (B, S, hidden_dim)
        #  context = (B, 1, hidden_dim)
        context = attn_weights.unsqueeze(1).bmm(encoder_outputs)  
        
        output = torch.cat((context.squeeze(1), output.squeeze(1)), 1)
        output = self.linear(output)
        
        return output, hidden, context, attn_weights


## Seq2Seq model

This class encapsulate _Decoder_ and _Encoder_ class.

__Steps:__

1. The input sequcence $X$ is fed into the encoder to receive one hidden state vector.

2. The initial decoder hidden state is set to be the hidden state vector of the encoder

3. Add a batch of `<start>` tokens (batch_size, 1) as the first input $y_1$
    
4. Then, decode within a loop:
    - Inserting the input token $y_t$, previous hidden state, $s_{t-1}$, and the context vector $z$ into the decoder
    - Receiveing a prediction $\hat{y}$ and a new hidden state $s_t$
    - Then, either use teacher forcing to let groundtruth target character as the input for the decoder at time step $t+1$, or let the result from decoder as the input for the next time step.

In [806]:
class Seq2Seq(nn.Module): 

    def __init__(self, encoder, decoder, device):
        super().__init__()
        
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        self.pad_idx = 0

        assert encoder.hidden_size == decoder.hidden_size
    
    def create_mask(self, source_seq):
        mask = (source_seq != self.pad_idx)
        return mask
        
  
    def forward(self, source_seq, source_seq_len, target_seq, teacher_forcing_ratio = 0.5):
        """
            Parameters:
                - source_seq: (batch_size x MAX_LENGTH) 
                - source_seq_len: (batch_size x 1)
                - target_seq: (batch_size x MAX_LENGTH)

            Returns
        """
        batch_size = source_seq.size(0)
        start_token = thai_romanization_dataset.lang_th_romanized.char2index["<start>"]
        end_token = thai_romanization_dataset.lang_th_romanized.char2index["<end>"]
        max_len = MAX_LENGTH
        target_vocab_size = self.decoder.vocabulary_size

        # init a tensor to store decoder outputs
        outputs = torch.zeros(max_len, batch_size, target_vocab_size).to(device)
        
        if target_seq is None:
            assert teacher_forcing_ratio == 0, "Must be zero during inference"
            inference = True
        else:
            inference = False

    
        # feed mini-batch source sequences into the `Encoder`
        encoder_outputs, encoder_hidden = encoder(source_seq, source_seq_len)

        # create a Tensor of first input for the decoder
        decoder_input = torch.tensor([[start_token] * batch_size]).view(batch_size, 1).to(device)
        
        # Initiate decoder output as the last state encoder's hidden state
        decoder_hidden_0 = torch.cat([encoder_hidden[0][0], encoder_hidden[0][1]], dim=1).unsqueeze(dim=0)
        decoder_hidden_1 = torch.cat([encoder_hidden[1][0], encoder_hidden[1][1]], dim=1).unsqueeze(dim=0)
        decoder_hidden = (decoder_hidden_0, decoder_hidden_1) # (hidden state, cell state)

        # define a context vector
        decoder_context = Variable(torch.zeros(encoder_outputs.size(0), encoder_outputs.size(2))).unsqueeze(1).to(device)
        
        max_source_len = encoder_outputs.size(1)
        mask = self.create_mask(source_seq[:, 0:max_source_len])

        for di in range(max_len):
            decoder_output, decoder_hidden, decoder_context, attn_weights = decoder(decoder_input,
                                                                                    decoder_hidden,
                                                                                    decoder_context,
                                                                                    encoder_outputs,
                                                                                    mask)
            # decoder_output: (batch_size, target_vocab_size)

            topv, topi = decoder_output.topk(1)
            outputs[di] = decoder_output.to(device)

            teacher_force = random.random() < teacher_forcing_ratio

            decoder_input = target_seq[:, di].reshape(batch_size, 1) if teacher_force else topi.detach() 

            if inference and decoder_input == end_token:
                return outputs[:di]
       
        return outputs

Initializae model


In [797]:
SEED = 0
BATCH_SIZE = 256
TRAIN_RATIO = 0.8

N = len(thai_romanization_dataset)

print('Number of samples: ', N)
train_split_idx = int(TRAIN_RATIO * N)

print('split at index:', train_split_idx)
indices = list(range(N))

# Random Split
np.random.seed(SEED)
np.random.shuffle(indices)
train_indices, val_indices = indices[:train_split_idx], indices[train_split_idx:]

print('train_indices', train_indices[0:5])
print('val_indices', val_indices[0:5])

train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)
                                   
train_dataset_loader = torch.utils.data.DataLoader(
                                             thai_romanization_dataset,
                                             batch_size=BATCH_SIZE, 
                                             sampler=train_sampler,
                                             num_workers=0)

val_dataset_loader = torch.utils.data.DataLoader(
                                             thai_romanization_dataset,
                                             batch_size=BATCH_SIZE,
                                             shuffle=False,
                                             sampler=valid_sampler,
                                             num_workers=0)


print('Number of train mini-batches', len(train_dataset_loader))
print('Number of val mini-batches', len(val_dataset_loader))


Number of samples:  648206
split at index: 518564
train_indices [118842, 551164, 200228, 93841, 142270]
val_indices [419695, 455015, 88739, 260808, 64788]
Number of train mini-batches 2026
Number of val mini-batches 507


In [807]:
INPUT_DIM = len(thai_romanization_dataset.lang_th.char2index)
OUTPUT_DIM = len(thai_romanization_dataset.lang_th_romanized.char2index)

ENC_EMB_DIM = 128
ENC_HID_DIM = 256
ENC_DROPOUT = 0.5

DEC_EMB_DIM = 128
DEC_HID_DIM = 256
DEC_DROPOUT = 0.5

encoder = Encoder(INPUT_DIM, ENC_EMB_DIM, ENC_HID_DIM).to(device)
decoder = AttentionDecoder(OUTPUT_DIM, DEC_HID_DIM, DEC_HID_DIM).to(device)

model = Seq2Seq(encoder, decoder, device).to(device)


In [808]:
def init_weights(m):
    for name, param in m.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param.data, mean=0, std=0.01)
        else:
            nn.init.constant_(param.data, 0)
            
model.apply(init_weights)

Seq2Seq(
  (encoder): Encoder(
    (character_embedding): Embedding(94, 128)
    (lstm): LSTM(128, 128, batch_first=True, bidirectional=True)
    (dropout): Dropout(p=0.5)
  )
  (decoder): AttentionDecoder(
    (character_embedding): Embedding(40, 256)
    (lstm): LSTM(512, 256, batch_first=True)
    (attn): Attn(
      (attn): Linear(in_features=256, out_features=256, bias=True)
    )
    (linear): Linear(in_features=512, out_features=40, bias=True)
    (dropout): Dropout(p=0.5)
  )
)

In [809]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 1,161,256 trainable parameters


In [810]:

learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss(ignore_index = 0)


## Training

In [811]:
print_loss_every = 500

def train(model, iterator, optimizer, criterion, clip):
    
    model.train()
    
    epoch_loss = 0
    for i, batch in tqdm(enumerate(iterator), total = len(iterator)):
        optimizer.zero_grad()

        source_seq, source_seq_len = batch['input_tensor'], batch['input_length']
        batch_size = source_seq.size(0)
        
        # target_seq: (batch_size , MAX_LENGTH)
        # output: (MAX_LENGTH , batch_size , target_vocab_size)
        target_seq = batch['target_tensor']

        output = model(source_seq, source_seq_len, target_seq)
        
        # target_seq -> (MAX_LENGTH , batch_size)
        target_seq = target_seq.transpose(0, 1)

        # target_seq -> ((MAX_LENGTH - 1) * batch_size)
        target_seq = target_seq[1:].contiguous().view(-1)

        # output -> ((MAX_LENGTH -1) * batch_size, target_vocab_size)        
        output = output[1:].view(-1, output.shape[-1])

        loss = criterion(output, target_seq)
        
        if i % print_loss_every == 0:
            print('Loss ', loss.item())

        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
        
    return epoch_loss / len(iterator)

In [858]:
def evaluate(model, iterator, criterion):
    
    model.eval()

    epoch_loss = 0
    
    with torch.no_grad():
    
        for i, batch in enumerate(iterator):

            source_seq, source_seq_len = batch['input_tensor'], batch['input_length']
            batch_size = source_seq.size(0)

            # target_seq: (batch_size , MAX_LENGTH)
            # output: (MAX_LENGTH , batch_size , target_vocab_size)
            target_seq = batch['target_tensor']
            output = model(source_seq, source_seq_len, target_seq)
        
            # target_seq -> (MAX_LENGTH , batch_size)
            target_seq = target_seq.transpose(0, 1)

            # target_seq -> ((MAX_LENGTH - 1) * batch_size)
            target_seq = target_seq[1:].contiguous().view(-1)

            # output -> ((MAX_LENGTH -1) * batch_size, target_vocab_size)        
            output = output[1:].view(-1, output.shape[-1])

            loss = criterion(output, target_seq)
            epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

def inference(model, text):
    model.eval()

    input_seq =  [ch for ch in text] +  ['<end>']
    numericalized = [thai_romanization_dataset.lang_th.char2index[ch] for ch in input_seq] 
    sentence_length = [len(numericalized)]

    tensor = torch.LongTensor(numericalized).view(1, -1).to(device) 
    translation_tensor_logits = model(tensor, sentence_length, None, 0) 
    
    try:
        translation_tensor = torch.argmax(translation_tensor_logits.squeeze(1), 1).cpu().numpy()
        translation_indices = [t for t in translation_tensor]
        translation = [thai_romanization_dataset.lang_th_romanized.index2char[t] for t in translation_tensor]
    except:
        translation_indices = [0]
        translation = ['<pad>']
    return ''.join(translation), translation_indices

def show_inference_example(model, input_texts, target_texts):
    for index, input_text in enumerate(input_texts):
        prediction, indices = inference(model, input_text)
        print('groundtruth: {}'.format(target_texts[index]))
        print(' prediction: {} {}\n'.format(prediction, indices))
    

In [813]:

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [814]:
N_EPOCHS = 50
CLIP = 5

best_valid_loss = float('inf')


for epoch in range(N_EPOCHS):
    model.train()
    start_time = time.time()
    
    train_loss = train(model, train_dataset_loader, optimizer, criterion, CLIP)
    valid_loss = evaluate(model, val_dataset_loader, criterion)
    show_inference_example(model,
                           thai_romanization_dataset.input_texts[5000:5010],
                           thai_romanization_dataset.target_texts[5000:5010])
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        save_model(epoch, best_valid_loss)
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')

HBox(children=(IntProgress(value=0, max=2026), HTML(value='')))

Loss  3.6888835430145264
Loss  2.2496068477630615
Loss  1.0013138055801392
Loss  0.7338398098945618
Loss  0.6599687337875366
input: บ้านปางสามัคคี [28, 31, 12, 9, 14, 12, 6, 30, 12, 26, 8, 39, 39, 37, 3]
groundtruth: banpangsamakkhi
 prediction: hanpangsamakkhi [8, 9, 5, 7, 9, 5, 6, 19, 9, 16, 9, 3, 3, 8, 14]

input: บ้านหนองนาเวียง [28, 31, 12, 9, 11, 9, 5, 6, 9, 12, 25, 19, 37, 23, 6, 3]
groundtruth: bannongnawiang
 prediction: bannongnawiang [17, 9, 5, 5, 4, 5, 6, 5, 9, 15, 14, 9, 5, 6]

input: มณฑลกานสู้ [26, 47, 21, 29, 4, 12, 9, 30, 22, 31, 3]
groundtruth: monthonkansu
 prediction: hanthonkansu [8, 9, 5, 10, 8, 4, 5, 3, 9, 5, 19, 11]

input: บ้านโนนหมากมุ่น [28, 31, 12, 9, 41, 9, 9, 11, 26, 12, 4, 26, 32, 18, 9, 3]
groundtruth: bannonmakmun
 prediction: bannonmakmun [17, 9, 5, 5, 4, 5, 16, 9, 3, 16, 11, 5]

input: นัว [9, 8, 19, 3]
groundtruth: nua
 prediction: hua [8, 11, 9]

input: การเลือกตั้งพิเศษ [4, 12, 13, 25, 29, 15, 5, 4, 27, 8, 31, 6, 7, 20, 25, 36, 33, 3]
groundtruth: 

HBox(children=(IntProgress(value=0, max=2026), HTML(value='')))

Loss  0.6301097869873047
Loss  0.6022240519523621
Loss  0.5084996819496155
Loss  0.4069646894931793
Loss  0.5963090658187866
input: บ้านปางสามัคคี [28, 31, 12, 9, 14, 12, 6, 30, 12, 26, 8, 39, 39, 37, 3]
groundtruth: banpangsamakkhi
 prediction: banpangsamakkhi [17, 9, 5, 7, 9, 5, 6, 19, 9, 16, 9, 3, 3, 8, 14]

input: บ้านหนองนาเวียง [28, 31, 12, 9, 11, 9, 5, 6, 9, 12, 25, 19, 37, 23, 6, 3]
groundtruth: bannongnawiang
 prediction: bannongnawiang [17, 9, 5, 5, 4, 5, 6, 5, 9, 15, 14, 9, 5, 6]

input: มณฑลกานสู้ [26, 47, 21, 29, 4, 12, 9, 30, 22, 31, 3]
groundtruth: monthonkansu
 prediction: monthonkansu [16, 4, 5, 10, 8, 4, 5, 3, 9, 5, 19, 11]

input: บ้านโนนหมากมุ่น [28, 31, 12, 9, 41, 9, 9, 11, 26, 12, 4, 26, 32, 18, 9, 3]
groundtruth: bannonmakmun
 prediction: bannonmakmun [17, 9, 5, 5, 4, 5, 16, 9, 3, 16, 11, 5]

input: นัว [9, 8, 19, 3]
groundtruth: nua
 prediction: hua [8, 11, 9]

input: การเลือกตั้งพิเศษ [4, 12, 13, 25, 29, 15, 5, 4, 27, 8, 31, 6, 7, 20, 25, 36, 33, 3]
groundtruth

HBox(children=(IntProgress(value=0, max=2026), HTML(value='')))

Loss  0.5242689251899719
Loss  0.3869246542453766
Loss  0.3381291329860687
Loss  0.3808336555957794
Loss  0.38804715871810913
input: บ้านปางสามัคคี [28, 31, 12, 9, 14, 12, 6, 30, 12, 26, 8, 39, 39, 37, 3]
groundtruth: banpangsamakkhi
 prediction: banpangsamakkhi [17, 9, 5, 7, 9, 5, 6, 19, 9, 16, 9, 3, 3, 8, 14]

input: บ้านหนองนาเวียง [28, 31, 12, 9, 11, 9, 5, 6, 9, 12, 25, 19, 37, 23, 6, 3]
groundtruth: bannongnawiang
 prediction: bannongnawiang [17, 9, 5, 5, 4, 5, 6, 5, 9, 15, 14, 9, 5, 6]

input: มณฑลกานสู้ [26, 47, 21, 29, 4, 12, 9, 30, 22, 31, 3]
groundtruth: monthonkansu
 prediction: monthonlakansu [16, 4, 5, 10, 8, 4, 5, 18, 9, 3, 9, 5, 19, 11]

input: บ้านโนนหมากมุ่น [28, 31, 12, 9, 41, 9, 9, 11, 26, 12, 4, 26, 32, 18, 9, 3]
groundtruth: bannonmakmun
 prediction: bannonmakmun [17, 9, 5, 5, 4, 5, 16, 9, 3, 16, 11, 5]

input: นัว [9, 8, 19, 3]
groundtruth: nua
 prediction: hua [8, 11, 9]

input: การเลือกตั้งพิเศษ [4, 12, 13, 25, 29, 15, 5, 4, 27, 8, 31, 6, 7, 20, 25, 36, 33, 3]
g

HBox(children=(IntProgress(value=0, max=2026), HTML(value='')))

Loss  0.40508025884628296
Loss  0.42880457639694214
Loss  0.376893013715744
Loss  0.35563620924949646
Loss  0.3870902955532074
input: บ้านปางสามัคคี [28, 31, 12, 9, 14, 12, 6, 30, 12, 26, 8, 39, 39, 37, 3]
groundtruth: banpangsamakkhi
 prediction: banpangsamakkhi [17, 9, 5, 7, 9, 5, 6, 19, 9, 16, 9, 3, 3, 8, 14]

input: บ้านหนองนาเวียง [28, 31, 12, 9, 11, 9, 5, 6, 9, 12, 25, 19, 37, 23, 6, 3]
groundtruth: bannongnawiang
 prediction: bannongnawiang [17, 9, 5, 5, 4, 5, 6, 5, 9, 15, 14, 9, 5, 6]

input: มณฑลกานสู้ [26, 47, 21, 29, 4, 12, 9, 30, 22, 31, 3]
groundtruth: monthonkansu
 prediction: monthonkansu [16, 4, 5, 10, 8, 4, 5, 3, 9, 5, 19, 11]

input: บ้านโนนหมากมุ่น [28, 31, 12, 9, 41, 9, 9, 11, 26, 12, 4, 26, 32, 18, 9, 3]
groundtruth: bannonmakmun
 prediction: bannonmakmun [17, 9, 5, 5, 4, 5, 16, 9, 3, 16, 11, 5]

input: นัว [9, 8, 19, 3]
groundtruth: nua
 prediction: hua [8, 11, 9]

input: การเลือกตั้งพิเศษ [4, 12, 13, 25, 29, 15, 5, 4, 27, 8, 31, 6, 7, 20, 25, 36, 33, 3]
groundtru

HBox(children=(IntProgress(value=0, max=2026), HTML(value='')))

Loss  0.36941730976104736
Loss  0.5102496147155762
Loss  0.39104074239730835
Loss  0.3989707827568054
Loss  0.3984375
input: บ้านปางสามัคคี [28, 31, 12, 9, 14, 12, 6, 30, 12, 26, 8, 39, 39, 37, 3]
groundtruth: banpangsamakkhi
 prediction: banpangsamakkhi [17, 9, 5, 7, 9, 5, 6, 19, 9, 16, 9, 3, 3, 8, 14]

input: บ้านหนองนาเวียง [28, 31, 12, 9, 11, 9, 5, 6, 9, 12, 25, 19, 37, 23, 6, 3]
groundtruth: bannongnawiang
 prediction: bannongnawiang [17, 9, 5, 5, 4, 5, 6, 5, 9, 15, 14, 9, 5, 6]

input: มณฑลกานสู้ [26, 47, 21, 29, 4, 12, 9, 30, 22, 31, 3]
groundtruth: monthonkansu
 prediction: monthonkansu [16, 4, 5, 10, 8, 4, 5, 3, 9, 5, 19, 11]

input: บ้านโนนหมากมุ่น [28, 31, 12, 9, 41, 9, 9, 11, 26, 12, 4, 26, 32, 18, 9, 3]
groundtruth: bannonmakmun
 prediction: bannonmakmun [17, 9, 5, 5, 4, 5, 16, 9, 3, 16, 11, 5]

input: นัว [9, 8, 19, 3]
groundtruth: nua
 prediction: hua [8, 11, 9]

input: การเลือกตั้งพิเศษ [4, 12, 13, 25, 29, 15, 5, 4, 27, 8, 31, 6, 7, 20, 25, 36, 33, 3]
groundtruth: kanlu

HBox(children=(IntProgress(value=0, max=2026), HTML(value='')))

Loss  0.44033315777778625
Loss  0.41107574105262756
Loss  0.4545617699623108
Loss  0.448952317237854
Loss  0.3322313129901886
input: บ้านปางสามัคคี [28, 31, 12, 9, 14, 12, 6, 30, 12, 26, 8, 39, 39, 37, 3]
groundtruth: banpangsamakkhi
 prediction: banpangsamakkhi [17, 9, 5, 7, 9, 5, 6, 19, 9, 16, 9, 3, 3, 8, 14]

input: บ้านหนองนาเวียง [28, 31, 12, 9, 11, 9, 5, 6, 9, 12, 25, 19, 37, 23, 6, 3]
groundtruth: bannongnawiang
 prediction: bannongnawiang [17, 9, 5, 5, 4, 5, 6, 5, 9, 15, 14, 9, 5, 6]

input: มณฑลกานสู้ [26, 47, 21, 29, 4, 12, 9, 30, 22, 31, 3]
groundtruth: monthonkansu
 prediction: monthonkansu [16, 4, 5, 10, 8, 4, 5, 3, 9, 5, 19, 11]

input: บ้านโนนหมากมุ่น [28, 31, 12, 9, 41, 9, 9, 11, 26, 12, 4, 26, 32, 18, 9, 3]
groundtruth: bannonmakmun
 prediction: bannonmakmun [17, 9, 5, 5, 4, 5, 16, 9, 3, 16, 11, 5]

input: นัว [9, 8, 19, 3]
groundtruth: nua
 prediction: hua [8, 11, 9]

input: การเลือกตั้งพิเศษ [4, 12, 13, 25, 29, 15, 5, 4, 27, 8, 31, 6, 7, 20, 25, 36, 33, 3]
groundtrut

HBox(children=(IntProgress(value=0, max=2026), HTML(value='')))

Loss  0.3360288739204407
Loss  0.3428659439086914
Loss  0.32382258772850037
Loss  0.34376171231269836
Loss  0.30948951840400696
input: บ้านปางสามัคคี [28, 31, 12, 9, 14, 12, 6, 30, 12, 26, 8, 39, 39, 37, 3]
groundtruth: banpangsamakkhi
 prediction: banpangsamakkhi [17, 9, 5, 7, 9, 5, 6, 19, 9, 16, 9, 3, 3, 8, 14]

input: บ้านหนองนาเวียง [28, 31, 12, 9, 11, 9, 5, 6, 9, 12, 25, 19, 37, 23, 6, 3]
groundtruth: bannongnawiang
 prediction: bannongnawiang [17, 9, 5, 5, 4, 5, 6, 5, 9, 15, 14, 9, 5, 6]

input: มณฑลกานสู้ [26, 47, 21, 29, 4, 12, 9, 30, 22, 31, 3]
groundtruth: monthonkansu
 prediction: monthonkansu [16, 4, 5, 10, 8, 4, 5, 3, 9, 5, 19, 11]

input: บ้านโนนหมากมุ่น [28, 31, 12, 9, 41, 9, 9, 11, 26, 12, 4, 26, 32, 18, 9, 3]
groundtruth: bannonmakmun
 prediction: bannonmakmun [17, 9, 5, 5, 4, 5, 16, 9, 3, 16, 11, 5]

input: นัว [9, 8, 19, 3]
groundtruth: nua
 prediction: nua [5, 11, 9]

input: การเลือกตั้งพิเศษ [4, 12, 13, 25, 29, 15, 5, 4, 27, 8, 31, 6, 7, 20, 25, 36, 33, 3]
groundtr

HBox(children=(IntProgress(value=0, max=2026), HTML(value='')))

Loss  0.32041195034980774
Loss  0.30718448758125305
Loss  0.2989685535430908
Loss  0.39838504791259766
Loss  0.4312022924423218
input: บ้านปางสามัคคี [28, 31, 12, 9, 14, 12, 6, 30, 12, 26, 8, 39, 39, 37, 3]
groundtruth: banpangsamakkhi
 prediction: banpangsamakkhi [17, 9, 5, 7, 9, 5, 6, 19, 9, 16, 9, 3, 3, 8, 14]

input: บ้านหนองนาเวียง [28, 31, 12, 9, 11, 9, 5, 6, 9, 12, 25, 19, 37, 23, 6, 3]
groundtruth: bannongnawiang
 prediction: bannongnawiang [17, 9, 5, 5, 4, 5, 6, 5, 9, 15, 14, 9, 5, 6]

input: มณฑลกานสู้ [26, 47, 21, 29, 4, 12, 9, 30, 22, 31, 3]
groundtruth: monthonkansu
 prediction: monthonkansu [16, 4, 5, 10, 8, 4, 5, 3, 9, 5, 19, 11]

input: บ้านโนนหมากมุ่น [28, 31, 12, 9, 41, 9, 9, 11, 26, 12, 4, 26, 32, 18, 9, 3]
groundtruth: bannonmakmun
 prediction: bannonmakmun [17, 9, 5, 5, 4, 5, 16, 9, 3, 16, 11, 5]

input: นัว [9, 8, 19, 3]
groundtruth: nua
 prediction: nua [5, 11, 9]

input: การเลือกตั้งพิเศษ [4, 12, 13, 25, 29, 15, 5, 4, 27, 8, 31, 6, 7, 20, 25, 36, 33, 3]
groundtr

HBox(children=(IntProgress(value=0, max=2026), HTML(value='')))

Loss  0.31754180788993835
Loss  0.35052692890167236
Loss  0.31456610560417175
Loss  0.3781377375125885
Loss  0.30663156509399414
input: บ้านปางสามัคคี [28, 31, 12, 9, 14, 12, 6, 30, 12, 26, 8, 39, 39, 37, 3]
groundtruth: banpangsamakkhi
 prediction: banpangsamakkhi [17, 9, 5, 7, 9, 5, 6, 19, 9, 16, 9, 3, 3, 8, 14]

input: บ้านหนองนาเวียง [28, 31, 12, 9, 11, 9, 5, 6, 9, 12, 25, 19, 37, 23, 6, 3]
groundtruth: bannongnawiang
 prediction: bannongnawiang [17, 9, 5, 5, 4, 5, 6, 5, 9, 15, 14, 9, 5, 6]

input: มณฑลกานสู้ [26, 47, 21, 29, 4, 12, 9, 30, 22, 31, 3]
groundtruth: monthonkansu
 prediction: monthonkansu [16, 4, 5, 10, 8, 4, 5, 3, 9, 5, 19, 11]

input: บ้านโนนหมากมุ่น [28, 31, 12, 9, 41, 9, 9, 11, 26, 12, 4, 26, 32, 18, 9, 3]
groundtruth: bannonmakmun
 prediction: bannonmakmun [17, 9, 5, 5, 4, 5, 16, 9, 3, 16, 11, 5]

input: นัว [9, 8, 19, 3]
groundtruth: nua
 prediction: nua [5, 11, 9]

input: การเลือกตั้งพิเศษ [4, 12, 13, 25, 29, 15, 5, 4, 27, 8, 31, 6, 7, 20, 25, 36, 33, 3]
groundt

HBox(children=(IntProgress(value=0, max=2026), HTML(value='')))

Loss  0.37358129024505615
Loss  0.378834992647171
Loss  0.3593375086784363
Loss  0.36714065074920654
Loss  0.3307659924030304
input: บ้านปางสามัคคี [28, 31, 12, 9, 14, 12, 6, 30, 12, 26, 8, 39, 39, 37, 3]
groundtruth: banpangsamakkhi
 prediction: banpangsamakkhi [17, 9, 5, 7, 9, 5, 6, 19, 9, 16, 9, 3, 3, 8, 14]

input: บ้านหนองนาเวียง [28, 31, 12, 9, 11, 9, 5, 6, 9, 12, 25, 19, 37, 23, 6, 3]
groundtruth: bannongnawiang
 prediction: bannongnawiang [17, 9, 5, 5, 4, 5, 6, 5, 9, 15, 14, 9, 5, 6]

input: มณฑลกานสู้ [26, 47, 21, 29, 4, 12, 9, 30, 22, 31, 3]
groundtruth: monthonkansu
 prediction: monthonkansu [16, 4, 5, 10, 8, 4, 5, 3, 9, 5, 19, 11]

input: บ้านโนนหมากมุ่น [28, 31, 12, 9, 41, 9, 9, 11, 26, 12, 4, 26, 32, 18, 9, 3]
groundtruth: bannonmakmun
 prediction: bannonmakmun [17, 9, 5, 5, 4, 5, 16, 9, 3, 16, 11, 5]

input: นัว [9, 8, 19, 3]
groundtruth: nua
 prediction: nua [5, 11, 9]

input: การเลือกตั้งพิเศษ [4, 12, 13, 25, 29, 15, 5, 4, 27, 8, 31, 6, 7, 20, 25, 36, 33, 3]
groundtrut

HBox(children=(IntProgress(value=0, max=2026), HTML(value='')))

Loss  0.3663148581981659
Loss  0.37442147731781006
Loss  0.2600252628326416
Loss  0.27780547738075256
Loss  0.31423357129096985
input: บ้านปางสามัคคี [28, 31, 12, 9, 14, 12, 6, 30, 12, 26, 8, 39, 39, 37, 3]
groundtruth: banpangsamakkhi
 prediction: banpangsamakkhi [17, 9, 5, 7, 9, 5, 6, 19, 9, 16, 9, 3, 3, 8, 14]

input: บ้านหนองนาเวียง [28, 31, 12, 9, 11, 9, 5, 6, 9, 12, 25, 19, 37, 23, 6, 3]
groundtruth: bannongnawiang
 prediction: bannongnawiang [17, 9, 5, 5, 4, 5, 6, 5, 9, 15, 14, 9, 5, 6]

input: มณฑลกานสู้ [26, 47, 21, 29, 4, 12, 9, 30, 22, 31, 3]
groundtruth: monthonkansu
 prediction: monthonkansu [16, 4, 5, 10, 8, 4, 5, 3, 9, 5, 19, 11]

input: บ้านโนนหมากมุ่น [28, 31, 12, 9, 41, 9, 9, 11, 26, 12, 4, 26, 32, 18, 9, 3]
groundtruth: bannonmakmun
 prediction: bannonmakmun [17, 9, 5, 5, 4, 5, 16, 9, 3, 16, 11, 5]

input: นัว [9, 8, 19, 3]
groundtruth: nua
 prediction: nua [5, 11, 9]

input: การเลือกตั้งพิเศษ [4, 12, 13, 25, 29, 15, 5, 4, 27, 8, 31, 6, 7, 20, 25, 36, 33, 3]
groundtr

HBox(children=(IntProgress(value=0, max=2026), HTML(value='')))

Loss  0.3236123025417328
Loss  0.37696489691734314
Loss  0.2822015583515167
Loss  0.3587813377380371
Loss  0.3410130441188812
input: บ้านปางสามัคคี [28, 31, 12, 9, 14, 12, 6, 30, 12, 26, 8, 39, 39, 37, 3]
groundtruth: banpangsamakkhi
 prediction: banpangsamakkhi [17, 9, 5, 7, 9, 5, 6, 19, 9, 16, 9, 3, 3, 8, 14]

input: บ้านหนองนาเวียง [28, 31, 12, 9, 11, 9, 5, 6, 9, 12, 25, 19, 37, 23, 6, 3]
groundtruth: bannongnawiang
 prediction: bannongnawiang [17, 9, 5, 5, 4, 5, 6, 5, 9, 15, 14, 9, 5, 6]

input: มณฑลกานสู้ [26, 47, 21, 29, 4, 12, 9, 30, 22, 31, 3]
groundtruth: monthonkansu
 prediction: monthonkansu [16, 4, 5, 10, 8, 4, 5, 3, 9, 5, 19, 11]

input: บ้านโนนหมากมุ่น [28, 31, 12, 9, 41, 9, 9, 11, 26, 12, 4, 26, 32, 18, 9, 3]
groundtruth: bannonmakmun
 prediction: bannonmakmun [17, 9, 5, 5, 4, 5, 16, 9, 3, 16, 11, 5]

input: นัว [9, 8, 19, 3]
groundtruth: nua
 prediction: nua [5, 11, 9]

input: การเลือกตั้งพิเศษ [4, 12, 13, 25, 29, 15, 5, 4, 27, 8, 31, 6, 7, 20, 25, 36, 33, 3]
groundtrut

HBox(children=(IntProgress(value=0, max=2026), HTML(value='')))

Loss  0.29775699973106384
Loss  0.4047224819660187
Loss  0.30073168873786926
Loss  0.4221619963645935
Loss  0.3110000491142273
input: บ้านปางสามัคคี [28, 31, 12, 9, 14, 12, 6, 30, 12, 26, 8, 39, 39, 37, 3]
groundtruth: banpangsamakkhi
 prediction: banpangsamakkhi [17, 9, 5, 7, 9, 5, 6, 19, 9, 16, 9, 3, 3, 8, 14]

input: บ้านหนองนาเวียง [28, 31, 12, 9, 11, 9, 5, 6, 9, 12, 25, 19, 37, 23, 6, 3]
groundtruth: bannongnawiang
 prediction: bannongnawiang [17, 9, 5, 5, 4, 5, 6, 5, 9, 15, 14, 9, 5, 6]

input: มณฑลกานสู้ [26, 47, 21, 29, 4, 12, 9, 30, 22, 31, 3]
groundtruth: monthonkansu
 prediction: monthonkansu [16, 4, 5, 10, 8, 4, 5, 3, 9, 5, 19, 11]

input: บ้านโนนหมากมุ่น [28, 31, 12, 9, 41, 9, 9, 11, 26, 12, 4, 26, 32, 18, 9, 3]
groundtruth: bannonmakmun
 prediction: bannonmakmun [17, 9, 5, 5, 4, 5, 16, 9, 3, 16, 11, 5]

input: นัว [9, 8, 19, 3]
groundtruth: nua
 prediction: nua [5, 11, 9]

input: การเลือกตั้งพิเศษ [4, 12, 13, 25, 29, 15, 5, 4, 27, 8, 31, 6, 7, 20, 25, 36, 33, 3]
groundtru

HBox(children=(IntProgress(value=0, max=2026), HTML(value='')))

Loss  0.35722050070762634
Loss  0.31930023431777954
Loss  0.2946299910545349
Loss  0.3571363091468811
Loss  0.29988110065460205
input: บ้านปางสามัคคี [28, 31, 12, 9, 14, 12, 6, 30, 12, 26, 8, 39, 39, 37, 3]
groundtruth: banpangsamakkhi
 prediction: banpangsamakkhi [17, 9, 5, 7, 9, 5, 6, 19, 9, 16, 9, 3, 3, 8, 14]

input: บ้านหนองนาเวียง [28, 31, 12, 9, 11, 9, 5, 6, 9, 12, 25, 19, 37, 23, 6, 3]
groundtruth: bannongnawiang
 prediction: bannongnawiang [17, 9, 5, 5, 4, 5, 6, 5, 9, 15, 14, 9, 5, 6]

input: มณฑลกานสู้ [26, 47, 21, 29, 4, 12, 9, 30, 22, 31, 3]
groundtruth: monthonkansu
 prediction: monthonkansu [16, 4, 5, 10, 8, 4, 5, 3, 9, 5, 19, 11]

input: บ้านโนนหมากมุ่น [28, 31, 12, 9, 41, 9, 9, 11, 26, 12, 4, 26, 32, 18, 9, 3]
groundtruth: bannonmakmun
 prediction: bannonmakmun [17, 9, 5, 5, 4, 5, 16, 9, 3, 16, 11, 5]

input: นัว [9, 8, 19, 3]
groundtruth: nua
 prediction: nua [5, 11, 9]

input: การเลือกตั้งพิเศษ [4, 12, 13, 25, 29, 15, 5, 4, 27, 8, 31, 6, 7, 20, 25, 36, 33, 3]
groundtr

HBox(children=(IntProgress(value=0, max=2026), HTML(value='')))

Loss  0.29727524518966675
Loss  0.27846983075141907
Loss  0.3773447871208191
Loss  0.4190753698348999
Loss  0.370210736989975
input: บ้านปางสามัคคี [28, 31, 12, 9, 14, 12, 6, 30, 12, 26, 8, 39, 39, 37, 3]
groundtruth: banpangsamakkhi
 prediction: banpangsamakkhi [17, 9, 5, 7, 9, 5, 6, 19, 9, 16, 9, 3, 3, 8, 14]

input: บ้านหนองนาเวียง [28, 31, 12, 9, 11, 9, 5, 6, 9, 12, 25, 19, 37, 23, 6, 3]
groundtruth: bannongnawiang
 prediction: bannongnawiang [17, 9, 5, 5, 4, 5, 6, 5, 9, 15, 14, 9, 5, 6]

input: มณฑลกานสู้ [26, 47, 21, 29, 4, 12, 9, 30, 22, 31, 3]
groundtruth: monthonkansu
 prediction: monthonkansu [16, 4, 5, 10, 8, 4, 5, 3, 9, 5, 19, 11]

input: บ้านโนนหมากมุ่น [28, 31, 12, 9, 41, 9, 9, 11, 26, 12, 4, 26, 32, 18, 9, 3]
groundtruth: bannonmakmun
 prediction: bannonmakmun [17, 9, 5, 5, 4, 5, 16, 9, 3, 16, 11, 5]

input: นัว [9, 8, 19, 3]
groundtruth: nua
 prediction: nua [5, 11, 9]

input: การเลือกตั้งพิเศษ [4, 12, 13, 25, 29, 15, 5, 4, 27, 8, 31, 6, 7, 20, 25, 36, 33, 3]
groundtrut

HBox(children=(IntProgress(value=0, max=2026), HTML(value='')))

Loss  0.3112165033817291
Loss  0.32027021050453186
Loss  0.29929471015930176
Loss  0.3692239224910736
Loss  0.324817419052124
input: บ้านปางสามัคคี [28, 31, 12, 9, 14, 12, 6, 30, 12, 26, 8, 39, 39, 37, 3]
groundtruth: banpangsamakkhi
 prediction: banpangsamakkhi [17, 9, 5, 7, 9, 5, 6, 19, 9, 16, 9, 3, 3, 8, 14]

input: บ้านหนองนาเวียง [28, 31, 12, 9, 11, 9, 5, 6, 9, 12, 25, 19, 37, 23, 6, 3]
groundtruth: bannongnawiang
 prediction: bannongnawiang [17, 9, 5, 5, 4, 5, 6, 5, 9, 15, 14, 9, 5, 6]

input: มณฑลกานสู้ [26, 47, 21, 29, 4, 12, 9, 30, 22, 31, 3]
groundtruth: monthonkansu
 prediction: monthonkansu [16, 4, 5, 10, 8, 4, 5, 3, 9, 5, 19, 11]

input: บ้านโนนหมากมุ่น [28, 31, 12, 9, 41, 9, 9, 11, 26, 12, 4, 26, 32, 18, 9, 3]
groundtruth: bannonmakmun
 prediction: bannonmakmun [17, 9, 5, 5, 4, 5, 16, 9, 3, 16, 11, 5]

input: นัว [9, 8, 19, 3]
groundtruth: nua
 prediction: 9ua [35, 11, 9]

input: การเลือกตั้งพิเศษ [4, 12, 13, 25, 29, 15, 5, 4, 27, 8, 31, 6, 7, 20, 25, 36, 33, 3]
groundtru

HBox(children=(IntProgress(value=0, max=2026), HTML(value='')))

Loss  0.35912320017814636
Loss  0.3263883888721466
Loss  0.2888607978820801
Loss  0.38656988739967346
Loss  0.282650887966156
input: บ้านปางสามัคคี [28, 31, 12, 9, 14, 12, 6, 30, 12, 26, 8, 39, 39, 37, 3]
groundtruth: banpangsamakkhi
 prediction: banpangsamakkhi [17, 9, 5, 7, 9, 5, 6, 19, 9, 16, 9, 3, 3, 8, 14]

input: บ้านหนองนาเวียง [28, 31, 12, 9, 11, 9, 5, 6, 9, 12, 25, 19, 37, 23, 6, 3]
groundtruth: bannongnawiang
 prediction: bannongnawiang [17, 9, 5, 5, 4, 5, 6, 5, 9, 15, 14, 9, 5, 6]

input: มณฑลกานสู้ [26, 47, 21, 29, 4, 12, 9, 30, 22, 31, 3]
groundtruth: monthonkansu
 prediction: monthonkansu [16, 4, 5, 10, 8, 4, 5, 3, 9, 5, 19, 11]

input: บ้านโนนหมากมุ่น [28, 31, 12, 9, 41, 9, 9, 11, 26, 12, 4, 26, 32, 18, 9, 3]
groundtruth: bannonmakmun
 prediction: bannonmakmun [17, 9, 5, 5, 4, 5, 16, 9, 3, 16, 11, 5]

input: นัว [9, 8, 19, 3]
groundtruth: nua
 prediction: nua [5, 11, 9]

input: การเลือกตั้งพิเศษ [4, 12, 13, 25, 29, 15, 5, 4, 27, 8, 31, 6, 7, 20, 25, 36, 33, 3]
groundtrut

HBox(children=(IntProgress(value=0, max=2026), HTML(value='')))

Loss  0.354021817445755
Loss  0.3350050747394562
Loss  0.2979723811149597
Loss  0.3013467788696289
Loss  0.35500696301460266
input: บ้านปางสามัคคี [28, 31, 12, 9, 14, 12, 6, 30, 12, 26, 8, 39, 39, 37, 3]
groundtruth: banpangsamakkhi
 prediction: banpangsamakkhi [17, 9, 5, 7, 9, 5, 6, 19, 9, 16, 9, 3, 3, 8, 14]

input: บ้านหนองนาเวียง [28, 31, 12, 9, 11, 9, 5, 6, 9, 12, 25, 19, 37, 23, 6, 3]
groundtruth: bannongnawiang
 prediction: bannongnawiang [17, 9, 5, 5, 4, 5, 6, 5, 9, 15, 14, 9, 5, 6]

input: มณฑลกานสู้ [26, 47, 21, 29, 4, 12, 9, 30, 22, 31, 3]
groundtruth: monthonkansu
 prediction: monthonkansu [16, 4, 5, 10, 8, 4, 5, 3, 9, 5, 19, 11]

input: บ้านโนนหมากมุ่น [28, 31, 12, 9, 41, 9, 9, 11, 26, 12, 4, 26, 32, 18, 9, 3]
groundtruth: bannonmakmun
 prediction: bannonmakmun [17, 9, 5, 5, 4, 5, 16, 9, 3, 16, 11, 5]

input: นัว [9, 8, 19, 3]
groundtruth: nua
 prediction: nua [5, 11, 9]

input: การเลือกตั้งพิเศษ [4, 12, 13, 25, 29, 15, 5, 4, 27, 8, 31, 6, 7, 20, 25, 36, 33, 3]
groundtruth

HBox(children=(IntProgress(value=0, max=2026), HTML(value='')))

Loss  0.3370044231414795
Loss  0.3926723599433899
Loss  0.375838965177536
Loss  0.3750090003013611
Loss  0.328673779964447
input: บ้านปางสามัคคี [28, 31, 12, 9, 14, 12, 6, 30, 12, 26, 8, 39, 39, 37, 3]
groundtruth: banpangsamakkhi
 prediction: banpangsamakkhi [17, 9, 5, 7, 9, 5, 6, 19, 9, 16, 9, 3, 3, 8, 14]

input: บ้านหนองนาเวียง [28, 31, 12, 9, 11, 9, 5, 6, 9, 12, 25, 19, 37, 23, 6, 3]
groundtruth: bannongnawiang
 prediction: bannongnawiang [17, 9, 5, 5, 4, 5, 6, 5, 9, 15, 14, 9, 5, 6]

input: มณฑลกานสู้ [26, 47, 21, 29, 4, 12, 9, 30, 22, 31, 3]
groundtruth: monthonkansu
 prediction: monthonkansu [16, 4, 5, 10, 8, 4, 5, 3, 9, 5, 19, 11]

input: บ้านโนนหมากมุ่น [28, 31, 12, 9, 41, 9, 9, 11, 26, 12, 4, 26, 32, 18, 9, 3]
groundtruth: bannonmakmun
 prediction: bannonmakmun [17, 9, 5, 5, 4, 5, 16, 9, 3, 16, 11, 5]

input: นัว [9, 8, 19, 3]
groundtruth: nua
 prediction: nua [5, 11, 9]

input: การเลือกตั้งพิเศษ [4, 12, 13, 25, 29, 15, 5, 4, 27, 8, 31, 6, 7, 20, 25, 36, 33, 3]
groundtruth: 

HBox(children=(IntProgress(value=0, max=2026), HTML(value='')))

Loss  0.31234925985336304
Loss  0.3727436661720276
Loss  0.29092225432395935
Loss  0.25551044940948486
Loss  0.3458912670612335
input: บ้านปางสามัคคี [28, 31, 12, 9, 14, 12, 6, 30, 12, 26, 8, 39, 39, 37, 3]
groundtruth: banpangsamakkhi
 prediction: banpangsamakkhi [17, 9, 5, 7, 9, 5, 6, 19, 9, 16, 9, 3, 3, 8, 14]

input: บ้านหนองนาเวียง [28, 31, 12, 9, 11, 9, 5, 6, 9, 12, 25, 19, 37, 23, 6, 3]
groundtruth: bannongnawiang
 prediction: bannongnawiang [17, 9, 5, 5, 4, 5, 6, 5, 9, 15, 14, 9, 5, 6]

input: มณฑลกานสู้ [26, 47, 21, 29, 4, 12, 9, 30, 22, 31, 3]
groundtruth: monthonkansu
 prediction: monthonkansu [16, 4, 5, 10, 8, 4, 5, 3, 9, 5, 19, 11]

input: บ้านโนนหมากมุ่น [28, 31, 12, 9, 41, 9, 9, 11, 26, 12, 4, 26, 32, 18, 9, 3]
groundtruth: bannonmakmun
 prediction: bannonmakmun [17, 9, 5, 5, 4, 5, 16, 9, 3, 16, 11, 5]

input: นัว [9, 8, 19, 3]
groundtruth: nua
 prediction: nua [5, 11, 9]

input: การเลือกตั้งพิเศษ [4, 12, 13, 25, 29, 15, 5, 4, 27, 8, 31, 6, 7, 20, 25, 36, 33, 3]
groundtr

HBox(children=(IntProgress(value=0, max=2026), HTML(value='')))

Loss  0.31277844309806824
Loss  0.32897552847862244
Loss  0.2950771749019623
Loss  0.2804759740829468
Loss  0.39502695202827454
input: บ้านปางสามัคคี [28, 31, 12, 9, 14, 12, 6, 30, 12, 26, 8, 39, 39, 37, 3]
groundtruth: banpangsamakkhi
 prediction: banpangsamakkhi [17, 9, 5, 7, 9, 5, 6, 19, 9, 16, 9, 3, 3, 8, 14]

input: บ้านหนองนาเวียง [28, 31, 12, 9, 11, 9, 5, 6, 9, 12, 25, 19, 37, 23, 6, 3]
groundtruth: bannongnawiang
 prediction: bannongnawiang [17, 9, 5, 5, 4, 5, 6, 5, 9, 15, 14, 9, 5, 6]

input: มณฑลกานสู้ [26, 47, 21, 29, 4, 12, 9, 30, 22, 31, 3]
groundtruth: monthonkansu
 prediction: monthonkansu [16, 4, 5, 10, 8, 4, 5, 3, 9, 5, 19, 11]

input: บ้านโนนหมากมุ่น [28, 31, 12, 9, 41, 9, 9, 11, 26, 12, 4, 26, 32, 18, 9, 3]
groundtruth: bannonmakmun
 prediction: bannonmakmun [17, 9, 5, 5, 4, 5, 16, 9, 3, 16, 11, 5]

input: นัว [9, 8, 19, 3]
groundtruth: nua
 prediction: 9ua [35, 11, 9]

input: การเลือกตั้งพิเศษ [4, 12, 13, 25, 29, 15, 5, 4, 27, 8, 31, 6, 7, 20, 25, 36, 33, 3]
groundt

HBox(children=(IntProgress(value=0, max=2026), HTML(value='')))

Loss  0.307510644197464
Loss  0.36824309825897217
Loss  0.32088959217071533
Loss  0.3614412844181061
Loss  0.28250032663345337
input: บ้านปางสามัคคี [28, 31, 12, 9, 14, 12, 6, 30, 12, 26, 8, 39, 39, 37, 3]
groundtruth: banpangsamakkhi
 prediction: banpangsamakkhi [17, 9, 5, 7, 9, 5, 6, 19, 9, 16, 9, 3, 3, 8, 14]

input: บ้านหนองนาเวียง [28, 31, 12, 9, 11, 9, 5, 6, 9, 12, 25, 19, 37, 23, 6, 3]
groundtruth: bannongnawiang
 prediction: bannongnawiang [17, 9, 5, 5, 4, 5, 6, 5, 9, 15, 14, 9, 5, 6]

input: มณฑลกานสู้ [26, 47, 21, 29, 4, 12, 9, 30, 22, 31, 3]
groundtruth: monthonkansu
 prediction: monthonkansu [16, 4, 5, 10, 8, 4, 5, 3, 9, 5, 19, 11]

input: บ้านโนนหมากมุ่น [28, 31, 12, 9, 41, 9, 9, 11, 26, 12, 4, 26, 32, 18, 9, 3]
groundtruth: bannonmakmun
 prediction: bannonmakmun [17, 9, 5, 5, 4, 5, 16, 9, 3, 16, 11, 5]

input: นัว [9, 8, 19, 3]
groundtruth: nua
 prediction: 9ua [35, 11, 9]

input: การเลือกตั้งพิเศษ [4, 12, 13, 25, 29, 15, 5, 4, 27, 8, 31, 6, 7, 20, 25, 36, 33, 3]
groundtr

HBox(children=(IntProgress(value=0, max=2026), HTML(value='')))

Loss  0.27166005969047546
Loss  0.3619347810745239
Loss  0.3666609823703766
Loss  0.29701241850852966
Loss  0.3157553970813751
input: บ้านปางสามัคคี [28, 31, 12, 9, 14, 12, 6, 30, 12, 26, 8, 39, 39, 37, 3]
groundtruth: banpangsamakkhi
 prediction: banpangsamakkhi [17, 9, 5, 7, 9, 5, 6, 19, 9, 16, 9, 3, 3, 8, 14]

input: บ้านหนองนาเวียง [28, 31, 12, 9, 11, 9, 5, 6, 9, 12, 25, 19, 37, 23, 6, 3]
groundtruth: bannongnawiang
 prediction: bannongnawiang [17, 9, 5, 5, 4, 5, 6, 5, 9, 15, 14, 9, 5, 6]

input: มณฑลกานสู้ [26, 47, 21, 29, 4, 12, 9, 30, 22, 31, 3]
groundtruth: monthonkansu
 prediction: monthonkansu [16, 4, 5, 10, 8, 4, 5, 3, 9, 5, 19, 11]

input: บ้านโนนหมากมุ่น [28, 31, 12, 9, 41, 9, 9, 11, 26, 12, 4, 26, 32, 18, 9, 3]
groundtruth: bannonmakmun
 prediction: bannonmakmun [17, 9, 5, 5, 4, 5, 16, 9, 3, 16, 11, 5]

input: นัว [9, 8, 19, 3]
groundtruth: nua
 prediction: 9ua [35, 11, 9]

input: การเลือกตั้งพิเศษ [4, 12, 13, 25, 29, 15, 5, 4, 27, 8, 31, 6, 7, 20, 25, 36, 33, 3]
groundtr

HBox(children=(IntProgress(value=0, max=2026), HTML(value='')))

Loss  0.31817498803138733
Loss  0.2956112325191498
Loss  0.28813546895980835
Loss  0.33040162920951843
Loss  0.35370349884033203
input: บ้านปางสามัคคี [28, 31, 12, 9, 14, 12, 6, 30, 12, 26, 8, 39, 39, 37, 3]
groundtruth: banpangsamakkhi
 prediction: banpangsamakkhi [17, 9, 5, 7, 9, 5, 6, 19, 9, 16, 9, 3, 3, 8, 14]

input: บ้านหนองนาเวียง [28, 31, 12, 9, 11, 9, 5, 6, 9, 12, 25, 19, 37, 23, 6, 3]
groundtruth: bannongnawiang
 prediction: bannongnawiang [17, 9, 5, 5, 4, 5, 6, 5, 9, 15, 14, 9, 5, 6]

input: มณฑลกานสู้ [26, 47, 21, 29, 4, 12, 9, 30, 22, 31, 3]
groundtruth: monthonkansu
 prediction: monthonkansu [16, 4, 5, 10, 8, 4, 5, 3, 9, 5, 19, 11]

input: บ้านโนนหมากมุ่น [28, 31, 12, 9, 41, 9, 9, 11, 26, 12, 4, 26, 32, 18, 9, 3]
groundtruth: bannonmakmun
 prediction: bannonmakmun [17, 9, 5, 5, 4, 5, 16, 9, 3, 16, 11, 5]

input: นัว [9, 8, 19, 3]
groundtruth: nua
 prediction: nua [5, 11, 9]

input: การเลือกตั้งพิเศษ [4, 12, 13, 25, 29, 15, 5, 4, 27, 8, 31, 6, 7, 20, 25, 36, 33, 3]
groundt

HBox(children=(IntProgress(value=0, max=2026), HTML(value='')))

Loss  0.37154972553253174
Loss  0.3275672197341919


KeyboardInterrupt: 

In [791]:
print(thai_romanization_dataset.input_texts[5000:5010])
print(thai_romanization_dataset.target_texts[5000:5010])

['บ้านปางสามัคคี', 'บ้านหนองนาเวียง', 'มณฑลกานสู้', 'บ้านโนนหมากมุ่น', 'นัว', 'การเลือกตั้งพิเศษ', 'มัลกะ', 'ทหารกองหนุน', 'ตำบลวังสรรพรส', 'พร้อย']
['banpangsamakkhi', 'bannongnawiang', 'monthonkansu', 'bannonmakmun', 'nua', 'kanlueaktangphiset', 'manka', 'thahankongnun', 'tambonwangsappharot', 'phroi']


numericalized [19, 20, 21, 22, 13, 23, 24, 3]
sentence_length [8]
1: วิฑูรย์ -> 83454r4iri88h8kak454r4iri88h8k8k454r4iri88h8k8k454r4iri88h8k



## Evaluation on val_set with following metrics:
   
1. F1-score (macro-average) -- Character level

2. Exact Match (EM)

3. Exact Match (EM) - Character level

    

In [845]:
# Functions for model performance evaluation
def precision(pred_chars, target_chars):
    # TP / TP + FP
    pred_chars_multiset = Counter(pred_chars)
    target_chars_multiset = Counter(target_chars)

    overlap = list((pred_chars_multiset & target_chars_multiset).elements())
    n_overlap = len(overlap)

    return n_overlap / max(len(pred_chars), 1)

def recall(pred_chars, target_chars):
    # TP / TP + FN
        
    pred_chars_multiset = Counter(pred_chars)
    target_chars_multiset = Counter(target_chars)

    overlap = list((pred_chars_multiset & target_chars_multiset).elements())
    n_overlap = len(overlap)
    return n_overlap / len(target_chars)

def f1(precision, recall):
    
    return (2  * precision * recall) / (precision + recall)

def em(pred, target):
    if pred == target:
        return 1
    return 0

def em_char(pred_chars, target_chars):
    N_target_chars = len(target_chars)
    N_pred_chars = len(pred_chars)

    score = 0
    for index in range(min(N_pred_chars, N_target_chars)):
        if target_chars[index] == pred_chars[index]:
            score+=1
            
    return score / max(N_target_chars, N_pred_chars)

In [860]:
def evaluate_inference(model, val_indices):
    cumulative_precision = 0
    cumulative_recall = 0
    cumulative_em = 0
    cumulative_em_char = 0
    
    N = len(val_indices)

    epoch_loss = 0
    prediction_results = []
    em_char_score = 0
    for i, val_index in tqdm(enumerate(val_indices), total=N):
        input_text = thai_romanization_dataset.input_texts[val_index]
        target_text = thai_romanization_dataset.target_texts[val_index]

        
        prediction, indices = inference(model, input_text)
        prediction_results.append(prediction)
        if i <= 10:
            print('Example: {}'.format(i+1))
            print('      input: {}'.format(input_text))
            print('groundtruth: {}'.format(target_text))
            print(' prediction: {}\n'.format(prediction))

        pred_chars = [char for char in prediction]
        target_chars = [char for char in target_text]

        cumulative_precision += precision(pred_chars, target_chars)
        cumulative_recall +=  recall(pred_chars, target_chars)
        cumulative_em_char += em_char(pred_chars, target_chars)
        cumulative_em += em(prediction, target_text)

    macro_average_precision = cumulative_precision / N
    macro_average_recall = cumulative_recall /N
    f1_macro_average = f1(macro_average_precision, macro_average_recall) 
    em_score = cumulative_em / N
    em_char_score = cumulative_em_char / N
    print('')
    print('F1 (macro-average) = ', f1_macro_average)
    print('EM = ', em_score)
    print('EM (Character-level) = ', em_char_score)


    return f1_macro_average, em_score, em_char_score, prediction_results

evaluate_inference(model, val_indices)


HBox(children=(IntProgress(value=0, max=129642), HTML(value='')))

Example: 1
      input: ขัตติยา กรีมี
groundtruth: khattiya krimi
 prediction: khattiya krimi

Example: 2
      input: คมสรร บุญทอง
groundtruth: khomsan bunthong
 prediction: khomsan bunthong

Example: 3
      input: ต่อเลขหมาย
groundtruth: tolekmai
 prediction: tolekmai

Example: 4
      input: กิตติศักดิ์ หวังวรวงศ์
groundtruth: kittisak wangworawong
 prediction: mittisak wangworawong

Example: 5
      input: สถาปัตยกรรมศาสตร์
groundtruth: sathapattayakamsat
 prediction: sathapattayakammasat

Example: 6
      input: กิมเจียว แซ่โง้ว
groundtruth: kimchiao sae-ngow
 prediction: kimchiao sae-ngow

Example: 7
      input: เขยื้อน รังสิยีรานนท์
groundtruth: khayuean rangsiyiranon
 prediction: khayuean rangsiyiranon

Example: 8
      input: ขนิษฐา เสมอใจ
groundtruth: khanittha samoechai
 prediction: hhanittha samoechai

Example: 9
      input: กนกกุล หวังบุญผาติ
groundtruth: kanokkun wangbunphati
 prediction: kanokkun wangbunphati

Example: 10
      input: จรัล กิติสาย
groundtruth: charan 

(0.9771169937908014,
 0.7423674426497585,
 0.964700464024947,
 ['khattiya krimi',
  'khomsan bunthong',
  'tolekmai',
  'mittisak wangworawong',
  'sathapattayakammasat',
  'kimchiao sae-ngow',
  'khayuean rangsiyiranon',
  'hhanittha samoechai',
  'kanokkun wangbunphati',
  'charan kitisai',
  '8osai manatsilanon',
  '8esini bamphennorakit',
  'chamrat chan-amphon',
  'kitison sukpradit',
  'chuthamat sikomut',
  'manlaya phasukdi',
  'kanokwan champala',
  'chamnian sawin',
  'chueachan buathet',
  'chakkraphan bunlo',
  'banthakrathiam',
  'mantinan phachonwichaichan',
  'manlaya chiraphatichai',
  'charuni panchaphatthanasiri',
  'khonluatchueam',
  'khanthasisai',
  'kannika iamnoi',
  'kiang saetae',
  'chamrat phumlamchiak',
  'kasem chaowarat',
  'chitti phetsawang',
  'khomsan sangsithawong',
  'kanokwan rangsisena',
  'khwanchit bunmilaptrakun',
  'manya saetia',
  'kopkun rit',
  'chetsada watthanalueang-arun',
  'kittiphong diaotrakun',
  'chaturong chaithiang',
  'morakamo