#### Neural Machine Translation (NMT) with Attention

In [18]:
### Import the necessary packages

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random

In [19]:
### Create the source vocabulary

# Default word tokens
PAD_token = 0  # Used for padding short number sequences
SOS_token = 1  # Start-of-sequence token
EOS_token = 2  # End-of-sequence token
UNK_token = 3  # Unknown word token

class ASR_Vocab(object):
    def __init__(self, digit_seqs):
        super(ASR_Vocab, self).__init__()
        self.digit2index = {}
        self.digit2count = {}
        self.index2digit = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS", UNK_token: "UNK"}
        
        # Count SOS, EOS, PAD, UNK
        self.num_tokens  = 4
        self.digit_seqs  = digit_seqs
        
    def dig2idx(self, digit):
        if digit in self.digit2index:
            return self.digit2index[digit]
        else:
            return UNK_token
        
    def idx2dig(self, idx):
        if idx in self.index2digit:
            return self.index2digit[idx]
        else:
            return self.index2digit[UNK_token]
            
    def add_digit(self, digit):
        if digit in self.digit2index:
            self.digit2count[digit] += 1
            
        else:
            self.digit2index[digit] = self.num_tokens
            self.index2digit[self.num_tokens] = digit
            self.digit2count[digit] = 1
            self.num_tokens += 1
            
    def build_vocab(self):        
        for seq in self.digit_seqs:
            for digit in seq:      # Ignore EOS token
                self.add_digit(digit)
            
        print("Vocabulary created with %d tokens ..." % self.num_tokens)
        # return self.num_tokens
    
    def vocab_size(self):
        return self.num_tokens
    
    def vocabulary(self):
        for idx in self.index2digit:
            print(idx, self.index2digit[idx])
    
    def encode(self, seq):
        return [self.dig2idx(digit) for digit in seq]
    
    def decode(self, seq):
        return "".join([self.idx2dig(idx) for idx in seq])
    
# Create the vocabulary
with open('../data/ASR/ASR_Labels.txt', 'r') as f:
    asr_labels = f.read().splitlines()
asr_labels = [asr_label[:-1] for asr_label in asr_labels]
    
print("Number of ASR utterances : ", len(asr_labels))
vocab_asr = ASR_Vocab(asr_labels)
vocab_asr.build_vocab()

# Test the vocabulary
print("\nOriginal sequence : ", asr_labels[4])
encoded_seq = vocab_asr.encode(asr_labels[4])
decoded_seq = vocab_asr.decode(encoded_seq)

print("Encoded sequence  : ", encoded_seq)
print("Decoded sequence  : ", decoded_seq)

Number of ASR utterances :  8398
Vocabulary created with 15 tokens ...

Original sequence :  49Z93Z2
Encoded sequence  :  [10, 7, 11, 7, 9, 11, 4]
Decoded sequence  :  49Z93Z2


In [20]:
### Create the target vocabulary

class MT_Vocab(object):
    def __init__(self, digit_seqs):
        super(MT_Vocab, self).__init__()
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS", UNK_token: "UNK"}
        
        # Count SOS, EOS, PAD, UNK
        self.num_words  = 4
        self.digit_seqs = digit_seqs
        
    def word2idx(self, word):
        if word in self.word2index:
            return self.word2index[word]
        else:
            return UNK_token
        
    def idx2word(self, idx):
        if idx in self.index2word:
            return self.index2word[idx]
        else:
            return self.index2word[UNK_token]
            
    def add_word(self, word):
        if word in self.word2index:
            self.word2count[word] += 1
            
        else:
            self.word2index[word] = self.num_words
            self.index2word[self.num_words] = word
            self.word2count[word] = 1
            self.num_words += 1
            
    def build_vocab(self):        
        for seq in self.digit_seqs:
            for word in seq.split(" "):
                self.add_word(word)

        print("Vocabulary created with %d tokens ..." % self.num_words)
        # return self.num_tokens
    
    def vocab_size(self):
        return self.num_words
    
    def vocabulary(self):
        for idx in self.index2word:
            print(idx, self.index2word[idx])
    
    def encode(self, seq):
        return [self.word2idx(word) for word in seq.split(" ")]
    
    def decode(self, seq):
        return " ".join([self.idx2word(idx) for idx in seq])
    
# Create the vocabulary
with open('../data/ASR/MT_Labels.txt', 'r') as f:
    mt_labels = f.read().splitlines()
mt_labels = [mt_label[:-1] for mt_label in mt_labels]

print("Number of MT sentences : ", len(mt_labels))
vocab_mt = MT_Vocab(mt_labels)
vocab_mt.build_vocab()

# Test the vocabulary
print("\nOriginal sequence : ", mt_labels[4])
encoded_seq = vocab_mt.encode(mt_labels[4])
decoded_seq = vocab_mt.decode(encoded_seq)

print("Encoded sequence  : ", encoded_seq)
print("Decoded sequence  : ", decoded_seq)

Number of MT sentences :  8398
Vocabulary created with 15 tokens ...

Original sequence :  நான்கு ஒன்பது பூஜ்யம் ஒன்பது மூன்று பூஜ்யம் இரண்டு
Encoded sequence  :  [10, 7, 11, 7, 9, 11, 4]
Decoded sequence  :  நான்கு ஒன்பது பூஜ்யம் ஒன்பது மூன்று பூஜ்யம் இரண்டு


In [21]:
### Define the device

USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")
print(torch.cuda.get_device_name(torch.cuda.current_device()))

NVIDIA GeForce RTX 2070 Super with Max-Q Design


In [22]:
# Prepare the data and create the dataloader
max_len = max([len(asr_label) for asr_label in asr_labels]) + 2

X = []
for i in range(len(asr_labels)):
    encoded_seq = [SOS_token] + vocab_asr.encode(asr_labels[i]) + [EOS_token]
    encoded_seq = encoded_seq + [PAD_token] * (max_len - len(encoded_seq))
    
    # Convert to one-hot encoding
    one_hot_enc = np.zeros((max_len, vocab_asr.vocab_size()))
    for j in range(len(encoded_seq)):
        # print(j, encoded_seq[j], vocab_mt.decode([encoded_seq[j]]))
        one_hot_enc[j][encoded_seq[j]] = 1
    
    X.append(one_hot_enc)

Y, Z = [], []
for i in range(len(mt_labels)):
    encoded_seq = [SOS_token] + vocab_mt.encode(mt_labels[i]) + [EOS_token]
    Y.append(encoded_seq + [PAD_token] * (max_len - len(encoded_seq)))
    
    # Convert to one-hot encoding
    one_hot_enc = np.zeros((max_len, vocab_mt.vocab_size()))
    for j in range(len(encoded_seq)):
        # print(j, encoded_seq[j], vocab_mt.decode([encoded_seq[j]]))
        one_hot_enc[j][encoded_seq[j]] = 1
        
    Z.append(one_hot_enc)

X = np.array(X)
Z = np.array(Z)
X = torch.tensor(X, dtype=torch.float)
Y = torch.tensor(Y, dtype=torch.long)
Z = torch.tensor(Z, dtype=torch.float)

print(X.shape, Y.shape, Z.shape)

torch.Size([8398, 9, 15]) torch.Size([8398, 9]) torch.Size([8398, 9, 15])


In [23]:
# Create the dataloader

class DigitsDataset(Dataset):
    def __init__(self, X, Y, Z):
        super(DigitsDataset, self).__init__()
        self.X = X
        self.Y = Y
        self.Z = Z
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx], self.Z[idx]
    
dataset = DigitsDataset(X, Y, Z)

In [24]:
# Define the seq2seq model
input_size  = vocab_asr.vocab_size()
hidden_size = 300
output_size = vocab_mt.vocab_size()

n_epochs = 300
batch_size = 512

In [25]:
# Define the encoder
# Encoder inputs : One-hot encoded ASR labels (B X L X U)

class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, n_layers=2):
        super(EncoderRNN, self).__init__()
        self.input_size = input_size                    # U
        self.hidden_size = hidden_size                  # H
        self.n_layers = n_layers                        
        
        self.U = nn.Linear(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, batch_first=True)
        
    def _init_hidden(self, batch_size):
        return torch.zeros(self.n_layers, batch_size, self.hidden_size).to(device)
    
    # Train the encoder with the decoder
    def forward(self, input, hidden=None):
        hidden = self._init_hidden(input.size(0)) if hidden is None else hidden
                                                        # 2 x B x H
        input = F.relu(self.U(input))                   # B x L x H
        outputs, hidden = self.gru(input, hidden)       # B x L x H, 2 x B x H
        
        hidden = hidden[-1].unsqueeze(0)                # 1 x B x H
        outputs = outputs[:, -1, :].unsqueeze(1)        # B x 1 x H
        
        return outputs, hidden                          # B x L x H, 1 x B x H                        
    
    # Encoder in inference mode
    def inference(self, input):
        hidden = self._init_hidden(input.size(0))
        input = F.relu(self.U(input))
        outputs, hidden = self.gru(input, hidden)
        
        hidden = hidden[-1].unsqueeze(0)
        outputs = outputs[:, -1, :].unsqueeze(1)
        
        return outputs, hidden

In [26]:
# Define the encoder-decoder attention model

class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.hidden_size = hidden_size
        
        self.W1 = nn.Linear(hidden_size, hidden_size)
        self.W2 = nn.Linear(hidden_size, hidden_size)
        self.V  = nn.Linear(hidden_size, 1)
        
    def attend(self, hidden, encoder_states):
        # hidden : 1 x B x H
        # encoder_outputs : B x L x H
        
        hidden = hidden.squeeze(0)                      # B x H
        hidden = hidden.unsqueeze(1)                    # B x 1 x H
        
        # Calculate the attention weights
        attn_weights = self.V(torch.tanh(self.W1(hidden) + self.W2(encoder_states)))
        attn_weights = F.softmax(attn_weights, dim=1)   # B x L x 1
        
        return attn_weights

In [27]:
# Define the decoder
# Decoder inputs : Encoder outputs (B x L x H)

class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size                  # H
        self.output_size = output_size                  # V
        
        self.U = nn.Linear(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, 1, batch_first=True)
        self.V = nn.Linear(hidden_size, output_size)
        self.attn = Attention(hidden_size)
    
    # Decoder in training mode  
    def forward(self, input, last_hidden, encoder_states):  
        input = F.relu(self.U(input))                               # B x 1 x H        
        atn_weights = self.attn.attend(last_hidden, encoder_states) # B x L x 1
        context = atn_weights * encoder_states                      # B x L x 1 * B x L x H = B x L x H
        
        gru_input = torch.cat((input, context), dim=1)              # B x (L+1) x H
        gru_input = torch.sum(gru_input, dim=1).unsqueeze(1)        # B x 1 x H
        
        output, hidden = self.gru(gru_input, last_hidden)           # B x 1 x H, B x 1 x H        
        output = self.V(output)                                     # B x 1 x V
        
        return output, hidden
    
    # Decoder in inference mode
    def inference(self, input, last_hidden, encoder_states):
        input = F.relu(self.U(input))                               # 1 x 1 x H         
        atn_weights = self.attn.attend(last_hidden, encoder_states) # 1 x L x 1
        context = atn_weights * encoder_states                      # 1 x L x 1 * 1 x L x H = 1 x L x H
        
        gru_input = torch.cat((input, context), dim=1)              # B x (L+1) x H
        gru_input = torch.sum(gru_input, dim=1).unsqueeze(1)        # B x 1 x H
        
        output, hidden = self.gru(gru_input, last_hidden)           # 1 x 1 x H, 1 x 1 x H
        output = self.V(output)                                     # 1 x 1 x V
        labels = F.softmax(output, dim=2)                           # 1 x 1 x V
        
        return output, labels, hidden

In [28]:
# Define the se2seq model
class Seq2Seq(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, n_layers=2):
        super(Seq2Seq, self).__init__()
        self.encoder = EncoderRNN(input_size, hidden_size, n_layers)
        self.decoder = DecoderRNN(hidden_size, output_size)
        
    def forward(self, inputs, targets, teacher_forcing_ratio=0.5):
        batch_size = inputs.size(0)
        vocab_size = self.decoder.output_size
        
        encoder_states, hidden = self.encoder(inputs)
        
        target_len = targets.size(1)
        outputs = torch.zeros(batch_size, target_len, vocab_size).to(device)
        decoder_input = torch.zeros(batch_size, 1, vocab_size).to(device)
        decoder_input[:, 0, SOS_token] = 1
        
        use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
        
        if use_teacher_forcing:
            for t in range(target_len):
                output, hidden = self.decoder(decoder_input, hidden, encoder_states)
                outputs[:, t, :] = output.squeeze(1)
                decoder_input = targets[:, t, :].unsqueeze(1)
        else:
            for t in range(target_len):
                output, hidden = self.decoder(decoder_input, hidden, encoder_states)
                outputs[:, t, :] = output.squeeze(1)
                topv, topi = output.topk(1)
                
                decoder_input = torch.zeros(batch_size, 1, vocab_size).to(device)
                decoder_input[:, 0, topi] = 1
        
        return outputs
    
    def predict(self, input):
        batch_size = input.size(0)
        target_len = input.size(1)
        vocab_size = self.decoder.output_size
        
        outputs = torch.zeros(batch_size, target_len, vocab_size).to(device)
        encoder_states, hidden = self.encoder(input)
        
        decoder_input = torch.zeros(batch_size, 1, vocab_size).to(device)
        decoder_input[:, 0, SOS_token] = 1
        
        for t in range(target_len - 1):
            output, labels, hidden = self.decoder.inference(decoder_input, hidden, encoder_states)
            outputs[:, t, :] = labels.squeeze(1)
            topv, topi = labels.topk(1)
            decoder_input = torch.zeros(batch_size, 1, vocab_size).to(device)
            decoder_input[:, 0, topi] = 1
        
        # Add the EOS token
        outputs[:, target_len - 1, EOS_token] = 1
        
        return outputs

In [29]:
# Instantiate the model
seq2seq = Seq2Seq(input_size, hidden_size, output_size).to(device)

print(seq2seq.parameters)

<bound method Module.parameters of Seq2Seq(
  (encoder): EncoderRNN(
    (U): Linear(in_features=15, out_features=300, bias=True)
    (gru): GRU(300, 300, num_layers=2, batch_first=True)
  )
  (decoder): DecoderRNN(
    (U): Linear(in_features=15, out_features=300, bias=True)
    (gru): GRU(300, 300, batch_first=True)
    (V): Linear(in_features=300, out_features=15, bias=True)
    (attn): Attention(
      (W1): Linear(in_features=300, out_features=300, bias=True)
      (W2): Linear(in_features=300, out_features=300, bias=True)
      (V): Linear(in_features=300, out_features=1, bias=True)
    )
  )
)>


In [30]:
# Define the optimizer and loss function
learning_rate = 0.001

seq2seq_optimizer = optim.Adam(seq2seq.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

In [31]:
# Train the model
seq2seq.train()
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

print("Training started ...")
for epoch in range(n_epochs):
    for i, data in enumerate(dataloader):
        inputs, labels, targets = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        targets = targets.to(device)
        
        if inputs.size(0) != batch_size:
            continue
        
        # Zero the gradients
        seq2seq_optimizer.zero_grad()
        
        # Forward pass
        outputs = seq2seq(inputs, targets)
        
        loss = 0
        for j in range(outputs.size(1)):
            loss += criterion(outputs[:, j, :], labels[:, j])
        
        # Backward pass
        loss.backward()
        seq2seq_optimizer.step()
        
        print("Epoch: {}/{}, Step: {}/{}, Loss: {}".format(epoch+1, n_epochs, i+1, len(dataloader), np.round(loss.item()/max_len, 4)))

print("Training completed !!!")

Training started ...
Epoch: 1/300, Step: 1/17, Loss: 2.6774
Epoch: 1/300, Step: 2/17, Loss: 2.5516
Epoch: 1/300, Step: 3/17, Loss: 2.4262
Epoch: 1/300, Step: 4/17, Loss: 2.3362
Epoch: 1/300, Step: 5/17, Loss: 2.2657
Epoch: 1/300, Step: 6/17, Loss: 2.1533
Epoch: 1/300, Step: 7/17, Loss: 2.1022
Epoch: 1/300, Step: 8/17, Loss: 1.9719
Epoch: 1/300, Step: 9/17, Loss: 1.9332
Epoch: 1/300, Step: 10/17, Loss: 1.8644
Epoch: 1/300, Step: 11/17, Loss: 1.8543
Epoch: 1/300, Step: 12/17, Loss: 1.7908
Epoch: 1/300, Step: 13/17, Loss: 1.8083
Epoch: 1/300, Step: 14/17, Loss: 1.7511
Epoch: 1/300, Step: 15/17, Loss: 1.7459
Epoch: 1/300, Step: 16/17, Loss: 1.6838
Epoch: 2/300, Step: 1/17, Loss: 1.6398
Epoch: 2/300, Step: 2/17, Loss: 1.6264
Epoch: 2/300, Step: 3/17, Loss: 1.5482
Epoch: 2/300, Step: 4/17, Loss: 1.548
Epoch: 2/300, Step: 5/17, Loss: 1.5442
Epoch: 2/300, Step: 6/17, Loss: 1.6128
Epoch: 2/300, Step: 7/17, Loss: 1.5054
Epoch: 2/300, Step: 8/17, Loss: 1.5421
Epoch: 2/300, Step: 9/17, Loss: 1.563

In [32]:
# Inference

def inference(seq2seq, input, max_len=7):
    # truncate the input to max_len of training samples and remove the unknown tokens
    input = input[:max_len]  
    input = [token for token in input if vocab_asr.dig2idx(token) != UNK_token]                        
    encoded_seq = [SOS_token] + vocab_asr.encode(input) + [EOS_token]
    
    # one-hot encode the input
    one_hot_enc = np.zeros((1, len(encoded_seq), vocab_asr.vocab_size()))
    for i, token in enumerate(encoded_seq):
        one_hot_enc[0, i, token] = 1
        
    one_hot_enc = torch.tensor(one_hot_enc, dtype=torch.float).to(device)
    decoded_seq = seq2seq.predict(one_hot_enc)
    decoded_seq = torch.argmax(decoded_seq, dim=2).flatten().tolist()
    
    decoded_seq = [token for token in decoded_seq if token not in [SOS_token, EOS_token]]
    
    return vocab_mt.decode(decoded_seq)

In [41]:
# Test the model
seq2seq.eval()

input_seq = "O123456"
output_seq = inference(seq2seq, input_seq)

print("Input sequence : {}".format(input_seq))
print("Output sequence : {}".format(output_seq))

Input sequence : O123456
Output sequence : பூஜ்யம் ஒன்று இரண்டு மூன்று நான்கு ஐந்து ஆறு


In [42]:
# Save the model
# torch.save(seq2seq.state_dict(), "seq2seq.pth")