In [1]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import numpy as np
import tensorflow as tf
from tensor2tensor import models
from tensor2tensor import problems
from tensor2tensor.layers import common_layers
from tensor2tensor.utils import trainer_lib
from tensor2tensor.utils import t2t_model
from tensor2tensor.utils import registry
from tensor2tensor.utils import metrics

tfe = tf.contrib.eager
tfe.enable_eager_execution()
Modes = tf.estimator.ModeKeys


For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.

INFO:tensorflow:Entry Point [tensor2tensor.envs.tic_tac_toe_env:TicTacToeEnv] registered with id [T2TEnv-TicTacToeEnv-v0]


### Register Problem

In [9]:
from tensor2tensor.data_generators import semantic_search
search_problem = semantic_search.SemanticSearch()

## Generating Data 

In [None]:
"""Skip this if already generated data"""
search_problem.generate_data(data_dir='/tf/t2t_data', tmp_dir='/tf/datagen')

### Building Model and Train

In [2]:
PARAMS = {}
PARAMS['T2T_Problem'] = 'semantic_search'
PARAMS['T2T_Model'] = 'transformer'
PARAMS['T2T_HPARAMS'] = 'transformer_base_single_gpu'
PARAMS['train_steps'] = 10000000
PARAMS['eval_steps'] = 1000
PARAMS['keep_checkpoint_max'] = 3

In [3]:
import tensorflow as tf
from tensor2tensor.utils import trainer_lib
from tensor2tensor.utils.trainer_lib import create_run_config, create_experiment, create_hparams
from tensor2tensor.utils import registry
from tensor2tensor import models, problems

hparams = create_hparams(PARAMS['T2T_HPARAMS'])

FLAGS = tf.flags
FLAGS.problems = PARAMS['T2T_Problem']
FLAGS.problem = PARAMS['T2T_Problem']
FLAGS.model = PARAMS['T2T_Model']
FLAGS.schedule = "train_and_evaluate"

"""Changing up warmup steps"""
hparams.batch_size = 1024
hparams.learning_rate_warmup_steps = 400
hparams.learning_rate = .4

In [4]:
PARAMS['TMP_DIR'] = '/tf/datagen/'
PARAMS['DATA_DIR'] = '/tf/t2t_data'
PARAMS['TRAIN_DIR'] = '/tf/t2t_train/intent_to_code/conala/' 
PARAMS['OUTPUT_DIR'] = 'tf/t2t_train/semantic_search'

In [5]:
RUN_CONFIG = create_run_config(hparams, model_dir=PARAMS['TRAIN_DIR'])

Instructions for updating:
When switching to tf.estimator.Estimator, use tf.estimator.RunConfig instead.
INFO:tensorflow:Configuring DataParallelism to replicate the model.
INFO:tensorflow:schedule=continuous_train_and_eval
INFO:tensorflow:worker_gpu=1
INFO:tensorflow:sync=False
INFO:tensorflow:datashard_devices: ['gpu:0']
INFO:tensorflow:caching_devices: None
INFO:tensorflow:ps_devices: ['gpu:0']


In [None]:
exp_fn = create_experiment(
        run_config=RUN_CONFIG,
        hparams=hparams,
        model_name=PARAMS['T2T_Model'],
        problem_name=PARAMS['T2T_Problem'],
        data_dir=PARAMS['DATA_DIR'],
        train_steps=PARAMS['train_steps'],
        eval_steps=PARAMS['eval_steps']
    )
exp_fn.train_and_evaluate() 

## Getting latest checkpoint

In [6]:
"""Getting latest checkpoint"""
ckpt_path = tf.train.latest_checkpoint(PARAMS['TRAIN_DIR'])
ckpt_path 

'/tf/t2t_train/intent_to_code/conala/model.ckpt-51000'

### Evaluating Translation

In [23]:
translate_model = registry.model(PARAMS['T2T_Model'])(hparams, Modes.EVAL)

encoders = problems.problem(PARAMS['T2T_Problem']).feature_encoders(PARAMS['DATA_DIR'])

def encode(input_str, output_str=None):
    """Input str to features dict, ready for inference"""
    inputs = encoders["inputs"].encode(input_str) + [1]  # add EOS id
    batch_inputs = tf.reshape(inputs, [1, -1, 1])  # Make it 3D.
    return {"inputs": batch_inputs}

def decode(integers):
    """List of ints to str"""
    integers = list(np.squeeze(integers))
    if 1 in integers:
        integers = integers[:integers.index(1)]
    return encoders["inputs"].decode(np.squeeze(integers))

# Restore and translate!
def translate(inputs):
    encoded_inputs = encode(inputs)
    with tfe.restore_variables_on_create(tf.train.latest_checkpoint(PARAMS['TRAIN_DIR'])):
        model_output = translate_model.infer(encoded_inputs, 
                                             beam_size=4,
                                             alpha=0.6)["outputs"]
        return decode(model_output)

In [14]:
conala_df = pd.read_json("/tf/datagen/conala-test.json.prod")
conala_df.head()

Unnamed: 0,intent,intent_tokens,question_id,rewritten_intent,slot_map,snippet,snippet_tokens
0,How can I send a signal from a python program?,"[send, a, signal, `, signal.SIGUSR1, `, to, th...",15080500,send a signal `signal.SIGUSR1` to the current ...,{},"os.kill(os.getpid(), signal.SIGUSR1)","[os, ., kill, (, os, ., getpid, (, ), ,, signa..."
1,Decode Hex String in Python 3,"[decode, a, hex, string, '4a4b4c, ', to, UTF-8...",3283984,decode a hex string '4a4b4c' to UTF-8.,{},bytes.fromhex('4a4b4c').decode('utf-8'),"[bytes, ., fromhex, (, '4a4b4c', ), ., decode,..."
2,check if all elements in a list are identical,"[check, if, all, elements, in, list, `, myList...",3844801,check if all elements in list `myList` are ide...,{},all(x == myList[0] for x in myList),"[all, (, x, ==, myList, [, 0, ], for, x, in, m..."
3,Format string dynamically,"[format, number, of, spaces, between, strings,...",4302166,format number of spaces between strings `Pytho...,{},"print('%*s : %*s' % (20, 'Python', 20, 'Very G...","[print, (, '%*s#SPACE#:#SPACE#%*s', %, (, 20, ..."
4,How to convert a string from CP-1251 to UTF-8?,"[How, to, convert, a, string, from, CP-1251, t...",7555335,,,d.decode('cp1251').encode('utf8'),"[d, ., decode, (, 'cp1251', ), ., encode, (, '..."


In [25]:
intent, code = conala_df.iloc[1].rewritten_intent, conala_df.iloc[1].snippet
print(intent)
print(code)

decode a hex string '4a4b4c' to UTF-8.
bytes.fromhex('4a4b4c').decode('utf-8')


In [None]:
test_result = []
for row in conala_df.itertuples():
    intent = row.rewritten_intent if row.rewritten_intent else row.intent
    code = translate(intent)
    test_result.append(code)

In [None]:
"""Write predicted translation to file"""
with open('/tf/datagen/translation.txt', 'w+') as f:
    for l in test_result:
        f.write(l+'\n')

In [None]:
"""Write intent to file so we can use test_conala.sh to decode"""

with open('/tf/datagen/test_intent.txt', 'w+') as f:
    for row in conala_df.itertuples():
        intent = row.rewritten_intent if row.rewritten_intent else row.intent
        f.write( intent + '\n')