In [1]:
import sys
import numpy as np
import re

In [2]:
sys.path.append('../')

In [3]:
from test_function_generator import GenerateFunctions

In [4]:
gf = GenerateFunctions(10000)

In [5]:
data = gf.generate_functions()
targets = gf.generate_test_functions()

In [6]:
data[:3]

['\n\ndef dvlhfjo(qtherctg, ozjzrwmzj, yjfjbf, uu, pliu = False):\n    """\n    Comment\n    :test: dvlhfjo(44, 8, 47, 88) = False\n    """\n    return pliu',
 '\n\ndef bhwodllq_blkgnbmixd_n(pewadkae, ipekjnrhze = False):\n    """\n    Comment\n    :test: bhwodllq_blkgnbmixd_n(24) is not True\n    """\n    return ipekjnrhze',
 '\n\ndef yib_(dhcyp, va, elatrzbn, dme = True):\n    """\n    Comment\n    :test: yib_(91, 18, 75) is True\n    """\n    return dme']

In [7]:
targets[:3]

['\n\n    def test_dvlhfjo(self):\n        self.assertFalse(dvlhfjo(44, 8, 47, 88))',
 '\n\n    def test_bhwodllq_blkgnbmixd_n(self):\n        self.assertFalse(bhwodllq_blkgnbmixd_n(24))',
 '\n\n    def test_yib_(self):\n        self.assertTrue(yib_(91, 18, 75))']

In [8]:
def extract_intent(x):
    return re.search(r'^\s+:test:\s+(.*)$', x, re.MULTILINE).group(1)

In [9]:
intents = [extract_intent(x) for x in data]

In [10]:
intents[:3]

['dvlhfjo(44, 8, 47, 88) = False',
 'bhwodllq_blkgnbmixd_n(24) is not True',
 'yib_(91, 18, 75) is True']

In [11]:
# start and end char
targets = ['$' + x + '¤' for x in targets]

### Vectorize data

In [12]:
input_characters = sorted(list(set([x for line in intents for x in line])))
target_characters = sorted(list(set([x for line in targets for x in line])))


In [13]:
num_encoder_tokens = len(input_characters)
num_decoder_tokens = len(target_characters)

In [14]:
max_encoder_seq_length = max([len(txt) for txt in intents])
max_decoder_seq_length = max([len(txt) for txt in targets])

In [15]:
input_token_index = {char:i for i,char in enumerate(input_characters)}
target_token_index = {char:i for i,char in enumerate(target_characters)}

In [16]:
encoder_input_data = np.zeros((len(intents), max_encoder_seq_length, num_encoder_tokens), dtype='float32')
decoder_input_data = np.zeros((len(targets), max_decoder_seq_length, num_decoder_tokens), dtype='float32')
decoder_target_data = np.zeros((len(intents), max_decoder_seq_length, num_decoder_tokens), dtype='float32')

In [17]:
for i, (input_text, target_text) in enumerate(zip(intents, targets)):
    for t, char in enumerate(input_text):
        encoder_input_data[i, t, input_token_index[char]] = 1.
    for t, char in enumerate(target_text):
        decoder_input_data[i, t, target_token_index[char]] = 1.
        if t > 0:
            decoder_target_data[i, t - 1, target_token_index[char]] = 1.

## Model

In [18]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, LSTM, Dense

In [19]:
latent_dim = 256

In [20]:
encoder_inputs = Input(shape=(None, num_encoder_tokens))
encoder = LSTM(latent_dim, return_state=True)
encoder_outputs, state_h, state_c = encoder(encoder_inputs)
encoder_states = [state_h, state_c]

In [21]:
decoder_inputs = Input(shape=(None, num_decoder_tokens))
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(decoder_inputs, initial_state=encoder_states)
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)

In [22]:
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)

In [23]:
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, None, 45)     0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, None, 50)     0                                            
__________________________________________________________________________________________________
lstm (LSTM)                     [(None, 256), (None, 309248      input_1[0][0]                    
__________________________________________________________________________________________________
lstm_1 (LSTM)                   [(None, None, 256),  314368      input_2[0][0]                    
                                                                 lstm[0][1]                       
          

In [None]:
model.fit([encoder_input_data, decoder_input_data], decoder_target_data, batch_size=16, epochs=100, validation_split=0.2)

Train on 8000 samples, validate on 2000 samples
Epoch 1/100
Epoch 2/100
Epoch 4/100
Epoch 8/100
Epoch 13/100

In [84]:
model.save('../models/s2s.h5')



## Inference

In [85]:
encoder_model = Model(encoder_inputs, encoder_states)

decoder_state_input_h = Input(shape=(latent_dim,))
decoder_state_input_c = Input(shape=(latent_dim,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
decoder_outputs, state_h , state_c = decoder_lstm(decoder_inputs, initial_state=decoder_states_inputs)
decoder_states = [state_h, state_c]
decoder_outputs = decoder_dense(decoder_outputs)
decoder_model = Model([decoder_inputs] + decoder_states_inputs, [decoder_outputs] + decoder_states)

In [86]:
reverse_input_char_index = {i:char for char, i in input_token_index.items()}
reverse_target_char_index = {i:char for char, i in target_token_index.items()}

In [87]:
reverse_target_char_index

{0: '\n',
 1: ' ',
 2: '$',
 3: '(',
 4: ')',
 5: ',',
 6: '.',
 7: '0',
 8: '1',
 9: '2',
 10: '3',
 11: '4',
 12: '5',
 13: '6',
 14: '7',
 15: '8',
 16: '9',
 17: ':',
 18: 'A',
 19: 'D',
 20: 'E',
 21: 'F',
 22: 'N',
 23: 'R',
 24: 'S',
 25: 'T',
 26: '_',
 27: 'a',
 28: 'b',
 29: 'c',
 30: 'd',
 31: 'e',
 32: 'f',
 33: 'g',
 34: 'h',
 35: 'i',
 36: 'j',
 37: 'k',
 38: 'l',
 39: 'm',
 40: 'n',
 41: 'o',
 42: 'p',
 43: 'q',
 44: 'r',
 45: 's',
 46: 't',
 47: 'u',
 48: 'v',
 49: 'w',
 50: 'x',
 51: 'y',
 52: 'z',
 53: '¤'}

In [92]:
def decode_sequence(input_seq):
    states_value = encoder_model.predict(input_seq)
    target_seq = np.zeros((1, 1, num_decoder_tokens))
    target_seq[0,0,target_token_index['$']] = 1.
    
    stop_condition = False
    decoded_sentence = ''
    while not stop_condition:
        output_tokens, h, c = decoder_model.predict([target_seq] + states_value)
        
        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        sampled_char = reverse_target_char_index[sampled_token_index]
        decoded_sentence += sampled_char
        
        #exit
        if (sampled_char == '¤' or len(decoded_sentence) > max_decoder_seq_length):
            stop_condition = True
        
        target_seq = np.zeros((1,1, num_decoder_tokens))
        target_seq[0,0,sampled_token_index] = 1.
        
        states_value = [h,c]
    return decoded_sentence

In [104]:
for seq_index in range(100):
    input_seq = encoder_input_data[seq_index: seq_index+1]
    decoded_sentence = decode_sequence(input_seq)
    print('-')
    print('Input sentence:', intents[seq_index])
    print('Decoded sentence:', decoded_sentence)

-
Input sentence: dmjcfy_cva(38, 63, 24, 7) = True
Decoded sentence: START


         tttt________________lf            ssssesertrtttte(((____________________,,,                seesertttttt_
-
Input sentence: ijhl__k_(86, 32) is not True
Decoded sentence: START


         tttt________________lf            ssssesertrtttte(((____________________,,,                seesertttttt_
-
Input sentence: mopfyk_er_lon_rltr_(63) = 62
Decoded sentence: START


         tttt________________lf            ssssesertrtttte(((____________________,,,                seesertttttt_
-
Input sentence: rj___f_fwxpcz_hb_k(65, 10, 61, 84, 86) = 27
Decoded sentence: START


         tttt________________lf            ssssesertrtttte(((____________________,,,                seesertttttt_
-
Input sentence: gnybehjvbjnvok__(79, 100, 98, 72, 83) = True
Decoded sentence: START


         tttt________________lf            ssssesertrtttte(((____________________,,,                seesertttttt_
-
Input sentence: ht_bgc_zph_m

In [102]:
np.shape(encoder_input_data[0:1])

(1, 47, 45)

In [103]:
intents[0]

'dmjcfy_cva(38, 63, 24, 7) = True'