In [18]:
import unicodedata

import torch
import torch.nn as nn

import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import re
import unicodedata
from nltk.translate.bleu_score import sentence_bleu

import random
device = "cuda" if torch.cuda.is_available() else "cpu"

In [19]:
SOS_TOKEN = 0
EOS_TOKEN = 1
MAX_LENGTH = 10

eng_prefixes = (
    "i am ", "i m ",
    "he is", "he s ",
    "she is", "she s ",
    "you are", "you re ",
    "we are", "we re ",
    "they are", "they re "
)

class Lang:
    def __init__(self,name):
        self.name = name
        self.word2index ={}
        self.word2count ={}
        self.index2word = {
            0:"SOS", 1:"EOS"
        }
        self.n_words = 2 # number of words
        
    def addSentence(self,sentence):
        for word in sentence.split(' '):
            self.addWord(word)
        
    def addWord(self,word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words +=1 
        else:
            self.word2count[word] +=1
            

def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize("NFD",s) if unicodedata.category(c) != 'Mn'
    )

def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    
    return s


def readLangs(lang1,lang2,reverse=False):
    print("Reading lines ... ")
    
    lines = open("./%s-%s.txt" % (lang1,lang2), encoding = 'utf-8').read().strip().split('\n')
    
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
    
    if reverse:
        pairs = [list(reversed(p)) for p in pairs ]
        input_lang = Lang(lang2)
        output_lang = Lang(lang1)
    else:
        input_lang = Lang(lang1)
        output_lang = Lang(lang2)
        
    
    return input_lang,output_lang,pairs

def filterPair(p):
    return len(p[0].split(' '))<MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH and p[1].startswith(eng_prefixes)

def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

def prepareData(lang1,lang2,reverse=False):
    input_lang,output_lang,pairs = readLangs(lang1,lang2,reverse)
    print("Reading %s sentence pairs ", len(pairs))
    pairs = filterPairs(pairs)
    print("Trimmed to %s sentence pairs", len(pairs))
    print("Counting words")
    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])
        
    print("Counted words:")
    print(input_lang.name, input_lang.n_words)
    print(output_lang.name, output_lang.n_words)
    return input_lang, output_lang, pairs

input_lang, output_lang, pairs = prepareData('eng', 'fra', True)
print(random.choice(pairs))

Reading lines ... 
Reading %s sentence pairs  135842
Trimmed to %s sentence pairs 10599
Counting words
Counted words:
fra 4345
eng 2803
['elle a peur de retomber malade .', 'she is afraid of falling ill again .']


In [26]:
## Encoder

class Encoder(nn.Module):
    def __init__(self,input_size,hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        
        self.embeds = nn.Embedding(input_size,hidden_dim)
        self.gru = nn.GRU(hidden_dim,hidden_dim)
        
    def forward(self,x,hidden):
        embeds = self.embeds(x).view(1,1,-1)
        output,hidden = self.gru(embeds,hidden)
        return output,hidden
    
    def initHidden(self):
        return torch.zeros(1,1,self.hidden_dim,device=device)
    
    
## Decoder

class AttnDecoder(nn.Module):
    def __init__(self,hidden_dim,output_size,max_len=MAX_LENGTH):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.output_size = output_size
        self.max_len = max_len
        
        self.embeds = nn.Embedding(self.output_size,self.hidden_dim)
        self.attn_matrix = nn.Parameter(data = torch.ones((self.hidden_dim,self.hidden_dim)),requires_grad=True)
        self.gru = nn.GRU(self.hidden_dim*2,self.hidden_dim)
        self.out = nn.Linear(self.hidden_dim,self.output_size)
        
    def forward(self,x,hidden,encoder_outputs):
        embeds = self.embeds(x).view(1,1,-1)
        attn_weights = torch.matmul(torch.matmul(hidden[0],self.attn_matrix),torch.transpose(encoder_outputs,0,1))
        # 1 x max_len
        attn_weights = F.softmax(attn_weights,dim=1)
        
        attn_applied = torch.bmm(attn_weights.unsqueeze(1),encoder_outputs.view(1,-1,self.hidden_dim))
        
        input_gru = torch.cat((attn_applied[0],embeds[0]),dim=1)
        
        output,hidden = self.gru(input_gru.unsqueeze(0),hidden)
        output = F.log_softmax(self.out(output[0]),dim=1)
        
        return output,hidden,attn_weights

In [27]:
def indexexFromSentece(lang,sentence):
    return [lang.word2index[word] for word in sentence.split(' ')]

def tensorFromSentence(lang,sentence):
    indexes = indexexFromSentece(lang,sentence)
    indexes.append(EOS_TOKEN)
    return torch.tensor(indexes,dtype=torch.long,device=device).view(-1,1)

def tensorFromPair(pair):
    input_tensor = tensorFromSentence(input_lang,pair[0])
    target_tensor = tensorFromSentence(output_lang,pair[1])
    return (input_tensor,target_tensor)


In [34]:
hidden_size = 256
teacher_forcing_ratio = 0.5
max_length = MAX_LENGTH
n_iters = 50000
n_epochs = 1
lr = 0.001
device = "cpu"


encoder = Encoder(input_size=input_lang.n_words,hidden_dim=hidden_size).to(device)

decoder = AttnDecoder(hidden_dim=hidden_size,output_size=output_lang.n_words,max_len=max_length).to(device)

encoder_opt = torch.optim.Adam(encoder.parameters(),lr=lr)
decoder_opt = torch.optim.Adam(decoder.parameters(),lr=lr)
loss_fn = nn.NLLLoss()

training_pairs = [tensorFromPair(random.choice(pairs)) for i in range(n_iters)]

for epoch in range(n_epochs+1):
    for iter in range(1,n_iters+1):
        
        training_pair = training_pairs[iter-1]
        input_tensor = training_pair[0]
        target_tensor = training_pair[1]
        
        #encoder
        encoder_hidden = encoder.initHidden()
        
        input_len = input_tensor.size(0)
        target_len = target_tensor.size(0)
        
        encoder_outputs = torch.zeros(max_length,encoder.hidden_dim,device=device)
        
        loss = 0
        
        for ei in range(input_len):
            encoder_output,encoder_hidden = encoder(input_tensor[ei],encoder_hidden)
            encoder_outputs[ei] = encoder_output[0,0]
            
        #decoder
        output_sentence = [output_lang.index2word[t.item()] for t in target_tensor]

        decoder_input = torch.tensor([[SOS_TOKEN]],device=device)
        decoder_hidden = encoder_hidden
        
        use_teacher_forcing = True if random.random()< teacher_forcing_ratio else False
        
        decoded_sentence = []
        if use_teacher_forcing:
            for di in range(target_len):
                decoder_output,decoder_hidden,decoder_att = decoder(decoder_input,decoder_hidden,encoder_outputs)
                
                topv,topi = decoder_output.topk(1)
                decoded_sentence.append(output_lang.index2word[topi.item()])
                
                loss += loss_fn(decoder_output,target_tensor[di])
                decoder_input = target_tensor[di]
            
        else:
            for di in range(target_len):
                decoder_output,decoder_hidden,decoder_att = decoder(decoder_input,decoder_hidden,encoder_outputs)
                
                topv,topi = decoder_output.topk(1)
                decoded_sentence.append(output_lang.index2word[topi.item()])
                decoder_input = topi.squeeze().detach()
                
                loss+= loss_fn(decoder_output,target_tensor[di])
                if decoder_input.item() == EOS_TOKEN:
                    break
                    
        encoder_opt.zero_grad()
        decoder_opt.zero_grad()
    
        loss.backward()
    
        encoder_opt.step()
        decoder_opt.step()
    
        #bleu_score = sentence_bleu([output_sentence[:-1]],decoded_sentence[:-1])
    
        if iter%500 == 0:
              print('epoch: {}, iter: {}, loss: {:.6f}, '.format(epoch, iter, loss.item() / target_len))

epoch: 0, iter: 500, loss: 2.819708, 
epoch: 0, iter: 1000, loss: 1.415829, 
epoch: 0, iter: 1500, loss: 1.989942, 
epoch: 0, iter: 2000, loss: 3.183518, 
epoch: 0, iter: 2500, loss: 1.210539, 
epoch: 0, iter: 3000, loss: 3.101019, 


KeyboardInterrupt: 