In [1]:
import re
import numpy as np

In [2]:
MAX_LEN=50

pairs=[]
with open('train.txt',encoding='utf-8') as f:
    for line in f.readlines():
        word_re=re.compile(r"\w+")
        sent,title=line.split('\t')
        sent=word_re.findall(sent.lower())
        title=word_re.findall(title.lower())
        if len(sent)>MAX_LEN or len(title)>MAX_LEN:
            continue

        if sent[0] in ['i','you','he','she','we','they','us']:
            pairs.append([sent,title])
        
print(len(pairs))
for sent,title in pairs[:5]:
    print('-'*30)
    print(sent)
    print(title)


1724
------------------------------
['us', 'business', 'leaders', 'lashed', 'out', 'wednesday', 'at', 'legislation', 'that', 'would', 'penalize', 'companies', 'for', 'employing', 'illegal', 'immigrants']
['us', 'business', 'attacks', 'tough', 'immigration', 'law']
------------------------------
['us', 'first', 'lady', 'laura', 'bush', 'and', 'us', 'secretary', 'of', 'state', 'condoleezza', 'rice', 'will', 'represent', 'the', 'united', 'states', 'later', 'this', 'month', 'at', 'the', 'inauguration', 'of', 'liberia', 's', 'president', 'elect', 'ellen', 'johnson', 'sirleaf', 'the', 'white', 'house', 'said', 'wednesday']
['laura', 'bush', 'unk', 'rice', 'to', 'attend', 'sirleaf', 's', 'inauguration', 'in', 'liberia']
------------------------------
['us', 'auto', 'sales', 'will', 'likely', 'be', 'weaker', 'in', 'a', 'senior', 'executive', 'at', 'ford', 'motor', 'company', 'said', 'wednesday']
['ford', 'executive', 'sees', 'weaker', 'us', 'auto', 'sales', 'in']
------------------------------

In [3]:
en_vocab={}

en_vocab['<pad>'],en_vocab['<bos>'],en_vocab['<eos>']=0,1,2
en_idx=3

for sent,title in pairs:
    for w in sent:
        if w not in en_vocab:
            en_vocab[w]=en_idx
            en_idx+=1
    
    for w in title:
        if w not in en_vocab:
            en_vocab[w]=en_idx
            en_idx+=1



In [4]:
train_sents=[]
train_titles=[]
train_labels=[]

for sent,title in pairs:
    train_sent=sent+["<eos>"]+["<pad>"]*(MAX_LEN-len(sent))
    train_title=["<bos>"]+title+["<eos>"]+["<pad>"]*(MAX_LEN-len(title))
    train_label=title+["<eos>"]+["<pad>"]*(MAX_LEN-len(title)+1)

    train_sent.reverse()

    train_sents.append([en_vocab[w] for w in train_sent])
    train_titles.append([en_vocab[w] for w in train_title])
    train_labels.append([en_vocab[w] for w in train_label])

train_sents=np.array(train_sents)
train_titles=np.array(train_titles)
train_labels=np.array(train_labels)

print(train_sents.shape)
print(train_titles.shape)
print(train_labels.shape)

(1724, 51)
(1724, 52)
(1724, 52)


# 配置模型

In [5]:
import paddle
import paddle.nn.functional as F

In [6]:
embedding_size=128
hidden_size=256
layers=1
epochs=20
batch_size=16
en_vocab_size=len(en_vocab)
en_vocab_size

7939

In [7]:
class Encoder(paddle.nn.Layer):
    def __init__(self):
        super(Encoder,self).__init__()

        self.emb=paddle.nn.Embedding(en_vocab_size,embedding_size)
        self.lstm=paddle.nn.LSTM(embedding_size,hidden_size=hidden_size,num_layers=layers)

    def forward(self,x):
        # print(x.shape)
        x=self.emb(x)
        #x:[batch_size,MAX_LEN+1,embedding_size]
        x,(_,_)=self.lstm(x)
        #x:[batch_size,MAX_LEN+1,hidden_size]
        return x
        

In [8]:
class AttentionDecoder(paddle.nn.Layer):
    
    def __init__(self):
        super(AttentionDecoder,self).__init__()
        self.emb=paddle.nn.Embedding(en_vocab_size,embedding_size)

        self.attention_linear1=paddle.nn.Linear(hidden_size*2,hidden_size)
        self.attention_linear2=paddle.nn.Linear(hidden_size,1)

        self.lstm=paddle.nn.LSTM(embedding_size+hidden_size,hidden_size=hidden_size,num_layers=layers)

        self.outlinear=paddle.nn.Linear(hidden_size,en_vocab_size)
    
    def forward(self,word,hidden,cell,en_repr):

        #word:[batch_size,1]
        #hidden,cell:[batch_size,1,hidden_size]
        #en_repr:[batch_size,MAX_LEN+1,hidden_size]

        word=self.emb(word)
        #word:[batch_size,1,embedding_size]

        attention_inputs=paddle.concat([paddle.tile(hidden,[1,MAX_LEN+1,1]),en_repr],axis=-1)
        #[batch_size,MAX_LEN+1,hidden_size*2]
        #last word of title+ every word of sent

        attention_inputs=self.attention_linear1(attention_inputs)
        attention_inputs=F.tanh(attention_inputs)

        attention_inputs=self.attention_linear2(attention_inputs)
        #[batch_size,MAX_LEN+1,1]

        attention_inputs=paddle.squeeze(attention_inputs)
        attention_inputs=F.softmax(attention_inputs)
        weights=paddle.unsqueeze(attention_inputs,axis=-1)

        weights=paddle.expand_as(weights,en_repr)
        
        context_vector=paddle.multiply(weights,en_repr)
        context_vector=paddle.sum(context_vector,axis=1)
        #[batch_size,hidden_size]

        context_vector=paddle.unsqueeze(context_vector,axis=1)
        #[batch_size,1,hidden_size]

        #word:[batch_size,1,embedding_size]
        lstm_inputs=paddle.concat([word,context_vector],axis=-1)

        hidden=paddle.transpose(hidden,[1,0,2])
        cell=paddle.transpose(cell,[1,0,2])
        output,(hidden,cell)=self.lstm(lstm_inputs,(hidden,cell))
        hidden=paddle.transpose(hidden,[1,0,2])
        cell=paddle.transpose(cell,[1,0,2])

        #hidden,cell:[batch_size,1,hidden_size]
        output=self.outlinear(hidden)
        #[batch_size,1,en_vocab_size]
        
        output=paddle.squeeze(output)
        #[batch_size,en_vocab_size]

        return output,(hidden,cell)

# 训练

In [9]:
encoder=Encoder()
decoder=AttentionDecoder()
opt=paddle.optimizer.Adam(learning_rate=0.001,parameters=encoder.parameters()+decoder.parameters())

for epoch in range(epochs):
    print(f"epoch:{epoch}")

    perm=np.random.permutation(len(train_sents))
    train_sents_shuffle=train_sents[perm]
    train_labels_shuffle=train_labels[perm]
    train_titles_shuffle=train_titles[perm]

    for iteration in range(train_sents.shape[0]//batch_size):
        x_sents=train_sents_shuffle[iteration*batch_size:(iteration+1)*batch_size]
        x_title=train_titles_shuffle[iteration*batch_size:(iteration+1)*batch_size]
        y=train_labels_shuffle[iteration*batch_size:(iteration+1)*batch_size]

        x_sents=paddle.to_tensor(x_sents)
        en_repr=encoder(x_sents)

        hidden=paddle.zeros([batch_size,1,hidden_size])
        cell=paddle.zeros([batch_size,1,hidden_size])

        loss=paddle.zeros([1])
        for i in range(MAX_LEN+2):
            word=x_title[:,i:i+1]
            label=y[:,i]
            word=paddle.to_tensor(word)
            label=paddle.to_tensor(label)

            logits,(hidden,cell)=decoder(word,hidden,cell,en_repr)
            #logits:[batch_size,en_vocab_size]
            loss+=F.cross_entropy(logits,label)

            # word=paddle.argmax(logits,axis=1)#多余
            #word:[batch_size,1]

        if iteration%50==0:
            print(f"step: {iteration},loss: {(loss/(MAX_LEN+2)).numpy()[0]}")

        loss.backward()
        opt.step()
        opt.clear_grad()

        

W0807 11:05:58.148721  6171 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1
W0807 11:05:58.152729  6171 gpu_resources.cc:91] device: 0, cuDNN Version: 7.6.


epoch:0
step: 0,loss: 8.987388610839844
step: 50,loss: 1.599442720413208
step: 100,loss: 1.350903868675232
epoch:1
step: 0,loss: 1.182124376296997
step: 50,loss: 1.2501391172409058
step: 100,loss: 1.2571748495101929
epoch:2
step: 0,loss: 1.0913419723510742
step: 50,loss: 1.0296783447265625
step: 100,loss: 1.1578869819641113
epoch:3
step: 0,loss: 1.1276209354400635
step: 50,loss: 1.0401471853256226
step: 100,loss: 1.1024224758148193
epoch:4
step: 0,loss: 1.1479531526565552
step: 50,loss: 1.168302297592163
step: 100,loss: 1.0650542974472046
epoch:5
step: 0,loss: 0.9886360764503479
step: 50,loss: 1.061771035194397
step: 100,loss: 1.052024245262146
epoch:6
step: 0,loss: 0.9993051290512085
step: 50,loss: 1.0444704294204712
step: 100,loss: 0.9912765622138977
epoch:7
step: 0,loss: 0.9642326831817627
step: 50,loss: 1.0607023239135742
step: 100,loss: 0.9567527174949646
epoch:8
step: 0,loss: 0.9924283027648926
step: 50,loss: 1.1005970239639282
step: 100,loss: 1.0446380376815796
epoch:9
step: 0,l

# 预测

In [10]:
encoder.eval()
decoder.eval()

nums=10
indices=np.random.choice(len(train_sents),nums,replace=False)
x_sents=train_sents[indices]

word=[[en_vocab["<bos>"]]]*nums
hidden=paddle.zeros([nums,1,hidden_size])
cell=paddle.zeros([nums,1,hidden_size])

res=[]
x_sents=paddle.to_tensor(x_sents)
word=paddle.to_tensor(word)
en_repr=encoder(x_sents)
for i in range(MAX_LEN+2):
    #[nums,1]
    logits,(hidden,cell)=decoder(word,hidden,cell,en_repr)
    #[nums,en_vocab_size]
    word=paddle.argmax(logits,axis=1)

    # print(word.shape)
    #[nums]
    res.append(word)

    word=paddle.unsqueeze(word,axis=-1)

res=np.stack(res,axis=-1)

for i in range(nums):
    x_sent=' '.join(pairs[indices[i]][0])
    ground_truth=' '.join(pairs[indices[i]][1])
    pred=""
    for w in res[i]:
        w=list(en_vocab)[w]
        if w!="<pad>" and w!="<eos>":
            pred+=" "+w
    
    print('-'*30)
    print("sent:",x_sent)
    print("true:",ground_truth)
    print("pred:",pred)



------------------------------
sent: us forces have arrested the spiritual head of a kurdish islamist group ali abdul aziz and other people in the northern town of halabja an official of the group said sunday
true: islamist group in iraqi kurdistan says us army arrested spiritual guide
pred:  us airways to visit of the a be a victory
------------------------------
sent: us federal reserve lrb fed rrb opened a two day meeting on tuesday to discuss the monetary policy in the future with most analysts predicting that the us short term interest rates will be surely raised by percentage point
true: us fed opens meeting to discuss future monetary policy
pred:  us airways to visit of the a be a victory
------------------------------
sent: us secretary of state condoleezza rice said wednesday she was satisfied with iraqi government efforts to include minority sunni muslims in the unk process
true: rice satisfied with inclusive iraqi political process
pred:  us airways to visit on us troops to 