In [None]:
%env OPENAI_API_KEY = api_key

In [None]:
# pip install --upgrade openai

In [None]:
import glob
import json
import re
import random
import time
import subprocess

from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

In [None]:
import openai
import os

openai.api_key = os.getenv("OPENAI_API_KEY")

In [None]:
cd '/home/azmain/alljavajsons'

In [None]:
inputFiles = []
correctOutputFiles = []
    
numbers = re.compile(r'(\d+)')
def numericalSort(value):
    parts = numbers.split(value)
    parts[1::2] = map(int, parts[1::2])
    return parts

for file in sorted(glob.glob("*.java.json"), key=numericalSort):
    inputFiles.append(file)

for file in sorted(glob.glob("*.benchmark_log.json"), key=numericalSort):
    correctOutputFiles.append(file)

print(inputFiles)
print('\n\n\n')
print(correctOutputFiles)

In [None]:
def get_codes(inputFiles):
    codes = []
    for code in inputFiles:
        codes.append(str(json.load(open(code))['originalContent']))

    return codes

In [None]:
def get_correct_outputs(correctOutputFiles):
    correct_outputs = []
    for output in correctOutputFiles:
        correct_output_list = json.load(open(output))['total_imports']
        correct_output_list = ["import "+i+";" for i in correct_output_list]
        correct_outputs.append(correct_output_list)
    
    for import_lines in correct_outputs:
        if "import gen.R;" in import_lines:
            import_lines.remove("import gen.R;")

    return correct_outputs

In [None]:
def get_dataset(codes, correct_outputs):
    dataset = {
        "codes": codes,
        "correct_outputs": correct_outputs
    }
    return dataset

In [None]:
def get_test_examples_and_y_true(dataset):
    sample_list = []

    for i in range(0, len(dataset["codes"])):
        sample_list.append(dict(codes=dataset["codes"][i], correct_outputs=dataset["correct_outputs"][i]))
    
    # print(sample_list)
    
    test_examples = [(example["codes"], example["correct_outputs"]) for example in sample_list]
    y_true = [correct_outputs for _, correct_outputs in test_examples]
    
    return test_examples, y_true

In [None]:
print("Total Android Codes: {}\n".format(len(inputFiles[:50])))
print(inputFiles[:50])

android_codes = get_codes(inputFiles[:50])
# print(android_codes)

android_correct_outputs = get_correct_outputs(correctOutputFiles[:50])
# print(android_correct_outputs)

android_dataset = get_dataset(android_codes, android_correct_outputs)
# print(android_dataset)

android_test_examples, android_y_true = get_test_examples_and_y_true(android_dataset)
# print(android_test_examples)

In [None]:
print("Total JDK Codes: {}\n".format(len(inputFiles[50:73])))
print(inputFiles[50:73])

jdk_codes = get_codes(inputFiles[50:73])
# print(jdk_codes)

jdk_correct_outputs = get_correct_outputs(correctOutputFiles[50:73])
# print(jdk_correct_outputs)

jdk_dataset = get_dataset(jdk_codes, jdk_correct_outputs)
# print(jdk_dataset)

jdk_test_examples, jdk_y_true = get_test_examples_and_y_true(jdk_dataset)
# print(jdk_test_examples)

In [None]:
print("Total Hibernate Codes: {}\n".format(len(inputFiles[73:74] + inputFiles[174:224])))
print(inputFiles[73:74] + inputFiles[174:224])

hibernate_codes = get_codes(inputFiles[73:74] + inputFiles[174:224])
# print(hibernate_codes)

hibernate_correct_outputs = get_correct_outputs(correctOutputFiles[73:74] + correctOutputFiles[174:224])
# print(hibernate_correct_outputs)

hibernate_dataset = get_dataset(hibernate_codes, hibernate_correct_outputs)
# print(hibernate_dataset)

hibernate_test_examples, hibernate_y_true = get_test_examples_and_y_true(hibernate_dataset)
# print(hibernate_test_examples)

In [None]:
print("Total JodaTime Codes: {}\n".format(len(inputFiles[74:124])))
print(inputFiles[74:124])

jodatime_codes = get_codes(inputFiles[74:124])
# print(jodatime_codes)

jodatime_correct_outputs = get_correct_outputs(correctOutputFiles[74:124])
# print(jodatime_correct_outputs)

jodatime_dataset = get_dataset(jodatime_codes, jodatime_correct_outputs)
# print(jodatime_dataset)

jodatime_test_examples, jodatime_y_true = get_test_examples_and_y_true(jodatime_dataset)
# print(jodatime_test_examples)

In [None]:
print("Total GWT Codes: {}\n".format(len(inputFiles[124:174])))
print(inputFiles[124:174])

gwt_codes = get_codes(inputFiles[124:174])
# print(gwt_codes)

gwt_correct_outputs = get_correct_outputs(correctOutputFiles[124:174])
# print(gwt_correct_outputs)

gwt_dataset = get_dataset(gwt_codes, gwt_correct_outputs)
# print(gwt_dataset)

gwt_test_examples, gwt_y_true = get_test_examples_and_y_true(gwt_dataset)
# print(gwt_test_examples)

In [None]:
print("Total XStream Codes: {}\n".format(len(inputFiles[224:268])))
print(inputFiles[224:268])

xstream_codes = get_codes(inputFiles[224:268])
# print(xstream_codes)

xstream_correct_outputs = get_correct_outputs(correctOutputFiles[224:268])
# print(xstream_correct_outputs)

xstream_dataset = get_dataset(xstream_codes, xstream_correct_outputs)
# print(xstream_dataset)

xstream_test_examples, xstream_y_true = get_test_examples_and_y_true(xstream_dataset)
# print(xstream_test_examples)

In [None]:
def group_imports(predicted_imports):
    added_imports = []
    for import_group in predicted_imports:
        new_group = []
        for import_statement in import_group:
            new_group.append(import_statement)
        added_imports.append(new_group)
    
    return added_imports

In [None]:
def append_imports(save_directory, code_name, list_of_java_codes, list_of_imports):
    code_import_dict = {}
    # Check if lengths match
    if len(list_of_java_codes) != len(list_of_imports):
        print("Mismatch between number of Java code strings and import lists!")
    else:
        # Zip the lists together into a dictionary
        code_import_dict = {i: (imports, code) for i, (imports, code) in enumerate(zip(list_of_imports, list_of_java_codes))}

    # Prepend the imports to the Java codes, save each to a .java file
    for index, (imports, code) in code_import_dict.items():
        # Split the code into lines
        lines = code.split('\n')
        
        # Find the line with the package declaration
        package_line_index = next((i for i, line in enumerate(lines) if line.strip().startswith('package ')), None)

        # If a package declaration is found, insert the imports after it
        if package_line_index is not None:
            lines = lines[:package_line_index+1] + imports + lines[package_line_index+1:]
        else:
            # If not, prepend the imports to the code
            lines = imports + lines

        full_code = '\n'.join(lines)
        file_name = ""
        if code_name == "android":
            if index<9:
                file_name = f"Android0{index+1}.java"
            else:
                file_name = f"Android{index+1}.java"
        elif code_name == "jdk":
            file_name = f"Class_{index+1}.java" 
        elif code_name == "hibernate":
            if index == 0:
                file_name = f"HibernateUtil.java"
            else:
                file_name = f"hibernate_class_{index}.java"
        elif code_name == "jodatime":
            if index<9:
                file_name = f"JodaTime0{index+1}.java"
            else:
                file_name = f"JodaTime{index+1}.java"
        elif code_name == "gwt":
            file_name = f"gwt_class_{index+1}.java"
        elif code_name == "xstream":
            file_name = f"xstream_class_{index+1}.java"
        full_path = os.path.join(save_directory, file_name)

        # Save the full code to a .java file
        with open(full_path, "w") as f:
            f.write(full_code)

# Base Prompt Implementation

In [None]:
MODEL = 'gpt-3.5-turbo'

In [None]:
def pred_process(y_pred, y_true):
    y_pred_processed = []
    y_true_processed = []
    
    for pred, correct_imports in zip(y_pred, y_true):
        max_length = max(len(pred), len(correct_imports))
        correct_preds = list(set(pred).intersection(correct_imports))
#         print('Correct Predictions:', correct_preds)
#         wrong_preds = max_length - len(correct_preds)
#         print('Wrong Predictions:', wrong_preds)

        for i in range(0, max_length):
            if i<len(correct_preds):
                y_pred_processed.append(1)
                y_true_processed.append(1)
            else:
                if i<len(correct_imports):
                    y_pred_processed.append(0)
                    y_true_processed.append(1)
                else:
                    y_pred_processed.append(1)
                    y_true_processed.append(0)
            
    print(y_pred_processed)
    print(y_true_processed)
    print()
    return y_pred_processed, y_true_processed

In [None]:
def eval_performance(y_pred, y_true):
    print(json.dumps({
        "accuracy": accuracy_score(y_true, y_pred),
        "f1": f1_score(y_true, y_pred),
        "recall": recall_score(y_true, y_pred),
        "precision": precision_score(y_true, y_pred)
    }, indent=2))

In [None]:
def get_prediction(code_snippet):
    retry_delay = 2
    while True:
        try:
            response = openai.ChatCompletion.create(
                model=MODEL,
                messages=[
                    {"role": "system", "content": "Reply with only code, no elaboration."},
                    {"role": "user", "content": f"Make the code below compilable:\n\n{code_snippet}"},
                ],
                temperature=0.5,
            )
            
            return response["choices"][0]["message"]["content"]
        
        except Exception as e:
            time.sleep(retry_delay)

In [None]:
def extract_code(input_string):
    pattern = r"(package|import|@[\w]+|public|private|protected).*\}\s*$"
    match = re.search(pattern, input_string, re.DOTALL | re.MULTILINE)
    
    if match:
        return match.group(0)
    else:
        # print("No valid code block found!")
        return "No valid code block found!"

In [None]:
def get_predictions(dataset):
    y_pred = []
    predicted_codes = []
    for code_snippet, correct_imports in tqdm(dataset):
        predicted_code = extract_code(get_prediction(code_snippet))
        predicted_import = re.findall(r"import\s+[\w\., ]+;", predicted_code)
        y_pred.append(predicted_import)
        predicted_codes.append(predicted_code)
    return y_pred, predicted_codes

In [None]:
save_directory = "/home/azmain/code_for_compilation_test/base-prompt/"

In [None]:
# Prediction for Android Classes

print("\nPrediction for Android Classes:\n")
y_pred, predicted_codes = get_predictions(android_test_examples)
print("\nPredicted Import List:", y_pred)
print("\nCorrect Import List:", android_y_true)

code_name = "android"
codes = predicted_codes
predicted_imports = y_pred
import_list = group_imports(predicted_imports)
append_imports(save_directory, code_name, codes, import_list)

y_pred_processed, y_true_processed = pred_process(y_pred, android_y_true)
eval_performance(y_pred_processed, y_true_processed)

In [None]:
# Prediction for JDK Classes

print("\nPrediction for JDK Classes:\n")
y_pred, predicted_codes = get_predictions(jdk_test_examples)
print("\nPredicted Import List:", y_pred)
print("\nCorrect Import List:", jdk_y_true)

code_name = "jdk"
codes = predicted_codes
predicted_imports = y_pred
import_list = group_imports(predicted_imports)
append_imports(save_directory, code_name, codes, import_list)

y_pred_processed, y_true_processed = pred_process(y_pred, jdk_y_true)
eval_performance(y_pred_processed, y_true_processed)

In [None]:
# Prediction for Hibernate Classes

print("\nPrediction for Hibernate Classes:\n")
y_pred, predicted_codes = get_predictions(hibernate_test_examples)
print("\nPredicted Import List:", y_pred)
print("\nCorrect Import List:", hibernate_y_true)

code_name = "hibernate"
codes = predicted_codes
predicted_imports = y_pred
import_list = group_imports(predicted_imports)
append_imports(save_directory, code_name, codes, import_list)

y_pred_processed, y_true_processed = pred_process(y_pred, hibernate_y_true)
eval_performance(y_pred_processed, y_true_processed)

In [None]:
# Prediction for Joda-Time Classes

print("\nPrediction for Joda-Time Classes:\n")
y_pred, predicted_codes = get_predictions(jodatime_test_examples)
print("\nPredicted Import List:", y_pred)
print("\nCorrect Import List:", jodatime_y_true)

code_name = "jodatime"
codes = predicted_codes
predicted_imports = y_pred
import_list = group_imports(predicted_imports)
append_imports(save_directory, code_name, codes, import_list)

y_pred_processed, y_true_processed = pred_process(y_pred, jodatime_y_true)
eval_performance(y_pred_processed, y_true_processed)

In [None]:
# Prediction for GWT Classes

print("\nPrediction for GWT Classes:\n")
y_pred, predicted_codes = get_predictions(gwt_test_examples)
print("\nPredicted Import List:", y_pred)
print("\nCorrect Import List:", gwt_y_true)

code_name = "gwt"
codes = predicted_codes
predicted_imports = y_pred
import_list = group_imports(predicted_imports)
append_imports(save_directory, code_name, codes, import_list)

y_pred_processed, y_true_processed = pred_process(y_pred, gwt_y_true)
eval_performance(y_pred_processed, y_true_processed)

In [None]:
# Prediction for XStream Classes

print("\nPrediction for XStream Classes\n")
y_pred, predicted_codes = get_predictions(xstream_test_examples)
print("\nPredicted Import List:", y_pred)
print("\nCorrect Import List:", xstream_y_true)

code_name = "xstream"
codes = predicted_codes
predicted_imports = y_pred
import_list = group_imports(predicted_imports)
append_imports(save_directory, code_name, codes, import_list)

y_pred_processed, y_true_processed = pred_process(y_pred, xstream_y_true)
eval_performance(y_pred_processed, y_true_processed)

# Compilation Rate Calculation

In [None]:
def compile_java(file_path, class_path, log_folder_success, log_folder_fail, output_folder):
    r_errors_count = 0
    gen_r_errors_count = 0
    try:
        successful_compile = False
        # If the output directory doesn't exist, create it
        if not os.path.exists(output_folder):
            os.makedirs(output_folder)
        
        # Compile the .java file with the provided class_path and specify the output directory for .class files
        result = subprocess.run(['javac', '-cp', class_path, '-d', output_folder, file_path], capture_output=True, text=True, timeout=10)

        if result.returncode == 0:
            print(f"Successfully compiled \"{file_path}\".")
            stderr_log = result.stderr
            successful_compile = True
        else:
            warning_val = False
            num_errors = 1
            stderr_lines = result.stderr.splitlines()
            num_errors_line = stderr_lines[-1]
            try:
                num_errors = int(num_errors_line.split()[0])  # Extract the number of error(s)
            except ValueError:
                warning_val = True

            r_errors_count = result.stderr.count("error: package R does not exist")
            gen_r_errors_count = result.stderr.count("error: package gen does not exist")
            total_r_errors_count = r_errors_count + gen_r_errors_count

            # Check if all errors are related to "package R does not exist"
            if num_errors == total_r_errors_count and warning_val == False:
                print(f"Successfully compiled, ignoring {num_errors} 'package R does not exist' errors for \"{file_path}\".")
                stderr_log = "No error (Ignored 'package R does not exist' errors)"
                successful_compile = True
            else:
                if warning_val:
                    print(f"Compiled with warning.")
                    stderr_log = "Warning present"
                    successful_compile = True
                else:
                    print(f"Failed to compile \"{file_path}\".")
                    stderr_log = result.stderr
                    successful_compile = False

        # Read the content of the Java file
        with open(file_path, 'r', errors='ignore') as java_file:
            java_code = java_file.read()

        # Prepare the JSON log
        log_data = {
            "file": file_path,
            "java_code": java_code,
            "stdout": result.stdout,
            "stderr": stderr_log,
            "return_code": result.returncode
        }
        
        # Save the logs to a file named after the Java file being compiled
        log_file_name = os.path.basename(file_path).replace('.java', '.json')
        if successful_compile:
            log_file_path = os.path.join(log_folder_success, log_file_name)
        else:
            log_file_path = os.path.join(log_folder_fail, log_file_name)
        
        with open(log_file_path, 'w') as log_file:
            json.dump(log_data, log_file, indent=4)

        return successful_compile

    except subprocess.TimeoutExpired:
        print(f"Compilation of {file_path} timed out.")
        return False

In [None]:
def calculate_success_rate(directory, class_path, log_folder_success, log_folder_fail, output_folder):
    if not os.path.exists(log_folder_success):
        os.makedirs(log_folder_success)

    if not os.path.exists(log_folder_fail):
        os.makedirs(log_folder_fail)

    java_files = [f for f in os.listdir(directory) if f.endswith('.java')]
    if not java_files:
        print("No .java files found.")
        return 0

    total_files = len(java_files)
    successful_compilations = 0

    for java_file in java_files:
        file_path = os.path.join(directory, java_file)
        if compile_java(file_path, class_path, log_folder_success, log_folder_fail, output_folder):
            successful_compilations += 1

    success_rate = (successful_compilations / total_files) * 100
    return success_rate, successful_compilations

In [None]:
dir_type = "base-prompt"
dir_path = f"/home/azmain/code_for_compilation_test/{dir_type}/"
jar_path = f"/home/azmain/snr_jars/"
class_path = f".:{jar_path}/*"
log_folder_success = f"/home/azmain/GitHub Codes/Type_Inference_with_LLM/Java_Type_Inference/Results/logs/{dir_type}-logs/compile_success/"
log_folder_fail = f"/home/azmain/GitHub Codes/Type_Inference_with_LLM/Java_Type_Inference/Results/logs/{dir_type}-logs/compile_fail/"
compiled_folder_path = f"/home/azmain/GitHub Codes/Type_Inference_with_LLM/Java_Type_Inference/Results/compiled-classes/{dir_type}-compiled/" 

rate, num_successful = calculate_success_rate(dir_path, class_path, log_folder_success, log_folder_fail, compiled_folder_path)
print(f"Compilation success rate: {rate:.2f}%")
print(f"Number of successfully compiled files: {num_successful}")