In [1]:
import tensorflow as tf
import json
import jieba
import os

In [2]:
with open('./word2id.json','r') as f:
    word2id = json.load(f)
with open('./id2word.json','r') as f:
    id2word = json.load(f)
id2word = {int(k):v for k,v in id2word.items()}

In [3]:
class Encoder(tf.keras.Model):
    def __init__(self, vocab_size, embedding_dim, encoder_units, batch_size):
        super(Encoder, self).__init__()
        self.batch_size = batch_size
        self.encoder_units = encoder_units
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        self.gru = tf.keras.layers.GRU(self.encoder_units,
                                       return_sequences=True,
                                       return_state=True,
                                       recurrent_initializer='glorot_uniform')

    def call(self, x, hidden):
        x = self.embedding(x)
        output, state = self.gru(x, initial_state = hidden) #全部隐藏状态，最后一个隐藏状态
        return output, state

    def initialize_hidden_state(self):
        return tf.zeros((self.batch_size, self.encoder_units))

In [4]:
class Attention(tf.keras.layers.Layer):
    def __init__(self,units):
        super(Attention,self).__init__()
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V = tf.keras.layers.Dense(1)
    
    def call(self,query,values):
        hidden_with_the_time_axis = tf.expand_dims(query,1)
        score = self.V(tf.nn.tanh(self.W1(values) + 
                                 self.W2(hidden_with_the_time_axis)))
        
        attention_weights = tf.nn.softmax(score,axis=1)
        
        context_vector = attention_weights * values
        context_vector = tf.reduce_sum(context_vector,axis=1)
        
        return context_vector , attention_weights

In [5]:
class Decoder(tf.keras.Model):
    def __init__(self,vocab_size,embedding_dim,decoder_units,batch_size):
        super(Decoder,self).__init__()
        self.batch_size = batch_size
        self.decoder_units = decoder_units
        self.embedding = tf.keras.layers.Embedding(vocab_size,embedding_dim)
        self.gru = tf.keras.layers.GRU(self.decoder_units,
                                      return_sequences=True,
                                      return_state=True,
                                      recurrent_initializer='glorot_uniform')
        self.fc = tf.keras.layers.Dense(vocab_size)
        self.attention = Attention(self.decoder_units)
    
    def call(self,x,hidden,encoder_output):
        context_vector, attention_weights = self.attention(hidden,encoder_output)
        x = self.embedding(x)
        x = tf.concat([tf.expand_dims(context_vector,1),x],axis=-1)
        output, state = self.gru(x)
        
        output = tf.reshape(output,(-1,output.shape[2]))
        
        x = self.fc(output)
        
        return x,state,attention_weights

In [6]:
BUFFER_SIZE = 114969
BATCH_SIZE = 64
steps_per_epoch = 114969 //BATCH_SIZE
embedding_dim = 100
units = 512
vocab_size = len(word2id)+1

In [7]:
encoder = Encoder(vocab_size, embedding_dim, units, BATCH_SIZE)
decoder = Decoder(vocab_size,embedding_dim,units,BATCH_SIZE)

In [8]:
optimizer = tf.keras.optimizers.RMSprop()
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 encoder=encoder,
                                 decoder=decoder)

In [9]:
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x1a5ae8c1518>

In [14]:
def predict(sentence):
    for ch in "#@$%^&*():;：；{}[]'_<>-+/~0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ":
        sentence = sentence.replace(ch,"")
    sentence = jieba.cut(sentence)
    sentence = [word2id.get(word,word2id["<pad>"]) for word in sentence]
    sentence = tf.keras.preprocessing.sequence.pad_sequences([sentence],value=word2id["<pad>"],
                                                             maxlen=20,padding='post')
    sentence = tf.convert_to_tensor(sentence)
    
    result = ''
    hidden = [tf.zeros((1,units))] #现在只输入了一个所以batch_size是1
    encoder_output, encoder_hidden = encoder(sentence,hidden)
    
    decoder_hidden = encoder_hidden
    decoder_input = tf.expand_dims([word2id["<start>"]],0)
    
    for t in range(20):
        predictions , decoder_hidden , attention_weights = decoder(decoder_input,
                                                                   decoder_hidden,
                                                                   encoder_output)
        predicted_id = tf.argmax(predictions[0]).numpy()
        
        if id2word[predicted_id] == "<end>":
            return result
        result += id2word[predicted_id]
        
        decoder_input = tf.expand_dims([predicted_id],0)
    
    return result

In [17]:
predict("百度一下")

' 你说的什么哦？换种说法行不'