In [1]:
import sys
import argparse
import sqlite3
import json
import time
import os
import regex as re
import numpy as np
os.chdir(os.path.dirname(os.path.abspath(os.path.dirname(os.path.abspath(os.path.dirname('__file__'))))))

from util.helpers import get_lines, extract_line_number, FailedToGetLineNumberException, _truncate_fix
from util.helpers import tokens_to_source, compilation_errors, apply_fix, InvalidFixLocationException

In [2]:
data_name = "ids"
data_type = "seeded"
rnn = "lstm"
pretrained_dir_name = None
test_path = "data/network_inputs/iitk-"+data_name+"-1189"
config_path = "models/config.json"

config_json = open(config_path).read()
config = json.loads(config_json)
config["hidden_size"] = 300
config["rnn_cell"] = rnn
config["n_layers"] = 4
config["dropout_p"] = 0.2
config["embedding_size"] = 50
config["use_attention"] = True
config["encoder_position_embedding"] = None
config["decoder_position_embedding"] = None
config["use_memory"] = None
seed=config["seed"]= 1189
#config["pos_add"] = "cat"

print(json.dumps(config, indent=4))
    
save_path = (data_name
            + ("_att" if config["use_attention"] else "")
            + ("_with_pos" if config["encoder_position_embedding"] is not None 
                or config["decoder_position_embedding"] is not None else "")
            + ("_encoder_" + config["encoder_position_embedding"]
                if config["encoder_position_embedding"] is not None else "")
            + ("_decoder_" + config["decoder_position_embedding"]
                if config["decoder_position_embedding"] is not None else "")
            + ("_cat" if config["pos_add"] == "cat" else "")
            + ("_use_stack" if config["use_memory"] == "stack" else "")
            + ("_use_queue" if config["use_memory"] == "queue" else "")
            + "_emb" + str(config["embedding_size"])
            + "_hidden" + str(config["hidden_size"])
            + ("_pretrained" if pretrained_dir_name is not None else ""))
print("Save_path : %s" % save_path)

database = "log/test/" + save_path + "/" + rnn + "/" + data_type + "/" + data_name + "_" + data_type + ".db"

{
    "encoder_max_len": 420,
    "decoder_max_len": 40,
    "embedding_size": 50,
    "hidden_size": 300,
    "input_dropout_p": 0,
    "dropout_p": 0.2,
    "n_layers": 4,
    "bidirectional": false,
    "rnn_cell": "lstm",
    "variable_lengths": true,
    "embedding": null,
    "update_embedding": true,
    "get_context_vector": false,
    "use_attention": true,
    "attn_layers": 1,
    "hard_attn": false,
    "encoder_position_embedding": null,
    "decoder_position_embedding": null,
    "pos_add": "add",
    "use_memory": null,
    "memory_dim": 5,
    "seed": 1189
}
Save_path : ids_att_emb50_hidden300


In [3]:
def _is_stop_signal(fix):
    if _truncate_fix(fix) == '':
        return True

In [4]:
def meets_criterion(incorrect_program_tokens, fix, type_, silent=True):
    lines = get_lines(incorrect_program_tokens)
    fix = _truncate_fix(fix)

    if _is_stop_signal(fix):
        return False

    try:
        fix_line_number = extract_line_number(fix)
    except FailedToGetLineNumberException:
        return False

    if fix_line_number >= len(lines):
        return False

    fix_line = lines[fix_line_number]

    # Make sure number of IDs is the same
    if len(re.findall('_<id>_\w*', fix_line)) != len(re.findall('_<id>_\w*', fix)):
        if not silent:
            print('number of ids is not the same')
        return False

    keywords_regex = '_<keyword>_\w+|_<type>_\w+|_<APIcall>_\w+|_<include>_\w+'

    if type_ == 'replace' and re.findall(keywords_regex, fix_line) != re.findall(keywords_regex, fix):
        if not silent:
            print('important words (keywords, etc.) change drastically')
        return False

    return True

In [5]:
def get_final_results(database):
    with sqlite3.connect(database) as conn:
        c = conn.cursor()

        error_counts = []

        for row in c.execute("SELECT iteration, COUNT(*) FROM error_messages GROUP BY iteration ORDER BY iteration;"):
            error_counts.append(row[1])

        query1 = """SELECT COUNT(*)
        FROM error_messages
        WHERE iteration = 0 AND prog_id NOT IN (SELECT p.prog_id FROM programs p INNER JOIN error_message_strings e ON p.prog_id = e.prog_id WHERE e.iteration = 0 AND e.error_message_count = 0);"""

        for row in c.execute(query1):
            initial_errors = row[0]

        query2 = """SELECT COUNT(*)
        FROM error_messages
        WHERE iteration = 10 AND prog_id NOT IN (SELECT p.prog_id FROM programs p INNER JOIN error_message_strings e ON p.prog_id = e.prog_id WHERE e.iteration = 0 AND e.error_message_count = 0);"""

        for row in c.execute(query2):
            final_errors = row[0]

        query3 = """SELECT COUNT(DISTINCT prog_id)
        FROM error_message_strings
        WHERE iteration = 10 AND error_message_count = 0 and prog_id NOT IN (SELECT p.prog_id FROM programs p INNER JOIN error_message_strings e ON p.prog_id = e.prog_id WHERE e.iteration = 0 AND e.error_message_count = 0);"""

        for row in c.execute(query3):
            fully_fixed = row[0]

        query4 = """SELECT DISTINCT prog_id, error_message_count FROM error_message_strings
        WHERE iteration = 0 AND error_message_count > 0 and prog_id NOT IN (SELECT p.prog_id FROM programs p INNER JOIN error_message_strings e ON p.prog_id = e.prog_id WHERE e.iteration = 0 AND e.error_message_count = 0);"""

        query5 = """SELECT DISTINCT prog_id, error_message_count FROM error_message_strings
        WHERE iteration = 10 AND error_message_count > 0 and prog_id NOT IN (SELECT p.prog_id FROM programs p INNER JOIN error_message_strings e ON p.prog_id = e.prog_id WHERE e.iteration = 0 AND e.error_message_count = 0);"""

        original_errors = {}
        for row in c.execute(query4):
            original_errors[row[0]] = int(row[1])

        partially_fixed = {}
        unfixed = {}
        for row in c.execute(query5):
            if int(row[1]) < original_errors[row[0]]:
                partially_fixed[row[0]] = int(row[1])
            elif int(row[1]) == original_errors[row[0]]:
                unfixed[row[0]] = int(row[1])
            else:
                print(row[0], row[1], original_errors[row[0]])

        token_counts = []
        assignments = None

        for row in c.execute("SELECT COUNT(DISTINCT prob_id) FROM programs p WHERE prog_id NOT IN (SELECT p.prog_id FROM programs p INNER JOIN error_message_strings e ON p.prog_id = e.prog_id WHERE e.iteration = 0 AND e.error_message_count = 0);"):
            assignments = int(row[0])

        for row in c.execute("SELECT code FROM programs p INNER JOIN error_message_strings e ON p.prog_id = e.prog_id WHERE e.iteration = 0 AND e.error_message_count <> 0;"):
            token_counts += [len(row[0].split())]

        avg_token_count = np.mean(token_counts)

        print("-------")
        print("Assignments:", assignments)
        print("Program count:", len(token_counts))
        print("Average token count:", avg_token_count)
        print("Error messages:", initial_errors)
        print("-------")

        print("Errors remaining:", final_errors)
        print("Reduction in errors:", (initial_errors - final_errors))
        print("Completely fixed programs:", fully_fixed)
        print("partially fixed programs:", len(partially_fixed))
        print("unfixed programs:", len(unfixed))
        print("-------")

In [6]:
def do_problem(problem_id):
    global reconstruction, errors, errors_full, total_count, errors_test

    c = conn.cursor()

    reconstruction[problem_id] = {}
    errors[problem_id] = {}
    errors_full[problem_id] = {}
    errors_test[problem_id] = []
    candidate_programs = []

    for row in c.execute('SELECT user_id, prog_id, code, name_dict, name_seq FROM programs WHERE prob_id = ?', (problem_id,)):
        user_id, prog_id, initial = row[0], row[1], row[2]
        name_dict = json.loads(row[3])
        name_seq = json.loads(row[4])

        candidate_programs.append(
            (user_id, prog_id, initial, name_dict, name_seq,))

    for _, prog_id, initial, name_dict, name_seq in candidate_programs:
        fixes_suggested_by_typo_network = []
        fixes_suggested_by_undeclared_network = []

        for row in c.execute('SELECT fix FROM iterations WHERE prog_id=? AND network = \'typo\' ORDER BY iteration', (prog_id,)):
            fixes_suggested_by_typo_network.append(row[0])

        for row in c.execute('SELECT fix FROM iterations WHERE prog_id=? AND network = \'ids\' ORDER BY iteration', (prog_id,)):
            fixes_suggested_by_undeclared_network.append(row[0])

        reconstruction[problem_id][prog_id] = [initial]
        temp_errors, temp_errors_full = compilation_errors(
            tokens_to_source(initial, name_dict, False))
        errors[problem_id][prog_id] = [temp_errors]
        errors_full[problem_id][prog_id] = [temp_errors_full]

        try:
            for fix in fixes_suggested_by_typo_network:
                if meets_criterion(reconstruction[problem_id][prog_id][-1], fix, 'replace'):
                    temp_prog = apply_fix(
                        reconstruction[problem_id][prog_id][-1], fix, 'replace')
                    temp_errors, temp_errors_full = compilation_errors(
                        tokens_to_source(temp_prog, name_dict, False))

                    if len(temp_errors) > len(errors[problem_id][prog_id][-1]):
                        break
                    else:
                        reconstruction[problem_id][prog_id].append(temp_prog)
                        errors[problem_id][prog_id].append(temp_errors)
                        errors_full[problem_id][prog_id].append(
                            temp_errors_full)
                else:
                    break

        except InvalidFixLocationException:
            print('Localization failed')

        while len(reconstruction[problem_id][prog_id]) <= 5:
            reconstruction[problem_id][prog_id].append(
                reconstruction[problem_id][prog_id][-1])
            errors[problem_id][prog_id].append(errors[problem_id][prog_id][-1])
            errors_full[problem_id][prog_id].append(
                errors_full[problem_id][prog_id][-1])

        already_fixed = []

        try:
            for fix in fixes_suggested_by_undeclared_network:
                if fix not in already_fixed:
                    temp_prog = apply_fix(
                        reconstruction[problem_id][prog_id][-1], fix, 'insert')
                    already_fixed.append(fix)
                    temp_errors, temp_errors_full = compilation_errors(
                        tokens_to_source(temp_prog, name_dict, False))

                    if len(temp_errors) > len(errors[problem_id][prog_id][-1]):
                        break
                    else:
                        reconstruction[problem_id][prog_id].append(temp_prog)
                        errors[problem_id][prog_id].append(temp_errors)
                        errors_full[problem_id][prog_id].append(
                            temp_errors_full)
                else:
                    pass

        except InvalidFixLocationException:
            print('Localization failed')

        while len(reconstruction[problem_id][prog_id]) <= 10:
            reconstruction[problem_id][prog_id].append(
                reconstruction[problem_id][prog_id][-1])
            errors[problem_id][prog_id].append(errors[problem_id][prog_id][-1])
            errors_full[problem_id][prog_id].append(
                errors_full[problem_id][prog_id][-1])

        errors_test[problem_id].append(errors[problem_id][prog_id])

        for k, errors_t, errors_full_t in zip(range(len(errors[problem_id][prog_id])), errors[problem_id][prog_id], errors_full[problem_id][prog_id]):
            c.execute("INSERT INTO error_message_strings VALUES(?, ?, ?, ?, ?)", (
                prog_id, k, 'typo', errors_full_t.decode('utf-8', 'ignore'), len(errors_t)))

            for error_ in errors_t:
                c.execute("INSERT INTO error_messages VALUES(?, ?, ?, ?)",
                            (prog_id, k, 'typo', error_.decode('utf-8', 'ignore'),))

    count_t = len(candidate_programs)
    total_count += count_t

    print('Committing changes to database...')
    conn.commit()
    print('Done!')

    c.close()

In [7]:
def subset(arr1, arr2):
    for x in arr1:
        if x not in arr2:
            return False

    return True

In [8]:
conn = sqlite3.connect(database)
c = conn.cursor()

c.execute('''CREATE TABLE IF NOT EXISTS error_message_strings (
                prog_id text NOT NULL,
                iteration text NOT NULL,
                network text NOT NULL,
                error_message_string text NOT NULL,
                error_message_count integer NOT NULL,
                FOREIGN KEY(prog_id, iteration, network) REFERENCES iterations(prog_id, iteration, network)
             )''')

problem_ids = []

for row in c.execute('SELECT DISTINCT prob_id FROM programs'):
    problem_ids.append(row[0])

c.close()

reconstruction = {}
errors = {}
errors_full = {}
errors_test = {}

fixes_per_stage = [0] * 10

total_count = 0

In [9]:
start = time.time()

for problem_id in problem_ids:
    do_problem(problem_id)

time_t = time.time() - start

conn.commit()
conn.close()

print('Total time:', time_t, 'seconds')
print('Total programs processed:', total_count)
print('Average time per program:', int(float(time_t) / float(total_count) * 1000), 'ms')

Committing changes to database...
Done!
Committing changes to database...
Done!
Committing changes to database...
Done!
Committing changes to database...
Done!
Committing changes to database...
Done!
Committing changes to database...
Done!
Committing changes to database...
Done!
Committing changes to database...
Done!
Committing changes to database...
Done!
Committing changes to database...
Done!
Committing changes to database...
Done!
Committing changes to database...
Done!
Committing changes to database...
Done!
Committing changes to database...
Done!
Committing changes to database...
Done!
Committing changes to database...
Done!
Committing changes to database...
Done!
Committing changes to database...
Done!
Committing changes to database...
Done!
Committing changes to database...
Done!
Committing changes to database...
Done!
Committing changes to database...
Done!
Committing changes to database...
Done!
Committing changes to database...
Done!
Committing changes to database...
Done!


In [10]:
import numpy as np

In [11]:
total_fixes_num = {}
errors_before = {}

for problem_id in errors_test:
    total_fixes_num[problem_id] = 0

    for j, seq in enumerate(errors_test[problem_id]):
        error_numbers = [len(x) for x in seq]
        skip = False

        for i in range(len(error_numbers) - 1):
            assert (not error_numbers[i + 1] > error_numbers[i])
            total_fixes_num[problem_id] += error_numbers[i] - \
                error_numbers[i + 1]

            if error_numbers[i] != error_numbers[i + 1]:
                fixes_per_stage[i] += error_numbers[i] - error_numbers[i + 1]

total_numerator = 0
total_denominator = 0

for problem_id in errors_test:
    print(problem_id)
    print('%d/%d' % (total_fixes_num[problem_id], sum([len(x[0]) for x in errors_test[problem_id]])))
    print('(%g%%)' % (float(total_fixes_num[problem_id]) / sum([len(x[0]) for x in errors_test[problem_id]]) * 100))

    total_numerator += total_fixes_num[problem_id]
    total_denominator += sum([len(x[0]) for x in errors_test[problem_id]])


print(int(float(total_numerator) * 100.0 / float(total_denominator)), '%')


for stage in range(len(fixes_per_stage)):
    print('Stage', stage, ':', fixes_per_stage[stage])

get_final_results(database)

prob200
22/259
(8.49421%)
prob99
42/294
(14.2857%)
prob215
91/269
(33.829%)
prob315
40/300
(13.3333%)
prob258
40/275
(14.5455%)
prob235
51/278
(18.3453%)
prob90
43/265
(16.2264%)
prob117
31/282
(10.9929%)
prob266
54/303
(17.8218%)
prob32
92/260
(35.3846%)
prob43
30/301
(9.96678%)
prob370
74/263
(28.1369%)
prob288
80/269
(29.7398%)
prob262
23/280
(8.21429%)
prob331
57/284
(20.0704%)
prob236
40/282
(14.1844%)
prob122
57/291
(19.5876%)
prob73
36/287
(12.5436%)
prob10
28/289
(9.68858%)
prob100
41/296
(13.8514%)
prob321
55/316
(17.4051%)
prob51
26/268
(9.70149%)
prob152
11/267
(4.11985%)
prob06
20/276
(7.24638%)
prob79
44/295
(14.9153%)
prob271
3/140
(2.14286%)
prob47
39/287
(13.5889%)
prob240
32/263
(12.1673%)
prob30
25/222
(11.2613%)
prob89
41/290
(14.1379%)
prob283
60/272
(22.0588%)
prob198
30/298
(10.0671%)
prob244
94/182
(51.6484%)
prob376
36/259
(13.8996%)
prob56
36/283
(12.7208%)
prob154
46/284
(16.1972%)
prob334
46/264
(17.4242%)
prob272
40/292
(13.6986%)
prob249
59/251
(23.506%)
pr