In [1]:
import numpy as np
import tensorflow as tf
import logging

In [2]:
from transformer_modules import *
from utils import *

In [3]:
np.random.seed(20)

In [4]:
#building a toy dataset, where output sequence is the same as input
N_ = 1000  #dataset size
n_seq = 15  #sizes of sequences
max_val = 10  #number of distinct 'characters', input and output will be numbers from 0 to max_val
inp_data = np.random.randint(low=0, high=max_val, size = (N_,n_seq))  #random sequences as inputs
out_data = np.concatenate([np.full([N_, 1], np.max(inp_data)+1), inp_data], axis = 1)  #outputs should be shifted by 1 character to the right, 
                                                                                       #insert a dummy character in the beginning

data = (inp_data, out_data)
batch_size=1
batch_iter = gen_batch(data, batch_size=batch_size)

In [5]:
#helper function for decoder output placeholder feeding
def get_labels(dec_inp_batch):
    '''decoder output is the same as the decoder input shifted by one position.
    The input argument should be given already the shifted sequence, i.e. the output ground truth'''
    lbls = np.zeros([*dec_inp_batch.shape, max_val+1])
    for i in range(lbls.shape[0]):
        for j in range(lbls.shape[1]):
            lbls[i,j,dec_inp_batch[i,j]] = 1
    return lbls

In [6]:
#input and output placeholders
enc_inp_seq = tf.placeholder(shape=(batch_size, n_seq), dtype=tf.int32, name='enc_inp_plhd') 
dec_inp_seq = tf.placeholder(shape=(batch_size, n_seq), dtype=tf.int32, name='dec_inp_plhd') 
out_lbls = tf.placeholder(shape=(batch_size, n_seq, max_val+1), dtype=tf.int32, name='out_plhd') 

In [7]:
d_model = 50 #size of the embedding vectors

In [8]:
#declare the embeddings
inp_embs = WordEmbeddings(d_model=d_model, d_vocab=max_val, name='inp_embs')
out_embs = WordEmbeddings(d_model=d_model, d_vocab=max_val+1, name='out_embs')

#declare the tensors to be fed to model as encoder and decoder inputs
tsf_enc_inp = inp_embs(enc_inp_seq)
tsf_dec_inp = out_embs(dec_inp_seq)

In [9]:
#mask for keeping the decoder from having access to it's own subsequent entries
mask = subsequent_mask(n_seq)
mask = tf.expand_dims(tf.constant(mask), axis=0)

In [10]:
#placeholder for decoder mask
dec_mask = tf.placeholder_with_default(input=mask, shape = mask.get_shape().as_list())

In [11]:
#setting logger level to 5 to skip the debugging messages
logger.setLevel(5)

In [12]:
#Defining the model
model = TransformerModel(d_inp_vocab=max_val, d_out_vocab=max_val+1, d_model=d_model, n_blocks=1, n_heads=5, d_ff=100, dropout=0)
out = model(tsf_enc_inp, tsf_dec_inp, tgt_mask=dec_mask)

In [13]:
#defining the loss and the train step
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=out_lbls, logits=out, dim=-1))
global_step = tf.Variable(0, trainable=False)
train_step = tf.train.AdamOptimizer(learning_rate=0.003, name = 'adam').minimize(loss, global_step=global_step)

In [14]:
#defining an op to track the accuracy
correct_preds = tf.equal(tf.argmax(out, axis=-1), tf.argmax(out_lbls, axis=-1))
accuracy = tf.reduce_mean(tf.cast(correct_preds, tf.float32))

In [15]:
config = tf.ConfigProto(device_count = {'GPU': 0}) #train on cpu
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())

In [16]:
n_epochs = 3
for epoch in range(n_epochs):
    batch_iter = gen_batch(data, batch_size=batch_size)
    for i, next_batch in enumerate(batch_iter):
        if len(next_batch[0]) != 0:
            loss_i, _, acc_i = sess.run([loss, train_step, accuracy], feed_dict={enc_inp_seq: next_batch[0], dec_inp_seq:next_batch[1][:,:-1], out_lbls: get_labels(next_batch[1][:,1:])})
            if i % 500 == 0:
                print ('Step {}, loss value: {}, accuracy: {}'.format(i,loss_i, acc_i))

Step 0, loss value: 5.388062477111816, accuracy: 0.0
Step 500, loss value: 0.029122380539774895, accuracy: 1.0
Step 0, loss value: 0.0014398357598111033, accuracy: 1.0
Step 500, loss value: 0.0009113454143516719, accuracy: 1.0
Step 0, loss value: 0.0003404961316846311, accuracy: 1.0
Step 500, loss value: 0.00011632432870101184, accuracy: 1.0


In [17]:
#function for inference. Greedy
def predict(enc_inp):
    N_seq = 15
    dec_inp = np.zeros([1,N_seq])
    dec_inp[:,0]=10 #first decoder input should be the dummy character, the rest will be masked by _mask
    
    output = np.zeros(N_seq).astype(np.int8) #an array for the final output
    
    for i in range(N_seq):
        _mask = np.zeros([1,1,N_seq,N_seq])  
        _mask[:,:,:,:i+1] = 1 #keep access only to first i elements of the sequence, mask the rest
        pred_i = sess.run(tf.argmax(tf.nn.softmax(out, axis=-1), axis=-1), feed_dict={enc_inp_seq: enc_inp, dec_inp_seq:dec_inp, dec_mask:_mask}) #prediction at the i-th iteration
        dec_inp[:,i] = pred_i[0][i] #next decoder input should include the last predicition
        output[i] = pred_i[0][i] #i-th element of the prediction
        print ('At iteration {}, sequence prediction: {}, instance prediction: {}'.format(i, pred_i[0], output[i]))
    print ('\n')
    print ('Encoder input: {}'.format(enc_inp[0]))
    print ('Decoder output: {}'.format(output))
    print ('Prediction was correct: {}'.format(np.all(output == enc_inp[0])))
    return output.astype(np.int8)
    

In [18]:
rand_inp = np.random.randint(low=0, high=10, size=(1,15)) #some random sequence for input
rand_inp

array([[7, 7, 0, 8, 1, 2, 2, 3, 0, 2, 9, 7, 1, 6, 2]])

In [19]:
predict(rand_inp)

At iteration 0, sequence prediction: [7 7 7 8 1 7 7 3 1 2 7 8 1 7 7], instance prediction: 7
At iteration 1, sequence prediction: [7 7 0 8 1 2 2 3 0 2 9 8 1 6 7], instance prediction: 7
At iteration 2, sequence prediction: [0 7 0 8 1 2 2 3 0 2 0 0 1 6 2], instance prediction: 0
At iteration 3, sequence prediction: [0 7 0 8 1 2 2 3 0 2 0 7 1 6 2], instance prediction: 8
At iteration 4, sequence prediction: [0 7 0 8 1 2 2 3 0 2 0 0 1 6 2], instance prediction: 1
At iteration 5, sequence prediction: [0 7 0 8 1 2 2 3 0 2 9 7 1 6 2], instance prediction: 2
At iteration 6, sequence prediction: [0 7 0 8 1 2 2 3 0 2 9 7 1 6 2], instance prediction: 2
At iteration 7, sequence prediction: [0 7 0 8 1 2 2 3 0 2 9 7 1 6 2], instance prediction: 3
At iteration 8, sequence prediction: [0 7 0 8 1 2 2 3 0 2 9 7 1 6 2], instance prediction: 0
At iteration 9, sequence prediction: [0 7 0 8 1 2 2 3 0 2 9 7 1 6 2], instance prediction: 2
At iteration 10, sequence prediction: [1 7 0 8 1 2 2 3 0 2 9 7 1 6 2],

array([7, 7, 0, 8, 1, 2, 2, 3, 0, 2, 9, 7, 1, 6, 2], dtype=int8)