In [30]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import string
import random
import re
import unicodedata
from collections import Counter
from io import open

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import optim


In [31]:
class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: 'SOS', 1: 'EOS'}
        self.n_words = 2
    
    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)
    
    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

In [32]:
def unicode_to_ascii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn'
    )

def normalize_string(s):
    s = unicode_to_ascii(s.lower().strip()) # lowercase all characters and strip both ends
    s = re.sub(r"([.!?])", r" \1", s) # remove special characters except ? and !
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

In [33]:
s = "             J'ai pigé ! 1234!@#%^*&(*(?))                  "
normalize_string(s)

'j ai pige ! ! ? '

In [34]:
def read_file(lang1, lang2, reverse=False):
    print('Reading lines...')
    
    lines = open('./data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').read().strip().split('\n')
    pairs = [[normalize_string(s) for s in l.split('\t')] for l in lines]
    
    # check for reverse
    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        original_lang = Lang(lang1)
        translated_lang = Lang(lang2)
    else:
        original_lang = Lang(lang2)
        translated_lang = Lang(lang1)
        
    return original_lang, translated_lang, pairs

In [35]:
MAX_LEN = 10

eng_prefixes = ('i am', 'i m')

def filter_pair(p):
    return (len(p[0].split(' ')) < MAX_LEN and len(p[1].split(' ')) < MAX_LEN and p[1].startswith(eng_prefixes))

def filter_pairs(pairs):
    return [pair for pair in pairs if filter_pair(pair)]

In [36]:
def prepare_data(lang1, lang2, reverse=False):
    original, translated, pairs = read_file(lang1, lang2, reverse)
    print('Total number of pairs:', len(pairs))
    pairs = filter_pairs(pairs)
    print('Trimmed to %s number of pairs' % len(pairs))
    print('Counting words...')
    for pair in pairs:
        original.addSentence(pair[0])
        translated.addSentence(pair[1])
    print(original.name, original.n_words)
    print(translated.name, translated.n_words)
    
    return original, translated, pairs
    
original, translated, pairs = prepare_data('eng', 'fra', reverse=True)
print(pairs[0])

Reading lines...
Total number of pairs: 135842
Trimmed to 4534 number of pairs
Counting words...
eng 2405
fra 1792
['j ai ans .', 'i m .']
