## Process data

In [1]:
buggy_data = ['((x + y) >= (z - 1))',
              '(a && b)',
              '(c > 0)',
              'd',
              '(e > f)']
fixed_data = ['((x + y) > (z - 1))',
              '(a && !(b))',
              '(c > 1)',
              '!(d)',
              '(f > e)']

buggy_codes = [list(x) for x in buggy_data]
fixed_codes = [['<soc>']+list(x)+['<eoc>'] for x in fixed_data]

print("Buggy codes:-")
for x in buggy_codes:
    print(x)
print("====================")
print("Fixed codes:-")
for x in fixed_codes:
    print(x)
print("====================")

vocab = set([x for y in buggy_codes for x in y]+[x for y in fixed_codes for x in y])
token_int_map = dict([(token, i+1) for i, token in enumerate(sorted(vocab-{' '}))])
token_int_map[' '] = 0
int_token_map = dict((i, token) for token, i in token_int_map.items())

print(token_int_map)
print("=====")
print(int_token_map)
print("====================")

vocab_size = len(vocab)
max_buggy_len = max([len(txt) for txt in buggy_codes])
max_fixed_len = max([len(txt) for txt in fixed_codes])
num_dps = len(fixed_codes)

print('Number of data points:', num_dps)
print('Vocabulary size:', vocab_size)
print('Max length in buggy codes:', max_buggy_len)
print('Max length in fixed codes:', max_fixed_len)


import numpy as np


buggy_inputs =  np.zeros((num_dps, max_buggy_len, vocab_size), dtype='float32')
fixed_inputs =  np.zeros((num_dps, max_fixed_len, vocab_size), dtype='float32')
fixed_outputs = np.zeros((num_dps, max_fixed_len, vocab_size), dtype='float32')

buggy_inputs[:, :, 0] = 1.
fixed_inputs[:, :, 0] = 1.
fixed_outputs[:, :, 0] = 1.

for i, (buggy, fixed) in enumerate(zip(buggy_codes, fixed_codes)):
    for t, token in enumerate(buggy):
        buggy_inputs[i, t, 0] = 0.
        buggy_inputs[i, t, token_int_map[token]] = 1.
    for t, token in enumerate(fixed):
        int_value = token_int_map[token]
        fixed_inputs[i, t, 0] = 0.
        fixed_inputs[i, t, int_value] = 1.
        if t > 0:
            fixed_outputs[i, t-1, 0] = 0.
            fixed_outputs[i, t-1, int_value] = 1.

Buggy codes:-
['(', '(', 'x', ' ', '+', ' ', 'y', ')', ' ', '>', '=', ' ', '(', 'z', ' ', '-', ' ', '1', ')', ')']
['(', 'a', ' ', '&', '&', ' ', 'b', ')']
['(', 'c', ' ', '>', ' ', '0', ')']
['d']
['(', 'e', ' ', '>', ' ', 'f', ')']
Fixed codes:-
['<soc>', '(', '(', 'x', ' ', '+', ' ', 'y', ')', ' ', '>', ' ', '(', 'z', ' ', '-', ' ', '1', ')', ')', '<eoc>']
['<soc>', '(', 'a', ' ', '&', '&', ' ', '!', '(', 'b', ')', ')', '<eoc>']
['<soc>', '(', 'c', ' ', '>', ' ', '1', ')', '<eoc>']
['<soc>', '!', '(', 'd', ')', '<eoc>']
['<soc>', '(', 'f', ' ', '>', ' ', 'e', ')', '<eoc>']
{'!': 1, '&': 2, '(': 3, ')': 4, '+': 5, '-': 6, '0': 7, '1': 8, '<eoc>': 9, '<soc>': 10, '=': 11, '>': 12, 'a': 13, 'b': 14, 'c': 15, 'd': 16, 'e': 17, 'f': 18, 'x': 19, 'y': 20, 'z': 21, ' ': 0}
=====
{1: '!', 2: '&', 3: '(', 4: ')', 5: '+', 6: '-', 7: '0', 8: '1', 9: '<eoc>', 10: '<soc>', 11: '=', 12: '>', 13: 'a', 14: 'b', 15: 'c', 16: 'd', 17: 'e', 18: 'f', 19: 'x', 20: 'y', 21: 'z', 0: ' '}
Number of data po

## LSTM Encoder Decoder

In [1]:
from keras.models import Model
from keras.layers import Input, LSTM, Dense
import numpy as np


buggy_data = ['((x + y) >= (z - 1))',
              '(a && b)',
              '(c > 0)',
              'd',
              '(e > f)']
fixed_data = ['((x + y) > (z - 1))',
              '(a && !(b))',
              '(c > 1)',
              '!(d)',
              '(f > e)']


epochs = 100  # Number of epochs to train for.
latent_dim = 256  # Latent dimensionality of the encoding space.

# Vectorize the data.
buggy_codes = [list(x) for x in buggy_data]
fixed_codes = [['<soc>']+list(x)+['<eoc>'] for x in fixed_data]
vocab = set([x for y in buggy_codes for x in y] + [x for y in fixed_codes for x in y])

vocab = sorted(list(vocab))
vocab_size = len(vocab)
max_encoder_seq_length = max([len(x) for x in buggy_codes])
max_decoder_seq_length = max([len(x) for x in fixed_codes])
token_index = dict(
    [(char, i) for i, char in enumerate(vocab)])
reverse_token_index = dict(
    (i, char) for char, i in token_index.items())

num_dps = len(buggy_codes)

print('Number of samples:', num_dps)
print('Vocabulary size:', vocab_size)
print('Max sequence length for inputs:', max_encoder_seq_length)
print('Max sequence length for outputs:', max_decoder_seq_length)
print("====================")
print('Token-integer mapping:-')
print(token_index)
print(reverse_token_index)
print("====================")

encoder_input_data = np.zeros(
    (num_dps, max_encoder_seq_length, vocab_size),
    dtype='float32')
decoder_input_data = np.zeros(
    (num_dps, max_decoder_seq_length, vocab_size),
    dtype='float32')
decoder_target_data = np.zeros(
    (num_dps, max_decoder_seq_length, vocab_size),
    dtype='float32')

for i, (buggy_code, fixed_code) in enumerate(zip(buggy_codes, fixed_codes)):
    for t, char in enumerate(buggy_code):
        encoder_input_data[i, t, token_index[char]] = 1.
    encoder_input_data[i, t + 1:, token_index[' ']] = 1.
    for t, char in enumerate(fixed_code):
        # decoder_target_data is ahead of decoder_input_data by one timestep
        decoder_input_data[i, t, token_index[char]] = 1.
        if t > 0:
            # decoder_target_data will be ahead by one timestep
            # and will not include the start character.
            decoder_target_data[i, t - 1, token_index[char]] = 1.
    decoder_input_data[i, t + 1:, token_index[' ']] = 1.
    decoder_target_data[i, t:, token_index[' ']] = 1.


# Define an input sequence and process it.
encoder_inputs = Input(shape=(None, vocab_size))
encoder = LSTM(latent_dim, return_state=True)
encoder_outputs, state_h, state_c = encoder(encoder_inputs)
# We discard `encoder_outputs` and only keep the states.
encoder_states = [state_h, state_c]

# Set up the decoder, using `encoder_states` as initial state.
decoder_inputs = Input(shape=(None, vocab_size))
# We set up our decoder to return full output sequences,
# and to return internal states as well. We don't use the
# return states in the training model, but we will use them in inference.
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(decoder_inputs,
                                     initial_state=encoder_states)
decoder_dense = Dense(vocab_size, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)

# Define the model that will turn
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)

# Run training
model.compile(optimizer='rmsprop', loss='categorical_crossentropy',
              metrics=['accuracy'])
model.fit([encoder_input_data, decoder_input_data], decoder_target_data,
          epochs=epochs)
# Save model
# model.save('s2s.h5')

# Next: inference mode (sampling).
# Here's the drill:
# 1) encode input and retrieve initial decoder state
# 2) run one step of decoder with this initial state
# and a "start of sequence" token as target.
# Output will be the next target token
# 3) Repeat with the current target token and current states

# Define sampling models
encoder_model = Model(encoder_inputs, encoder_states)

decoder_state_input_h = Input(shape=(latent_dim,))
decoder_state_input_c = Input(shape=(latent_dim,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
decoder_outputs, state_h, state_c = decoder_lstm(
    decoder_inputs, initial_state=decoder_states_inputs)
decoder_states = [state_h, state_c]
decoder_outputs = decoder_dense(decoder_outputs)
decoder_model = Model(
    [decoder_inputs] + decoder_states_inputs,
    [decoder_outputs] + decoder_states)


def decode_sequence(input_seq):
    # Encode the input as state vectors.
    states_value = encoder_model.predict(input_seq)

    # Generate empty target sequence of length 1.
    target_seq = np.zeros((1, 1, vocab_size))
    # Populate the first character of target sequence with the start character.
    target_seq[0, 0, token_index['<soc>']] = 1.

    # Sampling loop for a batch of sequences
    # (to simplify, here we assume a batch of size 1).
    stop_condition = False
    decoded_sentence = ''
    while not stop_condition:
        output_tokens, h, c = decoder_model.predict(
            [target_seq] + states_value)

        # Sample a token
        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        sampled_char = reverse_token_index[sampled_token_index]
        decoded_sentence += sampled_char

        # Exit condition: either hit max length
        # or find stop character.
        if (sampled_char == '<eoc>' or
           len(decoded_sentence) > max_decoder_seq_length):
            stop_condition = True

        # Update the target sequence (of length 1).
        target_seq = np.zeros((1, 1, vocab_size))
        target_seq[0, 0, sampled_token_index] = 1.

        # Update states
        states_value = [h, c]

    return decoded_sentence


for seq_index in range(5):
    # Take one sequence (part of the training set)
    # for trying out decoding.
    input_seq = encoder_input_data[seq_index: seq_index + 1]
    decoded_sentence = decode_sequence(input_seq)
    print('-')
    print('Input sentence:', buggy_data[seq_index])
    print('Decoded sentence:', decoded_sentence)

Using TensorFlow backend.
W0918 19:53:29.113529 140057484760832 deprecation_wrapper.py:119] From /home/aziz/anaconda3/envs/tf/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

W0918 19:53:29.124272 140057484760832 deprecation_wrapper.py:119] From /home/aziz/anaconda3/envs/tf/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0918 19:53:29.127001 140057484760832 deprecation_wrapper.py:119] From /home/aziz/anaconda3/envs/tf/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.



Number of samples: 5
Vocabulary size: 22
Max sequence length for inputs: 20
Max sequence length for outputs: 21
Token-integer mapping:-
{' ': 0, '!': 1, '&': 2, '(': 3, ')': 4, '+': 5, '-': 6, '0': 7, '1': 8, '<eoc>': 9, '<soc>': 10, '=': 11, '>': 12, 'a': 13, 'b': 14, 'c': 15, 'd': 16, 'e': 17, 'f': 18, 'x': 19, 'y': 20, 'z': 21}
{0: ' ', 1: '!', 2: '&', 3: '(', 4: ')', 5: '+', 6: '-', 7: '0', 8: '1', 9: '<eoc>', 10: '<soc>', 11: '=', 12: '>', 13: 'a', 14: 'b', 15: 'c', 16: 'd', 17: 'e', 18: 'f', 19: 'x', 20: 'y', 21: 'z'}


W0918 19:53:29.568046 140057484760832 deprecation_wrapper.py:119] From /home/aziz/anaconda3/envs/tf/lib/python3.7/site-packages/keras/optimizers.py:790: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.

W0918 19:53:29.583596 140057484760832 deprecation_wrapper.py:119] From /home/aziz/anaconda3/envs/tf/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:3295: The name tf.log is deprecated. Please use tf.math.log instead.

W0918 19:53:29.674618 140057484760832 deprecation.py:323] From /home/aziz/anaconda3/envs/tf/lib/python3.7/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
W0918 19:53:30.584087 140057484760832 deprecation_wrapper.py:119] From /home/aziz/anaconda3/envs/tf/lib/python3.7/site-packages/keras

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

Epoch 89/100
Epoch 90/100
Epoch 91/100
Epoch 92/100
Epoch 93/100
Epoch 94/100
Epoch 95/100
Epoch 96/100
Epoch 97/100
Epoch 98/100
Epoch 99/100
Epoch 100/100
-
Input sentence: ((x + y) >= (z - 1))
Decoded sentence: ((x + y))             
-
Input sentence: (a && b)
Decoded sentence: (a &&&!(b)))<eoc>
-
Input sentence: (c > 0)
Decoded sentence: (f >  ))              
-
Input sentence: d
Decoded sentence: (( >  )<eoc>
-
Input sentence: (e > f)
Decoded sentence: (f >  ))              


In [None]:
# def generate_fixed_ints(enc_dec, bugs, fixed_len, token_map, int_map):
#     gntd_ints = np.zeros(shape=(len(bugs), fixed_len))
#     gntd_ints[:, 0] = token_map["<soc>"]
#     for buggy, generated in zip(bugs, gntd_ints):
#         buggy_input = buggy[np.newaxis]
#         gntd_in_out = generated[np.newaxis]
#         for i in range(1, fixed_len):
#             prediction = enc_dec.predict([buggy_input, gntd_in_out]).argmax(axis=2)
#             if int_map[prediction[:, i][0]] == "<eoc>":
#                 break
#             generated[i] = prediction[:, i]
    
#     return gntd_ints


# def decode_ints(int_matrix, int_map):
#     gntd_codes = []
#     for ints in int_matrix:
#         code = [int_map[x] for x in ints if x != 0]
#         gntd_codes.append(code)
        
#     return gntd_codes


# print('=============')
# print('=============')
# print('=============')
# generated_ints = generate_fixed_ints(seq2seq, buggy_inputs, max_fixed_len, token_int_map, int_token_map)
# generated_codes = decode_ints(generated_ints, int_token_map)
# for buggy, fixed, gnrtd in zip(buggy_codes, fixed_codes, generated_codes):
#     print('=============')
#     print('Buggy code:', ' '.join(buggy[1:-1]))
#     print('Fixed code:', ' '.join(fixed[1:-1]))
#     print('Genration: ', ' '.join(gnrtd[1:]))



def from_mats_to_seqs(mats, int_map):
    gntd_ints = []
    for matrix in mats:
        gntd_seq = []
        for row in matrix:
            for i, token in enumerate(row):
                if token == 1.:
                    gntd_seq.append(i)
                    if int_map[i] == "<eoc>":
                        break
        gntd_ints.append(gntd_seq)
    
    return gntd_ints


def decode_ints(int_matrix, int_map):
    gntd_codes = []
    for ints in int_matrix:
        code = [int_map[x] for x in ints if x != 0]
        gntd_codes.append(code)
        
    return gntd_codes


def generate_fixed_ints(enc_dec, bugs, fixed_len, v_size, token_map, int_map):
    gntd_mats = np.zeros(shape=(len(bugs), fixed_len, v_size))
    gntd_mats[:, 0, token_map["<soc>"]] = 1.
    gntd_mats[:, 1:, 0] = 1.
    print(gntd_mats.shape)
    for j, buggy in enumerate(bugs):  # for seq in dps
        buggy_input = buggy[np.newaxis]
        gntd_in_out = gntd_mats[j]
        gntd_in_out = gntd_in_out[np.newaxis]
        for i in range(1, fixed_len):  # for token in dp
            prediction = enc_dec.predict([buggy_input, gntd_in_out]).argmax(axis=2)
            gntd_mats[j, i, 0] = 0.
            gntd_mats[j, i, prediction[:, i][0]] = 1.
#             print(from_mats_to_seqs(gntd_mats, int_map))
            if int_map[prediction[:, i][0]] == "<eoc>":
                print('hi')
                break
#     gntd_ints = []
#     for matrix in gntd_mats:
#         gntd_seq = []
#         for row in matrix:
#             for i, token in enumerate(row):
#                 if token == 1.:
#                     if int_map[i] == "<eoc>":
#                         break
#                     gntd_seq.append(i)
#             gntd_ints.append(gntd_seq)
    
    return from_mats_to_seqs(gntd_mats, int_map)





print('=============')
print('=============')
print('=============')
generated_ints = generate_fixed_ints(seq2seq, buggy_inputs, max_fixed_len, vocab_size, token_int_map, int_token_map)
generated_codes = decode_ints(generated_ints, int_token_map)
for buggy, fixed, gnrtd in zip(buggy_codes, fixed_codes, generated_codes):
    print('=============')
    print('Buggy code:', ' '.join(buggy))
    print('Fixed code:', ' '.join(fixed))
    print('Genration: ', ' '.join(gnrtd))

In [None]:
encoder_model = Model(encoder_inputs, encoder_states)
decoder_state_input_h = Input(shape=(latent_dim,))
decoder_state_input_c = Input(shape=(latent_dim,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
decoder_outputs, state_h, state_c = decoder_lstm(decoder_inputs, initial_state=decoder_states_inputs)
decoder_states = [state_h, state_c]
decoder_outputs = decoder_dense(decoder_outputs)
decoder_model = Model([decoder_inputs] + decoder_states_inputs, [decoder_outputs] + decoder_states)



def decode_sequence(input_seq):
    # Encode the input as state vectors.
    states_value = encoder_model.predict(input_seq)

    # Generate empty target sequence of length 1.
    target_seq = np.zeros((1, 1, num_decoder_tokens))
    # Populate the first character of target sequence with the start character.
    target_seq[0, 0, target_token_index['\t']] = 1.

    # Sampling loop for a batch of sequences
    # (to simplify, here we assume a batch of size 1).
    stop_condition = False
    decoded_sentence = ''
    while not stop_condition:
        output_tokens, h, c = decoder_model.predict(
            [target_seq] + states_value)

        # Sample a token
        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        sampled_char = reverse_target_char_index[sampled_token_index]
        decoded_sentence += sampled_char

        # Exit condition: either hit max length
        # or find stop character.
        if (sampled_char == '\n' or
           len(decoded_sentence) > max_decoder_seq_length):
            stop_condition = True

        # Update the target sequence (of length 1).
        target_seq = np.zeros((1, 1, num_decoder_tokens))
        target_seq[0, 0, sampled_token_index] = 1.

        # Update states
        states_value = [h, c]

    return decoded_sentence

In [28]:
decode_ints(from_mats_to_seqs(fixed_inputs, int_token_map), int_token_map)

[['<soc>',
  '(',
  '(',
  'x',
  '+',
  'y',
  ')',
  '>',
  '(',
  'z',
  '-',
  '1',
  ')',
  ')',
  '<eoc>'],
 ['<soc>', '(', 'a', '&', '&', '!', '(', 'b', ')', ')', '<eoc>'],
 ['<soc>', '(', 'c', '>', '1', ')', '<eoc>'],
 ['<soc>', '!', '(', 'd', ')', '<eoc>'],
 ['<soc>', '(', 'f', '>', 'e', ')', '<eoc>']]

In [29]:
decode_ints(from_mats_to_seqs(fixed_outputs, int_token_map), int_token_map)

[['(', '(', 'x', '+', 'y', ')', '>', '(', 'z', '-', '1', ')', ')', '<eoc>'],
 ['(', 'a', '&', '&', '!', '(', 'b', ')', ')', '<eoc>'],
 ['(', 'c', '>', '1', ')', '<eoc>'],
 ['!', '(', 'd', ')', '<eoc>'],
 ['(', 'f', '>', 'e', ')', '<eoc>']]

## GANs

In [1]:
from keras.layers import Input, Concatenate, Embedding, LSTM, Dense, dot, Activation, concatenate, Lambda
from keras.models import Model
from keras.backend import argmax, cast


def build_discriminator(dimension, v_size, buggy_len, fixed_len):
    buggy_input_layer = Input(shape=(buggy_len,))
    fixed_input_layer = Input(shape=(fixed_len,))
    concatted = Concatenate()([buggy_input_layer, fixed_input_layer])
    embed_lay = Embedding(v_size, dimension, mask_zero=True)(concatted)
    x = LSTM(dimension)(embed_lay)
    out = Dense(1, activation='sigmoid')(x)
    disc = Model([buggy_input_layer, fixed_input_layer], out)
    disc.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'], loss_weights=[0.5])
    
    return disc


def build_generator(dimension, v_size, buggy_len, fixed_len):
    # Encoder
    buggy_input_layer = Input(shape=(buggy_len,))
    enc_embed_lay = Embedding(v_size, dimension, mask_zero=True)(buggy_input_layer)
    encoder_outputs, state_h, state_c = LSTM(dimension, return_sequences=True, return_state=True)(enc_embed_lay)
    # Decoder
    fixed_input_layer = Input(shape=(fixed_len,))
    dec_embed_lay = Embedding(v_size, dimension, mask_zero=True)(fixed_input_layer)
    decoder_outputs = LSTM(dimension, return_sequences=True)(dec_embed_lay, initial_state=[state_h, state_c])
    # Attention
    attention = dot([decoder_outputs, encoder_outputs], axes=[2, 2])
    attention = Activation('softmax', name='attention')(attention)
    context = dot([attention, encoder_outputs], axes=[2, 1])
    decoder_combined_context = concatenate([context, decoder_outputs])
    attention_context_output = Dense(dimension, activation="tanh")(decoder_combined_context)
    # Model output
    model_output = Dense(v_size, activation="softmax")(attention_context_output)
    # Build model
    gen = Model([buggy_input_layer, fixed_input_layer], model_output)
    
    return gen


def build_gan(gen, disc, buggy_len, fixed_len):
    disc.trainable = False
    buggy_input_layer = Input(shape=(buggy_len,))
    fixed_input_layer = Input(shape=(fixed_len,))
    gen_out = gen([buggy_input_layer, fixed_input_layer])
    argmax_layer = Lambda(lambda x: cast(argmax(x, axis=2), dtype='float32'))
    disc_out = disc([buggy_input_layer, argmax_layer(gen_out)])
    gan = Model([buggy_input_layer, fixed_input_layer], [disc_out, gen_out])
    # compile model
    gan.compile(loss=['binary_crossentropy', 'categorical_crossentropy'], optimizer='rmsprop', loss_weights=[1, 100])
    
    return gan


%matplotlib inline
from keras.utils.vis_utils import plot_model
from IPython.display import Image


latent_dim = 512

discriminator = build_discriminator(latent_dim, vocab_size, max_buggy_len, max_fixed_len)
plot_model(discriminator, to_file='discriminator_model_plot.png', show_shapes=True, show_layer_names=True)
# Image('discriminator_model_plot.png')

generator = build_generator(latent_dim, vocab_size, max_buggy_len, max_fixed_len)
plot_model(generator, to_file='generator_model_plot.png', show_shapes=True, show_layer_names=True)
# Image('generator_model_plot.png')

gan = build_gan(generator, discriminator, max_buggy_len, max_fixed_len)
plot_model(gan, to_file='gan_model_plot.png', show_shapes=True, show_layer_names=True)
# gan.summary()
# Image('gan_model_plot.png')


def generate_fixed_ints(gen, bugs, fixed_len, token_map, int_map):
    gntd_ints = np.zeros(shape=(len(bugs), fixed_len))
    gntd_ints[:, 0] = token_map["<soc>"]
    for buggy, generated in zip(bugs, gntd_ints):
        buggy_input = buggy[np.newaxis]
        gntd_in_out = generated[np.newaxis]
        for i in range(1, fixed_len):
            prediction = gen.predict([buggy_input, gntd_in_out]).argmax(axis=2)
            if int_map[prediction[:, i][0]] == "<eoc>":
                break
            generated[i] = prediction[:, i]
    
    return gntd_ints


epochs = 20

for e in range(epochs):
    discriminator.fit([buggy_inputs, fixed_inputs], np.ones(num_dps))
    generated_ints = generate_fixed_ints(generator, buggy_inputs, max_fixed_len, token_int_map, int_token_map)
    discriminator.fit([buggy_inputs, generated_ints], np.zeros(num_dps))
    gan.fit([buggy_inputs, fixed_inputs], [np.ones(num_dps), fixed_outputs])


def decode_ints(int_matrix, int_map):
    gntd_codes = []
    for ints in int_matrix:
        code = [int_map[x] for x in ints if x != 0]
        gntd_codes.append(code)
        
    return gntd_codes


print('=============')
print('=============')
print('=============')
generated_ints = generate_fixed_ints(generator, buggy_inputs, max_fixed_len, token_int_map, int_token_map)
generated_codes = decode_ints(generated_ints, int_token_map)
for buggy, fixed, gnrtd in zip(buggy_codes, fixed_codes, generated_codes):
    print('=============')
    print('Buggy code:', ' '.join(buggy[1:-1]))
    print('Fixed code:', ' '.join(fixed[1:-1]))
    print('Genration: ', ' '.join(gnrtd[1:]))

Buggy codes:-
['(', '(', 'x', '+', 'y', ')', '>', '=', '(', 'z', '-', '1', ')', ')']
['(', 'a', '&', '&', 'b', ')']
['(', 'c', '>', '0', ')']
['d']
['(', 'e', '>', 'f', ')']
Fixed codes:-
['<soc>', '(', '(', 'x', '+', 'y', ')', '>', '(', 'z', '-', '1', ')', ')', '<eoc>']
['<soc>', '(', 'a', '&', '&', '!', '(', 'b', ')', ')', '<eoc>']
['<soc>', '(', 'c', '>', '1', ')', '<eoc>']
['<soc>', '!', '(', 'd', ')', '<eoc>']
['<soc>', '(', 'f', '>', 'e', ')', '<eoc>']
{1: 'f', 2: '+', 3: '>', 4: 'e', 5: 'z', 6: 'y', 7: '1', 8: 'x', 9: '=', 10: '<soc>', 11: 'b', 12: '(', 13: '&', 14: 'a', 15: '-', 16: 'd', 17: '0', 18: '<eoc>', 19: ')', 20: 'c', 21: '!', 0: '<pad/unknown>'}
Number of data points: 5
Vocabulary size: 22
Max length in buggy codes: 14
Max length in fixed codes: 15


Using TensorFlow backend.
W0913 23:20:54.030289 140003070756608 deprecation_wrapper.py:119] From /home/aziz/anaconda3/envs/tf/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

W0913 23:20:54.040325 140003070756608 deprecation_wrapper.py:119] From /home/aziz/anaconda3/envs/tf/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0913 23:20:54.044531 140003070756608 deprecation_wrapper.py:119] From /home/aziz/anaconda3/envs/tf/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.

W0913 23:20:54.357945 140003070756608 deprecation.py:323] From /home/aziz/anaconda3/envs/tf/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:2974: add_dispatch_support.<locals>.wrapper (fro

Epoch 1/1
Epoch 1/1


  'Discrepancy between trainable weights and collected trainable'


Epoch 1/1
Epoch 1/1


  'Discrepancy between trainable weights and collected trainable'


Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Epoch 1/1
Buggy code: ( x + y ) > = ( z - 1 )
Fixed code: ( ( x + y ) > ( z - 1 ) )
Genration:  ( ( ( x + y y ( y ( ) ) ( z
Buggy code: a & & b
Fixed code: ( a & & ! ( b ) )
Genration:  ( a & & ! ( b ) )
Buggy code: c > 0
Fixed code: ( c > 1 )
Genration:  ( c > ) )
Buggy code: 
Fixed code: ! ( d )
Genration:  ! ( d )
Buggy code: e > f
Fixed code: ( f > e )
Genration:  ( f > ) )
