In [1]:
from io import open
import unicodedata
import string
import re
import random
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
SOS_token = 0
EOS_token = 1

In [37]:
class Lang:
    def __init__(self,name):
        self.name = name
        self.W2I = {'SOS':SOS_token,'EOS':EOS_token}
        self.I2W = {SOS_token:'SOS',EOS_token:'EOS'}
        self.W2C = {}
        self.n_words = 2
    def addSentence(self,s):
        for word in s.split(' '):
            self.addWord(word)
    def addWord(self,w):
        if w not in self.W2I:
            self.W2I[w] = self.n_words
            self.W2C[w] = 1
            self.I2W[self.n_words] = w
            self.n_words+=1
        else:
            self.W2C[w]+=1
    def printAllWords(self):
        words = list(self.W2I.keys())
        for word in words:
            print(word)

In [38]:
L = Lang('Eng')

In [39]:
L.addWord('NLP')

In [41]:
L.addSentence('How are you today')

In [42]:
L.printAllWords()

SOS
EOS
NLP
How
are
you
today


In [43]:
def u2a(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD',s)
        if unicodedata.category(c) != 'Mn'
    )

In [44]:
def normalizeString(s):
    s = u2a(s.lower().strip())
    s = re.sub(r'([.!?])',r'\1',s)
    s = re.sub(r'[^a-zA-Z.!?]+',r' ',s)
    return s

In [45]:
print(normalizeString('asfdieojo98793259'))

asfdieojo 


In [46]:
def readLangs():
    lines = open('eng-fra.txt',encoding='utf-8').read().strip().split('\n')
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
    input_lang = Lang('eng')
    output_lang = Lang('fra')
    return input_lang, output_lang,pairs

In [47]:
I,O,P = readLangs()

In [34]:
P

[['go.', 'va !'],
 ['run!', 'cours !'],
 ['run!', 'courez !'],
 ['wow!', 'ca alors !'],
 ['fire!', 'au feu !'],
 ['help!', 'a l aide !'],
 ['jump.', 'saute.'],
 ['stop!', 'ca suffit !'],
 ['stop!', 'stop !'],
 ['stop!', 'arrete toi !'],
 ['wait!', 'attends !'],
 ['wait!', 'attendez !'],
 ['i see.', 'je comprends.'],
 ['i try.', 'j essaye.'],
 ['i won!', 'j ai gagne !'],
 ['i won!', 'je l ai emporte !'],
 ['oh no!', 'oh non !'],
 ['attack!', 'attaque !'],
 ['attack!', 'attaquez !'],
 ['cheers!', 'sante !'],
 ['cheers!', 'a votre sante !'],
 ['cheers!', 'merci !'],
 ['get up.', 'leve toi.'],
 ['got it!', 'j ai pige !'],
 ['got it!', 'compris !'],
 ['got it?', 'pige ?'],
 ['got it?', 'compris ?'],
 ['got it?', 't as capte ?'],
 ['hop in.', 'monte.'],
 ['hop in.', 'montez.'],
 ['hug me.', 'serre moi dans tes bras !'],
 ['hug me.', 'serrez moi dans vos bras !'],
 ['i fell.', 'je suis tombee.'],
 ['i fell.', 'je suis tombe.'],
 ['i know.', 'je sais.'],
 ['i left.', 'je suis parti.'],
 ['i le

In [50]:
def prepareData(I,O,P):
    MAX_LENGTH = 0
    for pair in P:
        I.addSentence(pair[0])
        O.addSentence(pair[1])
        MAX_LENGTH = max(MAX_LENGTH,len(pair[0].split()),len(pair[1].split()))
    return I,O,MAX_LENGTH

In [51]:
input_lang,output_lang,MAX_LENGTH = prepareData(I,O,P)

In [52]:
MAX_LENGTH

59

In [53]:
input_lang.n_words

20753

In [54]:
output_lang.n_words

29481

In [56]:
output_lang.printAllWords()

SOS
EOS
va
!
cours
courez
ca
alors
au
feu
a
l
aide
saute.
suffit
stop
arrete
toi
attends
attendez
je
comprends.
j
essaye.
ai
gagne
emporte
oh
non
attaque
attaquez
sante
votre
merci
leve
toi.
pige
compris
?
t
as
capte
monte.
montez.
serre
moi
dans
tes
bras
serrez
vos
suis
tombee.
tombe.
sais.
parti.
partie.
perdu.
ans.
vais
bien.
va.
ecoutez
impossible
en
aucun
cas.
c
est
hors
de
question
il
n
pas
exclu
aucune
maniere
vraiment
vrai
ah
bon
on
nous
avons
gagne.
gagnames.
emporte.
emportames.
demande
tom.
fantastique
sois
calme
soyez
calmes
detendu
juste
justes
equitable
equitables
gentil.
gentil
gentille
gentils
gentilles
degage
appelle
appellez
appelez
entrez
entre.
entre
allez
viens
venez
laisse
tomber
laissez
le
sortez
sors
sors.
casse
pars
te
faire
foutre
fous
camp
d
ici.
disparais
vous
doucement
la
revoyure.
un
peu
tiens
tenez
laissa
tomber.
court.
moi.
aidez
ne
bouge
plus
quittez
pas.
du
meme
avis.
essayai.
tente.
irai.
gras.
gros.
forme.
touche
touchee
malade.
triste.
timide.
mouil

In [58]:
print(random.choice(P))

['you were jealous weren t you?', 'vous etiez jalouses n est ce pas ?']


In [59]:
pairs = P

In [61]:
class EncoderRNN(nn.Module):
    def __init__(self,vocabSize,hidden_size):
        super(EncoderRNN,self).__init__()
        self.hidden_size = hidden_size
        self.E = nn.Embedding(vocabSize,hidden_size)
        self.gru = nn.GRU(hidden_size,hidden_size,
                          batch_first=True,bidirectional=True)
    def forward(self,input,hidden):
        emb = self.E(input).view(1,1,-1)
        output,hidden = self.gru(emb,hidden)
        return output,hidden
    def initHidden(self):
        return torch.zeros(2,1,self.hidden_size,device=device)

In [73]:
class DecoderRNN(nn.Module):
    def __init__(self,hidden_size,vocabSize,max_length = MAX_LENGTH):
        super(DecoderRNN,self).__init__()
        self.hidden_size = hidden_size
        self.output_size = vocabSize
        self.max_length = max_length
        self.E = nn.Embedding(self.output_size,self.hidden_size)
        self.attn = nn.Linear(self.hidden_size*2,self.max_length)
        self.attn_combine = nn.Linear(self.hidden_size*3,self.hidden_size)
        self.gru = nn.GRU(self.hidden_size,self.hidden_size)
        self.out = nn.Linear(self.hidden_size,self.output_size)
    
    def forward(self,input,hidden,encoder_outputs):
        emb = self.E(input).view(1,1,-1)
        attn_w = F.softmax(self.attn(torch.cat((emb[0],hidden[0]),1)),dim=1)
        attn_A = torch.bmm(attn_w.unsqueeze(0),
                          encoder_outputs.unsqueeze(0))
        
        output = torch.cat((emb[0],attn_A[0]),1)
        output = self.attn_combine(output).unsqueeze(0)
        output = F.relu(output)
        output,hidden = self.gru(output,hidden)
        output = F.log_softmax(self.out(output[0]),dim=1)
        return output,hidden,attn_w
    def initHidden(self):
        return torch.zeros(1,1,self.hidden_size,device=device)
        

In [74]:
def indexesFromSentence(lang,s):
    return[lang.W2I[w] for w in s.split()]

def tensorFromSentence(lang,s):
    idx = indexesFromSentence(lang,s)
    idx.append(EOS_token)
    return torch.tensor(idx,dtype=torch.long,device=device).view(-1,1)

def tensorsFromPair(pair):
    input_tensor = tensorFromSentence(input_lang,pair[0])
    output_tensor = tensorFromSentence(output_lang,pair[1])
    return (input_tensor,output_tensor)
    
    

In [75]:
def train(input_tensor,target_tensor,encoder,decoder,
         encoder_optimizer,decoder_optimizer,loss_fn,
         max_length=MAX_LENGTH):
    encoder_hidden = encoder.initHidden()
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    input_length = input_tensor.size(0)
    target_length= target_tensor.size(0)
    encoder_outputs = torch.zeros(max_length,2*encoder.hidden_size,device=device)
    loss = 0
    for ei in range(input_length):
        encoder_output,encoder_hidden = encoder(
            input_tensor[ei],encoder_hidden)
        out_reshaped = encoder_output.view(1,1,2,encoder.hidden_size)
        out_fwd = out_reshaped[:,:,0,:]
        out_bck = out_reshaped[:,:,1,:]
        encoder_outputs[ei] = torch.cat((out_fwd[0,0],out_bck[0,0]),0)
    decoder_input = torch.tensor([[SOS_token]],device=device)
    h_reshaped = encoder_hidden.view(1,2,1,encoder.hidden_size)
    decoder_hidden = h_reshaped[:,0,:,:]
    
    for di in range(target_length):
        decoder_output,decoder_hidden,decoder_attention = decoder(
            decoder_input,decoder_hidden,encoder_outputs
        )
        topv,topi = decoder_output.topk(1)
        decoder_input = topi.squeeze().detach()
        loss+=loss_fn(decoder_output,target_tensor[di])
        if decoder_input.item() == EOS_token:
            break
    loss.backward()
    encoder_optimizer.step()
    decoder_optimizer.step()
    return loss.item()/target_length

In [76]:
def trainIters(encoder,decoder,n_iters,lr=0.001):
    totalLoss = 0
    encoder_optimizer = optim.SGD(encoder.parameters(),lr=lr)
    decoder_optimizer = optim.SGD(decoder.parameters(),lr=lr)
    training_pairs = [tensorsFromPair(random.choice(pairs))
                     for i in range(n_iters)]
    loss_fn = nn.NLLLoss()
    for iter in range(n_iters):
        training_pair = training_pairs[iter]
        input_tensor = training_pair[0]
        target_tensor = training_pair[1]
        loss = train(input_tensor,target_tensor,encoder,decoder,
                        encoder_optimizer,decoder_optimizer,loss_fn
                    )
        totalLoss+=loss
        print(totalLoss/(iter+1))
    

In [77]:
hidden_size = 128
encoder = EncoderRNN(input_lang.n_words,hidden_size).to(device)
decoder = DecoderRNN(hidden_size,output_lang.n_words).to(device)

In [79]:
trainIters(encoder,decoder,100)

10.21520767211914
10.22722511291504
10.250142015729631
10.26163343702044
10.263259793236143
10.261796781751846
10.26548384711856
10.26921136379242
10.271564641410922
10.266349892086453
10.268826141742744
10.267096716007858
10.265554630479825
10.26649846873365
10.26811267342689
10.270549121683983
10.269697360605491
10.269082019228536
10.267825202321902
10.266020746065943
10.263043179430438
10.261924746392141
10.263807271730872
10.263661757967347
10.263896354543183
10.260336628811304
10.258958515782735
10.258485208877802
10.25808669184846
10.115888229542596
10.11835326532276
10.122867564635532
10.126922264751336
10.130830448206654
10.133727392850526
10.136519477473291
10.109667699046248
10.112864568013126
10.050788963421049
10.054305492310986
10.057625568049414
10.061293109107297
10.065090292406355
10.07001546859804
10.011983816282328
10.017049541951454
10.021477287976811
10.025915306072617
9.946124751481412
9.906746815506688
9.846730302813855
9.74835137934969
9.757165055716802
9.7657228

In [82]:
def evaluate(encoder,decoder,s,max_length=MAX_LENGTH):
    with torch.no_grad():
        input_tensor = tensorFromSentence(input_lang,s)
        input_length = input_tensor.size()[0]
        encoder_hidden = encoder.initHidden()
        encoder_outputs = torch.zeros(max_length,2*encoder.hidden_size,device=device)
        
        for ei in range(input_length):
            encoder_output,encoder_hidden = encoder(
                input_tensor[ei],encoder_hidden)
            out_reshaped = encoder_output.view(1,1,2,encoder.hidden_size)
            out_fwd = out_reshaped[:,:,0,:]
            out_bck = out_reshaped[:,:,1,:]
            encoder_outputs[ei] = torch.cat((out_fwd[0,0],out_bck[0,0]),0)
        decoder_input = torch.tensor([[SOS_token]],device=device)
        h_reshaped = encoder_hidden.view(1,2,1,encoder.hidden_size)
        decoder_hidden = h_reshaped[:,0,:,:]
        
        decoded_words = []
        decoder_att = torch.zeros(max_length,max_length)
        
        for di in range(max_length):
            decoder_output,decoder_hidden,decoder_attention = decoder(
                decoder_input,decoder_hidden,encoder_outputs
            )
            decoder_att[di] = decoder_attention.data
            topv,topi = decoder_output.topk(1)
            if topi.item() == EOS_token:
                decoded_words.append('<EOS>')
                break
            else:
                decoded_words.append(output_lang.I2W[topi.item()])
            decoder_input = topi.squeeze().detach()
        return decoded_words
        
        

In [84]:
print(evaluate(encoder,decoder,pairs[0][0]),pairs[0][1])

['tralala.', 'tralala.', 'splendeur.', 'croirait', '<EOS>'] va !
