In [1]:
import re
import numpy as np

In [2]:
MAX_LEN=50

pairs=[]
with open('train.txt',encoding='utf-8') as f:
    for line in f.readlines():
        word_re=re.compile(r"\w+")
        sent,title=line.split('\t')
        sent=word_re.findall(sent.lower())
        title=word_re.findall(title.lower())
        if len(sent)>MAX_LEN or len(title)>MAX_LEN:
            continue

        if sent[0] in ['i','you','he','she','we','they','us']:
            pairs.append([sent,title])
        
print(len(pairs))
for sent,title in pairs[:5]:
    print('-'*30)
    print(sent)
    print(title)


1724
------------------------------
['us', 'business', 'leaders', 'lashed', 'out', 'wednesday', 'at', 'legislation', 'that', 'would', 'penalize', 'companies', 'for', 'employing', 'illegal', 'immigrants']
['us', 'business', 'attacks', 'tough', 'immigration', 'law']
------------------------------
['us', 'first', 'lady', 'laura', 'bush', 'and', 'us', 'secretary', 'of', 'state', 'condoleezza', 'rice', 'will', 'represent', 'the', 'united', 'states', 'later', 'this', 'month', 'at', 'the', 'inauguration', 'of', 'liberia', 's', 'president', 'elect', 'ellen', 'johnson', 'sirleaf', 'the', 'white', 'house', 'said', 'wednesday']
['laura', 'bush', 'unk', 'rice', 'to', 'attend', 'sirleaf', 's', 'inauguration', 'in', 'liberia']
------------------------------
['us', 'auto', 'sales', 'will', 'likely', 'be', 'weaker', 'in', 'a', 'senior', 'executive', 'at', 'ford', 'motor', 'company', 'said', 'wednesday']
['ford', 'executive', 'sees', 'weaker', 'us', 'auto', 'sales', 'in']
------------------------------

In [3]:
en_vocab={}

en_vocab['<pad>'],en_vocab['<bos>'],en_vocab['<eos>']=0,1,2
en_idx=3

for sent,title in pairs:
    for w in sent:
        if w not in en_vocab:
            en_vocab[w]=en_idx
            en_idx+=1
    
    for w in title:
        if w not in en_vocab:
            en_vocab[w]=en_idx
            en_idx+=1



In [4]:
train_sents=[]
train_titles=[]
train_labels=[]

for sent,title in pairs:
    train_sent=sent+["<eos>"]+["<pad>"]*(MAX_LEN-len(sent))
    train_title=["<bos>"]+title+["<eos>"]+["<pad>"]*(MAX_LEN-len(title))
    train_label=title+["<eos>"]+["<pad>"]*(MAX_LEN-len(title)+1)

    train_sent.reverse()

    train_sents.append([en_vocab[w] for w in train_sent])
    train_titles.append([en_vocab[w] for w in train_title])
    train_labels.append([en_vocab[w] for w in train_label])

train_sents=np.array(train_sents)
train_titles=np.array(train_titles)
train_labels=np.array(train_labels)

print(train_sents.shape)
print(train_titles.shape)
print(train_labels.shape)

(1724, 51)
(1724, 52)
(1724, 52)


# 配置模型

In [5]:
import paddle
import paddle.nn.functional as F

In [6]:
embedding_size=128
# hidden_size=512
layers=1
epochs=20
batch_size=16
en_vocab_size=len(en_vocab)
en_vocab_size

7939

In [7]:
class Encoder(paddle.nn.Layer):
    
    def __init__(self):
        super(Encoder,self).__init__()

        self.emb=paddle.nn.Embedding(en_vocab_size,embedding_size)
        self.layer=paddle.nn.TransformerEncoderLayer(embedding_size,8,512)
        self.encoder=paddle.nn.TransformerEncoder(self.layer,2)


    def forward(self,x):
        x=self.emb(x)
        #[batch_size,MAX_LEN+1,embedding_size]

        x=self.encoder(x)
        #[batch_size,MAX_LEN+1,embedding_size]
        return x

In [8]:
class Decoder(paddle.nn.Layer):
    
    def __init__(self):
        super(Decoder,self).__init__()
        self.emb=paddle.nn.Embedding(en_vocab_size,embedding_size)
        self.layer=paddle.nn.TransformerDecoderLayer(embedding_size,8,512)
        self.decoder=paddle.nn.TransformerDecoder(self.layer,2)

        self.outlinear=paddle.nn.Linear(embedding_size,en_vocab_size)

    def forward(self,x,en_repr):
        x=self.emb(x)
        #[batch_size,1,embedding_size]

        x=self.decoder(x,en_repr)
        #[batch_size,1,embedding_size]

        x=self.outlinear(x)
        #[batch_size,1,en_vocab_size]

        x=paddle.squeeze(x)
        return x        

# 训练

In [None]:
encoder=Encoder()
decoder=Decoder()
opt=paddle.optimizer.Adam(learning_rate=0.001,parameters=encoder.parameters()+decoder.parameters())

for epoch in range(epochs):
    print(f"epoch:{epoch}")

    perm=np.random.permutation(len(train_sents))
    train_sents_shuffle=train_sents[perm]
    train_labels_shuffle=train_labels[perm]
    train_titles_shuffle=train_titles[perm]

    for iteration in range(train_sents.shape[0]//batch_size):
        x_sents=train_sents_shuffle[iteration*batch_size:(iteration+1)*batch_size]
        x_title=train_titles_shuffle[iteration*batch_size:(iteration+1)*batch_size]
        y=train_labels_shuffle[iteration*batch_size:(iteration+1)*batch_size]

        x_sents=paddle.to_tensor(x_sents)
        en_repr=encoder(x_sents)

        # hidden=paddle.zeros([batch_size,1,hidden_size])
        # cell=paddle.zeros([batch_size,1,hidden_size])

        loss=paddle.zeros([1])
        for i in range(MAX_LEN+2):
            word=x_title[:,i:i+1]
            label=y[:,i]
            word=paddle.to_tensor(word)
            label=paddle.to_tensor(label)

            logits=decoder(word,en_repr)
            #logits:[batch_size,en_vocab_size]
            loss+=F.cross_entropy(logits,label)

            # word=paddle.argmax(logits,axis=1)#多余
            #word:[batch_size,1]

        if iteration%50==0:
            print(f"step: {iteration},loss: {(loss/(MAX_LEN+2)).numpy()[0]}")

        loss.backward()
        opt.step()
        opt.clear_grad()

        

epoch:0
step: 0,loss: 9.082996368408203
step: 50,loss: 1.4103511571884155


# 预测

In [None]:
encoder.eval()
decoder.eval()

nums=10
indices=np.random.choice(len(train_sents),nums,replace=False)
x_sents=train_sents[indices]

word=[[en_vocab["<bos>"]]]*nums
# hidden=paddle.zeros([nums,1,hidden_size])
# cell=paddle.zeros([nums,1,hidden_size])

res=[]
x_sents=paddle.to_tensor(x_sents)
word=paddle.to_tensor(word)
en_repr=encoder(x_sents)
for i in range(MAX_LEN+2):
    #[nums,1]
    logits=decoder(word,en_repr)
    #[nums,en_vocab_size]
    word=paddle.argmax(logits,axis=1)

    # print(word.shape)
    #[nums]
    res.append(word)

    word=paddle.unsqueeze(word,axis=-1)

res=np.stack(res,axis=-1)

for i in range(nums):
    x_sent=' '.join(pairs[indices[i]][0])
    ground_truth=' '.join(pairs[indices[i]][1])
    pred=""
    for w in res[i]:
        w=list(en_vocab)[w]
        if w!="<pad>" and w!="<eos>":
            pred+=" "+w
    
    print('-'*30)
    print("sent:",x_sent)
    print("true:",ground_truth)
    print("pred:",pred)

