In [None]:
%env OPENAI_API_KEY = api_key

In [None]:
# pip install --upgrade openai

In [None]:
import os
import json
import re
import subprocess
import tiktoken

In [None]:
import openai

openai.api_key = os.getenv("OPENAI_API_KEY")
MODEL = 'gpt-3.5-turbo'
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")

In [None]:
def get_refined_error_log(stderr):
    # Modify stderr to only include the .java, not the full file path
    refined_stderr = '\n'.join([line.split('/')[-1] if '.java' in line else line for line in stderr.split('\n')])

    return refined_stderr

In [None]:
def extract_code(input_string):
    pattern = r"(package|import|@[\w]+|public|private|protected).*\}\s*$"
    match = re.search(pattern, input_string, re.DOTALL | re.MULTILINE)
    
    if match:
        return match.group(0)  # get the entire matched string
    else:
        print("No valid code block found!")
        return "No valid code block found!"

In [None]:
def get_fixed_code(java_code, refined_stderr, attempt, chat_history):
    message_content = ""
    if attempt < 2:
        message_content = f"See the code below:\n\"\"\"\n{java_code}\n\"\"\"\nFor the above code I got the below error log:\n\"\"\"\n{refined_stderr}\n\"\"\"\nNow fix the error and reply with fixed full code so that it can be successfully compiled."
    else:
        message_content = f"{chat_history}Now based on the above context fix the error and reply with fixed full code so that it can be successfully compiled."

    response = openai.ChatCompletion.create(
        model=MODEL,
        messages=[
            {"role": "system", "content": "Reply with only code, no elaboration."},
            {"role": "user", "content": f"{message_content}"},
        ],
        temperature=1,
    )
    
    return extract_code(response["choices"][0]["message"]["content"])

In [None]:
chat_history_error_message = ""

def compile_java(attempt, new_java_code_path, class_path, log_folder_success, log_folder_fail, compiled_folder_path):
    try:
        successful_compile = False
        # If the output directory doesn't exist, create it
        if not os.path.exists(compiled_folder_path):
            os.makedirs(compiled_folder_path)
        
        # Compile the .java file with the provided class_path and specify the output directory for .class files
        result = subprocess.run(['javac', '-cp', class_path, '-d', compiled_folder_path, new_java_code_path], capture_output=True, text=True, timeout=10)
        if result.returncode == 0:
            print(f"Attempt {attempt}. Successfully compiled \"{new_java_code_path}\".")
            stderr_log = result.stderr
            successful_compile = True
        else:
            warning_val = False
            num_errors = 1
            stderr_lines = result.stderr.splitlines()
            num_errors_line = stderr_lines[-1]
            try:
                num_errors = int(num_errors_line.split()[0])  # Extract the number of error(s)
            except ValueError:
                warning_val = True

            r_errors_count = result.stderr.count("error: package R does not exist")

            # Check if all errors are related to "package R does not exist"
            if num_errors == r_errors_count and warning_val == False:
                print(f"Attempt {attempt}. Successfully compiled, ignoring {num_errors} 'package R does not exist' errors for \"{new_java_code_path}\".")
                stderr_log = "No error (Ignored 'package R does not exist' errors)"
                successful_compile = True
            else:
                if warning_val:
                    print(f"Attempt {attempt}. Compiled with warning.")
                    stderr_log = "Warning present"
                    successful_compile = True
                else:
                    print(f"Attempt {attempt}. Failed to compile \"{new_java_code_path}\".")
                    stderr_log = result.stderr
                    global chat_history_error_message
                    chat_history_error_message = '\n'.join([line.split('/')[-1] if '.java' in line else line for line in result.stderr.split('\n')])
                    successful_compile = False

        if successful_compile or (not successful_compile and attempt == 5):
            # Read the content of the Java file
            with open(new_java_code_path, 'r') as java_file:
                java_code = java_file.read()

            # Prepare the JSON log
            log_data = {
                "attempt": attempt,
                "file": new_java_code_path,
                "java_code": java_code,
                "stdout": result.stdout,
                "stderr": stderr_log,
                "return_code": result.returncode
            }
        
            # Save the logs to a file named after the Java file being compiled
            log_file_name = os.path.basename(new_java_code_path).replace('.java', '.json')
            if successful_compile:
                log_file_path = os.path.join(log_folder_success, log_file_name)
            else:
                log_file_path = os.path.join(log_folder_fail, log_file_name)
            
            with open(log_file_path, 'w') as log_file:
                json.dump(log_data, log_file, indent=4)

        return successful_compile

    except subprocess.TimeoutExpired:
        print(f"Compilation of {new_java_code_path} timed out.")
        return False

In [None]:
def compile_code_and_result(data, fixed_path, class_path, log_folder_success, log_folder_fail, compiled_folder_path):
    java_code_path = data["file"]
    java_code = data["java_code"]
    refined_stderr = get_refined_error_log(data["stderr"])
    
    if not os.path.exists(log_folder_success):
        os.makedirs(log_folder_success)

    if not os.path.exists(log_folder_fail):
        os.makedirs(log_folder_fail)
    
    # Maximum 5 tries to get a successful compilation
    chat_history = ""
    chat_history_code = ""
    global chat_history_error_message
    chat_history_error_message = ""
                           
    for i in range(5):
        attempt = i+1
        if i==0:
            fixed_code = get_fixed_code(java_code, refined_stderr, attempt, chat_history)
        else:
            # Keeping the last fixed code by the model inside java_code
            java_code = fixed_code
            chat_history = f"{chat_history}{chat_history_code}\n\nYou gave the above code fix in your attempt {i}. But compiler gave this error:\n\n{chat_history_error_message}\n\n"
            if len(encoding.encode(chat_history)) > 4000:
                last_chat_history = f"{chat_history_code}\n\nYou gave the above code fix. But compiler gave this error:\n\n{chat_history_error_message}\n\n"
                fixed_code = get_fixed_code(java_code, refined_stderr, attempt, last_chat_history)
            else:
                fixed_code = get_fixed_code(java_code, refined_stderr, attempt, chat_history)
        chat_history_code = fixed_code
        # Saving the predicted fixed code by the model
        new_java_file_name = java_code_path.split('/')[-1]
        new_java_code_path = os.path.join(fixed_path, new_java_file_name)
        with open(new_java_code_path, 'w') as file:
            file.write(fixed_code)
        if compile_java(attempt, new_java_code_path, class_path, log_folder_success, log_folder_fail, compiled_folder_path):
            return 1

    return 0

In [None]:
# Zero-shot

dir_type = "zero-shot"

existing_failed_folder_path = f"/home/azmain/GitHub Codes/Type_Inference_with_LLM/Java_Type_Inference/Results/logs/{dir_type}-logs/compile_fail"
fixed_path = f"/home/azmain/GitHub Codes/Type_Inference_with_LLM/Java_Type_Inference/Results/llm_fixes/fixed_codes/{dir_type}-fix/"

jar_path = f"/home/azmain/snr_jars/"
class_path = f".:{jar_path}/*"
log_folder_success = f"/home/azmain/GitHub Codes/Type_Inference_with_LLM/Java_Type_Inference/Results/llm_fixes/llm-logs/{dir_type}-logs/compile_success/"
log_folder_fail = f"/home/azmain/GitHub Codes/Type_Inference_with_LLM/Java_Type_Inference/Results/llm_fixes/llm-logs/{dir_type}-logs/compile_fail/"
compiled_folder_path = f"/home/azmain/GitHub Codes/Type_Inference_with_LLM/Java_Type_Inference/Results/llm_fixes/llm-compiled-classes/{dir_type}-compiled/" 

successful_compilations = 0
for filename in os.listdir(existing_failed_folder_path):
    if filename.endswith(".json"):
        with open(os.path.join(existing_failed_folder_path, filename), "r") as f:
            data = json.load(f)
            success_count = compile_code_and_result(data, fixed_path, class_path, log_folder_success, log_folder_fail, compiled_folder_path)
            successful_compilations += success_count
            print()

total_files = len([f for f in os.listdir(existing_failed_folder_path) if f.endswith(".json")])
success_rate = (successful_compilations / total_files) * 100

print(f"Compilation success rate: {success_rate:.2f}%")
print(f"Number of successfully compiled files: {successful_compilations} out of {total_files}.")

In [None]:
# One-shot

dir_type = "one-shot"

existing_failed_folder_path = f"/home/azmain/GitHub Codes/Type_Inference_with_LLM/Java_Type_Inference/Results/logs/{dir_type}-logs/compile_fail"
fixed_path = f"/home/azmain/GitHub Codes/Type_Inference_with_LLM/Java_Type_Inference/Results/llm_fixes/fixed_codes/{dir_type}-fix/"

jar_path = f"/home/azmain/snr_jars/"
class_path = f".:{jar_path}/*"
log_folder_success = f"/home/azmain/GitHub Codes/Type_Inference_with_LLM/Java_Type_Inference/Results/llm_fixes/llm-logs/{dir_type}-logs/compile_success/"
log_folder_fail = f"/home/azmain/GitHub Codes/Type_Inference_with_LLM/Java_Type_Inference/Results/llm_fixes/llm-logs/{dir_type}-logs/compile_fail/"
compiled_folder_path = f"/home/azmain/GitHub Codes/Type_Inference_with_LLM/Java_Type_Inference/Results/llm_fixes/llm-compiled-classes/{dir_type}-compiled/" 

successful_compilations = 0
for filename in os.listdir(existing_failed_folder_path):
    if filename.endswith(".json"):
        with open(os.path.join(existing_failed_folder_path, filename), "r") as f:
            data = json.load(f)
            success_count = compile_code_and_result(data, fixed_path, class_path, log_folder_success, log_folder_fail, compiled_folder_path)
            successful_compilations += success_count
            print()

total_files = len([f for f in os.listdir(existing_failed_folder_path) if f.endswith(".json")])
success_rate = (successful_compilations / total_files) * 100

print(f"Compilation success rate: {success_rate:.2f}%")
print(f"Number of successfully compiled files: {successful_compilations} out of {total_files}.")