In [None]:
from __future__ import absolute_import, division
import os
import sys
import logging
import tensorflow as tf
import tensorflow.contrib.eager as tfe
import pickle
tf.logging.set_verbosity(0)

In [None]:
tf.enable_eager_execution()

In [None]:
# Set PATHs
PATH_TO_SENTEVAL = '../'
PATH_TO_DATA = './data_evaluation'
sys.path.insert(0, PATH_TO_SENTEVAL)

In [None]:
# import SentEval
import senteval

In [None]:
class Embedding(tf.keras.Model):
    def __init__(self, V, d):
        super(Embedding, self).__init__()
        self.W = tfe.Variable(tf.random_uniform(minval=-1.0, maxval=1.0, shape=[V, d]))
    
    def call(self, word_indexes):
        return tf.cast(tf.nn.embedding_lookup(self.W, word_indexes), tf.float32)

In [None]:
class StaticRNN(tf.keras.Model):
    def __init__(self, h, cell):
        super(StaticRNN, self).__init__()
        if cell == 'lstm':
            self.cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=h)
        elif cell == 'gru':
            self.cell = tf.nn.rnn_cell.GRUCell(num_units=h)
        else:
            self.cell = tf.nn.rnn_cell.BasicRNNCell(num_units=h)
        
        
    def call(self, state, word_vectors, num_words):
        word_vectors_time = tf.unstack(word_vectors, axis=1)
        outputs, final_state = tf.nn.static_rnn(cell=self.cell, initial_state = state, inputs=word_vectors_time, sequence_length=num_words, dtype=tf.float32)
        return outputs, final_state

In [None]:
class Encoder(tf.keras.Model):
    def __init__(self, V, d, h, cell):
        super(Encoder, self).__init__()
        self.word_embedding = Embedding(V, d)
        self.rnn = StaticRNN(h, cell)
        
    def call(self, word_vector, word_length):
        word_vectors = self.word_embedding(word_vector)
        rnn_outputs_time, final_state = self.rnn(None, word_vectors, word_length)
        output = []
        for i in range(int(tf.size(word_length))):
            output.append(rnn_outputs_time[int(word_length[i]) - 1][i])
        t = tf.convert_to_tensor(output, dtype=tf.float32)
        return t, final_state, self.word_embedding

In [None]:
#creae the dataset for each batch of sentences
def create_dataset(sentences, vocab_table, batch_size):
    sentences = tf.convert_to_tensor(sentences)
    dataset = tf.data.TextLineDataset.from_tensor_slices(sentences)
    dataset = dataset.map(lambda sentence: (
        vocab_table.lookup(tf.string_split([(tf.string_split([sentence],',')).values[0]]).values),
        tf.size(vocab_table.lookup(tf.string_split([(tf.string_split([sentence],',')).values[0]]).values))
                         ))
    dataset = dataset.padded_batch(batch_size=batch_size, padded_shapes=([None], []))
    return dataset

In [None]:
def prepare(params, samples):
    return

In [None]:
#return the outputs of the encoder
def batcher(params, batch):
    batch = [' '.join(sent) if sent != [] else '.' for sent in batch]
    dataset = create_dataset(batch, params['vocab'], params['classifier']['batch_size'])
    datum = next(iter(dataset))
    embeddings,_,_ = params['encoder'](datum[0], datum[1])
    return embeddings

In [None]:
# define senteval params
params_senteval = {'task_path': PATH_TO_DATA, 'usepytorch': True, 'kfold': 5}
params_senteval['classifier'] = {'nhid': 0, 'optim': 'rmsprop', 'batch_size': 128,
                                 'tenacity': 3, 'epoch_size': 2}
# Set up logger
logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)

In [None]:
if __name__ == "__main__":
    
    #load the english vocab
    from tensorflow.python.ops import lookup_ops
    english_vocab_file = './data/english_vocab.txt'
    english_vocab_table = lookup_ops.index_table_from_file(english_vocab_file, default_value=0)
    params_senteval['vocab'] = english_vocab_table
    
    #loading the final trained model
    opt = tf.train.AdamOptimizer(learning_rate=0.002)
    encoder_nmt = Encoder(english_vocab_table.size(), 256, 512, 'gru')
    checkpoint_dir = './encoder_nmt'
    root = tfe.Checkpoint(optimizer=opt, model=encoder_nmt, optimizer_step=tf.train.get_or_create_global_step())
    root.restore(tf.train.latest_checkpoint(checkpoint_dir))
    params_senteval['encoder'] = encoder_nmt
    
    
    #running the evalutaton tasks
    se = senteval.engine.SE(params_senteval, batcher, prepare)
    transfer_tasks = ['STS12', 'STS13', 'STS14', 'STS15', 'STS16', 'TREC']
    results = se.eval(transfer_tasks)
    print(results)

In [None]:
import torch
torch.cuda.is_available()