In [138]:
import os
import math
import numpy as np
import torch
from torch import nn,optim
from torch.utils import data
from torchtext import vocab

In [139]:
from collections import Counter

In [140]:
DATA_PATH="./data/fr-en-small.txt"

In [141]:
BOS="<bos>"
EOS="<eos>"
PAD="<pad>"
UNK="<unk>"
MAX_LEN=7
EPOCH_SIZE=50
BATCH_SIZE=4

# 数据预处理

In [142]:
def tokenize(vocabulary,text):
    text=text.split()
    vocabulary.update(text)
    return text

In [143]:
def token2index(vocabulary,text,target=False):
    sentence=[vocabulary.stoi[w] for w in text]
    if target:
        sentence=sentence[:MAX_LEN-1] if len(sentence)>MAX_LEN-1 else sentence
        sentence+=[vocabulary.stoi[EOS]]+[vocabulary.stoi[PAD]]*(MAX_LEN-len(sentence)-1)
    else:
        sentence=sentence[:MAX_LEN] if len(sentence)>MAX_LEN else sentence
        sentence+=[vocabulary.stoi[PAD]]*(MAX_LEN-len(sentence))
    return sentence

In [144]:
def read_data(file_path):
    input_vacab,output_vocab=Counter(),Counter()
    input_data,output_data=[],[]
    input_examples,out_examples=[],[]
    
    with open(file_path,"r") as f:
        lines=f.readlines()
    for line in lines:
        text=line.strip().lower().split("\t")
        fr=text[0]
        en=text[1]
        input_data.append(tokenize(input_vacab,fr))
        output_data.append(tokenize(output_vocab,en))
    input_vocabs=vocab.Vocab(input_vacab,specials=[UNK,PAD,BOS,EOS])
    output_vocabs=vocab.Vocab(output_vocab,specials=[UNK,PAD,BOS,EOS])
  
    for input_text,output_text in zip(input_data,output_data):
        input_examples.append(token2index(input_vocabs,input_text))
        out_examples.append(token2index(output_vocabs,output_text,True))
    return input_vocabs,output_vocabs,data.TensorDataset(torch.tensor(input_examples),torch.tensor(out_examples))

In [29]:
fr_vocab,en_vocab,dataset=read_data(DATA_PATH)

In [149]:
print(dataset[5][0],dataset[5][1])

tensor([ 7,  8, 39,  4,  1,  1,  1]) tensor([ 8,  6, 32,  4,  3,  1,  1])


In [150]:
print(fr_vocab.itos[:10])

['<unk>', '<pad>', '<bos>', '<eos>', '.', 'est', 'elle', 'ils', 'sont', 'il']


In [151]:
print(en_vocab.itos[:10])

['<unk>', '<pad>', '<bos>', '<eos>', '.', 'is', 'are', 'he', 'they', 'she']


In [152]:
print(len(dataset))

20


# 模型建立

In [153]:
class Encoder(nn.Module):
    '''编码器'''
    def __init__(self,vocab_size,embedding_size,hidden_size,num_layer):
        super(Encoder,self).__init__()
        self.embedding=nn.Embedding(vocab_size,embedding_size)
        self.rnn=nn.GRU(input_size=embedding_size,hidden_size=hidden_size,num_layers=num_layer)
    def forward(self,x):
        embedding=self.embedding(x.long()) #batch_size,seq_len,embedding_size
        return self.rnn(embedding.permute(1,0,2))#output: seq_len,batch_size,hidden_size state:batch_size,num_layer*hidden_size  

In [154]:
class Attention(nn.Module):
    def __init__(self,enc_hidden_size,dec_hidden_size,attention_size):
        super(Attention,self).__init__()
        self.weight=nn.Sequential(nn.Linear(enc_hidden_size+dec_hidden_size,attention_size),
                                 nn.Tanh(),
                                 nn.Linear(attention_size,1))
    def forward(self,encode_output,decode_hidden):
        '''
        encode_output:seq_len,batch_size,hidden_size
        decode_hidden:batch_size,hidden_size
        
        '''
        decode_hidden=decode_hidden.unsqueeze(dim=0).expand_as(encode_output)
        w=nn.functional.softmax(self.weight(torch.cat((encode_output,decode_hidden),dim=2)),dim=0)#  seq_len,batch_size,1
        return (w*encode_output).sum(dim=0)# batch_size,hidden_size
        

In [161]:
class Decoder(nn.Module):
    def __init__(self,vocab_size,embedding_size,enc_hidden_size,dec_hidden_size,attention_size,num_layers):
        super(Decoder,self).__init__()
        self.embedding=nn.Embedding(vocab_size,embedding_size)
        self.attention=Attention(enc_hidden_size,dec_hidden_size,attention_size)
        self.rnn=nn.GRU(embedding_size+enc_hidden_size,dec_hidden_size,num_layers=num_layers)
        self.linear=nn.Linear(dec_hidden_size,vocab_size)
    def forward(self,x,encoder_output,state=None):
        '''
        x:batch_size,
        encoder_output:seq_len,batch_size,hidden_size
        state:num_layer,batch_size,hidden_size
        '''
        embedding=self.embedding(x.long())  #batch_size,embedding_size
        attention=self.attention(encoder_output,state[-1])# batch_size,hidden_size
        output,state=self.rnn((torch.cat((embedding,attention),dim=1)).unsqueeze(dim=0),state) #output:1,batch_size,hidden_size;state:num_layer,batch_size,hidden_size
        return self.linear(output).squeeze(dim=0),state

In [200]:
def batchloss(encoder,decoder,source,target,loss):
    outputs=[]
    batch_size=source.size(0)
    encoder_output,state=encoder(source)
    decoder_input=torch.ones(batch_size,)*en_vocab.stoi[BOS]
    for Y in target.permute(1,0):
        output,state=decoder(decoder_input,encoder_output,state)
        outputs.append(output)
        decoder_input=Y
    predict=torch.stack(outputs).permute(1,0,2).contiguous().view(-1,len(en_vocab)) #batch_size*seq_len,vocab_size
    label=target.view(-1,)
    mask=(label!=en_vocab.stoi[PAD]).float()
    return (loss(predict,label)*mask).sum()/mask.sum()

In [187]:
def batch_loss(encoder, decoder, X, Y, loss):
    batch_size = X.shape[0]
    enc_outputs, enc_state = encoder(X)
    # 初始化解码器的隐藏状态
    dec_state = enc_state
    # 解码器在最初时间步的输入是BOS
    dec_input = torch.tensor([en_vocab.stoi[BOS]] * batch_size)
    # 我们将使用掩码变量mask来忽略掉标签为填充项PAD的损失
    mask, num_not_pad_tokens = torch.ones(batch_size,), 0
    l = torch.tensor([0.0])
    for y in Y.permute(1,0): # Y shape: (batch, seq_len)
        dec_output, dec_state = decoder(dec_input, enc_outputs,dec_state)
        l = l + (mask * loss(dec_output, y)).sum()
        dec_input = y  # 使用强制教学
        num_not_pad_tokens += mask.sum().item()
        # EOS后面全是PAD. 下面一行保证一旦遇到EOS接下来的循环中mask就一直是0
        mask = mask * (y != en_vocab.stoi[EOS]).float()
    return l / num_not_pad_tokens

In [201]:
def train(encoder,decoder,source_vocab,target_vocab,dataset):
    encoder_optim=optim.Adam(encoder.parameters(),lr=0.01)
    decoder_optim=optim.Adam(decoder.parameters(),lr=0.01)
    loss=nn.CrossEntropyLoss(reduction="none")
    dataloader=data.DataLoader(dataset=dataset,batch_size=BATCH_SIZE,shuffle=True)
    for epoch in range(EPOCH_SIZE):
        l_sum=0
        for X,Y in dataloader:
            encoder_optim.zero_grad()
            decoder_optim.zero_grad()
            l=batchloss(encoder,decoder,X,Y,loss)
            l.backward()
            encoder_optim.step()
            decoder_optim.step()
            l_sum+=l.cpu().item()
        print("Epoch:%d,Loss:%f"%(epoch+1,l_sum))

In [202]:
EMBEDDING_SIZE=64
HIDDEN_SIZE=64
ATTENTION_SIZE=10
encoder=Encoder(len(fr_vocab),EMBEDDING_SIZE,HIDDEN_SIZE,2)
decoder=Decoder(len(en_vocab),EMBEDDING_SIZE,HIDDEN_SIZE,HIDDEN_SIZE,ATTENTION_SIZE,2)
train(encoder,decoder,fr_vocab,en_vocab,dataset)

Epoch:1,Loss:15.396118
Epoch:2,Loss:10.650159
Epoch:3,Loss:7.820441
Epoch:4,Loss:6.193841
Epoch:5,Loss:5.002726
Epoch:6,Loss:4.024870
Epoch:7,Loss:3.434067
Epoch:8,Loss:2.930856
Epoch:9,Loss:2.576673
Epoch:10,Loss:2.305188
Epoch:11,Loss:1.944704
Epoch:12,Loss:1.913826
Epoch:13,Loss:1.598700
Epoch:14,Loss:1.716502
Epoch:15,Loss:1.410523
Epoch:16,Loss:1.260203
Epoch:17,Loss:1.108375
Epoch:18,Loss:0.974600
Epoch:19,Loss:0.855859
Epoch:20,Loss:0.779932
Epoch:21,Loss:0.692020
Epoch:22,Loss:0.666504
Epoch:23,Loss:0.587369
Epoch:24,Loss:0.535799
Epoch:25,Loss:0.485055
Epoch:26,Loss:0.432214
Epoch:27,Loss:0.407595
Epoch:28,Loss:0.388035
Epoch:29,Loss:0.338016
Epoch:30,Loss:0.350799
Epoch:31,Loss:0.376195
Epoch:32,Loss:0.280863
Epoch:33,Loss:0.275644
Epoch:34,Loss:0.257240
Epoch:35,Loss:0.218947
Epoch:36,Loss:0.198931
Epoch:37,Loss:0.168688
Epoch:38,Loss:0.163474
Epoch:39,Loss:0.146729
Epoch:40,Loss:0.129693
Epoch:41,Loss:0.158469
Epoch:42,Loss:0.128901
Epoch:43,Loss:0.105476
Epoch:44,Loss:0.09

In [203]:
def translate(encoder,decoder,fr_vocab,en_vocab,source_text):
    in_seq=source_text.strip().split()
    in_seq=torch.tensor([[fr_vocab.stoi[w] for w in in_seq]])
    encoder_output,state=encoder(in_seq)
    Y=torch.tensor([en_vocab.stoi[BOS]])
    target=[]
    while len(target)<MAX_LEN:
        decoder_output,state=decoder(Y,encoder_output,state)
        w=decoder_output.argmax(dim=1).item()
        if w==en_vocab.stoi[EOS]:
            break
        else:
            target.append(w)
            Y=torch.tensor([w])
    target_text=" ".join([en_vocab.itos[i] for i in target])
    print("French:%s   English:%s"%(source_text,target_text))

In [206]:
translate(encoder,decoder,fr_vocab,en_vocab,"ils crevees .")

French:ils crevees .   English:they are exhausted .


In [165]:
print(len(en_vocab.itos))

39


In [127]:
print(fr_vocab.itos)

['<unk>', '<pad>', '<bos>', '<eos>', '.', 'est', 'elle', 'ils', 'sont', 'il', 'mon', 'a', 'c', 'elles', '!', 'acteurs', 'adorable', 'age', 'amis', 'bonne', 'bonnes', 'canadienne', 'crevees', 'de', 'des', 'deux', 'disputent', 'du', 'ennuis', 'environ', 'fait', 'frere', 'genre', 'grands', 'japonaise', 'nageuse', 'oncle', 'personne', 'regardent', 'russes', 'se', 'tort', 'toutes', 'tranquille', 'une', 'velo', 'vieille']


In [207]:
import collections

In [224]:
def blue_score(label,predict,k=2):
    label_len,predict_len=len(label),len(predict)
    score=math.exp(min(0,1-label_len/predict_len))
    for i in range(1,k+1):
        num_match=0
        label_subs=collections.defaultdict(int)
        for j in range(label_len-i+1):
            label_subs["".join(label[j:j+i])]+=1
        for j in range(predict_len-i+1):
            sub="".join(predict[j:j+i])
            if label_subs[sub]>0:
                num_match+=1
                label_subs[sub]-=1 
        score*=math.pow(num_match/(predict_len-i+1),math.pow(1/2,i))
    return score

In [225]:
in_seq='they are watching .'.split()
out_seq='they are .'.split()
print(blue_score(in_seq,out_seq))

0.6025286104785453
