In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import string
import random
import re
import unicodedata
from collections import Counter
from io import open

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import optim


In [2]:
class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: 'SOS', 1: 'EOS'}
        self.n_words = 2
    
    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)
    
    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

In [3]:
def unicode_to_ascii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn'
    )

def normalize_string(s):
    s = unicode_to_ascii(s.lower().strip()) # lowercase all characters and strip both ends
    s = re.sub(r"([.!?])", r" \1", s) # remove special characters except ? and !
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

In [4]:
s = "             J'ai pigé ! 1234!@#%^*&(*(?))                  "
normalize_string(s)

'j ai pige ! ! ? '

In [5]:
def read_file(lang1, lang2, reverse=False):
    print('Reading lines...')
    
    lines = open('./data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').read().strip().split('\n')
    pairs = [[normalize_string(s) for s in l.split('\t')] for l in lines]
    
    # check for reverse
    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        original_lang = Lang(lang1)
        translated_lang = Lang(lang2)
    else:
        original_lang = Lang(lang2)
        translated_lang = Lang(lang1)
        
    return original_lang, translated_lang, pairs

In [6]:
MAX_LEN = 10

eng_prefixes = ('i am', 'i m')

def filter_pair(p):
    return (len(p[0].split(' ')) < MAX_LEN and len(p[1].split(' ')) < MAX_LEN and p[1].startswith(eng_prefixes))

def filter_pairs(pairs):
    return [pair for pair in pairs if filter_pair(pair)]

In [7]:
def prepare_data(lang1, lang2, reverse=False):
    original, translated, pairs = read_file(lang1, lang2, reverse)
    print('Total number of pairs:', len(pairs))
    pairs = filter_pairs(pairs)
    print('Trimmed to %s number of pairs' % len(pairs))
    print('Counting words...')
    for pair in pairs:
        original.addSentence(pair[0])
        translated.addSentence(pair[1])
    print(original.name, original.n_words)
    print(translated.name, translated.n_words)
    
    return original, translated, pairs
    
original, translated, pairs = prepare_data('eng', 'fra', reverse=True)
print(pairs[0])

Reading lines...
Total number of pairs: 135842
Trimmed to 4534 number of pairs
Counting words...
eng 2405
fra 1792
['j ai ans .', 'i m .']


In [17]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        self.embed = nn.Embedding(self.input_size, self.hidden_size)
        self.gru = nn.GRU(self.hidden_size, self.hidden_size)
        
    def forward(self, x, h):
        embedding = self.embed(x)
        output, h = self.gru(embedding, h)
        
        return output, h
    
    def init_hidden_weights(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

In [18]:
class Decoder(nn.Module):
    def __init__(self, hidden_size, output_size):
        self.hidden_size = hidden_size
        self.output_size = output_size
        
        self.embed = nn.Embedding(self.output_size, self.hidden_size)
        self.gru = nn.GRU(self.hidden_size, self.hidden_size)
        self.fc = nn.Linear(self.hidden_size, self.output_size)
        self.softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, x, h):
        out = F.relu(self.embed(x).view(1, 1, -1))
        out, h = self.gru(out, h)
        out = self.softmax(self.fc(out[0]))
    
    def init_hidden_weights(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

In [21]:
# preparing training data
EOS_TOKEN=1
SOS_TOKEN=0

def index_from_sentence(lang, sentence):
    return [lang.word2index[word] for word in sentence.split(' ')]

def tensor_from_sentence(lang, sentence):
    idx = index_from_sentence(lang, sentence)
    idx.append(EOS_TOKEN)
    
    return torch.tensor(idx, type=torch.long, device=device).view(-1, 1)

def tensor_from_pair(pair):
    feature = tensor_from_sentence(lang, pair[0])
    target = tensor_from_sentence(lang, pair[1])
    
    return (feature, target)

In [None]:
force_teacher_ratio = 0.5

def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LEN):
    encoder_hidden = encoder.init_hidden_weights()
    
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    
    input_length = input_tensor.size(0)
    output_length = target_tensor.size(0)
    
    encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
    
    loss = 0
    for i in range(input_length):
        encoder_output, encoder_hidden = encoder(input_tensor[i], encoder_hidden)
        encoder_outputs[i] = encoder_output[0,0]
    
    decoder_input = torch.tensor([[SOS_TOKEN]], device=device)
    decoder_hidden = encoder_hidden
    
    use_teacher_forcing = True if random.random() < force_teacher_ratio else False
    if use_teacher_forcing:
        for i in range(target_length):
            