In [1]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import numpy as np
import tensorflow as tf
from tensor2tensor import models
from tensor2tensor import problems
from tensor2tensor.layers import common_layers
from tensor2tensor.utils import trainer_lib
from tensor2tensor.utils import t2t_model
from tensor2tensor.utils import registry
from tensor2tensor.utils import metrics

tfe = tf.contrib.eager
tfe.enable_eager_execution()
Modes = tf.estimator.ModeKeys


For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.

INFO:tensorflow:Entry Point [tensor2tensor.envs.tic_tac_toe_env:TicTacToeEnv] registered with id [T2TEnv-TicTacToeEnv-v0]


### Register Problem

In [2]:
from tensor2tensor.data_generators import semantic_search
search_problem = semantic_search.SemanticSearch()

## Generating Data 

In [3]:
"""Skip this if already generated data"""
search_problem.generate_data(data_dir='/tf/t2t_data', tmp_dir='/tf/datagen')

[('gs://conala/conala-mined.jsonl', 'conala-mined.jsonl'),
 ('gs://conala/conala-train.json', 'conala-train.json')]

### Building Model and Train

In [2]:
PARAMS = {}
PARAMS['T2T_Problem'] = 'semantic_search'
PARAMS['T2T_Model'] = 'transformer'
PARAMS['T2T_HPARAMS'] = 'transformer_base_single_gpu'
PARAMS['train_steps'] = 10000000
PARAMS['eval_steps'] = 1000
PARAMS['keep_checkpoint_max'] = 3

In [3]:
import tensorflow as tf
from tensor2tensor.utils import trainer_lib
from tensor2tensor.utils.trainer_lib import create_run_config, create_experiment, create_hparams
from tensor2tensor.utils import registry
from tensor2tensor import models, problems

hparams = create_hparams(PARAMS['T2T_HPARAMS'])

FLAGS = tf.flags
FLAGS.problems = PARAMS['T2T_Problem']
FLAGS.problem = PARAMS['T2T_Problem']
FLAGS.model = PARAMS['T2T_Model']
FLAGS.schedule = "train_and_evaluate"

"""Changing up warmup steps"""
hparams.batch_size = 1024
hparams.learning_rate_warmup_steps = 400
hparams.learning_rate = .4

In [4]:
PARAMS['TMP_DIR'] = '/tf/datagen/'
PARAMS['DATA_DIR'] = '/tf/t2t_data'
PARAMS['TRAIN_DIR'] = '/tf/t2t_train/intent_to_code/conala/' 
PARAMS['OUTPUT_DIR'] = 'tf/t2t_train/semantic_search'

In [5]:
RUN_CONFIG = create_run_config(hparams, model_dir=PARAMS['TRAIN_DIR'])

Instructions for updating:
When switching to tf.estimator.Estimator, use tf.estimator.RunConfig instead.
INFO:tensorflow:Configuring DataParallelism to replicate the model.
INFO:tensorflow:schedule=continuous_train_and_eval
INFO:tensorflow:worker_gpu=1
INFO:tensorflow:sync=False
INFO:tensorflow:datashard_devices: ['gpu:0']
INFO:tensorflow:caching_devices: None
INFO:tensorflow:ps_devices: ['gpu:0']


In [None]:
exp_fn = create_experiment(
        run_config=RUN_CONFIG,
        hparams=hparams,
        model_name=PARAMS['T2T_Model'],
        problem_name=PARAMS['T2T_Problem'],
        data_dir=PARAMS['DATA_DIR'],
        train_steps=PARAMS['train_steps'],
        eval_steps=PARAMS['eval_steps']
    )
exp_fn.train_and_evaluate() 

## Getting latest checkpoint

In [6]:
"""Getting latest checkpoint"""
ckpt_path = tf.train.latest_checkpoint(PARAMS['TRAIN_DIR'])
ckpt_path 

'/tf/t2t_train/intent_to_code/conala/model.ckpt-51000'

### Evaluating Translation

In [23]:
translate_model = registry.model(PARAMS['T2T_Model'])(hparams, Modes.EVAL)

encoders = problems.problem(PARAMS['T2T_Problem']).feature_encoders(PARAMS['DATA_DIR'])

def encode(input_str, output_str=None):
    """Input str to features dict, ready for inference"""
    inputs = encoders["inputs"].encode(input_str) + [1]  # add EOS id
    batch_inputs = tf.reshape(inputs, [1, -1, 1])  # Make it 3D.
    return {"inputs": batch_inputs}

def decode(integers):
    """List of ints to str"""
    integers = list(np.squeeze(integers))
    if 1 in integers:
        integers = integers[:integers.index(1)]
    return encoders["inputs"].decode(np.squeeze(integers))

# Restore and translate!
def translate(inputs):
    encoded_inputs = encode(inputs)
    with tfe.restore_variables_on_create(tf.train.latest_checkpoint(PARAMS['TRAIN_DIR'])):
        model_output = translate_model.infer(encoded_inputs, 
                                             beam_size=4,
                                             alpha=0.6)["outputs"]
        return decode(model_output)

In [47]:
conala_df = pd.read_json("/tf/datagen/conala-mined.jsonl.prod")
conala_df.head()

Unnamed: 0,id,intent,intent_tokens,parent_answer_post_id,prob,question_id,slot_map,snippet,snippet_tokens
0,34705205_34705233_0,Sort a nested list by two elements,"[Sort, a, nested, list, by, two, elements]",34705233,0.869,34705205,{},"sorted(l, key=lambda x: (-int(x[1]), x[0]))","[sorted, (, l, ,, key, =, lambda, x, :, (, -, ..."
1,13905936_13905946_0,converting integer to list in python,"[converting, integer, to, list, in, python]",13905946,0.85267,13905936,{},[int(x) for x in str(num)],"[[, int, (, x, ), for, x, in, str, (, num, ), ]]"
2,13837848_13838041_0,Converting byte string in unicode string,"[Converting, byte, string, in, unicode, string]",13838041,0.852143,13837848,{},c.decode('unicode_escape'),"[c, ., decode, (, 'unicode_escape', )]"
3,23490152_23490179_0,List of arguments with argparse,"[List, of, arguments, with, argparse]",23490179,0.850829,23490152,{},"parser.add_argument('-t', dest='table', help='...","[parser, ., add_argument, (, '-t', ,, dest, =,..."
4,2721782_2721807_0,How to convert a Date string to a DateTime obj...,"[How, to, convert, a, Date, string, to, a, Dat...",2721807,0.840372,2721782,{},"datetime.datetime.strptime(s, '%Y-%m-%dT%H:%M:...","[datetime, ., datetime, ., strptime, (, s, ,, ..."


In [25]:
intent, code = conala_df.iloc[1].rewritten_intent, conala_df.iloc[1].snippet
print(intent)
print(code)

decode a hex string '4a4b4c' to UTF-8.
bytes.fromhex('4a4b4c').decode('utf-8')


In [49]:
code_rows = []
intent_rows = []

for row in conala_df.itertuples():
    intent_rows.append(" ".join([ "#NEWLINE" if x == "\n" else x for x in row.intent_tokens]))
    code_rows.append(" ".join([ "#NEWLINE" if x == "\n" else x for x in row.snippet_tokens ]))

In [50]:
"""Write predicted translation to file"""
with open('/tf/datagen/conala-train-mined.code', 'w+') as f:
    for l in code_rows:
        f.write(l+'\n')

In [51]:
"""Write intent to file so we can use test_conala.sh to decode"""
with open('/tf/datagen/conala-train-mined.intent', 'w+') as f:
    for intent in intent_rows:
        f.write( intent + '\n')

In [52]:
len(code_rows), len(intent_rows)

(593891, 593891)