In [1]:
from tqdm import tqdm

In [2]:
import tensorflow as tf
import numpy as np

# ByteNet

In [3]:
def layer_norm(input, causal=False, name=None):
    '''
    Layer Normalization
    
    If the model is causal and using Convnet,
    normalize input only according to depth.
    '''
    with tf.variable_scope('layer_norm', name):
        if causal: # Sub Layer Normalization
            axis_depth = len(input.get_shape()) - 1
            mean, var = tf.nn.moments(input, [axis_depth], keep_dims=True)
            out = (input - mean) / tf.sqrt(var)
            return out
        else: # Layer Normalization
            axes = np.arange(len(input.get_shape()) - 1) + 1
            mean, var = tf.nn.moments(input, axes, keep_dims=True)
            out = (input - mean) / tf.sqrt(var)
            return out

def convolution(input, filter, padding, strides=None, dilation_rate=None, causal=False, name=None):
    '''
    Masked Convolution
    
    See PixelCNN
    '''
    with tf.variable_scope('masked_convolution', name):
        filter_shape = filter.get_shape().as_list()
        filter_len = np.prod(filter_shape[:-2])
        center = filter_len // 2
        if causal:
            mask = np.ones([filter_len] + filter_shape[-2:], dtype='float32')
            mask[center+1: ,: ,:] = 0.
            mask = mask.reshape(filter_shape)
            
            mask = tf.constant(mask, dtype='float32')
            filter = filter * mask


        ret = tf.nn.convolution(input, filter, padding=padding, strides=strides,
                                dilation_rate=dilation_rate, name=name)
        
    return ret


def res_block(input, filter_size=3, dilation_rate=None, causal=False, name='res_block'):
    '''
    Residual block
    
    For details, see Ch3.6(Fig 3. Left) of 'Neural Machine Translation in Linear Time(https://arxiv.org/abs/1610.10099)'.
    '''
    
    with tf.variable_scope(name):
        x = input
        
        # input dimension
        in_dim = input.get_shape().as_list()[-1]
    
        # normalization
        x = layer_norm(x, causal)
        x = tf.nn.relu(x)
        
        # reduce dimension
        w_shape = [1, in_dim, in_dim//2]
        w_stddev = np.sqrt(2./np.prod(w_shape[:-1])) # He's init
        w = tf.get_variable(shape=w_shape, initializer=tf.random_normal_initializer(stddev=w_stddev),
                            name='w1')
        x = tf.nn.convolution(x, w, padding='SAME')
        x = layer_norm(x, causal)
        x = tf.nn.relu(x)
        
        # 1xk conv dilated (with mask)
        w_shape = [filter_size, in_dim//2, in_dim//2]
        if causal:
            w_stddev = np.sqrt(2. / (np.prod(w_shape[1:-1]) * (filter_size//2 + 1)))
        else:
            w_stddev = np.sqrt(2./np.prod(w_shape[:-1])) # He's init
        w = tf.get_variable(shape=w_shape, initializer=tf.random_normal_initializer(stddev=w_stddev),
                            name='w2')
        x = convolution(x, w, padding='SAME', dilation_rate=dilation_rate, causal=causal)
        x = layer_norm(x, causal)
        x = tf.nn.relu(x)
        
        # dimension recover and residual connection
        w_shape = [1, in_dim//2, in_dim]
        w_stddev = np.sqrt(2./np.prod(w_shape[:-1])) # He's init
        w = tf.get_variable(shape=w_shape, initializer=tf.random_normal_initializer(stddev=w_stddev),
                            name='w3')
        x = tf.nn.convolution(x, w, padding='SAME')
        
        # residual connection
        x = x + input
        

    return x


def encoder(input, filter_size=3, num_block_sets=6):
    '''
    Encoder for Character-Level Machine Translation
    
    For details, see Ch6 of 'Neural Machine Translation in Linear Time(https://arxiv.org/abs/1610.10099)'.
    '''
    with tf.variable_scope('encoder'):
        x = input
        for i in range(num_block_sets):
            for j in [1,2,4,8,16]:
                x = res_block(x, filter_size=filter_size, dilation_rate=[j], name='res_block_%d_%d' % (i, j))
        
    return x

def decoder(input, filter_size=3, num_block_sets=6):
    '''
    Decoder for Character-Level Machine Translation
    
    For details, see Ch6 of 'Neural Machine Translation in Linear Time(https://arxiv.org/abs/1610.10099)'.
    '''
    with tf.variable_scope('decoder'):
        x = input
        for i in range(num_block_sets):
            for j in [1,2,4,8,16]:
                x = res_block(x, filter_size=filter_size, dilation_rate=[j],
                              causal=True, name='res_block_%d_%d' % (i, j))
        
    return x

In [4]:
class ByteNet(object):
    """
    ByteNet

    For details, see 'Neural Machine Translation in Linear Time(https://arxiv.org/abs/1610.10099)'.
    """ 

    def __init__(self, input_dim=254, input_max_len=150, latent_dim=200, num_block_sets=4):
        self.input_dim = input_dim
        self.input_max_len = input_max_len
        self.filter_size = 3

        self.latent_dim = latent_dim
        self.num_block_sets = num_block_sets

    def encoder(self, x):
        #
        # inputs
        #
        with tf.variable_scope('input'):
            # make embedding matrix for source and target
            emb_x = tf.get_variable(shape=[self.input_dim, self.latent_dim],
                                    initializer=tf.random_uniform_initializer(-1.0, 1.0),
                                    name='emb_x')

        #
        # encode graph ( atrous convolution )
        #

        # embed table lookup
        enc_emb = tf.nn.embedding_lookup(emb_x, x)
        enc = encoder(enc_emb, filter_size=self.filter_size, num_block_sets=self.num_block_sets)

        return enc

    def decoder(self, enc, y, p_keep_conv):
        #
        # inputs
        #
        with tf.variable_scope('input'):
            emb_y = tf.get_variable(shape=[self.input_dim, self.latent_dim], 
                                    initializer=tf.random_uniform_initializer(-1.0, 1.0),
                                    name='emb_y')
            y_src = tf.pad(y[:,:-1], [[0,0],[1,0]])

        #
        # decode graph ( causal convolution )
        #

        # loop dilated causal conv block
        dec_emb = tf.concat([enc, tf.nn.embedding_lookup(emb_y, y_src)], 2)
        dec = decoder(dec_emb, filter_size=self.filter_size, num_block_sets=self.num_block_sets)


        with tf.variable_scope('output'):
            # additional convolution and relu
            out = layer_norm(dec, causal=True)
            out = tf.nn.relu(out)
            out_dim = out.get_shape().as_list()[-1] # latent_dim * 2
            w_shape = [1, out_dim, out_dim]
            w_stddev = np.sqrt(2./np.prod(w_shape[:-1])) # He's init
            w = tf.get_variable(shape=w_shape, initializer=tf.random_normal_initializer(stddev=w_stddev),
                                name='w1')
            out = tf.nn.convolution(out, w, padding='SAME')

            # final fully convolution layer for softmax
            logits = layer_norm(out, causal=True)
            logits = tf.nn.relu(logits)

            logits = tf.nn.dropout(logits, p_keep_conv)

            w_shape = [1, out_dim, self.input_dim]
            w_stddev = np.sqrt(2./np.prod(w_shape[:-1])) # He's init
            w = tf.get_variable(shape=w_shape, initializer=tf.random_normal_initializer(stddev=w_stddev),
                                name='w2')
            logits = tf.nn.convolution(logits, w, padding='SAME')

        return logits

## Test

In [5]:
from preprocess import MAX_LEN
from batch import batch_iter

Using TensorFlow backend.


INFO:tensorflow:Train data loaded.(total data=486376, total batch=15199)
INFO:tensorflow:Train data loaded.(total data=486376, total batch=15199)


In [6]:
# hyperparameters
latent_dim = 100   # hidden layer dimension
num_block_sets = 2     # dilated blocks

In [7]:
p_keep_conv = tf.placeholder(tf.float32, [])

alpha1 = tf.constant(0.10, dtype=np.float32, name="a1")
alpha2 = tf.constant(0.10, dtype=np.float32, name="a2")
alpha3 = tf.constant(0.05, dtype=np.float32, name="a3")
in_u1 = tf.placeholder(tf.int32, [None, MAX_LEN], name="ull")
in_v1 = tf.placeholder(tf.int32, [None, MAX_LEN], name="vll")
in_u2 = tf.placeholder(tf.int32, [None, MAX_LEN], name="ulu")
in_v2 = tf.placeholder(tf.int32, [None, MAX_LEN], name="vlu")
in_u3 = tf.placeholder(tf.int32, [None, MAX_LEN], name="ulu")
in_v3 = tf.placeholder(tf.int32, [None, MAX_LEN], name="ulu")
labels_u1 = tf.placeholder(tf.int32, [None, MAX_LEN], name="lu1")
labels_v1 = tf.placeholder(tf.int32, [None, MAX_LEN], name="lv1")
labels_u2 = tf.placeholder(tf.int32, [None, MAX_LEN], name="lu2")
weights_ll = tf.placeholder(tf.float32, [None, ], name="wll")
weights_lu = tf.placeholder(tf.float32, [None, ], name="wlu")
weights_uu = tf.placeholder(tf.float32, [None, ], name="wuu")
cu1 = tf.placeholder(tf.float32, [None, ], name="CuLL")
cv1 = tf.placeholder(tf.float32, [None, ], name="CvLL")
cu2 = tf.placeholder(tf.float32, [None, ], name="CuLU")

labels_zero_1 = tf.placeholder(tf.int32, [None, MAX_LEN], name="l0_1")
labels_zero_2 = tf.placeholder(tf.int32, [None, MAX_LEN], name="l0_2")
labels_zero_3 = tf.placeholder(tf.int32, [None, MAX_LEN], name="l0_3")

In [8]:
with tf.variable_scope('model') as scope:
    model = ByteNet(latent_dim=latent_dim, num_block_sets=num_block_sets)
    enc_u1 = model.encoder(in_u1)
    logits_u1 = model.decoder(enc_u1, labels_u1, p_keep_conv)

In [9]:
with tf.variable_scope('model', reuse=True) as scope:
    enc_v1 = model.encoder(in_v1)
    enc_u2 = model.encoder(in_u2)
    enc_v2 = model.encoder(in_v2)
    enc_u3 = model.encoder(in_u3)
    enc_v3 = model.encoder(in_v3)    

In [10]:
with tf.variable_scope('model', reuse=True) as scope:
    logits_v1 = model.decoder(enc_v1, labels_v1, p_keep_conv)
    logits_u2 = model.decoder(enc_u2, labels_u2, p_keep_conv)
    
    scores_u1 = model.decoder(enc_u1, labels_zero_1, p_keep_conv)
    scores_v1 = model.decoder(enc_v1, labels_zero_1, p_keep_conv)
    scores_u2 = model.decoder(enc_u2, labels_zero_2, p_keep_conv)
    scores_v2 = model.decoder(enc_v2, labels_zero_2, p_keep_conv)
    scores_u3 = model.decoder(enc_u3, labels_zero_3, p_keep_conv)
    scores_v3 = model.decoder(enc_v3, labels_zero_3, p_keep_conv)

In [11]:
# vanilla loss
# cross entropy loss with logit and mask 
def vanilla_loss(logits, labels):
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)
    loss = tf.identity(loss)
    loss *= tf.cast(tf.not_equal(labels, tf.zeros_like(labels)), loss.dtype)
    loss = tf.reduce_sum(loss, 1)

    return loss

In [12]:
# distance loss
def distance_loss(scores_u, scores_v):
    loss = tf.nn.softmax_cross_entropy_with_logits(logits=scores_u, labels=tf.nn.softmax(scores_v))
    loss = tf.reduce_sum(loss, 1)
    return loss

In [13]:
loss_function = tf.reduce_mean(cu1 * vanilla_loss(logits_u1, labels_u1))\
                    + tf.reduce_mean(cv1 * vanilla_loss(logits_v1, labels_v1))\
                    + tf.reduce_mean(cu2 * vanilla_loss(logits_u2, labels_u2))

In [14]:
loss_function += tf.reduce_mean(alpha1 * weights_ll * distance_loss(scores_u1, scores_v1))\
                    + tf.reduce_mean(alpha2 * weights_lu * distance_loss(scores_u2, scores_v2))\
                    + tf.reduce_mean(alpha3 * weights_uu * distance_loss(scores_u3, scores_v3))

In [15]:
optimizer = tf.train.AdamOptimizer(1e-3).minimize(loss_function)

In [16]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [17]:
num_epochs = 1
for epoch in range(num_epochs):
    print("======== EPOCH " + str(epoch + 1) + " ========")

    batches = batch_iter(batch_size=32)
    epoch_loss = 0
    
    cnt = 0
    for batch in tqdm(batches):

        u1, v1, lu1, lv1, u3, v3, u2, v2, lu2, w_ll, w_lu, w_uu, c_ull, c_vll, c_ulu = batch
        
        l0_1 = np.zeros(u1.shape)
        l0_2 = np.zeros(u2.shape)
        l0_3 = np.zeros(u3.shape)
        _, loss = sess.run([optimizer, loss_function],
                                feed_dict={in_u1: u1,
                                           in_v1: v1,
                                           in_u2: u2,
                                           in_v2: v2,
                                           in_u3: u3,
                                           in_v3: v3,
                                           labels_u1: lu1,
                                           labels_v1: lv1,
                                           labels_u2: lu2,
                                           weights_ll: w_ll,
                                           weights_lu: w_lu,
                                           weights_uu: w_uu,
                                           cu1: c_ull,
                                           cv1: c_vll,
                                           cu2: c_ulu,
                                           p_keep_conv: 0.9,
                                           labels_zero_1: l0_1,
                                           labels_zero_2: l0_2,
                                           labels_zero_3: l0_3})
        epoch_loss += loss
        print(cnt, loss, end='\r')
        cnt += 1
        
    print()
    print("Epoch_Loss", epoch_loss/cnt)

0it [00:00, ?it/s]



1it [00:11, 11.08s/it]

0 308.109

2it [00:12,  8.25s/it]

1 281.733

3it [00:14,  6.25s/it]

2 281.679

4it [00:15,  4.64s/it]

3 292.044

5it [00:15,  3.34s/it]

4 243.823

6it [00:16,  2.62s/it]

5 239.263

7it [00:16,  1.92s/it]

6 240.612

8it [00:17,  1.44s/it]

7 247.699

9it [00:17,  1.28s/it]

8 222.178

10it [00:19,  1.28s/it]

9 224.268

11it [00:20,  1.18s/it]

10 236.412

12it [00:21,  1.14s/it]

11 222.616

13it [00:22,  1.13s/it]

12 233.502

14it [00:22,  1.13it/s]

13 237.303

15it [00:23,  1.09it/s]

14 237.451

16it [00:23,  1.35it/s]

15 213.148

17it [00:25,  1.18it/s]

16 203.915

18it [00:25,  1.46it/s]

17 197.639

19it [00:25,  1.74it/s]

18 186.948

20it [00:26,  1.47it/s]

19 199.339

21it [00:26,  1.75it/s]

20 215.594

22it [00:27,  2.03it/s]

21 200.431

23it [00:27,  2.28it/s]

22 207.0

24it [00:27,  2.49it/s]

23 188.9

25it [00:28,  2.69it/s]

24 188.479

26it [00:29,  1.56it/s]

25 197.628

27it [00:29,  1.84it/s]

26 204.407

28it [00:30,  2.11it/s]

27 191.432

29it [00:30,  2.39it/s]

28 194.315

30it [00:30,  2.59it/s]

29 194.504

31it [00:30,  2.74it/s]

30 193.394

32it [00:31,  2.83it/s]

31 187.378

33it [00:32,  2.05it/s]

32 185.878

34it [00:32,  2.30it/s]

33 195.037

35it [00:32,  2.51it/s]

34 197.019

36it [00:33,  2.68it/s]

35 185.631

37it [00:33,  2.81it/s]

36 178.46

38it [00:33,  2.89it/s]

37 178.602

39it [00:34,  2.97it/s]

38 182.652

40it [00:34,  3.05it/s]

39 177.332

41it [00:34,  3.08it/s]

40 190.441

42it [00:34,  3.14it/s]

41 183.178

43it [00:35,  3.13it/s]

42 186.9

44it [00:35,  3.10it/s]

43 174.28

45it [00:35,  3.13it/s]

44 175.067

46it [00:36,  3.23it/s]

45 184.005

47it [00:36,  2.23it/s]

46 177.391

48it [00:37,  2.47it/s]

47 199.424

49it [00:37,  2.66it/s]

48 171.778

50it [00:37,  2.80it/s]

49 180.984

51it [00:38,  2.94it/s]

50 184.031

52it [00:38,  3.03it/s]

51 175.251

53it [00:38,  3.09it/s]

52 181.336

54it [00:39,  3.08it/s]

53 181.642

55it [00:39,  3.17it/s]

54 181.887

56it [00:39,  3.19it/s]

55 194.952

57it [00:40,  3.22it/s]

56 175.856

58it [00:40,  3.27it/s]

57 182.711

59it [00:40,  3.29it/s]

58 181.483

60it [00:40,  3.27it/s]

59 182.513

61it [00:41,  3.26it/s]

60 180.92

62it [00:41,  3.23it/s]

61 170.62

63it [00:41,  3.19it/s]

62 169.223

64it [00:42,  3.19it/s]

63 166.679

65it [00:42,  3.21it/s]

64 180.042

66it [00:42,  3.20it/s]

65 189.226

67it [00:43,  3.21it/s]

66 175.274

68it [00:43,  3.23it/s]

67 172.035

69it [00:43,  3.23it/s]

68 171.67

70it [00:44,  3.24it/s]

69 167.363

71it [00:44,  3.30it/s]

70 190.012

72it [00:44,  3.27it/s]

71 175.599

73it [00:44,  3.19it/s]

72 173.563

74it [00:45,  3.17it/s]

73 169.438

75it [00:45,  3.17it/s]

74 180.449

76it [00:45,  3.19it/s]

75 173.309

77it [00:46,  3.23it/s]

76 159.839

78it [00:46,  3.21it/s]

77 169.484

79it [00:46,  3.19it/s]

78 179.402

80it [00:47,  3.18it/s]

79 179.036

81it [00:47,  3.21it/s]

80 181.617

82it [00:47,  3.20it/s]

81 185.076

83it [00:48,  3.19it/s]

82 171.564

84it [00:48,  3.20it/s]

83 178.592

85it [00:48,  3.19it/s]

84 178.79

86it [00:49,  3.20it/s]

85 177.753

87it [00:49,  3.18it/s]

86 188.246

88it [00:49,  3.13it/s]

87 192.264

89it [00:50,  3.19it/s]

88 173.68

90it [00:50,  3.17it/s]

89 203.81

91it [00:50,  3.16it/s]

90 198.816

92it [00:50,  3.18it/s]

91 187.289

93it [00:51,  3.24it/s]

92 186.926

94it [00:51,  3.22it/s]

93 197.502

95it [00:51,  3.19it/s]

94 190.693

96it [00:52,  3.16it/s]

95 196.787

97it [00:52,  3.16it/s]

96 206.199

98it [00:52,  3.18it/s]

97 195.314

99it [00:53,  3.16it/s]

98 200.447

100it [00:54,  1.73it/s]

99 190.685

101it [00:54,  2.01it/s]

100 192.146

102it [00:54,  2.26it/s]

101 202.672

103it [00:55,  2.47it/s]

102 206.752

104it [00:55,  2.66it/s]

103 201.237

105it [00:55,  2.77it/s]

104 195.431

106it [00:56,  2.88it/s]

105 197.53

107it [00:56,  2.98it/s]

106 205.478

108it [00:56,  3.02it/s]

107 212.231

109it [00:57,  3.03it/s]

108 203.982

110it [00:57,  3.07it/s]

109 214.12

111it [00:57,  3.08it/s]

110 213.774

112it [00:58,  3.11it/s]

111 212.612

113it [00:58,  3.14it/s]

112 202.499

114it [00:58,  3.17it/s]

113 202.362

115it [00:59,  3.16it/s]

114 202.785

116it [00:59,  3.19it/s]

115 216.419

117it [00:59,  3.18it/s]

116 219.991

118it [01:00,  3.24it/s]

117 210.208

119it [01:00,  3.22it/s]

118 208.679

120it [01:00,  3.23it/s]

119 207.614

121it [01:00,  3.20it/s]

120 222.494

122it [01:01,  3.21it/s]

121 205.148

123it [01:01,  3.24it/s]

122 203.683

124it [01:01,  3.27it/s]

123 229.119

125it [01:02,  3.22it/s]

124 215.701

126it [01:02,  3.20it/s]

125 203.645

127it [01:02,  3.19it/s]

126 202.974

128it [01:03,  3.19it/s]

127 211.601

129it [01:03,  3.18it/s]

128 209.211

130it [01:03,  3.16it/s]

129 212.35

131it [01:04,  3.20it/s]

130 238.178

132it [01:04,  3.18it/s]

131 203.956

133it [01:04,  3.17it/s]

132 210.179

134it [01:05,  3.17it/s]

133 213.686

135it [01:05,  3.22it/s]

134 206.372

136it [01:05,  3.20it/s]

135 217.62

137it [01:05,  3.21it/s]

136 218.014

138it [01:06,  3.20it/s]

137 213.759

139it [01:06,  3.18it/s]

138 212.675

140it [01:07,  2.21it/s]

139 222.852

141it [01:07,  2.43it/s]

140 208.859

142it [01:07,  2.65it/s]

141 212.541

143it [01:08,  2.81it/s]

142 228.256

144it [01:08,  2.91it/s]

143 213.437

145it [01:08,  3.03it/s]

144 217.326

146it [01:09,  3.05it/s]

145 233.544

147it [01:09,  3.14it/s]

146 205.775

148it [01:09,  3.10it/s]

147 214.944

149it [01:10,  3.14it/s]

148 207.688

150it [01:10,  3.14it/s]

149 219.435

151it [01:10,  3.11it/s]

150 209.22

152it [01:11,  3.16it/s]

151 222.327

153it [01:11,  3.14it/s]

152 216.615

154it [01:11,  3.16it/s]

153 212.116

155it [01:12,  3.17it/s]

154 210.795

156it [01:12,  3.22it/s]

155 208.026

157it [01:12,  3.20it/s]

156 211.792

158it [01:12,  3.16it/s]

157 216.386

159it [01:13,  3.16it/s]

158 219.699

160it [01:13,  3.18it/s]

159 217.963

161it [01:13,  3.17it/s]

160 209.802

162it [01:14,  3.17it/s]

161 220.963

163it [01:14,  3.16it/s]

162 215.754

KeyboardInterrupt: 