In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf
from tensorflow import keras
from utils import load_data,get_word_index
import numpy as np

import tensorflow_addons as tfa
from tensorflow.keras.layers import Layer
from tensorflow.keras.regularizers import l2

In [None]:
(train_data, train_labels), (test_data, test_labels) = load_data()
# 一个映射单词到整数索引的词典
word_index = get_word_index()

# 保留第一个索引
word_index = {k:(v+3) for k,v in word_index.items()}
word_index["<PAD>"] = 0
word_index["<START>"] = 1
word_index["<UNK>"] = 2  # unknown
word_index["<UNUSED>"] = 3

reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])

def decode_review(text):
    return ' '.join([reverse_word_index.get(i, '?') for i in text])

In [None]:
train_data = keras.preprocessing.sequence.pad_sequences(train_data,
                                                        value=word_index["<PAD>"],
                                                        padding='post',
                                                        maxlen=256)

test_data = keras.preprocessing.sequence.pad_sequences(test_data,
                                                       value=word_index["<PAD>"],
                                                       padding='post',
                                                       maxlen=256)

论文地址：https://arxiv.org/pdf/1803.01271.pdf  
<img src="./tcn.png" style="width:400;height:300px;"> 
结构很简单

In [None]:
class dilatedConv(Layer):
    def __init__(self, filters, kernel_size, dilated, regularizer, **kwargs):
        self.filters = filters
        self.kernel_size = kernel_size
        self.dilated = dilated
        self.regularizer = l2(l=0.01)
        super(dilatedConv, self).__init__(**kwargs)

    def build(self, input_shape):
        input_dim = input_shape[-1]
        self.conv1 = tfa.layers.WeightNormalization(
            keras.layers.Conv1D(self.filters, self.kernel_size, dilation_rate=self.dilated[0], 
                                padding='causal', kernel_regularizer = self.regularizer)
        )
        self.conv2 = tfa.layers.WeightNormalization(
            keras.layers.Conv1D(self.filters, self.kernel_size, dilation_rate=self.dilated[1], 
                                padding='causal', kernel_regularizer = self.regularizer)
        )
        self.conv3 = tfa.layers.WeightNormalization(
            keras.layers.Conv1D(self.filters, self.kernel_size, dilation_rate=self.dilated[2], 
                                padding='causal', kernel_regularizer = self.regularizer)
        )
        super(dilatedConv, self).build(input_shape)
    
    def call(self, inputs, **kwargs):
        #inputs = tf.pad(inputs, [[0,0], [(self.kernel_size-1)*self.dilated[0],0], [0,0]]) #抵消（k-1）*d带来的维度变化
        #padding='causal'更简单
        dconv1 = self.conv1(inputs)
        dconv2 = self.conv2(dconv1)
        dconv3 = self.conv3(dconv2)
        return dconv3

def residual_block(filters, kernel_size, dilated, inputs, regularizer = l2(l=0.0)):
    residual_conv = keras.layers.Conv1D(filters, 1)(inputs)
    
    Dconv1 = dilatedConv(filters, kernel_size, dilated, regularizer)(inputs)
    relu1 = keras.layers.ReLU()(Dconv1)
    drop1 = keras.layers.Dropout(0.5)(relu1)
    Dconv2 = dilatedConv(filters, kernel_size, dilated, regularizer)(drop1)
    relu2 = keras.layers.ReLU()(Dconv2)
    drop2 = keras.layers.Dropout(0.5)(relu2)
    
    return tf.add(drop2, residual_conv)

In [None]:
vocab_size = len(word_index)
inputs = keras.layers.Input(shape = (256,), name = 'input')
emb = keras.layers.Embedding(vocab_size, 32)(inputs)

tcn = residual_block(32, 3, [1, 2, 4], emb)
dense = keras.layers.GlobalAveragePooling1D()(tcn)
dense = keras.layers.Dense(16, activation='relu')(dense)
outputs = keras.layers.Dense(1, activation='sigmoid')(dense)

In [None]:
##数据集小特别容易过拟合
optimizer = keras.optimizers.Adam(learning_rate = 0.01)
early_stopping = keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)

model = keras.Model(inputs = [inputs], outputs = [outputs])

model.compile(optimizer=optimizer,
              loss='binary_crossentropy',
              metrics=['accuracy', keras.metrics.Precision()],
              callbacks=[early_stopping]
             )

history = model.fit(train_data,
                    train_labels,
                    epochs=2,
                    batch_size=256,
                    shuffle=True,
                    validation_data=(test_data, test_labels),
                    verbose=1)