In [1]:
import tensorflow as tf
from tensorflow import keras

In [46]:
class Encoder(keras.Model):
    def __init__(self,vocab_size,embdedding_dim,encoder_units):
        super(Encoder,self).__init__()
        
        # Embedding 层
        self.encoder_embedding = keras.layers.Embedding(vocab_size,
                                                       embdedding_dim, 
                                                       mask_zero=True,
                                                       name="encoder_embed_layer")
        
        # gru层
        self.encoder_gru = keras.layers.GRU(encoder_units,
                                           return_sequences=True,
                                           return_state=True,
                                           recurrent_initializer="glorot_uniform",
                                           name="encoder_gru_layer")
        
    def call(self,inputs):
        encoder_inps_embed = self.encoder_embedding(inputs)
        encoder_outs,encoder_state = self.encoder_gru(encoder_inps_embed)
        return encoder_outs,encoder_state
        

In [47]:
class Decoder(keras.Model):
    def __init__(self,vocab_size,embdedding_dim,decoder_units):
        super(Decoder,self).__init__()
        
        # Embedding 层
        self.decoder_embedding = keras.layers.Embedding(vocab_size,
                                                       embdedding_dim, 
                                                       mask_zero=True,
                                                       name="decoder_embed_layer")
        
        # attention 层
        self.decoder_atten = keras.layers.Attention(name="decoder_atten_layer")
        
        # gru层
        self.decoder_gru = keras.layers.GRU(decoder_units,
                                           return_sequences=True,
                                           return_state=True,
                                           recurrent_initializer="glorot_uniform",
                                           name="encoder_gru_layer")
        
    def call(self,encoder_outs,decoder_inps,states):
        decoder_inps_embed = self.decoder_embedding(decoder_inps)
        decoder_outs,decoder_state = self.decoder_gru(decoder_inps_embed,
                                                      initial_state=states)
        attention_output = self.decoder_atten([decoder_outs,encoder_outs])
        
        return attention_output,decoder_state

In [50]:
def Seq2Seq(maxlen,embedding_dim,units,vocab_size):
    #Input layer
    encoder_iputs = keras.Input(shape=(maxlen,),name="encoder_input")
    decoder_iputs = keras.Input(shape=(None,),name="decoder_input")  
    
    # Encoder layer
    encoder = Encoder(vocab_size,embedding_dim,units)
    enc_outs,enc_state = encoder(encoder_iputs)
    
    # Decoder layer
    decoder = Decoder(vocab_size,embedding_dim,units)
    atten_outs,dec_state = decoder(enc_outs,decoder_iputs,enc_state)
    
    # Dense layer
    dense_outs = keras.layers.Dense(vocab_size, activation='softmax', name="dense")(atten_outs)
    
    # Seq2Seq model
    model = keras.Model(inputs=[encoder_iputs,decoder_iputs],outputs=dense_outs)
    
    return model

In [5]:
def read_vocab(vocab_path):
    vocab_words= []
    with open(vocab_path, "r", encoding="utf-8") as f:
        for line in f:
            vocab_words.append(line.strip())
    return vocab_words

def read_data(data_path):
    datas= []
    with open(data_path, "r", encoding="utf-8") as f:
        for line in f:
            words = line.strip().split()
            datas.append(words)
    return datas

def process_data_index(datas,vocab2id):
    data_indexs = []
    for words in datas:
        line_index = [vocab2id[w] if w in vocab2id else vocab2id["<UNK>"] for w in words]
        data_indexs.append(line_index)
    return data_indexs

In [7]:
vocab_words = read_vocab("./datasets/ch_word_vocab.txt")
vocab_words[:5]

['呵呵', '不是', '怎么', '了', '开心']

In [8]:
special_words = ["<PAD>", "<UNK>", "<GO>", "<EOS>"]
vocab_words = special_words + vocab_words
vocab_words[:5]

['<PAD>', '<UNK>', '<GO>', '<EOS>', '呵呵']

In [9]:
word2id = {word:i for i,word in enumerate(vocab_words)}
id2word = {i:word for i,word in enumerate(vocab_words)}

In [13]:
word2id

{'<PAD>': 0,
 '<UNK>': 1,
 '<GO>': 2,
 '<EOS>': 3,
 '呵呵': 4,
 '不是': 5,
 '怎么': 6,
 '了': 7,
 '开心': 8,
 '点': 9,
 '哈': 10,
 ',': 11,
 '一切': 12,
 '都会': 13,
 '好': 14,
 '起来': 15,
 '我': 16,
 '还': 17,
 '喜欢': 18,
 '她': 19,
 '怎么办': 20,
 '短信': 21,
 '你': 22,
 '知道': 23,
 '谁': 24,
 '么': 25,
 '许兵': 26,
 '是': 27,
 '这么': 28,
 '假': 29,
 '傻': 30,
 '逼': 31,
 '到底': 32,
 '尼玛': 33,
 '小黄': 34,
 '鸭': 35,
 '有': 36,
 '女朋友': 37,
 '那': 38,
 '男朋友': 39,
 '在': 40,
 '哪': 41,
 '妈': 42,
 '去': 43,
 '大爷': 44,
 '的': 45,
 '骂': 46,
 '一': 47,
 '句': 48,
 '屌丝': 49,
 '鸡': 50,
 '高富帅': 51,
 '今天': 52,
 '生日': 53,
 '敢不敢': 54,
 '呵': 55,
 '女': 56,
 '?': 57,
 '怎么回事': 58,
 '天王': 59,
 '盖地虎': 60,
 '小通': 61,
 '监考': 62,
 '干么': 63,
 '哼': 64,
 '!': 65,
 '不想': 66,
 '就': 67,
 '不': 68,
 '和': 69,
 '玩': 70,
 '要': 71,
 '气死我': 72,
 '吗': 73,
 '坏蛋': 74,
 '恩': 75,
 '也': 76,
 '不能': 77,
 '生气': 78,
 '啦': 79,
 '行': 80,
 '谈': 81,
 '过': 82,
 '恋爱': 83,
 '什么': 84,
 '让': 85,
 '伤心': 86,
 '敢问': 87,
 '性别': 88,
 '小': 89,
 '受': 90,
 '干嘛': 91,
 '为什么': 92,
 '爱情': 93,
 '

In [14]:
id2word

{0: '<PAD>',
 1: '<UNK>',
 2: '<GO>',
 3: '<EOS>',
 4: '呵呵',
 5: '不是',
 6: '怎么',
 7: '了',
 8: '开心',
 9: '点',
 10: '哈',
 11: ',',
 12: '一切',
 13: '都会',
 14: '好',
 15: '起来',
 16: '我',
 17: '还',
 18: '喜欢',
 19: '她',
 20: '怎么办',
 21: '短信',
 22: '你',
 23: '知道',
 24: '谁',
 25: '么',
 26: '许兵',
 27: '是',
 28: '这么',
 29: '假',
 30: '傻',
 31: '逼',
 32: '到底',
 33: '尼玛',
 34: '小黄',
 35: '鸭',
 36: '有',
 37: '女朋友',
 38: '那',
 39: '男朋友',
 40: '在',
 41: '哪',
 42: '妈',
 43: '去',
 44: '大爷',
 45: '的',
 46: '骂',
 47: '一',
 48: '句',
 49: '屌丝',
 50: '鸡',
 51: '高富帅',
 52: '今天',
 53: '生日',
 54: '敢不敢',
 55: '呵',
 56: '女',
 57: '?',
 58: '怎么回事',
 59: '天王',
 60: '盖地虎',
 61: '小通',
 62: '监考',
 63: '干么',
 64: '哼',
 65: '!',
 66: '不想',
 67: '就',
 68: '不',
 69: '和',
 70: '玩',
 71: '要',
 72: '气死我',
 73: '吗',
 74: '坏蛋',
 75: '恩',
 76: '也',
 77: '不能',
 78: '生气',
 79: '啦',
 80: '行',
 81: '谈',
 82: '过',
 83: '恋爱',
 84: '什么',
 85: '让',
 86: '伤心',
 87: '敢问',
 88: '性别',
 89: '小',
 90: '受',
 91: '干嘛',
 92: '为什么',
 93: '爱情',
 9

In [58]:
num_sample = 20000
source_data = read_data("./datasets/ch_source_data_seg.txt")[:num_sample]
target_data = read_data("./datasets/ch_target_data_seg.txt")[:num_sample]

source_data_ids = process_data_index(source_data,word2id)
target_data_ids = process_data_index(target_data,word2id)

In [62]:
source_data_ids

[[4],
 [5],
 [6, 7],
 [8, 9, 10, 11, 12, 13, 14, 15],
 [16, 17, 18, 19, 11, 20],
 [21],
 [22, 23, 24, 25],
 [26, 27, 24],
 [28, 29],
 [26, 27, 30, 31],
 [26, 27, 24],
 [26, 27, 24],
 [26, 27, 24],
 [26, 32, 27, 24],
 [33, 11, 26, 32, 27, 24],
 [34, 35, 11, 22, 36, 37, 25],
 [38, 22, 36, 39, 25],
 [38, 22, 40, 41],
 [22, 42, 27, 24],
 [43, 22, 44, 45],
 [22, 40, 46, 16, 47, 48],
 [22, 44, 45],
 [22, 27, 49, 50],
 [51],
 [22],
 [4],
 [52, 27, 24, 45, 53],
 [22, 54],
 [4],
 [55, 4],
 [22, 27, 56, 45, 7, 57, 58],
 [4],
 [59, 60],
 [61],
 [40, 62, 11, 22, 40, 63],
 [64, 65, 22, 66, 16, 16, 67, 68, 69, 22, 70],
 [22, 71, 72, 73, 57, 74],
 [75, 11, 76, 27, 65, 38, 16, 77, 78, 79, 65, 22, 78, 67, 80],
 [22, 81, 82, 83, 25],
 [84, 85, 22, 28, 86],
 [87, 22, 45, 88],
 [89, 90],
 [27, 73],
 [22, 91],
 [92],
 [22, 36, 93, 7],
 [38, 94, 36, 95, 96, 97],
 [38, 22, 98],
 [22, 99],
 [22, 100, 101, 102, 7],
 [40, 91],
 [103, 16, 91],
 [22, 104, 99, 105, 73],
 [106, 69, 37, 107],
 [22, 45, 108, 27, 24],

In [63]:
target_data_ids

[[27, 37846, 756, 45, 180],
 [38, 27, 84, 49272],
 [16, 6692, 82, 49273, 320, 16, 518],
 [526],
 [16, 438, 22, 328, 19, 49272, 15817, 254, 1764, 49272],
 [928, 180, 16, 76, 855],
 [2143, 5, 16, 49273, 27, 49274],
 [49275, 465, 3504, 89, 762],
 [9100, 9101, 76, 29, 49273, 68, 1715, 45, 2222, 111],
 [684, 22, 2699, 7, 180],
 [27, 16, 9572, 436, 45, 452, 45, 274, 111],
 [27, 49276, 45, 1460, 111],
 [8347],
 [219, 5550, 16, 49277, 1929, 6887, 6999, 27, 49278, 49277],
 [16, 144, 16, 376, 328, 22, 16, 27, 33674, 45, 49279, 2893, 49280, 2080],
 [8990, 27, 56, 45, 520, 49279],
 [277, 282, 49273, 16, 1391, 452, 440, 102, 10569, 16, 514, 180, 180],
 [16, 248],
 [16, 1251, 27, 851, 777, 293, 2252, 45, 452, 22, 111],
 [16, 43, 49273, 346, 22, 44, 45, 49279],
 [22, 416, 417, 31072, 518],
 [49281, 20183, 7, 180, 4123, 27, 19, 491, 1474, 45, 2501, 26132, 7],
 [928],
 [49282, 111],
 [7870],
 [22, 148],
 [27, 16, 22695, 566, 6615, 45, 53, 49279],
 [22, 892, 16, 67, 892, 49273, 748, 49273, 892, 731, 16,

In [64]:
print("word vocab test:",[id2word[i] for i in range(10)])
print("source test:",source_data[10])
print("source index:",source_data_ids[10])
print("target test:",target_data[10])
print("target index:",target_data_ids[10])

word vocab test: ['<PAD>', '<UNK>', '<GO>', '<EOS>', '呵呵', '不是', '怎么', '了', '开心', '点']
source test: ['许兵', '是', '谁']
source index: [26, 27, 24]
target test: ['是', '我', '善良', '可爱', '的', '主人', '的', '老公', '啊']
target index: [27, 16, 9572, 436, 45, 452, 45, 274, 111]


In [65]:
def process_input_data(source_data_ids,target_data_ids, vocab2id):
    source_inputs= []
    decode_inputs= []
    decode_outputs= []
    
    for source,target in zip(source_data_ids,target_data_ids):
        source_inputs.append([word2id["<GO>"]] + source + [word2id["<EOS>"]])
        decode_inputs.append([word2id["<GO>"]] + target)
        decode_outputs.append(target + [word2id["<EOS>"]])
    return source_inputs,decode_inputs,decode_outputs

In [66]:
source_input_ids,target_input_ids,target_output_ids = process_input_data(source_data_ids,target_data_ids,word2id)

In [67]:
print("encoder inputs: ", source_input_ids[:2])
print("decoder inputs: ", target_input_ids[:2])
print("decoder outputs: ", target_output_ids[:2])

encoder inputs:  [[2, 4, 3], [2, 5, 3]]
decoder inputs:  [[2, 27, 37846, 756, 45, 180], [2, 38, 27, 84, 49272]]
decoder outputs:  [[27, 37846, 756, 45, 180, 3], [38, 27, 84, 49272, 3]]


In [68]:
maxlen = 20
source_input_ids = keras.preprocessing.sequence.pad_sequences(source_input_ids, padding='post', maxlen=maxlen)
target_input_ids = keras.preprocessing.sequence.pad_sequences(target_input_ids, padding='post',  maxlen=maxlen)
target_output_ids = keras.preprocessing.sequence.pad_sequences(target_output_ids, padding='post',  maxlen=maxlen)

In [69]:
print("encoder inputs pad_sequence: ", source_input_ids[:2])
print("decoder inputs pad_sequence: ", target_input_ids[:2])
print("decoder outputs pad_sequence: ", target_output_ids[:2])

encoder inputs pad_sequence:  [[2 4 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [2 5 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]
decoder inputs pad_sequence:  [[    2    27 37846   756    45   180     0     0     0     0     0     0
      0     0     0     0     0     0     0     0]
 [    2    38    27    84 49272     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0]]
decoder outputs pad_sequence:  [[   27 37846   756    45   180     3     0     0     0     0     0     0
      0     0     0     0     0     0     0     0]
 [   38    27    84 49272     3     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0]]


In [70]:
keras.backend.clear_session()
maxlen = 20
embedding_dim = 128
units= 256
vocab_size= len(word2id)

model = Seq2Seq(maxlen,embedding_dim,units,vocab_size)
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
encoder_input (InputLayer)      [(None, 20)]         0                                            
__________________________________________________________________________________________________
encoder (Encoder)               ((None, 20, 256), (N 9273600     encoder_input[0][0]              
__________________________________________________________________________________________________
decoder_input (InputLayer)      [(None, None)]       0                                            
__________________________________________________________________________________________________
decoder (Decoder)               ((None, None, 256),  9273600     encoder[0][0]                    
                                                                 decoder_input[0][0]          

In [72]:
epochs = 5
batch_size = 64
val_rate = 0.2

In [73]:
model.compile(loss=keras.losses.SparseCategoricalCrossentropy(), optimizer="adam")

In [74]:
model.fit([source_input_ids, target_input_ids], target_output_ids, 
          batch_size=batch_size, epochs=epochs, validation_split=val_rate)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x1ded579e630>

In [80]:
model.save_weights("./outs/chatmodel_weights.h5")

### encode_infer

In [86]:
def encoder_infer(model):
    encoder_model = keras.Model(inputs= model.get_layer("encoder").input,
                               outputs = model.get_layer("encoder").output)
    return encoder_model

In [88]:
encoder_model = encoder_infer(model)
encoder_model.summary()

Model: "model_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
encoder_input (InputLayer)   [(None, 20)]              0         
_________________________________________________________________
encoder (Encoder)            ((None, 20, 256), (None,  9273600   
Total params: 9,273,600
Trainable params: 9,273,600
Non-trainable params: 0
_________________________________________________________________


### decoder_infer

In [89]:
encoder_model.get_layer("encoder").output[0]

<KerasTensor: shape=(None, 20, 256) dtype=float32 (created by layer 'encoder')>

In [90]:
encoder_model.get_layer("encoder").output[1]

<KerasTensor: shape=(None, 256) dtype=float32 (created by layer 'encoder')>

In [97]:
def decoder_infer(model,encoder_model):
    encoder_output = encoder_model.get_layer("encoder").output[0]
    maxlen, hidden_units = encoder_output.shape[1:]
    
    dec_input = model.get_layer('decoder_input').input
    enc_output = keras.Input(shape=(maxlen, hidden_units), name='enc_output')
    dec_inp_state = keras.Input(shape=(hidden_units,), name='dec_inp_state')
    
    
    decoder = model.get_layer('decoder')
    dec_outputs, dec_out_state= decoder(enc_output, dec_input, dec_inp_state)
    
    decoder_dense = model.get_layer('dense')
    dense_output = decoder_dense(dec_outputs)
    
    decoder_model = keras.Model(inputs=[enc_output, dec_input, dec_inp_state], 
                          outputs=[dense_output,dec_out_state])
    
    return decoder_model

In [98]:
decoder_model = decoder_infer(model, encoder_model)
decoder_model.summary()

Model: "model_4"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
enc_output (InputLayer)         [(None, 20, 256)]    0                                            
__________________________________________________________________________________________________
decoder_input (InputLayer)      [(None, None)]       0                                            
__________________________________________________________________________________________________
dec_inp_state (InputLayer)      [(None, 256)]        0                                            
__________________________________________________________________________________________________
decoder (Decoder)               ((None, None, 256),  9273600     enc_output[0][0]                 
                                                                 decoder_input[0][0]        